Source code for devinterp.optim.preconditioner

from abc import ABC, abstractmethod
from functools import reduce
from typing import NamedTuple, Optional, Union

import torch


[docs] class PreconditionerCoefs(NamedTuple): """Coefficients returned by preconditioners""" grad_coef: Union[float, torch.Tensor] # Coefficient for gradient terms prior_coef: Union[float, torch.Tensor] # Coefficient for prior terms noise_coef: Union[float, torch.Tensor] # Coefficient for noise terms overall_coef: Union[float, torch.Tensor] # Overall scaling coefficient grad_correction: Optional[Union[float, torch.Tensor]] = ( None # Optional momentum-based correction, note that this is additive when multiplying PreconditionerCoefs ) def combine_with(self, other: "PreconditionerCoefs") -> "PreconditionerCoefs": grad_correction = ( self.grad_correction if self.grad_correction is not None else other.grad_correction ) if self.grad_correction is not None and other.grad_correction is not None: grad_correction = self.grad_correction + other.grad_correction return PreconditionerCoefs( grad_coef=self.grad_coef * other.grad_coef, prior_coef=self.prior_coef * other.prior_coef, noise_coef=self.noise_coef * other.noise_coef, overall_coef=self.overall_coef * other.overall_coef, grad_correction=grad_correction, ) @classmethod def combine(cls, *coefs: "PreconditionerCoefs") -> "PreconditionerCoefs": return reduce(lambda a, b: a.combine_with(b), coefs)
[docs] class Preconditioner(ABC): """Base class for preconditioners that generate coefficients for MCMC terms"""
[docs] @abstractmethod def get_coefficients( self, param: torch.Tensor, grad: torch.Tensor, state: dict ) -> PreconditionerCoefs: """ Compute coefficients for gradient, prior, and noise terms Returns PreconditionerCoefs containing all coefficients Each coefficient can be a scalar or tensor of shape matching param """ pass
[docs] class IdentityPreconditioner(Preconditioner): """Identity preconditioning (i.e., no preconditioning)""" def get_coefficients( self, param: torch.Tensor, grad: torch.Tensor, state: dict ) -> PreconditionerCoefs: return PreconditionerCoefs(1.0, 1.0, 1.0, 1.0, None) def __repr__(self) -> str: return "IdentityPreconditioner()"
[docs] class MaskPreconditioner(Preconditioner): """Applies masks to the overall coefficient while keeping other coefficients at 1.0 Stores one mask per parameter in the parameter group. """ def __init__(self, masks: list[Union[torch.Tensor, float]]): """ Args: masks: List of masks, one per parameter in the parameter group. Each mask can be either a tensor matching the parameter shape or a float that will be broadcast. """ self.masks = masks def get_coefficients( self, param: torch.Tensor, grad: torch.Tensor, state: dict ) -> PreconditionerCoefs: # Get the parameter index from the state dict if "param_idx" not in state: raise ValueError("param_idx not found in state dict") param_idx = state["param_idx"] mask = self.masks[param_idx] return PreconditionerCoefs( grad_coef=1.0, prior_coef=1.0, noise_coef=1.0, overall_coef=mask, grad_correction=None, ) def __repr__(self) -> str: return f"MaskPreconditioner(masks={self.masks})"
[docs] class CompositePreconditioner(Preconditioner): """Combines multiple preconditioners by multiplying their coefficients""" def __new__(cls, preconditioners: list[Preconditioner]): # Filter out identity preconditioners non_identity_preconditioners = [ p for p in preconditioners if not isinstance(p, IdentityPreconditioner) ] # If no preconditioners left, return identity if not non_identity_preconditioners: return IdentityPreconditioner() # If only one preconditioner left, return it directly if len(non_identity_preconditioners) == 1: instance = non_identity_preconditioners[0] else: instance = super().__new__(cls) instance.preconditioners = non_identity_preconditioners return instance def __init__(self, preconditioners: list[Preconditioner]): self.preconditioners = preconditioners def get_coefficients( self, param: torch.Tensor, grad: torch.Tensor, state: dict ) -> PreconditionerCoefs: # Initialize with first preconditioner's coefficients coefs = self.preconditioners[0].get_coefficients(param, grad, state) # Multiply coefficients from remaining preconditioners for precond in self.preconditioners[1:]: coefs = PreconditionerCoefs.combine( coefs, precond.get_coefficients(param, grad, state) ) return coefs def __repr__(self) -> str: return f"CompositePreconditioner({', '.join(repr(precond) for precond in self.preconditioners)})"
[docs] class RMSpropPreconditioner(Preconditioner): """RMSprop-style diagonal preconditioning""" def __init__( self, alpha: float = 0.99, eps: float = 1e-1, add_grad_correction=False ): self.alpha = alpha self.eps = eps self.add_grad_correction = add_grad_correction if self.add_grad_correction: raise NotImplementedError( "Gradient correction not yet implemented for RMSprop" ) def get_coefficients( self, param: torch.Tensor, grad: torch.Tensor, state: dict ) -> PreconditionerCoefs: # Update running average of squared gradient if "square_avg" not in state: state["square_avg"] = torch.zeros_like(grad) state["square_avg"].mul_(self.alpha).addcmul_(grad, grad, value=1 - self.alpha) # Compute preconditioner (1/sqrt(v)) preconditioner = 1.0 / (torch.sqrt(state["square_avg"]) + self.eps) return PreconditionerCoefs( grad_coef=preconditioner, prior_coef=preconditioner, noise_coef=torch.sqrt(preconditioner), overall_coef=1.0, grad_correction=None, ) def __repr__(self) -> str: return f"RMSpropPreconditioner(alpha={self.alpha}, eps={self.eps}, add_grad_correction={self.add_grad_correction})"
[docs] class NHTPreconditioning(Preconditioner): """Nose-Hoover Thermostat preconditioning""" def __init__(self, diffusion_factor: float = 0.01, eps: float = 1e-8): self.diffusion_factor = diffusion_factor self.eps = eps def get_coefficients( self, param: torch.Tensor, grad: torch.Tensor, state: dict ) -> PreconditionerCoefs: # Initialize or get thermostat variable state["thermostat"] = thermostat = state.get( "thermostat", torch.tensor(self.diffusion_factor, device=param.device) ) # Initialize or get momentum state["momentum"] = momentum = state.get( "momentum", torch.randn_like(param.data) ) grad_correction = -thermostat * momentum # Update thermostat based on kinetic energy kinetic_energy = ( torch.sum(momentum * momentum, dtype=torch.float32) / momentum.numel() ) state["thermostat"] += kinetic_energy - 1.0 return PreconditionerCoefs( grad_coef=1.0, prior_coef=1.0, noise_coef=torch.sqrt(torch.tensor(2.0 * self.diffusion_factor)), overall_coef=1.0, grad_correction=grad_correction, ) def __repr__(self) -> str: return f"NHTPreconditioning(diffusion_factor={self.diffusion_factor}, eps={self.eps})"