devinterp.optim package

Submodules

devinterp.optim.metrics module

class devinterp.optim.metrics.Metrics(scaled_grad: ~torch.Tensor = <factory>, unscaled_grad: ~torch.Tensor = <factory>, localization: ~torch.Tensor = <factory>, weight_decay: ~torch.Tensor = <factory>, noise: ~torch.Tensor = <factory>, distance: ~torch.Tensor = <factory>, dot_grad_prior: ~torch.Tensor = <factory>, dot_grad_noise: ~torch.Tensor = <factory>, dot_prior_noise: ~torch.Tensor = <factory>, numel: int = 0)[source]

Bases: object

Norms and dot products of SGMCMC parameter update components.

Each step, w += dw where dw = -(scaled_grad + prior) + noise.

Norm fields store L2 norms of the post-preconditioned update components (i.e. actual magnitudes applied to parameters, not raw gradients).

scaled_grad: (ε/2) · nβ · G · ∇L — preconditioned gradient unscaled_grad: (ε/2) · nβ · ∇L — raw gradient (no preconditioner) localization: (ε/2) · G · γ(w - w₀) — pull toward initial params weight_decay: (ε/2) · G · λw — L2 regularization noise: √ε · √G · η — stochastic exploration distance: w - w₀ — raw displacement from init

Dot product fields store inner products between the three main component vectors (scaled_grad, combined prior, noise).

dot_grad_prior: ⟨scaled_grad, localization + weight_decay⟩ dot_grad_noise: ⟨scaled_grad, noise⟩ dot_prior_noise: ⟨localization + weight_decay, noise⟩

Cosine similarities can be derived: cos = dot / (norm_a * norm_b).

Lifecycle (one Metrics per optimizer param group, on the group’s device):

  1. __init__: group[“metrics”] = Metrics().to(device)

  2. step(), start: group[“metrics”].zero_()

  3. step(), per-p: group[“metrics”].add_sum_squared_(…)

    group[“metrics”].add_dot_products_(…)

  4. step(), end: group[“metrics”].sqrt_norms_()

  5. get_metrics(): combine per-group metrics on CPU

DOT_FIELDS: ClassVar[tuple[str, ...]] = ('dot_grad_prior', 'dot_grad_noise', 'dot_prior_noise')
NORM_FIELDS: ClassVar[tuple[str, ...]] = ('scaled_grad', 'unscaled_grad', 'localization', 'weight_decay', 'noise', 'distance')
add_dot_products_(scaled_grad: Tensor, prior: Tensor, noise: Tensor) None[source]

Accumulate dot products between the three main component vectors.

Parameters:
  • scaled_grad – The preconditioned gradient vector.

  • prior – Combined prior vector (localization + weight_decay).

  • noise – The noise vector.

add_sum_squared_(scaled_grad: Tensor, unscaled_grad: Tensor, localization: Tensor, weight_decay: Tensor, noise: Tensor, distance: Tensor) None[source]

Accumulate sum-of-squares for each norm component in-place.

Casts to float32 before squaring to avoid precision loss with bf16/fp16 inputs (where squaring can overflow or underflow in the input dtype).

static aggregate(group_metrics: Iterable[Metrics]) Metrics[source]

Combine per-group metrics into a single Metrics on CPU.

Norms: re-square to get sum-of-squares, accumulate, then sqrt: ||[a; b]|| = sqrt(||a||^2 + ||b||^2).

Dot products: additive across disjoint parameter sets.

distance: Tensor
dot_grad_noise: Tensor
dot_grad_prior: Tensor
dot_prior_noise: Tensor
localization: Tensor
noise: Tensor
numel: int = 0
property prior: Tensor

||[localization; weight_decay]||₂.

Type:

Combined prior norm

scaled_grad: Tensor
sqrt_norms_() None[source]

Convert norm fields from sum-of-squares to L2 norms in-place.

to(device: str | device | int) Metrics[source]

Return a copy of these metrics on the specified device.

unscaled_grad: Tensor
weight_decay: Tensor
zero_() None[source]

Reset all metrics to zero in-place.

devinterp.optim.preconditioner module

class devinterp.optim.preconditioner.CompositePreconditioner(preconditioners: list[Preconditioner])[source]

Bases: Preconditioner

Combines multiple preconditioners by multiplying their coefficients

class devinterp.optim.preconditioner.IdentityPreconditioner[source]

Bases: Preconditioner

Identity preconditioning (i.e., no preconditioning)

class devinterp.optim.preconditioner.MaskPreconditioner(masks: list[Tensor | float])[source]

Bases: Preconditioner

Applies masks to the overall coefficient while keeping other coefficients at 1.0

Stores one mask per parameter in the parameter group.

class devinterp.optim.preconditioner.NHTPreconditioning(diffusion_factor: float = 0.01, eps: float = 1e-08)[source]

Bases: Preconditioner

Nose-Hoover Thermostat preconditioning

class devinterp.optim.preconditioner.Preconditioner[source]

Bases: ABC

Base class for preconditioners that generate coefficients for MCMC terms

abstract get_coefficients(param: Tensor, grad: Tensor, state: dict) PreconditionerCoefs[source]

Compute coefficients for gradient, prior, and noise terms Returns PreconditionerCoefs containing all coefficients Each coefficient can be a scalar or tensor of shape matching param

class devinterp.optim.preconditioner.PreconditionerCoefs(grad_coef: float | Tensor, prior_coef: float | Tensor, noise_coef: float | Tensor, overall_coef: float | Tensor, grad_correction: float | Tensor | None = None)[source]

Bases: NamedTuple

Coefficients returned by preconditioners

count(value, /)

Return number of occurrences of value.

grad_coef: float | Tensor

Alias for field number 0

