Skip to content

Commit 3092422

Browse files
committed
Test matmul, matrix_transpose, tensordot, and vecdot for the main and linalg namespaces separately
1 parent 3856b8f commit 3092422

File tree

1 file changed

+76
-29
lines changed

1 file changed

+76
-29
lines changed

array_api_tests/test_linalg.py

Lines changed: 76 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -331,10 +331,9 @@ def test_inv(x):
331331

332332
# TODO: Test that the result is actually the inverse
333333

334-
@given(
335-
*two_mutual_arrays(dh.real_dtypes)
336-
)
337-
def test_matmul(x1, x2):
334+
def _test_matmul(namespace, x1, x2):
335+
matmul = namespace.matmul
336+
338337
# TODO: Make this also test the @ operator
339338
if (x1.shape == () or x2.shape == ()
340339
or len(x1.shape) == len(x2.shape) == 1 and x1.shape != x2.shape
@@ -347,7 +346,7 @@ def test_matmul(x1, x2):
347346
"matmul did not raise an exception for invalid shapes")
348347
return
349348
else:
350-
res = _array_module.matmul(x1, x2)
349+
res = matmul(x1, x2)
351350

352351
ph.assert_dtype("matmul", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype)
353352

@@ -358,19 +357,32 @@ def test_matmul(x1, x2):
358357
ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape],
359358
out_shape=res.shape,
360359
expected=x2.shape[:-2] + x2.shape[-1:])
361-
_test_stacks(_array_module.matmul, x1, x2, res=res, dims=1,
360+
_test_stacks(matmul, x1, x2, res=res, dims=1,
362361
matrix_axes=[(0,), (-2, -1)], res_axes=[-1])
363362
elif len(x2.shape) == 1:
364363
ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape],
365364
out_shape=res.shape, expected=x1.shape[:-1])
366-
_test_stacks(_array_module.matmul, x1, x2, res=res, dims=1,
365+
_test_stacks(matmul, x1, x2, res=res, dims=1,
367366
matrix_axes=[(-2, -1), (0,)], res_axes=[-1])
368367
else:
369368
stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
370369
ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape],
371370
out_shape=res.shape,
372371
expected=stack_shape + (x1.shape[-2], x2.shape[-1]))
373-
_test_stacks(_array_module.matmul, x1, x2, res=res)
372+
_test_stacks(matmul, x1, x2, res=res)
373+
374+
@pytest.mark.xp_extension('linalg')
375+
@given(
376+
*two_mutual_arrays(dh.real_dtypes)
377+
)
378+
def test_linalg_matmul(x1, x2):
379+
return _test_matmul(linalg, x1, x2)
380+
381+
@given(
382+
*two_mutual_arrays(dh.real_dtypes)
383+
)
384+
def test_matmul(x1, x2):
385+
return _test_matmul(_array_module, x1, x2)
374386

