devinterp.backends.default.slt package¶
Submodules¶
devinterp.backends.default.slt.sampler module¶
- devinterp.backends.default.slt.sampler.sample(model: ~torch.nn.modules.module.Module, loader: ~torch.utils.data.dataloader.DataLoader, callbacks: ~typing.List[~devinterp.slt.callback.SamplerCallback], evaluate: ~typing.Callable[[~torch.nn.modules.module.Module, ~torch.Tensor], ~devinterp.utils.Outputs | ~typing.Dict[str, ~torch.Tensor] | ~typing.Tuple[~torch.Tensor, ...] | ~torch.Tensor] | None = None, sampling_method: ~typing.Type[~torch.optim.optimizer.Optimizer] = <class 'devinterp.optim.sgld.SGLD'>, optimizer_kwargs: ~typing.Dict[str, float | ~typing.Literal['adaptive']] | None = None, num_draws: int = 100, num_chains: int = 10, num_burnin_steps: int = 0, num_steps_bw_draws: int = 1, init_loss: float | None = None, gradient_accumulation_steps: int = 1, cores: int = 1, seed: int | ~typing.List[int] | None = None, device: ~torch.device | str = device(type='cpu'), verbose: bool = True, optimize_over_per_model_param: ~typing.Dict[str, ~typing.List[bool]] | None = None, batch_size: bool = 1, **kwargs)¶
Sample model weights using a given sampling_method, supporting multiple chains/cores, and calculate the observables (loss, llc, etc.) for each callback passed along. The
update
,finalize
andsample
methods of eachSamplerCallback()
are called during sampling, after sampling, and atsampler_callback_object.get_results()
respectively.After calling this function, the stats of interest live in the callback object.
- Parameters:
model (torch.nn.Module) – The neural network model.
loader (DataLoader) – DataLoader for input data.
evaluate (EvaluateFn) – Maps a model and batch of data to an object with a loss attribute.
callbacks (list[SamplerCallback]) – list of callbacks, each of type SamplerCallback
sampling_method (torch.optim.Optimizer, optional) – Sampling method to use (a PyTorch optimizer under the hood). Default is SGLD
optimizer_kwargs (dict, optional) – Keyword arguments for the PyTorch optimizer (used as sampler here). Default is None (using standard SGLD parameters as defined in the SGLD class)
num_draws (int, optional) – Number of samples to draw. Default is 100
num_chains (int, optional) – Number of chains to run. Default is 10
num_burnin_steps (int, optional) – Number of burn-in steps before sampling. Default is 0
num_steps_bw_draws (int, optional) – Number of steps between each draw. Default is 1
init_loss (float, optional) – Initial loss for use in LLCEstimator and OnlineLLCEstimator
cores (int, optional) – Number of cores for parallel execution. Default is 1
seed (int, optional) – Random seed(s) for sampling. Each chain gets a different (deterministic) seed if this is passed. Default is None
device (str or torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
verbose (bool, optional) – whether to print sample chain progress. Default is True
- Raises:
ValueError – if derivative callbacks (f.e.
OnlineLossStatistics()
) are passed before base callbacks (f.e.OnlineLLCEstimator()
)Warning – if num_burnin_steps < num_draws
Warning – if num_draws > len(loader)
Warning – if using seeded runs
- Returns:
None (access LLCs or other observables through callback_object.get_results())