Source code for devinterp.optim.sgmcmc

import warnings
from typing import Iterable, Iterator, Literal, NamedTuple, Optional, Union

import torch
from torch.optim import Optimizer

from devinterp.optim.metrics import Metrics
from devinterp.optim.preconditioner import (
    CompositePreconditioner,
    IdentityPreconditioner,
    MaskPreconditioner,
    NHTPreconditioning,
    Preconditioner,
    PreconditionerCoefs,
    RMSpropPreconditioner,
)
from devinterp.optim.prior import CompositePrior, GaussianPrior, Prior
from devinterp.optim.sketch import CountSketch, SketchBuffer

SamplingMethodLiteral = Literal["sgld", "rmsprop_sgld", "sgnht"]


class _ComponentVectors(NamedTuple):
    """Post-preconditioned component vectors from a single parameter update."""

    # Structurally identical to the copy in sgld.py. Kept separate because
    # SGLD is deprecated — sharing a private type would couple the old
    # module to this one and complicate its eventual removal.

    scaled_grad: torch.Tensor
    unscaled_grad: torch.Tensor
    localization: torch.Tensor
    weight_decay: torch.Tensor
    noise: torch.Tensor


[docs] class SGMCMC(Optimizer): """Unified Stochastic Gradient Markov Chain Monte Carlo (SGMCMC) optimizer. This optimizer implements a general SGMCMC framework that unifies several common variants like SGLD, SGHMC, and SGNHT. It supports custom priors and preconditioners for flexible posterior sampling. The general update rule follows: .. math:: Δθ_t = (ε/2)G(θ_t)(nβ/m ∑∇logp(y|x,θ_t) + ∇logp(θ_t)) + √(εG(θ_t))N(0,σ²) where: * ε is learning rate * nβ is inverse temperature * m is batch size * G(θ) is the preconditioner * p(θ) is the prior distribution * σ is noise level :param params: Iterable of parameters to optimize or dicts defining parameter groups :param lr: Learning rate ε (default: 0.01) :param noise_level: Standard deviation σ of the noise (default: 1.0) :param nbeta: Inverse temperature nβ (default: 1.0) :param prior: Prior distribution specification. Can be: In SGMCMC, weight_decay and localization should be implemented as a prior. See the `rmsprop_sgld()` method for an example. - GaussianPrior instance - string specifying center type - iterable of tensor centers - float specifying precision (default: None) :param preconditioner: Preconditioner specification. Can be: - "identity" for no preconditioning - "rmsprop" for RMSprop-style preconditioning - Preconditioner instance (default: None, equivalent to "identity") :param preconditioner_kwargs: Additional keyword arguments for preconditioner :param bounding_box_size: Size of bounding box around initial parameters :param mask: Boolean mask for restricting updatable parameters :param save_metrics: Whether to track metrics during training (default: False) :type params: Iterable :type lr: float :type noise_level: float :type nbeta: float :type prior: Optional[Union[Prior, Literal["initial"], Iterable[torch.Tensor], float]] :type preconditioner: Optional[Union[Preconditioner, str]] :type preconditioner_kwargs: Optional[dict] :type bounding_box_size: Optional[float] :type mask: Optional[torch.Tensor] :type save_metrics: bool Example:: from devinterp.utils import default_nbeta # Basic SGLD-style usage optimizer = SGMCMC.sgld( model.parameters(), lr=0.1, nbeta=default_nbeta(dataloader) ) # RMSprop-preconditioned with prior optimizer = SGMCMC.rmsprop_sgld( model.parameters(), lr=0.01, localization=0.1, nbeta=default_nbeta(dataloader) ) # SGNHT-style with thermostat optimizer = SGMCMC.sgnht( model.parameters(), lr=0.01, diffusion_factor=0.01, nbeta=default_nbeta(dataloader) ) # Training loop for data, target in dataloader: optimizer.zero_grad() loss = criterion(model(data), target) loss.backward() optimizer.step() Notes: * Use the factory methods (sgld, rmsprop_sgld, sgnht) for easier initialization * nbeta should typically be set using devinterp.utils.default_nbeta() rather than manually * The prior helps explore the local posterior by pulling toward initialization * Use save_metrics=True and call get_metrics() to access tracked metrics References: * Welling & Teh (2011) - Original SGLD paper * Li et al. (2015) - RMSprop-SGLD * Ding et al. (2014) - SGNHT * Lau et al. (2023) - Implementation with localization term """ def __init__( self, params, *, lr: float = 0.01, noise_level: float = 1.0, nbeta: float = 1.0, prior: Optional[ Union[Prior, Literal["initial"], Iterable[torch.Tensor], float] ] = None, preconditioner: Optional[ Union[Preconditioner, Literal["identity", "rmsprop"]] ] = "identity", preconditioner_kwargs: Optional[dict] = None, bounding_box_size: Optional[float] = None, mask: Optional[torch.Tensor] = None, save_metrics: bool = False, sketch_dim: Optional[int] = None, sketch_seed: int = 0, ): # Handle single parameter case if isinstance(params, (torch.Tensor, dict)): params = [params] # Define per-group parameters defaults = dict( lr=lr, noise_level=noise_level, nbeta=nbeta, prior=prior, preconditioner=preconditioner, preconditioner_kwargs=preconditioner_kwargs or {}, bounding_box_size=bounding_box_size, mask=mask, ) super().__init__(params, defaults) self.save_metrics = save_metrics if sketch_dim is not None: self._total_numel = sum( p.numel() for group in self.param_groups for p in group["params"] ) device = next(iter(self.param_groups[0]["params"])).device self._sketch = CountSketch.create( self._total_numel, sketch_dim, sketch_seed ).to(device) # Single buffer shared across all param groups. Sketch accumulation # is purely additive (linear), so per-group buffers are unnecessary. # N.B. With FSDP each rank only sees a parameter shard; this buffer # would hold only the local contribution and would need an all-reduce # across ranks before consumption. See the FSDP guard in sampler.py. self._sketch_buf = SketchBuffer.create(sketch_dim, device) self.save_sketches = True else: self._sketch = None self._sketch_buf = None self.save_sketches = False # Initialize each parameter group for group in self.param_groups: self._init_group(group) def _init_group(self, group: dict) -> None: """Initialize all group-specific settings. Prior initialization supports several formats: 1. Prior object: Use any existing Prior instance directly - localization and weight_decay should be implemented via a Prior. To see how to do this, look at e.g. `SGMCMC.sgld()`. 2. String ("initial"): Creates GaussianPrior centered at parameter initialization 3. Tensor centers: Creates GaussianPrior centered at provided tensor values 4. Number (float/int): Creates GaussianPrior with specified localization strength Args: group: Parameter group dictionary containing optimizer settings """ # Initialize prior prior = group["prior"] localization = group.get("localization", 0.0) if prior is not None or localization: prior = prior if prior is not None else "initial" localization = localization or 1.0 if isinstance(prior, Prior): group["prior"] = prior elif isinstance(prior, str): group["prior"] = GaussianPrior(localization=localization, center=prior) elif isinstance(prior, Iterable) and not isinstance(prior, str): group["prior"] = GaussianPrior(localization=localization, center=prior) elif isinstance(prior, (int, float)): group["prior"] = GaussianPrior(localization=float(prior)) else: raise ValueError(f"Unsupported prior type: {type(prior)}") # Initialize preconditioner preconditioner = group["preconditioner"] preconditioner_kwargs = group.pop("preconditioner_kwargs") if preconditioner is None or preconditioner == "identity": group["preconditioner"] = IdentityPreconditioner(**preconditioner_kwargs) elif preconditioner == "rmsprop": group["preconditioner"] = RMSpropPreconditioner(**preconditioner_kwargs) elif isinstance(preconditioner, Preconditioner): group["preconditioner"] = preconditioner else: raise ValueError(f"Unsupported preconditioner type: {preconditioner}") mask = group.get("mask", None) if mask is not None: # Convert mask to masks (1.0 where True, 0.0 where False) if isinstance(mask, torch.Tensor): mask = [mask] def _process_mask(m, p): if not isinstance(m, torch.Tensor): m = torch.tensor(m).to(self.device) # Validate mask shape matches parameter shape if m.shape != p.shape: raise ValueError( f"Mask shape {m.shape} does not match parameter shape {p.shape}. " ) return m.float() params = list(group["params"]) masks = [_process_mask(_mask, p) for _mask, p in zip(mask, params)] mask_preconditioner = MaskPreconditioner(masks=masks) if group["preconditioner"] is not None: group["preconditioner"] = CompositePreconditioner( [group["preconditioner"], mask_preconditioner] ) else: group["preconditioner"] = mask_preconditioner pstates = {} # Initialize prior state if needed if group["prior"] is not None: pstates = group["prior"].initialize(list(group["params"])) # Initialize states for each parameter in the group store_initial = group["bounding_box_size"] is not None or self.save_metrics for i, p in enumerate(group["params"]): pstate = pstates.get(p, {}) self.state[p] = { **self.state[p], **pstate, "param_idx": i, "initial_param": (p.data.clone().detach() if store_initial else None), } # Initialize per-group metrics if self.save_metrics: device = next(iter(group["params"])).device group["metrics"] = Metrics().to(device) def _decompose_update( self, group: dict, p: torch.Tensor, loc_grad: Optional[torch.Tensor], wd_grad: Optional[torch.Tensor], noise: torch.Tensor, preconditioning: PreconditionerCoefs, d_p: torch.Tensor, ) -> _ComponentVectors: """Decompose a parameter update into its post-preconditioned component vectors. Shared by both metrics accumulation and sketch scattering. Debug assertions verify the decomposition matches the actual update. Args: loc_grad: Pre-computed localization prior gradient, or None. wd_grad: Pre-computed weight_decay prior gradient, or None. d_p: The actual deterministic update vector (for debug assertion). """ _lr = group["lr"] _precond = preconditioning.overall_coef _half_lr = 0.5 * _lr _grad_pre = preconditioning.grad_coef * p.grad.mul(group["nbeta"]) _loc_pre = loc_grad if loc_grad is not None else torch.zeros_like(p) _wd_pre = wd_grad if wd_grad is not None else torch.zeros_like(p) _scaled_grad = _half_lr * (_precond * _grad_pre) _unscaled_grad = _half_lr * p.grad.mul(group["nbeta"]) _prior_scale = _half_lr * preconditioning.prior_coef * _precond _loc = _prior_scale * _loc_pre _wd = _prior_scale * _wd_pre # noise already includes overall_coef from step() _noise = (group["lr"] ** 0.5) * preconditioning.noise_coef * noise if __debug__: _combined_prior = ( _half_lr * preconditioning.prior_coef * _precond * (_loc_pre + _wd_pre) ) # (1) grad + prior must reconstruct the full step update. # step() computes d_p at _grad_pre's precision before # overall_coef promotion, so rounding error is at that scale. _step_eps = torch.finfo(_grad_pre.dtype).eps _decomp_scale = float((_scaled_grad.abs() + _combined_prior.abs()).max()) _decomp_atol = max(_decomp_scale * _step_eps * 16, 1e-12) torch.testing.assert_close( (_scaled_grad + _combined_prior).float(), (_half_lr * d_p).float(), atol=_decomp_atol, rtol=0, msg=lambda s: f"Decomposition components don't match gradient update:\n{s}", ) # (2) loc + wd must match the combined prior. # _loc_pre + _wd_pre sums at input precision regardless # of preconditioner promotion, so use p.grad.dtype. _input_eps = torch.finfo(p.grad.dtype).eps _dist_scale = float((_loc.abs() + _wd.abs()).max()) _dist_atol = max(_dist_scale * _input_eps * 16, 1e-12) torch.testing.assert_close( (_loc + _wd).float(), _combined_prior.float(), atol=_dist_atol, rtol=0, msg=lambda s: f"Prior distribution error exceeds bound:\n{s}", ) return _ComponentVectors( scaled_grad=_scaled_grad, unscaled_grad=_unscaled_grad, localization=_loc, weight_decay=_wd, noise=_noise, ) def _accumulate_metrics( self, group: dict, components: _ComponentVectors, distance: torch.Tensor, preconditioning: PreconditionerCoefs, p: torch.Tensor, ) -> None: """Accumulate decomposed component vectors into per-group metrics.""" group["metrics"].add_sum_squared_( scaled_grad=components.scaled_grad, unscaled_grad=components.unscaled_grad, localization=components.localization, weight_decay=components.weight_decay, noise=components.noise, distance=distance, ) group["metrics"].add_dot_products_( scaled_grad=components.scaled_grad, prior=components.localization + components.weight_decay, noise=components.noise, ) _precond = preconditioning.overall_coef if isinstance(_precond, torch.Tensor): group["metrics"].numel += int(_precond.count_nonzero()) else: group["metrics"].numel += p.numel() def _scatter_sketches(self, components: _ComponentVectors, offset: int) -> None: """Scatter decomposed component vectors into the sketch buffer.""" buf = self._sketch_buf sketch = self._sketch assert buf is not None and sketch is not None sketch.scatter_into_(buf.scaled_grad, components.scaled_grad, offset) sketch.scatter_into_(buf.unscaled_grad, components.unscaled_grad, offset) sketch.scatter_into_(buf.localization, components.localization, offset) sketch.scatter_into_(buf.weight_decay, components.weight_decay, offset) sketch.scatter_into_(buf.noise, components.noise, offset)
[docs] def get_sketches(self) -> SketchBuffer: """Return a CPU snapshot of the current sketch buffer.""" if not self.save_sketches: raise RuntimeError("Sketches not enabled.") assert self._sketch_buf is not None return SketchBuffer( **{ q: getattr(self._sketch_buf, q).detach().cpu().clone() for q in SketchBuffer.QUANTITIES } )
[docs] @torch.no_grad() def step( self, closure=None, noise_generator: Optional[torch.Generator] = None ) -> None: """Perform a single optimization step. Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. noise_generator (torch.Generator, optional): Generator for reproducible noise. Returns: Optional[float]: The loss value if closure is provided, else None. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() _need_decomposition = self.save_metrics or self.save_sketches param_offset = 0 if self.save_sketches: assert self._sketch_buf is not None self._sketch_buf.zero_() for group_idx, group in enumerate(self.param_groups): # Metrics lifecycle: zero → accumulate per-param → sqrt (see Metrics docstring) if self.save_metrics: group["metrics"].zero_() prior = group["prior"] preconditioner = group["preconditioner"] params = group["params"] for i, p in enumerate(params): if p.grad is None: if self.save_sketches: # Advance past this param's region in the sketch's # hash/sign index space so subsequent params stay aligned. param_offset += p.numel() continue state = self.state[p] # Get preconditioner coefficients preconditioning = preconditioner.get_coefficients(p, p.grad, state) # Gradient computation d_p = preconditioning.grad_coef * p.grad.mul(group["nbeta"]) # Prior contribution — decompose into sub-priors when # tracking metrics or sketches so we can record localization # and weight_decay separately. Must happen before p.data # is modified below. loc_grad = None wd_grad = None if prior is not None: if _need_decomposition: loc_grad = torch.zeros_like(p) wd_grad = torch.zeros_like(p) sub_priors = ( prior.priors if isinstance(prior, CompositePrior) else [prior] ) for sub_prior in sub_priors: sub_grad = sub_prior.grad(p.data, state) if ( isinstance(sub_prior, GaussianPrior) and sub_prior.center is None ): wd_grad += sub_grad else: loc_grad += sub_grad prior_grad = loc_grad + wd_grad else: prior_grad = prior.grad(p.data, state) d_p.add_(preconditioning.prior_coef * prior_grad) d_p = preconditioning.overall_coef * d_p if self.save_metrics: _distance = p.data - state["initial_param"] p.data.add_(d_p, alpha=-0.5 * group["lr"]) # Noise addition noise = torch.normal( mean=0.0, std=group["noise_level"], size=d_p.size(), device=d_p.device, generator=noise_generator, ) noise = preconditioning.overall_coef * noise # Parameter updates p.data.add_( preconditioning.noise_coef * noise, alpha=group["lr"] ** 0.5, ) # Bounding box enforcement if group["bounding_box_size"] is not None: initial_param = state["initial_param"] torch.clamp_( p.data, min=initial_param - group["bounding_box_size"], max=initial_param + group["bounding_box_size"], ) if _need_decomposition: components = self._decompose_update( group, p, loc_grad, wd_grad, noise, preconditioning, d_p ) if self.save_metrics: self._accumulate_metrics( group, components, _distance, preconditioning, p ) if self.save_sketches: self._scatter_sketches(components, param_offset) if self.save_sketches: param_offset += p.numel() if self.save_metrics: # All params accumulated; convert sum-of-squares to L2 norms group["metrics"].sqrt_norms_() return loss
[docs] def get_params(self) -> Iterator[torch.Tensor]: """Helper to get all parameters""" for group in self.param_groups: for p in group["params"]: yield p
[docs] def iter_group_metrics(self) -> Iterator[Metrics]: """Yield metrics for each param group.""" if not self.save_metrics: raise RuntimeError("Metrics not enabled. Set save_metrics=True.") for group in self.param_groups: yield group["metrics"]
[docs] def get_metrics(self) -> Metrics: """Aggregate metrics across all param groups into a single CPU Metrics.""" return Metrics.aggregate(self.iter_group_metrics())
[docs] @classmethod def sgld( cls, params, lr=0.01, noise_level=1.0, weight_decay=0.0, localization=0.0, nbeta=1.0, bounding_box_size=None, mask=None, save_metrics: bool = False, sketch_dim: Optional[int] = None, sketch_seed: int = 0, ): """Factory method to create an SGMCMC instance that implements Stochastic Gradient Langevin Dynamics (SGLD) with a localization term (Lau et al. 2023). This optimizer combines Stochastic Gradient Descent (SGD) with Langevin Dynamics, introducing Gaussian noise to the gradient updates. This makes it sample weights from the posterior distribution, instead of finding point estimates through optimization (Welling and Teh 2011). The update rule follows:: Δθ_t = (ε/2)(nβ/m ∑∇logp(y|x,θ_t) + γ(θ_0-θ_t) - λθ_t) + N(0,εσ²) where: * ε is learning rate * nβ is inverse temperature * m is batch size * γ is localization strength * λ is weight decay * σ is noise level This follows Lau et al.'s (2023) implementation, which modifies Welling and Teh (2011) by: * Omitting the learning rate schedule (this functionality could be recoverd by using a separate learning rate scheduler). * Adding a localization term that pulls weights toward initialization * Using tempered Bayes paradigm with inverse temperature nβ This allows SGMCMC to be used as a drop-in replacement for SGLD. :param params: Iterable of parameters to optimize :param lr: Learning rate (default: 0.01) :param noise_level: Standard deviation of noise (default: 1.0) :param weight_decay: Weight decay factor. Applied with preconditioning (Adam-style). Creates a GaussianPrior centered at zero with localization=weight_decay. (default: 0.0) :param localization: Strength of pull toward initial parameters. Creates a GaussianPrior centered at initialization with localization=localization. (default: 0.0) :param nbeta: Inverse temperature (default: 1.0) :param bounding_box_size: Size of bounding box around initial parameters (default: None) :param mask: Boolean mask for restricting updatable parameters (default: None) :param save_metrics: Whether to track metrics during training (default: False) :return: SGMCMC optimizer instance """ if noise_level != 1.0: warnings.warn( "noise_level in SGLD is unequal to one, this removes SGLD posterior sampling guarantees." ) if nbeta == 1.0: warnings.warn( "nbeta set to 1, LLC estimates will be off unless you know what you're doing. Use utils.default_nbeta(dataloader) instead" ) priors = [] if weight_decay > 0: priors.append(GaussianPrior(localization=weight_decay, center=None)) if localization > 0: priors.append(GaussianPrior(localization=localization, center="initial")) prior = CompositePrior(priors) instance = cls( params, lr=lr, noise_level=noise_level, nbeta=nbeta, prior=prior, bounding_box_size=bounding_box_size, mask=mask, save_metrics=save_metrics, sketch_dim=sketch_dim, sketch_seed=sketch_seed, ) return instance
[docs] @classmethod def sgnht( cls, params, lr=0.01, diffusion_factor=0.01, nbeta=1.0, bounding_box_size=None, save_metrics: bool = False, sketch_dim: Optional[int] = None, sketch_seed: int = 0, ): """Factory method to create an SGMCMC instance that matches SGNHT's interface. This allows SGMCMC to be used as a drop-in replacement for SGNHT. :param params: Iterable of parameters to optimize :param lr: Learning rate (default: 0.01) :param diffusion_factor: Diffusion factor (default: 0.01) :param nbeta: Inverse temperature (default: 1.0) :param bounding_box_size: Size of bounding box around initial parameters (default: None) :param save_metrics: Whether to track metrics (default: False) :return: SGMCMC optimizer instance """ if nbeta == 1.0: warnings.warn( "nbeta set to 1, LLC estimates will be off unless you know what you're doing. Use utils.default_nbeta(dataloader) instead" ) # Create NHT preconditioner preconditioner = NHTPreconditioning(diffusion_factor=diffusion_factor) instance = cls( params, lr=lr, noise_level=1.0, # Noise scaling handled by preconditioner nbeta=nbeta, preconditioner=preconditioner, bounding_box_size=bounding_box_size, save_metrics=save_metrics, sketch_dim=sketch_dim, sketch_seed=sketch_seed, ) return instance
[docs] @classmethod def rmsprop_sgld( cls, params, lr=0.01, noise_level=1.0, weight_decay=0.0, localization=0.0, nbeta=1.0, alpha=0.99, eps=0.1, add_grad_correction=False, bounding_box_size=None, mask=None, save_metrics: bool = False, sketch_dim: Optional[int] = None, sketch_seed: int = 0, ): """Factory method to create an SGMCMC instance that wraps RMSprop's adaptive preconditioning with SGLD to perform Bayesian sampling of neural network weights. The update rule with preconditioning follows:: V(θ_t) = αV(θ_{t-1}) + (1-α)g̅(θ_t)g̅(θ_t) G(θ_t) = diag(1/(λ1 + √V(θ_t))) Δθ_t = (ε/2)G(θ_t)(nβ/m ∑∇logp(y|x,θ_t) + γ(θ_0-θ_t) - λθ_t) + √(εG(θ_t))N(0,σ²) where: * ε is learning rate * nβ is effective dataset size (=dataset size * inverse temperature) * m is batch size * γ is localization strength * λ is weight decay * σ is noise level * G(θ) is the RMSprop preconditioner * V(θ) tracks squared gradient moving average * α is the exponential decay rate Key differences from standard SGLD: * Uses RMSprop preconditioner to adapt to local geometry and curvature * Scales both the gradients and noise by the preconditioner * Handles pathological curvature through adaptive step sizes :param params: Iterable of parameters to optimize :param lr: Learning rate (default: 0.01) :param noise_level: Standard deviation of noise (default: 1.0) :param weight_decay: Weight decay factor. Applied with preconditioning (Adam-style). Creates a GaussianPrior centered at zero with localization=weight_decay. (default: 0.0) :param localization: Strength of pull toward initial parameters. Creates a GaussianPrior centered at initialization with localization=localization. (default: 0.0) :param nbeta: Inverse temperature (default: 1.0) :param alpha: RMSprop moving average coefficient (default: 0.99) :param eps: RMSprop stability constant (default: 0.1) :param add_grad_correction: Whether to add gradient correction term (default: False) :param bounding_box_size: Size of bounding box around initial parameters (default: None) :param mask: Boolean mask for restricting updatable parameters (default: None) :param save_metrics: Whether to track metrics during training (default: False) :return: SGMCMC optimizer instance """ if noise_level != 1.0: warnings.warn( "noise_level in RMSProp-SGLD is unequal to one, this removes SGLD posterior sampling guarantees." ) if nbeta == 1.0: warnings.warn( "nbeta set to 1, LLC estimates will be off unless you know what you're doing. Use utils.default_nbeta(dataloader) instead" ) priors = [] if weight_decay > 0: priors.append(GaussianPrior(localization=weight_decay, center=None)) if localization > 0: priors.append(GaussianPrior(localization=localization, center="initial")) prior = CompositePrior(priors) # Configure RMSprop preconditioner preconditioner_kwargs = { "alpha": alpha, "eps": eps, "add_grad_correction": add_grad_correction, } instance = cls( params, lr=lr, noise_level=noise_level, nbeta=nbeta, prior=prior, preconditioner="rmsprop", preconditioner_kwargs=preconditioner_kwargs, bounding_box_size=bounding_box_size, mask=mask, save_metrics=save_metrics, sketch_dim=sketch_dim, sketch_seed=sketch_seed, ) return instance
@classmethod def get_method(cls, method: SamplingMethodLiteral): if method == "sgld": return cls.sgld elif method == "rmsprop_sgld": return cls.rmsprop_sgld elif method == "sgnht": return cls.sgnht else: raise ValueError( f"`method` should be one of 'sgld', 'rmsprop_sgld', or 'sgnht'. Got {method}" ) @property def device(self): return next(self.get_params()).device