devinterp.slt package¶
Submodules¶
devinterp.slt.callback module¶
- class devinterp.slt.callback.SamplerCallback(device: device | str = 'cpu')¶
- Bases: - object- Base class for creating callbacks used in - devinterp.slt.sampler.get_results(). Each instantiated callback gets its- __call__called every sample, and- finalizecalled at the end of sample (if it exists). Each callback method can access parameters in- locals(), so there’s no need to pass variables along explicitly.- Parameters:
- device – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. 
- Raises:
- NotImplementedError – if :python: __call__ :python: sample are not overwritten. 
 - Note - mpsdevices might not work for all callbacks.
 - finalize(*args, **kwargs)¶
- Gets called at the end of sampling. Can access any variable in - locals()when called. Should be used for calucalting stats over chains, for example average chain loss.
 - get_results(*args, **kwargs)¶
- Does not get called automatically, but functions as an interface to easily access stats calculated by the callback. 
 
devinterp.slt.cov module¶
- class devinterp.slt.cov.BetweenLayerCovarianceAccumulator(model, pairs: Dict[str, Tuple[str, str]], device: device | str = 'cpu', num_evals: int = 3, **accessors: Callable[[Module], Tensor])¶
- Bases: - object- A CovarianceAccumulator to compute covariance between arbitrary layers. For use with - devinterp.slt.sampler.sample().- Parameters:
- model (torch.nn.Module) – The model to compute covariances on. 
- pairs (Dict[str, Tuple[str, str]]) – Named pairs of layers to compute covariances on. 
- device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’ 
- num_evals (int) – number of eigenvectors / eigenvalues to compute. Default is 3 
- accessors (Callable[[nn.Module], torch.Tensor]) – Functions to access attention head weights. 
 
 - get_results()¶
- Returns:
- A dict with named_pairs keys, with corresponding values - {"evals": eigenvalues_of_cov_matrix, "evecs": eigenvectors_of_cov_matrix, "matrix": cov_matrix}. (Only after running- devinterp.slt.sampler.sample(..., [covariance_accumulator_instance], ...)).
 
 
- class devinterp.slt.cov.CovarianceAccumulator(num_weights: int, accessors: List[Callable[[Module], Tensor]], device: device | str = 'cpu', num_evals: int = 3)¶
- Bases: - SamplerCallback- A callback to iteratively compute and store the covariance matrix of model weights. For passing along to - devinterp.slt.sampler.sample().- Parameters:
- num_weights (int) – Total number of weights. 
- accessors (List[Callable[[nn.Module], torch.Tensor]]) – Functions to access model weights. 
- device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’ 
- num_evals (int, optional) – Number of eigenvalues to compute. Default is 3 
 
 - get_results()¶
- Returns:
- A dict - {"evals": eigenvalues_of_cov_matrix, "evecs": eigenvectors_of_cov_matrix, "matrix": cov_matrix}. (Only after running- devinterp.slt.sampler.sample(..., [covariance_accumulator_instance], ...))
 
 
- class devinterp.slt.cov.WithinHeadCovarianceAccumulator(num_heads: int, num_weights_per_head: int, accessors: List[Callable[[Module], Tuple[Tensor, ...]]], device: device | str = 'cpu', num_evals: int = 3)¶
- Bases: - object- A CovarianceAccumulator to compute covariance within attention heads. For use with - devinterp.slt.sampler.sample().- Parameters:
- num_heads (int) – The number of attention heads. 
- num_weights_per_head (int) – The number of weights per attention head. 
- accessors (List[Callable[[nn.Module], Tuple[torch.Tensor, ...]]]) – Functions to access attention head weights. 
- device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’ 
- num_evals (int, optional) – number of eigenvectors / eigenvalues to compute. Default is 3 
 
 - get_results()¶
