Source code for devinterp.slt.sampling

"""SGLD sampling with observables, writing results to zarr.

Provides sample() as the main entry point. Internally uses
sample_single_chain from sampler.py for the SGLD inner loop.
"""

from __future__ import annotations

import logging
import tempfile
import time
import warnings
from pathlib import Path
from typing import Any, cast

import torch
import xarray as xr
import zarr
from datasets import Dataset
from torch.utils.data import DataLoader, Dataset as TorchDataset
from zarr.storage import LocalStore

from devinterp.optim.metrics import Metrics
from devinterp.slt.config import (
    SAMPLING_METHODS,
    SamplerConfig,
    SamplingMethodLiteral,
)
from devinterp.slt.lm_loss import LossFn, compute_per_token_loss, make_evaluate_fn
from devinterp.slt.observables import Observable
from devinterp.slt.sampler import (
    ParamMasks,
    _make_feed,
    is_transformer_lens_model,
    set_seed,
    calculate_num_steps,
    sample_single_chain,
)
from devinterp.slt.writing import ZarrWriter
from devinterp.slt.zarr_schema import DataArraySpec, ZarrSchema

SAMPLES_LOSS_DTYPE_STR = "float32"
ZARR_MAX_WRITE_THREADS = 4

logger = logging.getLogger(__name__)

# Type for observable specification: dataset alone (uses default batches_per_draw)
# or (dataset, batches_per_draw) tuple for explicit control.
ObservableSpec = Dataset | tuple[Dataset, int]


