import warnings
from typing import Callable, Iterable, Iterator, NamedTuple, Optional, Union
import torch
from .metrics import Metrics
from .sketch import CountSketch, SketchBuffer
class _ComponentVectors(NamedTuple):
"""Post-masked component vectors from a single SGLD parameter update."""
# Duplicated from SGMCMC deliberately: SGLD is deprecated and will be
# removed, so coupling it to the replacement via a shared type would make
# that cleanup harder. Also the semantics differ slightly in the way masks
# are applied.
scaled_grad: torch.Tensor
unscaled_grad: torch.Tensor
localization: torch.Tensor
weight_decay: torch.Tensor
noise: torch.Tensor
[docs]
class SGLD(torch.optim.Optimizer):
r"""
Implements Stochastic Gradient Langevin Dynamics (SGLD) optimizer.
This optimizer blends Stochastic Gradient Descent (SGD) with Langevin Dynamics,
introducing Gaussian noise to the gradient updates. This makes it sample weights from the posterior distribution, instead of optimizing weights.
This implementation follows Lau et al.'s (2023) implementation, which is a modification of
Welling and Teh (2011) that omits the learning rate schedule and introduces
an localization term that pulls the weights towards their initial values.
The equation for the update is as follows:
.. math::
\Delta w_t = \frac{\epsilon}{2}\left(\frac{\beta n}{m} \sum_{i=1}^m \nabla \log p\left(y_{l_i} \mid x_{l_i}, w_t\right)+\gamma\left(w_0-w_t\right) - \lambda w_t\right) + N(0, \epsilon\sigma^2)
where :math:`w_t` is the weight at time :math:`t`, :math:`\epsilon` is the learning rate,
:math:`(\beta n)` is the inverse temperature (we're in the tempered Bayes paradigm),
:math:`n` is the number of training samples, :math:`m` is the batch size, :math:`\gamma` is
the localization strength, :math:`\lambda` is the weight decay strength,
and :math:`\sigma` is the noise term.
Example:
>>> optimizer = SGLD(model.parameters(), lr=0.1, nbeta=utils.default_nbeta(dataloader))
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
.. |colab6| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/github/timaeus-research/devinterp/blob/main/examples/sgld_calibration.ipynb
Note:
- ``localization`` is unique to this class and serves to guide the weights towards their original values. This is useful for estimating quantities over the local posterior.
- ``noise_level`` is not intended to be changed, except when testing! Doing so will raise a warning.
- Although this class is a subclass of ``torch.optim.Optimizer``, this is a bit of a misnomer in this case. It's not used for optimizing in LLC estimation, but rather for sampling from the posterior distribution around a point.
- Hyperparameter optimization is more of an art than a science. Check out `the calibration notebook <https://www.github.com/timaeus-research/devinterp/blob/main/examples/sgld_calibration.ipynb>`_ |colab6| for how to go about it in a simple case.
:param params: Iterable of parameters to optimize or dicts defining parameter groups. Either ``model.parameters()`` or something more fancy, just like other ``torch.optim.Optimizer`` classes.
:type params: Iterable
:param lr: Learning rate :math:`\epsilon`. Default is 0.01
:type lr: float, optional
:param noise_level: Amount of Gaussian noise :math:`\sigma` introduced into gradient updates. Don't change this unless you know very well what you're doing! Default is 1
:type noise_level: float, optional
:param weight_decay: L2 regularization term :math:`\lambda`, applied as weight decay. Default is 0
:type weight_decay: float, optional
:param localization: Strength of the force :math:`\gamma` pulling weights back to their initial values. Default is 0
:type localization: float, optional
:param nbeta: Inverse reparameterized temperature (otherwise known as n*beta or ~beta), float (default: 1., set to utils.default_nbeta(dataloader)=len(batch_size)/np.log(len(batch_size)))
:type nbeta: float or Callable, optional
:param bounding_box_size: the size of the bounding box enclosing our trajectory in parameter space. Default is None, in which case no bounding box is used.
:type bounding_box_size: float, optional
:param save_metrics: Whether to track metrics (scaled_grad, localization, weight_decay, noise norms) during optimization. Use :meth:`get_metrics` to retrieve. Default is False
:type save_metrics: bool, optional
:raises Warning: if ``noise_level`` is set to anything other than 1
:raises Warning: if ``nbeta`` is set to 1
"""
def __init__(
self,
params: Iterable[torch.nn.Parameter],
*,
lr: float = 0.01,
noise_level: float = 1.0,
weight_decay: float = 0.0,
localization: float = 0.0,
nbeta: Union[Callable[[], float], float] = 1.0,
bounding_box_size: Optional[float] = None,
save_metrics: bool = False,
sketch_dim: Optional[int] = None,
sketch_seed: int = 0,
):
warnings.warn(
"SGLD has been deprecated. Please use SGMCMC.sgld instead.",
DeprecationWarning,
)
self.save_metrics = save_metrics
self.save_sketches = sketch_dim is not None
if noise_level != 1.0:
warnings.warn(
"Warning: noise_level in SGLD is unequal to one, this removes SGLD posterior sampling guarantees."
)
if nbeta == 1.0:
warnings.warn(
"Warning: nbeta set to 1, LLC estimates will be off unless you know what you're doing. Use utils.default_nbeta(dataloader) instead"
)
defaults = dict(
lr=lr,
noise_level=noise_level,
weight_decay=weight_decay,
localization=localization,
nbeta=nbeta,
bounding_box_size=bounding_box_size,
)
# In torch.optim.Optimizer, the parameters are stored in a list of dictionaries.
# defaults holds the default values for the optimizer parameters.
super(SGLD, self).__init__(params, defaults)
# Save the initial parameters if the localization term is set
for group in self.param_groups:
group["num_el"] = 0
# Validate mask shape if present
if group.get("mask") is not None:
for p in group["params"]:
mask = group["mask"]
if isinstance(mask, torch.Tensor) and mask.shape != p.shape:
raise ValueError(
f"Mask shape {mask.shape} does not match parameter shape {p.shape}. "
"Scalar masks are not supported."
)
store_initial = (
group["localization"] != 0
or group["bounding_box_size"] != 0
or self.save_metrics
or self.save_sketches
)
if store_initial:
for p in group["params"]:
param_state = self.state[p]
param_state["initial_param"] = p.data.clone().detach()
group["num_el"] += p.numel()
if self.save_metrics:
device = next(iter(group["params"])).device
group["metrics"] = Metrics().to(device)
if self.save_sketches:
assert sketch_dim is not None
total_numel = sum(
p.numel() for group in self.param_groups for p in group["params"]
)
device = next(iter(self.param_groups[0]["params"])).device
self._sketch = CountSketch.create(total_numel, sketch_dim, sketch_seed).to(
device
)
self._sketch_buf = SketchBuffer.create(sketch_dim, device)
else:
self._sketch = None
self._sketch_buf = None
[docs]
def step(self, noise_generator: Optional[torch.Generator] = None) -> None:
"""
Perform a single SGLD optimization step.
"""
with torch.no_grad():
_need_decomposition = self.save_metrics or self.save_sketches
param_offset = 0
if self.save_sketches:
assert self._sketch_buf is not None
self._sketch_buf.zero_()
for group_idx, group in enumerate(self.param_groups):
# Metrics lifecycle: zero → accumulate per-param → sqrt (see Metrics docstring)
if self.save_metrics:
group["metrics"].zero_()
for p in group["params"]:
param_state = self.state[p]
# Gradients are None if the parameter is not trainable
# We'll denote the gradient of the loss with respect to this param group (p) as dw
if p.grad is None:
dw = torch.zeros_like(p.data)
else:
dw = p.grad.data * group["nbeta"]
# Weight decay
if group["weight_decay"] != 0:
dw.add_(
p.data, alpha=group["weight_decay"]
) # inplace addition. Effectively, dw = dw + p.data * group["weight_decay"]
# Here, group["localization"] is the localization strength $\gamma$ (a single float). If it's 0, we don't do anything.
initial_param = self.state[p]["initial_param"]
initial_param_distance = p.data - initial_param
if group["localization"] != 0:
dw.add_(initial_param_distance, alpha=group["localization"])
# Add Gaussian noise
noise = torch.normal(
mean=0.0,
std=group["noise_level"],
size=dw.size(),
device=dw.device,
generator=noise_generator,
)
if group.get("mask") is not None:
# Restrict the noise and gradient to the subset of parameters we're optimizing over.
dw = dw * group["mask"]
noise = noise * group["mask"]
if _need_decomposition:
components = self._decompose_update(
group, p, dw, initial_param_distance, noise
)
if self.save_metrics:
self._accumulate_metrics(
group, p, components, initial_param_distance
)
if self.save_sketches:
self._scatter_sketches(components, param_offset)
# Update parameters
p.data.add_(dw, alpha=-0.5 * group["lr"])
p.data.add_(
noise, alpha=group["lr"] ** 0.5
) # Scale noise by sqrt(lr)
if self.save_sketches:
param_offset += p.numel()
# Rebound if exceeded bounding box size
if group["bounding_box_size"]:
torch.clamp_(
p.data,
min=param_state["initial_param"]
- group["bounding_box_size"],
max=param_state["initial_param"]
+ group["bounding_box_size"],
)
if self.save_metrics:
# All params accumulated; convert sum-of-squares to L2 norms
group["metrics"].sqrt_norms_()
def _decompose_update(
self,
group: dict,
p: torch.Tensor,
dw: torch.Tensor,
initial_param_distance: torch.Tensor,
noise: torch.Tensor,
) -> _ComponentVectors:
"""Decompose an SGLD parameter update into post-masked component vectors.
Shared by both metrics accumulation and sketch scattering. SGLD has no
preconditioner, so unscaled_grad == scaled_grad. Debug assertions verify
the decomposition matches the actual update.
"""
_lr = group["lr"]
_mask = group.get("mask")
_half_lr = 0.5 * _lr
# Multiplication order must match step() to avoid bfloat16 associativity errors:
# step() builds dw as: (p.grad * nbeta) + (p.data * wd) + (dist * loc)
raw_grad = p.grad.data if p.grad is not None else torch.zeros_like(p.data)
_scaled_grad = _half_lr * raw_grad.mul(group["nbeta"])
_noise_vec = noise * (_lr**0.5)
_loc = _half_lr * initial_param_distance.mul(group["localization"])
_wd = _half_lr * p.data.mul(group["weight_decay"])
if _mask is not None:
_scaled_grad = _scaled_grad * _mask
_loc = _loc * _mask
_wd = _wd * _mask
_noise_vec = _noise_vec * _mask
# Sanity-check: reconstructed components must sum to the actual update.
# step() builds dw via in-place add_ (which may use FMA on GPU),
# then we multiply by half_lr once. Here we multiply half_lr into
# each component separately and sum — distributing the scalar
# breaks IEEE 754 associativity. The rounding gap scales with the
# component magnitudes and the dtype's epsilon (significant for
# bfloat16 models where eps ≈ 0.004). Using the sum of absolute
# values as the scale handles catastrophic cancellation, where
# large components nearly cancel leaving a small residual whose
# relative error would otherwise look enormous.
if __debug__:
_step_eps = torch.finfo(dw.dtype).eps
_decomp_scale = float((_scaled_grad.abs() + _loc.abs() + _wd.abs()).max())
_decomp_atol = max(_decomp_scale * _step_eps * 16, 1e-12)
torch.testing.assert_close(
(_scaled_grad + _loc + _wd).float(),
(_half_lr * dw).float(),
atol=_decomp_atol,
rtol=0,
msg=lambda s: f"Decomposition components don't match gradient update:\n{s}",
)
return _ComponentVectors(
scaled_grad=_scaled_grad,
unscaled_grad=_scaled_grad,
localization=_loc,
weight_decay=_wd,
noise=_noise_vec,
)
def _accumulate_metrics(
self,
group: dict,
p: torch.Tensor,
components: _ComponentVectors,
initial_param_distance: torch.Tensor,
) -> None:
"""Accumulate decomposed component vectors into per-group metrics."""
_mask = group.get("mask")
if _mask is not None:
initial_param_distance = initial_param_distance * _mask
numel = p[_mask.bool()].numel()
else:
numel = p.numel()
group["metrics"].add_sum_squared_(
scaled_grad=components.scaled_grad,
unscaled_grad=components.unscaled_grad,
localization=components.localization,
weight_decay=components.weight_decay,
noise=components.noise,
distance=initial_param_distance,
)
group["metrics"].add_dot_products_(
scaled_grad=components.scaled_grad,
prior=components.localization + components.weight_decay,
noise=components.noise,
)
group["metrics"].numel += numel
def _scatter_sketches(self, components: _ComponentVectors, offset: int) -> None:
"""Scatter decomposed component vectors into the sketch buffer."""
buf = self._sketch_buf
sketch = self._sketch
assert buf is not None and sketch is not None
sketch.scatter_into_(buf.scaled_grad, components.scaled_grad, offset)
sketch.scatter_into_(buf.unscaled_grad, components.unscaled_grad, offset)
sketch.scatter_into_(buf.localization, components.localization, offset)
sketch.scatter_into_(buf.weight_decay, components.weight_decay, offset)
sketch.scatter_into_(buf.noise, components.noise, offset)
[docs]
def get_sketches(self) -> SketchBuffer:
"""Return a CPU snapshot of the current sketch buffer."""
if not self.save_sketches:
raise RuntimeError("Sketches not enabled.")
assert self._sketch_buf is not None
return SketchBuffer(
**{
q: getattr(self._sketch_buf, q).detach().cpu().clone()
for q in SketchBuffer.QUANTITIES
}
)
[docs]
def iter_group_metrics(self) -> Iterator[Metrics]:
"""Yield metrics for each param group."""
if not self.save_metrics:
raise RuntimeError("Metrics not enabled. Set save_metrics=True.")
for group in self.param_groups:
yield group["metrics"]
[docs]
def get_metrics(self) -> Metrics:
"""Aggregate metrics across all param groups into a single CPU Metrics."""
return Metrics.aggregate(self.iter_group_metrics())