- Returns:
- A dict - {"evals": array_of_eigenvalues_of_cov_matrix_layer_idx_head_idx, "evecs": array_of_eigenvectors_of_cov_matrix_layer_idx_head_idx, "matrix": array_of_cov_matrices_layer_idx_head_idx}. (Only after running- devinterp.slt.sampler.sample(..., [covariance_accumulator_instance], ...)).
 
 
devinterp.slt.gradient module¶
- class devinterp.slt.gradient.GradientDistribution(num_chains: int, num_draws: int, min_bins: int = 20, param_names: List[str] | None = None, device: device | str = 'cpu')¶
- Bases: - SamplerCallback- Callback for plotting the distribution of gradients as a function of draws. Does some magic to automatically adjust bins as more draws are taken. For use with - devinterp.slt.sampler.sample().- Parameters:
- num_draws (int) – Number of samples to draw (should be identical to - num_drawspassed to- devinterp.slt.sampler.sample)
- num_chains – Number of chains to run (should be identical to - num_chainspassed to- devinterp.slt.sampler.sample)
- num_chains – int 
- min_bins (int, optional) – Minimum number of bins for histogram approximation. Default is 20 
- param_names (List[str], optional) – List of parameter names to track. If None, all parameters are tracked. Default is None 
- device (: str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. 
 
- Raises:
- ValueError – If gradients are not computed before calling this callback. 
 - finalize(*args, **kwargs)¶
- Gets called at the end of sampling. Can access any variable in - locals()when called. Should be used for calucalting stats over chains, for example average chain loss.
 - get_results()¶
- Returns:
- A dict - {"gradient/distributions": grad_dists}. (Only after running- devinterp.slt.sampler.sample(..., [gradient_dist_instance], ...))
 
 - plot(param_name: str, color='blue', plot_zero=True, chain: int | None = None)¶
- Plots the gradient distribution for a specific parameter. - Parameters:
- param_name (str) – the name of the parameter plot gradients for. 
- color (str, optional) – The color to plot gradient bins in. Default is blue 
- plot_zero (bool, optional) – Whether to plot the line through y=0. Default is True 
- chain (int, optional) – The model to compute covariances on. 
 
- Returns:
- None, but shows the denisty gradient bins over sampling steps. 
 
 
devinterp.slt.llc module¶
- class devinterp.slt.llc.LLCEstimator(num_chains: int, num_draws: int, init_loss: Tensor, device: device | str = 'cpu', eval_field: str = 'loss', nbeta: float | None = None, temperature: float | None = None)¶
- Bases: - SamplerCallback- Callback for estimating the Local Learning Coefficient (LLC) in a rolling fashion during a sampling process. It calculates the LLC based on the average loss across draws for each chain: - $$LLC = textrm{n beta} * (textrm{avg_loss} - textrm{init_loss})$$ - For use with - devinterp.slt.sampler.sample().- Parameters:
- num_draws (int) – Number of samples to draw (should be identical to - num_drawspassed to- devinterp.slt.sampler.sample)
- num_chains (int) – Number of chains to run (should be identical to - num_chainspassed to- devinterp.slt.sampler.sample)
- nbeta (int) – Effective Inverse Temperature, float (default: 1., set by sample() to utils.optimal_nbeta(dataloader)=len(batch_size)/np.log(len(batch_size))) 
- device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. 
 
 - get_results()¶
- Returns:
- A dict - {"llc/mean": llc_mean, "llc/std": llc_std, "llc-chain/{i}": llc_trace_per_chain, "loss/trace": loss_trace_per_chain}. (Only after running- devinterp.slt.sampler.sample(..., [llc_estimator_instance], ...)).
 
 
- class devinterp.slt.llc.OnlineLLCEstimator(num_chains: int, num_draws: int, init_loss, device='cpu', eval_field='loss', nbeta: float | None = None, temperature: float | None = None)¶
- Bases: - SamplerCallback- Callback for estimating the Local Learning Coefficient (LLC) in an online fashion during a sampling process. It calculates LLCs using the same formula as - devinterp.slt.llc.LLCEstimator(), but continuously and including means and std across draws (as opposed to just across chains). For use with- devinterp.slt.sampler.sample().- Parameters:
- num_draws (int) – Number of samples to draw (should be identical to - num_drawspassed to- devinterp.slt.sampler.sample)
- num_chains (int) – Number of chains to run (should be identical to - num_chainspassed to- devinterp.slt.sampler.sample)
- nbeta (int) – Effective Inverse Temperature, float (default: 1., set by sample() to utils.optimal_nbeta(dataloader)=len(batch_size)/np.log(len(batch_size))) 
- device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’ 
 
 - get_results()¶
- Returns:
- A dict - {"llc/means": llc_means, "llc/stds": llc_stds, "llc/trace": llc_trace_per_chain, "loss/trace": loss_trace_per_chain}. (Only after running- devinterp.slt.sampler.sample(..., [llc_estimator_instance], ...)).
 
 
devinterp.slt.loss module¶
- class devinterp.slt.loss.OnlineLossStatistics(base_callback: OnlineLLCEstimator)¶
- Bases: - SamplerCallback- Derivative callback that computes various loss statistics for - OnlineLLCEstimator(). Must be called after the base- OnlineLLCEstimator()has been called at each draw.- See the diagnostics notebook - for examples on how to use this to diagnose your sample health. - Parameters:
- base_callback ( - OnlineLLCEstimator()) – Base callback that computes original loss metric.
 - Note - Requires losses to be computed first, so call using f.e. - devinterp.slt.sampler.sample(..., [llc_estimator_instance, ..., online_loss_stats_instance], ...)
 - finalize(*args, **kwargs)¶
- Gets called at the end of sampling. Can access any variable in - locals()when called. Should be used for calucalting stats over chains, for example average chain loss.
 - get_results()¶
- Returns:
- A dict - {"loss/percent_neg_steps": percent_neg_steps, "loss/percent_mean_neg_steps": percent_mean_neg_steps, "loss/percent_thresholded_neg_steps": percent_thresholded_neg_steps, "loss/z_scores": z_scores}. (Only after running- devinterp.slt.sampler.sample(..., [llc_estimator_instance, online_loss_stats_instance], ...))
 
 - loss_hist_by_draw(draw: int = 0, bins: int = 10)¶
- Plots a histogram of chain losses for a given draw index. - Parameters:
- draw (int, optional) – Draw index to plot histogram for. Default is 0 
- bins (int, optional) – number of histogram bins. Default is 10 
 
 
 
devinterp.slt.mala module¶
- class devinterp.slt.mala.MalaAcceptanceRate(num_chains: int, num_draws: int, nbeta: float, learning_rate: float, device: device | str = 'cpu')¶
- Bases: - SamplerCallback- Callback for computing MALA acceptance rate. - num_draws¶
- Number of samples to draw. (should be identical to param passed to sample()) - Type:
- int 
 
 - num_chains¶
- Number of chains to run. (should be identical to param passed to sample()) - Type:
- int 
 
 - nbeta¶
- Effective Inverse Temperature used to calculate the LLC. - Type:
- float 
 
 - learning_rate¶
- Learning rate of the model. - Type:
- int 
 
 - device¶
- Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. - Type:
- Union[torch.device, str] 
 
 - finalize(*args, **kwargs)¶
- Gets called at the end of sampling. Can access any variable in - locals()when called. Should be used for calucalting stats over chains, for example average chain loss.
 
- devinterp.slt.mala.mala_acceptance_probability(prev_params: Tensor | List[Tensor], prev_grads: Tensor | List[Tensor], prev_loss: Tensor, current_params: Tensor | List[Tensor], current_grads: Tensor | List[Tensor], current_loss: Tensor, learning_rate: float) float¶
- Calculate the acceptance probability for a MALA transition. Parameters and gradients can either all be given as a tensor (all of the same shape) or all as lists of tensors (eg the parameters of a Module). - Args: prev_params: The previous point in parameter space. prev_grads: Gradient of the prev point in parameter space. prev_loss: Loss of the previous point in parameter space. current_params: The current point in parameter space. current_grads: Gradient of the current point in parameter space. current_loss: Loss of the current point in parameter space. learning_rate (float): Learning rate of the model. - Returns: float: Acceptance probability for the proposed transition. 
devinterp.slt.norms module¶
- class devinterp.slt.norms.GradientNorm(num_chains: int, num_draws: int, p_norm: int = 2, device: device | str = 'cpu')¶
- Bases: - SamplerCallback- Callback for computing the norm of the gradients of the optimizer / sampler. - Parameters:
- num_draws (int) – Number of samples to draw (should be identical to - num_drawspassed to- devinterp.slt.sampler.sample)
- num_chains (int) – Number of chains to run (should be identical to - num_chainspassed to- devinterp.slt.sampler.sample)
- p_norm (int, optional) – Order of the norm to be computed (e.g., 2 for Euclidean norm). Default is 2 
- device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’ 
 
 - finalize(*args, **kwargs)¶
- Gets called at the end of sampling. Can access any variable in - locals()when called. Should be used for calucalting stats over chains, for example average chain loss.
 - get_results()¶
- Returns:
- A dict - {"gradient_norm/trace": gradient_norms}. (Only after running- devinterp.slt.sampler.sample(..., [grad_norm_instance], ...))
 
 
- class devinterp.slt.norms.NoiseNorm(num_chains: int, num_draws: int, p_norm: int = 2, device: device | str = 'cpu')¶
- Bases: - SamplerCallback- Callback for computing the norm of the noise added in the optimizer / sampler. - Parameters:
- num_draws (int) – Number of samples to draw (should be identical to - num_drawspassed to- devinterp.slt.sampler.sample)
- num_chains (int) – Number of chains to run (should be identical to - num_chainspassed to- devinterp.slt.sampler.sample)
- p_norm (int, optional) – Order of the norm to be computed (e.g., 2 for Euclidean norm). Default is 2 
- device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’ 
 
 - finalize(*args, **kwargs)¶
- Gets called at the end of sampling. Can access any variable in - locals()when called. Should be used for calucalting stats over chains, for example average chain loss.
 - get_results()¶
- Returns:
- A dict - {"noise_norm/trace": noise_norms}. (Only after running- devinterp.slt.sampler.sample(..., [noise_norm_instance], ...))
 
 
- class devinterp.slt.norms.WeightNorm(num_chains: int, num_draws: int, p_norm: int = 2, device: device | str = 'cpu')¶
- Bases: - SamplerCallback- Callback for computing the norm of the weights over the sampling process. - Parameters:
- num_draws (int) – Number of samples to draw (should be identical to - num_drawspassed to- devinterp.slt.sampler.sample)
- num_chains (int) – Number of chains to run (should be identical to - num_chainspassed to- devinterp.slt.sampler.sample)
- p_norm (int, optional) – Order of the norm to be computed (e.g., 2 for Euclidean norm). Default is 2 
- device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’ 
 
 - finalize(*args, **kwargs)¶
- Gets called at the end of sampling. Can access any variable in - locals()when called. Should be used for calucalting stats over chains, for example average chain loss.
 - get_results()¶
- Returns:
- A dict - {"weight_norm/trace": weight_norms}. (Only after running- devinterp.slt.sampler.sample(..., [weight_norm_instance], ...))
 
 
devinterp.slt.sampler module¶
devinterp.slt.trace module¶
- class devinterp.slt.trace.OnlineTraceStatistics(base_callback: SamplerCallback, attribute: str, device: device | str = 'cpu')¶
- Bases: - SamplerCallback- Derivative callback that computes mean/std statistics of a specified trace online. Must be called after the base callback has been called at each draw. - See the diagnostics notebook - for examples on how to use this to diagnose your sample health. - Parameters:
- base_callback ( - SamplerCallback()) – Base callback that computes some metric.
- attribute (str) – Name of attribute to compute which mean/std statistics of. 
- device (str | torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’ 
 
- Raises:
- ValueError if underlying trace does not have the requested - attribute,- num_chainsor- num_draws.
 - Note - Requires base trace stats to be computed first, so call using f.e. - devinterp.slt.sampler.sample(..., [weight_norm_instance, online_trace_stats_instance], ...)
 - finalize(*args, **kwargs)¶
- Gets called at the end of sampling. Can access any variable in - locals()when called. Should be used for calucalting stats over chains, for example average chain loss.
 - get_results()¶
- Returns:
- A dict - "{self.attribute}/chain/mean": mean_attribute_by_chain, "{self.attribute}/chain/std": std_attribute_by_chain, "{self.attribute}/draw/mean": mean_attribute_by_draw, "{self.attribute}/draw/std": std_attribute_by_draw}. (Only after running- devinterp.slt.sampler.sample(..., [some_thing_to_calc_stats_of, ..., trace_stats_instance], ...)).
 
 - sample_at_draw(draw=-1)¶
- Parameters:
- draw (int, optional) – draw index to return stats at. Default is -1 
- Returns:
- A dict - "{self.attribute}/chain/mean": mean_attribute_of_draw_by_chain, "{self.attribute}/chain/std": std_attribute_of_draw_by_chain, "{self.attribute}/draw/mean": mean_attribute_of_draw, "{self.attribute}/draw/std": std_attribute_of_draw}. (Only after running- devinterp.slt.sampler.sample(..., [some_thing_to_calc_stats_of, ..., trace_stats_instance], ...)).
 
 
devinterp.slt.wbic module¶
- class devinterp.slt.wbic.OnlineWBICEstimator(num_chains: int, num_draws: int, n: int, device: device | str = 'cpu')¶
- Bases: - SamplerCallback- Callback for estimating the Widely Applicable Bayesian Information Criterion (WBIC) in an online fashion. The WBIC used here is just $n * ( extrm{average sampled loss})$. (Watanabe, 2013) - Parameters:
- num_draws (int) – Number of samples to draw (should be identical to - num_drawspassed to- devinterp.slt.sampler.sample)
- num_chains (int) – Number of chains to run (should be identical to - num_chainspassed to- devinterp.slt.sampler.sample)
- n – Number of samples used to calculate the wbic. 
- n – int 
- device – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’ 
- device – str | torch.device, optional 
 
 - get_results()¶
- Returns:
- A dict - {"wbic/means": wbic_means, "wbic/stds": wbic_stds, "wbic/trace": wbic_trace_per_chain, "loss/trace": loss_trace_per_chain}. (Only after running- devinterp.slt.sampler.sample(..., [wbic_estimator_instance], ...)).