Source code for devinterp.optim.metrics

from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import ClassVar

import torch


def _zeros() -> torch.Tensor:
    return torch.zeros(1, dtype=torch.float32)


[docs] @dataclass class Metrics: """Norms and dot products of SGMCMC parameter update components. Each step, w += dw where dw = -(scaled_grad + prior) + noise. Norm fields store L2 norms of the post-preconditioned update components (i.e. actual magnitudes applied to parameters, not raw gradients). scaled_grad: (ε/2) · nβ · G · ∇L — preconditioned gradient unscaled_grad: (ε/2) · nβ · ∇L — raw gradient (no preconditioner) localization: (ε/2) · G · γ(w - w₀) — pull toward initial params weight_decay: (ε/2) · G · λw — L2 regularization noise: √ε · √G · η — stochastic exploration distance: w - w₀ — raw displacement from init Dot product fields store inner products between the three main component vectors (scaled_grad, combined prior, noise). dot_grad_prior: ⟨scaled_grad, localization + weight_decay⟩ dot_grad_noise: ⟨scaled_grad, noise⟩ dot_prior_noise: ⟨localization + weight_decay, noise⟩ Cosine similarities can be derived: cos = dot / (norm_a * norm_b). Lifecycle (one Metrics per optimizer param group, on the group's device): 1. __init__: group["metrics"] = Metrics().to(device) 2. step(), start: group["metrics"].zero_() 3. step(), per-p: group["metrics"].add_sum_squared_(...) group["metrics"].add_dot_products_(...) 4. step(), end: group["metrics"].sqrt_norms_() 5. get_metrics(): combine per-group metrics on CPU """ # TODO: Could these be annotations in the `field` declarations? NORM_FIELDS: ClassVar[tuple[str, ...]] = ( "scaled_grad", "unscaled_grad", "localization", "weight_decay", "noise", "distance", ) DOT_FIELDS: ClassVar[tuple[str, ...]] = ( "dot_grad_prior", "dot_grad_noise", "dot_prior_noise", ) scaled_grad: torch.Tensor = field(default_factory=_zeros) unscaled_grad: torch.Tensor = field(default_factory=_zeros) localization: torch.Tensor = field(default_factory=_zeros) weight_decay: torch.Tensor = field(default_factory=_zeros) noise: torch.Tensor = field(default_factory=_zeros) distance: torch.Tensor = field(default_factory=_zeros) dot_grad_prior: torch.Tensor = field(default_factory=_zeros) dot_grad_noise: torch.Tensor = field(default_factory=_zeros) dot_prior_noise: torch.Tensor = field(default_factory=_zeros) numel: int = 0 @property def prior(self) -> torch.Tensor: """Combined prior norm: ||[localization; weight_decay]||₂.""" return (self.localization.square() + self.weight_decay.square()).sqrt()
[docs] def to(self, device: str | torch.device | int) -> "Metrics": """Return a copy of these metrics on the specified device.""" return Metrics( scaled_grad=self.scaled_grad.to(device), unscaled_grad=self.unscaled_grad.to(device), localization=self.localization.to(device), weight_decay=self.weight_decay.to(device), noise=self.noise.to(device), distance=self.distance.to(device), dot_grad_prior=self.dot_grad_prior.to(device), dot_grad_noise=self.dot_grad_noise.to(device), dot_prior_noise=self.dot_prior_noise.to(device), numel=self.numel, )
[docs] def zero_(self) -> None: """Reset all metrics to zero in-place.""" self.scaled_grad.zero_() self.unscaled_grad.zero_() self.localization.zero_() self.weight_decay.zero_() self.noise.zero_() self.distance.zero_() self.dot_grad_prior.zero_() self.dot_grad_noise.zero_() self.dot_prior_noise.zero_() self.numel = 0
[docs] def add_sum_squared_( self, scaled_grad: torch.Tensor, unscaled_grad: torch.Tensor, localization: torch.Tensor, weight_decay: torch.Tensor, noise: torch.Tensor, distance: torch.Tensor, ) -> None: """Accumulate sum-of-squares for each norm component in-place. Casts to float32 before squaring to avoid precision loss with bf16/fp16 inputs (where squaring can overflow or underflow in the input dtype). """ self.scaled_grad += scaled_grad.float().square().sum() self.unscaled_grad += unscaled_grad.float().square().sum() self.localization += localization.float().square().sum() self.weight_decay += weight_decay.float().square().sum() self.noise += noise.float().square().sum() self.distance += distance.float().square().sum()
[docs] def add_dot_products_( self, scaled_grad: torch.Tensor, prior: torch.Tensor, noise: torch.Tensor, ) -> None: """Accumulate dot products between the three main component vectors. Args: scaled_grad: The preconditioned gradient vector. prior: Combined prior vector (localization + weight_decay). noise: The noise vector. """ sg = scaled_grad.float() p = prior.float() n = noise.float() self.dot_grad_prior += (sg * p).sum() self.dot_grad_noise += (sg * n).sum() self.dot_prior_noise += (p * n).sum()
[docs] def sqrt_norms_(self) -> None: """Convert norm fields from sum-of-squares to L2 norms in-place.""" self.scaled_grad = self.scaled_grad.sqrt() self.unscaled_grad = self.unscaled_grad.sqrt() self.localization = self.localization.sqrt() self.weight_decay = self.weight_decay.sqrt() self.noise = self.noise.sqrt() self.distance = self.distance.sqrt()
[docs] @staticmethod def aggregate(group_metrics: Iterable["Metrics"]) -> "Metrics": """Combine per-group metrics into a single Metrics on CPU. Norms: re-square to get sum-of-squares, accumulate, then sqrt: ||[a; b]|| = sqrt(||a||^2 + ||b||^2). Dot products: additive across disjoint parameter sets. """ result = Metrics() for m in group_metrics: m_cpu = m.to("cpu") result.scaled_grad += m_cpu.scaled_grad.square() result.unscaled_grad += m_cpu.unscaled_grad.square() result.localization += m_cpu.localization.square() result.weight_decay += m_cpu.weight_decay.square() result.noise += m_cpu.noise.square() result.distance += m_cpu.distance.square() result.dot_grad_prior += m_cpu.dot_grad_prior result.dot_grad_noise += m_cpu.dot_grad_noise result.dot_prior_noise += m_cpu.dot_prior_noise result.numel += m_cpu.numel result.sqrt_norms_() return result