devinterp.optim package
Submodules
devinterp.optim.metrics module
- class devinterp.optim.metrics.Metrics(scaled_grad: ~torch.Tensor = <factory>, unscaled_grad: ~torch.Tensor = <factory>, localization: ~torch.Tensor = <factory>, weight_decay: ~torch.Tensor = <factory>, noise: ~torch.Tensor = <factory>, distance: ~torch.Tensor = <factory>, dot_grad_prior: ~torch.Tensor = <factory>, dot_grad_noise: ~torch.Tensor = <factory>, dot_prior_noise: ~torch.Tensor = <factory>, numel: int = 0)[source]
Bases:
objectNorms and dot products of SGMCMC parameter update components.
Each step, w += dw where dw = -(scaled_grad + prior) + noise.
Norm fields store L2 norms of the post-preconditioned update components (i.e. actual magnitudes applied to parameters, not raw gradients).
scaled_grad: (ε/2) · nβ · G · ∇L — preconditioned gradient unscaled_grad: (ε/2) · nβ · ∇L — raw gradient (no preconditioner) localization: (ε/2) · G · γ(w - w₀) — pull toward initial params weight_decay: (ε/2) · G · λw — L2 regularization noise: √ε · √G · η — stochastic exploration distance: w - w₀ — raw displacement from init
Dot product fields store inner products between the three main component vectors (scaled_grad, combined prior, noise).
dot_grad_prior: ⟨scaled_grad, localization + weight_decay⟩ dot_grad_noise: ⟨scaled_grad, noise⟩ dot_prior_noise: ⟨localization + weight_decay, noise⟩
Cosine similarities can be derived: cos = dot / (norm_a * norm_b).
Lifecycle (one Metrics per optimizer param group, on the group’s device):
__init__: group[“metrics”] = Metrics().to(device)
step(), start: group[“metrics”].zero_()
- step(), per-p: group[“metrics”].add_sum_squared_(…)
group[“metrics”].add_dot_products_(…)
step(), end: group[“metrics”].sqrt_norms_()
get_metrics(): combine per-group metrics on CPU
- DOT_FIELDS: ClassVar[tuple[str, ...]] = ('dot_grad_prior', 'dot_grad_noise', 'dot_prior_noise')
- NORM_FIELDS: ClassVar[tuple[str, ...]] = ('scaled_grad', 'unscaled_grad', 'localization', 'weight_decay', 'noise', 'distance')
- add_dot_products_(scaled_grad: Tensor, prior: Tensor, noise: Tensor) None[source]
Accumulate dot products between the three main component vectors.
- Parameters:
scaled_grad – The preconditioned gradient vector.
prior – Combined prior vector (localization + weight_decay).
noise – The noise vector.
- add_sum_squared_(scaled_grad: Tensor, unscaled_grad: Tensor, localization: Tensor, weight_decay: Tensor, noise: Tensor, distance: Tensor) None[source]
Accumulate sum-of-squares for each norm component in-place.
Casts to float32 before squaring to avoid precision loss with bf16/fp16 inputs (where squaring can overflow or underflow in the input dtype).
- static aggregate(group_metrics: Iterable[Metrics]) Metrics[source]
Combine per-group metrics into a single Metrics on CPU.
Norms: re-square to get sum-of-squares, accumulate, then sqrt: ||[a; b]|| = sqrt(||a||^2 + ||b||^2).
Dot products: additive across disjoint parameter sets.
- distance: Tensor
- dot_grad_noise: Tensor
- dot_grad_prior: Tensor
- dot_prior_noise: Tensor
- localization: Tensor
- noise: Tensor
- numel: int = 0
- property prior: Tensor
||[localization; weight_decay]||₂.
- Type:
Combined prior norm
- scaled_grad: Tensor
- to(device: str | device | int) Metrics[source]
Return a copy of these metrics on the specified device.
- unscaled_grad: Tensor
- weight_decay: Tensor
devinterp.optim.preconditioner module
- class devinterp.optim.preconditioner.CompositePreconditioner(preconditioners: list[Preconditioner])[source]
Bases:
PreconditionerCombines multiple preconditioners by multiplying their coefficients
- class devinterp.optim.preconditioner.IdentityPreconditioner[source]
Bases:
PreconditionerIdentity preconditioning (i.e., no preconditioning)
- class devinterp.optim.preconditioner.MaskPreconditioner(masks: list[Tensor | float])[source]
Bases:
PreconditionerApplies masks to the overall coefficient while keeping other coefficients at 1.0
Stores one mask per parameter in the parameter group.
- class devinterp.optim.preconditioner.NHTPreconditioning(diffusion_factor: float = 0.01, eps: float = 1e-08)[source]
Bases:
PreconditionerNose-Hoover Thermostat preconditioning
- class devinterp.optim.preconditioner.Preconditioner[source]
Bases:
ABCBase class for preconditioners that generate coefficients for MCMC terms
- abstract get_coefficients(param: Tensor, grad: Tensor, state: dict) PreconditionerCoefs[source]
Compute coefficients for gradient, prior, and noise terms Returns PreconditionerCoefs containing all coefficients Each coefficient can be a scalar or tensor of shape matching param
- class devinterp.optim.preconditioner.PreconditionerCoefs(grad_coef: float | Tensor, prior_coef: float | Tensor, noise_coef: float | Tensor, overall_coef: float | Tensor, grad_correction: float | Tensor | None = None)[source]
Bases:
NamedTupleCoefficients returned by preconditioners
- count(value, /)
Return number of occurrences of value.
- grad_coef: float | Tensor
Alias for field number 0
- grad_correction: float | Tensor | None
Alias for field number 4
- index(value, start=0, stop=9223372036854775807, /)
Return first index of value.
Raises ValueError if the value is not present.
- noise_coef: float | Tensor
Alias for field number 2
- overall_coef: float | Tensor
Alias for field number 3
- prior_coef: float | Tensor
Alias for field number 1
- class devinterp.optim.preconditioner.RMSpropPreconditioner(alpha: float = 0.99, eps: float = 0.1, add_grad_correction=False)[source]
Bases:
PreconditionerRMSprop-style diagonal preconditioning
devinterp.optim.prior module
- class devinterp.optim.prior.CompositePrior(priors: list[Prior])[source]
Bases:
PriorCombines multiple priors, summing their contributions.
Always wraps even a single non-uniform prior (no short-circuit). Callers that need to decompose sub-priors (e.g. SGMCMC._update_metrics) can iterate
self.priorsuniformly.- key: str
- class devinterp.optim.prior.GaussianPrior(localization: float, center: Literal['initial'] | Iterable[Tensor] | Real | None = 'initial')[source]
Bases:
PriorGaussian prior with configurable center and precision
- grad(param: Tensor, state: dict[str, Any]) Tensor[source]
Compute gradient of the prior. If state is provided, the prior center is looked up in the state dictionary using the instance key.
- Parameters:
param – Parameter tensor
state – State dictionary
- Returns:
Gradient tensor
- initialize(params: Sequence[Tensor]) dict[Tensor, dict[str, Any]][source]
Initialize centers for all parameters
- Parameters:
params – Iterator of model parameters
- Returns:
State dictionary containing prior centers
- key: str
- class devinterp.optim.prior.Prior[source]
Bases:
ABCAbstract base class for parameter priors
- abstract grad(param: Tensor, state: dict[str, Any]) Tensor[source]
Compute gradient of the prior
- Parameters:
param – Parameter tensor
state – State dictionary
- Returns:
Gradient tensor
- abstract initialize(params: Sequence[Tensor]) dict[Tensor, dict[str, Any]][source]
Initialize prior for parameters
- Parameters:
params – Iterator of model parameters
- Returns:
Updated state dictionary
- key: str
devinterp.optim.sgld module
- class devinterp.optim.sgld.SGLD(params: Iterable[Parameter], *, lr: float = 0.01, noise_level: float = 1.0, weight_decay: float = 0.0, localization: float = 0.0, nbeta: Callable[[], float] | float = 1.0, bounding_box_size: float | None = None, save_metrics: bool = False, sketch_dim: int | None = None, sketch_seed: int = 0)[source]
Bases:
OptimizerImplements Stochastic Gradient Langevin Dynamics (SGLD) optimizer.
This optimizer blends Stochastic Gradient Descent (SGD) with Langevin Dynamics, introducing Gaussian noise to the gradient updates. This makes it sample weights from the posterior distribution, instead of optimizing weights.
This implementation follows Lau et al.’s (2023) implementation, which is a modification of Welling and Teh (2011) that omits the learning rate schedule and introduces an localization term that pulls the weights towards their initial values.
The equation for the update is as follows:
\[\Delta w_t = \frac{\epsilon}{2}\left(\frac{\beta n}{m} \sum_{i=1}^m \nabla \log p\left(y_{l_i} \mid x_{l_i}, w_t\right)+\gamma\left(w_0-w_t\right) - \lambda w_t\right) + N(0, \epsilon\sigma^2)\]where \(w_t\) is the weight at time \(t\), \(\epsilon\) is the learning rate, \((\beta n)\) is the inverse temperature (we’re in the tempered Bayes paradigm), \(n\) is the number of training samples, \(m\) is the batch size, \(\gamma\) is the localization strength, \(\lambda\) is the weight decay strength, and \(\sigma\) is the noise term.
Example
>>> optimizer = SGLD(model.parameters(), lr=0.1, nbeta=utils.default_nbeta(dataloader))
>>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step()
Note
localizationis unique to this class and serves to guide the weights towards their original values. This is useful for estimating quantities over the local posterior.noise_levelis not intended to be changed, except when testing! Doing so will raise a warning.Although this class is a subclass of
torch.optim.Optimizer, this is a bit of a misnomer in this case. It’s not used for optimizing in LLC estimation, but rather for sampling from the posterior distribution around a point.Hyperparameter optimization is more of an art than a science. Check out the calibration notebook
for how to go about it in a simple case.
- Parameters:
params (Iterable) – Iterable of parameters to optimize or dicts defining parameter groups. Either
model.parameters()or something more fancy, just like othertorch.optim.Optimizerclasses.lr (float, optional) – Learning rate \(\epsilon\). Default is 0.01
noise_level (float, optional) – Amount of Gaussian noise \(\sigma\) introduced into gradient updates. Don’t change this unless you know very well what you’re doing! Default is 1
weight_decay (float, optional) – L2 regularization term \(\lambda\), applied as weight decay. Default is 0
localization (float, optional) – Strength of the force \(\gamma\) pulling weights back to their initial values. Default is 0
nbeta (float or Callable, optional) – Inverse reparameterized temperature (otherwise known as n*beta or ~beta), float (default: 1., set to utils.default_nbeta(dataloader)=len(batch_size)/np.log(len(batch_size)))
bounding_box_size (float, optional) – the size of the bounding box enclosing our trajectory in parameter space. Default is None, in which case no bounding box is used.
save_metrics (bool, optional) – Whether to track metrics (scaled_grad, localization, weight_decay, noise norms) during optimization. Use
get_metrics()to retrieve. Default is False
- Raises:
Warning – if
noise_levelis set to anything other than 1Warning – if
nbetais set to 1
- add_param_group(param_group: dict[str, Any]) None
Add a param group to the
Optimizers param_groups.This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the
Optimizeras training progresses.- Parameters:
param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.
- get_sketches() SketchBuffer[source]
Return a CPU snapshot of the current sketch buffer.
- load_state_dict(state_dict: dict[str, Any]) None
Load the optimizer state.
- Parameters:
state_dict (dict) – optimizer state. Should be an object returned from a call to
state_dict().
Note
The names of the parameters (if they exist under the “param_names” key of each param group in
state_dict()) will not affect the loading process. To use the parameters’ names for custom cases (such as when the parameters in the loaded state dict differ from those initialized in the optimizer), a customregister_load_state_dict_pre_hookshould be implemented to adapt the loaded dict accordingly. Ifparam_namesexist in loaded state dictparam_groupsthey will be saved and override the current names, if present, in the optimizer state. If they do not exist in loaded state dict, the optimizerparam_nameswill remain unchanged.
- register_load_state_dict_post_hook(hook: Callable[[Optimizer], None], prepend: bool = False) RemovableHandle
Register a load_state_dict post-hook which will be called after
load_state_dict()is called. It should have the following signature:hook(optimizer) -> None
The
optimizerargument is the optimizer instance being used.The hook will be called with argument
selfafter callingload_state_dictonself. The registered hook can be used to perform post-processing afterload_state_dicthas loaded thestate_dict.- Parameters:
hook (Callable) – The user defined hook to be registered.
prepend (bool) – If True, the provided post
hookwill be fired before all the already registered post-hooks onload_state_dict. Otherwise, the providedhookwill be fired after all the already registered post-hooks. (default: False)
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemoveableHandle
- register_load_state_dict_pre_hook(hook: Callable[[Optimizer, dict[str, Any]], dict[str, Any] | None], prepend: bool = False) RemovableHandle
Register a load_state_dict pre-hook which will be called before
load_state_dict()is called. It should have the following signature:hook(optimizer, state_dict) -> state_dict or None
The
optimizerargument is the optimizer instance being used and thestate_dictargument is a shallow copy of thestate_dictthe user passed in toload_state_dict. The hook may modify the state_dict inplace or optionally return a new one. If a state_dict is returned, it will be used to be loaded into the optimizer.The hook will be called with argument
selfandstate_dictbefore callingload_state_dictonself. The registered hook can be used to perform pre-processing before theload_state_dictcall is made.- Parameters:
hook (Callable) – The user defined hook to be registered.
prepend (bool) – If True, the provided pre
hookwill be fired before all the already registered pre-hooks onload_state_dict. Otherwise, the providedhookwill be fired after all the already registered pre-hooks. (default: False)
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemoveableHandle
- register_state_dict_post_hook(hook: Callable[[Optimizer, dict[str, Any]], dict[str, Any] | None], prepend: bool = False) RemovableHandle
Register a state dict post-hook which will be called after
state_dict()is called.It should have the following signature:
hook(optimizer, state_dict) -> state_dict or None
The hook will be called with arguments
selfandstate_dictafter generating astate_dictonself. The hook may modify the state_dict inplace or optionally return a new one. The registered hook can be used to perform post-processing on thestate_dictbefore it is returned.- Parameters:
hook (Callable) – The user defined hook to be registered.
prepend (bool) – If True, the provided post
hookwill be fired before all the already registered post-hooks onstate_dict. Otherwise, the providedhookwill be fired after all the already registered post-hooks. (default: False)
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemoveableHandle
- register_state_dict_pre_hook(hook: Callable[[Optimizer], None], prepend: bool = False) RemovableHandle
Register a state dict pre-hook which will be called before
state_dict()is called.It should have the following signature:
hook(optimizer) -> None
The
optimizerargument is the optimizer instance being used. The hook will be called with argumentselfbefore callingstate_dictonself. The registered hook can be used to perform pre-processing before thestate_dictcall is made.- Parameters:
hook (Callable) – The user defined hook to be registered.
prepend (bool) – If True, the provided pre
hookwill be fired before all the already registered pre-hooks onstate_dict. Otherwise, the providedhookwill be fired after all the already registered pre-hooks. (default: False)
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemoveableHandle
- register_step_post_hook(hook: Callable[[Self, tuple[Any, ...], dict[str, Any]], None]) RemovableHandle
Register an optimizer step post hook which will be called after optimizer step.
It should have the following signature:
hook(optimizer, args, kwargs) -> None
The
optimizerargument is the optimizer instance being used.- Parameters:
hook (Callable) – The user defined hook to be registered.
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemovableHandle
- register_step_pre_hook(hook: Callable[[Self, tuple[Any, ...], dict[str, Any]], tuple[tuple[Any, ...], dict[str, Any]] | None]) RemovableHandle
Register an optimizer step pre hook which will be called before optimizer step.
It should have the following signature:
hook(optimizer, args, kwargs) -> None or modified args and kwargs
The
optimizerargument is the optimizer instance being used. If args and kwargs are modified by the pre-hook, then the transformed values are returned as a tuple containing the new_args and new_kwargs.- Parameters:
hook (Callable) – The user defined hook to be registered.
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemovableHandle
- state_dict() dict[str, Any]
Return the state of the optimizer as a
dict.It contains two entries:
state: a Dict holding current optimization state. Its contentdiffers between optimizer classes, but some common characteristics hold. For example, state is saved per parameter, and the parameter itself is NOT saved.
stateis a Dictionary mapping parameter ids to a Dict with state corresponding to each parameter.
param_groups: a List containing all parameter groups where eachparameter group is a Dict. Each parameter group contains metadata specific to the optimizer, such as learning rate and weight decay, as well as a List of parameter IDs of the parameters in the group. If a param group was initialized with
named_parameters()the names content will also be saved in the state dict.
NOTE: The parameter IDs may look like indices but they are just IDs associating state with param_group. When loading from a state_dict, the optimizer will zip the param_group
params(int IDs) and the optimizerparam_groups(actualnn.Parameters) in order to match state WITHOUT additional verification.A returned state dict might look something like:
{ 'state': { 0: {'momentum_buffer': tensor(...), ...}, 1: {'momentum_buffer': tensor(...), ...}, 2: {'momentum_buffer': tensor(...), ...}, 3: {'momentum_buffer': tensor(...), ...} }, 'param_groups': [ { 'lr': 0.01, 'weight_decay': 0, ... 'params': [0] 'param_names' ['param0'] (optional) }, { 'lr': 0.001, 'weight_decay': 0.5, ... 'params': [1, 2, 3] 'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional) } ] }
- step(noise_generator: Generator | None = None) None[source]
Perform a single SGLD optimization step.
- zero_grad(set_to_none: bool = True) None
Reset the gradients of all optimized
torch.Tensors.- Parameters:
set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests
zero_grad(set_to_none=True)followed by a backward pass,.grads are guaranteed to be None for params that did not receive a gradient. 3.torch.optimoptimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).
devinterp.optim.sgmcmc module
- class devinterp.optim.sgmcmc.SGMCMC(params, *, lr: float = 0.01, noise_level: float = 1.0, nbeta: float = 1.0, prior: Prior | Literal['initial'] | Iterable[Tensor] | float | None = None, preconditioner: Preconditioner | Literal['identity', 'rmsprop'] | None = 'identity', preconditioner_kwargs: dict | None = None, bounding_box_size: float | None = None, mask: Tensor | None = None, save_metrics: bool = False, sketch_dim: int | None = None, sketch_seed: int = 0)[source]
Bases:
OptimizerUnified Stochastic Gradient Markov Chain Monte Carlo (SGMCMC) optimizer.
This optimizer implements a general SGMCMC framework that unifies several common variants like SGLD, SGHMC, and SGNHT. It supports custom priors and preconditioners for flexible posterior sampling.
The general update rule follows:
\[Δθ_t = (ε/2)G(θ_t)(nβ/m ∑∇logp(y|x,θ_t) + ∇logp(θ_t)) + √(εG(θ_t))N(0,σ²)\]where:
ε is learning rate
nβ is inverse temperature
m is batch size
G(θ) is the preconditioner
p(θ) is the prior distribution
σ is noise level
- Parameters:
params (Iterable) – Iterable of parameters to optimize or dicts defining parameter groups
lr (float) – Learning rate ε (default: 0.01)
noise_level (float) – Standard deviation σ of the noise (default: 1.0)
nbeta (float) – Inverse temperature nβ (default: 1.0)
prior (Optional[Union[Prior, Literal["initial"], Iterable[torch.Tensor], float]]) – Prior distribution specification. Can be: In SGMCMC, weight_decay and localization should be implemented as a prior. See the rmsprop_sgld() method for an example. - GaussianPrior instance - string specifying center type - iterable of tensor centers - float specifying precision (default: None)
preconditioner (Optional[Union[Preconditioner, str]]) – Preconditioner specification. Can be: - “identity” for no preconditioning - “rmsprop” for RMSprop-style preconditioning - Preconditioner instance (default: None, equivalent to “identity”)
preconditioner_kwargs (Optional[dict]) – Additional keyword arguments for preconditioner
bounding_box_size (Optional[float]) – Size of bounding box around initial parameters
mask (Optional[torch.Tensor]) – Boolean mask for restricting updatable parameters
save_metrics (bool) – Whether to track metrics during training (default: False)
Example:
from devinterp.utils import default_nbeta # Basic SGLD-style usage optimizer = SGMCMC.sgld( model.parameters(), lr=0.1, nbeta=default_nbeta(dataloader) ) # RMSprop-preconditioned with prior optimizer = SGMCMC.rmsprop_sgld( model.parameters(), lr=0.01, localization=0.1, nbeta=default_nbeta(dataloader) ) # SGNHT-style with thermostat optimizer = SGMCMC.sgnht( model.parameters(), lr=0.01, diffusion_factor=0.01, nbeta=default_nbeta(dataloader) ) # Training loop for data, target in dataloader: optimizer.zero_grad() loss = criterion(model(data), target) loss.backward() optimizer.step()
Notes
Use the factory methods (sgld, rmsprop_sgld, sgnht) for easier initialization
nbeta should typically be set using devinterp.utils.default_nbeta() rather than manually
The prior helps explore the local posterior by pulling toward initialization
Use save_metrics=True and call get_metrics() to access tracked metrics
References
Welling & Teh (2011) - Original SGLD paper
Li et al. (2015) - RMSprop-SGLD
Ding et al. (2014) - SGNHT
Lau et al. (2023) - Implementation with localization term
- add_param_group(param_group: dict[str, Any]) None
Add a param group to the
Optimizers param_groups.This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the
Optimizeras training progresses.- Parameters:
param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.
- get_sketches() SketchBuffer[source]
Return a CPU snapshot of the current sketch buffer.
- load_state_dict(state_dict: dict[str, Any]) None
Load the optimizer state.
- Parameters:
state_dict (dict) – optimizer state. Should be an object returned from a call to
state_dict().
Note
The names of the parameters (if they exist under the “param_names” key of each param group in
state_dict()) will not affect the loading process. To use the parameters’ names for custom cases (such as when the parameters in the loaded state dict differ from those initialized in the optimizer), a customregister_load_state_dict_pre_hookshould be implemented to adapt the loaded dict accordingly. Ifparam_namesexist in loaded state dictparam_groupsthey will be saved and override the current names, if present, in the optimizer state. If they do not exist in loaded state dict, the optimizerparam_nameswill remain unchanged.
- register_load_state_dict_post_hook(hook: Callable[[Optimizer], None], prepend: bool = False) RemovableHandle
Register a load_state_dict post-hook which will be called after
load_state_dict()is called. It should have the following signature:hook(optimizer) -> None
The
optimizerargument is the optimizer instance being used.The hook will be called with argument
selfafter callingload_state_dictonself. The registered hook can be used to perform post-processing afterload_state_dicthas loaded thestate_dict.- Parameters:
hook (Callable) – The user defined hook to be registered.
prepend (bool) – If True, the provided post
hookwill be fired before all the already registered post-hooks onload_state_dict. Otherwise, the providedhookwill be fired after all the already registered post-hooks. (default: False)
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemoveableHandle
- register_load_state_dict_pre_hook(hook: Callable[[Optimizer, dict[str, Any]], dict[str, Any] | None], prepend: bool = False) RemovableHandle
Register a load_state_dict pre-hook which will be called before
load_state_dict()is called. It should have the following signature:hook(optimizer, state_dict) -> state_dict or None
The
optimizerargument is the optimizer instance being used and thestate_dictargument is a shallow copy of thestate_dictthe user passed in toload_state_dict. The hook may modify the state_dict inplace or optionally return a new one. If a state_dict is returned, it will be used to be loaded into the optimizer.The hook will be called with argument
selfandstate_dictbefore callingload_state_dictonself. The registered hook can be used to perform pre-processing before theload_state_dictcall is made.- Parameters:
hook (Callable) – The user defined hook to be registered.
prepend (bool) – If True, the provided pre
hookwill be fired before all the already registered pre-hooks onload_state_dict. Otherwise, the providedhookwill be fired after all the already registered pre-hooks. (default: False)
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemoveableHandle
- register_state_dict_post_hook(hook: Callable[[Optimizer, dict[str, Any]], dict[str, Any] | None], prepend: bool = False) RemovableHandle
Register a state dict post-hook which will be called after
state_dict()is called.It should have the following signature:
hook(optimizer, state_dict) -> state_dict or None
The hook will be called with arguments
selfandstate_dictafter generating astate_dictonself. The hook may modify the state_dict inplace or optionally return a new one. The registered hook can be used to perform post-processing on thestate_dictbefore it is returned.- Parameters:
hook (Callable) – The user defined hook to be registered.
prepend (bool) – If True, the provided post
hookwill be fired before all the already registered post-hooks onstate_dict. Otherwise, the providedhookwill be fired after all the already registered post-hooks. (default: False)
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemoveableHandle
- register_state_dict_pre_hook(hook: Callable[[Optimizer], None], prepend: bool = False) RemovableHandle
Register a state dict pre-hook which will be called before
state_dict()is called.It should have the following signature:
hook(optimizer) -> None
The
optimizerargument is the optimizer instance being used. The hook will be called with argumentselfbefore callingstate_dictonself. The registered hook can be used to perform pre-processing before thestate_dictcall is made.- Parameters:
hook (Callable) – The user defined hook to be registered.
prepend (bool) – If True, the provided pre
hookwill be fired before all the already registered pre-hooks onstate_dict. Otherwise, the providedhookwill be fired after all the already registered pre-hooks. (default: False)
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemoveableHandle
- register_step_post_hook(hook: Callable[[Self, tuple[Any, ...], dict[str, Any]], None]) RemovableHandle
Register an optimizer step post hook which will be called after optimizer step.
It should have the following signature:
hook(optimizer, args, kwargs) -> None
The
optimizerargument is the optimizer instance being used.- Parameters:
hook (Callable) – The user defined hook to be registered.
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemovableHandle
- register_step_pre_hook(hook: Callable[[Self, tuple[Any, ...], dict[str, Any]], tuple[tuple[Any, ...], dict[str, Any]] | None]) RemovableHandle
Register an optimizer step pre hook which will be called before optimizer step.
It should have the following signature:
hook(optimizer, args, kwargs) -> None or modified args and kwargs
The
optimizerargument is the optimizer instance being used. If args and kwargs are modified by the pre-hook, then the transformed values are returned as a tuple containing the new_args and new_kwargs.- Parameters:
hook (Callable) – The user defined hook to be registered.
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemovableHandle
- classmethod rmsprop_sgld(params, lr=0.01, noise_level=1.0, weight_decay=0.0, localization=0.0, nbeta=1.0, alpha=0.99, eps=0.1, add_grad_correction=False, bounding_box_size=None, mask=None, save_metrics: bool = False, sketch_dim: int | None = None, sketch_seed: int = 0)[source]
Factory method to create an SGMCMC instance that wraps RMSprop’s adaptive preconditioning with SGLD to perform Bayesian sampling of neural network weights.
The update rule with preconditioning follows:
V(θ_t) = αV(θ_{t-1}) + (1-α)g̅(θ_t)g̅(θ_t) G(θ_t) = diag(1/(λ1 + √V(θ_t))) Δθ_t = (ε/2)G(θ_t)(nβ/m ∑∇logp(y|x,θ_t) + γ(θ_0-θ_t) - λθ_t) + √(εG(θ_t))N(0,σ²)where:
ε is learning rate
nβ is effective dataset size (=dataset size * inverse temperature)
m is batch size
γ is localization strength
λ is weight decay
σ is noise level
G(θ) is the RMSprop preconditioner
V(θ) tracks squared gradient moving average
α is the exponential decay rate
Key differences from standard SGLD:
Uses RMSprop preconditioner to adapt to local geometry and curvature
Scales both the gradients and noise by the preconditioner
Handles pathological curvature through adaptive step sizes
- Parameters:
params – Iterable of parameters to optimize
lr – Learning rate (default: 0.01)
noise_level – Standard deviation of noise (default: 1.0)
weight_decay – Weight decay factor. Applied with preconditioning (Adam-style). Creates a GaussianPrior centered at zero with localization=weight_decay. (default: 0.0)
localization – Strength of pull toward initial parameters. Creates a GaussianPrior centered at initialization with localization=localization. (default: 0.0)
nbeta – Inverse temperature (default: 1.0)
alpha – RMSprop moving average coefficient (default: 0.99)
eps – RMSprop stability constant (default: 0.1)
add_grad_correction – Whether to add gradient correction term (default: False)
bounding_box_size – Size of bounding box around initial parameters (default: None)
mask – Boolean mask for restricting updatable parameters (default: None)
save_metrics – Whether to track metrics during training (default: False)
- Returns:
SGMCMC optimizer instance
- classmethod sgld(params, lr=0.01, noise_level=1.0, weight_decay=0.0, localization=0.0, nbeta=1.0, bounding_box_size=None, mask=None, save_metrics: bool = False, sketch_dim: int | None = None, sketch_seed: int = 0)[source]
Factory method to create an SGMCMC instance that implements Stochastic Gradient Langevin Dynamics (SGLD) with a localization term (Lau et al. 2023).
This optimizer combines Stochastic Gradient Descent (SGD) with Langevin Dynamics, introducing Gaussian noise to the gradient updates. This makes it sample weights from the posterior distribution, instead of finding point estimates through optimization (Welling and Teh 2011).
The update rule follows:
Δθ_t = (ε/2)(nβ/m ∑∇logp(y|x,θ_t) + γ(θ_0-θ_t) - λθ_t) + N(0,εσ²)
where:
ε is learning rate
nβ is inverse temperature
m is batch size
γ is localization strength
λ is weight decay
σ is noise level
This follows Lau et al.’s (2023) implementation, which modifies Welling and Teh (2011) by:
Omitting the learning rate schedule (this functionality could be recoverd by using a separate learning rate scheduler).
Adding a localization term that pulls weights toward initialization
Using tempered Bayes paradigm with inverse temperature nβ
This allows SGMCMC to be used as a drop-in replacement for SGLD.
- Parameters:
params – Iterable of parameters to optimize
lr – Learning rate (default: 0.01)
noise_level – Standard deviation of noise (default: 1.0)
weight_decay – Weight decay factor. Applied with preconditioning (Adam-style). Creates a GaussianPrior centered at zero with localization=weight_decay. (default: 0.0)
localization – Strength of pull toward initial parameters. Creates a GaussianPrior centered at initialization with localization=localization. (default: 0.0)
nbeta – Inverse temperature (default: 1.0)
bounding_box_size – Size of bounding box around initial parameters (default: None)
mask – Boolean mask for restricting updatable parameters (default: None)
save_metrics – Whether to track metrics during training (default: False)
- Returns:
SGMCMC optimizer instance
- classmethod sgnht(params, lr=0.01, diffusion_factor=0.01, nbeta=1.0, bounding_box_size=None, save_metrics: bool = False, sketch_dim: int | None = None, sketch_seed: int = 0)[source]
Factory method to create an SGMCMC instance that matches SGNHT’s interface.
This allows SGMCMC to be used as a drop-in replacement for SGNHT.
- Parameters:
params – Iterable of parameters to optimize
lr – Learning rate (default: 0.01)
diffusion_factor – Diffusion factor (default: 0.01)
nbeta – Inverse temperature (default: 1.0)
bounding_box_size – Size of bounding box around initial parameters (default: None)
save_metrics – Whether to track metrics (default: False)
- Returns:
SGMCMC optimizer instance
- state_dict() dict[str, Any]
Return the state of the optimizer as a
dict.It contains two entries:
state: a Dict holding current optimization state. Its contentdiffers between optimizer classes, but some common characteristics hold. For example, state is saved per parameter, and the parameter itself is NOT saved.
stateis a Dictionary mapping parameter ids to a Dict with state corresponding to each parameter.
param_groups: a List containing all parameter groups where eachparameter group is a Dict. Each parameter group contains metadata specific to the optimizer, such as learning rate and weight decay, as well as a List of parameter IDs of the parameters in the group. If a param group was initialized with
named_parameters()the names content will also be saved in the state dict.
NOTE: The parameter IDs may look like indices but they are just IDs associating state with param_group. When loading from a state_dict, the optimizer will zip the param_group
params(int IDs) and the optimizerparam_groups(actualnn.Parameters) in order to match state WITHOUT additional verification.A returned state dict might look something like:
{ 'state': { 0: {'momentum_buffer': tensor(...), ...}, 1: {'momentum_buffer': tensor(...), ...}, 2: {'momentum_buffer': tensor(...), ...}, 3: {'momentum_buffer': tensor(...), ...} }, 'param_groups': [ { 'lr': 0.01, 'weight_decay': 0, ... 'params': [0] 'param_names' ['param0'] (optional) }, { 'lr': 0.001, 'weight_decay': 0.5, ... 'params': [1, 2, 3] 'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional) } ] }
- step(closure=None, noise_generator: Generator | None = None) None[source]
Perform a single optimization step.
- Parameters:
closure (callable, optional) – A closure that reevaluates the model and returns the loss.
noise_generator (torch.Generator, optional) – Generator for reproducible noise.
- Returns:
The loss value if closure is provided, else None.
- Return type:
Optional[float]
- zero_grad(set_to_none: bool = True) None
Reset the gradients of all optimized
torch.Tensors.- Parameters:
set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests
zero_grad(set_to_none=True)followed by a backward pass,.grads are guaranteed to be None for params that did not receive a gradient. 3.torch.optimoptimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).
devinterp.optim.sgnht module
- class devinterp.optim.sgnht.SGNHT(params, lr=0.01, diffusion_factor=0.01, bounding_box_size=None, save_noise=False, save_mala_vars=False, nbeta=1.0, metrics: list[str] | None = None)[source]
Bases:
OptimizerImplement the Stochastic Gradient Nose Hoover Thermostat (SGNHT) Optimizer. This optimizer blends SGD with an adaptive thermostat variable to control the magnitude of the injected noise, maintaining the kinetic energy of the system.
It follows Ding et al.’s (2014) implementation.
The equations for the update are as follows:
\[\Delta w_t = \epsilon\left(\frac{\beta n}{m} \sum_{i=1}^m \nabla \log p\left(y_{l_i} \mid x_{l_i}, w_t\right) - \xi_t w_t \right) + \sqrt{2A} N(0, \epsilon)\]\[\Delta\xi_{t} = \epsilon \left( \frac{1}{n} \|w_t\|^2 - 1 \right)\]where \(w_t\) is the weight at time \(t\), \(\epsilon\) is the learning rate, \((\beta n)\) is the inverse temperature (we’re in the tempered Bayes paradigm), \(n\) is the number of samples, \(m\) is the batch size, \(\xi_t\) is the thermostat variable at time \(t\), \(A\) is the diffusion factor, and \(N(0, A)\) represents Gaussian noise with mean 0 and variance \(A\).
Note
diffusion_factoris unique to this class, and functions as a way to allow for random parameter changes while keeping them from blowing up by guiding parameters back to a slowly-changing thermostat value using a friction term.This class does not have an explicit localization term like
SGLD()does. If you want to constrain your sampling, usebounding_box_size.Although this class is a subclass of
torch.optim.Optimizer, this is a bit of a misnomer in this case. It’s not used for optimizing in LLC estimation, but rather for sampling from the posterior distribution around a point.
- Parameters:
params (Iterable) – Iterable of parameters to optimize or dicts defining parameter groups. Either
model.parameters()or something more fancy, just like othertorch.optim.Optimizerclasses.lr (float, optional) – Learning rate \(\epsilon\). Default is 0.01
diffusion_factor (float, optional) – The diffusion factor \(A\) of the thermostat. Default is 0.01
bounding_box_size (float, optional) – the size of the bounding box enclosing our trajectory. Default is None
nbeta (int, optional) – Effective Inverse Temperature, float (default: 1., set to utils.default_nbeta(dataloader)=len(batch_size)/np.log(len(batch_size)))
- Raises:
Warning – if
nbetais set to 1Warning – if
NoiseNormcallback is usedWarning – if
MALAcallback is used
- add_param_group(param_group: dict[str, Any]) None
Add a param group to the
Optimizers param_groups.This can be useful when fine tuning a pre-trained network as frozen layers can be made trainable and added to the
Optimizeras training progresses.- Parameters:
param_group (dict) – Specifies what Tensors should be optimized along with group specific optimization options.
- load_state_dict(state_dict: dict[str, Any]) None
Load the optimizer state.
- Parameters:
state_dict (dict) – optimizer state. Should be an object returned from a call to
state_dict().
Note
The names of the parameters (if they exist under the “param_names” key of each param group in
state_dict()) will not affect the loading process. To use the parameters’ names for custom cases (such as when the parameters in the loaded state dict differ from those initialized in the optimizer), a customregister_load_state_dict_pre_hookshould be implemented to adapt the loaded dict accordingly. Ifparam_namesexist in loaded state dictparam_groupsthey will be saved and override the current names, if present, in the optimizer state. If they do not exist in loaded state dict, the optimizerparam_nameswill remain unchanged.
- register_load_state_dict_post_hook(hook: Callable[[Optimizer], None], prepend: bool = False) RemovableHandle
Register a load_state_dict post-hook which will be called after
load_state_dict()is called. It should have the following signature:hook(optimizer) -> None
The
optimizerargument is the optimizer instance being used.The hook will be called with argument
selfafter callingload_state_dictonself. The registered hook can be used to perform post-processing afterload_state_dicthas loaded thestate_dict.- Parameters:
hook (Callable) – The user defined hook to be registered.
prepend (bool) – If True, the provided post
hookwill be fired before all the already registered post-hooks onload_state_dict. Otherwise, the providedhookwill be fired after all the already registered post-hooks. (default: False)
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemoveableHandle
- register_load_state_dict_pre_hook(hook: Callable[[Optimizer, dict[str, Any]], dict[str, Any] | None], prepend: bool = False) RemovableHandle
Register a load_state_dict pre-hook which will be called before
load_state_dict()is called. It should have the following signature:hook(optimizer, state_dict) -> state_dict or None
The
optimizerargument is the optimizer instance being used and thestate_dictargument is a shallow copy of thestate_dictthe user passed in toload_state_dict. The hook may modify the state_dict inplace or optionally return a new one. If a state_dict is returned, it will be used to be loaded into the optimizer.The hook will be called with argument
selfandstate_dictbefore callingload_state_dictonself. The registered hook can be used to perform pre-processing before theload_state_dictcall is made.- Parameters:
hook (Callable) – The user defined hook to be registered.
prepend (bool) – If True, the provided pre
hookwill be fired before all the already registered pre-hooks onload_state_dict. Otherwise, the providedhookwill be fired after all the already registered pre-hooks. (default: False)
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemoveableHandle
- register_state_dict_post_hook(hook: Callable[[Optimizer, dict[str, Any]], dict[str, Any] | None], prepend: bool = False) RemovableHandle
Register a state dict post-hook which will be called after
state_dict()is called.It should have the following signature:
hook(optimizer, state_dict) -> state_dict or None
The hook will be called with arguments
selfandstate_dictafter generating astate_dictonself. The hook may modify the state_dict inplace or optionally return a new one. The registered hook can be used to perform post-processing on thestate_dictbefore it is returned.- Parameters:
hook (Callable) – The user defined hook to be registered.
prepend (bool) – If True, the provided post
hookwill be fired before all the already registered post-hooks onstate_dict. Otherwise, the providedhookwill be fired after all the already registered post-hooks. (default: False)
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemoveableHandle
- register_state_dict_pre_hook(hook: Callable[[Optimizer], None], prepend: bool = False) RemovableHandle
Register a state dict pre-hook which will be called before
state_dict()is called.It should have the following signature:
hook(optimizer) -> None
The
optimizerargument is the optimizer instance being used. The hook will be called with argumentselfbefore callingstate_dictonself. The registered hook can be used to perform pre-processing before thestate_dictcall is made.- Parameters:
hook (Callable) – The user defined hook to be registered.
prepend (bool) – If True, the provided pre
hookwill be fired before all the already registered pre-hooks onstate_dict. Otherwise, the providedhookwill be fired after all the already registered pre-hooks. (default: False)
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemoveableHandle
- register_step_post_hook(hook: Callable[[Self, tuple[Any, ...], dict[str, Any]], None]) RemovableHandle
Register an optimizer step post hook which will be called after optimizer step.
It should have the following signature:
hook(optimizer, args, kwargs) -> None
The
optimizerargument is the optimizer instance being used.- Parameters:
hook (Callable) – The user defined hook to be registered.
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemovableHandle
- register_step_pre_hook(hook: Callable[[Self, tuple[Any, ...], dict[str, Any]], tuple[tuple[Any, ...], dict[str, Any]] | None]) RemovableHandle
Register an optimizer step pre hook which will be called before optimizer step.
It should have the following signature:
hook(optimizer, args, kwargs) -> None or modified args and kwargs
The
optimizerargument is the optimizer instance being used. If args and kwargs are modified by the pre-hook, then the transformed values are returned as a tuple containing the new_args and new_kwargs.- Parameters:
hook (Callable) – The user defined hook to be registered.
- Returns:
a handle that can be used to remove the added hook by calling
handle.remove()- Return type:
torch.utils.hooks.RemovableHandle
- state_dict() dict[str, Any]
Return the state of the optimizer as a
dict.It contains two entries:
state: a Dict holding current optimization state. Its contentdiffers between optimizer classes, but some common characteristics hold. For example, state is saved per parameter, and the parameter itself is NOT saved.
stateis a Dictionary mapping parameter ids to a Dict with state corresponding to each parameter.
param_groups: a List containing all parameter groups where eachparameter group is a Dict. Each parameter group contains metadata specific to the optimizer, such as learning rate and weight decay, as well as a List of parameter IDs of the parameters in the group. If a param group was initialized with
named_parameters()the names content will also be saved in the state dict.
NOTE: The parameter IDs may look like indices but they are just IDs associating state with param_group. When loading from a state_dict, the optimizer will zip the param_group
params(int IDs) and the optimizerparam_groups(actualnn.Parameters) in order to match state WITHOUT additional verification.A returned state dict might look something like:
{ 'state': { 0: {'momentum_buffer': tensor(...), ...}, 1: {'momentum_buffer': tensor(...), ...}, 2: {'momentum_buffer': tensor(...), ...}, 3: {'momentum_buffer': tensor(...), ...} }, 'param_groups': [ { 'lr': 0.01, 'weight_decay': 0, ... 'params': [0] 'param_names' ['param0'] (optional) }, { 'lr': 0.001, 'weight_decay': 0.5, ... 'params': [1, 2, 3] 'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional) } ] }
- zero_grad(set_to_none: bool = True) None
Reset the gradients of all optimized
torch.Tensors.- Parameters:
set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests
zero_grad(set_to_none=True)followed by a backward pass,.grads are guaranteed to be None for params that did not receive a gradient. 3.torch.optimoptimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).
devinterp.optim.sketch module
- class devinterp.optim.sketch.CountSketch(hash_indices: Tensor, hash_signs: Tensor, _output_dim: int)[source]
Bases:
objectCount sketch projection for dimensionality reduction.
Projects d-dimensional vectors to k-dimensional sketches while preserving inner products in expectation: E[<Sv, Sw>] = <v, w>.
The sketch is defined by a hash function h: [d] -> [k] (mapping each input coordinate to an output bucket) and a sign function s: [d] -> {-1, +1} (random per coordinate). The sketch of vector v is:
S(v)[j] = sum_{i : h(i) = j} s(i) * v[i]
Both h and s are generated deterministically from a seed. Two sketch vectors are only comparable when produced by the same CountSketch instance (same seed, same input_dim). When used with an optimizer, input_dim is the total trainable parameter count, so different weight restrictions yield incomparable sketches even with the same seed.
This single-row construction is equivalent to what Weinberger et al. call “feature hashing”. The inner product preservation property is proved there.
References
Weinberger, Dasgupta, Langford, Smola & Attenberg, “Feature Hashing for Large Scale Multitask Learning” (ICML 2009), https://doi.org/10.1145/1553374.1553516
Charikar, Chen & Farach-Colton, “Finding Frequent Items in Data Streams” (ICALP 2002), https://doi.org/10.1007/3-540-45465-9_59
Larsen, Pagh & Tetek, “CountSketches, Feature Hashing and the Median of Three” (ICML 2021), https://doi.org/10.48550/arXiv.2102.02193
- hash_indices: Tensor
- hash_signs: Tensor
- class devinterp.optim.sketch.SketchBuffer(scaled_grad: Tensor, unscaled_grad: Tensor, localization: Tensor, weight_decay: Tensor, noise: Tensor)[source]
Bases:
objectPer-step accumulation buffers for count sketch metrics.
One buffer per tracked quantity. Lifecycle: zero_() at step start -> accumulate per-param via scatter_into_ -> read.
- QUANTITIES: ClassVar[tuple[str, ...]] = ('scaled_grad', 'unscaled_grad', 'localization', 'weight_decay', 'noise')
- localization: Tensor
- noise: Tensor
- scaled_grad: Tensor
- unscaled_grad: Tensor
- weight_decay: Tensor