diff --git a/src/zarr/core/_transforms/__init__.py b/src/zarr/core/_transforms/__init__.py new file mode 100644 index 0000000000..6ac4d85343 --- /dev/null +++ b/src/zarr/core/_transforms/__init__.py @@ -0,0 +1,39 @@ +"""Composable, lazy coordinate transforms for zarr array indexing. + +This package implements TensorStore-inspired index transforms. The core idea: +every indexing operation (slicing, fancy indexing, etc.) produces a coordinate +mapping from user space to storage space. These mappings compose lazily - no +I/O until you explicitly read or write. + +Private package: this module is not part of the public zarr API. The leading +underscore in the package name signals this. Importers outside this package +must be limited to other private zarr modules. + +Key types: + +- `IndexDomain` -- a rectangular region of integer coordinates +- `IndexTransform` -- maps input coordinates to storage coordinates +- `ConstantMap`, `DimensionMap`, `ArrayMap` -- the three ways a single + output dimension can depend on the input (see `output_map.py`) +- `compose` -- chain two transforms into one +""" + +from zarr.core._transforms.composition import compose +from zarr.core._transforms.domain import IndexDomain +from zarr.core._transforms.output_map import ( + ArrayMap, + ConstantMap, + DimensionMap, + OutputIndexMap, +) +from zarr.core._transforms.transform import IndexTransform + +__all__ = [ + "ArrayMap", + "ConstantMap", + "DimensionMap", + "IndexDomain", + "IndexTransform", + "OutputIndexMap", + "compose", +] diff --git a/src/zarr/core/_transforms/chunk_resolution.py b/src/zarr/core/_transforms/chunk_resolution.py new file mode 100644 index 0000000000..bed247fcaf --- /dev/null +++ b/src/zarr/core/_transforms/chunk_resolution.py @@ -0,0 +1,210 @@ +"""Chunk resolution — mapping transforms to chunk-level I/O. + +Given an `IndexTransform` (which coordinates a user wants to access) and a +`ChunkGrid` (how storage is divided into chunks), chunk resolution answers: + + For each chunk, which storage coordinates does this transform touch, + and where do those values land in the output buffer? + +The algorithm is: + +1. **Enumerate candidate chunks** — determine which chunks could possibly + be touched by the transform's output coordinate ranges. + +2. **Intersect** — for each candidate chunk, call + `transform.intersect(chunk_domain)` to restrict the transform to + coordinates within that chunk. If the intersection is empty, skip it. + +3. **Translate** — shift the restricted transform to chunk-local coordinates + via `transform.translate(-chunk_origin)`. + +4. **Yield** — produce `(chunk_coords, local_transform, surviving_indices)` + triples that the codec pipeline consumes. + +`sub_transform_to_selections` bridges from the transform representation +back to the raw `(chunk_selection, out_selection, drop_axes)` tuples that +the current codec pipeline expects. This bridge will go away when the codec +pipeline accepts transforms natively. +""" + +from __future__ import annotations + +import itertools +from typing import TYPE_CHECKING, Any + +import numpy as np + +from zarr.core._transforms.domain import IndexDomain +from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap +from zarr.core._transforms.transform import IndexTransform + +if TYPE_CHECKING: + from collections.abc import Iterator + + from zarr.core.chunk_grids import ChunkGrid + +ChunkTransformResult = tuple[ + tuple[int, ...], + IndexTransform, + np.ndarray[Any, np.dtype[np.intp]] | None, +] + + +def iter_chunk_transforms( + transform: IndexTransform, + chunk_grid: ChunkGrid, +) -> Iterator[ChunkTransformResult]: + """Resolve a composed IndexTransform against a ChunkGrid. + + Yields `(chunk_coords, sub_transform, out_indices)` triples: + + - `chunk_coords`: which chunk to access. + - `sub_transform`: maps output buffer coords to chunk-local coords. + - `out_indices`: for vectorized/array indexing, the output scatter + indices (integer array). `None` for basic/slice indexing. + """ + dim_grids = chunk_grid._dimensions + + # Enumerate all possible chunks via cartesian product of per-dim chunk ranges + # For each candidate chunk, intersect the transform with the chunk domain. + # The transform.intersect method handles both orthogonal and vectorized cases. + chunk_ranges: list[range] = [] + for out_dim, m in enumerate(transform.output): + dg = dim_grids[out_dim] + if isinstance(m, ConstantMap): + # Single chunk + c = dg.index_to_chunk(m.offset) + chunk_ranges.append(range(c, c + 1)) + elif isinstance(m, DimensionMap): + d = m.input_dimension + dim_lo = transform.domain.inclusive_min[d] + dim_hi = transform.domain.exclusive_max[d] + if dim_lo >= dim_hi: + return # empty domain + # DimensionMap.stride is always positive (enforced by __post_init__). + s_min = m.offset + m.stride * dim_lo + s_max = m.offset + m.stride * (dim_hi - 1) + first = dg.index_to_chunk(s_min) + last = dg.index_to_chunk(s_max) + chunk_ranges.append(range(first, last + 1)) + elif isinstance(m, ArrayMap): # pragma: no branch - exhaustive over OutputIndexMap union + storage = m.offset + m.stride * m.index_array + flat = storage.ravel().astype(np.intp) + chunk_ids = dg.indices_to_chunks(flat) + first = int(chunk_ids.min()) + last = int(chunk_ids.max()) + chunk_ranges.append(range(first, last + 1)) + + for chunk_coords_tuple in itertools.product(*chunk_ranges): + chunk_coords = tuple(int(c) for c in chunk_coords_tuple) + + # Build the chunk domain in storage space + chunk_min: list[int] = [] + chunk_max: list[int] = [] + chunk_shift: list[int] = [] + for out_dim, c in enumerate(chunk_coords): + dg = dim_grids[out_dim] + c_start = dg.chunk_offset(c) + c_size = dg.chunk_size(c) + chunk_min.append(c_start) + chunk_max.append(c_start + c_size) + chunk_shift.append(-c_start) + + chunk_domain = IndexDomain( + inclusive_min=tuple(chunk_min), + exclusive_max=tuple(chunk_max), + ) + + # Intersect transform with chunk domain + result = transform.intersect(chunk_domain) + if result is None: + continue + + restricted, surviving = result + + # Translate to chunk-local coordinates + local = restricted.translate(tuple(chunk_shift)) + + yield (chunk_coords, local, surviving) + + +def sub_transform_to_selections( + sub_transform: IndexTransform, + out_indices: np.ndarray[Any, np.dtype[np.intp]] | None = None, +) -> tuple[ + tuple[int | slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]], ...], + tuple[slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]], ...], + tuple[int, ...], +]: + """Convert a chunk-local sub-transform to raw selections for the codec pipeline. + + Parameters + ---------- + sub_transform + A chunk-local IndexTransform (output maps already translated to + chunk-local coordinates). + out_indices + For vectorized indexing: the output scatter indices for this chunk. + None for orthogonal/basic indexing. + + Returns + ------- + tuple + `(chunk_selection, out_selection, drop_axes)` + """ + chunk_sel: list[int | slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]]] = [] + drop_axes: list[int] = [] + + for m in sub_transform.output: + if isinstance(m, ConstantMap): + chunk_sel.append(m.offset) + elif isinstance(m, DimensionMap): + # DimensionMap.stride is always positive (enforced by __post_init__). + dim_lo = sub_transform.domain.inclusive_min[m.input_dimension] + dim_hi = sub_transform.domain.exclusive_max[m.input_dimension] + start = m.offset + m.stride * dim_lo + stop = m.offset + m.stride * dim_hi + chunk_sel.append(slice(start, stop, m.stride)) + elif isinstance(m, ArrayMap): # pragma: no branch - exhaustive over OutputIndexMap union + if m.offset == 0 and m.stride == 1: + chunk_sel.append(m.index_array) + else: + storage_coords = m.offset + m.stride * m.index_array + chunk_sel.append(storage_coords.astype(np.intp)) + + # Build out_sel: one entry per non-dropped output dim. + out_sel: list[slice | np.ndarray[tuple[int, ...], np.dtype[np.intp]]] = [] + + # Vectorized: 2+ ArrayMaps that share at least one input dimension are + # correlated; they all index into a single shared scatter array. + is_vectorized = False + if out_indices is not None: + seen_input_dims: set[int] = set() + for m in sub_transform.output: + if isinstance(m, ArrayMap): + if seen_input_dims & set(m.input_dimensions): + is_vectorized = True + break + seen_input_dims.update(m.input_dimensions) + + if is_vectorized: + assert out_indices is not None + out_sel.append(out_indices) + else: + for m in sub_transform.output: + if isinstance(m, ConstantMap): + continue + if isinstance(m, DimensionMap): + lo = sub_transform.domain.inclusive_min[m.input_dimension] + hi = sub_transform.domain.exclusive_max[m.input_dimension] + out_sel.append(slice(lo, hi)) + elif isinstance( + m, ArrayMap + ): # pragma: no branch - exhaustive over OutputIndexMap union + if out_indices is not None: + # Orthogonal ArrayMap: out_indices has the surviving positions + out_sel.append(out_indices) + else: + out_sel.append(slice(0, len(m.index_array))) + + return tuple(chunk_sel), tuple(out_sel), tuple(drop_axes) diff --git a/src/zarr/core/_transforms/composition.py b/src/zarr/core/_transforms/composition.py new file mode 100644 index 0000000000..86c05503e2 --- /dev/null +++ b/src/zarr/core/_transforms/composition.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import numpy as np + +from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap, OutputIndexMap +from zarr.core._transforms.transform import IndexTransform + + +def compose(outer: IndexTransform, inner: IndexTransform) -> IndexTransform: + """Compose two IndexTransforms. + + `outer` maps user coords (rank m) to intermediate coords (rank n). + `inner` maps intermediate coords (rank n) to storage coords (rank p). + The result maps user coords (rank m) to storage coords (rank p). + + Precondition: `outer.output_rank == inner.domain.ndim`. + """ + if outer.output_rank != inner.domain.ndim: + raise ValueError( + f"outer output rank ({outer.output_rank}) must match inner input rank " + f"({inner.domain.ndim})" + ) + + result_output = [_compose_single(outer, inner_map) for inner_map in inner.output] + + return IndexTransform(domain=outer.domain, output=tuple(result_output)) + + +def _compose_single(outer: IndexTransform, inner_map: OutputIndexMap) -> OutputIndexMap: + """Compose a single inner output map with the full outer transform.""" + if isinstance(inner_map, ConstantMap): + return ConstantMap(offset=inner_map.offset) + + if isinstance(inner_map, DimensionMap): + return _compose_dimension(outer, inner_map) + + if isinstance(inner_map, ArrayMap): + return _compose_array(outer, inner_map) + + raise TypeError(f"Unknown output map type: {type(inner_map)}") # pragma: no cover + + +def _compose_dimension(outer: IndexTransform, inner_map: DimensionMap) -> OutputIndexMap: + """Compose when inner is a DimensionMap. + + storage = offset_i + stride_i * intermediate[dim_i] + where intermediate[dim_i] = outer.output[dim_i](user_input) + """ + dim_i = inner_map.input_dimension + offset_i = inner_map.offset + stride_i = inner_map.stride + outer_map = outer.output[dim_i] + + if isinstance(outer_map, ConstantMap): + return ConstantMap(offset=offset_i + stride_i * outer_map.offset) + + if isinstance(outer_map, DimensionMap): + return DimensionMap( + input_dimension=outer_map.input_dimension, + offset=offset_i + stride_i * outer_map.offset, + stride=stride_i * outer_map.stride, + ) + + if isinstance(outer_map, ArrayMap): + return ArrayMap( + index_array=outer_map.index_array, + input_dimensions=outer_map.input_dimensions, + offset=offset_i + stride_i * outer_map.offset, + stride=stride_i * outer_map.stride, + ) + + raise TypeError(f"Unknown output map type: {type(outer_map)}") # pragma: no cover + + +def _compose_array(outer: IndexTransform, inner_map: ArrayMap) -> OutputIndexMap: + """Compose when inner is an ArrayMap. + + storage = offset_i + stride_i * arr_i[intermediate[input_dimensions[0]], + intermediate[input_dimensions[1]], ...] + + For each axis k of arr_i, the corresponding intermediate dim is + inner_map.input_dimensions[k] = d. We need to evaluate arr_i over the + product of `outer.output[d]` for each such d. + + All-constant outer: collapse to a single ConstantMap. + + Single 1-D inner array, single outer output: evaluate arr_i along the + one outer output's parameterization. + """ + arr_i = inner_map.index_array + offset_i = inner_map.offset + stride_i = inner_map.stride + in_dims_i = inner_map.input_dimensions + + # All-constant outer: arr_i is evaluated at a single fixed point. + if all(isinstance(m, ConstantMap) for m in outer.output): + idx = tuple(outer.output[d].offset for d in in_dims_i) + value = int(arr_i[idx]) + return ConstantMap(offset=offset_i + stride_i * value) + + # 1-D inner array, single referenced outer output. + if len(in_dims_i) == 1: + dim_i = in_dims_i[0] + outer_map = outer.output[dim_i] + + if isinstance(outer_map, DimensionMap): + # Evaluate arr_i at the outer DimensionMap's range. + input_d = outer_map.input_dimension + input_lo = outer.domain.inclusive_min[input_d] + input_hi = outer.domain.exclusive_max[input_d] + user_indices = np.arange(input_lo, input_hi, dtype=np.intp) + intermediate_vals = outer_map.offset + outer_map.stride * user_indices + new_arr = arr_i[intermediate_vals] + return ArrayMap( + index_array=new_arr, + input_dimensions=(input_d,), + offset=offset_i, + stride=stride_i, + ) + + if isinstance(outer_map, ArrayMap): + # Evaluate arr_i at outer's array values; new array inherits outer's + # parameterization. + intermediate_vals = outer_map.offset + outer_map.stride * outer_map.index_array + new_arr = arr_i[intermediate_vals] + return ArrayMap( + index_array=new_arr, + input_dimensions=outer_map.input_dimensions, + offset=offset_i, + stride=stride_i, + ) + + # General multi-dim case: not yet implemented. + raise NotImplementedError( + "Composing a multi-dimensional inner array map with non-constant outer maps " + "is not yet supported." + ) diff --git a/src/zarr/core/_transforms/domain.py b/src/zarr/core/_transforms/domain.py new file mode 100644 index 0000000000..18a21ae4e4 --- /dev/null +++ b/src/zarr/core/_transforms/domain.py @@ -0,0 +1,182 @@ +"""Index domains — rectangular regions in N-dimensional integer space. + +An `IndexDomain` represents the set of valid coordinates for an array or +array view. It is the cartesian product of per-dimension integer ranges: + +```python +from zarr.core._transforms.domain import IndexDomain + +IndexDomain(inclusive_min=(2, 5), exclusive_max=(10, 20)) +# represents {(i, j) : 2 <= i < 10, 5 <= j < 20} +``` + +Unlike NumPy, domains can have **non-zero origins**. After slicing +`arr[5:10]`, the result has origin 5 and shape 5 — coordinates 5 through +9 are valid. This follows the TensorStore convention. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True, slots=True) +class IndexDomain: + """A rectangular region in N-dimensional index space. + + The valid coordinates are the integers in + `[inclusive_min[d], exclusive_max[d])` for each dimension `d`. + """ + + inclusive_min: tuple[int, ...] + exclusive_max: tuple[int, ...] + labels: tuple[str, ...] | None = None + + def __post_init__(self) -> None: + if len(self.inclusive_min) != len(self.exclusive_max): + raise ValueError( + f"inclusive_min and exclusive_max must have the same length. " + f"Got {len(self.inclusive_min)} and {len(self.exclusive_max)}." + ) + for i, (lo, hi) in enumerate(zip(self.inclusive_min, self.exclusive_max, strict=True)): + if lo > hi: + raise ValueError( + f"inclusive_min must be <= exclusive_max for all dimensions. " + f"Dimension {i}: {lo} > {hi}" + ) + if self.labels is not None and len(self.labels) != len(self.inclusive_min): + raise ValueError( + f"labels must have the same length as dimensions. " + f"Got {len(self.labels)} labels for {len(self.inclusive_min)} dimensions." + ) + + @classmethod + def from_shape(cls, shape: tuple[int, ...]) -> IndexDomain: + """Create a domain with origin at zero.""" + return cls( + inclusive_min=(0,) * len(shape), + exclusive_max=shape, + ) + + @property + def ndim(self) -> int: + return len(self.inclusive_min) + + @property + def origin(self) -> tuple[int, ...]: + return self.inclusive_min + + @property + def shape(self) -> tuple[int, ...]: + return tuple(hi - lo for lo, hi in zip(self.inclusive_min, self.exclusive_max, strict=True)) + + def contains(self, index: tuple[int, ...]) -> bool: + if len(index) != self.ndim: + return False + return all( + lo <= idx < hi + for lo, hi, idx in zip(self.inclusive_min, self.exclusive_max, index, strict=True) + ) + + def contains_domain(self, other: IndexDomain) -> bool: + if other.ndim != self.ndim: + return False + return all( + self_lo <= other_lo and other_hi <= self_hi + for self_lo, self_hi, other_lo, other_hi in zip( + self.inclusive_min, + self.exclusive_max, + other.inclusive_min, + other.exclusive_max, + strict=True, + ) + ) + + def intersect(self, other: IndexDomain) -> IndexDomain | None: + if other.ndim != self.ndim: + raise ValueError( + f"Cannot intersect domains with different ranks: {self.ndim} vs {other.ndim}" + ) + new_min = tuple( + max(a, b) for a, b in zip(self.inclusive_min, other.inclusive_min, strict=True) + ) + new_max = tuple( + min(a, b) for a, b in zip(self.exclusive_max, other.exclusive_max, strict=True) + ) + if any(lo >= hi for lo, hi in zip(new_min, new_max, strict=True)): + return None + return IndexDomain(inclusive_min=new_min, exclusive_max=new_max) + + def translate(self, offset: tuple[int, ...]) -> IndexDomain: + if len(offset) != self.ndim: + raise ValueError( + f"Offset must have same length as domain dimensions. " + f"Domain has {self.ndim} dimensions, offset has {len(offset)}." + ) + new_min = tuple(lo + off for lo, off in zip(self.inclusive_min, offset, strict=True)) + new_max = tuple(hi + off for hi, off in zip(self.exclusive_max, offset, strict=True)) + return IndexDomain(inclusive_min=new_min, exclusive_max=new_max) + + def narrow(self, selection: Any) -> IndexDomain: + """Apply a basic selection and return a narrowed domain. + Indices are absolute coordinates. Integer indices produce length-1 extent. + Strided slices are not supported — use IndexTransform for strides. + """ + normalized = _normalize_selection(selection, self.ndim) + new_inclusive_min: list[int] = [] + new_exclusive_max: list[int] = [] + for dim_idx, (sel, dim_lo, dim_hi) in enumerate( + zip(normalized, self.inclusive_min, self.exclusive_max, strict=True) + ): + if isinstance(sel, int): + if sel < dim_lo or sel >= dim_hi: + raise IndexError( + f"index {sel} is out of bounds for dimension {dim_idx} " + f"with domain [{dim_lo}, {dim_hi})" + ) + new_inclusive_min.append(sel) + new_exclusive_max.append(sel + 1) + else: + start, stop, step = sel.start, sel.stop, sel.step + if step is not None and step != 1: + raise IndexError( + "IndexDomain.narrow only supports step=1 slices. " + f"Got step={step}. Use IndexTransform for strided access." + ) + abs_start = dim_lo if start is None else start + abs_stop = dim_hi if stop is None else stop + abs_start = max(abs_start, dim_lo) + abs_stop = min(abs_stop, dim_hi) + abs_stop = max(abs_stop, abs_start) + new_inclusive_min.append(abs_start) + new_exclusive_max.append(abs_stop) + return IndexDomain( + inclusive_min=tuple(new_inclusive_min), + exclusive_max=tuple(new_exclusive_max), + ) + + +def _normalize_selection(selection: Any, ndim: int) -> tuple[int | slice, ...]: + """Normalize a basic selection to a tuple of ints/slices with length ndim.""" + if not isinstance(selection, tuple): + selection = (selection,) + result: list[int | slice] = [] + ellipsis_seen = False + for sel in selection: + if sel is Ellipsis: + if ellipsis_seen: + raise IndexError("an index can only have a single ellipsis ('...')") + ellipsis_seen = True + num_missing = ndim - (len(selection) - 1) + result.extend([slice(None)] * num_missing) + else: + result.append(sel) + while len(result) < ndim: + result.append(slice(None)) + if len(result) > ndim: + raise IndexError( + f"too many indices for array: array has {ndim} dimensions, " + f"but {len(result)} were indexed" + ) + return tuple(result) diff --git a/src/zarr/core/_transforms/output_map.py b/src/zarr/core/_transforms/output_map.py new file mode 100644 index 0000000000..c4b2dca6eb --- /dev/null +++ b/src/zarr/core/_transforms/output_map.py @@ -0,0 +1,104 @@ +"""Output index maps — three representations of a set of integer coordinates. + +An output index map describes, for one dimension of storage, which coordinates +an array access will touch. Conceptually it is a **set of integers** (1-D) +or a structured set of integers parameterized by some input dims. Three +representations cover the cases that arise in practice: + +- `ConstantMap(offset=5)` — a singleton set: `{5}` +- `DimensionMap(input_dimension=0, offset=3, stride=2)` over input `[0, 5)` + — an arithmetic progression: `{3, 5, 7, 9, 11}` +- `ArrayMap(index_array=[1, 5, 9], input_dimensions=(0,))` — an explicit + enumeration parameterized by input dim 0: `{1, 5, 9}` indexed by `i ∈ [0, 3)`. + +Every output map supports two set-theoretic operations (defined on +`IndexTransform`, which provides the input domain context these maps lack): + +- **intersect** — restrict to coordinates within a range (e.g., a chunk). + `{3, 5, 7, 9, 11} ∩ [4, 8) = {5, 7}` +- **translate** — shift every coordinate by a constant (e.g., make chunk-local). + `{5, 7} - 4 = {1, 3}` + +These two operations are the foundation of chunk resolution: for each chunk, +intersect the map with the chunk's range, then translate to chunk-local +coordinates. + +The three types exist because they trade off generality for efficiency: + +- `ConstantMap`: O(1) storage, O(1) intersection +- `DimensionMap`: O(1) storage, O(1) intersection (analytical) +- `ArrayMap`: O(n) storage, O(n) intersection (must scan the array) + +Collapsing everything to `ArrayMap` would be correct but wasteful — a +billion-element slice would materialize a billion coordinates just to group +them by chunk, when `DimensionMap` does it with three integers. + +Correlation between `ArrayMap`s is encoded by `input_dimensions`. Two +`ArrayMap`s in the same transform that share an input dim are correlated: +their values at the same input coordinate belong to the same storage point +(this is how vectorized indexing is represented). Two `ArrayMap`s with +disjoint `input_dimensions` are independent (orthogonal-style). The +type-level distinction prevents the older convention of inferring +correlation from array length and rank. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import numpy as np + import numpy.typing as npt + + +@dataclass(frozen=True, slots=True) +class ConstantMap: + """A singleton set: one storage coordinate. + + Represents `{offset}`. Arises from integer indexing (e.g., `arr[5]` + fixes one dimension to coordinate 5). + """ + + offset: int = 0 + + +@dataclass(frozen=True, slots=True) +class DimensionMap: + """An arithmetic progression of storage coordinates. + + Represents `{offset + stride * i : i in input_range}`, where the input + range comes from the enclosing `IndexTransform`'s domain. Arises from + slice indexing (e.g., `arr[2:10:3]` gives offset=2, stride=3). + """ + + input_dimension: int + offset: int = 0 + stride: int = 1 + + +@dataclass(frozen=True, slots=True) +class ArrayMap: + """An explicit enumeration of storage coordinates parameterized by input dims. + + Represents `{offset + stride * index_array[i_d0, i_d1, ...]}` where + `(i_d0, i_d1, ...)` ranges over the input coordinates on the dimensions + listed in `input_dimensions`. + + Shape contract (enforced by the enclosing `IndexTransform.__post_init__`): + `index_array.shape` equals the input domain's extent on the dimensions + in `input_dimensions`, in order. For example, if `input_dimensions=(0, 2)` + and the enclosing transform's domain is `(5, 3, 4)`, then + `index_array.shape == (5, 4)`. + + Arises from fancy indexing (e.g., `arr.oindex[[1, 5, 9]]`, boolean masks + via vindex, etc.). + """ + + index_array: npt.NDArray[np.intp] + input_dimensions: tuple[int, ...] + offset: int = 0 + stride: int = 1 + + +OutputIndexMap = ConstantMap | DimensionMap | ArrayMap diff --git a/src/zarr/core/_transforms/transform.py b/src/zarr/core/_transforms/transform.py new file mode 100644 index 0000000000..b67a5400e6 --- /dev/null +++ b/src/zarr/core/_transforms/transform.py @@ -0,0 +1,989 @@ +"""Index transforms — composable, lazy coordinate mappings. + +An `IndexTransform` pairs an **input domain** (the coordinates a user sees) +with a tuple of **output maps** (the storage coordinates those inputs map to). +One output map per storage dimension. See `output_map.py` for the three +output map types. + +Key operations: + +- **Indexing** (`transform[2:8]`, `.oindex[idx]`, `.vindex[idx]`) — + produces a new transform with a narrower input domain and adjusted output + maps. No I/O occurs. This is how lazy slicing works. + +- **intersect(output_domain)** — restrict to storage coordinates within a + region. This is chunk resolution: "which of my coordinates fall in this + chunk?" + +- **translate(shift)** — shift all output coordinates. This makes coordinates + chunk-local: "express my coordinates relative to the chunk origin." + +- **compose(outer, inner)** — chain two transforms. See `composition.py`. + +The transform is the atomic unit that connects user-facing indexing to +chunk-level I/O. Indexing into a lazy view composes a new transform; reading +resolves the transform against the chunk grid via intersect + translate. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any, Literal + +import numpy as np + +from zarr.core._transforms.domain import IndexDomain +from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap, OutputIndexMap + + +@dataclass(frozen=True, slots=True) +class IndexTransform: + """A composable mapping from input coordinates to storage coordinates. + + An `IndexTransform` has: + + - `domain`: an `IndexDomain` describing the valid input coordinates + (the user-facing shape, possibly with non-zero origin). + - `output`: a tuple of output maps (one per storage dimension), each + describing which storage coordinates the inputs touch. + + For a freshly opened array, the transform is the identity: input + coordinate `i` maps to storage coordinate `i`. Indexing operations + compose new transforms without I/O. + """ + + domain: IndexDomain + output: tuple[OutputIndexMap, ...] + + def __post_init__(self) -> None: + for i, m in enumerate(self.output): + if isinstance(m, DimensionMap): + if m.input_dimension < 0 or m.input_dimension >= self.domain.ndim: + raise ValueError( + f"output[{i}].input_dimension = {m.input_dimension} " + f"is out of range for input rank {self.domain.ndim}" + ) + if m.stride <= 0: + raise ValueError( + f"output[{i}].stride = {m.stride} must be positive. " + "Negative-stride DimensionMaps are not supported; the " + "Array layer normalizes negative strides upstream." + ) + elif isinstance( + m, ArrayMap + ): # pragma: no branch - exhaustive over OutputIndexMap union # pragma: no branch - exhaustive over OutputIndexMap union + # Every input dim referenced must be in range. + for d in m.input_dimensions: + if d < 0 or d >= self.domain.ndim: + raise ValueError( + f"output[{i}].input_dimensions = {m.input_dimensions} " + f"references dimension {d}, out of range for input " + f"rank {self.domain.ndim}" + ) + # No duplicates allowed. + if len(set(m.input_dimensions)) != len(m.input_dimensions): + raise ValueError( + f"output[{i}].input_dimensions = {m.input_dimensions} " + f"contains duplicate dimensions" + ) + # index_array.shape must match the extents on input_dimensions. + expected_shape = tuple(self.domain.shape[d] for d in m.input_dimensions) + if m.index_array.shape != expected_shape: + raise ValueError( + f"output[{i}].index_array.shape = {m.index_array.shape} " + f"does not match expected shape {expected_shape} for " + f"input_dimensions={m.input_dimensions}" + ) + + @property + def input_rank(self) -> int: + return self.domain.ndim + + @property + def output_rank(self) -> int: + return len(self.output) + + @classmethod + def identity(cls, domain: IndexDomain) -> IndexTransform: + output = tuple(DimensionMap(input_dimension=i) for i in range(domain.ndim)) + return cls(domain=domain, output=output) + + @classmethod + def from_shape(cls, shape: tuple[int, ...]) -> IndexTransform: + return cls.identity(IndexDomain.from_shape(shape)) + + @property + def selection_repr(self) -> str: + """Compact domain string, e.g. `'{ [2, 8), [0, 10) }'`. + + Follows TensorStore's IndexDomain notation: each dimension shown + as `[inclusive_min, exclusive_max)` with stride annotation if not 1. + Constant (integer-indexed) dimensions show as a single value. + Array-indexed dimensions show the set of selected coordinates. + """ + parts: list[str] = [] + for m in self.output: + if isinstance(m, ConstantMap): + parts.append(str(m.offset)) + elif isinstance(m, DimensionMap): + d = m.input_dimension + lo = self.domain.inclusive_min[d] + hi = self.domain.exclusive_max[d] + start = m.offset + m.stride * lo + stop = m.offset + m.stride * hi + if m.stride == 1: + parts.append(f"[{start}, {stop})") + else: + parts.append(f"[{start}, {stop}) step {m.stride}") + elif isinstance( + m, ArrayMap + ): # pragma: no branch - exhaustive over OutputIndexMap union # pragma: no branch - exhaustive over OutputIndexMap union + storage = m.offset + m.stride * m.index_array + n = len(storage) + if n <= 5: + vals = ", ".join(str(int(v)) for v in storage.ravel()) + parts.append("{" + vals + "}") + else: + parts.append("{" + f"array({n})" + "}") + return "{ " + ", ".join(parts) + " }" + + def __repr__(self) -> str: + maps: list[str] = [] + for i, m in enumerate(self.output): + if isinstance(m, ConstantMap): + maps.append(f"out[{i}] = {m.offset}") + elif isinstance(m, DimensionMap): + maps.append(f"out[{i}] = {m.offset} + {m.stride} * in[{m.input_dimension}]") + elif isinstance( + m, ArrayMap + ): # pragma: no branch - exhaustive over OutputIndexMap union # pragma: no branch - exhaustive over OutputIndexMap union + in_dims = ",".join(f"in[{d}]" for d in m.input_dimensions) + maps.append( + f"out[{i}] = {m.offset} + {m.stride} * arr{m.index_array.shape}[{in_dims}]" + ) + maps_str = ", ".join(maps) + return f"IndexTransform(domain={self.domain}, {maps_str})" + + def intersect( + self, output_domain: IndexDomain + ) -> tuple[IndexTransform, np.ndarray[Any, np.dtype[np.intp]] | None] | None: + """Restrict this transform to storage coordinates within output_domain. + + Returns `(restricted_transform, surviving_indices)` or None if empty. + + `surviving_indices` is an integer array of which input positions + survived the intersection (for ArrayMap dimensions), or None if all + positions survived (ConstantMap/DimensionMap only). + """ + return _intersect(self, output_domain) + + def translate(self, shift: tuple[int, ...]) -> IndexTransform: + """Shift all output coordinates by `shift`.""" + if len(shift) != self.output_rank: + raise ValueError(f"shift must have length {self.output_rank}, got {len(shift)}") + new_output: list[OutputIndexMap] = [] + for m, s in zip(self.output, shift, strict=True): + if isinstance(m, ConstantMap): + new_output.append(ConstantMap(offset=m.offset + s)) + elif isinstance(m, DimensionMap): + new_output.append( + DimensionMap( + input_dimension=m.input_dimension, + offset=m.offset + s, + stride=m.stride, + ) + ) + elif isinstance( + m, ArrayMap + ): # pragma: no branch - exhaustive over OutputIndexMap union # pragma: no branch - exhaustive over OutputIndexMap union + new_output.append( + ArrayMap( + index_array=m.index_array, + input_dimensions=m.input_dimensions, + offset=m.offset + s, + stride=m.stride, + ) + ) + return IndexTransform(domain=self.domain, output=tuple(new_output)) + + def __getitem__(self, selection: Any) -> IndexTransform: + return _apply_basic_indexing(self, selection) + + @property + def oindex(self) -> _OIndexHelper: + return _OIndexHelper(self) + + @property + def vindex(self) -> _VIndexHelper: + return _VIndexHelper(self) + + +def _intersect( + transform: IndexTransform, output_domain: IndexDomain +) -> tuple[IndexTransform, np.ndarray[Any, np.dtype[np.intp]] | None] | None: + """Intersect a transform with an output domain (e.g., a chunk's bounds). + + For each output dimension, restrict to storage coordinates within + [output_domain.inclusive_min[d], output_domain.exclusive_max[d]). + + For orthogonal transforms (ConstantMap, DimensionMap, independent ArrayMaps), + each dimension is intersected independently and the input domain is narrowed. + + For vectorized transforms (correlated ArrayMaps), all array dimensions + must be checked simultaneously — a point survives only if ALL its + coordinates fall within the output domain. + + Returns None if the intersection is empty. + """ + if output_domain.ndim != transform.output_rank: + raise ValueError( + f"output_domain rank ({output_domain.ndim}) != " + f"transform output rank ({transform.output_rank})" + ) + + # Check if we have correlated ArrayMaps (vectorized). + # Two ArrayMaps are correlated iff they share at least one input dimension. + array_output_dims = [i for i, m in enumerate(transform.output) if isinstance(m, ArrayMap)] + if len(array_output_dims) >= 2: + seen_input_dims: set[int] = set() + is_vectorized = False + for out_d in array_output_dims: + m = transform.output[out_d] + assert isinstance(m, ArrayMap) + if seen_input_dims & set(m.input_dimensions): + is_vectorized = True + break + seen_input_dims.update(m.input_dimensions) + if is_vectorized: + return _intersect_vectorized(transform, output_domain, array_output_dims) + + # Orthogonal: intersect each output dimension independently + new_min = list(transform.domain.inclusive_min) + new_max = list(transform.domain.exclusive_max) + new_output: list[OutputIndexMap] = [] + surviving_indices: np.ndarray[Any, np.dtype[np.intp]] | None = None + + for out_dim, m in enumerate(transform.output): + lo = output_domain.inclusive_min[out_dim] + hi = output_domain.exclusive_max[out_dim] + + if isinstance(m, ConstantMap): + if lo <= m.offset < hi: + new_output.append(m) + else: + return None + + elif isinstance(m, DimensionMap): + d = m.input_dimension + input_lo = new_min[d] + input_hi = new_max[d] + if input_lo >= input_hi: + return None + + # Find the input range that produces storage coords in [lo, hi). + # DimensionMap.stride is always positive (enforced by __post_init__). + new_input_lo = max(input_lo, math.ceil((lo - m.offset) / m.stride)) + new_input_hi = min(input_hi, math.ceil((hi - m.offset) / m.stride)) + + if new_input_lo >= new_input_hi: + return None + + new_min[d] = new_input_lo + new_max[d] = new_input_hi + new_output.append(m) + + elif isinstance(m, ArrayMap): # pragma: no branch - exhaustive over OutputIndexMap union + storage = m.offset + m.stride * m.index_array + mask = (storage >= lo) & (storage < hi) + if not np.any(mask): + return None + # Orthogonal ArrayMap: filter the array and shrink the input dim + # it parameterizes. The 1-D case is the only one produced by the + # public oindex / vindex API; multi-dim orthogonal ArrayMaps would + # only arise via direct manual construction. Reject rather than + # silently produce an unsupported result. + if len(m.input_dimensions) != 1: # pragma: no cover - public API never produces this + raise NotImplementedError( + "intersect on a multi-dimensional orthogonal ArrayMap is not yet supported" + ) + (input_d,) = m.input_dimensions + surviving_indices = np.nonzero(mask.ravel())[0].astype(np.intp) + filtered = m.index_array.ravel()[surviving_indices] + new_min[input_d] = 0 + new_max[input_d] = len(filtered) + new_output.append( + ArrayMap( + index_array=filtered, + input_dimensions=(input_d,), + offset=m.offset, + stride=m.stride, + ) + ) + + new_domain = IndexDomain( + inclusive_min=tuple(new_min), + exclusive_max=tuple(new_max), + ) + result = IndexTransform(domain=new_domain, output=tuple(new_output)) + return (result, surviving_indices) + + +def _intersect_vectorized( + transform: IndexTransform, + output_domain: IndexDomain, + array_output_dims: list[int], +) -> tuple[IndexTransform, np.ndarray[Any, np.dtype[np.intp]] | None] | None: + """Intersect a vectorized transform with an output domain. + + All ArrayMap outputs in `array_output_dims` are correlated — a point + survives only if ALL its storage coordinates fall within the output + domain. The correlated ArrayMaps share `input_dimensions`; after + filtering, those shared input dims collapse into a single 1-D domain + `(len(surviving),)`. Any non-correlated input dims (used by + DimensionMap outputs on independent input dims) are preserved. + """ + # Compute storage coords per array dim and check bounds simultaneously. + masks: list[np.ndarray[Any, np.dtype[np.bool_]]] = [] + correlated_input_dims: set[int] = set() + + for out_dim in array_output_dims: + m = transform.output[out_dim] + assert isinstance(m, ArrayMap) + storage = m.offset + m.stride * m.index_array + lo = output_domain.inclusive_min[out_dim] + hi = output_domain.exclusive_max[out_dim] + masks.append((storage >= lo) & (storage < hi)) + correlated_input_dims.update(m.input_dimensions) + + # A point survives only if it's in-bounds on ALL array dims. + combined_mask = masks[0] + for mask in masks[1:]: + combined_mask = combined_mask & mask + + if not np.any(combined_mask): + return None + + surviving = np.nonzero(combined_mask.ravel())[0].astype(np.intp) + + # Build the new domain. The correlated input dims collapse into a single + # 1-D dim (the surviving-points dim, placed at index 0). Any input dims + # NOT consumed by the correlated ArrayMaps are preserved in their order. + preserved_input_dims = [ + d for d in range(transform.domain.ndim) if d not in correlated_input_dims + ] + new_inclusive_min = [0] + [transform.domain.inclusive_min[d] for d in preserved_input_dims] + new_exclusive_max = [len(surviving)] + [ + transform.domain.exclusive_max[d] for d in preserved_input_dims + ] + # old input dim -> new input dim + old_to_new: dict[int, int] = {d: i + 1 for i, d in enumerate(preserved_input_dims)} + + # Build new output maps. + new_output: list[OutputIndexMap] = [] + for out_dim, m in enumerate(transform.output): + if isinstance(m, ArrayMap): + filtered = m.index_array.ravel()[surviving] + new_output.append( + ArrayMap( + index_array=filtered, + input_dimensions=(0,), + offset=m.offset, + stride=m.stride, + ) + ) + elif isinstance(m, ConstantMap): + lo = output_domain.inclusive_min[out_dim] + hi = output_domain.exclusive_max[out_dim] + if lo <= m.offset < hi: + new_output.append(m) + else: + return None + elif isinstance(m, DimensionMap): + # _apply_vindex never places a DimensionMap on a correlated (broadcast) + # input dim: the broadcast dims always become ArrayMap parameters and + # the slice dims become DimensionMaps on dims past the broadcast block. + # This guard is reachable only via direct manual transform construction. + if ( + m.input_dimension in correlated_input_dims + ): # pragma: no cover - public API never produces this + raise NotImplementedError( + "vectorized intersect with a DimensionMap on a correlated " + "input dim is not supported" + ) + new_output.append( + DimensionMap( + input_dimension=old_to_new[m.input_dimension], + offset=m.offset, + stride=m.stride, + ) + ) + + new_domain = IndexDomain( + inclusive_min=tuple(new_inclusive_min), + exclusive_max=tuple(new_exclusive_max), + ) + result = IndexTransform(domain=new_domain, output=tuple(new_output)) + return (result, surviving) + + +def _normalize_basic_selection(selection: Any, ndim: int) -> tuple[int | slice | None, ...]: + """Normalize a selection to a tuple of int, slice, or None (newaxis), + expanding ellipsis and padding with slice(None) as needed. + """ + if not isinstance(selection, tuple): + selection = (selection,) + + # Count non-newaxis, non-ellipsis entries to determine how many real dims are addressed + n_newaxis = sum(1 for s in selection if s is None) + has_ellipsis = any(s is Ellipsis for s in selection) + n_real = len(selection) - n_newaxis - (1 if has_ellipsis else 0) + + if n_real > ndim: + raise IndexError( + f"too many indices for array: array has {ndim} dimensions, but {n_real} were indexed" + ) + + result: list[int | slice | None] = [] + ellipsis_seen = False + for sel in selection: + if sel is Ellipsis: + if ellipsis_seen: + raise IndexError("an index can only have a single ellipsis ('...')") + ellipsis_seen = True + num_missing = ndim - n_real + result.extend([slice(None)] * num_missing) + elif isinstance(sel, (int, np.integer)): + result.append(int(sel)) + elif isinstance(sel, slice) or sel is None: + result.append(sel) + else: + raise IndexError(f"unsupported selection type for basic indexing: {type(sel)!r}") + + # Pad remaining dimensions with slice(None) + while sum(1 for s in result if s is not None) < ndim: + result.append(slice(None)) + + return tuple(result) + + +def _apply_basic_indexing(transform: IndexTransform, selection: Any) -> IndexTransform: + """Apply basic indexing (int, slice, ellipsis, newaxis) to an IndexTransform.""" + normalized = _normalize_basic_selection(selection, transform.domain.ndim) + + new_inclusive_min: list[int] = [] + new_exclusive_max: list[int] = [] + old_dim = 0 + new_dim_idx = 0 + old_to_new_dim: dict[int, int] = {} + dropped_dims: set[int] = set() + + # Per old-dim: the slice parameters (for computing new output maps) + dim_slice_params: dict[int, tuple[int, int, int]] = {} # old_dim -> (start, stop, step) + dim_int_val: dict[int, int] = {} # old_dim -> integer index value + + for sel in normalized: + if sel is None: + # newaxis: add a size-1 dimension + new_inclusive_min.append(0) + new_exclusive_max.append(1) + new_dim_idx += 1 + elif isinstance(sel, int): + # Integer index: drop this input dimension. + # Negative indices are literal coordinates (TensorStore convention), + # NOT "from the end" like NumPy. The Array layer handles conversion. + lo = transform.domain.inclusive_min[old_dim] + hi = transform.domain.exclusive_max[old_dim] + idx = sel + if idx < lo or idx >= hi: + raise IndexError( + f"index {sel} is out of bounds for dimension {old_dim} with domain [{lo}, {hi})" + ) + dropped_dims.add(old_dim) + dim_int_val[old_dim] = idx + old_dim += 1 + elif isinstance( + sel, slice + ): # pragma: no branch - exhaustive over normalized's element type + lo = transform.domain.inclusive_min[old_dim] + hi = transform.domain.exclusive_max[old_dim] + dim_size = hi - lo + + # Resolve slice relative to the current domain (origin-based) + start, stop, step = sel.indices(dim_size) + # start, stop, step are now relative to a 0-based range of size dim_size + + if step <= 0: + raise IndexError("slice step must be positive") + + new_size = max(0, math.ceil((stop - start) / step)) + new_inclusive_min.append(0) + new_exclusive_max.append(new_size) + + # Absolute start in the original domain coordinates + abs_start = lo + start + dim_slice_params[old_dim] = (abs_start, stop, step) + old_to_new_dim[old_dim] = new_dim_idx + new_dim_idx += 1 + old_dim += 1 + + new_domain = IndexDomain( + inclusive_min=tuple(new_inclusive_min), + exclusive_max=tuple(new_exclusive_max), + ) + + # Now update output maps + new_output: list[OutputIndexMap] = [] + for m in transform.output: + if isinstance(m, ConstantMap): + new_output.append(m) + elif isinstance(m, DimensionMap): + d = m.input_dimension + if d in dropped_dims: + # Integer index: this output becomes constant + new_offset = m.offset + m.stride * dim_int_val[d] + new_output.append(ConstantMap(offset=new_offset)) + elif d in old_to_new_dim: + # Slice: update offset and stride + abs_start, _, step = dim_slice_params[d] + new_offset = m.offset + m.stride * abs_start + new_stride = m.stride * step + new_input_dim = old_to_new_dim[d] + new_output.append( + DimensionMap( + input_dimension=new_input_dim, offset=new_offset, stride=new_stride + ) + ) + else: + raise RuntimeError( # pragma: no cover - defensive; unreachable for validated transforms + f"unexpected: dimension {d} not handled" + ) + elif isinstance(m, ArrayMap): # pragma: no branch - exhaustive over OutputIndexMap union + # The array's axes are labeled by m.input_dimensions, in order. + # For each labeled axis: if the corresponding old input dim is + # dropped (int), select that one entry; if sliced, slice the axis; + # otherwise leave the axis intact. Newaxis insertions don't touch + # the array (they add new input dims not in input_dimensions). + arr_idx: list[Any] = [] + new_input_dims: list[int] = [] + for axis_dim in m.input_dimensions: + if axis_dim in dropped_dims: + array_idx = dim_int_val[axis_dim] - transform.domain.inclusive_min[axis_dim] + arr_idx.append(array_idx) + elif axis_dim in old_to_new_dim: + abs_start, _, step = dim_slice_params[axis_dim] + array_start = abs_start - transform.domain.inclusive_min[axis_dim] + new_size = ( + new_exclusive_max[old_to_new_dim[axis_dim]] + - new_inclusive_min[old_to_new_dim[axis_dim]] + ) + array_stop = array_start + step * new_size + arr_idx.append(slice(array_start, array_stop, step)) + new_input_dims.append(old_to_new_dim[axis_dim]) + else: # pragma: no cover - defensive; unreachable for validated transforms + raise RuntimeError( + f"unexpected: ArrayMap input_dim {axis_dim} not in " + "dropped_dims or old_to_new_dim" + ) + new_arr = m.index_array[tuple(arr_idx)] if arr_idx else m.index_array + new_arr = np.asarray(new_arr, dtype=np.intp) + new_output.append( + ArrayMap( + index_array=new_arr, + input_dimensions=tuple(new_input_dims), + offset=m.offset, + stride=m.stride, + ) + ) + + return IndexTransform(domain=new_domain, output=tuple(new_output)) + + +class _OIndexHelper: + """Helper that provides orthogonal (outer) indexing via `transform.oindex[...]`.""" + + def __init__(self, transform: IndexTransform) -> None: + self._transform = transform + + def __getitem__(self, selection: Any) -> IndexTransform: + return _apply_oindex(self._transform, selection) + + +def _normalize_oindex_selection( + selection: Any, ndim: int +) -> tuple[np.ndarray[Any, np.dtype[np.intp]] | slice, ...]: + """Normalize an oindex selection: arrays, slices, booleans, integers.""" + if not isinstance(selection, tuple): + selection = (selection,) + + # Expand ellipsis + has_ellipsis = any(s is Ellipsis for s in selection) + n_ellipsis = 1 if has_ellipsis else 0 + n_real = len(selection) - n_ellipsis + + result: list[np.ndarray[Any, np.dtype[np.intp]] | slice] = [] + for sel in selection: + if sel is Ellipsis: + num_missing = ndim - n_real + result.extend([slice(None)] * num_missing) + elif isinstance(sel, np.ndarray) and sel.dtype == np.bool_: + # Boolean array -> integer indices + (indices,) = np.nonzero(sel) + result.append(indices.astype(np.intp)) + elif isinstance(sel, np.ndarray): + result.append(sel.astype(np.intp)) + elif isinstance(sel, slice): + result.append(sel) + elif isinstance(sel, (int, np.integer)): + # Convert integer scalars to 1-element arrays for orthogonal indexing + result.append(np.array([int(sel)], dtype=np.intp)) + elif isinstance(sel, (list, tuple)): + result.append(np.asarray(sel, dtype=np.intp)) + else: # pragma: no cover - upstream _validate_array_selection rejects other types + result.append(sel) + + # Pad with slice(None) + while len(result) < ndim: + result.append(slice(None)) + + return tuple(result) + + +def _apply_oindex(transform: IndexTransform, selection: Any) -> IndexTransform: + """Apply orthogonal indexing to an IndexTransform. + + Each index array is applied independently per dimension (outer product). + """ + normalized = _normalize_oindex_selection(selection, transform.domain.ndim) + + new_inclusive_min: list[int] = [] + new_exclusive_max: list[int] = [] + new_dim_idx = 0 + old_to_new_dim: dict[int, int] = {} + + # Info per old dim + dim_array: dict[int, np.ndarray[Any, np.dtype[np.intp]]] = {} + dim_slice_params: dict[int, tuple[int, int, int]] = {} + + for old_dim, sel in enumerate(normalized): + if isinstance(sel, np.ndarray): + dim_array[old_dim] = sel + new_inclusive_min.append(0) + new_exclusive_max.append(len(sel)) + old_to_new_dim[old_dim] = new_dim_idx + new_dim_idx += 1 + elif isinstance( + sel, slice + ): # pragma: no branch - exhaustive over normalized's element type + lo = transform.domain.inclusive_min[old_dim] + hi = transform.domain.exclusive_max[old_dim] + dim_size = hi - lo + start, stop, step = sel.indices(dim_size) + if step <= 0: + raise IndexError("slice step must be positive") + new_size = max(0, math.ceil((stop - start) / step)) + new_inclusive_min.append(0) + new_exclusive_max.append(new_size) + abs_start = lo + start + dim_slice_params[old_dim] = (abs_start, stop, step) + old_to_new_dim[old_dim] = new_dim_idx + new_dim_idx += 1 + + new_domain = IndexDomain( + inclusive_min=tuple(new_inclusive_min), + exclusive_max=tuple(new_exclusive_max), + ) + + new_output: list[OutputIndexMap] = [] + for m in transform.output: + if isinstance(m, ConstantMap): + new_output.append(m) + elif isinstance(m, DimensionMap): + d = m.input_dimension + if d in dim_array: + new_output.append( + ArrayMap( + index_array=dim_array[d], + input_dimensions=(old_to_new_dim[d],), + offset=m.offset, + stride=m.stride, + ) + ) + elif d in dim_slice_params: + abs_start, _, step = dim_slice_params[d] + new_offset = m.offset + m.stride * abs_start + new_stride = m.stride * step + new_input_dim = old_to_new_dim[d] + new_output.append( + DimensionMap( + input_dimension=new_input_dim, offset=new_offset, stride=new_stride + ) + ) + else: + raise RuntimeError( # pragma: no cover - defensive; unreachable for validated transforms + f"unexpected: dimension {d} not handled" + ) + elif isinstance(m, ArrayMap): # pragma: no branch - exhaustive over OutputIndexMap union + # Each axis of m.index_array corresponds to one entry in + # m.input_dimensions. For each such old input dim, oindex either + # picks specific entries (dim_array[d]) or slices the axis + # (dim_slice_params[d]). + arr_idx: list[Any] = [] + n_array_axes = 0 + for axis_dim in m.input_dimensions: + if axis_dim in dim_array: + arr_idx.append(dim_array[axis_dim]) + n_array_axes += 1 + elif axis_dim in dim_slice_params: + abs_start, _, step = dim_slice_params[axis_dim] + array_start = abs_start - transform.domain.inclusive_min[axis_dim] + new_size = ( + new_exclusive_max[old_to_new_dim[axis_dim]] + - new_inclusive_min[old_to_new_dim[axis_dim]] + ) + array_stop = array_start + step * new_size + arr_idx.append(slice(array_start, array_stop, step)) + else: # pragma: no cover - defensive; unreachable for validated transforms + raise RuntimeError( + f"unexpected: ArrayMap input_dim {axis_dim} not in " + "dim_array or dim_slice_params" + ) + # Multi-dim ArrayMap with two or more axes selected by arrays needs + # `np.ix_`-style outer-product indexing to preserve oindex semantics + # (NumPy's `arr[a, b]` would broadcast a and b instead). Until that + # is implemented, refuse rather than silently produce wrong results. + if n_array_axes >= 2: + raise NotImplementedError( + "oindex on a multi-dimensional ArrayMap with two or more " + "axes selected by integer/boolean arrays is not yet " + "supported" + ) + new_arr = m.index_array[tuple(arr_idx)] if arr_idx else m.index_array + new_arr = np.asarray(new_arr, dtype=np.intp) + new_input_dims = tuple(old_to_new_dim[d] for d in m.input_dimensions) + new_output.append( + ArrayMap( + index_array=new_arr, + input_dimensions=new_input_dims, + offset=m.offset, + stride=m.stride, + ) + ) + + return IndexTransform(domain=new_domain, output=tuple(new_output)) + + +class _VIndexHelper: + """Helper that provides vectorized (fancy) indexing via `transform.vindex[...]`.""" + + def __init__(self, transform: IndexTransform) -> None: + self._transform = transform + + def __getitem__(self, selection: Any) -> IndexTransform: + return _apply_vindex(self._transform, selection) + + +def _apply_vindex(transform: IndexTransform, selection: Any) -> IndexTransform: + """Apply vectorized indexing to an IndexTransform. + + All array indices are broadcast together. Broadcast dimensions are prepended, + followed by non-array (slice) dimensions. + """ + if not isinstance(selection, tuple): + selection = (selection,) + + # Expand ellipsis and count consumed dimensions + # Boolean arrays with ndim > 1 consume ndim dims + n_consumed = 0 + for s in selection: + if s is Ellipsis: + continue + if isinstance(s, np.ndarray) and s.dtype == np.bool_ and s.ndim > 1: + n_consumed += s.ndim + else: + n_consumed += 1 + ndim = transform.domain.ndim + + expanded: list[Any] = [] + for sel in selection: + if sel is Ellipsis: + num_missing = ndim - n_consumed + expanded.extend([slice(None)] * num_missing) + else: + expanded.append(sel) + # Count dimensions already consumed by expanded entries + n_expanded_dims = 0 + for sel in expanded: + if isinstance(sel, np.ndarray) and sel.dtype == np.bool_ and sel.ndim > 1: + n_expanded_dims += sel.ndim + else: + n_expanded_dims += 1 + while n_expanded_dims < ndim: + expanded.append(slice(None)) + n_expanded_dims += 1 + + # Convert booleans, lists, ints to integer arrays + processed: list[np.ndarray[Any, np.dtype[np.intp]] | slice] = [] + for sel in expanded: + if isinstance(sel, np.ndarray) and sel.dtype == np.bool_: + indices_tuple = np.nonzero(sel) + processed.extend(indices.astype(np.intp) for indices in indices_tuple) + elif isinstance(sel, np.ndarray): + processed.append(sel.astype(np.intp)) + elif isinstance(sel, (list, tuple)): + processed.append(np.asarray(sel, dtype=np.intp)) + elif isinstance(sel, (int, np.integer)): + processed.append(np.array([int(sel)], dtype=np.intp)) + else: # pragma: no cover - upstream _validate_array_selection rejects other types + processed.append(sel) + + # Separate array dims and slice dims + array_dims: list[int] = [] + slice_dims: list[int] = [] + arrays: list[np.ndarray[Any, np.dtype[np.intp]]] = [] + + for i, sel in enumerate(processed): + if isinstance(sel, np.ndarray): + array_dims.append(i) + arrays.append(sel) + else: + slice_dims.append(i) + + # Broadcast all arrays together + broadcast_arrays: list[np.ndarray[Any, np.dtype[np.intp]]] + if arrays: + broadcast_arrays = list(np.broadcast_arrays(*arrays)) + broadcast_shape = broadcast_arrays[0].shape + else: + broadcast_arrays = [] + broadcast_shape = () + + # Build new domain: broadcast dims first, then slice dims + new_inclusive_min: list[int] = [] + new_exclusive_max: list[int] = [] + + # Broadcast dimensions + for s in broadcast_shape: + new_inclusive_min.append(0) + new_exclusive_max.append(s) + + # Slice dimensions + slice_dim_params: dict[int, tuple[int, int, int]] = {} + for old_dim in slice_dims: + sel = processed[old_dim] + assert isinstance(sel, slice) + lo = transform.domain.inclusive_min[old_dim] + hi = transform.domain.exclusive_max[old_dim] + dim_size = hi - lo + start, stop, step = sel.indices(dim_size) + if step <= 0: + raise IndexError("slice step must be positive") + new_size = max(0, math.ceil((stop - start) / step)) + new_inclusive_min.append(0) + new_exclusive_max.append(new_size) + abs_start = lo + start + slice_dim_params[old_dim] = (abs_start, stop, step) + + new_domain = IndexDomain( + inclusive_min=tuple(new_inclusive_min), + exclusive_max=tuple(new_exclusive_max), + ) + + # Build output maps + array_dim_to_broadcast: dict[int, np.ndarray[Any, np.dtype[np.intp]]] = {} + for i, d in enumerate(array_dims): + array_dim_to_broadcast[d] = broadcast_arrays[i] + + # New dim index for slice dims starts after broadcast dims + n_broadcast_dims = len(broadcast_shape) + + # Broadcast dims are placed at input_dim positions [0, n_broadcast_dims). + broadcast_input_dims = tuple(range(n_broadcast_dims)) + + new_output: list[OutputIndexMap] = [] + for m in transform.output: + if isinstance(m, ConstantMap): + new_output.append(m) + elif isinstance(m, DimensionMap): + d = m.input_dimension + if d in array_dim_to_broadcast: + new_output.append( + ArrayMap( + index_array=array_dim_to_broadcast[d], + input_dimensions=broadcast_input_dims, + offset=m.offset, + stride=m.stride, + ) + ) + else: + # Slice dim + abs_start, _, step = slice_dim_params[d] + new_offset = m.offset + m.stride * abs_start + new_stride = m.stride * step + new_input_dim = n_broadcast_dims + slice_dims.index(d) + new_output.append( + DimensionMap( + input_dimension=new_input_dim, offset=new_offset, stride=new_stride + ) + ) + elif isinstance(m, ArrayMap): # pragma: no branch - exhaustive over OutputIndexMap union + # vindex on a transform that already has an ArrayMap output is not + # currently exercised. The semantics are subtle (broadcasting can + # reshape the array's parameterization) and require careful design; + # raise rather than produce wrong results. + raise NotImplementedError( + "vindex on a transform whose output is already an ArrayMap is not yet supported" + ) + + return IndexTransform(domain=new_domain, output=tuple(new_output)) + + +def _validate_array_selection(selection: Any, shape: tuple[int, ...], mode: str) -> None: + """Validate array-based selections (orthogonal, vectorized). + + Rejects types that are not valid for coordinate/vectorized indexing. + Does not check bounds — the transform operations handle that. + """ + items = selection if isinstance(selection, tuple) else (selection,) + for sel in items: + if sel is Ellipsis or isinstance(sel, (int, np.integer, slice)): + continue + if isinstance(sel, (list, np.ndarray)): + continue + raise IndexError(f"unsupported selection type for {mode} indexing: {type(sel)!r}") + + +def _validate_basic_selection(selection: Any) -> None: + """Validate that a selection only contains basic indexing types (int, slice, Ellipsis). + + Rejects None (newaxis), arrays, lists, floats, strings, etc. + """ + items = selection if isinstance(selection, tuple) else (selection,) + for s in items: + if s is Ellipsis or isinstance(s, (int, np.integer, slice)): + continue + raise IndexError(f"unsupported selection type for basic indexing: {type(s)!r}") + + +def selection_to_transform( + selection: Any, + transform: IndexTransform, + mode: Literal["basic", "orthogonal", "vectorized"], +) -> IndexTransform: + """Convert a user selection into a composed IndexTransform. + + Negative indices are treated as literal coordinates (TensorStore convention). + The caller (Array layer) is responsible for converting numpy-style negative + indices before calling this function. + """ + if mode == "basic": + _validate_basic_selection(selection) + return transform[selection] + elif mode == "orthogonal": + _validate_array_selection(selection, transform.domain.shape, mode) + return transform.oindex[selection] + elif mode == "vectorized": + _validate_array_selection(selection, transform.domain.shape, mode) + return transform.vindex[selection] + else: + raise ValueError(f"Unknown mode: {mode!r}") diff --git a/tests/test_transforms/__init__.py b/tests/test_transforms/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_transforms/conftest.py b/tests/test_transforms/conftest.py new file mode 100644 index 0000000000..12c9896c1c --- /dev/null +++ b/tests/test_transforms/conftest.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Expect[TIn, TOut]: + """Model an input and an expected output value for a test case.""" + + input: TIn + expected: TOut + id: str + + +@dataclass(frozen=True) +class ExpectErr[TIn]: + """Model an input and an expected error for a test case.""" + + input: TIn + msg: str + exception_cls: type[Exception] + id: str diff --git a/tests/test_transforms/test_chunk_resolution.py b/tests/test_transforms/test_chunk_resolution.py new file mode 100644 index 0000000000..f6af94daa4 --- /dev/null +++ b/tests/test_transforms/test_chunk_resolution.py @@ -0,0 +1,380 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest + +from tests.test_transforms.conftest import Expect +from zarr.core._transforms.chunk_resolution import ( + iter_chunk_transforms, + sub_transform_to_selections, +) +from zarr.core._transforms.domain import IndexDomain +from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap +from zarr.core._transforms.transform import IndexTransform +from zarr.core.chunk_grids import ChunkGrid, FixedDimension + +# --------------------------------------------------------------------------- +# iter_chunk_transforms — for a transform composed against a ChunkGrid, yield +# (chunk_coords, sub_transform, out_indices) for each touched chunk. +# --------------------------------------------------------------------------- + + +def _grid_1d(size: int, extent: int) -> ChunkGrid: + return ChunkGrid(dimensions=(FixedDimension(size=size, extent=extent),)) + + +def _grid_2d(size0: int, extent0: int, size1: int, extent1: int) -> ChunkGrid: + return ChunkGrid( + dimensions=( + FixedDimension(size=size0, extent=extent0), + FixedDimension(size=size1, extent=extent1), + ) + ) + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexTransform.from_shape((10,)), _grid_1d(10, 10)), + expected={"n_chunks": 1, "coords": [(0,)]}, + id="single-chunk-fits-array", + ), + Expect( + input=(IndexTransform.from_shape((30,)), _grid_1d(10, 30)), + expected={"n_chunks": 3, "coords": [(0,), (1,), (2,)]}, + id="three-chunks-1d", + ), + Expect( + input=(IndexTransform.from_shape((20, 30)), _grid_2d(10, 20, 10, 30)), + expected={ + "n_chunks": 6, + "coords": [(i, j) for i in (0, 1) for j in (0, 1, 2)], + }, + id="six-chunks-2x3", + ), + Expect( + input=(IndexTransform.from_shape((100,))[5:8], _grid_1d(10, 100)), + expected={"n_chunks": 1, "coords": [(0,)]}, + id="slice-within-chunk", + ), + Expect( + input=(IndexTransform.from_shape((100,))[8:15], _grid_1d(10, 100)), + expected={"n_chunks": 2, "coords": [(0,), (1,)]}, + id="slice-across-two-chunks", + ), + ], + ids=lambda c: c.id, +) +def test_iter_chunk_transforms_yields_expected_chunks( + case: Expect[tuple[IndexTransform, ChunkGrid], dict[str, Any]], +) -> None: + """iter_chunk_transforms enumerates all chunks intersected by the transform.""" + transform, grid = case.input + results = list(iter_chunk_transforms(transform, grid)) + assert len(results) == case.expected["n_chunks"] + coords_list = [r[0] for r in results] + for expected_coord in case.expected["coords"]: + assert expected_coord in coords_list + + +def test_iter_chunk_transforms_constant_map_picks_single_chunk_per_dim() -> None: + """An integer index produces a ConstantMap, fixing the chunk on that dim. + + arr[25, :] over a 10-element chunk grid: chunk index for storage 25 is 2, + so every chunk yielded has coords[0] == 2. The free dim (the slice) iterates.""" + t = IndexTransform.from_shape((100, 100))[25, :] + grid = _grid_2d(10, 100, 10, 100) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 10 + for coords, _, _ in results: + assert coords[0] == 2 + + +def test_iter_chunk_transforms_array_map_lists_chunks_for_array_entries() -> None: + """An ArrayMap yields chunks for each unique chunk-id of its index_array entries.""" + idx = np.array([5, 15, 25], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=idx, input_dimensions=(0,)),), + ) + results = list(iter_chunk_transforms(t, _grid_1d(10, 30))) + coords_list = [r[0] for r in results] + assert (0,) in coords_list + assert (1,) in coords_list + assert (2,) in coords_list + + +def test_iter_chunk_transforms_within_chunk_offset_is_local() -> None: + """The yielded sub-transform's output is in chunk-local coordinates, + so a slice arr[5:8] in chunk 0 yields offset=5 (the offset within the chunk).""" + t = IndexTransform.from_shape((100,))[5:8] + grid = _grid_1d(10, 100) + results = list(iter_chunk_transforms(t, grid)) + assert len(results) == 1 + _, sub_t, _ = results[0] + assert isinstance(sub_t.output[0], DimensionMap) + assert sub_t.output[0].offset == 5 + + +# --------------------------------------------------------------------------- +# sub_transform_to_selections — convert a chunk-local sub-transform into +# (chunk_selection, out_selection, drop_axes) tuples for the codec pipeline. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), + ), + expected={ + "chunk_sel": (5,), + "out_sel": (), + "drop_axes": (), + }, + id="constant-map-yields-int-selection-no-out", + ), + Expect( + input=IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=3, stride=1),), + ), + expected={ + "chunk_sel": (slice(3, 13, 1),), + "out_sel": (slice(0, 10),), + "drop_axes": (), + }, + id="dimension-map-stride-1-yields-contiguous-slice", + ), + Expect( + input=IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(DimensionMap(input_dimension=0, offset=2, stride=3),), + ), + expected={ + "chunk_sel": (slice(2, 17, 3),), + "out_sel": (slice(0, 5),), + "drop_axes": (), + }, + id="dimension-map-strided-yields-strided-slice", + ), + Expect( + input=IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=( + ConstantMap(offset=5), + DimensionMap(input_dimension=0, offset=0, stride=1), + ), + ), + expected={ + "chunk_sel_kinds": (int, slice), + "chunk_sel_values": (5, slice(0, 10, 1)), + "drop_axes": (), + }, + id="mixed-2d-constant-and-dimension", + ), + ], + ids=lambda c: c.id, +) +def test_sub_transform_to_selections_basic(case: Expect[IndexTransform, dict[str, Any]]) -> None: + """sub_transform_to_selections produces the expected (chunk_sel, out_sel, drop_axes) for non-array maps.""" + chunk_sel, out_sel, drop_axes = sub_transform_to_selections(case.input) + if "chunk_sel" in case.expected: + assert chunk_sel == case.expected["chunk_sel"] + if "chunk_sel_kinds" in case.expected: + for got, expected_kind in zip(chunk_sel, case.expected["chunk_sel_kinds"], strict=True): + assert isinstance(got, expected_kind) + if "chunk_sel_values" in case.expected: + for got, expected_val in zip(chunk_sel, case.expected["chunk_sel_values"], strict=True): + assert got == expected_val + if "out_sel" in case.expected: + assert out_sel == case.expected["out_sel"] + assert drop_axes == case.expected["drop_axes"] + + +def test_sub_transform_to_selections_array_map_no_offset() -> None: + """An ArrayMap with offset=0, stride=1 produces the index_array itself as chunk_sel.""" + arr = np.array([1, 5, 9], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr, input_dimensions=(0,), offset=0, stride=1),), + ) + chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) + assert isinstance(chunk_sel[0], np.ndarray) + np.testing.assert_array_equal(chunk_sel[0], arr) + # Without out_indices, out_sel falls back to a domain-derived slice. + assert out_sel == (slice(0, 3),) + assert drop_axes == () + + +def test_sub_transform_to_selections_array_map_with_offset_stride() -> None: + """An ArrayMap with non-zero offset/stride is materialized into storage coords.""" + arr = np.array([0, 1, 2], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr, input_dimensions=(0,), offset=10, stride=5),), + ) + chunk_sel, _out_sel, drop_axes = sub_transform_to_selections(t) + assert isinstance(chunk_sel[0], np.ndarray) + np.testing.assert_array_equal(chunk_sel[0], np.array([10, 15, 20])) + assert drop_axes == () + + +def test_sub_transform_to_selections_orthogonal_array_with_out_indices() -> None: + """When out_indices is supplied with a single ArrayMap (orthogonal mode), + out_sel uses the supplied scatter indices rather than a domain slice.""" + arr = np.array([1, 5, 9], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr, input_dimensions=(0,)),), + ) + out_indices = np.array([0, 2], dtype=np.intp) + _chunk_sel, out_sel, _drop_axes = sub_transform_to_selections(t, out_indices) + assert len(out_sel) == 1 + assert isinstance(out_sel[0], np.ndarray) + np.testing.assert_array_equal(out_sel[0], out_indices) + + +def test_sub_transform_to_selections_vectorized_with_out_indices() -> None: + """When out_indices is supplied with 2+ correlated ArrayMaps (vectorized mode), + out_sel collapses to a single shared scatter array.""" + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + ArrayMap(index_array=np.array([1, 5, 9], dtype=np.intp), input_dimensions=(0,)), + ArrayMap(index_array=np.array([10, 11, 12], dtype=np.intp), input_dimensions=(0,)), + ), + ) + out_indices = np.array([0, 1], dtype=np.intp) + _chunk_sel, out_sel, _drop_axes = sub_transform_to_selections(t, out_indices) + assert len(out_sel) == 1 + assert isinstance(out_sel[0], np.ndarray) + np.testing.assert_array_equal(out_sel[0], out_indices) + + +def test_iter_chunk_transforms_empty_domain() -> None: + """When the input domain is empty (some dim has zero extent), + iter_chunk_transforms yields nothing.""" + t = IndexTransform( + domain=IndexDomain(inclusive_min=(0,), exclusive_max=(0,)), + output=(DimensionMap(input_dimension=0, offset=0, stride=1),), + ) + grid = _grid_1d(10, 30) + results = list(iter_chunk_transforms(t, grid)) + assert results == [] + + +def test_iter_chunk_transforms_arraymap_followed_by_dimensionmap() -> None: + """An ArrayMap output followed by a DimensionMap output exercises the + ArrayMap branch's loop-continuation path in iter_chunk_transforms.""" + t = IndexTransform( + domain=IndexDomain.from_shape((3, 5)), + output=( + ArrayMap( + index_array=np.array([1, 5, 9], dtype=np.intp), + input_dimensions=(0,), + ), + DimensionMap(input_dimension=1, offset=0, stride=1), + ), + ) + grid = _grid_2d(10, 10, 10, 5) + results = list(iter_chunk_transforms(t, grid)) + # Sanity: at least one result is yielded. + assert results + + +def test_sub_transform_to_selections_arraymap_followed_by_dimensionmap_orthogonal() -> None: + """An ArrayMap output followed by a DimensionMap output in non-vectorized + mode (out_indices=None) exercises the ArrayMap branch's loop-continuation + path in both the chunk_sel and out_sel construction loops.""" + t = IndexTransform( + domain=IndexDomain.from_shape((3, 5)), + output=( + ArrayMap( + index_array=np.array([1, 5, 9], dtype=np.intp), + input_dimensions=(0,), + ), + DimensionMap(input_dimension=1, offset=0, stride=1), + ), + ) + chunk_sel, out_sel, drop_axes = sub_transform_to_selections(t) + # Two output dims, both with selections. + assert len(chunk_sel) == 2 + assert isinstance(chunk_sel[0], np.ndarray) # ArrayMap → array selection + assert isinstance(chunk_sel[1], slice) # DimensionMap → slice + assert len(out_sel) == 2 + assert drop_axes == () + + +def test_sub_transform_to_selections_with_out_indices_skips_non_arraymap_in_correlation_check() -> ( + None +): + """When out_indices is supplied and an output is NOT an ArrayMap, the + correlation-detection loop skips it (covers the `if isinstance(m, ArrayMap)` + False branch in vectorized detection).""" + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + DimensionMap(input_dimension=0, offset=0, stride=1), + ArrayMap( + index_array=np.array([1, 2, 3], dtype=np.intp), + input_dimensions=(0,), + ), + ), + ) + out_indices = np.array([0, 1], dtype=np.intp) + chunk_sel, out_sel, _drop_axes = sub_transform_to_selections(t, out_indices) + # Single ArrayMap → not vectorized; falls through to the orthogonal path. + assert len(chunk_sel) == 2 + assert len(out_sel) == 2 + + +def test_sub_transform_to_selections_uncorrelated_arraymaps_with_out_indices() -> None: + """Two uncorrelated ArrayMaps (disjoint input_dimensions) plus out_indices + falls through to the non-vectorized branch (covers the for-loop early + exit when no correlation found).""" + t = IndexTransform( + domain=IndexDomain.from_shape((3, 4)), + output=( + ArrayMap( + index_array=np.array([0, 1, 2], dtype=np.intp), + input_dimensions=(0,), + ), + ArrayMap( + index_array=np.array([0, 1, 2, 3], dtype=np.intp), + input_dimensions=(1,), + ), + ), + ) + out_indices = np.array([0, 1], dtype=np.intp) + chunk_sel, out_sel, _drop_axes = sub_transform_to_selections(t, out_indices) + # Non-vectorized: each ArrayMap contributes its own out_sel entry. + assert len(chunk_sel) == 2 + assert len(out_sel) == 2 + + +def test_iter_chunk_transforms_skips_chunks_that_intersect_returns_none() -> None: + """A strided DimensionMap can produce a chunk-range overestimate that + includes chunks the transform doesn't actually touch. iter_chunk_transforms + must skip those (the `if result is None: continue` branch).""" + # arr[::5] over a domain of size 30 yields storage coords [0, 5, 10, 15, 20, 25]. + # With chunk size 4, those land in chunks 0, 1, 2, 3, 5, 6 — chunk 4 (storage [16,20)) + # is in the chunk-range but contains no surviving storage coord (storage 20 is in chunk 5). + # Wait: storage 20 lands in chunk floor(20/4) = 5; 16 is in chunk 4. Let me recheck. + # arr[::5] gives [0, 5, 10, 15, 20, 25]. Chunks (size 4): 0/4=0, 5/4=1, 10/4=2, + # 15/4=3, 20/4=5, 25/4=6. So chunk 4 (storage [16, 20)) is skipped. + # The chunk-range computed in iter_chunk_transforms is range(0, 7) -> 0..6 inclusive, + # so chunk 4 is iterated and intersect() returns None. + t = IndexTransform.from_shape((30,))[::5] + grid = _grid_1d(4, 30) + results = list(iter_chunk_transforms(t, grid)) + coords = sorted(r[0][0] for r in results) + # Every storage coord is hit exactly once; chunk 4 is NOT in the result. + assert 4 not in coords + assert sorted(coords) == [0, 1, 2, 3, 5, 6] diff --git a/tests/test_transforms/test_composition.py b/tests/test_transforms/test_composition.py new file mode 100644 index 0000000000..3035c3a2a9 --- /dev/null +++ b/tests/test_transforms/test_composition.py @@ -0,0 +1,407 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest + +from tests.test_transforms.conftest import Expect, ExpectErr +from zarr.core._transforms.composition import compose +from zarr.core._transforms.domain import IndexDomain +from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap +from zarr.core._transforms.transform import IndexTransform + +# Inner = ConstantMap: result is always ConstantMap regardless of outer. +_constant_inner = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(ConstantMap(offset=42),), +) +_identity_outer_5 = IndexTransform.from_shape((5,)) + +# Inner = DimensionMap with various outers. +_dimension_inner_0_10_3 = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=10, stride=3),), +) +_constant_outer_5 = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), +) +_dimension_outer_0_5_2 = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(DimensionMap(input_dimension=0, offset=5, stride=2),), +) +_array_outer_arr = np.array([0, 2, 4], dtype=np.intp) +_array_outer = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=_array_outer_arr, input_dimensions=(0,), offset=5, stride=2),), +) + +# Inner = ArrayMap with various outers. +_array_inner_arr = np.array([10, 20, 30], dtype=np.intp) +_array_inner = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=_array_inner_arr, input_dimensions=(0,), offset=0, stride=1),), +) +_constant_outer_1 = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(ConstantMap(offset=1),), +) +_array_outer_for_array_inner = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + ArrayMap( + index_array=np.array([0, 2, 1], dtype=np.intp), + input_dimensions=(0,), + offset=0, + stride=1, + ), + ), +) + + +@pytest.mark.parametrize( + "case", + [ + # Inner = ConstantMap. Result is constant with the inner's offset, regardless of outer. + Expect( + input=(_identity_outer_5, _constant_inner), + expected={"kind": ConstantMap, "offset": 42}, + id="constant-inner-identity-outer", + ), + # Inner = DimensionMap. + Expect( + input=(_constant_outer_5, _dimension_inner_0_10_3), + expected={"kind": ConstantMap, "offset": 25}, + id="dimension-inner-constant-outer", + ), + Expect( + input=(_dimension_outer_0_5_2, _dimension_inner_0_10_3), + expected={ + "kind": DimensionMap, + "offset": 25, + "stride": 6, + "input_dimension": 0, + }, + id="dimension-inner-dimension-outer", + ), + Expect( + input=(_array_outer, _dimension_inner_0_10_3), + expected={ + "kind": ArrayMap, + "offset": 25, + "stride": 6, + "index_array": _array_outer_arr, + }, + id="dimension-inner-array-outer", + ), + # Inner = ArrayMap. + Expect( + input=(_constant_outer_1, _array_inner), + expected={"kind": ConstantMap, "offset": 20}, + id="array-inner-constant-outer", + ), + Expect( + input=( + # Outer: 1-D identity-ish, input domain (4,), DimensionMap with + # offset=1 stride=1. Intermediate produced: [1, 2, 3, 4]. + IndexTransform( + domain=IndexDomain.from_shape((4,)), + output=(DimensionMap(input_dimension=0, offset=1, stride=1),), + ), + # Inner: ArrayMap of length 5 on intermediate dim 0. + # arr[1..4] = [200, 300, 400, 500]. + IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=( + ArrayMap( + index_array=np.array([100, 200, 300, 400, 500], dtype=np.intp), + input_dimensions=(0,), + ), + ), + ), + ), + expected={ + "kind": ArrayMap, + "offset": 0, + "stride": 1, + "index_array": np.array([200, 300, 400, 500], dtype=np.intp), + }, + id="array-inner-dimension-outer", + ), + Expect( + input=(_array_outer_for_array_inner, _array_inner), + expected={ + "kind": ArrayMap, + "offset": 0, + "stride": 1, + "index_array": np.array([10, 30, 20], dtype=np.intp), + }, + id="array-inner-array-outer", + ), + ], + ids=lambda c: c.id, +) +def test_compose_success( + case: Expect[tuple[IndexTransform, IndexTransform], dict[str, Any]], +) -> None: + """compose dispatches over (inner_kind, outer_kind) pairs and produces the expected result map.""" + outer, inner = case.input + result = compose(outer, inner) + assert len(result.output) == 1 + out0 = result.output[0] + assert isinstance(out0, case.expected["kind"]) + if "offset" in case.expected: + assert out0.offset == case.expected["offset"] + if "stride" in case.expected: + assert isinstance(out0, (DimensionMap, ArrayMap)) + assert out0.stride == case.expected["stride"] + if "input_dimension" in case.expected: + assert isinstance(out0, DimensionMap) + assert out0.input_dimension == case.expected["input_dimension"] + if "index_array" in case.expected: + assert isinstance(out0, ArrayMap) + np.testing.assert_array_equal(out0.index_array, case.expected["index_array"]) + + +def test_compose_2d_identity() -> None: + """Composing two identity 2D transforms yields a 2D identity.""" + a = IndexTransform.from_shape((10, 20)) + b = IndexTransform.from_shape((10, 20)) + result = compose(a, b) + assert result.domain.shape == (10, 20) + for i, m in enumerate(result.output): + assert isinstance(m, DimensionMap) + assert m.input_dimension == i + assert m.offset == 0 + assert m.stride == 1 + + +def test_compose_mixed_map_types() -> None: + """Outer has heterogeneous output maps; each composes independently with its inner image.""" + outer = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=( + ConstantMap(offset=5), + DimensionMap(input_dimension=0, offset=0, stride=1), + ), + ) + inner = IndexTransform( + domain=IndexDomain.from_shape((10, 10)), + output=( + DimensionMap(input_dimension=0, offset=2, stride=3), + DimensionMap(input_dimension=1, offset=0, stride=1), + ), + ) + result = compose(outer, inner) + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 17 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].input_dimension == 0 + assert result.output[1].offset == 0 + assert result.output[1].stride == 1 + + +def test_compose_chains_associatively() -> None: + """compose(a, compose(b, c)) yields the same offsets/strides as composing in order.""" + a = IndexTransform.from_shape((100,)) + b = IndexTransform( + domain=IndexDomain.from_shape((100,)), + output=(DimensionMap(input_dimension=0, offset=10, stride=1),), + ) + c = IndexTransform( + domain=IndexDomain.from_shape((100,)), + output=(DimensionMap(input_dimension=0, offset=5, stride=2),), + ) + abc = compose(a, compose(b, c)) + assert isinstance(abc.output[0], DimensionMap) + assert abc.output[0].offset == 25 + assert abc.output[0].stride == 2 + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexTransform.from_shape((10,)), IndexTransform.from_shape((10, 20))), + msg="rank", + exception_cls=ValueError, + id="outer-output-rank-vs-inner-input-rank-mismatch", + ), + ExpectErr( + input=( + # Outer is a non-constant 2D identity transform. + IndexTransform.from_shape((3, 2)), + # Inner has a 2D ArrayMap. _compose_array's general multi-dim + # path raises NotImplementedError for this combination. + IndexTransform( + domain=IndexDomain.from_shape((3, 2)), + output=( + ArrayMap( + index_array=np.array([[10, 20], [30, 40], [50, 60]], dtype=np.intp), + input_dimensions=(0, 1), + ), + ), + ), + ), + msg="not yet supported", + exception_cls=NotImplementedError, + id="multi-d-array-inner-non-constant-outer", + ), + ExpectErr( + input=( + # Outer with mixed types: ConstantMap on dim 0, DimensionMap on dim 1. + # Outer is NOT all-constant, so the early-return path is skipped. + IndexTransform( + domain=IndexDomain.from_shape((4,)), + output=( + ConstantMap(offset=2), + DimensionMap(input_dimension=0, offset=0, stride=1), + ), + ), + # Inner: 1-D ArrayMap referencing outer's dim 0 (the ConstantMap). + # _compose_array reaches the 1-D path; outer.output[0] is ConstantMap, + # which falls through both inner elifs to NotImplementedError. + IndexTransform( + domain=IndexDomain.from_shape((5, 4)), + output=( + ArrayMap( + index_array=np.array([10, 20, 30, 40, 50], dtype=np.intp), + input_dimensions=(0,), + ), + ), + ), + ), + msg="not yet supported", + exception_cls=NotImplementedError, + id="single-input-dim-points-at-constantmap-with-mixed-outer", + ), + ], + ids=lambda c: c.id, +) +def test_compose_errors(case: ExpectErr[tuple[IndexTransform, IndexTransform]]) -> None: + """compose raises on rank mismatch and on the unsupported multi-d-array compose path.""" + outer, inner = case.input + with pytest.raises(case.exception_cls, match=case.msg): + compose(outer, inner) + + +# --------------------------------------------------------------------------- +# Associativity property test. +# +# An IndexTransform models a function from input coords to output coords; +# function composition is associative by definition. Verify the implementation +# preserves that algebraic property by sampling random affine triples +# `(a, b, c)` with compatible ranks and checking that +# compose(compose(a, b), c) +# evaluates the same as +# compose(a, compose(b, c)) +# at randomly-chosen points in `a`'s domain. +# +# Restricted to DimensionMap + ConstantMap outputs (the affine subset). +# ArrayMap composition has implementation-level branching that depends on +# outer structure, and would need a more careful generator to avoid the +# NotImplementedError path; saved for a follow-up. +# --------------------------------------------------------------------------- + +pytest.importorskip("hypothesis") + +from hypothesis import assume, given, settings # noqa: E402 +from hypothesis import strategies as st # noqa: E402 + + +def _evaluate(transform: IndexTransform, user_coord: tuple[int, ...]) -> tuple[int, ...]: + """Evaluate a transform at a single input coordinate. + + Restricted to DimensionMap + ConstantMap outputs; `ArrayMap` is unsupported + here because the property test only generates affine triples. + """ + storage: list[int] = [] + for m in transform.output: + if isinstance(m, ConstantMap): + storage.append(m.offset) + elif isinstance(m, DimensionMap): + storage.append(m.offset + m.stride * user_coord[m.input_dimension]) + else: + raise TypeError(f"property test should not generate {type(m).__name__}; got {m!r}") + return tuple(storage) + + +def _affine_output_map(input_rank: int, draw: st.DrawFn) -> ConstantMap | DimensionMap: + """Generate one ConstantMap or DimensionMap output map. + + DimensionMap requires input_rank >= 1; falls back to ConstantMap otherwise. + Offsets and strides are kept small to avoid integer overflow during + repeated composition. Strides are positive (DimensionMap rejects + non-positive strides at construction). + """ + if input_rank == 0: + return ConstantMap(offset=draw(st.integers(min_value=-10, max_value=10))) + kind = draw(st.sampled_from(["constant", "dimension"])) + if kind == "constant": + return ConstantMap(offset=draw(st.integers(min_value=-10, max_value=10))) + return DimensionMap( + input_dimension=draw(st.integers(min_value=0, max_value=input_rank - 1)), + offset=draw(st.integers(min_value=-10, max_value=10)), + stride=draw(st.integers(min_value=1, max_value=3)), + ) + + +@st.composite +def _affine_transform(draw: st.DrawFn, input_rank: int, output_rank: int) -> IndexTransform: + """Generate an affine IndexTransform with the requested ranks.""" + domain_shape = tuple(draw(st.integers(min_value=1, max_value=8)) for _ in range(input_rank)) + domain = IndexDomain.from_shape(domain_shape) + output = tuple(_affine_output_map(input_rank, draw) for _ in range(output_rank)) + return IndexTransform(domain=domain, output=output) + + +@st.composite +def _affine_triple( + draw: st.DrawFn, +) -> tuple[IndexTransform, IndexTransform, IndexTransform]: + """Generate three rank-compatible affine transforms (a, b, c).""" + m = draw(st.integers(min_value=1, max_value=3)) # a's input rank + n = draw(st.integers(min_value=1, max_value=3)) # a's output / b's input rank + p = draw(st.integers(min_value=1, max_value=3)) # b's output / c's input rank + q = draw(st.integers(min_value=1, max_value=3)) # c's output rank + a = draw(_affine_transform(input_rank=m, output_rank=n)) + b = draw(_affine_transform(input_rank=n, output_rank=p)) + c = draw(_affine_transform(input_rank=p, output_rank=q)) + return a, b, c + + +@settings(max_examples=200, deadline=None) +@given(triple=_affine_triple(), data=st.data()) +def test_compose_is_associative( + triple: tuple[IndexTransform, IndexTransform, IndexTransform], + data: st.DataObject, +) -> None: + """For affine transforms, compose(compose(a,b),c) and compose(a,compose(b,c)) + evaluate identically at every point in a's domain.""" + a, b, c = triple + left = compose(compose(a, b), c) + right = compose(a, compose(b, c)) + + # Sanity: both compositions agree on rank and domain. + assert left.input_rank == right.input_rank + assert left.output_rank == right.output_rank + assert left.domain == right.domain + + # Sample several points from a's domain and compare evaluations at each. + # 5 coords per triple raises probabilistic coverage at negligible cost. + for _ in range(5): + if a.input_rank == 0: + coord: tuple[int, ...] = () + else: + coord = tuple( + data.draw( + st.integers( + min_value=a.domain.inclusive_min[d], + max_value=a.domain.exclusive_max[d] - 1, + ) + ) + for d in range(a.input_rank) + ) + assume(a.domain.contains(coord)) + assert _evaluate(left, coord) == _evaluate(right, coord) diff --git a/tests/test_transforms/test_domain.py b/tests/test_transforms/test_domain.py new file mode 100644 index 0000000000..0a168eadd5 --- /dev/null +++ b/tests/test_transforms/test_domain.py @@ -0,0 +1,471 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from tests.test_transforms.conftest import Expect, ExpectErr +from zarr.core._transforms.domain import IndexDomain, _normalize_selection + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input={"inclusive_min": (0, 0), "exclusive_max": (10, 20)}, + expected={"ndim": 2, "origin": (0, 0), "shape": (10, 20), "labels": None}, + id="2d-zero-origin", + ), + Expect( + input={"inclusive_min": (5, 10), "exclusive_max": (15, 30)}, + expected={"ndim": 2, "origin": (5, 10), "shape": (10, 20), "labels": None}, + id="2d-non-zero-origin", + ), + Expect( + input={"inclusive_min": (5,), "exclusive_max": (5,)}, + expected={"ndim": 1, "origin": (5,), "shape": (0,), "labels": None}, + id="1d-empty", + ), + Expect( + input={"inclusive_min": (), "exclusive_max": ()}, + expected={"ndim": 0, "origin": (), "shape": (), "labels": None}, + id="0d", + ), + Expect( + input={"inclusive_min": (0, 0), "exclusive_max": (10, 20), "labels": ("x", "y")}, + expected={"ndim": 2, "origin": (0, 0), "shape": (10, 20), "labels": ("x", "y")}, + id="2d-with-labels", + ), + ], + ids=lambda c: c.id, +) +def test_construction_success(case: Expect[dict[str, Any], dict[str, Any]]) -> None: + """IndexDomain construction yields the expected shape, origin, ndim, and labels.""" + d = IndexDomain(**case.input) + for prop, expected in case.expected.items(): + assert getattr(d, prop) == expected + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input={"inclusive_min": (0,), "exclusive_max": (10, 20)}, + msg="same length", + exception_cls=ValueError, + id="mismatched-min-max-lengths", + ), + ExpectErr( + input={"inclusive_min": (10,), "exclusive_max": (5,)}, + msg="inclusive_min must be <=", + exception_cls=ValueError, + id="min-greater-than-max", + ), + ExpectErr( + input={"inclusive_min": (0, 0), "exclusive_max": (10, 20), "labels": ("x",)}, + msg="labels must have the same length as dimensions", + exception_cls=ValueError, + id="labels-wrong-length", + ), + ], + ids=lambda c: c.id, +) +def test_construction_errors(case: ExpectErr[dict[str, Any]]) -> None: + """IndexDomain construction with invalid inputs raises ValueError.""" + with pytest.raises(case.exception_cls, match=case.msg): + IndexDomain(**case.input) + + +@pytest.mark.parametrize( + "case", + [ + Expect(input=(10, 20), expected=(2, (0, 0), (10, 20)), id="2d"), + Expect(input=(10,), expected=(1, (0,), (10,)), id="1d"), + Expect(input=(), expected=(0, (), ()), id="0d"), + ], + ids=lambda c: c.id, +) +def test_from_shape_success( + case: Expect[tuple[int, ...], tuple[int, tuple[int, ...], tuple[int, ...]]], +) -> None: + """IndexDomain.from_shape produces a zero-origin domain with the requested shape.""" + d = IndexDomain.from_shape(case.input) + expected_ndim, expected_origin, expected_shape = case.expected + assert d.ndim == expected_ndim + assert d.origin == expected_origin + assert d.shape == expected_shape + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexDomain.from_shape((10, 20)), (0, 0)), + expected=True, + id="2d-corner-low", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), (9, 19)), + expected=True, + id="2d-corner-high", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), (5, 10)), + expected=True, + id="2d-interior", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), (10, 0)), + expected=False, + id="2d-outside-high", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), (-1, 0)), + expected=False, + id="2d-outside-low", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), (5,)), + expected=False, + id="wrong-ndim", + ), + Expect( + input=(IndexDomain(inclusive_min=(5,), exclusive_max=(10,)), (5,)), + expected=True, + id="non-zero-origin-low", + ), + Expect( + input=(IndexDomain(inclusive_min=(5,), exclusive_max=(10,)), (4,)), + expected=False, + id="non-zero-origin-below", + ), + ], + ids=lambda c: c.id, +) +def test_contains_success(case: Expect[tuple[IndexDomain, tuple[int, ...]], bool]) -> None: + """IndexDomain.contains returns True iff the index is within the domain.""" + domain, index = case.input + assert domain.contains(index) is case.expected + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=( + IndexDomain.from_shape((10, 20)), + IndexDomain(inclusive_min=(2, 3), exclusive_max=(8, 15)), + ), + expected=True, + id="strict-subset", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), IndexDomain.from_shape((10, 20))), + expected=True, + id="equal-domains", + ), + Expect( + input=( + IndexDomain.from_shape((10, 20)), + IndexDomain(inclusive_min=(2, 3), exclusive_max=(11, 15)), + ), + expected=False, + id="extends-past-max", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), IndexDomain.from_shape((5,))), + expected=False, + id="wrong-ndim", + ), + ], + ids=lambda c: c.id, +) +def test_contains_domain_success(case: Expect[tuple[IndexDomain, IndexDomain], bool]) -> None: + """IndexDomain.contains_domain returns True iff `other` is fully contained.""" + outer, inner = case.input + assert outer.contains_domain(inner) is case.expected + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=( + IndexDomain(inclusive_min=(0, 0), exclusive_max=(10, 10)), + IndexDomain(inclusive_min=(5, 5), exclusive_max=(15, 15)), + ), + expected=IndexDomain(inclusive_min=(5, 5), exclusive_max=(10, 10)), + id="overlapping-2d", + ), + Expect( + input=( + IndexDomain.from_shape((20,)), + IndexDomain(inclusive_min=(5,), exclusive_max=(10,)), + ), + expected=IndexDomain(inclusive_min=(5,), exclusive_max=(10,)), + id="contained", + ), + Expect( + input=( + IndexDomain(inclusive_min=(0,), exclusive_max=(5,)), + IndexDomain(inclusive_min=(10,), exclusive_max=(15,)), + ), + expected=None, + id="disjoint", + ), + Expect( + input=( + IndexDomain(inclusive_min=(0,), exclusive_max=(5,)), + IndexDomain(inclusive_min=(5,), exclusive_max=(10,)), + ), + expected=None, + id="touching-boundary", + ), + ], + ids=lambda c: c.id, +) +def test_intersect_success( + case: Expect[tuple[IndexDomain, IndexDomain], IndexDomain | None], +) -> None: + """IndexDomain.intersect returns the intersection, or None when disjoint.""" + a, b = case.input + assert a.intersect(b) == case.expected + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexDomain.from_shape((10,)), IndexDomain.from_shape((10, 20))), + msg="different ranks", + exception_cls=ValueError, + id="rank-mismatch", + ), + ], + ids=lambda c: c.id, +) +def test_intersect_errors(case: ExpectErr[tuple[IndexDomain, IndexDomain]]) -> None: + """IndexDomain.intersect raises ValueError on rank mismatch.""" + a, b = case.input + with pytest.raises(case.exception_cls, match=case.msg): + a.intersect(b) + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexDomain.from_shape((10, 20)), (5, 10)), + expected=IndexDomain(inclusive_min=(5, 10), exclusive_max=(15, 30)), + id="positive-offset", + ), + Expect( + input=(IndexDomain(inclusive_min=(10, 20), exclusive_max=(30, 40)), (-10, -20)), + expected=IndexDomain(inclusive_min=(0, 0), exclusive_max=(20, 20)), + id="negative-offset", + ), + Expect( + input=(IndexDomain.from_shape((10,)), (0,)), + expected=IndexDomain.from_shape((10,)), + id="zero-offset", + ), + ], + ids=lambda c: c.id, +) +def test_translate_success( + case: Expect[tuple[IndexDomain, tuple[int, ...]], IndexDomain], +) -> None: + """IndexDomain.translate shifts every coordinate by the offset.""" + domain, offset = case.input + assert domain.translate(offset) == case.expected + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexDomain.from_shape((10,)), (1, 2)), + msg="same length", + exception_cls=ValueError, + id="offset-too-long", + ), + ExpectErr( + input=(IndexDomain.from_shape((10, 20)), (1,)), + msg="same length", + exception_cls=ValueError, + id="offset-too-short", + ), + ], + ids=lambda c: c.id, +) +def test_translate_errors(case: ExpectErr[tuple[IndexDomain, tuple[int, ...]]]) -> None: + """IndexDomain.translate raises when offset length differs from ndim.""" + domain, offset = case.input + with pytest.raises(case.exception_cls, match=case.msg): + domain.translate(offset) + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexDomain.from_shape((10, 20)), (slice(2, 8), slice(5, 15))), + expected=IndexDomain(inclusive_min=(2, 5), exclusive_max=(8, 15)), + id="2d-slices", + ), + Expect( + input=(IndexDomain.from_shape((10, 20)), (3, slice(None))), + expected=IndexDomain(inclusive_min=(3, 0), exclusive_max=(4, 20)), + id="int-and-slice", + ), + Expect( + input=(IndexDomain.from_shape((10, 20, 30)), (slice(1, 5), ...)), + expected=IndexDomain(inclusive_min=(1, 0, 0), exclusive_max=(5, 20, 30)), + id="ellipsis-fills-trailing", + ), + Expect( + input=(IndexDomain.from_shape((10,)), (slice(None),)), + expected=IndexDomain.from_shape((10,)), + id="slice-none-is-noop", + ), + Expect( + input=(IndexDomain(inclusive_min=(10,), exclusive_max=(20,)), (slice(12, 18),)), + expected=IndexDomain(inclusive_min=(12,), exclusive_max=(18,)), + id="non-zero-origin", + ), + Expect( + input=(IndexDomain.from_shape((10,)), (slice(-5, 100),)), + expected=IndexDomain(inclusive_min=(0,), exclusive_max=(10,)), + id="clamps-to-domain", + ), + Expect( + input=(IndexDomain.from_shape((10,)), slice(2, 8)), + expected=IndexDomain(inclusive_min=(2,), exclusive_max=(8,)), + id="bare-slice-is-wrapped", + ), + ], + ids=lambda c: c.id, +) +def test_narrow_success(case: Expect[tuple[IndexDomain, Any], IndexDomain]) -> None: + """IndexDomain.narrow applies basic indexing to produce a sub-domain.""" + domain, selection = case.input + assert domain.narrow(selection) == case.expected + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexDomain.from_shape((10,)), (10,)), + msg="out of bounds", + exception_cls=IndexError, + id="int-at-upper-bound", + ), + ExpectErr( + input=(IndexDomain(inclusive_min=(5,), exclusive_max=(10,)), (4,)), + msg="out of bounds", + exception_cls=IndexError, + id="int-below-origin", + ), + ExpectErr( + input=(IndexDomain.from_shape((10,)), (1, 2)), + msg="too many indices", + exception_cls=IndexError, + id="too-many-indices", + ), + ExpectErr( + input=(IndexDomain.from_shape((10,)), (slice(0, 10, 2),)), + msg="step=1", + exception_cls=IndexError, + id="non-unit-step", + ), + ], + ids=lambda c: c.id, +) +def test_narrow_errors(case: ExpectErr[tuple[IndexDomain, Any]]) -> None: + """IndexDomain.narrow raises IndexError on invalid selections.""" + domain, selection = case.input + with pytest.raises(case.exception_cls, match=case.msg): + domain.narrow(selection) + + +# --------------------------------------------------------------------------- +# Direct tests for the non-trivial private helper _normalize_selection. +# Public callers (`IndexDomain.narrow` and `selection_to_transform`) exercise +# most branches transitively, but the double-ellipsis guard only triggers on +# inputs no public caller currently constructs. Test it directly. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=((slice(2, 8), slice(5, 15)), 2), + expected=(slice(2, 8), slice(5, 15)), + id="explicit-slices", + ), + Expect( + input=((3, slice(None)), 2), + expected=(3, slice(None)), + id="int-and-slice", + ), + Expect( + input=((..., slice(0, 5)), 3), + expected=(slice(None), slice(None), slice(0, 5)), + id="leading-ellipsis-fills", + ), + Expect( + input=((slice(0, 5), ...), 3), + expected=(slice(0, 5), slice(None), slice(None)), + id="trailing-ellipsis-fills", + ), + Expect( + input=((slice(2, 8),), 3), + expected=(slice(2, 8), slice(None), slice(None)), + id="implicit-trailing-fills", + ), + Expect( + input=(slice(2, 8), 1), + expected=(slice(2, 8),), + id="bare-slice-is-wrapped", + ), + Expect( + input=(5, 1), + expected=(5,), + id="bare-int-is-wrapped", + ), + ], + ids=lambda c: c.id, +) +def test_normalize_selection_success( + case: Expect[tuple[Any, int], tuple[int | slice, ...]], +) -> None: + """_normalize_selection produces a length-ndim tuple of ints/slices.""" + selection, ndim = case.input + assert _normalize_selection(selection, ndim) == case.expected + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=((..., ..., slice(0, 5)), 3), + msg="single ellipsis", + exception_cls=IndexError, + id="double-ellipsis", + ), + ExpectErr( + input=((1, 2, 3), 2), + msg="too many indices", + exception_cls=IndexError, + id="too-many-indices", + ), + ], + ids=lambda c: c.id, +) +def test_normalize_selection_errors(case: ExpectErr[tuple[Any, int]]) -> None: + """_normalize_selection rejects double ellipsis and over-long selections.""" + selection, ndim = case.input + with pytest.raises(case.exception_cls, match=case.msg): + _normalize_selection(selection, ndim) diff --git a/tests/test_transforms/test_output_map.py b/tests/test_transforms/test_output_map.py new file mode 100644 index 0000000000..0e62017b0a --- /dev/null +++ b/tests/test_transforms/test_output_map.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from dataclasses import FrozenInstanceError, asdict +from typing import Any + +import numpy as np +import pytest + +from tests.test_transforms.conftest import Expect, ExpectErr +from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=ConstantMap(offset=42), + expected={"offset": 42}, + id="ConstantMap-explicit-offset", + ), + Expect( + input=ConstantMap(), + expected={"offset": 0}, + id="ConstantMap-default-offset", + ), + Expect( + input=DimensionMap(input_dimension=3, offset=5, stride=2), + expected={"input_dimension": 3, "offset": 5, "stride": 2}, + id="DimensionMap-all-fields", + ), + Expect( + input=DimensionMap(input_dimension=0), + expected={"input_dimension": 0, "offset": 0, "stride": 1}, + id="DimensionMap-defaults", + ), + Expect( + input=ArrayMap( + index_array=np.array([1, 3, 5], dtype=np.intp), + input_dimensions=(0,), + offset=10, + stride=2, + ), + expected={ + "index_array": np.array([1, 3, 5], dtype=np.intp), + "input_dimensions": (0,), + "offset": 10, + "stride": 2, + }, + id="ArrayMap-all-fields", + ), + Expect( + input=ArrayMap( + index_array=np.array([0, 1], dtype=np.intp), + input_dimensions=(0,), + ), + expected={ + "index_array": np.array([0, 1], dtype=np.intp), + "input_dimensions": (0,), + "offset": 0, + "stride": 1, + }, + id="ArrayMap-defaults", + ), + ], + ids=lambda c: c.id, +) +def test_construction_success(case: Expect[Any, dict[str, Any]]) -> None: + """Constructing each map type with explicit and default values yields the expected fields.""" + actual = asdict(case.input) + assert set(actual) == set(case.expected) + for field, expected_value in case.expected.items(): + if isinstance(expected_value, np.ndarray): + np.testing.assert_array_equal(actual[field], expected_value) + else: + assert actual[field] == expected_value + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(ConstantMap(offset=5), "offset", 99), + msg="cannot assign to field 'offset'", + exception_cls=FrozenInstanceError, + id="ConstantMap-frozen", + ), + ExpectErr( + input=(DimensionMap(input_dimension=0), "stride", 7), + msg="cannot assign to field 'stride'", + exception_cls=FrozenInstanceError, + id="DimensionMap-frozen", + ), + ExpectErr( + input=( + ArrayMap(index_array=np.array([0], dtype=np.intp), input_dimensions=(0,)), + "offset", + 1, + ), + msg="cannot assign to field 'offset'", + exception_cls=FrozenInstanceError, + id="ArrayMap-frozen", + ), + ], + ids=lambda c: c.id, +) +def test_mutation_errors(case: ExpectErr[tuple[Any, str, Any]]) -> None: + """Attempting to mutate a frozen output map raises FrozenInstanceError.""" + obj, field, new_value = case.input + with pytest.raises(case.exception_cls, match=case.msg): + setattr(obj, field, new_value) diff --git a/tests/test_transforms/test_transform.py b/tests/test_transforms/test_transform.py new file mode 100644 index 0000000000..b615048d4a --- /dev/null +++ b/tests/test_transforms/test_transform.py @@ -0,0 +1,1472 @@ +from __future__ import annotations + +from typing import Any, Literal + +import numpy as np +import pytest + +from tests.test_transforms.conftest import Expect, ExpectErr +from zarr.core._transforms.domain import IndexDomain +from zarr.core._transforms.output_map import ArrayMap, ConstantMap, DimensionMap +from zarr.core._transforms.transform import ( + IndexTransform, + _intersect_vectorized, + selection_to_transform, +) + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=IndexTransform.from_shape((10, 20)), + expected={"input_rank": 2, "output_rank": 2, "domain_shape": (10, 20)}, + id="from_shape-2d", + ), + Expect( + input=IndexTransform.from_shape(()), + expected={"input_rank": 0, "output_rank": 0, "domain_shape": ()}, + id="from_shape-0d", + ), + Expect( + input=IndexTransform.identity(IndexDomain(inclusive_min=(5,), exclusive_max=(15,))), + expected={"input_rank": 1, "output_rank": 1, "domain_shape": (10,)}, + id="identity-non-zero-origin", + ), + Expect( + input=IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=42), DimensionMap(input_dimension=0)), + ), + expected={"input_rank": 1, "output_rank": 2, "domain_shape": (10,)}, + id="custom-output-maps", + ), + ], + ids=lambda c: c.id, +) +def test_construction_success(case: Expect[IndexTransform, dict[str, Any]]) -> None: + """IndexTransform constructors yield the expected ranks and domain shape.""" + t = case.input + assert t.input_rank == case.expected["input_rank"] + assert t.output_rank == case.expected["output_rank"] + assert t.domain.shape == case.expected["domain_shape"] + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input={ + "domain": IndexDomain.from_shape((10,)), + "output": (DimensionMap(input_dimension=5),), + }, + msg="input_dimension", + exception_cls=ValueError, + id="dimension-map-out-of-range", + ), + ExpectErr( + input={ + "domain": IndexDomain.from_shape((10,)), + "output": (DimensionMap(input_dimension=0, stride=0),), + }, + msg="must be positive", + exception_cls=ValueError, + id="dimension-map-zero-stride", + ), + ExpectErr( + input={ + "domain": IndexDomain.from_shape((10,)), + "output": (DimensionMap(input_dimension=0, stride=-1),), + }, + msg="must be positive", + exception_cls=ValueError, + id="dimension-map-negative-stride", + ), + ExpectErr( + input={ + "domain": IndexDomain.from_shape((5, 3)), + "output": ( + ArrayMap( + index_array=np.array([0, 1, 2], dtype=np.intp), + input_dimensions=(7,), + ), + ), + }, + msg="out of range", + exception_cls=ValueError, + id="array-map-input-dim-out-of-range", + ), + ExpectErr( + input={ + "domain": IndexDomain.from_shape((5, 3)), + "output": ( + ArrayMap( + index_array=np.zeros((5, 5), dtype=np.intp), + input_dimensions=(0, 0), + ), + ), + }, + msg="duplicate dimensions", + exception_cls=ValueError, + id="array-map-input-dims-duplicate", + ), + ], + ids=lambda c: c.id, +) +def test_construction_errors(case: ExpectErr[dict[str, Any]]) -> None: + """IndexTransform construction with invalid output maps raises ValueError.""" + with pytest.raises(case.exception_cls, match=case.msg): + IndexTransform(**case.input) + + +# --------------------------------------------------------------------------- +# from_shape produces an identity transform whose output maps are DimensionMaps +# pointing at the corresponding input dim with offset=0, stride=1. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect(input=(10, 20), expected=2, id="2d"), + Expect(input=(7,), expected=1, id="1d"), + Expect(input=(), expected=0, id="0d"), + ], + ids=lambda c: c.id, +) +def test_from_shape_produces_identity_dimension_maps( + case: Expect[tuple[int, ...], int], +) -> None: + """IndexTransform.from_shape produces DimensionMaps that map each output dim + back to the corresponding input dim, with no offset and unit stride.""" + t = IndexTransform.from_shape(case.input) + assert len(t.output) == case.expected + for i, m in enumerate(t.output): + assert isinstance(m, DimensionMap) + assert m.input_dimension == i + assert m.offset == 0 + assert m.stride == 1 + + +# --------------------------------------------------------------------------- +# __getitem__ (basic indexing) +# +# Most successful branches are covered by selection_to_transform tests below; +# this set focuses on cases unique to the __getitem__ surface (composition, +# bare-int / bare-slice, ArrayMap interactions). +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexTransform.from_shape((10, 20)), (slice(None), slice(None))), + expected={"shape": (10, 20), "input_rank": 2, "output_rank": 2}, + id="identity-slice", + ), + Expect( + input=(IndexTransform.from_shape((10, 20)), (slice(2, 8), slice(5, 15))), + expected={"shape": (6, 10), "input_rank": 2, "output_rank": 2}, + id="2d-narrowing-slices", + ), + Expect( + input=(IndexTransform.from_shape((10,)), slice(None, None, 2)), + expected={"shape": (5,), "input_rank": 1, "output_rank": 1}, + id="strided-slice", + ), + Expect( + input=(IndexTransform.from_shape((10,)), slice(1, 9, 3)), + expected={"shape": (3,), "input_rank": 1, "output_rank": 1}, + id="strided-slice-with-start", + ), + Expect( + input=(IndexTransform.from_shape((10, 20)), 3), + expected={"shape": (20,), "input_rank": 1, "output_rank": 2}, + id="bare-int-drops-leading-dim", + ), + Expect( + input=(IndexTransform.from_shape((10, 20, 30)), (slice(None), 5, slice(None))), + expected={"shape": (10, 30), "input_rank": 2, "output_rank": 3}, + id="int-drops-middle-dim", + ), + Expect( + input=(IndexTransform.from_shape((10, 20, 30)), (slice(2, 8), ...)), + expected={"shape": (6, 20, 30), "input_rank": 3, "output_rank": 3}, + id="ellipsis-fills-trailing", + ), + Expect( + input=(IndexTransform.from_shape((10, 20)), (np.newaxis, slice(None), slice(None))), + expected={"shape": (1, 10, 20), "input_rank": 3, "output_rank": 2}, + id="newaxis-prepends-axis", + ), + Expect( + input=(IndexTransform.from_shape((10, 20)), slice(2, 8)), + expected={"shape": (6, 20), "input_rank": 2, "output_rank": 2}, + id="bare-slice-implicitly-fills-trailing", + ), + ], + ids=lambda c: c.id, +) +def test_getitem_basic_success( + case: Expect[tuple[IndexTransform, Any], dict[str, Any]], +) -> None: + """IndexTransform.__getitem__ produces a sub-transform with the expected shape and rank.""" + transform, selection = case.input + result = transform[selection] + assert result.domain.shape == case.expected["shape"] + assert result.input_rank == case.expected["input_rank"] + assert result.output_rank == case.expected["output_rank"] + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexTransform.from_shape((10,)), 10), + msg="out of bounds", + exception_cls=IndexError, + id="int-at-upper-bound", + ), + ExpectErr( + input=(IndexTransform.from_shape((10,)), -1), + msg="out of bounds", + exception_cls=IndexError, + id="negative-int-out-of-domain", + ), + ], + ids=lambda c: c.id, +) +def test_getitem_basic_errors(case: ExpectErr[tuple[IndexTransform, Any]]) -> None: + """IndexTransform.__getitem__ rejects out-of-domain integer indices. + + Note: negative indices are LITERAL coordinates per TensorStore convention, + not wrap-around. arr[-1] on a domain [0, 10) is out of bounds, not arr[9]. + """ + transform, selection = case.input + with pytest.raises(case.exception_cls, match=case.msg): + transform[selection] + + +def test_getitem_negative_int_valid_with_negative_origin() -> None: + """A negative integer index is valid when the domain's origin is negative. + + Stand-alone test (not parametrized) because verifying the *literal-coordinate* + semantics is the whole point — the assertion on the resulting ConstantMap + offset is the load-bearing check, not the shape. + """ + domain = IndexDomain(inclusive_min=(-5,), exclusive_max=(5,)) + t = IndexTransform.identity(domain) + result = t[-3] + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == -3 + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexTransform.from_shape((100,))[10:50], slice(5, 20)), + expected={"shape": (15,), "offset": 15, "stride": 1}, + id="composed-slices", + ), + Expect( + input=(IndexTransform.from_shape((100,))[::2], slice(None, None, 3)), + expected={"shape": (17,), "offset": 0, "stride": 6}, + id="composed-strides", + ), + ], + ids=lambda c: c.id, +) +def test_getitem_composition( + case: Expect[tuple[IndexTransform, Any], dict[str, Any]], +) -> None: + """Indexing a sliced transform composes offsets and strides on the DimensionMap.""" + transform, selection = case.input + result = transform[selection] + assert result.domain.shape == case.expected["shape"] + assert isinstance(result.output[0], DimensionMap) + assert result.output[0].offset == case.expected["offset"] + assert result.output[0].stride == case.expected["stride"] + + +# Indexing into a transform whose output is already an ArrayMap — basic +# operations (int/slice/stride/newaxis) must transform the index_array itself +# rather than building a new map. +_array_map_1d = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=( + ArrayMap( + index_array=np.array([10, 20, 30, 40, 50], dtype=np.intp), + input_dimensions=(0,), + ), + ), +) +_array_map_2d_3x2 = IndexTransform( + domain=IndexDomain.from_shape((3, 2)), + output=( + ArrayMap( + index_array=np.array([[10, 20], [30, 40], [50, 60]], dtype=np.intp), + input_dimensions=(0, 1), + ), + ), +) +_array_map_2d_2x3 = IndexTransform( + domain=IndexDomain.from_shape((2, 3)), + output=( + ArrayMap( + index_array=np.array([[10, 20, 30], [40, 50, 60]], dtype=np.intp), + input_dimensions=(0, 1), + ), + ), +) + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(_array_map_2d_3x2, 1), + expected=np.array([30, 40], dtype=np.intp), + id="int-on-array-map-drops-axis", + ), + Expect( + input=(_array_map_1d, slice(1, 4)), + expected=np.array([20, 30, 40], dtype=np.intp), + id="slice-on-array-map", + ), + Expect( + input=(_array_map_1d, slice(None, None, 2)), + expected=np.array([10, 30, 50], dtype=np.intp), + id="strided-slice-on-array-map", + ), + Expect( + input=(_array_map_2d_2x3, (0, slice(1, 3))), + expected=np.array([20, 30], dtype=np.intp), + id="int-then-slice-on-2d-array-map", + ), + ], + ids=lambda c: c.id, +) +def test_getitem_on_array_map( + case: Expect[tuple[IndexTransform, Any], np.ndarray[Any, np.dtype[np.intp]]], +) -> None: + """Basic indexing on a transform whose output is an ArrayMap reshapes the index array.""" + transform, selection = case.input + result = transform[selection] + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, case.expected) + + +def test_getitem_newaxis_on_array_map() -> None: + """np.newaxis on an ArrayMap inserts a new input dim into the domain but + leaves the array's parameterization unchanged. The array's input_dimensions + just shifts to point at the new index of the old dim.""" + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + ArrayMap( + index_array=np.array([10, 20, 30], dtype=np.intp), + input_dimensions=(0,), + ), + ), + ) + result = t[np.newaxis, :] + assert result.input_rank == 2 + assert result.domain.shape == (1, 3) + assert isinstance(result.output[0], ArrayMap) + # newaxis is at new dim 0; old dim 0 shifts to new dim 1. + assert result.output[0].input_dimensions == (1,) + assert result.output[0].index_array.shape == (3,) + np.testing.assert_array_equal(result.output[0].index_array, np.array([10, 20, 30])) + + +# --------------------------------------------------------------------------- +# oindex (orthogonal indexing) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=( + IndexTransform.from_shape((10, 20)), + (np.array([1, 3, 5], dtype=np.intp), slice(None)), + ), + expected={"shape": (3, 20), "out0_kind": ArrayMap, "out1_kind": DimensionMap}, + id="int-array-and-slice", + ), + Expect( + input=(IndexTransform.from_shape((5,)), (np.array([True, False, True, False, True]),)), + expected={"shape": (3,), "out0_kind": ArrayMap, "out1_kind": None}, + id="bool-mask", + ), + Expect( + input=( + IndexTransform.from_shape((10, 20)), + (np.array([2, 4], dtype=np.intp), slice(5, 15)), + ), + expected={"shape": (2, 10), "out0_kind": ArrayMap, "out1_kind": DimensionMap}, + id="array-and-narrowing-slice", + ), + Expect( + input=( + IndexTransform.from_shape((10, 20, 30)), + ( + np.array([1, 3], dtype=np.intp), + slice(None), + np.array([5, 10, 15], dtype=np.intp), + ), + ), + expected={"shape": (2, 20, 3), "out0_kind": ArrayMap, "out1_kind": DimensionMap}, + id="three-dims-mixed", + ), + ], + ids=lambda c: c.id, +) +def test_oindex_success(case: Expect[tuple[IndexTransform, Any], dict[str, Any]]) -> None: + """IndexTransform.oindex combines array indices independently per dimension.""" + transform, selection = case.input + result = transform.oindex[selection] + assert result.domain.shape == case.expected["shape"] + assert isinstance(result.output[0], case.expected["out0_kind"]) + if case.expected["out1_kind"] is not None: + assert isinstance(result.output[1], case.expected["out1_kind"]) + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexTransform.from_shape((10,)), (slice(None, None, -1),)), + msg="slice step must be positive", + exception_cls=IndexError, + id="negative-slice-step", + ), + ], + ids=lambda c: c.id, +) +def test_oindex_errors(case: ExpectErr[tuple[IndexTransform, Any]]) -> None: + """IndexTransform.oindex rejects non-positive slice steps.""" + transform, selection = case.input + with pytest.raises(case.exception_cls, match=case.msg): + transform.oindex[selection] + + +def test_oindex_on_1d_array_map_with_int_array() -> None: + """oindex on a transform with a 1-D ArrayMap output indexes that ArrayMap's + array along its single parameterizing input dim.""" + arr = np.array([10, 20, 30, 40, 50], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(ArrayMap(index_array=arr, input_dimensions=(0,)),), + ) + result = t.oindex[np.array([0, 2, 4], dtype=np.intp)] + assert result.input_rank == 1 + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, np.array([10, 30, 50])) + + +def test_oindex_on_2d_array_map_all_slices() -> None: + """oindex on a 2-D ArrayMap with slices on every axis is well-defined + (no axes selected by integer arrays).""" + arr = np.arange(12, dtype=np.intp).reshape(3, 4) + t = IndexTransform( + domain=IndexDomain.from_shape((3, 4)), + output=(ArrayMap(index_array=arr, input_dimensions=(0, 1)),), + ) + # Both axes sliced; no array indices. + result = t.oindex[1:3, 0:3] + assert result.input_rank == 2 + assert result.domain.shape == (2, 3) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, arr[1:3, 0:3]) + + +def test_oindex_on_multi_dim_array_map_with_two_array_axes_errors() -> None: + """oindex on a multi-dim ArrayMap with two or more axes selected by + integer arrays needs np.ix_-style outer-product semantics. Until that + is implemented, raise NotImplementedError.""" + arr = np.arange(12, dtype=np.intp).reshape(3, 4) + t = IndexTransform( + domain=IndexDomain.from_shape((3, 4)), + output=(ArrayMap(index_array=arr, input_dimensions=(0, 1)),), + ) + with pytest.raises(NotImplementedError, match="multi-dimensional ArrayMap"): + t.oindex[np.array([0, 2], dtype=np.intp), np.array([1, 3], dtype=np.intp)] + + +# --------------------------------------------------------------------------- +# vindex (vectorized indexing) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexTransform.from_shape((10,)), np.array([1, 3, 5], dtype=np.intp)), + expected=(3,), + id="single-1d-array", + ), + Expect( + input=( + IndexTransform.from_shape((10, 20)), + ( + np.array([[1, 2], [3, 4]], dtype=np.intp), + np.array([[10, 11], [12, 13]], dtype=np.intp), + ), + ), + expected=(2, 2), + id="two-2d-arrays-broadcast", + ), + Expect( + input=( + IndexTransform.from_shape((10, 20, 30)), + (np.array([1, 3, 5], dtype=np.intp), slice(None), slice(None)), + ), + expected=(3, 20, 30), + id="array-with-trailing-slices", + ), + Expect( + input=(IndexTransform.from_shape((5,)), np.array([True, False, True, False, True])), + expected=(3,), + id="bool-mask", + ), + Expect( + input=( + IndexTransform.from_shape((10, 20)), + (np.array([1, 2, 3], dtype=np.intp), np.array([[10], [11]], dtype=np.intp)), + ), + expected=(2, 3), + id="broadcast-different-shapes", + ), + ], + ids=lambda c: c.id, +) +def test_vindex_success(case: Expect[tuple[IndexTransform, Any], tuple[int, ...]]) -> None: + """IndexTransform.vindex broadcasts array indices and produces correlated ArrayMaps.""" + transform, selection = case.input + result = transform.vindex[selection] + assert result.domain.shape == case.expected + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexTransform.from_shape((10,)), (slice(None, None, -1),)), + msg="slice step must be positive", + exception_cls=IndexError, + id="negative-slice-step", + ), + ], + ids=lambda c: c.id, +) +def test_vindex_errors(case: ExpectErr[tuple[IndexTransform, Any]]) -> None: + """IndexTransform.vindex rejects non-positive slice steps.""" + transform, selection = case.input + with pytest.raises(case.exception_cls, match=case.msg): + transform.vindex[selection] + + +# --------------------------------------------------------------------------- +# selection_to_transform — the public dispatch front door for all three modes. +# Sanity check that each mode produces the expected output kind. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(IndexTransform.from_shape((10, 20)), (slice(2, 8), slice(5, 15)), "basic"), + expected={"shape": (6, 10), "out0_kind": DimensionMap}, + id="basic-slices", + ), + Expect( + input=(IndexTransform.from_shape((10, 20)), (3, slice(None)), "basic"), + expected={"shape": (20,), "out0_kind": ConstantMap}, + id="basic-int-and-slice", + ), + Expect( + input=(IndexTransform.from_shape((10, 20)), Ellipsis, "basic"), + expected={"shape": (10, 20), "out0_kind": DimensionMap}, + id="basic-bare-ellipsis", + ), + Expect( + input=( + IndexTransform.from_shape((10, 20)), + (np.array([1, 3, 5], dtype=np.intp), slice(None)), + "orthogonal", + ), + expected={"shape": (3, 20), "out0_kind": ArrayMap}, + id="orthogonal", + ), + Expect( + input=( + IndexTransform.from_shape((10, 20)), + (np.array([1, 3], dtype=np.intp), np.array([5, 7], dtype=np.intp)), + "vectorized", + ), + expected={"shape": (2,), "out0_kind": ArrayMap}, + id="vectorized", + ), + Expect( + input=(IndexTransform.from_shape((100,))[10:50], slice(5, 20), "basic"), + expected={"shape": (15,), "out0_kind": DimensionMap}, + id="composes-with-non-identity-base", + ), + ], + ids=lambda c: c.id, +) +def test_selection_to_transform_success( + case: Expect[ + tuple[IndexTransform, Any, Literal["basic", "orthogonal", "vectorized"]], dict[str, Any] + ], +) -> None: + """selection_to_transform dispatches to basic/orthogonal/vectorized correctly.""" + transform, selection, mode = case.input + result = selection_to_transform(selection, transform, mode) + assert result.domain.shape == case.expected["shape"] + assert isinstance(result.output[0], case.expected["out0_kind"]) + + +def test_selection_to_transform_unknown_mode_errors() -> None: + """selection_to_transform rejects unknown indexing modes. + + The `mode` parameter is typed as `Literal["basic", "orthogonal", "vectorized"]`, + so this test bypasses static type checking to exercise the runtime guard. + """ + t = IndexTransform.from_shape((10,)) + with pytest.raises(ValueError, match="Unknown mode"): + selection_to_transform(slice(None), t, "diagonal") # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# intersect — restrict an output domain. Returns (sub_transform, surviving) +# or None when the intersection is empty. +# --------------------------------------------------------------------------- + + +def test_intersect_constant_inside() -> None: + """A ConstantMap whose offset is inside the chunk survives unchanged.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), + ) + result = t.intersect(IndexDomain(inclusive_min=(0,), exclusive_max=(10,))) + assert result is not None + restricted, surviving = result + assert isinstance(restricted.output[0], ConstantMap) + assert restricted.output[0].offset == 5 + assert surviving is None + + +def test_intersect_constant_outside() -> None: + """A ConstantMap whose offset is outside the chunk yields None.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), + ) + assert t.intersect(IndexDomain(inclusive_min=(10,), exclusive_max=(20,))) is None + + +def test_intersect_dimension_partial() -> None: + """A DimensionMap whose storage-coord range partially overlaps the chunk + narrows the input domain to the surviving slice.""" + t = IndexTransform.from_shape((10,)) + result = t.intersect(IndexDomain(inclusive_min=(5,), exclusive_max=(15,))) + assert result is not None + restricted, surviving = result + assert restricted.domain.inclusive_min == (5,) + assert restricted.domain.exclusive_max == (10,) + assert surviving is None + + +def test_intersect_dimension_no_overlap() -> None: + """A DimensionMap whose storage-coord range does not overlap the chunk yields None.""" + t = IndexTransform.from_shape((10,)) + assert t.intersect(IndexDomain(inclusive_min=(20,), exclusive_max=(30,))) is None + + +def test_intersect_dimension_strided() -> None: + """Strided DimensionMap: storage = offset + stride * input. Only inputs that land + in the chunk survive.""" + # offset=1, stride=2, input [0,5): storage = {1, 3, 5, 7, 9}. Chunk [4, 8) -> {5, 7}. + t = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=(DimensionMap(input_dimension=0, offset=1, stride=2),), + ) + result = t.intersect(IndexDomain(inclusive_min=(4,), exclusive_max=(8,))) + assert result is not None + restricted, _ = result + assert restricted.domain.inclusive_min == (2,) + assert restricted.domain.exclusive_max == (4,) + + +def test_intersect_array_partial() -> None: + """An ArrayMap whose storage coords partially overlap the chunk yields a filtered ArrayMap + plus a `surviving` mask of the input indices that survived.""" + arr = np.array([3, 8, 15, 22], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((4,)), + output=(ArrayMap(index_array=arr, input_dimensions=(0,)),), + ) + result = t.intersect(IndexDomain(inclusive_min=(5,), exclusive_max=(20,))) + assert result is not None + restricted, surviving = result + assert isinstance(restricted.output[0], ArrayMap) + np.testing.assert_array_equal(restricted.output[0].index_array, np.array([8, 15])) + assert surviving is not None + np.testing.assert_array_equal(surviving, np.array([1, 2])) + + +def test_intersect_array_disjoint() -> None: + """An ArrayMap whose storage coords are entirely outside the chunk yields None.""" + arr = np.array([1, 2, 3], dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=(ArrayMap(index_array=arr, input_dimensions=(0,)),), + ) + assert t.intersect(IndexDomain(inclusive_min=(10,), exclusive_max=(20,))) is None + + +def test_intersect_2d_mixed_constant_and_dimension() -> None: + """2D output: ConstantMap on dim 0 (inside chunk), DimensionMap on dim 1 (overlaps chunk).""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=( + ConstantMap(offset=5), + DimensionMap(input_dimension=0, offset=0, stride=1), + ), + ) + chunk = IndexDomain(inclusive_min=(0, 5), exclusive_max=(10, 15)) + result = t.intersect(chunk) + assert result is not None + restricted, _ = result + assert isinstance(restricted.output[0], ConstantMap) + assert restricted.output[0].offset == 5 + assert isinstance(restricted.output[1], DimensionMap) + assert restricted.domain.inclusive_min == (5,) + assert restricted.domain.exclusive_max == (10,) + + +def test_intersect_rank_mismatch_errors() -> None: + """intersect rejects an output_domain whose rank differs from the transform's output rank.""" + t = IndexTransform.from_shape((10,)) # output rank 1 + chunk = IndexDomain.from_shape((10, 20)) # rank 2 + with pytest.raises(ValueError, match="output rank"): + t.intersect(chunk) + + +# --------------------------------------------------------------------------- +# Direct tests for _intersect_vectorized. +# +# Public `intersect` only calls _intersect_vectorized when the transform has +# 2+ ArrayMap outputs (correlated indices). All public test cases use exactly +# one ArrayMap, so this branch is unreachable from public-surface tests. +# --------------------------------------------------------------------------- + + +def _vectorized_2d_array_map() -> IndexTransform: + """Helper: a vectorized transform over a (3,) input domain with two + correlated ArrayMaps. Storage coords: (1,10), (5,11), (9,12). + + Both ArrayMaps share input_dimensions=(0,) — that's what makes them + correlated under the new design.""" + return IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + ArrayMap(index_array=np.array([1, 5, 9], dtype=np.intp), input_dimensions=(0,)), + ArrayMap(index_array=np.array([10, 11, 12], dtype=np.intp), input_dimensions=(0,)), + ), + ) + + +def test_intersect_vectorized_partial_survival() -> None: + """Two correlated ArrayMaps; only points where ALL coords are in-chunk survive.""" + t = _vectorized_2d_array_map() + chunk = IndexDomain(inclusive_min=(0, 10), exclusive_max=(8, 12)) + # Storage points (1,10), (5,11), (9,12). In-chunk: (1,10), (5,11). (9,12) fails dim 1. + result = _intersect_vectorized(t, chunk, [0, 1]) + assert result is not None + restricted, surviving = result + assert isinstance(restricted.output[0], ArrayMap) + assert isinstance(restricted.output[1], ArrayMap) + np.testing.assert_array_equal(restricted.output[0].index_array, np.array([1, 5])) + np.testing.assert_array_equal(restricted.output[1].index_array, np.array([10, 11])) + assert surviving is not None + np.testing.assert_array_equal(surviving, np.array([0, 1])) + + +def test_intersect_vectorized_no_survival() -> None: + """If no point is in-chunk on all dims, returns None.""" + t = _vectorized_2d_array_map() + chunk = IndexDomain(inclusive_min=(20, 20), exclusive_max=(30, 30)) + assert _intersect_vectorized(t, chunk, [0, 1]) is None + + +def test_intersect_vectorized_with_constant_outside_drops_to_none() -> None: + """When a ConstantMap output is outside the chunk, the entire transform fails.""" + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + ArrayMap(index_array=np.array([1, 2, 3], dtype=np.intp), input_dimensions=(0,)), + ArrayMap(index_array=np.array([10, 11, 12], dtype=np.intp), input_dimensions=(0,)), + ConstantMap(offset=99), + ), + ) + chunk = IndexDomain(inclusive_min=(0, 0, 0), exclusive_max=(10, 20, 5)) + assert _intersect_vectorized(t, chunk, [0, 1]) is None + + +# --------------------------------------------------------------------------- +# translate — shift every coordinate by an offset. +# --------------------------------------------------------------------------- + +_translate_dimension_t = IndexTransform.from_shape((10,)) +_translate_array_t = IndexTransform( + domain=IndexDomain.from_shape((2,)), + output=( + ArrayMap( + index_array=np.array([5, 10], dtype=np.intp), + input_dimensions=(0,), + offset=3, + ), + ), +) +_translate_constant_t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=5),), +) + + +@pytest.mark.parametrize( + "case", + [ + Expect( + input=(_translate_constant_t, (-5,)), + expected={"out_kind": ConstantMap, "offset": 0}, + id="constant", + ), + Expect( + input=(_translate_dimension_t, (-3,)), + expected={"out_kind": DimensionMap, "offset": -3, "stride": 1}, + id="dimension", + ), + Expect( + input=(_translate_array_t, (-3,)), + expected={"out_kind": ArrayMap, "offset": 0}, + id="array", + ), + ], + ids=lambda c: c.id, +) +def test_translate_success( + case: Expect[tuple[IndexTransform, tuple[int, ...]], dict[str, Any]], +) -> None: + """IndexTransform.translate adjusts each output map's offset uniformly.""" + transform, shift = case.input + result = transform.translate(shift) + out0 = result.output[0] + assert isinstance(out0, case.expected["out_kind"]) + assert out0.offset == case.expected["offset"] + if "stride" in case.expected: + assert isinstance(out0, DimensionMap) + assert out0.stride == case.expected["stride"] + + +def test_translate_2d() -> None: + """A multi-dimensional translate shifts all output dims independently.""" + t = IndexTransform.from_shape((10, 20)) + result = t.translate((-5, -10)) + out0, out1 = result.output + assert isinstance(out0, DimensionMap) + assert out0.offset == -5 + assert isinstance(out1, DimensionMap) + assert out1.offset == -10 + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexTransform.from_shape((10, 20)), (1,)), + msg="shift must have length", + exception_cls=ValueError, + id="shift-too-short", + ), + ExpectErr( + input=(IndexTransform.from_shape((10,)), (1, 2)), + msg="shift must have length", + exception_cls=ValueError, + id="shift-too-long", + ), + ], + ids=lambda c: c.id, +) +def test_translate_errors(case: ExpectErr[tuple[IndexTransform, tuple[int, ...]]]) -> None: + """IndexTransform.translate rejects shifts whose length doesn't match output_rank.""" + transform, shift = case.input + with pytest.raises(case.exception_cls, match=case.msg): + transform.translate(shift) + + +# --------------------------------------------------------------------------- +# selection_repr and __repr__: verify the human-readable strings cover each +# OutputIndexMap variant. +# --------------------------------------------------------------------------- + + +def test_selection_repr_covers_all_map_kinds() -> None: + """selection_repr produces a TensorStore-style domain string with one + entry per output dim, formatted differently for each map kind.""" + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + ConstantMap(offset=5), + DimensionMap(input_dimension=0, offset=2, stride=1), + DimensionMap(input_dimension=0, offset=0, stride=3), + ArrayMap( + index_array=np.array([1, 5, 9], dtype=np.intp), + input_dimensions=(0,), + ), + ), + ) + repr_str = t.selection_repr + assert "5" in repr_str # ConstantMap + assert "[2, 5)" in repr_str # DimensionMap stride=1 over input [0, 3) + assert "step 3" in repr_str # DimensionMap stride=3 + assert "{1, 5, 9}" in repr_str # ArrayMap (small) + + +def test_selection_repr_array_map_large() -> None: + """ArrayMaps with more than 5 elements show as `array(N)` rather than spelled out.""" + arr = np.arange(10, dtype=np.intp) + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ArrayMap(index_array=arr, input_dimensions=(0,)),), + ) + assert "array(10)" in t.selection_repr + + +def test_repr_covers_all_map_kinds() -> None: + """__repr__ formats each output map with its kind-specific shape.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10, 5)), + output=( + ConstantMap(offset=7), + DimensionMap(input_dimension=0, offset=1, stride=2), + ArrayMap( + index_array=np.array([0, 1, 2, 3, 4], dtype=np.intp), + input_dimensions=(1,), + ), + ), + ) + s = repr(t) + assert "out[0] = 7" in s + assert "out[1] = 1 + 2 * in[0]" in s + assert "out[2] = 0 + 1 * arr(5,)[in[1]]" in s + + +# --------------------------------------------------------------------------- +# intersect() public dispatch: prior tests call _intersect_vectorized directly; +# the public IndexTransform.intersect() vectorized path was untested. +# --------------------------------------------------------------------------- + + +def test_intersect_dispatches_to_vectorized_when_arraymaps_correlated() -> None: + """IndexTransform.intersect() uses the vectorized path when 2+ ArrayMaps + share an input dimension. It uses the orthogonal path when ArrayMaps have + disjoint input dimensions.""" + # Correlated: both ArrayMaps share input_dimensions=(0,) on a 1-D domain. + t_correlated = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + ArrayMap( + index_array=np.array([1, 5, 9], dtype=np.intp), + input_dimensions=(0,), + ), + ArrayMap( + index_array=np.array([10, 11, 12], dtype=np.intp), + input_dimensions=(0,), + ), + ), + ) + chunk = IndexDomain(inclusive_min=(0, 10), exclusive_max=(8, 12)) + result = t_correlated.intersect(chunk) + assert result is not None + _, surviving = result + # Both points (1,10), (5,11) survive; (9,12) fails dim 1. + assert surviving is not None + np.testing.assert_array_equal(surviving, np.array([0, 1])) + + +# --------------------------------------------------------------------------- +# _intersect_vectorized with a DimensionMap output on a NON-correlated input +# dim: the post-refactor path that preserves the non-broadcast input dim. +# --------------------------------------------------------------------------- + + +def test_intersect_vectorized_preserves_non_correlated_dim() -> None: + """A vindex transform with a non-broadcast input dim produces a + DimensionMap on that dim. Intersecting must remap that DimensionMap's + input_dimension to the new domain (where the broadcast dim has been + collapsed to (len(surviving),) at index 0).""" + # Construct the transform that vindex-with-trailing-slice would produce: + # (broadcast_dim=3, slice_dim=20). output[0] = ArrayMap on broadcast, + # output[1] = DimensionMap on slice_dim. + t = IndexTransform( + domain=IndexDomain.from_shape((3, 20)), + output=( + ArrayMap( + index_array=np.array([1, 5, 9], dtype=np.intp), + input_dimensions=(0,), + ), + DimensionMap(input_dimension=1, offset=0, stride=1), + ), + ) + # Two correlated outputs needed for vectorized path; add a second ArrayMap + # on the same broadcast dim. + t_with_two_arrays = IndexTransform( + domain=IndexDomain.from_shape((3, 20)), + output=( + ArrayMap( + index_array=np.array([1, 5, 9], dtype=np.intp), + input_dimensions=(0,), + ), + ArrayMap( + index_array=np.array([2, 6, 10], dtype=np.intp), + input_dimensions=(0,), + ), + DimensionMap(input_dimension=1, offset=0, stride=1), + ), + ) + chunk = IndexDomain(inclusive_min=(0, 0, 0), exclusive_max=(20, 20, 20)) + result = t_with_two_arrays.intersect(chunk) + assert result is not None + restricted, surviving = result + # Surviving points: all 3 (all storage coords in [0,20)). + assert surviving is not None + np.testing.assert_array_equal(surviving, np.array([0, 1, 2])) + # New domain: (3 surviving, 20 from preserved slice dim). + assert restricted.domain.shape == (3, 20) + # output[2] (the DimensionMap) should have its input_dimension remapped + # from old dim 1 to new dim 1 (broadcast dim is now new dim 0). + out_dim_map = restricted.output[2] + assert isinstance(out_dim_map, DimensionMap) + assert out_dim_map.input_dimension == 1 + # silence unused-var: t was an intermediate construction reference + assert t.output_rank == 2 + + +# --------------------------------------------------------------------------- +# _apply_basic_indexing rejects negative slice steps. +# --------------------------------------------------------------------------- + + +def test_basic_indexing_rejects_negative_slice_step() -> None: + t = IndexTransform.from_shape((10,)) + with pytest.raises(IndexError, match="slice step must be positive"): + t[slice(None, None, -1)] + + +# --------------------------------------------------------------------------- +# _apply_vindex on an existing ArrayMap output raises NotImplementedError. +# --------------------------------------------------------------------------- + + +def test_vindex_on_existing_arraymap_errors() -> None: + t = IndexTransform( + domain=IndexDomain.from_shape((5,)), + output=( + ArrayMap( + index_array=np.array([1, 2, 3, 4, 5], dtype=np.intp), + input_dimensions=(0,), + ), + ), + ) + with pytest.raises(NotImplementedError, match="ArrayMap"): + t.vindex[np.array([0, 2], dtype=np.intp)] + + +# --------------------------------------------------------------------------- +# selection_to_transform validation: reject unsupported selection types. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexTransform.from_shape((10,)), 1.5, "basic"), + msg="unsupported selection type", + exception_cls=IndexError, + id="basic-rejects-float", + ), + ExpectErr( + input=(IndexTransform.from_shape((10,)), 1.5, "orthogonal"), + msg="unsupported selection type", + exception_cls=IndexError, + id="orthogonal-rejects-float", + ), + ExpectErr( + input=(IndexTransform.from_shape((10,)), 1.5, "vectorized"), + msg="unsupported selection type", + exception_cls=IndexError, + id="vectorized-rejects-float", + ), + ], + ids=lambda c: c.id, +) +def test_selection_to_transform_rejects_unsupported_types( + case: ExpectErr[tuple[IndexTransform, Any, Literal["basic", "orthogonal", "vectorized"]]], +) -> None: + """selection_to_transform's validators reject types like float.""" + transform, selection, mode = case.input + with pytest.raises(case.exception_cls, match=case.msg): + selection_to_transform(selection, transform, mode) + + +# --------------------------------------------------------------------------- +# _apply_oindex parsing branches: bare int, list selection. +# --------------------------------------------------------------------------- + + +def test_oindex_bare_int_becomes_singleton_array() -> None: + """oindex[3] on a 1-D transform converts the int to a 1-element array, + producing an ArrayMap of length 1 (not a ConstantMap).""" + t = IndexTransform.from_shape((10,)) + result = t.oindex[3] + assert result.input_rank == 1 + assert result.domain.shape == (1,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, np.array([3])) + + +def test_oindex_list_selection() -> None: + """oindex accepts a Python list and converts it to an integer array.""" + t = IndexTransform.from_shape((10,)) + result = t.oindex[[1, 3, 5]] + assert result.input_rank == 1 + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + np.testing.assert_array_equal(result.output[0].index_array, np.array([1, 3, 5])) + + +# --------------------------------------------------------------------------- +# _apply_vindex parsing branches: ellipsis, 2D bool, list, bare int. +# --------------------------------------------------------------------------- + + +def test_vindex_ellipsis() -> None: + """vindex[...] is a no-op identity.""" + t = IndexTransform.from_shape((4, 5)) + result = t.vindex[...] + assert result.domain.shape == (4, 5) + + +def test_vindex_2d_bool_mask_consumes_two_dims() -> None: + """A 2-D bool mask in vindex consumes both dims of a 2-D domain and + expands into two correlated 1-D ArrayMaps.""" + t = IndexTransform.from_shape((3, 4)) + mask = np.array( + [[True, False, True, False], [False, True, False, True], [True, True, False, False]] + ) + result = t.vindex[mask] + # 6 True entries; broadcast shape (6,). + assert result.domain.shape == (6,) + assert isinstance(result.output[0], ArrayMap) + assert isinstance(result.output[1], ArrayMap) + + +def test_vindex_list_selection() -> None: + """vindex accepts a Python list like oindex does.""" + t = IndexTransform.from_shape((10,)) + result = t.vindex[[1, 3, 5]] + assert result.domain.shape == (3,) + assert isinstance(result.output[0], ArrayMap) + + +def test_vindex_bare_int_becomes_singleton_array() -> None: + """vindex[3] on a 1-D transform produces an ArrayMap of length 1.""" + t = IndexTransform.from_shape((10,)) + result = t.vindex[3] + assert result.domain.shape == (1,) + assert isinstance(result.output[0], ArrayMap) + + +def test_vindex_with_fewer_selections_than_dims_pads_with_slice() -> None: + """vindex(arr) on a 2-D domain leaves trailing dims untouched (slice fill).""" + t = IndexTransform.from_shape((3, 5)) + result = t.vindex[np.array([0, 1], dtype=np.intp)] + # Broadcast dim (2,) prepended; trailing dim (5,) preserved. + assert result.domain.shape == (2, 5) + + +# --------------------------------------------------------------------------- +# ConstantMap survives basic / oindex / vindex unchanged. The tests above +# exercise these paths for DimensionMap-only transforms; these cover the +# `output[i] is ConstantMap` branch in each of the three apply functions. +# --------------------------------------------------------------------------- + + +def test_basic_indexing_preserves_constant_map() -> None: + """A ConstantMap output passes through basic indexing unchanged.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=42), DimensionMap(input_dimension=0)), + ) + result = t[2:8] + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 42 + + +def test_oindex_preserves_constant_map() -> None: + """A ConstantMap output passes through oindex unchanged.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=42), DimensionMap(input_dimension=0)), + ) + result = t.oindex[np.array([1, 3, 5], dtype=np.intp)] + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 42 + + +def test_vindex_preserves_constant_map() -> None: + """A ConstantMap output passes through vindex unchanged.""" + t = IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=(ConstantMap(offset=42), DimensionMap(input_dimension=0)), + ) + result = t.vindex[np.array([1, 3, 5], dtype=np.intp)] + assert isinstance(result.output[0], ConstantMap) + assert result.output[0].offset == 42 + + +def test_intersect_vectorized_constant_inside_chunk_passes() -> None: + """In _intersect_vectorized, a ConstantMap whose offset is inside the + chunk's range on its output dim is passed through. (The outside-chunk + case yields None and is already tested.)""" + t = IndexTransform( + domain=IndexDomain.from_shape((3,)), + output=( + ArrayMap(index_array=np.array([1, 2, 3], dtype=np.intp), input_dimensions=(0,)), + ArrayMap(index_array=np.array([10, 11, 12], dtype=np.intp), input_dimensions=(0,)), + ConstantMap(offset=5), + ), + ) + chunk = IndexDomain(inclusive_min=(0, 0, 0), exclusive_max=(10, 20, 10)) + result = t.intersect(chunk) + assert result is not None + restricted, _ = result + assert isinstance(restricted.output[2], ConstantMap) + assert restricted.output[2].offset == 5 + + +# --------------------------------------------------------------------------- +# Domain-level edge cases: empty-domain intersect, oindex with ellipsis or +# trailing-dim implicit fill. +# --------------------------------------------------------------------------- + + +def test_intersect_dimension_map_on_empty_domain_returns_none() -> None: + """When a DimensionMap's input dim is already empty (input_lo >= input_hi), + intersect returns None.""" + t = IndexTransform( + domain=IndexDomain(inclusive_min=(0,), exclusive_max=(0,)), + output=(DimensionMap(input_dimension=0, offset=0, stride=1),), + ) + assert t.intersect(IndexDomain.from_shape((10,))) is None + + +def test_oindex_with_ellipsis() -> None: + """oindex with ellipsis fills missing dims with slice(None).""" + t = IndexTransform.from_shape((4, 5, 6)) + result = t.oindex[np.array([0, 2], dtype=np.intp), ...] + # ellipsis fills dims 1 and 2 with slice(None); domain becomes (2, 5, 6). + assert result.domain.shape == (2, 5, 6) + + +def test_oindex_with_implicit_trailing_dim_fill() -> None: + """oindex with fewer entries than ndim pads trailing dims with slice(None).""" + t = IndexTransform.from_shape((4, 5, 6)) + result = t.oindex[np.array([0, 2], dtype=np.intp)] + # Only the first dim is selected; trailing dims pad with slice(None). + assert result.domain.shape == (2, 5, 6) + + +# --------------------------------------------------------------------------- +# IndexTransform.__post_init__ shape mismatch error: covered in test_construction_errors +# above? No — the shape mismatch is implicit (the __post_init__ check fires when +# ArrayMap shape != domain shape on input_dimensions), and it's hit by the +# multi-dim oindex test elsewhere. Add an explicit test. +# --------------------------------------------------------------------------- + + +def test_construction_rejects_shape_mismatch() -> None: + """ArrayMap.index_array.shape must match the input domain's extents on + input_dimensions (in order).""" + with pytest.raises(ValueError, match="does not match expected shape"): + IndexTransform( + domain=IndexDomain.from_shape((10,)), + output=( + ArrayMap( + index_array=np.array([1, 2, 3], dtype=np.intp), + input_dimensions=(0,), + ), + ), + ) + + +# --------------------------------------------------------------------------- +# _normalize_basic_selection error paths. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "case", + [ + ExpectErr( + input=(IndexTransform.from_shape((2,)), (1, 2, 3)), + msg="too many indices", + exception_cls=IndexError, + id="too-many-indices", + ), + ExpectErr( + input=(IndexTransform.from_shape((3, 3, 3)), (..., 0, ...)), + msg="single ellipsis", + exception_cls=IndexError, + id="double-ellipsis", + ), + ExpectErr( + input=(IndexTransform.from_shape((10,)), 1.5), + msg="unsupported selection type", + exception_cls=IndexError, + id="float-not-supported", + ), + ], + ids=lambda c: c.id, +) +def test_basic_indexing_rejects_malformed_selections( + case: ExpectErr[tuple[IndexTransform, Any]], +) -> None: + """_normalize_basic_selection error paths: too-many-indices, double-ellipsis, + and unsupported types like float.""" + transform, selection = case.input + with pytest.raises(case.exception_cls, match=case.msg): + transform[selection] + + +# --------------------------------------------------------------------------- +# Transforms with ArrayMap NOT in the last output position. +# +# Several `for m in self.output:` loops in selection_repr, __repr__, basic / +# oindex / vindex apply functions, and _intersect's orthogonal path have an +# `elif isinstance(m, ArrayMap):` branch that, for branch coverage, needs to +# be exercised with an ArrayMap that is NOT the last output (i.e., the loop +# continues to a next iteration after the ArrayMap branch). The fixture below +# constructs a transform with ArrayMap-then-DimensionMap output ordering; +# the tests use it to hit those continuation branches. +# --------------------------------------------------------------------------- + + +def _arraymap_then_dimensionmap() -> IndexTransform: + """Helper: a 2-D-input transform whose first output is an ArrayMap and + whose second output is a DimensionMap. Ensures `for m in output` loops + encounter an ArrayMap with a next iteration available.""" + return IndexTransform( + domain=IndexDomain.from_shape((3, 5)), + output=( + ArrayMap( + index_array=np.array([1, 4, 9], dtype=np.intp), + input_dimensions=(0,), + ), + DimensionMap(input_dimension=1, offset=0, stride=1), + ), + ) + + +def test_selection_repr_with_arraymap_not_last() -> None: + """selection_repr output loop visits ArrayMap then continues.""" + t = _arraymap_then_dimensionmap() + s = t.selection_repr + assert "{1, 4, 9}" in s + assert "[0, 5)" in s + + +def test_repr_with_arraymap_not_last() -> None: + """__repr__ output loop visits ArrayMap then continues.""" + t = _arraymap_then_dimensionmap() + s = repr(t) + assert "out[0] = 0 + 1 * arr(3,)[in[0]]" in s + assert "out[1] = 0 + 1 * in[1]" in s + + +def test_translate_with_arraymap_not_last() -> None: + """IndexTransform.translate output loop visits ArrayMap then continues. + + The shift is applied to every output, so an ArrayMap-then-DimensionMap + transform produces a (translated ArrayMap, translated DimensionMap) + pair.""" + t = _arraymap_then_dimensionmap() + result = t.translate((10, 100)) + assert isinstance(result.output[0], ArrayMap) + assert result.output[0].offset == 10 + assert isinstance(result.output[1], DimensionMap) + assert result.output[1].offset == 100 + + +def test_basic_indexing_with_arraymap_not_last() -> None: + """_apply_basic_indexing output loop visits ArrayMap then continues.""" + t = _arraymap_then_dimensionmap() + result = t[:, 2:5] + assert isinstance(result.output[0], ArrayMap) + assert isinstance(result.output[1], DimensionMap) + + +def test_oindex_with_arraymap_not_last() -> None: + """_apply_oindex output loop visits ArrayMap then continues.""" + t = _arraymap_then_dimensionmap() + result = t.oindex[:, np.array([0, 2, 4], dtype=np.intp)] + # Two outputs preserved: the original ArrayMap (untouched on its + # parameterizing dim) and the new ArrayMap created from the DimensionMap. + assert isinstance(result.output[0], ArrayMap) + assert isinstance(result.output[1], ArrayMap) + + +def test_intersect_with_two_uncorrelated_arraymaps_uses_orthogonal_path() -> None: + """When 2+ ArrayMaps have disjoint input_dimensions (no shared input dim), + intersect detects no correlation and falls through to the orthogonal path, + NOT the vectorized path. Also exercises the `for m in output` orthogonal + loop visiting an ArrayMap that is not the last output.""" + # 2-D input domain (3, 4); two ArrayMaps with disjoint input_dimensions. + t = IndexTransform( + domain=IndexDomain.from_shape((3, 4)), + output=( + ArrayMap( + index_array=np.array([0, 5, 10], dtype=np.intp), + input_dimensions=(0,), + ), + ArrayMap( + index_array=np.array([20, 30, 40, 50], dtype=np.intp), + input_dimensions=(1,), + ), + ), + ) + # Chunk that includes everything. The orthogonal path filters each + # ArrayMap independently against its output dim's chunk range. + chunk = IndexDomain.from_shape((100, 100)) + result = t.intersect(chunk) + assert result is not None + restricted, _ = result + # Both outputs survive as ArrayMaps (orthogonal path preserves them). + assert isinstance(restricted.output[0], ArrayMap) + assert isinstance(restricted.output[1], ArrayMap)