grad_correction: float | Tensor | None

Alias for field number 4

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

noise_coef: float | Tensor

Alias for field number 2

overall_coef: float | Tensor

Alias for field number 3

prior_coef: float | Tensor

Alias for field number 1

class devinterp.optim.preconditioner.RMSpropPreconditioner(alpha: float = 0.99, eps: float = 0.1, add_grad_correction=False)[source]

Bases: Preconditioner

RMSprop-style diagonal preconditioning

devinterp.optim.prior module

class devinterp.optim.prior.CompositePrior(priors: list[Prior])[source]

Bases: Prior

Combines multiple priors, summing their contributions.

Always wraps even a single non-uniform prior (no short-circuit). Callers that need to decompose sub-priors (e.g. SGMCMC._update_metrics) can iterate self.priors uniformly.

key: str
class devinterp.optim.prior.GaussianPrior(localization: float, center: Literal['initial'] | Iterable[Tensor] | Real | None = 'initial')[source]

Bases: Prior

Gaussian prior with configurable center and precision

grad(param: Tensor, state: dict[str, Any]) Tensor[source]

Compute gradient of the prior. If state is provided, the prior center is looked up in the state dictionary using the instance key.

Parameters:
  • param – Parameter tensor

  • state – State dictionary

Returns:

Gradient tensor

initialize(params: Sequence[Tensor]) dict[Tensor, dict[str, Any]][source]

Initialize centers for all parameters

Parameters:

params – Iterator of model parameters

Returns:

State dictionary containing prior centers

key: str
class devinterp.optim.prior.Prior[source]

Bases: ABC

Abstract base class for parameter priors

abstract grad(param: Tensor, state: dict[str, Any]) Tensor[source]

Compute gradient of the prior

Parameters:
  • param – Parameter tensor

  • state – State dictionary

Returns:

Gradient tensor

abstract initialize(params: Sequence[Tensor]) dict[Tensor, dict[str, Any]][source]

Initialize prior for parameters

Parameters:

params – Iterator of model parameters

Returns:

Updated state dictionary

key: str
class devinterp.optim.prior.UniformPrior(box_size: float = inf)[source]

Bases: Prior

Uniform prior.

key: str

devinterp.optim.sgld module

class devinterp.optim.sgld.SGLD(params: Iterable[Parameter], *, lr: float = 0.01, noise_level: float = 1.0, weight_decay: float = 0.0, localization: float = 0.0, nbeta: Callable[[], float] | float = 1.0, bounding_box_size: float | None = None, save_metrics: bool = False, sketch_dim: int | None = None, sketch_seed: int = 0)[source]

Bases: Optimizer

Implements Stochastic Gradient Langevin Dynamics (SGLD) optimizer.

This optimizer blends Stochastic Gradient Descent (SGD) with Langevin Dynamics, introducing Gaussian noise to the gradient updates. This makes it sample weights from the posterior distribution, instead of optimizing weights.

This implementation follows Lau et al.’s (2023) implementation, which is a modification of Welling and Teh (2011) that omits the learning rate schedule and introduces an localization term that pulls the weights towards their initial values.

The equation for the update is as follows:

\[\Delta w_t = \frac{\epsilon}{2}\left(\frac{\beta n}{m} \sum_{i=1}^m \nabla \log p\left(y_{l_i} \mid x_{l_i}, w_t\right)+\gamma\left(w_0-w_t\right) - \lambda w_t\right) + N(0, \epsilon\sigma^2)\]

where \(w_t\) is the weight at time \(t\), \(\epsilon\) is the learning rate, \((\beta n)\) is the inverse temperature (we’re in the tempered Bayes paradigm), \(n\) is the number of training samples, \(m\) is the batch size, \(\gamma\) is the localization strength, \(\lambda\) is the weight decay strength, and \(\sigma\) is the noise term.

Example

>>> optimizer = SGLD(model.parameters(), lr=0.1, nbeta=utils.default_nbeta(dataloader))
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()

Note

  • localization is unique to this class and serves to guide the weights towards their original values. This is useful for estimating quantities over the local posterior.

  • noise_level is not intended to be changed, except when testing! Doing so will raise a warning.

  • Although this class is a subclass of torch.optim.Optimizer, this is a bit of a misnomer in this case. It’s not used for optimizing in LLC estimation, but rather for sampling from the posterior distribution around a point.

  • Hyperparameter optimization is more of an art than a science. Check out the calibration notebook colab6 for how to go about it in a simple case.

Parameters:
  • params (Iterable) – Iterable of parameters to optimize or dicts defining parameter groups. Either model.parameters() or something more fancy, just like other torch.optim.Optimizer classes.

  • lr (float, optional) – Learning rate \(\epsilon\). Default is 0.01

  • noise_level (float, optional) – Amount of Gaussian noise \(\sigma\) introduced into gradient updates. Don’t change this unless you know very well what you’re doing! Default is 1

  • weight_decay (float, optional) – L2 regularization term \(\lambda\), applied as weight decay. Default is 0

  • localization (float, optional) – Strength of the force \(\gamma\) pulling weights back to their initial values. Default is 0

  • nbeta (float or Callable, optional) – Inverse reparameterized temperature (otherwise known as n*beta or ~beta), float (default: 1., set to utils.default_nbeta(dataloader)=len(batch_size)/np.log(len(batch_size)))

  • bounding_box_size (float, optional) – the size of the bounding box enclosing our trajectory in parameter space. Default is None, in which case no bounding box is used.

  • save_metrics (bool, optional) – Whether to track metrics (scaled_grad, localization, weight_decay, noise norms) during optimization. Use get_metrics() to retrieve. Default is False

Raises:
  • Warning – if noise_level is set to anything other than 1

  • Warning – if nbeta is set to 1

