devinterp.slt package¶
Submodules¶
devinterp.slt.callback module¶
- class devinterp.slt.callback.SamplerCallback(device: device | str = 'cpu')¶
Bases:
object
Base class for creating callbacks used in
devinterp.slt.sampler.get_results()
. Each instantiated callback gets its__call__
called every sample, andfinalize
called at the end of sample (if it exists). Each callback method can access parameters inlocals()
, so there’s no need to pass variables along explicitly.- Parameters:
device – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’.
- Raises:
NotImplementedError – if :python: __call__ :python: sample are not overwritten.
Note
mps
devices might not work for all callbacks.
- finalize(*args, **kwargs)¶
Gets called at the end of sampling. Can access any variable in
locals()
when called. Should be used for calucalting stats over chains, for example average chain loss.
- get_results(*args, **kwargs)¶
Does not get called automatically, but functions as an interface to easily access stats calculated by the callback.
devinterp.slt.cov module¶
- class devinterp.slt.cov.BetweenLayerCovarianceAccumulator(model, pairs: Dict[str, Tuple[str, str]], device: device | str = 'cpu', num_evals: int = 3, **accessors: Callable[[Module], Tensor])¶
Bases:
object
A CovarianceAccumulator to compute covariance between arbitrary layers. For use with
devinterp.slt.sampler.sample()
.- Parameters:
model (torch.nn.Module) – The model to compute covariances on.
pairs (Dict[str, Tuple[str, str]]) – Named pairs of layers to compute covariances on.
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
num_evals (int) – number of eigenvectors / eigenvalues to compute. Default is 3
accessors (Callable[[nn.Module], torch.Tensor]) – Functions to access attention head weights.
- get_results()¶
- Returns:
A dict with named_pairs keys, with corresponding values
{"evals": eigenvalues_of_cov_matrix, "evecs": eigenvectors_of_cov_matrix, "matrix": cov_matrix}
. (Only after runningdevinterp.slt.sampler.sample(..., [covariance_accumulator_instance], ...)
).
- class devinterp.slt.cov.CovarianceAccumulator(num_weights: int, accessors: List[Callable[[Module], Tensor]], device: device | str = 'cpu', num_evals: int = 3)¶
Bases:
SamplerCallback
A callback to iteratively compute and store the covariance matrix of model weights. For passing along to
devinterp.slt.sampler.sample()
.- Parameters:
num_weights (int) – Total number of weights.
accessors (List[Callable[[nn.Module], torch.Tensor]]) – Functions to access model weights.
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
num_evals (int, optional) – Number of eigenvalues to compute. Default is 3
- get_results()¶
- Returns:
A dict
{"evals": eigenvalues_of_cov_matrix, "evecs": eigenvectors_of_cov_matrix, "matrix": cov_matrix}
. (Only after runningdevinterp.slt.sampler.sample(..., [covariance_accumulator_instance], ...)
)
- class devinterp.slt.cov.WithinHeadCovarianceAccumulator(num_heads: int, num_weights_per_head: int, accessors: List[Callable[[Module], Tuple[Tensor, ...]]], device: device | str = 'cpu', num_evals: int = 3)¶
Bases:
object
A CovarianceAccumulator to compute covariance within attention heads. For use with
devinterp.slt.sampler.sample()
.- Parameters:
num_heads (int) – The number of attention heads.
num_weights_per_head (int) – The number of weights per attention head.
accessors (List[Callable[[nn.Module], Tuple[torch.Tensor, ...]]]) – Functions to access attention head weights.
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
num_evals (int, optional) – number of eigenvectors / eigenvalues to compute. Default is 3
- get_results()¶
- Returns:
A dict
{"evals": array_of_eigenvalues_of_cov_matrix_layer_idx_head_idx, "evecs": array_of_eigenvectors_of_cov_matrix_layer_idx_head_idx, "matrix": array_of_cov_matrices_layer_idx_head_idx}
. (Only after runningdevinterp.slt.sampler.sample(..., [covariance_accumulator_instance], ...)
).
devinterp.slt.gradient module¶
- class devinterp.slt.gradient.GradientDistribution(num_chains: int, num_draws: int, min_bins: int = 20, param_names: List[str] | None = None, device: device | str = 'cpu')¶
Bases:
SamplerCallback
Callback for plotting the distribution of gradients as a function of draws. Does some magic to automatically adjust bins as more draws are taken. For use with
devinterp.slt.sampler.sample()
.- Parameters:
num_draws (int) – Number of samples to draw (should be identical to
num_draws
passed todevinterp.slt.sampler.sample
)num_chains – Number of chains to run (should be identical to
num_chains
passed todevinterp.slt.sampler.sample
)num_chains – int
min_bins (int, optional) – Minimum number of bins for histogram approximation. Default is 20
param_names (List[str], optional) – List of parameter names to track. If None, all parameters are tracked. Default is None
device (: str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’.
- Raises:
ValueError – If gradients are not computed before calling this callback.
- finalize(*args, **kwargs)¶
Gets called at the end of sampling. Can access any variable in
locals()
when called. Should be used for calucalting stats over chains, for example average chain loss.
- get_results()¶
- Returns:
A dict
{"gradient/distributions": grad_dists}
. (Only after runningdevinterp.slt.sampler.sample(..., [gradient_dist_instance], ...)
)
- plot(param_name: str, color='blue', plot_zero=True, chain: int | None = None)¶
Plots the gradient distribution for a specific parameter.
- Parameters:
param_name (str) – the name of the parameter plot gradients for.
color (str, optional) – The color to plot gradient bins in. Default is blue
plot_zero (bool, optional) – Whether to plot the line through y=0. Default is True
chain (int, optional) – The model to compute covariances on.
- Returns:
None, but shows the denisty gradient bins over sampling steps.
devinterp.slt.llc module¶
- class devinterp.slt.llc.LLCEstimator(num_chains: int, num_draws: int, init_loss: Tensor, device: device | str = 'cpu', eval_field: str = 'loss', nbeta: float | None = None, temperature: float | None = None)¶
Bases:
SamplerCallback
Callback for estimating the Local Learning Coefficient (LLC) in a rolling fashion during a sampling process. It calculates the LLC based on the average loss across draws for each chain:
$$LLC = textrm{n beta} * (textrm{avg_loss} - textrm{init_loss})$$
For use with
devinterp.slt.sampler.sample()
.- Parameters:
num_draws (int) – Number of samples to draw (should be identical to
num_draws
passed todevinterp.slt.sampler.sample
)num_chains (int) – Number of chains to run (should be identical to
num_chains
passed todevinterp.slt.sampler.sample
)nbeta (int) – Effective Inverse Temperature, float (default: 1., set by sample() to utils.optimal_nbeta(dataloader)=len(batch_size)/np.log(len(batch_size)))
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’.
- get_results()¶
- Returns:
A dict
{"llc/mean": llc_mean, "llc/std": llc_std, "llc-chain/{i}": llc_trace_per_chain, "loss/trace": loss_trace_per_chain}
. (Only after runningdevinterp.slt.sampler.sample(..., [llc_estimator_instance], ...)
).
- class devinterp.slt.llc.OnlineLLCEstimator(num_chains: int, num_draws: int, init_loss, device='cpu', eval_field='loss', nbeta: float | None = None, temperature: float | None = None)¶
Bases:
SamplerCallback
Callback for estimating the Local Learning Coefficient (LLC) in an online fashion during a sampling process. It calculates LLCs using the same formula as
devinterp.slt.llc.LLCEstimator()
, but continuously and including means and std across draws (as opposed to just across chains). For use withdevinterp.slt.sampler.sample()
.- Parameters:
num_draws (int) – Number of samples to draw (should be identical to
num_draws
passed todevinterp.slt.sampler.sample
)num_chains (int) – Number of chains to run (should be identical to
num_chains
passed todevinterp.slt.sampler.sample
)nbeta (int) – Effective Inverse Temperature, float (default: 1., set by sample() to utils.optimal_nbeta(dataloader)=len(batch_size)/np.log(len(batch_size)))
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
- get_results()¶
- Returns:
A dict
{"llc/means": llc_means, "llc/stds": llc_stds, "llc/trace": llc_trace_per_chain, "loss/trace": loss_trace_per_chain}
. (Only after runningdevinterp.slt.sampler.sample(..., [llc_estimator_instance], ...)
).
devinterp.slt.loss module¶
- class devinterp.slt.loss.OnlineLossStatistics(base_callback: OnlineLLCEstimator)¶
Bases:
SamplerCallback
Derivative callback that computes various loss statistics for
OnlineLLCEstimator()
. Must be called after the baseOnlineLLCEstimator()
has been called at each draw.See the diagnostics notebook
for examples on how to use this to diagnose your sample health.
- Parameters:
base_callback (
OnlineLLCEstimator()
) – Base callback that computes original loss metric.
Note
Requires losses to be computed first, so call using f.e.
devinterp.slt.sampler.sample(..., [llc_estimator_instance, ..., online_loss_stats_instance], ...)
- finalize(*args, **kwargs)¶
Gets called at the end of sampling. Can access any variable in
locals()
when called. Should be used for calucalting stats over chains, for example average chain loss.
- get_results()¶
- Returns:
A dict
{"loss/percent_neg_steps": percent_neg_steps, "loss/percent_mean_neg_steps": percent_mean_neg_steps, "loss/percent_thresholded_neg_steps": percent_thresholded_neg_steps, "loss/z_scores": z_scores}
. (Only after runningdevinterp.slt.sampler.sample(..., [llc_estimator_instance, online_loss_stats_instance], ...)
)
- loss_hist_by_draw(draw: int = 0, bins: int = 10)¶
Plots a histogram of chain losses for a given draw index.
- Parameters:
draw (int, optional) – Draw index to plot histogram for. Default is 0
bins (int, optional) – number of histogram bins. Default is 10
devinterp.slt.mala module¶
- class devinterp.slt.mala.MalaAcceptanceRate(num_chains: int, num_draws: int, nbeta: float, learning_rate: float, device: device | str = 'cpu')¶
Bases:
SamplerCallback
Callback for computing MALA acceptance rate.
- num_draws¶
Number of samples to draw. (should be identical to param passed to sample())
- Type:
int
- num_chains¶
Number of chains to run. (should be identical to param passed to sample())
- Type:
int
- nbeta¶
Effective Inverse Temperature used to calculate the LLC.
- Type:
float
- learning_rate¶
Learning rate of the model.
- Type:
int
- device¶
Device to perform computations on, e.g., ‘cpu’ or ‘cuda’.
- Type:
Union[torch.device, str]
- finalize(*args, **kwargs)¶
Gets called at the end of sampling. Can access any variable in
locals()
when called. Should be used for calucalting stats over chains, for example average chain loss.
- devinterp.slt.mala.mala_acceptance_probability(prev_params: Tensor | List[Tensor], prev_grads: Tensor | List[Tensor], prev_loss: Tensor, current_params: Tensor | List[Tensor], current_grads: Tensor | List[Tensor], current_loss: Tensor, learning_rate: float) float ¶
Calculate the acceptance probability for a MALA transition. Parameters and gradients can either all be given as a tensor (all of the same shape) or all as lists of tensors (eg the parameters of a Module).
Args: prev_params: The previous point in parameter space. prev_grads: Gradient of the prev point in parameter space. prev_loss: Loss of the previous point in parameter space. current_params: The current point in parameter space. current_grads: Gradient of the current point in parameter space. current_loss: Loss of the current point in parameter space. learning_rate (float): Learning rate of the model.
Returns: float: Acceptance probability for the proposed transition.
devinterp.slt.norms module¶
- class devinterp.slt.norms.GradientNorm(num_chains: int, num_draws: int, p_norm: int = 2, device: device | str = 'cpu')¶
Bases:
SamplerCallback
Callback for computing the norm of the gradients of the optimizer / sampler.
- Parameters:
num_draws (int) – Number of samples to draw (should be identical to
num_draws
passed todevinterp.slt.sampler.sample
)num_chains (int) – Number of chains to run (should be identical to
num_chains
passed todevinterp.slt.sampler.sample
)p_norm (int, optional) – Order of the norm to be computed (e.g., 2 for Euclidean norm). Default is 2
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
- finalize(*args, **kwargs)¶
Gets called at the end of sampling. Can access any variable in
locals()
when called. Should be used for calucalting stats over chains, for example average chain loss.
- get_results()¶
- Returns:
A dict
{"gradient_norm/trace": gradient_norms}
. (Only after runningdevinterp.slt.sampler.sample(..., [grad_norm_instance], ...)
)
- class devinterp.slt.norms.NoiseNorm(num_chains: int, num_draws: int, p_norm: int = 2, device: device | str = 'cpu')¶
Bases:
SamplerCallback
Callback for computing the norm of the noise added in the optimizer / sampler.
- Parameters:
num_draws (int) – Number of samples to draw (should be identical to
num_draws
passed todevinterp.slt.sampler.sample
)num_chains (int) – Number of chains to run (should be identical to
num_chains
passed todevinterp.slt.sampler.sample
)p_norm (int, optional) – Order of the norm to be computed (e.g., 2 for Euclidean norm). Default is 2
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
- finalize(*args, **kwargs)¶
Gets called at the end of sampling. Can access any variable in
locals()
when called. Should be used for calucalting stats over chains, for example average chain loss.
- get_results()¶
- Returns:
A dict
{"noise_norm/trace": noise_norms}
. (Only after runningdevinterp.slt.sampler.sample(..., [noise_norm_instance], ...)
)
- class devinterp.slt.norms.WeightNorm(num_chains: int, num_draws: int, p_norm: int = 2, device: device | str = 'cpu')¶
Bases:
SamplerCallback
Callback for computing the norm of the weights over the sampling process.
- Parameters:
num_draws (int) – Number of samples to draw (should be identical to
num_draws
passed todevinterp.slt.sampler.sample
)num_chains (int) – Number of chains to run (should be identical to
num_chains
passed todevinterp.slt.sampler.sample
)p_norm (int, optional) – Order of the norm to be computed (e.g., 2 for Euclidean norm). Default is 2
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
- finalize(*args, **kwargs)¶
Gets called at the end of sampling. Can access any variable in
locals()
when called. Should be used for calucalting stats over chains, for example average chain loss.
- get_results()¶
- Returns:
A dict
{"weight_norm/trace": weight_norms}
. (Only after runningdevinterp.slt.sampler.sample(..., [weight_norm_instance], ...)
)
devinterp.slt.sampler module¶
devinterp.slt.trace module¶
- class devinterp.slt.trace.OnlineTraceStatistics(base_callback: SamplerCallback, attribute: str, device: device | str = 'cpu')¶
Bases:
SamplerCallback
Derivative callback that computes mean/std statistics of a specified trace online. Must be called after the base callback has been called at each draw.
See the diagnostics notebook
for examples on how to use this to diagnose your sample health.
- Parameters:
base_callback (
SamplerCallback()
) – Base callback that computes some metric.attribute (str) – Name of attribute to compute which mean/std statistics of.
device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
- Raises:
ValueError if underlying trace does not have the requested
attribute
,num_chains
ornum_draws
.
Note
Requires base trace stats to be computed first, so call using f.e.
devinterp.slt.sampler.sample(..., [weight_norm_instance, online_trace_stats_instance], ...)
- finalize(*args, **kwargs)¶
Gets called at the end of sampling. Can access any variable in
locals()
when called. Should be used for calucalting stats over chains, for example average chain loss.
- get_results()¶
- Returns:
A dict
"{self.attribute}/chain/mean": mean_attribute_by_chain, "{self.attribute}/chain/std": std_attribute_by_chain, "{self.attribute}/draw/mean": mean_attribute_by_draw, "{self.attribute}/draw/std": std_attribute_by_draw}
. (Only after runningdevinterp.slt.sampler.sample(..., [some_thing_to_calc_stats_of, ..., trace_stats_instance], ...)
).
- sample_at_draw(draw=-1)¶
- Parameters:
draw (int, optional) – draw index to return stats at. Default is -1
- Returns:
A dict
"{self.attribute}/chain/mean": mean_attribute_of_draw_by_chain, "{self.attribute}/chain/std": std_attribute_of_draw_by_chain, "{self.attribute}/draw/mean": mean_attribute_of_draw, "{self.attribute}/draw/std": std_attribute_of_draw}
. (Only after runningdevinterp.slt.sampler.sample(..., [some_thing_to_calc_stats_of, ..., trace_stats_instance], ...)
).
devinterp.slt.wbic module¶
- class devinterp.slt.wbic.OnlineWBICEstimator(num_chains: int, num_draws: int, n: int, device: device | str = 'cpu')¶
Bases:
SamplerCallback
Callback for estimating the Widely Applicable Bayesian Information Criterion (WBIC) in an online fashion. The WBIC used here is just $n * ( extrm{average sampled loss})$. (Watanabe, 2013)
- Parameters:
num_draws (int) – Number of samples to draw (should be identical to
num_draws
passed todevinterp.slt.sampler.sample
)num_chains (int) – Number of chains to run (should be identical to
num_chains
passed todevinterp.slt.sampler.sample
)n – Number of samples used to calculate the wbic.
n – int
device – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
device – str | torch.device, optional
- get_results()¶
- Returns:
A dict
{"wbic/means": wbic_means, "wbic/stds": wbic_stds, "wbic/trace": wbic_trace_per_chain, "loss/trace": loss_trace_per_chain}
. (Only after runningdevinterp.slt.sampler.sample(..., [wbic_estimator_instance], ...)
).