Skip to content

Commit 2654568

Browse files
ENH: add diag_indices, tril_indices, triu_indices
Resolves #686. Adds the three index-generating functions that numpy, jax, and cupy all have but that are missing from the array-api standard and (so far) from this library. Signatures follow array-api conventions: parameter `offset` (matching `xp.linalg.diagonal`) instead of numpy's `k`; keyword-only arguments for everything except `n`; `xp` is required (these functions have no input array to infer from, following the `default_dtype` precedent). Delegation: - numpy/cupy/jax: forward directly (signatures match verbatim). - dask: has tril/triu_indices but no diag_indices. - torch: has tril/triu_indices but with (row, col, *, offset) signature returning a 2xN tensor rather than a tuple; delegation translates. No torch.diag_indices exists; falls through to generic. - sparse, array-api-strictest: fall through to generic; marked xfail on those backends (no nonzero / data-dependent shapes). Generic implementation uses `xp.arange` + broadcasting + `xp.nonzero` for the triangle variants. Validation (n >= 0, ndim >= 1, m >= 0) happens in the delegation layer so all backends produce consistent ValueErrors. Also fixes a pre-existing bug in tests/conftest.py's NumPyReadOnly wrapper: `type(o)(*gen)` worked for namedtuples but failed for plain tuples of length >= 2. Exposed here because these are the first functions in the library that return a tuple of arrays.
1 parent 1e1f1ab commit 2654568

5 files changed

Lines changed: 319 additions & 2 deletions

File tree

src/array_api_extra/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
atleast_nd,
66
cov,
77
create_diagonal,
8+
diag_indices,
89
expand_dims,
910
isclose,
1011
isin,
@@ -15,6 +16,8 @@
1516
searchsorted,
1617
setdiff1d,
1718
sinc,
19+
tril_indices,
20+
triu_indices,
1821
union1d,
1922
)
2023
from ._lib._at import at
@@ -40,6 +43,7 @@
4043
"cov",
4144
"create_diagonal",
4245
"default_dtype",
46+
"diag_indices",
4347
"expand_dims",
4448
"isclose",
4549
"isin",
@@ -53,5 +57,7 @@
5357
"searchsorted",
5458
"setdiff1d",
5559
"sinc",
60+
"tril_indices",
61+
"triu_indices",
5662
"union1d",
5763
]

src/array_api_extra/_delegation.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,16 @@
2222
"atleast_nd",
2323
"cov",
2424
"create_diagonal",
25+
"diag_indices",
2526
"expand_dims",
2627
"isclose",
2728
"nan_to_num",
2829
"one_hot",
2930
"pad",
3031
"searchsorted",
3132
"sinc",
33+
"tril_indices",
34+
"triu_indices",
3235
]
3336

3437