375387
@pytest.mark.xp_extension('linalg')
376388
@given(
@@ -428,11 +440,9 @@ def test_matrix_power(x, n):
428440
def test_matrix_rank(x, kw):
429441
linalg.matrix_rank(x, **kw)
430442

431-
@given(
432-
x=arrays(dtype=xps.scalar_dtypes(), shape=matrix_shapes()),
433-
)
434-
def test_matrix_transpose(x):
435-
res = _array_module.matrix_transpose(x)
443+
def _test_matrix_transpose(namespace, x):
444+
matrix_transpose = namespace.matrix_transpose
445+
res = matrix_transpose(x)
436446
true_val = lambda a: _array_module.asarray([[a[i, j] for i in
437447
range(a.shape[0])] for j in
438448
range(a.shape[1])],
@@ -444,7 +454,20 @@ def test_matrix_transpose(x):
444454
ph.assert_result_shape("matrix_transpose", in_shapes=[x.shape],
445455
out_shape=res.shape, expected=shape)
446456

447-
_test_stacks(_array_module.matrix_transpose, x, res=res, true_val=true_val)
457+
_test_stacks(matrix_transpose, x, res=res, true_val=true_val)
458+
459+
@pytest.mark.xp_extension('linalg')
460+
@given(
461+
x=arrays(dtype=xps.scalar_dtypes(), shape=matrix_shapes()),
462+
)
463+
def test_linalg_matrix_transpose(x):
464+
return _test_matrix_transpose(linalg, x)
465+
466+
@given(
467+
x=arrays(dtype=xps.scalar_dtypes(), shape=matrix_shapes()),
468+
)
469+
def test_matrix_transpose(x):
470+
return _test_matrix_transpose(_array_module, x)
448471

449472
@pytest.mark.xp_extension('linalg')
450473
@given(
@@ -759,12 +782,9 @@ def _test_tensordot_stacks(x1, x2, kw, res):
759782
decomp_res_stack = xp.tensordot(x1_stack, x2_stack, axes=res_axes)
760783
assert_equal(res_stack, decomp_res_stack)
761784

762-
@given(
763-
*two_mutual_arrays(dh.numeric_dtypes, two_shapes=tensordot_shapes()),
764-
tensordot_kw,
765-
)
766-
def test_tensordot(x1, x2, kw):
767-
res = xp.tensordot(x1, x2, **kw)
785+
def _test_tensordot(namespace, x1, x2, kw):
786+
tensordot = namespace.tensordot
787+
res = tensordot(x1, x2, **kw)
768788

769789
ph.assert_dtype("tensordot", in_dtype=[x1.dtype, x2.dtype],
770790
out_dtype=res.dtype)
@@ -785,6 +805,21 @@ def test_tensordot(x1, x2, kw):
785805
expected=result_shape)
786806
_test_tensordot_stacks(x1, x2, kw, res)
787807

808+
@pytest.mark.xp_extension('linalg')
809+
@given(
810+
*two_mutual_arrays(dh.numeric_dtypes, two_shapes=tensordot_shapes()),
811+
tensordot_kw,
812+
)
813+
def test_linalg_tensordot(x1, x2, kw):
814+
_test_tensordot(linalg, x1, x2, kw)
815+
816+
@given(
817+
*two_mutual_arrays(dh.numeric_dtypes, two_shapes=tensordot_shapes()),
818+
tensordot_kw,
819+
)
820+
def test_tensordot(x1, x2, kw):
821+
_test_tensordot(_array_module, x1, x2, kw)
822+
788823
@pytest.mark.xp_extension('linalg')
789824
@given(
790825
x=arrays(dtype=xps.numeric_dtypes(), shape=matrix_shapes()),
@@ -828,12 +863,8 @@ def true_trace(x_stack, offset=0):
828863

829864
_test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace)
830865

831-
832-
@given(
833-
*two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)),
834-
data(),
835-
)
836-
def test_vecdot(x1, x2, data):
866+
def _test_vecdot(namespace, x1, x2, data):
867+
vecdot = namespace.vecdot
837868
broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)
838869
min_ndim = min(x1.ndim, x2.ndim)
839870
ndim = len(broadcasted_shape)
@@ -842,14 +873,14 @@ def test_vecdot(x1, x2, data):
842873
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
843874
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
844875
if x1_shape[axis] != x2_shape[axis]:
845-
ph.raises(Exception, lambda: xp.vecdot(x1, x2, **kw),
876+
ph.raises(Exception, lambda: vecdot(x1, x2, **kw),
846877
"vecdot did not raise an exception for invalid shapes")
847878
return
848879
expected_shape = list(broadcasted_shape)
849880
expected_shape.pop(axis)
850881
expected_shape = tuple(expected_shape)
851882

852-
res = xp.vecdot(x1, x2, **kw)
883+
res = vecdot(x1, x2, **kw)
853884

854885
ph.assert_dtype("vecdot", in_dtype=[x1.dtype, x2.dtype],
855886
out_dtype=res.dtype)
@@ -862,9 +893,25 @@ def true_val(x, y, axis=-1):
862893
else:
863894
true_val = None
864895

865-
_test_stacks(linalg.vecdot, x1, x2, res=res, dims=0,
896+
_test_stacks(vecdot, x1, x2, res=res, dims=0,
866897
matrix_axes=(axis,), true_val=true_val)
867898

899+
900+
@pytest.mark.xp_extension('linalg')
901+
@given(
902+
*two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)),
903+
data(),
904+
)
905+
def test_linalg_vecdot(x1, x2, data):
906+
_test_vecdot(linalg, x1, x2, data)
907+
908+
@given(
909+
*two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)),
910+
data(),
911+
)
912+
def test_vecdot(x1, x2, data):
913+
_test_vecdot(_array_module, x1, x2, data)
914+
868915
# Insanely large orders might not work. There isn't a limit specified in the
869916
# spec, so we just limit to reasonable values here.
870917
max_ord = 100

0 commit comments

Comments
 (0)