Source code for devinterp.slt.observables

"""Observable: evaluates probe datasets during SGLD sampling.

Each observable wraps a dataset and computes per-token losses at each draw.
Input IDs are fixed (same sequences every draw via DeterministicShuffledSampler).
"""

from __future__ import annotations

import random
from collections.abc import Iterator
import torch
from datasets import Dataset
from torch.utils.data import DataLoader, Sampler

from devinterp.slt.lm_loss import LossFn, compute_per_token_loss


[docs] def to_obs_id(task_name: str) -> str: """Convert a task name to a valid identifier for use in variable names.""" result = "".join(c if c.isalnum() or c == "_" else "_" for c in task_name) if result and result[0].isdigit(): result = "_" + result return result or "_"
[docs] class DeterministicShuffledSampler(Sampler): """A sampler that returns a fixed shuffled order of indices, deterministic from seed.""" def __init__(self, data_source: Dataset, num_samples: int, seed: int = 42): self.data_source = data_source self.num_samples = num_samples assert len(data_source) >= num_samples, ( f"Dataset size ({len(data_source)}) must be >= num_samples ({num_samples})" ) self.indices = random.Random(seed).sample(range(len(data_source)), num_samples) def __iter__(self) -> Iterator[int]: return iter(self.indices) def __len__(self) -> int: return self.num_samples
[docs] class Observable: """Evaluates a probe dataset at each SGLD draw. On construction, loads fixed input_ids (same sequences every draw). At each draw, compute_loss(model) returns per-token losses. Attributes: obs_id: Identifier derived from task_name (e.g. "pile_github"). input_ids: Fixed input_ids tensor, shape (n_samples, ctx_len+1). n_samples: batch_size * batches_per_draw. context_length: Number of predicted positions (ctx_len). """ def __init__( self, *, dataset: Dataset, task_name: str, batches_per_draw: int, batch_size: int, context_length: int, device: torch.device, seed: int = 1337, loss_fn: LossFn | None = None, ): self.obs_id = to_obs_id(task_name) self.batch_size = batch_size self.batches_per_draw = batches_per_draw self.n_samples = batch_size * batches_per_draw self.context_length = context_length self.device = device self.loss_fn = loss_fn if loss_fn is not None else compute_per_token_loss sampler = DeterministicShuffledSampler(dataset, self.n_samples, seed=seed) loader = DataLoader( dataset, batch_size=batch_size, sampler=sampler, drop_last=True ) # Load fixed input_ids self.input_ids = torch.zeros( self.n_samples, context_length + 1, dtype=torch.int64, device=device ) for i, batch in enumerate(loader): s, e = i * batch_size, (i + 1) * batch_size self.input_ids[s:e] = batch["input_ids"]
[docs] def compute_loss(self, model: torch.nn.Module) -> torch.Tensor: """Compute per-token loss for the model on this observable's data. Returns shape (n_samples, context_length). """ assert not torch.is_grad_enabled() loss = torch.zeros( self.n_samples, self.context_length, dtype=torch.float32, device=self.device ) for i in range(self.batches_per_draw): s, e = i * self.batch_size, (i + 1) * self.batch_size loss[s:e] = self.loss_fn(model, self.input_ids[s:e]) return loss