devinterp.backends.tpu.slt package

Submodules

devinterp.backends.tpu.slt.sampler module

devinterp.backends.tpu.slt.sampler.sample(model: ~torch.nn.modules.module.Module, loader: ~torch.utils.data.dataloader.DataLoader, callbacks: ~typing.List[~devinterp.slt.callback.SamplerCallback] | ~typing.Dict[str, ~devinterp.slt.callback.SamplerCallback], evaluate: ~typing.Callable = <function <lambda>>, sampling_method: ~typing.Type[~torch.optim.optimizer.Optimizer] = <class 'devinterp.optim.sgld.SGLD'>, optimizer_kwargs: ~typing.Dict[str, float | ~typing.Literal['adaptive']] | None = None, scheduler_cls: ~typing.Type[~torch.optim.lr_scheduler._LRScheduler] | None = None, scheduler_kwargs: ~typing.Dict[str, ~typing.Any] | None = None, num_draws: int = 100, num_chains: int = 10, num_burnin_steps: int = 0, num_steps_bw_draws: int = 1, gradient_accumulation_steps: int = 1, cores: int = 1, seed: int | ~typing.List[int] | None = None, device: ~torch.device | str = device(type='cpu'), verbose: bool = True, optimize_over_per_model_param: ~typing.Dict[str, ~torch.Tensor] | None = None, batch_size: int = 32, init_noise: float | None = None, shuffle: bool = True, use_alternate_batching=False, init_loss=None, **kwargs)

Sample model weights using a given sampling_method, supporting multiple chains/cores, and calculate the observables (loss, llc, etc.) for each callback passed along. The update, finalize and sample methods of each SamplerCallback() are called during sampling, after sampling, and at sampler_callback_object.get_results() respectively.

After calling this function, the stats of interest live in the callback object.

Parameters:
  • model (torch.nn.Module) – The neural network model.

  • loader (DataLoader) – DataLoader for input data.

  • criterion (Callable) – Loss function.

  • callbacks (list[SamplerCallback]) – list of callbacks, each of type SamplerCallback

  • sampling_method (torch.optim.Optimizer, optional) – Sampling method to use (a PyTorch optimizer under the hood). Default is SGLD

  • optimizer_kwargs (dict, optional) – Keyword arguments for the PyTorch optimizer (used as sampler here). Default is None (using standard SGLD parameters as defined in the SGLD class)

  • num_draws (int, optional) – Number of samples to draw. Default is 100

  • num_chains (int, optional) – Number of chains to run. Default is 10

  • num_burnin_steps (int, optional) – Number of burn-in steps before sampling. Default is 0

  • num_steps_bw_draws (int, optional) – Number of steps between each draw. Default is 1

  • cores (int, optional) – Number of cores for parallel execution. Default is 1

  • seed (int, optional) – Random seed(s) for sampling. Each chain gets a different (deterministic) seed if this is passed. Default is None

  • device (str or torch.device, optional) – Device to perform computations on, e.g., ‘cpu’ or ‘cuda’. Default is ‘cpu’

  • verbose (bool, optional) – whether to print sample chain progress. Default is True

Raises:
  • ValueError – if derivative callbacks (f.e. OnlineLossStatistics()) are passed before base callbacks (f.e. OnlineLLCEstimator())

  • Warning – if num_burnin_steps < num_draws

  • Warning – if num_draws > len(loader)

  • Warning – if using seeded runs

Returns:

None (access LLCs or other observables through callback_object.get_results())

devinterp.backends.tpu.slt.sampler.sample_single_chain(ref_model: ~torch.nn.modules.module.Module, loader: ~torch.utils.data.dataloader.DataLoader, evaluate: ~typing.Callable[[~torch.nn.modules.module.Module, ~torch.Tensor], ~typing.Tuple[~torch.Tensor, ~typing.Dict[str, ~typing.Any]]], num_draws=100, num_burnin_steps=0, num_steps_bw_draws=1, gradient_accumulation_steps: int = 1, sampling_method: ~typing.Type[~torch.optim.optimizer.Optimizer] = <class 'devinterp.optim.sgld.SGLD'>, optimizer_kwargs: ~typing.Dict | None = None, scheduler_cls: ~typing.Type[~torch.optim.lr_scheduler._LRScheduler] | None = None, scheduler_kwargs: ~typing.Dict | None = None, chain: int = 0, seed: int | None = None, verbose=True, device: str | ~torch.device = device(type='xla'), callbacks: ~typing.List[~typing.Callable] = [], optimize_over_per_model_param: ~typing.Dict[str, ~torch.Tensor] | None = None, init_noise: float | None = None, use_alternate_batching=False, **kwargs)

Base function to sample a single chain. This function is called by the sample function on both single and multi-core setups.

Module contents