Output Formats
All devinterp functions return xarray objects backed by Zarr. This page documents the exact structure of each output.
Working with Zarr Output
sample() writes results to a .zarr directory (by default in a temp
directory, or at output_path if specified). The returned xr.DataTree
is lazy-loaded — data is read from disk on access.
# Save to a specific path
tree = sample(..., output_path="my_experiment.zarr")
# Reopen later
import xarray as xr
tree = xr.open_datatree("my_experiment.zarr", engine="zarr", consolidated=False)
# Access data
loss = tree.dataset["sampling_loss"] # lazy xr.DataArray
loss.values # loads into memory as numpy array
# Post-process from saved results
from devinterp.slt.llc import compute_llc
result = compute_llc(tree)
sample() → xr.DataTree
The root dataset contains all arrays at the top level.
Variable |
Dimensions |
Description |
|---|---|---|
|
(chain, token_pos) |
Mean per-token loss before sampling, averaged over |
|
(chain, draw, batch, token_pos) |
Per-token loss on the sampling dataset at each draw |
|
(chain, draw, batch_{obs}, token_pos) |
Per-token loss on observable |
|
(batch_{obs}, token) |
Fixed input IDs for observable |
|
scalar |
Inverse temperature used for sampling |
When save_metrics=True, additional per-step SGLD diagnostics are included:
Variable |
Dimensions |
Description |
|---|---|---|
|
(chain, step) |
L2 norm of the scaled gradient component |
|
(chain, step) |
L2 norm of the unscaled (raw) gradient |
|
(chain, step) |
L2 norm of the localization force |
|
(chain, step) |
L2 norm of the weight decay component |
|
(chain, step) |
L2 norm of the injected noise |
|
(chain, step) |
L2 distance from initial parameters |
|
(chain, step) |
Dot product of gradient and prior components |
|
(chain, step) |
Dot product of gradient and noise |
|
(chain, step) |
Dot product of prior and noise |
Where step = num_burnin_steps + num_draws * num_steps_bw_draws.
Metadata is stored in tree.attrs["metadata"] as a nested dict containing
the SamplerConfig, observable specs, and other configuration.
compute_llc() → xr.Dataset
Variable |
Dimensions |
Description |
|---|---|---|
|
scalar |
Mean LLC across chains: |
|
scalar |
Std of LLC across chains |
|
(chain,) |
LLC per chain: |
|
scalar |
LLC with all dims reduced at once (for bitwise parity with aether) |
|
(chain, draw) |
Mean loss per chain per draw (averaged over batch and token_pos) |
|
scalar |
Mean init loss (averaged over all dims) |
compute_bif() → xr.Dataset
With correlation_method="token" (default):
Variable |
Dimensions |
Description |
|---|---|---|
|
(batch_1, batch_2, target_position, target_position_T) |
Token-wise pairwise correlation matrix between all sequence pairs.
With |
|
(batch, position) |
Concatenated input IDs from all observables |
With correlation_method="sequence":
Variable |
Dimensions |
Description |
|---|---|---|
|
(batch_1, batch_2) |
Sequence-level pairwise correlation (loss averaged over tokens first) |
|
(batch, position) |
Concatenated input IDs from all observables |
Coordinates batch_1 and batch_2 are integer indices into the
concatenated observable sequences. target_position is 1-indexed.
compute_susceptibilities() → xr.DataTree
/susceptibilities
sus (sus_flat, wr) — susceptibility values
Coordinates:
wr (wr,) — weight restriction names (excluding "full")
dataset_id (sus_flat,) — which observable each entry belongs to
batch (sus_flat,) — batch index within that observable
target_position (sus_flat,) — token position (1-indexed)
Attributes:
dataset_id_to_name — list mapping dataset_id integers to names
/context
input_ids (ctx_flat,) — flattened input token IDs
Coordinates:
dataset_id (ctx_flat,) — which observable
batch (ctx_flat,) — batch index
position (ctx_flat,) — token position (0-indexed, includes BOS)
Attributes:
dataset_id_to_name — same mapping as /susceptibilities
The sus_flat dimension flattens (batch, target_position) across all
observables. Use the dataset_id, batch, and target_position
coordinates to reconstruct the original structure.