Skip to content

Commit 6155e4a

Browse files
committed
Use ndindex.iter_indices in _test_stacks in the linalg tests
This adds ndindex >= 1.6 as a dependency of the test suite.
1 parent 520e685 commit 6155e4a

File tree

2 files changed

+38
-12
lines changed

2 files changed

+38
-12
lines changed

array_api_tests/test_linalg.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import pytest
1717
from hypothesis import assume, given
1818
from hypothesis.strategies import (booleans, composite, none, tuples, integers,
19-
shared, sampled_from, data, just)
19+
shared, sampled_from, one_of, data, just)
20+
from ndindex import iter_indices
2021

2122
from .array_helpers import assert_exactly_equal, asarray, equal, zero, infinity
2223
from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes,
@@ -43,25 +44,49 @@
4344
# Standin strategy for not yet implemented tests
4445
todo = none()
4546

46-
def _test_stacks(f, *args, res=None, dims=2, true_val=None, **kw):
47+
def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1), **kw):
4748
"""
4849
Test that f(*args, **kw) maps across stacks of matrices
4950
50-
dims is the number of dimensions f should have for a single n x m matrix
51-
stack.
51+
dims is the number of dimensions f(*args) should have for a single n x m
52+
matrix stack.
53+
54+
matrix_axes are the axes along which matrices (or vectors) are stacked in
55+
the input.
56+
57+
true_val may be a function such that true_val(*x_stacks, **kw) gives the
58+
true value for f on a stack.
59+
60+
res should be the result of f(*args, **kw). It is computed if not passed
61+
in.
5262
53-
true_val may be a function such that true_val(*x_stacks) gives the true
54-
value for f on a stack
5563
"""
5664
if res is None:
5765
res = f(*args, **kw)
5866

59-
shape = args[0].shape if len(args) == 1 else broadcast_shapes(*[x.shape
60-
for x in args])
61-
for _idx in sh.ndindex(shape[:-2]):
62-
idx = _idx + (slice(None),)*dims
63-
res_stack = res[idx]
64-
x_stacks = [x[_idx + (...,)] for x in args]
67+
shapes = [x.shape for x in args]
68+
69+
for (x_idxes, (res_idx,)) in zip(
70+
iter_indices(*shapes, skip_axes=matrix_axes),
71+
iter_indices(res.shape, skip_axes=tuple(range(-dims, 0)))):
72+
x_idxes = [x_idx.raw for x_idx in x_idxes]
73+
res_idx = res_idx.raw
74+
# res should have `dims` slices in it. Cases where there are more than
75+
# `dims` slices are ambiguous, but that should only occur in cases
76+
# where axes = (-2, -1).
77+
# res_idx2 = []
78+
# d = dims
79+
# for i in res_idx:
80+
# if isinstance(i, slice):
81+
# if d:
82+
# res_idx2.append(i)
83+
# d -= 1
84+
# else:
85+
# res_idx2.append(i)
86+
# res_idx2 = tuple(res_idx2)
87+
88+
res_stack = res[res_idx]
89+
x_stacks = [x[x_idx] for x, x_idx in zip(args, x_idxes)]
6590
decomp_res_stack = f(*x_stacks, **kw)
6691
assert_exactly_equal(res_stack, decomp_res_stack)
6792
if true_val:

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pytest
22
hypothesis>=6.31.1
3+
ndindex>=1.6
34
regex
45
removestar

0 commit comments

Comments
 (0)