@@ -238,6 +241,49 @@ def create_diagonal(
238241
return _funcs.create_diagonal(x, offset=offset, xp=xp)
239242

240243

244+
def diag_indices(n: int, /, *, ndim: int = 2, xp: ModuleType) -> tuple[Array, ...]:
245+
"""
246+
Return the indices to access the main diagonal of an array.
247+
248+
Equivalent to ``numpy.diag_indices``.
249+
250+
Parameters
251+
----------
252+
n : int
253+
The size of each dimension of the (hyper-)cube ``(n, n, ..., n)``
254+
that the returned indices index into.
255+
ndim : int, optional
256+
The number of dimensions. Default: ``2``.
257+
xp : array_namespace
258+
The standard-compatible namespace to create the indices in.
259+
260+
Returns
261+
-------
262+
tuple of array
263+
``ndim`` 1-D integer arrays of length ``n`` that together index
264+
the main diagonal of an array of shape ``(n,) * ndim``.
265+
266+
Examples
267+
--------
268+
>>> import array_api_strict as xp
269+
>>> import array_api_extra as xpx
270+
>>> rows, cols = xpx.diag_indices(3, xp=xp)
271+
>>> rows
272+
Array([0, 1, 2], dtype=array_api_strict.int64)
273+
>>> cols
274+
Array([0, 1, 2], dtype=array_api_strict.int64)
275+
"""
276+
if n < 0:
277+
msg = f"`n` must be non-negative, got {n}"
278+
raise ValueError(msg)
279+
if ndim < 1:
280+
msg = f"`ndim` must be >= 1, got {ndim}"
281+
raise ValueError(msg)
282+
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
283+
return xp.diag_indices(n, ndim=ndim)
284+
return _funcs.diag_indices(n, ndim=ndim, xp=xp)
285+
286+
241287
def expand_dims(
242288
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
243289
) -> Array:
@@ -1150,3 +1196,119 @@ def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
11501196
return xp.union1d(a, b)
11511197

11521198
return _funcs.union1d(a, b, xp=xp)
1199+
1200+
1201+
def tril_indices(
1202+
n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType
1203+
) -> tuple[Array, Array]:
1204+
"""
1205+
Return the indices of the lower triangle of an ``(n, m)`` array.
1206+
1207+
Equivalent to ``numpy.tril_indices`` with parameter ``k`` renamed to
1208+
``offset`` to match ``xp.linalg.diagonal``'s naming.
1209+
1210+
Parameters
1211+
----------
1212+
n : int
1213+
The row dimension of the array.
1214+
offset : int, optional
1215+
Diagonal offset; ``0`` (default) is the main diagonal. Corresponds
1216+
to ``k`` in ``numpy.tril_indices``.
1217+
m : int, optional
1218+
The column dimension. If ``None`` (default), assumed equal to `n`.
1219+
xp : array_namespace
1220+
The standard-compatible namespace to create the indices in.
1221+
1222+
Returns
1223+
-------
1224+
tuple of array
1225+
Row and column indices ``(rows, cols)`` of the lower triangle of
1226+
the ``(n, m)`` matrix, shifted by `offset`.
1227+
1228+
Examples
1229+
--------
1230+
>>> import array_api_strict as xp
1231+
>>> import array_api_extra as xpx
1232+
>>> rows, cols = xpx.tril_indices(3, xp=xp)
1233+
>>> rows
1234+
Array([0, 1, 1, 2, 2, 2], dtype=array_api_strict.int64)
1235+
>>> cols
1236+
Array([0, 0, 1, 0, 1, 2], dtype=array_api_strict.int64)
1237+
"""
1238+
if n < 0:
1239+
msg = f"`n` must be non-negative, got {n}"
1240+
raise ValueError(msg)
1241+
if m is not None and m < 0:
1242+
msg = f"`m` must be non-negative, got {m}"
1243+
raise ValueError(msg)
1244+
if (
1245+
is_numpy_namespace(xp)
1246+
or is_cupy_namespace(xp)
1247+
or is_jax_namespace(xp)
1248+
or is_dask_namespace(xp)
1249+
):
1250+
return xp.tril_indices(n, k=offset, m=m)
1251+
if is_torch_namespace(xp):
1252+
# `torch.tril_indices` returns a 2xN tensor, not a tuple, and
1253+
# takes (row, col) rather than (n, *, m=None).
1254+
cols = n if m is None else m
1255+
idx = xp.tril_indices(n, cols, offset=offset)
1256+
return (idx[0], idx[1])
1257+
return _funcs.tril_indices(n, offset=offset, m=m, xp=xp)
1258+
1259+
1260+
def triu_indices(
1261+
n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType
1262+
) -> tuple[Array, Array]:
1263+
"""
1264+
Return the indices of the upper triangle of an ``(n, m)`` array.
1265+
1266+
Equivalent to ``numpy.triu_indices`` with parameter ``k`` renamed to
1267+
``offset`` to match ``xp.linalg.diagonal``'s naming.
1268+
1269+
Parameters
1270+
----------
1271+
n : int
1272+
The row dimension of the array.
1273+
offset : int, optional
1274+
Diagonal offset; ``0`` (default) is the main diagonal. Corresponds
1275+
to ``k`` in ``numpy.triu_indices``.
1276+
m : int, optional
1277+
The column dimension. If ``None`` (default), assumed equal to `n`.
1278+
xp : array_namespace
1279+
The standard-compatible namespace to create the indices in.
1280+
1281+
Returns
1282+
-------
1283+
tuple of array
1284+
Row and column indices ``(rows, cols)`` of the upper triangle of
1285+
the ``(n, m)`` matrix, shifted by `offset`.
1286+
1287+
Examples
1288+
--------
1289+
>>> import array_api_strict as xp
1290+
>>> import array_api_extra as xpx
1291+
>>> rows, cols = xpx.triu_indices(3, xp=xp)
1292+
>>> rows
1293+
Array([0, 0, 0, 1, 1, 2], dtype=array_api_strict.int64)
1294+
>>> cols
1295+
Array([0, 1, 2, 1, 2, 2], dtype=array_api_strict.int64)
1296+
"""
1297+
if n < 0:
1298+
msg = f"`n` must be non-negative, got {n}"
1299+
raise ValueError(msg)
1300+
if m is not None and m < 0:
1301+
msg = f"`m` must be non-negative, got {m}"
1302+
raise ValueError(msg)
1303+
if (
1304+
is_numpy_namespace(xp)
1305+
or is_cupy_namespace(xp)
1306+
or is_jax_namespace(xp)
1307+
or is_dask_namespace(xp)
1308+
):
1309+
return xp.triu_indices(n, k=offset, m=m)
1310+
if is_torch_namespace(xp):
1311+
cols = n if m is None else m
1312+
idx = xp.triu_indices(n, cols, offset=offset)
1313+
return (idx[0], idx[1])
1314+
return _funcs.triu_indices(n, offset=offset, m=m, xp=xp)

src/array_api_extra/_lib/_funcs.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,16 @@
2828
"broadcast_shapes",
2929
"cov",
3030
"create_diagonal",
31+
"diag_indices",
3132
"expand_dims",
3233
"kron",
3334
"nunique",
3435
"pad",
3536
"searchsorted",
3637
"setdiff1d",
3738
"sinc",
39+
"tril_indices",
40+
"triu_indices",
3841
]
3942

