devinterp.backends.tpu.slt package¶
Submodules¶
devinterp.backends.tpu.slt.sampler module¶
- devinterp.backends.tpu.slt.sampler.sample(model: ~torch.nn.modules.module.Module, loader: ~torch.utils.data.dataloader.DataLoader, callbacks: ~typing.List[~devinterp.slt.callback.SamplerCallback] | ~typing.Dict[str, ~devinterp.slt.callback.SamplerCallback], evaluate: ~typing.Callable = <function <lambda>>, sampling_method: ~typing.Type[~torch.optim.optimizer.Optimizer] = <class 'devinterp.optim.sgld.SGLD'>, optimizer_kwargs: ~typing.Dict[str, float | ~typing.Literal['adaptive']] | None = None, scheduler_cls: ~typing.Type[~torch.optim.lr_scheduler._LRScheduler] | None = None, scheduler_kwargs: ~typing.Dict[str, ~typing.Any] | None = None, num_draws: int = 100, num_chains: int = 10, num_burnin_steps: int = 0, num_steps_bw_draws: int = 1, gradient_accumulation_steps: int = 1, cores: int = 1, seed: int | ~typing.List[int] | None = None, device: ~torch.device | str = device(type='cpu'), verbose: bool = True, optimize_over_per_model_param: ~typing.Dict[str, ~torch.Tensor] | None = None, batch_size: int = 32, init_noise: float | None = None, shuffle: bool = True, use_alternate_batching=False, init_loss=None, **kwargs)¶
Sample model weights using a given sampling_method, supporting multiple chains/cores, and calculate the observables (loss, llc, etc.) for each callback passed along. The
update
,finalize
andsample
methods of eachSamplerCallback()
are called during sampling, after sampling, and atsampler_callback_object.get_results()
respectively.After calling this function, the stats of interest live in the callback object.
- Parameters:
model (torch.nn.Module) – The neural network model.
loader (DataLoader) – DataLoader for input data.
criterion (Callable) – Loss function.
callbacks (list[SamplerCallback]) – list of callbacks, each of type SamplerCallback
sampling_method (torch.optim.Optimizer, optional) – Sampling method to use (a PyTorch optimizer under the hood). Default is SGLD
optimizer_kwargs (dict, optional) – Keyword arguments for the PyTorch optimizer (used as sampler here). Default is None (using standard SGLD parameters as defined in the SGLD class)
num_draws (int, optional) – Number of samples to draw. Default is 100
num_chains (int, optional) – Number of chains to run. Default is 10
num_burnin_steps (int, optional) – Number of burn-in steps before sampling. Default is 0
num_steps_bw_draws (int, optional) – Number of steps between each draw. Default is 1
cores (int, optional) – Number of cores for parallel execution. Default is 1
seed (int, optional) – Random seed(s) for sampling. Each chain gets a different (deterministic) seed if this is passed. Default is None
device (str or torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’
verbose (bool, optional) – whether to print sample chain progress. Default is True
- Raises:
ValueError – if derivative callbacks (f.e.
OnlineLossStatistics()
) are passed before base callbacks (f.e.OnlineLLCEstimator()
)Warning – if num_burnin_steps < num_draws
Warning – if num_draws > len(loader)
Warning – if using seeded runs
- Returns:
None (access LLCs or other observables through callback_object.get_results())
- devinterp.backends.tpu.slt.sampler.sample_single_chain(ref_model: ~torch.nn.modules.module.Module, loader: ~torch.utils.data.dataloader.DataLoader, evaluate: ~typing.Callable[[~torch.nn.modules.module.Module, ~torch.Tensor], ~typing.Tuple[~torch.Tensor, ~typing.Dict[str, ~typing.Any]]], num_draws=100, num_burnin_steps=0, num_steps_bw_draws=1, gradient_accumulation_steps: int = 1, sampling_method: ~typing.Type[~torch.optim.optimizer.Optimizer] = <class 'devinterp.optim.sgld.SGLD'>, optimizer_kwargs: ~typing.Dict | None = None, scheduler_cls: ~typing.Type[~torch.optim.lr_scheduler._LRScheduler] | None = None, scheduler_kwargs: ~typing.Dict | None = None, chain: int = 0, seed: int | None = None, verbose=True, device: str | ~torch.device = device(type='xla'), callbacks: ~typing.List[~typing.Callable] = [], optimize_over_per_model_param: ~typing.Dict[str, ~torch.Tensor] | None = None, init_noise: float | None = None, use_alternate_batching=False, **kwargs)¶
Base function to sample a single chain. This function is called by the sample function on both single and multi-core setups.