Source code for devinterp.slt.susceptibilities

"""Susceptibility computation.

Computes per-token susceptibilities from SGLD sampling results,
as described in appendix C.4 of https://arxiv.org/pdf/2504.18274.

Two entry points:
- susceptibilities(): high-level, takes model + datasets, runs sampling + post-processing
- compute_susceptibilities(): low-level, takes pre-computed sampling DataTrees
"""

from __future__ import annotations

from collections import defaultdict
from pathlib import Path

import numpy as np
import torch
import xarray as xr
from datasets import Dataset

from devinterp.slt.lm_loss import LossFn
from devinterp.slt.observables import to_obs_id
from devinterp.slt.sampler import ParamMasks
from devinterp.slt.sampling import ObservableSpec, sample


[docs] def susceptibilities( model: torch.nn.Module, dataset: Dataset, observables: dict[str, ObservableSpec], weight_restrictions: dict[str, ParamMasks | None], *, sampling_task: str, lr: float, n_beta: float, loss_fn: LossFn | None = None, include_sampling_task: bool = False, **kwargs, ) -> xr.DataTree: """Sample multiple weight restrictions and compute susceptibilities. Args: model: PyTorch model. dataset: HuggingFace Dataset with "input_ids" column. observables: Dict mapping names to datasets (or (dataset, batches_per_draw) tuples). weight_restrictions: Dict mapping WR names to param masks. Must include "full" (use None for full model). sampling_task: Name of the sampling dataset (must be in observables). lr: SGLD learning rate. n_beta: SGLD inverse temperature. loss_fn: Optional custom per-token loss `(model, input_ids) -> (batch, seq-1)`. Defaults to cross-entropy on the model's logits. include_sampling_task: If True, compute susceptibilities for the sampling_task observable too (per-token variation within that task is still informative). Default False. **kwargs: Additional arguments passed to sample() (num_chains, num_draws, batch_size, output_path, etc.). If output_path is provided, each weight restriction's samples are saved to a separate zarr with the WR name appended (e.g. "samples_full.zarr", "samples_l0h1.zarr"). Returns: DataTree with /susceptibilities and /context subtrees. """ if "full" not in weight_restrictions: raise ValueError( "weight_restrictions must include a 'full' key (use None for the full model)." ) if sampling_task not in observables: raise ValueError( f"sampling_task '{sampling_task}' not found in observables. " f"Available: {list(observables)}" ) if not include_sampling_task: non_sampling = [name for name in observables if name != sampling_task] if not non_sampling: raise ValueError( "Need at least one observable besides sampling_task " "(or set include_sampling_task=True)." ) # If output_path is provided, append WR name so each sample() writes to a # unique zarr (otherwise they'd all overwrite the same path). base_output_path = kwargs.pop("output_path", None) wr_map: dict[str, xr.DataTree] = {} for wr_name, masks in weight_restrictions.items(): if masks is None: masks = {name: None for name, _ in model.named_parameters()} wr_kwargs = dict(kwargs) if base_output_path is not None: base = Path(str(base_output_path)) wr_kwargs["output_path"] = ( base.parent / f"{base.stem}_{wr_name}{base.suffix}" ) wr_map[wr_name] = sample( model=model, dataset=dataset, observables=observables, param_masks=masks, lr=lr, n_beta=n_beta, loss_fn=loss_fn, **wr_kwargs, ) return compute_susceptibilities( wr_map, sampling_task, include_sampling_task=include_sampling_task )
[docs] def compute_susceptibilities( wr_map: dict[str, xr.DataTree], sampling_task: str, observable_names: list[str] | None = None, include_sampling_task: bool = False, ) -> xr.DataTree: """Compute per-token susceptibilities from sampling results. Args: wr_map: Dict mapping weight restriction names to DataTrees. Must include a "full" key for the unrestricted model. Each DataTree is the output of sample(). sampling_task: Name of the sampling/pretraining dataset task (e.g. "pile10k"). Must appear as an observable. observable_names: Which observables to compute susceptibilities for. If None, uses all observables (optionally including sampling_task, see include_sampling_task). include_sampling_task: If True and observable_names is None, include sampling_task in the discovered observables. Default False. Returns: DataTree with /susceptibilities and /context subtrees. """ full = wr_map["full"].dataset sampling_id = to_obs_id(sampling_task) # Discover observable ids from the DataTree all_obs_ids = [ str(v)[len("loss_") :] for v in full.data_vars if str(v).startswith("loss_") and str(v) != "sampling_loss" ] if observable_names is None: obs_ids = ( all_obs_ids if include_sampling_task else [oid for oid in all_obs_ids if oid != sampling_id] ) else: obs_ids = [to_obs_id(n) for n in observable_names] if not obs_ids: raise ValueError( f"No probe observables to compute susceptibilities for. " f"Available loss_* observables: {all_obs_ids!r}; " f"sampling_task={sampling_id!r}; " f"include_sampling_task={include_sampling_task}. " f"Add a probe observable distinct from the sampling task, " f"or pass include_sampling_task=True." ) # Pre-compute full-WR quantities full_sampling_mean = float(full[f"loss_{sampling_id}"].values.mean()) init_loss_mean = float(full["init_loss"].values.mean()) full_delta = {} for obs_id in obs_ids: # Mean over chain+draw, keeping (batch, target_position) full_delta[obs_id] = full_sampling_mean - full[f"loss_{obs_id}"].values.mean( axis=(0, 1) ) # Compute susceptibilities per WR x observable results: list[tuple[str, str, xr.DataArray, xr.DataArray]] = [] for wr_name, wr in wr_map.items(): if wr_name == "full": continue ds = wr.dataset # Verify expected layout since we slice positionally below. for oid in [sampling_id, *obs_ids]: expected = ("chain", "draw", f"batch_{oid}", "token_pos") assert ds[f"loss_{oid}"].dims == expected, ( f"loss_{oid} has dims {ds[f'loss_{oid}'].dims}, expected {expected}" ) # phi = per-(chain, draw) mean pretraining loss minus init_loss_mean pt_loss = ds[f"loss_{sampling_id}"].values # (chain, draw, batch, token_pos) pt_mean = pt_loss.mean(axis=(-2, -1)) # (chain, draw) phi = pt_mean - init_loss_mean E_phi = float(phi.mean()) E_phi_wr = float((phi * pt_mean).mean()) for obs_id in obs_ids: probe = ds[f"loss_{obs_id}"].values # (chain, draw, batch, token_pos) E_phi_probe = (phi[:, :, None, None] * probe).mean(axis=(0, 1)) sus = E_phi_wr - E_phi_probe - E_phi * full_delta[obs_id] ctx_len = sus.shape[1] sus_da = xr.DataArray( sus, dims=["batch", "target_position"], coords={"target_position": np.arange(1, ctx_len + 1)}, ) ids = ds[f"input_ids_{obs_id}"].values while ids.ndim > 2: ids = ids[0] ids_da = xr.DataArray( ids, dims=["batch", "position"], coords={"position": np.arange(ids.shape[1])}, ) results.append((wr_name, obs_id, sus_da, ids_da)) return _build_output(results)
# ─── Output format ────────────────────────────────────────────────────────── # The output uses the SusceptibilitiesV2 format: susceptibilities are flattened # into a 2D array (sus_flat, wr) and input_ids into a 1D array (ctx_flat), # both with coordinate arrays for traceability. This matches aether's format # so downstream visualization tools work unchanged. def _build_output( results: list[tuple[str, str, xr.DataArray, xr.DataArray]], ) -> xr.DataTree: """Build SusceptibilitiesV2 DataTree from (wr, obs, sus, ids) tuples.""" sus_by_wr: dict[str, dict[str, xr.DataArray]] = defaultdict(dict) all_ids: dict[str, dict[str, xr.DataArray]] = defaultdict(dict) wr_order: list[str] = [] for wr_name, obs, sus, ids in results: sus_by_wr[wr_name][obs] = sus all_ids[wr_name][obs] = ids if wr_name not in wr_order: wr_order.append(wr_name) obs_order = sorted(sus_by_wr[wr_order[0]].keys()) # Validate input_ids consistency across WRs first_wr = wr_order[0] deduped_ids: dict[str, xr.DataArray] = {} for obs in obs_order: ref = all_ids[first_wr][obs] deduped_ids[obs] = ref for wr in wr_order[1:]: if not np.array_equal(ref.values, all_ids[wr][obs].values): raise AssertionError( f"input_ids for {obs} differ between {first_wr} and {wr}" ) # Flatten input_ids → 1D with (dataset_id, batch, position) coords ids_items = {i: (obs, deduped_ids[obs]) for i, obs in enumerate(obs_order)} ids_coords = _flat_coords(ids_items, "ctx_flat") ids_data = np.concatenate([v.values.ravel() for _, v in ids_items.values()]) # Flatten susceptibilities → 2D (sus_flat, wr) sus_items = { i: (obs, sus_by_wr[wr_order[0]][obs]) for i, obs in enumerate(obs_order) } sus_coords = _flat_coords(sus_items, "sus_flat") sus_data = np.concatenate( [ np.stack([sus_by_wr[wr][obs].values.ravel() for wr in wr_order], axis=-1) for obs in obs_order ], axis=0, ) return xr.DataTree.from_dict( { "/susceptibilities": xr.Dataset( {"sus": xr.DataArray(sus_data, dims=["sus_flat", "wr"])}, coords=sus_coords.merge(xr.Coordinates({"wr": wr_order})), attrs={"dataset_id_to_name": obs_order}, ), "/context": xr.Dataset( {"input_ids": xr.DataArray(ids_data, dims=["ctx_flat"])}, coords=ids_coords, attrs={"dataset_id_to_name": obs_order}, ), } ) def _flat_coords( items: dict[int, tuple[str, xr.DataArray]], dim_name: str ) -> xr.Coordinates: """Build flattened coordinates from per-observable DataArrays.""" coords: dict[str, list] = {"dataset_id": []} for i, (_name, var) in items.items(): coords["dataset_id"].append(np.full(var.size, i)) grids = np.meshgrid(*[var.coords[d].values for d in var.dims], indexing="ij") for dim, grid in zip(var.dims, grids): coords.setdefault(str(dim), []).append(grid.ravel()) return xr.Coordinates( {k: (dim_name, np.concatenate(v)) for k, v in coords.items()}, indexes={}, )