Skip to content

Commit 5b44e47

Browse files
committed
Implement test_outer()
1 parent db2f231 commit 5b44e47

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)
110110
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
111111
)
112112

113+
one_d_shapes = xps.array_shapes(min_dims=1, max_dims=1, min_side=0, max_side=SQRT_MAX_ARRAY_SIZE)
114+
113115
# Matrix shapes assume stacks of matrices
114116
matrix_shapes = xps.array_shapes(min_dims=2, min_side=1).filter(
115117
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE

array_api_tests/test_linalg.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
"""
1515

1616
from hypothesis import assume, given
17-
from hypothesis.strategies import booleans, composite, none, integers, shared
17+
from hypothesis.strategies import booleans, composite, none, tuples, integers, shared
1818

1919
from .array_helpers import (assert_exactly_equal, ndindex, asarray,
2020
numeric_dtype_objects, promote_dtypes)
2121
from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes,
2222
square_matrix_shapes, symmetric_matrices,
2323
positive_definite_matrices, MAX_ARRAY_SIZE,
2424
invertible_matrices, two_mutual_arrays,
25-
mutually_promotable_dtypes)
25+
mutually_promotable_dtypes, one_d_shapes)
2626
from .pytest_helpers import raises
2727

2828
from .test_broadcasting import broadcast_shapes
@@ -339,12 +339,27 @@ def test_matrix_transpose(x):
339339
_test_stacks(linalg.matrix_transpose, x, res=res, true_val=true_val)
340340

341341
@given(
342-
x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
343-
x2=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
342+
*two_mutual_arrays(dtype_objects=numeric_dtype_objects,
343+
two_shapes=tuples(one_d_shapes, one_d_shapes))
344344
)
345345
def test_outer(x1, x2):
346-
# res = linalg.outer(x1, x2)
347-
pass
346+
# outer does not work on stacks. See
347+
# https://github.com/data-apis/array-api/issues/242.
348+
res = linalg.outer(x1, x2)
349+
350+
shape = (x1.shape[0], x2.shape[0])
351+
assert res.shape == shape, "outer() did not return the correct shape"
352+
assert res.dtype == promote_dtypes(x1, x2), "outer() did not return the correct dtype"
353+
354+
if 0 in shape:
355+
true_res = _array_module.empty(shape, dtype=res.dtype)
356+
else:
357+
true_res = _array_module.asarray([[x1[i]*x2[j]
358+
for j in range(x2.shape[0])]
359+
for i in range(x1.shape[0])],
360+
dtype=res.dtype)
361+
362+
assert_exactly_equal(res, true_res)
348363

349364
@given(
350365
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),

0 commit comments

Comments
 (0)