add_param_group(param_group: dict[str, Any]) None

Add a param group to the Optimizer s param_groups.

This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the Optimizer as training progresses.

Parameters:

param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.

get_metrics() Metrics[source]

Aggregate metrics across all param groups into a single CPU Metrics.

get_sketches() SketchBuffer[source]

Return a CPU snapshot of the current sketch buffer.

iter_group_metrics() Iterator[Metrics][source]

Yield metrics for each param group.

load_state_dict(state_dict: dict[str, Any]) None

Load the optimizer state.

Parameters:

state_dict (dict) – optimizer state. Should be an object returned from a call to state_dict().

Note

The names of the parameters (if they exist under the “param_names” key of each param group in state_dict()) will not affect the loading process. To use the parameters’ names for custom cases (such as when the parameters in the loaded state dict differ from those initialized in the optimizer), a custom register_load_state_dict_pre_hook should be implemented to adapt the loaded dict accordingly. If param_names exist in loaded state dict param_groups they will be saved and override the current names, if present, in the optimizer state. If they do not exist in loaded state dict, the optimizer param_names will remain unchanged.

register_load_state_dict_post_hook(hook: Callable[[Optimizer], None], prepend: bool = False) RemovableHandle

Register a load_state_dict post-hook which will be called after load_state_dict() is called. It should have the following signature:

hook(optimizer) -> None

The optimizer argument is the optimizer instance being used.

The hook will be called with argument self after calling load_state_dict on self. The registered hook can be used to perform post-processing after load_state_dict has loaded the state_dict.

Parameters:
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided post hook will be fired before all the already registered post-hooks on load_state_dict. Otherwise, the provided hook will be fired after all the already registered post-hooks. (default: False)

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemoveableHandle

register_load_state_dict_pre_hook(hook: Callable[[Optimizer, dict[str, Any]], dict[str, Any] | None], prepend: bool = False) RemovableHandle

Register a load_state_dict pre-hook which will be called before load_state_dict() is called. It should have the following signature:

hook(optimizer, state_dict) -> state_dict or None

The optimizer argument is the optimizer instance being used and the state_dict argument is a shallow copy of the state_dict the user passed in to load_state_dict. The hook may modify the state_dict inplace or optionally return a new one. If a state_dict is returned, it will be used to be loaded into the optimizer.

The hook will be called with argument self and state_dict before calling load_state_dict on self. The registered hook can be used to perform pre-processing before the load_state_dict call is made.

Parameters:
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided pre hook will be fired before all the already registered pre-hooks on load_state_dict. Otherwise, the provided hook will be fired after all the already registered pre-hooks. (default: False)

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemoveableHandle

register_state_dict_post_hook(hook: Callable[[Optimizer, dict[str, Any]], dict[str, Any] | None], prepend: bool = False) RemovableHandle

Register a state dict post-hook which will be called after state_dict() is called.

It should have the following signature:

hook(optimizer, state_dict) -> state_dict or None

The hook will be called with arguments self and state_dict after generating a state_dict on self. The hook may modify the state_dict inplace or optionally return a new one. The registered hook can be used to perform post-processing on the state_dict before it is returned.

Parameters:
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided post hook will be fired before all the already registered post-hooks on state_dict. Otherwise, the provided hook will be fired after all the already registered post-hooks. (default: False)

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemoveableHandle

register_state_dict_pre_hook(hook: Callable[[Optimizer], None], prepend: bool = False) RemovableHandle

Register a state dict pre-hook which will be called before state_dict() is called.

It should have the following signature:

hook(optimizer) -> None

The optimizer argument is the optimizer instance being used. The hook will be called with argument self before calling state_dict on self. The registered hook can be used to perform pre-processing before the state_dict call is made.

Parameters:
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided pre hook will be fired before all the already registered pre-hooks on state_dict. Otherwise, the provided hook will be fired after all the already registered pre-hooks. (default: False)

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemoveableHandle

register_step_post_hook(hook: Callable[[Self, tuple[Any, ...], dict[str, Any]], None]) RemovableHandle

Register an optimizer step post hook which will be called after optimizer step.

It should have the following signature:

hook(optimizer, args, kwargs) -> None

The optimizer argument is the optimizer instance being used.

Parameters:

hook (Callable) – The user defined hook to be registered.

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemovableHandle

register_step_pre_hook(hook: Callable[[Self, tuple[Any, ...], dict[str, Any]], tuple[tuple[Any, ...], dict[str, Any]] | None]) RemovableHandle

Register an optimizer step pre hook which will be called before optimizer step.

It should have the following signature:

hook(optimizer, args, kwargs) -> None or modified args and kwargs

The optimizer argument is the optimizer instance being used. If args and kwargs are modified by the pre-hook, then the transformed values are returned as a tuple containing the new_args and new_kwargs.

Parameters:

hook (Callable) – The user defined hook to be registered.

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemovableHandle

state_dict() dict[str, Any]

Return the state of the optimizer as a dict.

It contains two entries:

  • state: a Dict holding current optimization state. Its content

    differs between optimizer classes, but some common characteristics hold. For example, state is saved per parameter, and the parameter itself is NOT saved. state is a Dictionary mapping parameter ids to a Dict with state corresponding to each parameter.

  • param_groups: a List containing all parameter groups where each

    parameter group is a Dict. Each parameter group contains metadata specific to the optimizer, such as learning rate and weight decay, as well as a List of parameter IDs of the parameters in the group. If a param group was initialized with named_parameters() the names content will also be saved in the state dict.

NOTE: The parameter IDs may look like indices but they are just IDs associating state with param_group. When loading from a state_dict, the optimizer will zip the param_group params (int IDs) and the optimizer param_groups (actual nn.Parameter s) in order to match state WITHOUT additional verification.

