devinterp.slt package
Submodules
devinterp.slt.bif module
Bayesian Influence Function (BIF) computation.
Computes pairwise correlations between observable loss traces across sequences from SGLD sampling results.
Two entry points: - bif(): high-level, takes model + dataset, runs sampling + BIF - compute_bif(): low-level, takes pre-computed sampling DataTree
- devinterp.slt.bif.bif(model: Module, dataset: Dataset, observables: dict[str, Dataset | tuple[Dataset, int]], *, lr: float, n_beta: float, param_masks: dict[str, Tensor | None] | None = None, correlation_method: Literal['token', 'sequence'] = 'token', reduce_chain_dimension_method: Literal['stack', 'mean'] = 'stack', average_tokenwise_bif: bool = False, compute_covariance: bool = False, bif_batch_size: int = 32, bif_device: str | device | None = None, loss_fn: Callable[[Module, Tensor], Tensor] | None = None, **kwargs) Dataset[source]
Sample and compute BIF in one call.
- Parameters:
model – PyTorch model.
dataset – HuggingFace Dataset with “input_ids” column.
observables – Dict mapping names to datasets (or (dataset, batches_per_draw) tuples).
lr – SGLD learning rate.
n_beta – SGLD inverse temperature.
param_masks – Which parameters to optimize. None for full model.
correlation_method – “token” or “sequence”.
reduce_chain_dimension_method – “stack” (recommended) or “mean”.
average_tokenwise_bif – Average token-wise BIF to scalar per pair.
compute_covariance – Compute covariance instead of correlation.
bif_batch_size – Batch size for BIF block processing.
bif_device – Torch device for BIF computation. None for auto-detect.
loss_fn – Optional custom per-token loss (model, input_ids) -> (batch, seq-1). Defaults to cross-entropy on the model’s logits.
**kwargs – Additional arguments passed to sample() (num_chains, num_draws, batch_size, output_path, etc.)
- Returns:
xr.Dataset with “influences” and “input_ids” variables.
- devinterp.slt.bif.compute_bif(samples: DataTree, *, correlation_method: Literal['token', 'sequence'] = 'token', reduce_chain_dimension_method: Literal['stack', 'mean'] = 'stack', loss_keys: Literal['all'] | list[str] = 'all', batch_index_range_1: Literal['all'] | Sequence[int] = 'all', batch_index_range_2: Literal['all'] | Sequence[int] = 'all', average_tokenwise_bif: bool = False, compute_covariance: bool = False, batch_size: int = 32, device: str | device | None = None) Dataset[source]
Compute BIF from a sampling DataTree.
- Parameters:
samples – DataTree output from sample().
correlation_method – “token” for token-wise, “sequence” for sequence-level.
reduce_chain_dimension_method – “stack” (recommended) or “mean”.
loss_keys – Which observables to include. “all” auto-discovers.
batch_index_range_1 – Batch indices for first operand.
batch_index_range_2 – Batch indices for second operand.
average_tokenwise_bif – Average token-wise BIF to scalar per pair.
compute_covariance – Compute covariance instead of correlation.
batch_size – Batch size for block processing.
device – Torch device. None for auto-detect.
- Returns:
xr.Dataset with “influences” and “input_ids” variables.
devinterp.slt.config module
Configuration for SGLD sampling.
- class devinterp.slt.config.SamplerConfig(*, lr: float, n_beta: float, batch_size: int = 32, num_chains: int = 4, num_draws: int = 200, num_burnin_steps: int = 0, num_steps_bw_draws: int = 1, gradient_accumulation_steps: int = 1, num_init_loss_batches: int = 32, init_seed: int = 100, localization: float = 0.0, noise_level: float = 1.0, llc_weight_decay: float = 0.0, bounding_box_size: float | None = None, sampling_method: ~typing.Literal['sgmcmc_sgld', 'rmsprop_sgld'] = 'sgmcmc_sgld', sampling_method_kwargs: dict[str, ~typing.Any] = <factory>, init_noise: float | None = None, save_metrics: bool = False, shuffle: bool = True, epoch_mode: ~typing.Literal['once', 'cycle'] = 'cycle', match_sampling_input_ids_across_chains: bool = True)[source]
Bases:
BaseModelConfiguration for SGLD sampling.
Validates types and value ranges for all sampler parameters.
- batch_size: PositiveInt
- bounding_box_size: NonNegativeFloat | None
- copy(*, include: AbstractSetIntStr | MappingIntStrAny | None = None, exclude: AbstractSetIntStr | MappingIntStrAny | None = None, update: Dict[str, Any] | None = None, deep: bool = False) Self
Returns a copy of the model.
- !!! warning “Deprecated”
This method is now deprecated; use model_copy instead.
If you need include or exclude, use:
`python {test="skip" lint="skip"} data = self.model_dump(include=include, exclude=exclude, round_trip=True) data = {**data, **(update or {})} copied = self.model_validate(data) `- Parameters:
include – Optional set or mapping specifying which fields to include in the copied model.
exclude – Optional set or mapping specifying which fields to exclude in the copied model.
update – Optional dictionary of field-value pairs to override field values in the copied model.
deep – If True, the values of fields that are Pydantic models will be deep-copied.
- Returns:
A copy of the model with included, excluded and updated fields as specified.
- epoch_mode: EpochMode
- gradient_accumulation_steps: PositiveInt
- init_noise: NonNegativeFloat | None
- init_seed: int
- llc_weight_decay: NonNegativeFloat
- localization: NonNegativeFloat
- lr: NonNegativeFloat
- match_sampling_input_ids_across_chains: bool
- model_computed_fields = {}
- model_config: ClassVar[ConfigDict] = {'extra': 'forbid'}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- classmethod model_construct(_fields_set: set[str] | None = None, **values: Any) Self
Creates a new instance of the Model class with validated data.
Creates a new model setting __dict__ and __pydantic_fields_set__ from trusted or pre-validated data. Default values are respected, but no other validation is performed.
- !!! note
model_construct() generally respects the model_config.extra setting on the provided model. That is, if model_config.extra == ‘allow’, then all extra passed values are added to the model instance’s __dict__ and __pydantic_extra__ fields. If model_config.extra == ‘ignore’ (the default), then all extra passed values are ignored. Because no validation is performed with a call to model_construct(), having model_config.extra == ‘forbid’ does not result in an error if extra values are passed, but they will be ignored.
- Parameters:
_fields_set – A set of field names that were originally explicitly set during instantiation. If provided, this is directly used for the [model_fields_set][pydantic.BaseModel.model_fields_set] attribute. Otherwise, the field names from the values argument will be used.
values – Trusted or pre-validated data dictionary.
- Returns:
A new instance of the Model class with validated data.
- model_copy(*, update: Mapping[str, Any] | None = None, deep: bool = False) Self
- !!! abstract “Usage Documentation”
[model_copy](../concepts/models.md#model-copy)
Returns a copy of the model.
- !!! note
The underlying instance’s [__dict__][object.__dict__] attribute is copied. This might have unexpected side effects if you store anything in it, on top of the model fields (e.g. the value of [cached properties][functools.cached_property]).
- Parameters:
update – Values to change/add in the new model. Note: the data is not validated before creating the new model. You should trust this data.
deep – Set to True to make a deep copy of the model.
- Returns:
New model instance.
- model_dump(*, mode: Literal['json', 'python'] | str = 'python', include: set[int] | set[str] | Mapping[int, set[int] | set[str] | Mapping[int, IncEx | bool] | Mapping[str, IncEx | bool] | bool] | Mapping[str, set[int] | set[str] | Mapping[int, IncEx | bool] | Mapping[str, IncEx | bool] | bool] | None = None, exclude: set[int] | set[str] | Mapping[int, set[int] | set[str] | Mapping[int, IncEx | bool] | Mapping[str, IncEx | bool] | bool] | Mapping[str, set[int] | set[str] | Mapping[int, IncEx | bool] | Mapping[str, IncEx | bool] | bool] | None = None, context: Any | None = None, by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, exclude_computed_fields: bool = False, round_trip: bool = False, warnings: bool | Literal['none', 'warn', 'error'] = True, fallback: Callable[[Any], Any] | None = None, serialize_as_any: bool = False, polymorphic_serialization: bool | None = None) dict[str, Any]
- !!! abstract “Usage Documentation”
[model_dump](../concepts/serialization.md#python-mode)
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
- Parameters:
mode – The mode in which to_python should run. If mode is ‘json’, the output will only contain JSON serializable types. If mode is ‘python’, the output may contain non-JSON-serializable Python objects.
include – A set of fields to include in the output.
exclude – A set of fields to exclude from the output.
context – Additional context to pass to the serializer.
by_alias – Whether to use the field’s alias in the dictionary key if defined.
exclude_unset – Whether to exclude fields that have not been explicitly set.
exclude_defaults – Whether to exclude fields that are set to their default value.
exclude_none – Whether to exclude fields that have a value of None.
exclude_computed_fields – Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated round_trip parameter instead.
round_trip – If True, dumped values should be valid as input for non-idempotent types such as Json[T].
warnings – How to handle serialization errors. False/”none” ignores them, True/”warn” logs errors, “error” raises a [PydanticSerializationError][pydantic_core.PydanticSerializationError].
fallback – A function to call when an unknown value is encountered. If not provided, a [PydanticSerializationError][pydantic_core.PydanticSerializationError] error is raised.
serialize_as_any – Whether to serialize fields with duck-typing serialization behavior.
polymorphic_serialization – Whether to use model and dataclass polymorphic serialization for this call.
- Returns:
A dictionary representation of the model.
- model_dump_json(*, indent: int | None = None, ensure_ascii: bool = False, include: set[int] | set[str] | Mapping[int, set[int] | set[str] | Mapping[int, IncEx | bool] | Mapping[str, IncEx | bool] | bool] | Mapping[str, set[int] | set[str] | Mapping[int, IncEx | bool] | Mapping[str, IncEx | bool] | bool] | None = None, exclude: set[int] | set[str] | Mapping[int, set[int] | set[str] | Mapping[int, IncEx | bool] | Mapping[str, IncEx | bool] | bool] | Mapping[str, set[int] | set[str] | Mapping[int, IncEx | bool] | Mapping[str, IncEx | bool] | bool] | None = None, context: Any | None = None, by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, exclude_computed_fields: bool = False, round_trip: bool = False, warnings: bool | Literal['none', 'warn', 'error'] = True, fallback: Callable[[Any], Any] | None = None, serialize_as_any: bool = False, polymorphic_serialization: bool | None = None) str
- !!! abstract “Usage Documentation”
[model_dump_json](../concepts/serialization.md#json-mode)
Generates a JSON representation of the model using Pydantic’s to_json method.
- Parameters:
indent – Indentation to use in the JSON output. If None is passed, the output will be compact.
ensure_ascii – If True, the output is guaranteed to have all incoming non-ASCII characters escaped. If False (the default), these characters will be output as-is.
include – Field(s) to include in the JSON output.
exclude – Field(s) to exclude from the JSON output.
context – Additional context to pass to the serializer.
by_alias – Whether to serialize using field aliases.
exclude_unset – Whether to exclude fields that have not been explicitly set.
exclude_defaults – Whether to exclude fields that are set to their default value.
exclude_none – Whether to exclude fields that have a value of None.
exclude_computed_fields – Whether to exclude computed fields. While this can be useful for round-tripping, it is usually recommended to use the dedicated round_trip parameter instead.
round_trip – If True, dumped values should be valid as input for non-idempotent types such as Json[T].
warnings – How to handle serialization errors. False/”none” ignores them, True/”warn” logs errors, “error” raises a [PydanticSerializationError][pydantic_core.PydanticSerializationError].
fallback – A function to call when an unknown value is encountered. If not provided, a [PydanticSerializationError][pydantic_core.PydanticSerializationError] error is raised.
serialize_as_any – Whether to serialize fields with duck-typing serialization behavior.
polymorphic_serialization – Whether to use model and dataclass polymorphic serialization for this call.
- Returns:
A JSON string representation of the model.
- property model_extra: dict[str, Any] | None
Get extra fields set during validation.
- Returns:
A dictionary of extra fields, or None if config.extra is not set to “allow”.
- model_fields = {'batch_size': FieldInfo(annotation=int, required=False, default=32, metadata=[Gt(gt=0)]), 'bounding_box_size': FieldInfo(annotation=Union[Annotated[float, Ge], NoneType], required=False, default=None), 'epoch_mode': FieldInfo(annotation=Literal['once', 'cycle'], required=False, default='cycle'), 'gradient_accumulation_steps': FieldInfo(annotation=int, required=False, default=1, metadata=[Gt(gt=0)]), 'init_noise': FieldInfo(annotation=Union[Annotated[float, Ge], NoneType], required=False, default=None), 'init_seed': FieldInfo(annotation=int, required=False, default=100), 'llc_weight_decay': FieldInfo(annotation=float, required=False, default=0.0, metadata=[Ge(ge=0)]), 'localization': FieldInfo(annotation=float, required=False, default=0.0, metadata=[Ge(ge=0)]), 'lr': FieldInfo(annotation=float, required=True, metadata=[Ge(ge=0)]), 'match_sampling_input_ids_across_chains': FieldInfo(annotation=bool, required=False, default=True), 'n_beta': FieldInfo(annotation=float, required=True, metadata=[Ge(ge=0)]), 'noise_level': FieldInfo(annotation=float, required=False, default=1.0, metadata=[Ge(ge=0)]), 'num_burnin_steps': FieldInfo(annotation=int, required=False, default=0, metadata=[Ge(ge=0)]), 'num_chains': FieldInfo(annotation=int, required=False, default=4, metadata=[Gt(gt=0)]), 'num_draws': FieldInfo(annotation=int, required=False, default=200, metadata=[Gt(gt=0)]), 'num_init_loss_batches': FieldInfo(annotation=int, required=False, default=32, metadata=[Gt(gt=0)]), 'num_steps_bw_draws': FieldInfo(annotation=int, required=False, default=1, metadata=[Gt(gt=0)]), 'sampling_method': FieldInfo(annotation=Literal['sgmcmc_sgld', 'rmsprop_sgld'], required=False, default='sgmcmc_sgld'), 'sampling_method_kwargs': FieldInfo(annotation=dict[str, Any], required=False, default_factory=dict), 'save_metrics': FieldInfo(annotation=bool, required=False, default=False), 'shuffle': FieldInfo(annotation=bool, required=False, default=True)}
- property model_fields_set: set[str]
Returns the set of fields that have been explicitly set on this model instance.
- Returns:
- A set of strings representing the fields that have been set,
i.e. that were not filled from defaults.
- classmethod model_json_schema(by_alias: bool = True, ref_template: str = '#/$defs/{model}', schema_generator: type[~pydantic.json_schema.GenerateJsonSchema] = <class 'pydantic.json_schema.GenerateJsonSchema'>, mode: ~typing.Literal['validation', 'serialization'] = 'validation', *, union_format: ~typing.Literal['any_of', 'primitive_type_array'] = 'any_of') dict[str, Any]
Generates a JSON schema for a model class.
- Parameters:
by_alias – Whether to use attribute aliases or not.
ref_template – The reference template.
union_format –
The format to use when combining schemas from unions together. Can be one of:
’any_of’: Use the [anyOf](https://json-schema.org/understanding-json-schema/reference/combining#anyOf)
keyword to combine schemas (the default). - ‘primitive_type_array’: Use the [type](https://json-schema.org/understanding-json-schema/reference/type) keyword as an array of strings, containing each type of the combination. If any of the schemas is not a primitive type (string, boolean, null, integer or number) or contains constraints/metadata, falls back to any_of.
schema_generator – To override the logic used to generate the JSON schema, as a subclass of GenerateJsonSchema with your desired modifications
mode – The mode in which to generate the schema.
- Returns:
The JSON schema for the given model class.
- classmethod model_parametrized_name(params: tuple[type[Any], ...]) str
Compute the class name for parametrizations of generic classes.
This method can be overridden to achieve a custom naming scheme for generic BaseModels.
- Parameters:
params – Tuple of types of the class. Given a generic class Model with 2 type variables and a concrete model Model[str, int], the value (str, int) would be passed to params.
- Returns:
String representing the new class where params are passed to cls as type variables.
- Raises:
TypeError – Raised when trying to generate concrete names for non-generic models.
- model_post_init(context: Any, /) None
Override this method to perform additional initialization after __init__ and model_construct. This is useful if you want to do some validation that requires the entire model to be initialized.
- classmethod model_rebuild(*, force: bool = False, raise_errors: bool = True, _parent_namespace_depth: int = 2, _types_namespace: MappingNamespace | None = None) bool | None
Try to rebuild the pydantic-core schema for the model.
This may be necessary when one of the annotations is a ForwardRef which could not be resolved during the initial attempt to build the schema, and automatic rebuilding fails.
- Parameters:
force – Whether to force the rebuilding of the model schema, defaults to False.
raise_errors – Whether to raise errors, defaults to True.
_parent_namespace_depth – The depth level of the parent namespace, defaults to 2.
_types_namespace – The types namespace, defaults to None.
- Returns:
Returns None if the schema is already “complete” and rebuilding was not required. If rebuilding _was_ required, returns True if rebuilding was successful, otherwise False.
- classmethod model_validate(obj: Any, *, strict: bool | None = None, extra: Literal['allow', 'ignore', 'forbid'] | None = None, from_attributes: bool | None = None, context: Any | None = None, by_alias: bool | None = None, by_name: bool | None = None) Self
Validate a pydantic model instance.
- Parameters:
obj – The object to validate.
strict – Whether to enforce types strictly.
extra – Whether to ignore, allow, or forbid extra data during model validation. See the [extra configuration value][pydantic.ConfigDict.extra] for details.
from_attributes – Whether to extract data from object attributes.
context – Additional context to pass to the validator.
by_alias – Whether to use the field’s alias when validating against the provided input data.
by_name – Whether to use the field’s name when validating against the provided input data.
- Raises:
ValidationError – If the object could not be validated.
- Returns:
The validated model instance.
- classmethod model_validate_json(json_data: str | bytes | bytearray, *, strict: bool | None = None, extra: Literal['allow', 'ignore', 'forbid'] | None = None, context: Any | None = None, by_alias: bool | None = None, by_name: bool | None = None) Self
- !!! abstract “Usage Documentation”
[JSON Parsing](../concepts/json.md#json-parsing)
Validate the given JSON data against the Pydantic model.
- Parameters:
json_data – The JSON data to validate.
strict – Whether to enforce types strictly.
extra – Whether to ignore, allow, or forbid extra data during model validation. See the [extra configuration value][pydantic.ConfigDict.extra] for details.
context – Extra variables to pass to the validator.
by_alias – Whether to use the field’s alias when validating against the provided input data.
by_name – Whether to use the field’s name when validating against the provided input data.
- Returns:
The validated Pydantic model.
- Raises:
ValidationError – If json_data is not a JSON string or the object could not be validated.
- classmethod model_validate_strings(obj: Any, *, strict: bool | None = None, extra: Literal['allow', 'ignore', 'forbid'] | None = None, context: Any | None = None, by_alias: bool | None = None, by_name: bool | None = None) Self
Validate the given object with string data against the Pydantic model.
- Parameters:
obj – The object containing string data to validate.
strict – Whether to enforce types strictly.
extra – Whether to ignore, allow, or forbid extra data during model validation. See the [extra configuration value][pydantic.ConfigDict.extra] for details.
context – Extra variables to pass to the validator.
by_alias – Whether to use the field’s alias when validating against the provided input data.
by_name – Whether to use the field’s name when validating against the provided input data.
- Returns:
The validated Pydantic model.
- n_beta: NonNegativeFloat
- noise_level: NonNegativeFloat
- num_burnin_steps: NonNegativeInt
- num_chains: PositiveInt
- num_draws: PositiveInt
- num_init_loss_batches: PositiveInt
- num_steps_bw_draws: PositiveInt
- sampling_method: SamplingMethodLiteral
- sampling_method_kwargs: dict[str, Any]
- save_metrics: bool
- shuffle: bool
devinterp.slt.covariance module
Batched covariance and correlation computation. Port of aether’s covariance_utils.py — torch-based batched correlation for BIF computation.
- devinterp.slt.covariance.batch_corrcoef(batched_a: Tensor, batched_b: Tensor) Tensor[source]
Batched Pearson correlation between all pairs from batched_a and batched_b.
- Parameters:
batched_a – shape (n_a, series_a, observations), float32 or float64
batched_b – shape (n_b, series_b, observations), float32 or float64
- Returns:
shape (n_a, n_b, series_a + series_b, series_a + series_b)
- devinterp.slt.covariance.batch_cov(batched_a: Tensor, batched_b: Tensor) Tensor[source]
Batched covariance between all pairs from batched_a and batched_b.
- Parameters:
batched_a – shape (n_a, series_a, observations)
batched_b – shape (n_b, series_b, observations)
- Returns:
shape (n_a, n_b, series_a + series_b, series_a + series_b)
- devinterp.slt.covariance.xr_corrcoef_with_torch_backend(seq1: DataArray, seq2: DataArray, *, device: str | device, compute_covariance: bool = False) Dataset[source]
Full Pearson correlation matrix between rows of seq1 and seq2 using torch.
- Parameters:
seq1 – 2-D DataArray (e.g. batch, chain_draw)
seq2 – 2-D DataArray with same dims as seq1
device – torch device for computation
compute_covariance – if True, compute covariance instead of correlation
- Returns:
Dataset with “correlation” variable of shape (dim_1, dim_1_T)
devinterp.slt.llc module
Local Learning Coefficient (LLC) computation from sampling results.
Computes LLC from the stored per-draw losses, without needing callbacks.
LLC = n_beta * (mean_sampling_loss - init_loss)
Two entry points: - llc(): high-level, takes model + dataset, runs sampling + LLC - compute_llc(): low-level, takes pre-computed sampling DataTree
- devinterp.slt.llc.compute_llc(samples: DataTree) Dataset[source]
Compute LLC from a sampling DataTree.
Matches aether’s calculate action with function: llc: averages sampling_loss_micro over every step (including burn-in) and every micro-batch, then subtracts mean init_loss and scales by n_beta.
- Parameters:
samples – DataTree output from sample(), containing sampling_loss_micro, init_loss, and n_beta.
- Returns:
llc_mean: scalar, mean LLC across chains llc_std: scalar, std LLC across chains llc_per_chain: (chain,) LLC per chain llc_scalar: scalar LLC matching aether’s calculate action loss_trace: (chain, step) mean loss per chain per step init_loss: scalar, mean init loss
- Return type:
xr.Dataset with
- devinterp.slt.llc.llc(model: Module, dataset: Dataset, observables: dict[str, Dataset | tuple[Dataset, int]], *, lr: float, n_beta: float, param_masks: dict[str, Tensor | None] | None = None, loss_fn: Callable[[Module, Tensor], Tensor] | None = None, **kwargs) Dataset[source]
Sample and compute LLC in one call.
- Parameters:
model – PyTorch model.
dataset – HuggingFace Dataset with “input_ids” column.
observables – Dict mapping names to datasets (or (dataset, batches_per_draw) tuples).
lr – SGLD learning rate.
n_beta – SGLD inverse temperature.
param_masks – Which parameters to optimize. None for full model.
loss_fn – Optional custom per-token loss (model, input_ids) -> (batch, seq-1). Defaults to cross-entropy on the model’s logits.
**kwargs – Additional arguments passed to sample() (num_chains, num_draws, batch_size, output_path, etc.)
- Returns:
xr.Dataset with llc_mean, llc_std, llc_per_chain, loss_trace, init_loss.
devinterp.slt.lm_loss module
Model-agnostic loss computation for language models.
- exception devinterp.slt.lm_loss.NonFiniteLogitsError[source]
Bases:
ValueErrorRaised when logits contain NaNs or Infs.
- add_note()
Exception.add_note(note) – add a note to the exception
- with_traceback()
Exception.with_traceback(tb) – set self.__traceback__ to tb and return self.
- devinterp.slt.lm_loss.compute_per_token_loss(model: Module, input_ids: Tensor) Tensor[source]
Compute per-token cross-entropy loss. Returns shape (batch, seq-1).
- devinterp.slt.lm_loss.lm_cross_entropy_loss(logits: Tensor, input_ids: Tensor) Tensor[source]
Per-token cross entropy loss. Returns shape (batch, seq-1).
- devinterp.slt.lm_loss.lm_forward_logits(model: Module, input_ids: Tensor) Tensor[source]
Run a forward pass and return logits. Handles HF and TransformerLens models.
- devinterp.slt.lm_loss.make_evaluate_fn(loss_fn: Callable[[Module, Tensor], Tensor] | None = None) Callable[[Module, dict[str, Any]], tuple[Tensor, dict[str, Any]]][source]
Create an evaluation function returning unreduced per-token loss.
Returns (loss, {}) matching the (loss, results) protocol expected by sample_single_chain. If loss_fn is None, uses compute_per_token_loss (cross-entropy on the model’s logits).
devinterp.slt.observables module
Observable: evaluates probe datasets during SGLD sampling.
Each observable wraps a dataset and computes per-token losses at each draw. Input IDs are fixed (same sequences every draw via DeterministicShuffledSampler).
- class devinterp.slt.observables.DeterministicShuffledSampler(data_source: Dataset, num_samples: int, seed: int = 42)[source]
Bases:
SamplerA sampler that returns a fixed shuffled order of indices, deterministic from seed.
- class devinterp.slt.observables.Observable(*, dataset: Dataset, task_name: str, batches_per_draw: int, batch_size: int, context_length: int, device: device, seed: int = 1337, loss_fn: Callable[[Module, Tensor], Tensor] | None = None)[source]
Bases:
objectEvaluates a probe dataset at each SGLD draw.
On construction, loads fixed input_ids (same sequences every draw). At each draw, compute_loss(model) returns per-token losses.
- obs_id
Identifier derived from task_name (e.g. “pile_github”).
- input_ids
Fixed input_ids tensor, shape (n_samples, ctx_len+1).
- n_samples
batch_size * batches_per_draw.
- context_length
Number of predicted positions (ctx_len).
devinterp.slt.sampler module
SGLD sampler: runs a single SGLD chain with callbacks.
The inner loop for SGLD sampling. Called by sample() in sampling.py once per chain.
- class devinterp.slt.sampler.MicroCallback(*args, **kwargs)[source]
Bases:
ProtocolCalled once per micro-batch (inside the gradient accumulation loop).
- devinterp.slt.sampler.sample_single_chain(ref_model: ~torch.nn.modules.module.Module, dataset: ~torch.utils.data.dataset.Dataset, evaluate: ~typing.Callable[[~torch.nn.modules.module.Module, ~torch.Tensor], tuple[~torch.Tensor, dict[str, ~typing.Any]]], param_masks: dict[str, ~torch.Tensor | None], num_draws: int = 100, num_burnin_steps: int = 0, num_steps_bw_draws: int = 1, gradient_accumulation_steps: int = 1, sampling_method: type[~torch.optim.optimizer.Optimizer] = <class 'devinterp.optim.sgld.SGLD'>, sampling_method_kwargs: dict[str, ~typing.Any] | None = None, chain: int = 0, seed: int | None = None, dataloader_seed: int | None = None, device: str = 'cpu', callbacks: list[~typing.Callable] | None = None, step_callback: ~devinterp.slt.sampler.StepCallback | None = None, micro_callback: ~devinterp.slt.sampler.MicroCallback | None = None, batch_size: int = 32, init_noise: float | None = None, shuffle: bool = True, epoch_mode: ~typing.Literal['once', 'cycle'] = 'cycle') None[source]
Sample a single SGLD chain.
devinterp.slt.sampling module
SGLD sampling with observables, writing results to zarr.
Provides sample() as the main entry point. Internally uses sample_single_chain from sampler.py for the SGLD inner loop.
- devinterp.slt.sampling.sample(model: Module, dataset: Dataset, observables: dict[str, Dataset | tuple[Dataset, int]], *, lr: float, n_beta: float, param_masks: dict[str, Tensor | None] | None = None, num_chains: int = 4, num_draws: int = 200, batch_size: int = 32, num_burnin_steps: int = 0, num_steps_bw_draws: int = 1, num_init_loss_batches: int = 32, init_seed: int = 100, batches_per_draw: int = 3, obs_seed: int = 1337, gradient_accumulation_steps: int = 1, localization: float = 0.0, noise_level: float = 1.0, llc_weight_decay: float = 0.0, bounding_box_size: float | None = None, sampling_method: Literal['sgmcmc_sgld', 'rmsprop_sgld'] = 'sgmcmc_sgld', sampling_method_kwargs: dict[str, Any] | None = None, rmsprop_eps: float | None = None, rmsprop_alpha: float | None = None, shuffle: bool = True, match_sampling_input_ids_across_chains: bool = True, init_noise: float | None = None, device: str | None = None, save_metrics: bool = False, output_path: str | Path | None = None, loss_fn: Callable[[Module, Tensor], Tensor] | None = None) DataTree[source]
Run SGLD sampling with observables.
- Parameters:
model – PyTorch model.
dataset – HuggingFace Dataset with “input_ids” column, used for SGLD sampling.
observables – Dict mapping observable names to datasets (or (dataset, batches_per_draw) tuples). Each dataset must have an “input_ids” column.
lr – SGLD learning rate.
n_beta – SGLD inverse temperature.
param_masks – Which parameters to optimize. None means all parameters (full model). Otherwise a dict mapping param names to mask tensors (or None for unrestricted).
num_chains – Number of SGLD chains.
num_draws – Number of draws per chain.
batch_size – Batch size for sampling and observables.
num_burnin_steps – SGLD burn-in steps before drawing.
num_steps_bw_draws – Steps between draws.
num_init_loss_batches – Batches for init_loss computation.
init_seed – Random seed.
batches_per_draw – Default batches_per_draw for observables (used when an observable is specified as just a dataset, not a tuple).
obs_seed – Seed for deterministic observable sampling.
gradient_accumulation_steps – Number of micro-batches per optimizer step. Effective batch size is batch_size * gradient_accumulation_steps.
localization – Strength of the pull toward initial parameters (gamma in Lau et al. 2023). 0 disables localization.
noise_level – Standard deviation of SGLD noise. Defaults to 1.0; changing this breaks the SGLD posterior-sampling guarantee.
llc_weight_decay – L2 regularization strength (lambda). Applied as a Gaussian prior centered at zero.
bounding_box_size – If set, restricts sampling to a box of this radius around the initial parameters. None disables.
sampling_method – Which SGLD variant to use. “sgmcmc_sgld” is the default; “rmsprop_sgld” adds RMSprop-style preconditioning.
sampling_method_kwargs – Extra kwargs forwarded to the sampling-method constructor (e.g. rmsprop’s “alpha” / “eps”, or “add_grad_correction”). Use rmsprop_eps / rmsprop_alpha as convenience aliases for the two most common rmsprop knobs.
rmsprop_eps – RMSprop stability constant. Only valid when sampling_method=’rmsprop_sgld’. Shorthand for sampling_method_kwargs={“eps”: …}.
rmsprop_alpha – RMSprop moving-average coefficient. Only valid when sampling_method=’rmsprop_sgld’. Shorthand for sampling_method_kwargs={“alpha”: …}.
shuffle – Whether to shuffle the sampling dataset. Default True.
match_sampling_input_ids_across_chains – If True, every chain sees the same input_ids in the same order (only the SGLD noise differs across chains). If False, each chain gets an independently-seeded shuffle.
init_noise – If set, add Gaussian noise with this std to parameters before sampling.
device – Compute device. None for auto-detect.
save_metrics – If True, save per-step SGLD diagnostics (gradient norms, noise norms, distance from init, etc.) for tuning sampling parameters.
output_path – Path for output zarr. None for a temp directory.
loss_fn – Optional custom per-token loss (model, input_ids) -> (batch, seq-1). Defaults to cross-entropy on the model’s logits.
- Returns:
Lazy-loaded DataTree of sampling results.
devinterp.slt.susceptibilities module
Susceptibility computation.
Computes per-token susceptibilities from SGLD sampling results, as described in appendix C.4 of https://arxiv.org/pdf/2504.18274.
Two entry points: - susceptibilities(): high-level, takes model + datasets, runs sampling + post-processing - compute_susceptibilities(): low-level, takes pre-computed sampling DataTrees
- devinterp.slt.susceptibilities.compute_susceptibilities(wr_map: dict[str, DataTree], sampling_task: str, observable_names: list[str] | None = None, include_sampling_task: bool = False) DataTree[source]
Compute per-token susceptibilities from sampling results.
- Parameters:
wr_map – Dict mapping weight restriction names to DataTrees. Must include a “full” key for the unrestricted model. Each DataTree is the output of sample().
sampling_task – Name of the sampling/pretraining dataset task (e.g. “pile10k”). Must appear as an observable.
observable_names – Which observables to compute susceptibilities for. If None, uses all observables (optionally including sampling_task, see include_sampling_task).
include_sampling_task – If True and observable_names is None, include sampling_task in the discovered observables. Default False.
- Returns:
DataTree with /susceptibilities and /context subtrees.
- devinterp.slt.susceptibilities.susceptibilities(model: Module, dataset: Dataset, observables: dict[str, Dataset | tuple[Dataset, int]], weight_restrictions: dict[str, dict[str, Tensor | None] | None], *, sampling_task: str, lr: float, n_beta: float, loss_fn: Callable[[Module, Tensor], Tensor] | None = None, include_sampling_task: bool = False, **kwargs) DataTree[source]
Sample multiple weight restrictions and compute susceptibilities.
- Parameters:
model – PyTorch model.
dataset – HuggingFace Dataset with “input_ids” column.
observables – Dict mapping names to datasets (or (dataset, batches_per_draw) tuples).
weight_restrictions – Dict mapping WR names to param masks. Must include “full” (use None for full model).
sampling_task – Name of the sampling dataset (must be in observables).
lr – SGLD learning rate.
n_beta – SGLD inverse temperature.
loss_fn – Optional custom per-token loss (model, input_ids) -> (batch, seq-1). Defaults to cross-entropy on the model’s logits.
include_sampling_task – If True, compute susceptibilities for the sampling_task observable too (per-token variation within that task is still informative). Default False.
**kwargs – Additional arguments passed to sample() (num_chains, num_draws, batch_size, output_path, etc.). If output_path is provided, each weight restriction’s samples are saved to a separate zarr with the WR name appended (e.g. “samples_full.zarr”, “samples_l0h1.zarr”).
- Returns:
DataTree with /susceptibilities and /context subtrees.
devinterp.slt.weight_restrictions module
Weight restrictions for selecting subsets of model parameters.
Parses restriction strings like “full”, “l0”, “l0h1”, “l0g0”, “l0 attn”, “l0 mlp”, “embed”, “unembed” and returns ParamMasks suitable for the SGLD sampler.
Supports HuggingFace models (via auto-detected model_type) and TransformerLens HookedTransformer models, using the structure data in _model_structures.py.
- devinterp.slt.weight_restrictions.create_param_masks(model: Module, restriction: str | list[str]) dict[str, Tensor | None][source]
Create parameter masks from a restriction string.
- Parameters:
model – PyTorch model (HuggingFace or TransformerLens).
restriction – Which parameters to select. Options: - “full”: all parameters, no masks - “l0”: all params in layer 0 - “l0h1”: layer 0, head 1 (with per-element masks) - “l0g0”: layer 0, group 0 (GQA group: Q heads sharing a KV head) - “l0 attn”: layer 0 attention params - “l0 mlp”: layer 0 MLP params - “embed”: embedding params - “unembed”: unembedding params - list of the above: union (masks are ORed for shared params)
- Returns:
Dict mapping parameter names to mask tensors (or None for unrestricted).
- devinterp.slt.weight_restrictions.preview_weight_restriction(model: Module, masks: dict[str, Tensor | None], *, plain: bool = False) None[source]
Print a tree of which parameters are selected by a mask dict.
Shows how a weight restriction maps to actual state_dict keys, with a per-key selected/total count. Useful for debugging ambiguous selection patterns (e.g. “does ‘unembed’ include the layer norm?”).
- Parameters:
model – PyTorch model. Used to look up parameter sizes for the “selected of total” counts.
masks – Dict mapping param name to mask tensor (or None for unrestricted). Typically from create_param_masks or user-built.
plain – If True, print just one param name per line (useful for piping).
devinterp.slt.writing module
Zarr writing infrastructure for streaming sampling results.
ZarrWriter buffers per-chain data and writes to zarr arrays, optionally using a thread pool for async I/O.
- class devinterp.slt.writing.ZarrWriter(arrays: dict[str, Array], chain_buffer_size: int, buffer_device: device, executor: ThreadPoolExecutor | None = None)[source]
Bases:
objectBuffered writer for streaming per-chain data to zarr arrays.
Use ZarrWriter.open() for a context manager that manages the thread pool.
- arrays: dict[str, Array]
- buffer_device: device
- chain_buffer_size: int
devinterp.slt.zarr_schema module
Zarr schema for creating xarray-compatible zarr stores.
Creates a zarr store with typed arrays and group attributes in one batch call, producing a store that xr.open_datatree reads back as an equivalent DataTree.
- class devinterp.slt.zarr_schema.DataArraySpec(shape: tuple[int, ...], dims: tuple[str, ...], dtype_str: str, chunks: tuple[int, ...] = None, attrs: dict[str, Any] = None)[source]
Bases:
objectSpecification for a zarr array: shape, dims, dtype, chunks.
- dims: tuple[str, ...]
- dtype_str: str
- shape: tuple[int, ...]
- class devinterp.slt.zarr_schema.ZarrSchema(arrays_meta: dict[str, ~devinterp.slt.zarr_schema.DataArraySpec] = <factory>, group_attrs: dict[str, dict[str, ~typing.Any]] = <factory>)[source]
Bases:
objectSchema for a flat zarr store (arrays at root level with group attributes).
- arrays_meta: dict[str, DataArraySpec]
- create_hierarchy(store: Store) tuple[Group, dict[str, Array]][source]
Create the zarr store. Returns (root_group, arrays_dict).
- group_attrs: dict[str, dict[str, Any]]