[docs] def sample( model: torch.nn.Module, dataset: Dataset, observables: dict[str, ObservableSpec], *, lr: float, n_beta: float, param_masks: ParamMasks | None = None, num_chains: int = 4, num_draws: int = 200, batch_size: int = 32, num_burnin_steps: int = 0, num_steps_bw_draws: int = 1, num_init_loss_batches: int = 32, init_seed: int = 100, batches_per_draw: int = 3, obs_seed: int = 1337, gradient_accumulation_steps: int = 1, localization: float = 0.0, noise_level: float = 1.0, llc_weight_decay: float = 0.0, bounding_box_size: float | None = None, sampling_method: SamplingMethodLiteral = "sgmcmc_sgld", sampling_method_kwargs: dict[str, Any] | None = None, rmsprop_eps: float | None = None, rmsprop_alpha: float | None = None, shuffle: bool = True, match_sampling_input_ids_across_chains: bool = True, init_noise: float | None = None, device: str | None = None, save_metrics: bool = False, output_path: str | Path | None = None, loss_fn: LossFn | None = None, ) -> xr.DataTree: """Run SGLD sampling with observables. Args: model: PyTorch model. dataset: HuggingFace Dataset with "input_ids" column, used for SGLD sampling. observables: Dict mapping observable names to datasets (or (dataset, batches_per_draw) tuples). Each dataset must have an "input_ids" column. lr: SGLD learning rate. n_beta: SGLD inverse temperature. param_masks: Which parameters to optimize. None means all parameters (full model). Otherwise a dict mapping param names to mask tensors (or None for unrestricted). num_chains: Number of SGLD chains. num_draws: Number of draws per chain. batch_size: Batch size for sampling and observables. num_burnin_steps: SGLD burn-in steps before drawing. num_steps_bw_draws: Steps between draws. num_init_loss_batches: Batches for init_loss computation. init_seed: Random seed. batches_per_draw: Default batches_per_draw for observables (used when an observable is specified as just a dataset, not a tuple). obs_seed: Seed for deterministic observable sampling. gradient_accumulation_steps: Number of micro-batches per optimizer step. Effective batch size is batch_size * gradient_accumulation_steps. localization: Strength of the pull toward initial parameters (gamma in Lau et al. 2023). 0 disables localization. noise_level: Standard deviation of SGLD noise. Defaults to 1.0; changing this breaks the SGLD posterior-sampling guarantee. llc_weight_decay: L2 regularization strength (lambda). Applied as a Gaussian prior centered at zero. bounding_box_size: If set, restricts sampling to a box of this radius around the initial parameters. None disables. sampling_method: Which SGLD variant to use. "sgmcmc_sgld" is the default; "rmsprop_sgld" adds RMSprop-style preconditioning. sampling_method_kwargs: Extra kwargs forwarded to the sampling-method constructor (e.g. rmsprop's "alpha" / "eps", or "add_grad_correction"). Use `rmsprop_eps` / `rmsprop_alpha` as convenience aliases for the two most common rmsprop knobs. rmsprop_eps: RMSprop stability constant. Only valid when sampling_method='rmsprop_sgld'. Shorthand for sampling_method_kwargs={"eps": ...}. rmsprop_alpha: RMSprop moving-average coefficient. Only valid when sampling_method='rmsprop_sgld'. Shorthand for sampling_method_kwargs={"alpha": ...}. shuffle: Whether to shuffle the sampling dataset. Default True. match_sampling_input_ids_across_chains: If True, every chain sees the same input_ids in the same order (only the SGLD noise differs across chains). If False, each chain gets an independently-seeded shuffle. init_noise: If set, add Gaussian noise with this std to parameters before sampling. device: Compute device. None for auto-detect. save_metrics: If True, save per-step SGLD diagnostics (gradient norms, noise norms, distance from init, etc.) for tuning sampling parameters. output_path: Path for output zarr. None for a temp directory. loss_fn: Optional custom per-token loss `(model, input_ids) -> (batch, seq-1)`. Defaults to cross-entropy on the model's logits. Returns: Lazy-loaded DataTree of sampling results. """ start = time.perf_counter() if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if param_masks is None: param_masks = {name: None for name, _ in model.named_parameters()} context_length = len(dataset[0]["input_ids"]) - 1 if context_length < 1: raise ValueError( f"Sequences must have length >= 2 for next-token loss " f"(got {context_length + 1})." ) ds = cast(TorchDataset, dataset) sampling_method_kwargs = dict(sampling_method_kwargs or {}) for name, value in (("eps", rmsprop_eps), ("alpha", rmsprop_alpha)): if value is None: continue if sampling_method != "rmsprop_sgld": raise ValueError( f"rmsprop_{name} can only be set when sampling_method='rmsprop_sgld', " f"got sampling_method={sampling_method!r}" ) if name in sampling_method_kwargs: raise ValueError( f"rmsprop_{name} is set both as a top-level argument and in " f"sampling_method_kwargs[{name!r}]; specify only one." ) sampling_method_kwargs[name] = value # Build observable objects obs_list: list[Observable] = [] for name, spec in observables.items(): if isinstance(spec, tuple): obs_ds, bpd = spec else: obs_ds, bpd = spec, batches_per_draw obs_list.append( Observable( dataset=obs_ds, task_name=name, batches_per_draw=bpd, batch_size=batch_size, context_length=context_length, device=torch.device(device), seed=obs_seed, loss_fn=loss_fn, ) ) # Build config config = SamplerConfig( lr=lr, n_beta=n_beta, num_chains=num_chains, num_draws=num_draws, batch_size=batch_size, num_burnin_steps=num_burnin_steps, num_steps_bw_draws=num_steps_bw_draws, num_init_loss_batches=num_init_loss_batches, init_seed=init_seed, gradient_accumulation_steps=gradient_accumulation_steps, localization=localization, noise_level=noise_level, llc_weight_decay=llc_weight_decay, bounding_box_size=bounding_box_size, sampling_method=sampling_method, sampling_method_kwargs=sampling_method_kwargs, shuffle=shuffle, match_sampling_input_ids_across_chains=match_sampling_input_ids_across_chains, init_noise=init_noise, save_metrics=save_metrics, ) # Cache: if output_path exists, validate and return early on match. if output_path is not None and Path(output_path).exists(): cached = _check_cache(output_path, config) logger.info("sample() using cached output at %s", output_path) return cached # Warn about non-persisted big runs total_work = ( num_chains * (num_draws * num_steps_bw_draws + num_burnin_steps) * batch_size ) if output_path is None and total_work > 1000: warnings.warn( f"Sampling without output_path set -- {total_work} effective " "samples will be written to a temp directory and lost when " "the process exits. Pass output_path='/path/to/samples.zarr' " "to save them.", stacklevel=2, ) # Compute numel for metrics if needed numel: int | None = None if save_metrics: name_to_param = dict(model.named_parameters()) numel = sum( int(mask.count_nonzero()) if mask is not None else name_to_param[name].numel() for name, mask in param_masks.items() ) # Build zarr schema and store chain_buffer_size = min(50, num_draws) schema = _build_sampling_schema( config=config, context_length=context_length, chain_buffer_size=chain_buffer_size, observables=obs_list, numel=numel, ) if output_path is None: output_path = Path(tempfile.mkdtemp()) / "samples.zarr" store = LocalStore(output_path) _, arrays = schema.create_hierarchy(store) # Resolve sampling method sampling_method_cls = SAMPLING_METHODS.get(config.sampling_method) if sampling_method_cls is None: raise ValueError(f"Unknown sampling method {config.sampling_method}") sampling_method_kwargs = dict( nbeta=n_beta, lr=lr, localization=config.localization, noise_level=config.noise_level, weight_decay=config.llc_weight_decay, bounding_box_size=config.bounding_box_size, save_metrics=save_metrics, **config.sampling_method_kwargs, ) # Callbacks for zarr writing def on_draw(*, loss, draw, chain, model, **_): writer.push("sampling_loss", loss, chain=chain, draw=draw) for obs in obs_list: assert not torch.is_grad_enabled() obs_loss = obs.compute_loss(model) writer.push(f"loss_{obs.obs_id}", obs_loss, chain=chain, draw=draw) writer.flush_full_buffers() # Buffer grad_accum micro-batches per (chain, step) and write once full, # since the zarr writer expects a full row per push. micro_loss_buf: torch.Tensor | None = None micro_ids_buf: torch.Tensor | None = None micro_total = gradient_accumulation_steps * batch_size def on_micro( loss: torch.Tensor, input_ids: torch.Tensor, chain: int, step: int, micro_step: int, ) -> None: nonlocal micro_loss_buf, micro_ids_buf if micro_step == 0: micro_loss_buf = torch.empty(micro_total, context_length, dtype=loss.dtype) micro_ids_buf = torch.empty( micro_total, context_length + 1, dtype=input_ids.dtype ) s = slice(micro_step * batch_size, (micro_step + 1) * batch_size) assert micro_loss_buf is not None and micro_ids_buf is not None micro_loss_buf[s] = loss.cpu() micro_ids_buf[s] = input_ids if micro_step == gradient_accumulation_steps - 1: writer.push("sampling_loss_micro", micro_loss_buf, chain=chain, step=step) writer.push( "sampling_input_ids_micro", micro_ids_buf, chain=chain, step=step ) writer.flush_full_buffers() def on_step(chain: int, step: int, optimizer: torch.optim.Optimizer) -> None: metrics = optimizer.get_metrics() for field_name in Metrics.NORM_FIELDS + Metrics.DOT_FIELDS: value = getattr(metrics, field_name).squeeze() writer.push(f"metrics_{field_name}", value, chain=chain, step=step) writer.flush_full_buffers() step_callback = on_step if save_metrics else None # Seed and chain setup dataloader_seed = ( init_seed if config.match_sampling_input_ids_across_chains else None ) with ZarrWriter.open( arrays, chain_buffer_size, torch.device(device), ZARR_MAX_WRITE_THREADS ) as writer: # Write fixed observable input_ids for obs in obs_list: writer.write(f"input_ids_{obs.obs_id}", obs.input_ids) # Compute and write init loss _write_init_loss(writer, config, model, param_masks, ds, device, loss_fn) # Run SGLD chains for chain_idx in range(num_chains): sample_single_chain( ref_model=model, dataset=ds, evaluate=make_evaluate_fn(loss_fn), param_masks=param_masks, num_draws=num_draws, num_burnin_steps=num_burnin_steps, num_steps_bw_draws=num_steps_bw_draws, gradient_accumulation_steps=config.gradient_accumulation_steps, sampling_method=sampling_method_cls, sampling_method_kwargs=sampling_method_kwargs, chain=chain_idx, seed=init_seed + chain_idx, dataloader_seed=dataloader_seed if dataloader_seed is not None else init_seed + chain_idx, device=device, callbacks=[on_draw], step_callback=step_callback, micro_callback=on_micro, batch_size=batch_size, init_noise=init_noise, shuffle=config.shuffle, epoch_mode=config.epoch_mode, ) # Mark as completed so future cache checks can trust the output. zarr.open_group(str(output_path)).attrs["completed"] = 1 logger.info("sample() total time: %.2f seconds", time.perf_counter() - start) return xr.open_datatree(str(output_path), engine="zarr", consolidated=False)
def _check_cache(output_path: str | Path, config: SamplerConfig) -> xr.DataTree: """Validate an existing sample output against the current config. Raises RuntimeError with a clear "delete and retry" message if the file is unreadable, incomplete, or was produced with different sampler args. Otherwise returns the loaded DataTree. """ path_str = str(output_path) try: existing = xr.open_datatree(path_str, engine="zarr", consolidated=False) except Exception as e: raise RuntimeError( f"Output path '{output_path}' exists but couldn't be opened as zarr:\n" f" {e!r}\n" f"Delete and retry: rm -rf '{output_path}'" ) from e if existing.attrs.get("completed") != 1: raise RuntimeError( f"Output path '{output_path}' exists but sampling was incomplete " f"(no 'completed' flag — likely interrupted).\n" f"Delete and retry: rm -rf '{output_path}'" ) stored_sampler = existing.metadata["config"]["sampler"] expected_sampler = { **config.model_dump(), "scheduler_type": "constant", "scheduler_kwargs": None, "cores": 1, "online": False, } if stored_sampler != expected_sampler: diffs = [] all_keys = set(stored_sampler) | set(expected_sampler) for key in sorted(all_keys): s = stored_sampler.get(key) e = expected_sampler.get(key) if s != e: diffs.append(f" {key}: stored={s!r}, current={e!r}") raise RuntimeError( f"Output path '{output_path}' has a different sampler config:\n" + "\n".join(diffs) + f"\nDelete and retry: rm -rf '{output_path}'" ) return existing # ─── Internal helpers ──────────────────────────────────────────────────────── def _build_sampling_schema( *, config: SamplerConfig, context_length: int, chain_buffer_size: int, observables: list[Observable], numel: int | None = None, ) -> ZarrSchema: """Build a ZarrSchema for the sampling pipeline.""" num_chains = config.num_chains num_draws = config.num_draws chain_chunk_size = 1 draw_chunk_size = chain_buffer_size chain_draw_chunks = (chain_chunk_size, draw_chunk_size) num_steps = calculate_num_steps( num_draws=num_draws, num_steps_bw_draws=config.num_steps_bw_draws, num_burnin_steps=config.num_burnin_steps, ) arrays_meta: dict[str, DataArraySpec] = {} arrays_meta["init_loss"] = DataArraySpec( dtype_str=SAMPLES_LOSS_DTYPE_STR, dims=("chain", "token_pos"), shape=(num_chains, context_length), chunks=(1, context_length), ) arrays_meta["sampling_loss"] = DataArraySpec( dtype_str=SAMPLES_LOSS_DTYPE_STR, dims=("chain", "draw", "batch", "token_pos"), shape=(num_chains, num_draws, config.batch_size, context_length), chunks=chain_draw_chunks + (config.batch_size, context_length), ) micro_batch = config.gradient_accumulation_steps * config.batch_size arrays_meta["sampling_loss_micro"] = DataArraySpec( dtype_str=SAMPLES_LOSS_DTYPE_STR, dims=("chain", "step", "batch_sampling", "token_pos"), shape=(num_chains, num_steps, micro_batch, context_length), chunks=chain_draw_chunks + (micro_batch, context_length), ) arrays_meta["sampling_input_ids_micro"] = DataArraySpec( dtype_str="int64", dims=("chain", "step", "batch_sampling", "token"), shape=(num_chains, num_steps, micro_batch, context_length + 1), chunks=chain_draw_chunks + (micro_batch, context_length + 1), ) for obs in observables: batch_dim = f"batch_{obs.obs_id}" arrays_meta[f"loss_{obs.obs_id}"] = DataArraySpec( dtype_str=SAMPLES_LOSS_DTYPE_STR, dims=("chain", "draw", batch_dim, "token_pos"), shape=(num_chains, num_draws, obs.n_samples, context_length), chunks=chain_draw_chunks + (obs.n_samples, context_length), ) arrays_meta[f"input_ids_{obs.obs_id}"] = DataArraySpec( dtype_str="int64", dims=(batch_dim, "token"), shape=(obs.n_samples, context_length + 1), ) group_attrs: dict[str, dict[str, Any]] = { "": { "metadata": { "config": { "sampler": { **config.model_dump(), # Fields we removed but can provide defaults for "scheduler_type": "constant", "scheduler_kwargs": None, "cores": 1, "online": False, }, "observables": [ { "type": "ObserveDifferentDataset", "obs_id": obs.obs_id, "batches_per_draw": obs.batches_per_draw, } for obs in observables ], "chain_buffer_size": chain_buffer_size, "shared_observable_dims": ["token", "token_pos"], "tokenizer_context_length": context_length + 1, }, }, } } if config.save_metrics: step_chunk_size = chain_buffer_size chain_step_chunks = (chain_chunk_size, step_chunk_size) for field_name in Metrics.NORM_FIELDS + Metrics.DOT_FIELDS: arrays_meta[f"metrics_{field_name}"] = DataArraySpec( dtype_str="float32", dims=("chain", "step"), shape=(num_chains, num_steps), chunks=chain_step_chunks, ) assert numel is not None, "numel must be provided when save_metrics=True" group_attrs[""]["metrics_numel"] = numel return ZarrSchema( arrays_meta=arrays_meta, group_attrs=group_attrs, ) def _write_init_loss( writer: ZarrWriter, config: SamplerConfig, model: torch.nn.Module, param_masks: ParamMasks, dataset: TorchDataset, device: str, loss_fn: LossFn | None, ) -> None: """Compute and write the initial (pre-sampling) loss for each chain.""" num_batches = config.gradient_accumulation_steps * config.num_init_loss_batches fn = loss_fn if loss_fn is not None else compute_per_token_loss # Remember original device so we can restore after init_loss computation. # nn.Module.to() is in-place — without this, the caller's model would be # silently moved to the compute device. orig_device = next(model.parameters()).device was_training = model.training if is_transformer_lens_model(model): model_on_device = model.to(device, print_details=False) # pyright: ignore[reportCallIssue] else: model_on_device = model.to(device) model_on_device.train() torch.manual_seed(config.init_seed) name_to_param = dict(model_on_device.named_parameters()) orig_data: dict[str, torch.Tensor] = {} if config.init_noise: # init_loss is computed at the unrestricted-noise state (noise on all # params, ignoring weight_restrictions) to match aether's behavior. # The sampling chain itself still applies masked init_noise. orig_data = {name: p.data.clone() for name, p in name_to_param.items()} for chain_idx in range(config.num_chains): set_seed(config.init_seed + chain_idx, device=device) dataloader_rng = torch.Generator(device="cpu") dataloader_rng.manual_seed( config.init_seed if config.match_sampling_input_ids_across_chains else config.init_seed + chain_idx ) loader = DataLoader( dataset, batch_size=config.batch_size, shuffle=config.shuffle, generator=dataloader_rng, drop_last=True, ) if config.init_noise: for parameter in name_to_param.values(): parameter.data.add_( torch.randn_like(parameter.data) * config.init_noise ) feed = _make_feed(loader, config.epoch_mode, num_batches) accumulated_loss = 0.0 with torch.no_grad(): for i in range(num_batches): batch = next(feed) input_ids = batch["input_ids"].to(device) loss = fn(model_on_device, input_ids) if chain_idx == 0 and i == 0: expected = input_ids.shape[:-1] + (input_ids.shape[-1] - 1,) if tuple(loss.shape) != tuple(expected): raise ValueError( f"loss_fn must return per-token loss of shape " f"{tuple(expected)}, got {tuple(loss.shape)}" ) accumulated_loss += loss.detach().float() / num_batches writer.write("init_loss", accumulated_loss.mean(dim=0), chain=chain_idx) if config.init_noise: for name, parameter in name_to_param.items(): parameter.data.copy_(orig_data[name]) # Restore model to its original device and training mode if is_transformer_lens_model(model): model.to(orig_device, print_details=False) # pyright: ignore[reportCallIssue] else: model.to(orig_device) model.train(was_training)