Source code for devinterp.slt.writing

"""Zarr writing infrastructure for streaming sampling results.

ZarrWriter buffers per-chain data and writes to zarr arrays, optionally
using a thread pool for async I/O.
"""

from __future__ import annotations

import logging
from collections.abc import Iterator
from concurrent import futures
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass, field
from typing import assert_never

import numpy as np
import torch
import zarr
from zarr.core.metadata import ArrayV3Metadata

logger = logging.getLogger(__name__)


# ─── Dtype validation (used by zarr_schema.py) ──────────────────────────────

SAFE_DTYPE_STRINGS = frozenset(
    [
        "float16",
        "float32",
        "float64",
        "int8",
        "int16",
        "int32",
        "int64",
        "uint8",
        "uint16",
        "uint32",
        "uint64",
        "bool",
        "complex64",
        "complex128",
    ]
)


def validate_dtype_str_torch_numpy_compat(dtype_str: str) -> None:
    if dtype_str not in SAFE_DTYPE_STRINGS:
        raise ValueError(
            f"Unsupported dtype: '{dtype_str}'. Use one of: {sorted(SAFE_DTYPE_STRINGS)}"
        )


# ─── Zarr dimension helpers ──────────────────────────────────────────────────


def _zarr_dims(arr: zarr.Array) -> tuple[str, ...]:
    assert isinstance(arr.metadata, ArrayV3Metadata)
    dims = arr.metadata.dimension_names or ()
    assert all(d is not None for d in dims), f"All dimensions must be named, got {dims}"
    return dims  # type: ignore[return-value]


def _indices_from_dims(
    arr: zarr.Array, **dim_indices: int | slice
) -> tuple[int | slice, ...]:
    dims = _zarr_dims(arr)
    indices = [dim_indices.pop(dim, slice(None)) for dim in dims]
    if dim_indices:
        raise ValueError(
            f"Unexpected dimensions: {list(dim_indices)}. Valid: {list(dims)}"
        )
    return tuple(indices)


# ─── Write execution ────────────────────────────────────────────────────────


def _execute_write(
    arr: zarr.Array,
    data: np.ndarray | torch.Tensor,
    dim_indices: dict[str, int | slice],
    *,
    executor: ThreadPoolExecutor | None = None,
    write_futures: list[Future] | None = None,
) -> None:
    indices = _indices_from_dims(arr, **dim_indices)

    match data:
        case torch.Tensor():
            np_data = data.numpy(force=True).copy()
        case np.ndarray():
            np_data = data
        case _ as invalid:
            assert_never(invalid)

    if executor is not None:
        assert write_futures is not None
        write_futures.append(executor.submit(arr.set_basic_selection, indices, np_data))
    else:
        arr.set_basic_selection(indices, np_data)


# ─── Buffer ──────────────────────────────────────────────────────────────────


class _Buffer:
    """Accumulates complete rows before flushing to zarr."""

    def __init__(self, buffer: torch.Tensor):
        self.buffer = buffer
        self.row_ptr = 0
        self.size = 0

    @property
    def capacity(self) -> int:
        return self.buffer.shape[0]

    def is_full(self) -> bool:
        return self.size >= self.capacity

    def push(self, data: torch.Tensor, row_index: int) -> None:
        if self.is_full():
            raise RuntimeError("Buffer full. Call flush before writing more data.")
        assert row_index == self.row_ptr + self.size, (
            f"Row index mismatch: expected {self.row_ptr + self.size}, got {row_index}"
        )
        self.buffer[self.size].copy_(data)
        self.size += 1

    def flush(self) -> tuple[np.ndarray, slice] | None:
        if self.size == 0:
            return None
        np_data = self.buffer[: self.size].numpy(force=True).copy()
        row_slice = slice(self.row_ptr, self.row_ptr + self.size)
        self.row_ptr += self.size
        self.size = 0
        return np_data, row_slice


# ─── ZarrWriter ──────────────────────────────────────────────────────────────


[docs] @dataclass class ZarrWriter: """Buffered writer for streaming per-chain data to zarr arrays. Use ZarrWriter.open() for a context manager that manages the thread pool. """ arrays: dict[str, zarr.Array] chain_buffer_size: int buffer_device: torch.device executor: ThreadPoolExecutor | None = None _write_futures: list[Future] = field(default_factory=list, init=False) _chain_buffers: dict[tuple[str, int], _Buffer] = field( default_factory=dict, init=False ) @classmethod @contextmanager def open( cls, arrays: dict[str, zarr.Array], chain_buffer_size: int, buffer_device: torch.device, max_write_threads: int = 4, ) -> Iterator[ZarrWriter]: pool = ( ThreadPoolExecutor( max_workers=max_write_threads, thread_name_prefix="zarr_write" ) if max_write_threads > 0 else nullcontext() ) with pool as executor: if executor is not None: logger.info("zarr_write thread pool: %d max threads", max_write_threads) writer = cls( arrays=arrays, chain_buffer_size=chain_buffer_size, buffer_device=buffer_device, executor=executor, ) try: yield writer finally: writer.flush_all_buffers() writer.wait_for_writes()
[docs] def write( self, path: str, data: np.ndarray | torch.Tensor, /, **dim_indices: int | slice ) -> None: """Immediate write (for scalars, init_loss, fixed observables).""" _execute_write( self.arrays[path], data, dim_indices, executor=self.executor, write_futures=self._write_futures, )
[docs] def push( self, path: str, data: torch.Tensor, /, *, chain: int, **dim_indices: int | slice, ) -> None: """Push a complete row into a chain buffer.""" arr = self.arrays[path] dims = _zarr_dims(arr) assert dims[0] == "chain" row_dim = dims[1] assert row_dim in dim_indices row_index = dim_indices[row_dim] assert isinstance(row_index, int) buf = self._get_or_create_buffer(path, chain) buf.push(data, row_index=row_index)
def flush_full_buffers(self) -> None: self._flush_buffers(only_full=True) def flush_all_buffers(self) -> None: self._flush_buffers(only_full=False) def wait_for_writes(self) -> None: if not self._write_futures: return done, notdone = futures.wait( self._write_futures, return_when=futures.ALL_COMPLETED ) exceptions = [f.exception() for f in done if f.exception()] self._write_futures.clear() if exceptions: if len(exceptions) == 1: raise exceptions[0] # type: ignore[misc] raise ExceptionGroup(f"{len(exceptions)} zarr write(s) failed", exceptions) # type: ignore[arg-type] logger.info("All %d zarr writes completed successfully", len(done)) def _flush_buffers(self, only_full: bool) -> None: for (path, chain), buf in self._chain_buffers.items(): if only_full and not buf.is_full(): continue result = buf.flush() if result is not None: np_data, row_slice = result arr = self.arrays[path] dims = _zarr_dims(arr) _execute_write( arr, np_data, {"chain": chain, dims[1]: row_slice}, executor=self.executor, write_futures=self._write_futures, ) def _get_or_create_buffer(self, path: str, chain: int) -> _Buffer: key = (path, chain) if key not in self._chain_buffers: arr = self.arrays[path] self._chain_buffers[key] = _Buffer( buffer=torch.zeros( self.chain_buffer_size, *arr.shape[2:], dtype=getattr(torch, str(arr.dtype)), device=self.buffer_device, ), ) return self._chain_buffers[key]