|
14 | 14 | """
|
15 | 15 |
|
16 | 16 | from hypothesis import assume, given
|
17 |
| -from hypothesis.strategies import booleans, composite, none, integers, shared |
| 17 | +from hypothesis.strategies import booleans, composite, none, tuples, integers, shared |
18 | 18 |
|
19 | 19 | from .array_helpers import (assert_exactly_equal, ndindex, asarray,
|
20 | 20 | numeric_dtype_objects, promote_dtypes)
|
21 | 21 | from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes,
|
22 | 22 | square_matrix_shapes, symmetric_matrices,
|
23 | 23 | positive_definite_matrices, MAX_ARRAY_SIZE,
|
24 | 24 | invertible_matrices, two_mutual_arrays,
|
25 |
| - mutually_promotable_dtypes) |
| 25 | + mutually_promotable_dtypes, one_d_shapes) |
26 | 26 | from .pytest_helpers import raises
|
27 | 27 |
|
28 | 28 | from .test_broadcasting import broadcast_shapes
|
@@ -339,12 +339,27 @@ def test_matrix_transpose(x):
|
339 | 339 | _test_stacks(linalg.matrix_transpose, x, res=res, true_val=true_val)
|
340 | 340 |
|
341 | 341 | @given(
|
342 |
| - x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes), |
343 |
| - x2=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes), |
| 342 | + *two_mutual_arrays(dtype_objects=numeric_dtype_objects, |
| 343 | + two_shapes=tuples(one_d_shapes, one_d_shapes)) |
344 | 344 | )
|
345 | 345 | def test_outer(x1, x2):
|
346 |
| - # res = linalg.outer(x1, x2) |
347 |
| - pass |
| 346 | + # outer does not work on stacks. See |
| 347 | + # https://github.com/data-apis/array-api/issues/242. |
| 348 | + res = linalg.outer(x1, x2) |
| 349 | + |
| 350 | + shape = (x1.shape[0], x2.shape[0]) |
| 351 | + assert res.shape == shape, "outer() did not return the correct shape" |
| 352 | + assert res.dtype == promote_dtypes(x1, x2), "outer() did not return the correct dtype" |
| 353 | + |
| 354 | + if 0 in shape: |
| 355 | + true_res = _array_module.empty(shape, dtype=res.dtype) |
| 356 | + else: |
| 357 | + true_res = _array_module.asarray([[x1[i]*x2[j] |
| 358 | + for j in range(x2.shape[0])] |
| 359 | + for i in range(x1.shape[0])], |
| 360 | + dtype=res.dtype) |
| 361 | + |
| 362 | + assert_exactly_equal(res, true_res) |
348 | 363 |
|
349 | 364 | @given(
|
350 | 365 | x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
|
|
0 commit comments