A returned state dict might look something like:

{
    'state': {
        0: {'momentum_buffer': tensor(...), ...},
        1: {'momentum_buffer': tensor(...), ...},
        2: {'momentum_buffer': tensor(...), ...},
        3: {'momentum_buffer': tensor(...), ...}
    },
    'param_groups': [
        {
            'lr': 0.01,
            'weight_decay': 0,
            ...
            'params': [0]
            'param_names' ['param0']  (optional)
        },
        {
            'lr': 0.001,
            'weight_decay': 0.5,
            ...
            'params': [1, 2, 3]
            'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional)
        }
    ]
}
step(noise_generator: Generator | None = None) None[source]

Perform a single SGLD optimization step.

zero_grad(set_to_none: bool = True) None

Reset the gradients of all optimized torch.Tensor s.

Parameters:

set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests zero_grad(set_to_none=True) followed by a backward pass, .grads are guaranteed to be None for params that did not receive a gradient. 3. torch.optim optimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).

devinterp.optim.sgmcmc module

class devinterp.optim.sgmcmc.SGMCMC(params, *, lr: float = 0.01, noise_level: float = 1.0, nbeta: float = 1.0, prior: Prior | Literal['initial'] | Iterable[Tensor] | float | None = None, preconditioner: Preconditioner | Literal['identity', 'rmsprop'] | None = 'identity', preconditioner_kwargs: dict | None = None, bounding_box_size: float | None = None, mask: Tensor | None = None, save_metrics: bool = False, sketch_dim: int | None = None, sketch_seed: int = 0)[source]

Bases: Optimizer

Unified Stochastic Gradient Markov Chain Monte Carlo (SGMCMC) optimizer.

This optimizer implements a general SGMCMC framework that unifies several common variants like SGLD, SGHMC, and SGNHT. It supports custom priors and preconditioners for flexible posterior sampling.

The general update rule follows:

\[Δθ_t = (ε/2)G(θ_t)(nβ/m ∑∇logp(y|x,θ_t) + ∇logp(θ_t)) + √(εG(θ_t))N(0,σ²)\]

where:

  • ε is learning rate

  • nβ is inverse temperature

  • m is batch size

  • G(θ) is the preconditioner

  • p(θ) is the prior distribution

  • σ is noise level

Parameters:
  • params (Iterable) – Iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – Learning rate ε (default: 0.01)

  • noise_level (float) – Standard deviation σ of the noise (default: 1.0)

  • nbeta (float) – Inverse temperature nβ (default: 1.0)

  • prior (Optional[Union[Prior, Literal["initial"], Iterable[torch.Tensor], float]]) – Prior distribution specification. Can be: In SGMCMC, weight_decay and localization should be implemented as a prior. See the rmsprop_sgld() method for an example. - GaussianPrior instance - string specifying center type - iterable of tensor centers - float specifying precision (default: None)

  • preconditioner (Optional[Union[Preconditioner, str]]) – Preconditioner specification. Can be: - “identity” for no preconditioning - “rmsprop” for RMSprop-style preconditioning - Preconditioner instance (default: None, equivalent to “identity”)

  • preconditioner_kwargs (Optional[dict]) – Additional keyword arguments for preconditioner

  • bounding_box_size (Optional[float]) – Size of bounding box around initial parameters

  • mask (Optional[torch.Tensor]) – Boolean mask for restricting updatable parameters

  • save_metrics (bool) – Whether to track metrics during training (default: False)

Example:

from devinterp.utils import default_nbeta

# Basic SGLD-style usage
optimizer = SGMCMC.sgld(
    model.parameters(),
    lr=0.1,
    nbeta=default_nbeta(dataloader)
)

# RMSprop-preconditioned with prior
optimizer = SGMCMC.rmsprop_sgld(
    model.parameters(),
    lr=0.01,
    localization=0.1,
    nbeta=default_nbeta(dataloader)
)

# SGNHT-style with thermostat
optimizer = SGMCMC.sgnht(
    model.parameters(),
    lr=0.01,
    diffusion_factor=0.01,
    nbeta=default_nbeta(dataloader)
)

# Training loop
for data, target in dataloader:
    optimizer.zero_grad()
    loss = criterion(model(data), target)
    loss.backward()
    optimizer.step()

Notes

  • Use the factory methods (sgld, rmsprop_sgld, sgnht) for easier initialization

  • nbeta should typically be set using devinterp.utils.default_nbeta() rather than manually

  • The prior helps explore the local posterior by pulling toward initialization

  • Use save_metrics=True and call get_metrics() to access tracked metrics

References

  • Welling & Teh (2011) - Original SGLD paper

  • Li et al. (2015) - RMSprop-SGLD

  • Ding et al. (2014) - SGNHT

  • Lau et al. (2023) - Implementation with localization term

add_param_group(param_group: dict[str, Any]) None

Add a param group to the Optimizer s param_groups.

This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the Optimizer as training progresses.

Parameters:

param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.

get_metrics() Metrics[source]

Aggregate metrics across all param groups into a single CPU Metrics.

get_params() Iterator[Tensor][source]

Helper to get all parameters

get_sketches() SketchBuffer[source]

Return a CPU snapshot of the current sketch buffer.

iter_group_metrics() Iterator[Metrics][source]

Yield metrics for each param group.

load_state_dict(state_dict: dict[str, Any]) None

Load the optimizer state.

Parameters:

state_dict (dict) – optimizer state. Should be an object returned from a call to state_dict().

Note

