Source code for devinterp.slt.covariance

"""Batched covariance and correlation computation.
Port of aether's covariance_utils.py — torch-based batched correlation
for BIF computation.
"""

from __future__ import annotations

import numpy as np
import torch
import xarray as xr


[docs] def batch_cov(batched_a: torch.Tensor, batched_b: torch.Tensor) -> torch.Tensor: """Batched covariance between all pairs from batched_a and batched_b. Args: batched_a: shape (n_a, series_a, observations) batched_b: shape (n_b, series_b, observations) Returns: shape (n_a, n_b, series_a + series_b, series_a + series_b) """ assert batched_a.dim() == 3 and batched_b.dim() == 3 assert batched_a.shape[2] == batched_b.shape[2] n_a, series_a, n_obs = batched_a.shape n_b, series_b, _ = batched_b.shape a_centered = batched_a - batched_a.mean(dim=2, keepdim=True) b_centered = batched_b - batched_b.mean(dim=2, keepdim=True) a_broadcast = a_centered[:, None, :, :].expand(n_a, n_b, series_a, n_obs) b_broadcast = b_centered[None, :, :, :].expand(n_a, n_b, series_b, n_obs) combined = torch.cat([a_broadcast, b_broadcast], dim=2) return combined @ combined.transpose(-1, -2) / (n_obs - 1)
[docs] def batch_corrcoef(batched_a: torch.Tensor, batched_b: torch.Tensor) -> torch.Tensor: """Batched Pearson correlation between all pairs from batched_a and batched_b. Args: batched_a: shape (n_a, series_a, observations), float32 or float64 batched_b: shape (n_b, series_b, observations), float32 or float64 Returns: shape (n_a, n_b, series_a + series_b, series_a + series_b) """ assert batched_a.dtype in (torch.float32, torch.float64) assert batched_b.dtype in (torch.float32, torch.float64) cov = batch_cov(batched_a, batched_b) diag = torch.diagonal(cov, dim1=-2, dim2=-1) std = torch.sqrt(diag) cov /= std.unsqueeze(-1) * std.unsqueeze(-2) eye = torch.eye(cov.shape[-1], dtype=cov.dtype, device=cov.device) cov *= 1 - eye cov += eye return cov
[docs] def xr_corrcoef_with_torch_backend( seq1: xr.DataArray, seq2: xr.DataArray, *, device: str | torch.device, compute_covariance: bool = False, ) -> xr.Dataset: """Full Pearson correlation matrix between rows of seq1 and seq2 using torch. Args: seq1: 2-D DataArray (e.g. batch, chain_draw) seq2: 2-D DataArray with same dims as seq1 device: torch device for computation compute_covariance: if True, compute covariance instead of correlation Returns: Dataset with "correlation" variable of shape (dim_1, dim_1_T) """ assert seq1.ndim == 2 and seq2.ndim == 2 assert seq1.dims == seq2.dims dim_1 = str(seq1.dims[0]) concatenated = xr.concat([seq1, seq2], dim=dim_1) data = torch.as_tensor(concatenated.data, device=device) if compute_covariance: matrix = torch.cov(data).cpu().numpy() else: matrix = torch.corrcoef(data).cpu().numpy() seq_coord = concatenated[dim_1] dim_2 = f"{dim_1}_T" return xr.Dataset( data_vars={"correlation": ((dim_1, dim_2), matrix)}, coords={ dim_1: np.asarray(seq_coord), dim_2: np.asarray(seq_coord), }, )