devinterp.backends.default.slt package

Submodules

devinterp.backends.default.slt.sampler module

devinterp.backends.default.slt.sampler.sample(model: ~torch.nn.modules.module.Module, loader: ~torch.utils.data.dataloader.DataLoader, callbacks: ~typing.List[~devinterp.slt.callback.SamplerCallback], evaluate: ~typing.Callable[[~torch.nn.modules.module.Module, ~torch.Tensor], ~devinterp.utils.Outputs | ~typing.Dict[str, ~torch.Tensor] | ~typing.Tuple[~torch.Tensor, ...] | ~torch.Tensor] | None = None, sampling_method: ~typing.Type[~torch.optim.optimizer.Optimizer] = <class 'devinterp.optim.sgld.SGLD'>, optimizer_kwargs: ~typing.Dict[str, float | ~typing.Literal['adaptive']] | None = None, num_draws: int = 100, num_chains: int = 10, num_burnin_steps: int = 0, num_steps_bw_draws: int = 1, init_loss: float | None = None, gradient_accumulation_steps: int = 1, cores: int = 1, seed: int | ~typing.List[int] | None = None, device: ~torch.device | str = device(type='cpu'), verbose: bool = True, optimize_over_per_model_param: ~typing.Dict[str, ~typing.List[bool]] | None = None, batch_size: bool = 1, **kwargs)

Sample model weights using a given sampling_method, supporting multiple chains/cores, and calculate the observables (loss, llc, etc.) for each callback passed along. The update, finalize and sample methods of each SamplerCallback() are called during sampling, after sampling, and at sampler_callback_object.get_results() respectively.

After calling this function, the stats of interest live in the callback object.

Parameters:
  • model (torch.nn.Module) – The neural network model.

  • loader (DataLoader) – DataLoader for input data.

  • evaluate (EvaluateFn) – Maps a model and batch of data to an object with a loss attribute.

  • callbacks (list[SamplerCallback]) – list of callbacks, each of type SamplerCallback

  • sampling_method (torch.optim.Optimizer, optional) – Sampling method to use (a PyTorch optimizer under the hood). Default is SGLD

  • optimizer_kwargs (dict, optional) – Keyword arguments for the PyTorch optimizer (used as sampler here). Default is None (using standard SGLD parameters as defined in the SGLD class)

  • num_draws (int, optional) – Number of samples to draw. Default is 100

  • num_chains (int, optional) – Number of chains to run. Default is 10

  • num_burnin_steps (int, optional) – Number of burn-in steps before sampling. Default is 0

  • num_steps_bw_draws (int, optional) – Number of steps between each draw. Default is 1

  • init_loss (float, optional) – Initial loss for use in LLCEstimator and OnlineLLCEstimator

  • cores (int, optional) – Number of cores for parallel execution. Default is 1

  • seed (int, optional) – Random seed(s) for sampling. Each chain gets a different (deterministic) seed if this is passed. Default is None

  • device (str or torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’

  • verbose (bool, optional) – whether to print sample chain progress. Default is True

Raises:
  • ValueError – if derivative callbacks (f.e. OnlineLossStatistics()) are passed before base callbacks (f.e. OnlineLLCEstimator())

  • Warning – if num_burnin_steps < num_draws

  • Warning – if num_draws > len(loader)

  • Warning – if using seeded runs

Returns:

None (access LLCs or other observables through callback_object.get_results())

Module contents