The names of the parameters (if they exist under the “param_names” key of each param group in state_dict()) will not affect the loading process. To use the parameters’ names for custom cases (such as when the parameters in the loaded state dict differ from those initialized in the optimizer), a custom register_load_state_dict_pre_hook should be implemented to adapt the loaded dict accordingly. If param_names exist in loaded state dict param_groups they will be saved and override the current names, if present, in the optimizer state. If they do not exist in loaded state dict, the optimizer param_names will remain unchanged.

register_load_state_dict_post_hook(hook: Callable[[Optimizer], None], prepend: bool = False) RemovableHandle

Register a load_state_dict post-hook which will be called after load_state_dict() is called. It should have the following signature:

hook(optimizer) -> None

The optimizer argument is the optimizer instance being used.

The hook will be called with argument self after calling load_state_dict on self. The registered hook can be used to perform post-processing after load_state_dict has loaded the state_dict.

Parameters:
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided post hook will be fired before all the already registered post-hooks on load_state_dict. Otherwise, the provided hook will be fired after all the already registered post-hooks. (default: False)

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemoveableHandle

register_load_state_dict_pre_hook(hook: Callable[[Optimizer, dict[str, Any]], dict[str, Any] | None], prepend: bool = False) RemovableHandle

Register a load_state_dict pre-hook which will be called before load_state_dict() is called. It should have the following signature:

hook(optimizer, state_dict) -> state_dict or None

The optimizer argument is the optimizer instance being used and the state_dict argument is a shallow copy of the state_dict the user passed in to load_state_dict. The hook may modify the state_dict inplace or optionally return a new one. If a state_dict is returned, it will be used to be loaded into the optimizer.

The hook will be called with argument self and state_dict before calling load_state_dict on self. The registered hook can be used to perform pre-processing before the load_state_dict call is made.

Parameters:
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided pre hook will be fired before all the already registered pre-hooks on load_state_dict. Otherwise, the provided hook will be fired after all the already registered pre-hooks. (default: False)

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemoveableHandle

register_state_dict_post_hook(hook: Callable[[Optimizer, dict[str, Any]], dict[str, Any] | None], prepend: bool = False) RemovableHandle

Register a state dict post-hook which will be called after state_dict() is called.

It should have the following signature:

hook(optimizer, state_dict) -> state_dict or None

The hook will be called with arguments self and state_dict after generating a state_dict on self. The hook may modify the state_dict inplace or optionally return a new one. The registered hook can be used to perform post-processing on the state_dict before it is returned.

Parameters:
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided post hook will be fired before all the already registered post-hooks on state_dict. Otherwise, the provided hook will be fired after all the already registered post-hooks. (default: False)

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemoveableHandle

register_state_dict_pre_hook(hook: Callable[[Optimizer], None], prepend: bool = False) RemovableHandle

Register a state dict pre-hook which will be called before state_dict() is called.

It should have the following signature:

hook(optimizer) -> None

The optimizer argument is the optimizer instance being used. The hook will be called with argument self before calling state_dict on self. The registered hook can be used to perform pre-processing before the state_dict call is made.

Parameters:
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided pre hook will be fired before all the already registered pre-hooks on state_dict. Otherwise, the provided hook will be fired after all the already registered pre-hooks. (default: False)

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemoveableHandle

register_step_post_hook(hook: Callable[[Self, tuple[Any, ...], dict[str, Any]], None]) RemovableHandle

Register an optimizer step post hook which will be called after optimizer step.

It should have the following signature:

hook(optimizer, args, kwargs) -> None

The optimizer argument is the optimizer instance being used.

Parameters:

hook (Callable) – The user defined hook to be registered.

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemovableHandle

register_step_pre_hook(hook: Callable[[Self, tuple[Any, ...], dict[str, Any]], tuple[tuple[Any, ...], dict[str, Any]] | None]) RemovableHandle

Register an optimizer step pre hook which will be called before optimizer step.

It should have the following signature:

hook(optimizer, args, kwargs) -> None or modified args and kwargs

The optimizer argument is the optimizer instance being used. If args and kwargs are modified by the pre-hook, then the transformed values are returned as a tuple containing the new_args and new_kwargs.

Parameters:

hook (Callable) – The user defined hook to be registered.

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemovableHandle

classmethod rmsprop_sgld(params, lr=0.01, noise_level=1.0, weight_decay=0.0, localization=0.0, nbeta=1.0, alpha=0.99, eps=0.1, add_grad_correction=False, bounding_box_size=None, mask=None, save_metrics: bool = False, sketch_dim: int | None = None, sketch_seed: int = 0)[source]

Factory method to create an SGMCMC instance that wraps RMSprop’s adaptive preconditioning with SGLD to perform Bayesian sampling of neural network weights.

The update rule with preconditioning follows:

V(θ_t) = αV(θ_{t-1}) + (1-α)g̅(θ_t)g̅(θ_t)
G(θ_t) = diag(1/(λ1 + √V(θ_t)))
Δθ_t = (ε/2)G(θ_t)(nβ/m ∑∇logp(y|x,θ_t) + γ(θ_0-θ_t) - λθ_t) + √(εG(θ_t))N(0,σ²)

where:

  • ε is learning rate

  • nβ is effective dataset size (=dataset size * inverse temperature)

  • m is batch size

  • γ is localization strength

  • λ is weight decay

  • σ is noise level

  • G(θ) is the RMSprop preconditioner

  • V(θ) tracks squared gradient moving average

  • α is the exponential decay rate

Key differences from standard SGLD:

  • Uses RMSprop preconditioner to adapt to local geometry and curvature

  • Scales both the gradients and noise by the preconditioner

  • Handles pathological curvature through adaptive step sizes