4043

@@ -346,6 +349,41 @@ def create_diagonal(
346349
return xp.reshape(diag, (*batch_dims, n, n))
347350

348351

352+
def diag_indices(
353+
n: int, /, *, ndim: int = 2, xp: ModuleType
354+
) -> tuple[Array, ...]: # numpydoc ignore=PR01,RT01
355+
"""See docstring in array_api_extra._delegation."""
356+
idx = xp.arange(n)
357+
return (idx,) * ndim
358+
359+
360+
def _tri_indices(
361+
n: int, *, offset: int, m: int | None, upper: bool, xp: ModuleType
362+
) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01
363+
"""Shared implementation for `tril_indices` and `triu_indices`."""
364+
cols = n if m is None else m
365+
rows = xp.arange(n)[:, None]
366+
cols_a = xp.arange(cols)[None, :]
367+
delta = cols_a - rows
368+
mask = delta >= offset if upper else delta <= offset
369+
r, c = xp.nonzero(mask)
370+
return (r, c)
371+
372+
373+
def tril_indices(
374+
n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType
375+
) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01
376+
"""See docstring in array_api_extra._delegation."""
377+
return _tri_indices(n, offset=offset, m=m, upper=False, xp=xp)
378+
379+
380+
def triu_indices(
381+
n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType
382+
) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01
383+
"""See docstring in array_api_extra._delegation."""
384+
return _tri_indices(n, offset=offset, m=m, upper=True, xp=xp)
385+
386+
349387
def default_dtype(
350388
xp: ModuleType,
351389
kind: Literal[

tests/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,12 @@ def as_readonly(o: T) -> T: # numpydoc ignore=PR01,RT01
9696
# Cannot interpret as a data type
9797
return o
9898

99-
# This works with namedtuples too
10099
if isinstance(o, tuple | list):
101-
return type(o)(*(as_readonly(i) for i in o)) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType]
100+
# namedtuple wants positional args; plain tuple/list wants an iterable.
101+
items = (as_readonly(i) for i in o)
102+
if hasattr(o, "_fields"):
103+
return type(o)(*items) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType]
104+
return type(o)(items) # type: ignore[return-value]
102105

103106
return o
104107

0 commit comments

Comments
 (0)