Skip to content

Commit 0c57e2b

Browse files
ENH: expose correction and weights parameters in cov
Resolves #688. Adds `axis`, `correction`, `frequency_weights`, and `weights` to `cov`, giving users control over the degrees-of-freedom correction and the observation-axis / weighted variants that `numpy.cov` and `torch.cov` already support. Naming follows array-api conventions (`axis`, `correction`) rather than numpy's (`rowvar`, `bias`, `ddof`); the docstring includes a one-to-one mapping. The delegation moves observations to the last axis via `xp.moveaxis`, collapsing `rowvar` out of the backend dispatch — only `ddof` vs `correction` differs between branches. Dask's native `cov` forces `.compute()` on a lazy scalar when any weights are given, so weighted dask inputs fall through to the generic implementation, which is fully lazy.
1 parent 1e1f1ab commit 0c57e2b

3 files changed

Lines changed: 246 additions & 31 deletions

File tree

src/array_api_extra/_delegation.py

Lines changed: 104 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,16 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
8181
return _funcs.atleast_nd(x, ndim=ndim, xp=xp)
8282

8383

84-
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
84+
def cov(
85+
m: Array,
86+
/,
87+
*,
88+
axis: int = -1,
89+
correction: int | float = 1,
90+
frequency_weights: Array | None = None,
91+
weights: Array | None = None,
92+
xp: ModuleType | None = None,
93+
) -> Array:
8594
"""
8695
Estimate a covariance matrix (or a stack of covariance matrices).
8796
@@ -92,16 +101,37 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
92101
:math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance
93102
of :math:`x_i`.
94103
95-
With the exception of supporting batch input, this provides a subset of
96-
the functionality of ``numpy.cov``.
104+
Extends ``numpy.cov`` with support for batch input and array-api
105+
backends. Naming follows the array-api conventions used elsewhere in
106+
this library (``axis``, ``correction``) rather than the numpy spellings
107+
(``rowvar``, ``bias``, ``ddof``); see Notes for the mapping.
97108
98109
Parameters
99110
----------
100111
m : array
101112
An array of shape ``(..., N, M)`` whose innermost two dimensions
102-
contain *M* observations of *N* variables. That is,
103-
each row of `m` represents a variable, and each column a single
104-
observation of all those variables.
113+
contain *M* observations of *N* variables by default. The axis of
114+
observations is controlled by `axis`.
115+
axis : int, optional
116+
Axis of `m` containing the observations. Default: ``-1`` (the last
117+
axis), matching the array-api convention. Use ``axis=-2`` (or ``0``
118+
for 2-D input) to treat each column as a variable, which
119+
corresponds to ``rowvar=False`` in ``numpy.cov``.
120+
correction : int or float, optional
121+
Degrees of freedom correction: normalization divides by
122+
``N - correction`` (for unweighted input). Default: ``1``, which
123+
gives the unbiased estimate (matches ``numpy.cov`` default of
124+
``bias=False``). Set to ``0`` for the biased estimate (``N``
125+
normalization). Corresponds to ``ddof`` in ``numpy.cov`` and to
126+
``correction`` in ``numpy.var``/``std`` and ``torch.cov``.
127+
frequency_weights : array, optional
128+
1-D array of integer frequency weights: the number of times each
129+
observation is repeated. Corresponds to ``fweights`` in
130+
``numpy.cov``/``torch.cov``.
131+
weights : array, optional
132+
1-D array of observation-vector weights (analytic weights). Larger
133+
values mark more important observations. Corresponds to
134+
``aweights`` in ``numpy.cov``/``torch.cov``.
105135
xp : array_namespace, optional
106136
The standard-compatible namespace for `m`. Default: infer.
107137
@@ -111,6 +141,23 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
111141
An array having shape (..., N, N) whose innermost two dimensions represent
112142
the covariance matrix of the variables.
113143
144+
Notes
145+
-----
146+
Mapping from ``numpy.cov`` to this function::
147+
148+
numpy.cov(m, rowvar=True) -> cov(m, axis=-1) # default
149+
numpy.cov(m, rowvar=False) -> cov(m, axis=-2)
150+
numpy.cov(m, bias=True) -> cov(m, correction=0)
151+
numpy.cov(m, ddof=k) -> cov(m, correction=k)
152+
numpy.cov(m, fweights=f) -> cov(m, frequency_weights=f)
153+
numpy.cov(m, aweights=a) -> cov(m, weights=a)
154+
155+
Unlike ``numpy.cov``, a ``RuntimeWarning`` for non-positive effective
156+
degrees of freedom is only emitted on the unweighted path. The
157+
weighted path omits the check so that lazy backends (e.g. Dask) can
158+
stay lazy end-to-end; choose ``correction`` and weights such that the
159+
effective normalizer is positive.
160+
114161
Examples
115162
--------
116163
>>> import array_api_strict as xp
@@ -164,16 +211,57 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
164211
if xp is None:
165212
xp = array_namespace(m)
166213

167-
if (
168-
is_numpy_namespace(xp)
169-
or is_cupy_namespace(xp)
170-
or is_torch_namespace(xp)
171-
or is_dask_namespace(xp)
172-
or is_jax_namespace(xp)
173-
) and m.ndim <= 2:
174-
return xp.cov(m)
175-
176-
return _funcs.cov(m, xp=xp)
214+
# Validate axis against m.ndim.
215+
ndim = max(m.ndim, 1)
216+
if not -ndim <= axis < ndim:
217+
msg = f"axis {axis} is out of bounds for array of dimension {m.ndim}"
218+
raise IndexError(msg)
219+
220+
# Normalize: observations on the last axis. After this, every backend
221+
# sees the same convention and we never need to deal with `rowvar`.
222+
if m.ndim >= 2 and axis not in (-1, m.ndim - 1):
223+
m = xp.moveaxis(m, axis, -1)
224+
225+
# `numpy.cov` (and cupy/dask/jax) require integer `ddof`; `torch.cov`
226+
# requires integer `correction`. For non-integer-valued `correction`,
227+
# fall through to the generic implementation.
228+
integer_correction = isinstance(correction, int) or correction.is_integer()
229+
has_weights = frequency_weights is not None or weights is not None
230+
231+
if m.ndim <= 2 and integer_correction:
232+
if is_torch_namespace(xp):
233+
device = get_device(m)
234+
fw = (
235+
None
236+
if frequency_weights is None
237+
else xp.asarray(frequency_weights, device=device)
238+
)
239+
aw = None if weights is None else xp.asarray(weights, device=device)
240+
return xp.cov(m, correction=int(correction), fweights=fw, aweights=aw)
241+
# `dask.array.cov` forces `.compute()` whenever weights are given:
242+
# its internal `if fact <= 0` check on a lazy 0-D scalar triggers
243+
# materialization. Route to the generic impl, which is fully lazy
244+
# because it only does sum/matmul and skips that scalar check.
245+
if (
246+
is_numpy_namespace(xp)
247+
or is_cupy_namespace(xp)
248+
or is_jax_namespace(xp)
249+
or (is_dask_namespace(xp) and not has_weights)
250+
):
251+
return xp.cov(
252+
m,
253+
ddof=int(correction),
254+
fweights=frequency_weights,
255+
aweights=weights,
256+
)
257+
258+
return _funcs.cov(
259+
m,
260+
correction=correction,
261+
frequency_weights=frequency_weights,
262+
weights=weights,
263+
xp=xp,
264+
)
177265

178266

179267
def create_diagonal(

src/array_api_extra/_lib/_funcs.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -281,31 +281,67 @@ def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...
281281
return tuple(out)
282282

283283

284-
def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01
284+
def cov(
285+
m: Array,
286+
/,
287+
*,
288+
correction: int | float = 1,
289+
frequency_weights: Array | None = None,
290+
weights: Array | None = None,
291+
xp: ModuleType,
292+
) -> Array: # numpydoc ignore=PR01,RT01
285293
"""See docstring in array_api_extra._delegation."""
286-
m = xp.asarray(m, copy=True)
294+
m = xp.asarray(m)
287295
dtype = (
288296
xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64)
289297
)
290298

291299
m = atleast_nd(m, ndim=2, xp=xp)
292300
m = xp.astype(m, dtype)
293301

294-
avg = xp.mean(m, axis=-1, keepdims=True)
302+
device = _compat.device(m)
303+
fw = (
304+
None
305+
if frequency_weights is None
306+
else xp.astype(xp.asarray(frequency_weights, device=device), dtype)
307+
)
308+
aw = (
309+
None
310+
if weights is None
311+
else xp.astype(xp.asarray(weights, device=device), dtype)
312+
)
313+
if fw is None and aw is None:
314+
w = None
315+
elif fw is None:
316+
w = aw
317+
elif aw is None:
318+
w = fw
319+
else:
320+
w = fw * aw
295321

296322
m_shape = eager_shape(m)
297-
fact = m_shape[-1] - 1
298-
299-
if fact <= 0:
300-
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
301-
fact = 0
302-
303-
m -= avg
304-
m_transpose = xp.matrix_transpose(m)
305-
if xp.isdtype(m_transpose.dtype, "complex floating"):
306-
m_transpose = xp.conj(m_transpose)
307-
c = xp.matmul(m, m_transpose)
308-
c /= fact
323+
if w is None:
324+
avg = xp.mean(m, axis=-1, keepdims=True)
325+
fact = m_shape[-1] - correction
326+
if fact <= 0:
327+
warnings.warn(
328+
"Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2
329+
)
330+
fact = 0
331+
else:
332+
v1 = xp.sum(w, axis=-1)
333+
avg = xp.sum(m * w, axis=-1, keepdims=True) / v1
334+
if aw is None:
335+
fact = v1 - correction
336+
else:
337+
fact = v1 - correction * xp.sum(w * aw, axis=-1) / v1
338+
339+
m_c = m - avg
340+
m_w = m_c if w is None else m_c * w
341+
m_cT = xp.matrix_transpose(m_c)
342+
if xp.isdtype(m_cT.dtype, "complex floating"):
343+
m_cT = xp.conj(m_cT)
344+
c = xp.matmul(m_w, m_cT) / fact
309345
axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1)
310346
return xp.squeeze(c, axis=axes)
311347

tests/test_funcs.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,97 @@ def test_batch(self, xp: ModuleType):
608608
ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var))
609609
xp_assert_close(res, xp.asarray(ref))
610610

611+
def test_correction(self, xp: ModuleType):
612+
rng = np.random.default_rng(20260417)
613+
m = rng.random((3, 20))
614+
for correction in (0, 1, 2):
615+
ref = np.cov(m, ddof=correction)
616+
res = cov(xp.asarray(m), correction=correction)
617+
xp_assert_close(res, xp.asarray(ref))
618+
619+
def test_correction_float(self, xp: ModuleType):
620+
# Float correction: reference computed by hand (numpy.cov rejects
621+
# non-integer ddof; our generic path supports it).
622+
rng = np.random.default_rng(20260417)
623+
m = rng.random((3, 20))
624+
n = m.shape[-1]
625+
centered = m - m.mean(axis=-1, keepdims=True)
626+
ref = centered @ centered.T / (n - 1.5)
627+
res = cov(xp.asarray(m), correction=1.5)
628+
xp_assert_close(res, xp.asarray(ref))
629+
630+
def test_axis(self, xp: ModuleType):
631+
rng = np.random.default_rng(20260417)
632+
m = rng.random((20, 3)) # observations on axis 0
633+
ref = np.cov(m, rowvar=False)
634+
res = cov(xp.asarray(m), axis=0)
635+
xp_assert_close(res, xp.asarray(ref))
636+
res_neg = cov(xp.asarray(m), axis=-2)
637+
xp_assert_close(res_neg, xp.asarray(ref))
638+
639+
def test_frequency_weights(self, xp: ModuleType):
640+
rng = np.random.default_rng(20260417)
641+
m = rng.random((3, 10))
642+
fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1], dtype=np.int64)
643+
ref = np.cov(m, fweights=fw)
644+
res = cov(xp.asarray(m), frequency_weights=xp.asarray(fw))
645+
xp_assert_close(res, xp.asarray(ref))
646+
647+
def test_weights(self, xp: ModuleType):
648+
rng = np.random.default_rng(20260417)
649+
m = rng.random((3, 10))
650+
aw = rng.random(10)
651+
ref = np.cov(m, aweights=aw)
652+
res = cov(xp.asarray(m), weights=xp.asarray(aw))
653+
xp_assert_close(res, xp.asarray(ref))
654+
655+
def test_both_weights(self, xp: ModuleType):
656+
rng = np.random.default_rng(20260417)
657+
m = rng.random((3, 10))
658+
fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1], dtype=np.int64)
659+
aw = rng.random(10)
660+
for correction in (0, 1, 2):
661+
ref = np.cov(m, ddof=correction, fweights=fw, aweights=aw)
662+
res = cov(
663+
xp.asarray(m),
664+
correction=correction,
665+
frequency_weights=xp.asarray(fw),
666+
weights=xp.asarray(aw),
667+
)
668+
xp_assert_close(res, xp.asarray(ref))
669+
670+
def test_batch_with_weights(self, xp: ModuleType):
671+
rng = np.random.default_rng(20260417)
672+
batch_shape = (2, 3)
673+
n_var, n_obs = 3, 15
674+
m = rng.random((*batch_shape, n_var, n_obs))
675+
aw = rng.random(n_obs)
676+
res = cov(xp.asarray(m), weights=xp.asarray(aw))
677+
ref_list = [np.cov(m_, aweights=aw) for m_ in np.reshape(m, (-1, n_var, n_obs))]
678+
ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var))
679+
xp_assert_close(res, xp.asarray(ref))
680+
681+
def test_axis_with_weights(self, xp: ModuleType):
682+
# axis=-2 (observations on first of 2D) combined with weights:
683+
# verifies that moveaxis and weight alignment cooperate.
684+
rng = np.random.default_rng(20260417)
685+
m = rng.random((15, 3)) # observations on axis 0
686+
aw = rng.random(15)
687+
fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1, 1], dtype=np.int64)
688+
ref = np.cov(m, rowvar=False, fweights=fw, aweights=aw)
689+
res = cov(
690+
xp.asarray(m),
691+
axis=-2,
692+
frequency_weights=xp.asarray(fw),
693+
weights=xp.asarray(aw),
694+
)
695+
xp_assert_close(res, xp.asarray(ref))
696+
697+
def test_axis_out_of_bounds(self, xp: ModuleType):
698+
m = xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
699+
with pytest.raises(IndexError):
700+
_ = cov(m, axis=5)
701+
611702

612703
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False)
613704
class TestOneHot:

0 commit comments

Comments
 (0)