Parameters:
  • params – Iterable of parameters to optimize

  • lr – Learning rate (default: 0.01)

  • noise_level – Standard deviation of noise (default: 1.0)

  • weight_decay – Weight decay factor. Applied with preconditioning (Adam-style). Creates a GaussianPrior centered at zero with localization=weight_decay. (default: 0.0)

  • localization – Strength of pull toward initial parameters. Creates a GaussianPrior centered at initialization with localization=localization. (default: 0.0)

  • nbeta – Inverse temperature (default: 1.0)

  • alpha – RMSprop moving average coefficient (default: 0.99)

  • eps – RMSprop stability constant (default: 0.1)

  • add_grad_correction – Whether to add gradient correction term (default: False)

  • bounding_box_size – Size of bounding box around initial parameters (default: None)

  • mask – Boolean mask for restricting updatable parameters (default: None)

  • save_metrics – Whether to track metrics during training (default: False)

Returns:

SGMCMC optimizer instance

classmethod sgld(params, lr=0.01, noise_level=1.0, weight_decay=0.0, localization=0.0, nbeta=1.0, bounding_box_size=None, mask=None, save_metrics: bool = False, sketch_dim: int | None = None, sketch_seed: int = 0)[source]

Factory method to create an SGMCMC instance that implements Stochastic Gradient Langevin Dynamics (SGLD) with a localization term (Lau et al. 2023).

This optimizer combines Stochastic Gradient Descent (SGD) with Langevin Dynamics, introducing Gaussian noise to the gradient updates. This makes it sample weights from the posterior distribution, instead of finding point estimates through optimization (Welling and Teh 2011).

The update rule follows:

Δθ_t = (ε/2)(nβ/m ∑∇logp(y|x,θ_t) + γ(θ_0-θ_t) - λθ_t) + N(0,εσ²)

where:

  • ε is learning rate

  • nβ is inverse temperature

  • m is batch size

  • γ is localization strength

  • λ is weight decay

  • σ is noise level

This follows Lau et al.’s (2023) implementation, which modifies Welling and Teh (2011) by:

  • Omitting the learning rate schedule (this functionality could be recoverd by using a separate learning rate scheduler).

  • Adding a localization term that pulls weights toward initialization

  • Using tempered Bayes paradigm with inverse temperature nβ

This allows SGMCMC to be used as a drop-in replacement for SGLD.

Parameters:
  • params – Iterable of parameters to optimize

  • lr – Learning rate (default: 0.01)

  • noise_level – Standard deviation of noise (default: 1.0)

  • weight_decay – Weight decay factor. Applied with preconditioning (Adam-style). Creates a GaussianPrior centered at zero with localization=weight_decay. (default: 0.0)

  • localization – Strength of pull toward initial parameters. Creates a GaussianPrior centered at initialization with localization=localization. (default: 0.0)

  • nbeta – Inverse temperature (default: 1.0)

  • bounding_box_size – Size of bounding box around initial parameters (default: None)

  • mask – Boolean mask for restricting updatable parameters (default: None)

  • save_metrics – Whether to track metrics during training (default: False)

Returns:

SGMCMC optimizer instance

classmethod sgnht(params, lr=0.01, diffusion_factor=0.01, nbeta=1.0, bounding_box_size=None, save_metrics: bool = False, sketch_dim: int | None = None, sketch_seed: int = 0)[source]

Factory method to create an SGMCMC instance that matches SGNHT’s interface.

This allows SGMCMC to be used as a drop-in replacement for SGNHT.

Parameters:
  • params – Iterable of parameters to optimize

  • lr – Learning rate (default: 0.01)

  • diffusion_factor – Diffusion factor (default: 0.01)

  • nbeta – Inverse temperature (default: 1.0)

  • bounding_box_size – Size of bounding box around initial parameters (default: None)

  • save_metrics – Whether to track metrics (default: False)

Returns:

SGMCMC optimizer instance

state_dict() dict[str, Any]

Return the state of the optimizer as a dict.

It contains two entries:

  • state: a Dict holding current optimization state. Its content

    differs between optimizer classes, but some common characteristics hold. For example, state is saved per parameter, and the parameter itself is NOT saved. state is a Dictionary mapping parameter ids to a Dict with state corresponding to each parameter.

  • param_groups: a List containing all parameter groups where each

    parameter group is a Dict. Each parameter group contains metadata specific to the optimizer, such as learning rate and weight decay, as well as a List of parameter IDs of the parameters in the group. If a param group was initialized with named_parameters() the names content will also be saved in the state dict.

NOTE: The parameter IDs may look like indices but they are just IDs associating state with param_group. When loading from a state_dict, the optimizer will zip the param_group params (int IDs) and the optimizer param_groups (actual nn.Parameter s) in order to match state WITHOUT additional verification.

A returned state dict might look something like:

{
    'state': {
        0: {'momentum_buffer': tensor(...), ...},
        1: {'momentum_buffer': tensor(...), ...},
        2: {'momentum_buffer': tensor(...), ...},
        3: {'momentum_buffer': tensor(...), ...}
    },
    'param_groups': [
        {
            'lr': 0.01,
            'weight_decay': 0,
            ...
            'params': [0]
            'param_names' ['param0']  (optional)
        },
        {
            'lr': 0.001,
            'weight_decay': 0.5,
            ...
            'params': [1, 2, 3]
            'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional)
        }
    ]
}
step(closure=None, noise_generator: Generator | None = None) None[source]

Perform a single optimization step.

Parameters:
  • closure (callable, optional) – A closure that reevaluates the model and returns the loss.

  • noise_generator (torch.Generator, optional) – Generator for reproducible noise.

Returns:

The loss value if closure is provided, else None.

Return type:

Optional[float]

zero_grad(set_to_none: bool = True) None

Reset the gradients of all optimized torch.Tensor s.

Parameters:

set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests zero_grad(set_to_none=True) followed by a backward pass, .grads are guaranteed to be None for params that did not receive a gradient. 3. torch.optim optimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).

