@@ -331,10 +331,9 @@ def test_inv(x):
331
331
332
332
# TODO: Test that the result is actually the inverse
333
333
334
- @given (
335
- * two_mutual_arrays (dh .real_dtypes )
336
- )
337
- def test_matmul (x1 , x2 ):
334
+ def _test_matmul (namespace , x1 , x2 ):
335
+ matmul = namespace .matmul
336
+
338
337
# TODO: Make this also test the @ operator
339
338
if (x1 .shape == () or x2 .shape == ()
340
339
or len (x1 .shape ) == len (x2 .shape ) == 1 and x1 .shape != x2 .shape
@@ -347,7 +346,7 @@ def test_matmul(x1, x2):
347
346
"matmul did not raise an exception for invalid shapes" )
348
347
return
349
348
else :
350
- res = _array_module . matmul (x1 , x2 )
349
+ res = matmul (x1 , x2 )
351
350
352
351
ph .assert_dtype ("matmul" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = res .dtype )
353
352
@@ -358,19 +357,32 @@ def test_matmul(x1, x2):
358
357
ph .assert_result_shape ("matmul" , in_shapes = [x1 .shape , x2 .shape ],
359
358
out_shape = res .shape ,
360
359
expected = x2 .shape [:- 2 ] + x2 .shape [- 1 :])
361
- _test_stacks (_array_module . matmul , x1 , x2 , res = res , dims = 1 ,
360
+ _test_stacks (matmul , x1 , x2 , res = res , dims = 1 ,
362
361
matrix_axes = [(0 ,), (- 2 , - 1 )], res_axes = [- 1 ])
363
362
elif len (x2 .shape ) == 1 :
364
363
ph .assert_result_shape ("matmul" , in_shapes = [x1 .shape , x2 .shape ],
365
364
out_shape = res .shape , expected = x1 .shape [:- 1 ])
366
- _test_stacks (_array_module . matmul , x1 , x2 , res = res , dims = 1 ,
365
+ _test_stacks (matmul , x1 , x2 , res = res , dims = 1 ,
367
366
matrix_axes = [(- 2 , - 1 ), (0 ,)], res_axes = [- 1 ])
368
367
else :
369
368
stack_shape = sh .broadcast_shapes (x1 .shape [:- 2 ], x2 .shape [:- 2 ])
370
369
ph .assert_result_shape ("matmul" , in_shapes = [x1 .shape , x2 .shape ],
371
370
out_shape = res .shape ,
372
371
expected = stack_shape + (x1 .shape [- 2 ], x2 .shape [- 1 ]))
373
- _test_stacks (_array_module .matmul , x1 , x2 , res = res )
372
+ _test_stacks (matmul , x1 , x2 , res = res )
373
+
374
+ @pytest .mark .xp_extension ('linalg' )
375
+ @given (
376
+ * two_mutual_arrays (dh .real_dtypes )
377
+ )
378
+ def test_linalg_matmul (x1 , x2 ):
379
+ return _test_matmul (linalg , x1 , x2 )
380
+
381
+ @given (
382
+ * two_mutual_arrays (dh .real_dtypes )
383
+ )
384
+ def test_matmul (x1 , x2 ):
385
+ return _test_matmul (_array_module , x1 , x2 )
374
386
375
387
@pytest .mark .xp_extension ('linalg' )
376
388
@given (
@@ -428,11 +440,9 @@ def test_matrix_power(x, n):
428
440
def test_matrix_rank (x , kw ):
429
441
linalg .matrix_rank (x , ** kw )
430
442
431
- @given (
432
- x = arrays (dtype = xps .scalar_dtypes (), shape = matrix_shapes ()),
433
- )
434
- def test_matrix_transpose (x ):
435
- res = _array_module .matrix_transpose (x )
443
+ def _test_matrix_transpose (namespace , x ):
444
+ matrix_transpose = namespace .matrix_transpose
445
+ res = matrix_transpose (x )
436
446
true_val = lambda a : _array_module .asarray ([[a [i , j ] for i in
437
447
range (a .shape [0 ])] for j in
438
448
range (a .shape [1 ])],
@@ -444,7 +454,20 @@ def test_matrix_transpose(x):
444
454
ph .assert_result_shape ("matrix_transpose" , in_shapes = [x .shape ],
445
455
out_shape = res .shape , expected = shape )
446
456
447
- _test_stacks (_array_module .matrix_transpose , x , res = res , true_val = true_val )
457
+ _test_stacks (matrix_transpose , x , res = res , true_val = true_val )
458
+
459
+ @pytest .mark .xp_extension ('linalg' )
460
+ @given (
461
+ x = arrays (dtype = xps .scalar_dtypes (), shape = matrix_shapes ()),
462
+ )
463
+ def test_linalg_matrix_transpose (x ):
464
+ return _test_matrix_transpose (linalg , x )
465
+
466
+ @given (
467
+ x = arrays (dtype = xps .scalar_dtypes (), shape = matrix_shapes ()),
468
+ )
469
+ def test_matrix_transpose (x ):
470
+ return _test_matrix_transpose (_array_module , x )
448
471
449
472
@pytest .mark .xp_extension ('linalg' )
450
473
@given (
@@ -759,12 +782,9 @@ def _test_tensordot_stacks(x1, x2, kw, res):
759
782
decomp_res_stack = xp .tensordot (x1_stack , x2_stack , axes = res_axes )
760
783
assert_equal (res_stack , decomp_res_stack )
761
784
762
- @given (
763
- * two_mutual_arrays (dh .numeric_dtypes , two_shapes = tensordot_shapes ()),
764
- tensordot_kw ,
765
- )
766
- def test_tensordot (x1 , x2 , kw ):
767
- res = xp .tensordot (x1 , x2 , ** kw )
785
+ def _test_tensordot (namespace , x1 , x2 , kw ):
786
+ tensordot = namespace .tensordot
787
+ res = tensordot (x1 , x2 , ** kw )
768
788
769
789
ph .assert_dtype ("tensordot" , in_dtype = [x1 .dtype , x2 .dtype ],
770
790
out_dtype = res .dtype )
@@ -785,6 +805,21 @@ def test_tensordot(x1, x2, kw):
785
805
expected = result_shape )
786
806
_test_tensordot_stacks (x1 , x2 , kw , res )
787
807
808
+ @pytest .mark .xp_extension ('linalg' )
809
+ @given (
810
+ * two_mutual_arrays (dh .numeric_dtypes , two_shapes = tensordot_shapes ()),
811
+ tensordot_kw ,
812
+ )
813
+ def test_linalg_tensordot (x1 , x2 , kw ):
814
+ _test_tensordot (linalg , x1 , x2 , kw )
815
+
816
+ @given (
817
+ * two_mutual_arrays (dh .numeric_dtypes , two_shapes = tensordot_shapes ()),
818
+ tensordot_kw ,
819
+ )
820
+ def test_tensordot (x1 , x2 , kw ):
821
+ _test_tensordot (_array_module , x1 , x2 , kw )
822
+
788
823
@pytest .mark .xp_extension ('linalg' )
789
824
@given (
790
825
x = arrays (dtype = xps .numeric_dtypes (), shape = matrix_shapes ()),
@@ -828,12 +863,8 @@ def true_trace(x_stack, offset=0):
828
863
829
864
_test_stacks (linalg .trace , x , ** kw , res = res , dims = 0 , true_val = true_trace )
830
865
831
-
832
- @given (
833
- * two_mutual_arrays (dh .numeric_dtypes , mutually_broadcastable_shapes (2 , min_dims = 1 )),
834
- data (),
835
- )
836
- def test_vecdot (x1 , x2 , data ):
866
+ def _test_vecdot (namespace , x1 , x2 , data ):
867
+ vecdot = namespace .vecdot
837
868
broadcasted_shape = sh .broadcast_shapes (x1 .shape , x2 .shape )
838
869
min_ndim = min (x1 .ndim , x2 .ndim )
839
870
ndim = len (broadcasted_shape )
@@ -842,14 +873,14 @@ def test_vecdot(x1, x2, data):
842
873
x1_shape = (1 ,)* (ndim - x1 .ndim ) + tuple (x1 .shape )
843
874
x2_shape = (1 ,)* (ndim - x2 .ndim ) + tuple (x2 .shape )
844
875
if x1_shape [axis ] != x2_shape [axis ]:
845
- ph .raises (Exception , lambda : xp . vecdot (x1 , x2 , ** kw ),
876
+ ph .raises (Exception , lambda : vecdot (x1 , x2 , ** kw ),
846
877
"vecdot did not raise an exception for invalid shapes" )
847
878
return
848
879
expected_shape = list (broadcasted_shape )
849
880
expected_shape .pop (axis )
850
881
expected_shape = tuple (expected_shape )
851
882
852
- res = xp . vecdot (x1 , x2 , ** kw )
883
+ res = vecdot (x1 , x2 , ** kw )
853
884
854
885
ph .assert_dtype ("vecdot" , in_dtype = [x1 .dtype , x2 .dtype ],
855
886
out_dtype = res .dtype )
@@ -862,9 +893,25 @@ def true_val(x, y, axis=-1):
862
893
else :
863
894
true_val = None
864
895
865
- _test_stacks (linalg . vecdot , x1 , x2 , res = res , dims = 0 ,
896
+ _test_stacks (vecdot , x1 , x2 , res = res , dims = 0 ,
866
897
matrix_axes = (axis ,), true_val = true_val )
867
898
899
+
900
+ @pytest .mark .xp_extension ('linalg' )
901
+ @given (
902
+ * two_mutual_arrays (dh .numeric_dtypes , mutually_broadcastable_shapes (2 , min_dims = 1 )),
903
+ data (),
904
+ )
905
+ def test_linalg_vecdot (x1 , x2 , data ):
906
+ _test_vecdot (linalg , x1 , x2 , data )
907
+
908
+ @given (
909
+ * two_mutual_arrays (dh .numeric_dtypes , mutually_broadcastable_shapes (2 , min_dims = 1 )),
910
+ data (),
911
+ )
912
+ def test_vecdot (x1 , x2 , data ):
913
+ _test_vecdot (_array_module , x1 , x2 , data )
914
+
868
915
# Insanely large orders might not work. There isn't a limit specified in the
869
916
# spec, so we just limit to reasonable values here.
870
917
max_ord = 100
0 commit comments