Source code for devinterp.slt.llc

"""Local Learning Coefficient (LLC) computation from sampling results.

Computes LLC from the stored per-draw losses, without needing callbacks.

    LLC = n_beta * (mean_sampling_loss - init_loss)

Two entry points:
- llc(): high-level, takes model + dataset, runs sampling + LLC
- compute_llc(): low-level, takes pre-computed sampling DataTree
"""

from __future__ import annotations

import torch
import xarray as xr
from datasets import Dataset

from devinterp.slt.lm_loss import LossFn
from devinterp.slt.sampler import ParamMasks
from devinterp.slt.sampling import ObservableSpec, sample


[docs] def llc( model: torch.nn.Module, dataset: Dataset, observables: dict[str, ObservableSpec], *, lr: float, n_beta: float, param_masks: ParamMasks | None = None, loss_fn: LossFn | None = None, **kwargs, ) -> xr.Dataset: """Sample and compute LLC in one call. Args: model: PyTorch model. dataset: HuggingFace Dataset with "input_ids" column. observables: Dict mapping names to datasets (or (dataset, batches_per_draw) tuples). lr: SGLD learning rate. n_beta: SGLD inverse temperature. param_masks: Which parameters to optimize. None for full model. loss_fn: Optional custom per-token loss `(model, input_ids) -> (batch, seq-1)`. Defaults to cross-entropy on the model's logits. **kwargs: Additional arguments passed to sample() (num_chains, num_draws, batch_size, output_path, etc.) Returns: xr.Dataset with llc_mean, llc_std, llc_per_chain, loss_trace, init_loss. """ samples = sample( model=model, dataset=dataset, observables=observables, param_masks=param_masks, lr=lr, n_beta=n_beta, loss_fn=loss_fn, **kwargs, ) return compute_llc(samples)
[docs] def compute_llc(samples: xr.DataTree) -> xr.Dataset: """Compute LLC from a sampling DataTree. Matches aether's `calculate` action with `function: llc`: averages `sampling_loss_micro` over every step (including burn-in) and every micro-batch, then subtracts mean `init_loss` and scales by `n_beta`. Args: samples: DataTree output from sample(), containing sampling_loss_micro, init_loss, and n_beta. Returns: xr.Dataset with: llc_mean: scalar, mean LLC across chains llc_std: scalar, std LLC across chains llc_per_chain: (chain,) LLC per chain llc_scalar: scalar LLC matching aether's calculate action loss_trace: (chain, step) mean loss per chain per step init_loss: scalar, mean init loss """ ds = samples.dataset n_beta = float(ds.metadata["config"]["sampler"]["n_beta"]) init_loss_mean = float(ds.init_loss.values.mean()) # Per-chain per-step mean loss (average over batch_sampling and token_pos) loss_trace = ds.sampling_loss_micro.mean(dim=["batch_sampling", "token_pos"]) # LLC per chain = n_beta * (mean_loss_per_chain - init_loss) llc_per_chain = n_beta * (loss_trace.mean(dim="step") - init_loss_mean) # (chain,) # Scalar LLC matching aether's calculate action reduction order # (all dims reduced at once for bitwise parity) llc_scalar = n_beta * (float(ds.sampling_loss_micro.mean()) - init_loss_mean) return xr.Dataset( { "llc_mean": llc_per_chain.mean(), "llc_std": llc_per_chain.std(), "llc_per_chain": llc_per_chain, "llc_scalar": llc_scalar, "loss_trace": loss_trace, "init_loss": init_loss_mean, } )