devinterp.optim.sgnht module

class devinterp.optim.sgnht.SGNHT(params, lr=0.01, diffusion_factor=0.01, bounding_box_size=None, save_noise=False, save_mala_vars=False, nbeta=1.0, metrics: list[str] | None = None)[source]

Bases: Optimizer

Implement the Stochastic Gradient Nose Hoover Thermostat (SGNHT) Optimizer. This optimizer blends SGD with an adaptive thermostat variable to control the magnitude of the injected noise, maintaining the kinetic energy of the system.

It follows Ding et al.’s (2014) implementation.

The equations for the update are as follows:

\[\Delta w_t = \epsilon\left(\frac{\beta n}{m} \sum_{i=1}^m \nabla \log p\left(y_{l_i} \mid x_{l_i}, w_t\right) - \xi_t w_t \right) + \sqrt{2A} N(0, \epsilon)\]
\[\Delta\xi_{t} = \epsilon \left( \frac{1}{n} \|w_t\|^2 - 1 \right)\]

where \(w_t\) is the weight at time \(t\), \(\epsilon\) is the learning rate, \((\beta n)\) is the inverse temperature (we’re in the tempered Bayes paradigm), \(n\) is the number of samples, \(m\) is the batch size, \(\xi_t\) is the thermostat variable at time \(t\), \(A\) is the diffusion factor, and \(N(0, A)\) represents Gaussian noise with mean 0 and variance \(A\).

Note

  • diffusion_factor is unique to this class, and functions as a way to allow for random parameter changes while keeping them from blowing up by guiding parameters back to a slowly-changing thermostat value using a friction term.

  • This class does not have an explicit localization term like SGLD() does. If you want to constrain your sampling, use bounding_box_size.

  • Although this class is a subclass of torch.optim.Optimizer, this is a bit of a misnomer in this case. It’s not used for optimizing in LLC estimation, but rather for sampling from the posterior distribution around a point.

Parameters:
  • params (Iterable) – Iterable of parameters to optimize or dicts defining parameter groups. Either model.parameters() or something more fancy, just like other torch.optim.Optimizer classes.

  • lr (float, optional) – Learning rate \(\epsilon\). Default is 0.01

  • diffusion_factor (float, optional) – The diffusion factor \(A\) of the thermostat. Default is 0.01

  • bounding_box_size (float, optional) – the size of the bounding box enclosing our trajectory. Default is None

  • nbeta (int, optional) – Effective Inverse Temperature, float (default: 1., set to utils.default_nbeta(dataloader)=len(batch_size)/np.log(len(batch_size)))

Raises:
  • Warning – if nbeta is set to 1

  • Warning – if NoiseNorm callback is used

  • Warning – if MALA callback is used

add_param_group(param_group: dict[str, Any]) None

Add a param group to the Optimizer s param_groups.

This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the Optimizer as training progresses.

Parameters:

param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.

load_state_dict(state_dict: dict[str, Any]) None

Load the optimizer state.

Parameters:

state_dict (dict) – optimizer state. Should be an object returned from a call to state_dict().

Note

The names of the parameters (if they exist under the “param_names” key of each param group in state_dict()) will not affect the loading process. To use the parameters’ names for custom cases (such as when the parameters in the loaded state dict differ from those initialized in the optimizer), a custom register_load_state_dict_pre_hook should be implemented to adapt the loaded dict accordingly. If param_names exist in loaded state dict param_groups they will be saved and override the current names, if present, in the optimizer state. If they do not exist in loaded state dict, the optimizer param_names will remain unchanged.

register_load_state_dict_post_hook(hook: Callable[[Optimizer], None], prepend: bool = False) RemovableHandle

Register a load_state_dict post-hook which will be called after load_state_dict() is called. It should have the following signature:

hook(optimizer) -> None

The optimizer argument is the optimizer instance being used.

The hook will be called with argument self after calling load_state_dict on self. The registered hook can be used to perform post-processing after load_state_dict has loaded the state_dict.

Parameters:
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided post hook will be fired before all the already registered post-hooks on load_state_dict. Otherwise, the provided hook will be fired after all the already registered post-hooks. (default: False)

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemoveableHandle

register_load_state_dict_pre_hook(hook: Callable[[Optimizer, dict[str, Any]], dict[str, Any] | None], prepend: bool = False) RemovableHandle

Register a load_state_dict pre-hook which will be called before load_state_dict() is called. It should have the following signature:

hook(optimizer, state_dict) -> state_dict or None

The optimizer argument is the optimizer instance being used and the state_dict argument is a shallow copy of the state_dict the user passed in to load_state_dict. The hook may modify the state_dict inplace or optionally return a new one. If a state_dict is returned, it will be used to be loaded into the optimizer.

The hook will be called with argument self and state_dict before calling load_state_dict on self. The registered hook can be used to perform pre-processing before the load_state_dict call is made.

Parameters:
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided pre hook will be fired before all the already registered pre-hooks on load_state_dict. Otherwise, the provided hook will be fired after all the already registered pre-hooks. (default: False)

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemoveableHandle

register_state_dict_post_hook(hook: Callable[[Optimizer, dict[str, Any]], dict[str, Any] | None], prepend: bool = False) RemovableHandle

Register a state dict post-hook which will be called after state_dict() is called.

It should have the following signature:

hook(optimizer, state_dict) -> state_dict or None

The hook will be called with arguments self and state_dict after generating a state_dict on self. The hook may modify the state_dict inplace or optionally return a new one. The registered hook can be used to perform post-processing on the state_dict before it is returned.

Parameters:
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided post hook will be fired before all the already registered post-hooks on state_dict. Otherwise, the provided hook will be fired after all the already registered post-hooks. (default: False)

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemoveableHandle

register_state_dict_pre_hook(hook: Callable[[Optimizer], None], prepend: bool = False) RemovableHandle

Register a state dict pre-hook which will be called before state_dict() is called.

It should have the following signature:

hook(optimizer) -> None

The optimizer argument is the optimizer instance being used. The hook will be called with argument self before calling state_dict on self. The registered hook can be used to perform pre-processing before the state_dict call is made.

Parameters:
  • hook (Callable) – The user defined hook to be registered.

  • prepend (bool) – If True, the provided pre hook will be fired before all the already registered pre-hooks on state_dict. Otherwise, the provided hook will be fired after all the already registered pre-hooks. (default: False)

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemoveableHandle

register_step_post_hook(hook: Callable[[Self, tuple[Any, ...], dict[str, Any]], None]) RemovableHandle

Register an optimizer step post hook which will be called after optimizer step.

It should have the following signature:

hook(optimizer, args, kwargs) -> None

The optimizer argument is the optimizer instance being used.

Parameters:

hook (Callable) – The user defined hook to be registered.

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemovableHandle

register_step_pre_hook(hook: Callable[[Self, tuple[Any, ...], dict[str, Any]], tuple[tuple[Any, ...], dict[str, Any]] | None]) RemovableHandle

Register an optimizer step pre hook which will be called before optimizer step.

It should have the following signature:

hook(optimizer, args, kwargs) -> None or modified args and kwargs

The optimizer argument is the optimizer instance being used. If args and kwargs are modified by the pre-hook, then the transformed values are returned as a tuple containing the new_args and new_kwargs.

Parameters:

hook (Callable) – The user defined hook to be registered.

Returns:

a handle that can be used to remove the added hook by calling handle.remove()

Return type:

torch.utils.hooks.RemovableHandle

state_dict() dict[str, Any]

Return the state of the optimizer as a dict.

It contains two entries:

  • state: a Dict holding current optimization state. Its content

    differs between optimizer classes, but some common characteristics hold. For example, state is saved per parameter, and the parameter itself is NOT saved. state is a Dictionary mapping parameter ids to a Dict with state corresponding to each parameter.

  • param_groups: a List containing all parameter groups where each

    parameter group is a Dict. Each parameter group contains metadata specific to the optimizer, such as learning rate and weight decay, as well as a List of parameter IDs of the parameters in the group. If a param group was initialized with named_parameters() the names content will also be saved in the state dict.

NOTE: The parameter IDs may look like indices but they are just IDs associating state with param_group. When loading from a state_dict, the optimizer will zip the param_group params (int IDs) and the optimizer param_groups (actual nn.Parameter s) in order to match state WITHOUT additional verification.

A returned state dict might look something like:

{
    'state': {
        0: {'momentum_buffer': tensor(...), ...},
        1: {'momentum_buffer': tensor(...), ...},
        2: {'momentum_buffer': tensor(...), ...},
        3: {'momentum_buffer': tensor(...), ...}
    },
    'param_groups': [
        {
            'lr': 0.01,
            'weight_decay': 0,
            ...
            'params': [0]
            'param_names' ['param0']  (optional)
        },
        {
            'lr': 0.001,
            'weight_decay': 0.5,
            ...
            'params': [1, 2, 3]
            'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional)
        }
    ]
}
zero_grad(set_to_none: bool = True) None

Reset the gradients of all optimized torch.Tensor s.

Parameters:

set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests zero_grad(set_to_none=True) followed by a backward pass, .grads are guaranteed to be None for params that did not receive a gradient. 3. torch.optim optimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).

devinterp.optim.sketch module

class devinterp.optim.sketch.CountSketch(hash_indices: Tensor, hash_signs: Tensor, _output_dim: int)[source]

Bases: object

Count sketch projection for dimensionality reduction.

Projects d-dimensional vectors to k-dimensional sketches while preserving inner products in expectation: E[<Sv, Sw>] = <v, w>.

The sketch is defined by a hash function h: [d] -> [k] (mapping each input coordinate to an output bucket) and a sign function s: [d] -> {-1, +1} (random per coordinate). The sketch of vector v is:

S(v)[j] = sum_{i : h(i) = j} s(i) * v[i]

Both h and s are generated deterministically from a seed. Two sketch vectors are only comparable when produced by the same CountSketch instance (same seed, same input_dim). When used with an optimizer, input_dim is the total trainable parameter count, so different weight restrictions yield incomparable sketches even with the same seed.

This single-row construction is equivalent to what Weinberger et al. call “feature hashing”. The inner product preservation property is proved there.

References

hash_indices: Tensor
hash_signs: Tensor
scatter_into_(result: Tensor, v: Tensor, offset: int) None[source]

Accumulate one parameter’s contribution into a running sketch buffer.

Exploits linearity: sketching cat(p1, p2, …) is equivalent to accumulating each pi at its offset into the same buffer.

sketch(v: Tensor) Tensor[source]

Apply the full sketch to a flat vector.

class devinterp.optim.sketch.SketchBuffer(scaled_grad: Tensor, unscaled_grad: Tensor, localization: Tensor, weight_decay: Tensor, noise: Tensor)[source]

Bases: object

Per-step accumulation buffers for count sketch metrics.

One buffer per tracked quantity. Lifecycle: zero_() at step start -> accumulate per-param via scatter_into_ -> read.

QUANTITIES: ClassVar[tuple[str, ...]] = ('scaled_grad', 'unscaled_grad', 'localization', 'weight_decay', 'noise')
localization: Tensor
noise: Tensor
scaled_grad: Tensor
unscaled_grad: Tensor
weight_decay: Tensor

devinterp.optim.utils module

Module contents