@@ -561,12 +561,43 @@ def test_tensordot(x1, x2, kw):
561
561
562
562
@pytest .mark .xp_extension ('linalg' )
563
563
@given (
564
- x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
565
- kw = kwargs (offset = todo )
564
+ x = xps .arrays (dtype = xps .numeric_dtypes (), shape = matrix_shapes ()),
565
+ # offset may produce an overflow if it is too large. Supporting offsets
566
+ # that are way larger than the array shape isn't very important.
567
+ kw = kwargs (offset = integers (- MAX_ARRAY_SIZE , MAX_ARRAY_SIZE ))
566
568
)
567
569
def test_trace (x , kw ):
568
- # res = linalg.trace(x, **kw)
569
- pass
570
+ res = linalg .trace (x , ** kw )
571
+
572
+ # TODO: trace() should promote in some cases. See
573
+ # https://github.com/data-apis/array-api/issues/202. See also the dtype
574
+ # argument to sum() below.
575
+
576
+ # assert res.dtype == x.dtype, "trace() returned the wrong dtype"
577
+
578
+ n , m = x .shape [- 2 :]
579
+ offset = kw .get ('offset' , 0 )
580
+ assert res .shape == x .shape [:- 2 ], "trace() returned the wrong shape"
581
+
582
+ def true_trace (x_stack ):
583
+ # Note: the spec does not specify that offset must be within the
584
+ # bounds of the matrix. A large offset should just produce a size 0
585
+ # diagonal in the last dimension (trace 0). See test_diagonal().
586
+ if offset < 0 :
587
+ diag_size = min (n , m , max (n + offset , 0 ))
588
+ elif offset == 0 :
589
+ diag_size = min (n , m )
590
+ else :
591
+ diag_size = min (n , m , max (m - offset , 0 ))
592
+
593
+ if offset >= 0 :
594
+ x_stack_diag = [x_stack [i , i + offset ] for i in range (diag_size )]
595
+ else :
596
+ x_stack_diag = [x_stack [i - offset , i ] for i in range (diag_size )]
597
+ return _array_module .sum (asarray (x_stack_diag , dtype = x .dtype ), dtype = x .dtype )
598
+
599
+ _test_stacks (linalg .trace , x , ** kw , res = res , dims = 0 , true_val = true_trace )
600
+
570
601
571
602
@given (
572
603
x1 = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
0 commit comments