devinterp.backends.default.slt package¶
Submodules¶
devinterp.backends.default.slt.sampler module¶
- devinterp.backends.default.slt.sampler.sample(model: ~torch.nn.modules.module.Module, loader: ~torch.utils.data.dataloader.DataLoader, callbacks: ~typing.List[~devinterp.slt.callback.SamplerCallback], evaluate: ~typing.Callable[[~torch.nn.modules.module.Module, ~torch.Tensor], ~devinterp.utils.Outputs | ~typing.Dict[str, ~torch.Tensor] | ~typing.Tuple[~torch.Tensor, ...] | ~torch.Tensor] | None = None, sampling_method: ~typing.Type[~torch.optim.optimizer.Optimizer] = <class 'devinterp.optim.sgld.SGLD'>, optimizer_kwargs: ~typing.Dict[str, float | ~typing.Literal['adaptive']] | None = None, num_draws: int = 100, num_chains: int = 10, num_burnin_steps: int = 0, num_steps_bw_draws: int = 1, init_loss: float | None = None, gradient_accumulation_steps: int = 1, cores: int = 1, seed: int | ~typing.List[int] | None = None, device: ~torch.device | str = device(type='cpu'), verbose: bool = True, optimize_over_per_model_param: ~typing.Dict[str, ~typing.List[bool]] | None = None, batch_size: bool = 1, **kwargs)¶
- Sample model weights using a given sampling_method, supporting multiple chains/cores, and calculate the observables (loss, llc, etc.) for each callback passed along. The - update,- finalizeand- samplemethods of each- SamplerCallback()are called during sampling, after sampling, and at- sampler_callback_object.get_results()respectively.- After calling this function, the stats of interest live in the callback object. - Parameters:
- model (torch.nn.Module) – The neural network model. 
- loader (DataLoader) – DataLoader for input data. 
- evaluate (EvaluateFn) – Maps a model and batch of data to an object with a loss attribute. 
- callbacks (list[SamplerCallback]) – list of callbacks, each of type SamplerCallback 
- sampling_method (torch.optim.Optimizer, optional) – Sampling method to use (a PyTorch optimizer under the hood). Default is SGLD 
- optimizer_kwargs (dict, optional) – Keyword arguments for the PyTorch optimizer (used as sampler here). Default is None (using standard SGLD parameters as defined in the SGLD class) 
- num_draws (int, optional) – Number of samples to draw. Default is 100 
- num_chains (int, optional) – Number of chains to run. Default is 10 
- num_burnin_steps (int, optional) – Number of burn-in steps before sampling. Default is 0 
- num_steps_bw_draws (int, optional) – Number of steps between each draw. Default is 1 
- init_loss (float, optional) – Initial loss for use in LLCEstimator and OnlineLLCEstimator 
- cores (int, optional) – Number of cores for parallel execution. Default is 1 
- seed (int, optional) – Random seed(s) for sampling. Each chain gets a different (deterministic) seed if this is passed. Default is None 
- device (str or torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’ 
- verbose (bool, optional) – whether to print sample chain progress. Default is True 
 
- Raises:
- ValueError – if derivative callbacks (f.e. - OnlineLossStatistics()) are passed before base callbacks (f.e.- OnlineLLCEstimator())
- Warning – if num_burnin_steps < num_draws 
- Warning – if num_draws > len(loader) 
- Warning – if using seeded runs 
 
- Returns:
- None (access LLCs or other observables through callback_object.get_results())