Source code for devinterp.utils

import warnings
from typing import Union

import numpy as np
from torch.utils.data import DataLoader


def default_nbeta(
    dataloader: Union[DataLoader, int], gradient_accumulation_steps: int = 1
) -> float:
    if isinstance(dataloader, DataLoader):
        default_nbeta = dataloader.batch_size * gradient_accumulation_steps
        if default_nbeta <= 1:
            warnings.warn(
                "default nbeta is undefined for batch_size * gradient_accumulation_steps == 1, falling back to default value of 1"
            )
            return 1
        else:
            return default_nbeta / np.log(default_nbeta)
    elif isinstance(dataloader, int):
        default_nbeta = dataloader * gradient_accumulation_steps
        if default_nbeta <= 1:
            warnings.warn(
                "default nbeta is undefined for batch_size * gradient_accumulation_steps == 1, falling back to default value of 1"
            )
            return 1
        else:
            return default_nbeta / np.log(default_nbeta)
    else:
        raise NotImplementedError(
            f"N*beta for data type {type(dataloader)} not implemented, use DataLoader or int instead."
        )


[docs] def tokenize_and_concatenate( dataset, tokenizer, streaming: bool = False, max_length: int = 1024, column_name: str = "text", add_bos_token: bool = True, num_proc: int = 10, ): """Tokenize and concatenate a text dataset into fixed-length sequences. Based on TransformerLens (MIT License, Copyright 2022 TransformerLensOrg): https://github.com/TransformerLensOrg/TransformerLens Core algorithm unchanged, with local additions: input validation (bos token and max_length checks), numpy reshape in place of einops, and the output column renamed from "tokens" to "input_ids". Joins all text separated by EOS tokens, tokenizes, then reshapes into (num_sequences, max_length) chunks. Args: dataset: HuggingFace text dataset. tokenizer: HuggingFace tokenizer with bos_token_id and eos_token_id. streaming: If True, disables parallel tokenization. max_length: Context window length. column_name: Name of the text column. add_bos_token: Prepend BOS to each sequence (reduces usable length by 1). num_proc: Number of processes for dataset.map(). Returns: Dataset with an "input_ids" column of torch tensors. """ dataset = dataset.select_columns(column_name) if tokenizer.pad_token is None: tokenizer.add_special_tokens({"pad_token": "<PAD>"}) if add_bos_token and tokenizer.bos_token_id is None: raise ValueError( f"add_bos_token=True but tokenizer has no bos_token_id. " f"Use add_bos_token=False for {type(tokenizer).__name__}." ) seq_len = max_length - 1 if add_bos_token else max_length if seq_len < 1: raise ValueError( f"max_length={max_length} is too small" f"{' with add_bos_token=True (needs at least 2)' if add_bos_token else ' (needs at least 1)'}." ) def tokenize_function(examples): text = examples[column_name] full_text = tokenizer.eos_token.join(text) if not full_text.strip(): return {"input_ids": np.array([], dtype=np.int64)} num_chunks = 20 chunk_length = (len(full_text) - 1) // num_chunks + 1 chunks = [ full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks) ] tokens = tokenizer(chunks, return_tensors="np", padding=True)[ "input_ids" ].flatten() tokens = tokens[tokens != tokenizer.pad_token_id] num_tokens = len(tokens) if num_tokens < seq_len: num_batches = 1 tokens = tokens[:seq_len] if len(tokens) < seq_len: padding = np.full(seq_len - len(tokens), tokenizer.pad_token_id) tokens = np.concatenate([tokens, padding], axis=0) else: num_batches = num_tokens // seq_len tokens = tokens[: seq_len * num_batches] tokens = tokens.reshape(num_batches, seq_len) if add_bos_token: prefix = np.full((num_batches, 1), tokenizer.bos_token_id) tokens = np.concatenate([prefix, tokens], axis=1) return {"input_ids": tokens} tokenized_dataset = dataset.map( tokenize_function, batched=True, num_proc=(num_proc if not streaming else None), remove_columns=[column_name], ) if len(tokenized_dataset) == 0 or "input_ids" not in tokenized_dataset.column_names: raise ValueError( "Tokenization produced no sequences. Check that the input text is not empty." ) tokenized_dataset.set_format(type="torch", columns=["input_ids"]) return tokenized_dataset