Skip to content

Commit b178b5a

Browse files
committed
Implement test_trace()
The dtype of the output is still not specified correctly (it should work the same as sum()).
1 parent 733907d commit b178b5a

File tree

1 file changed

+35
-4
lines changed

1 file changed

+35
-4
lines changed

array_api_tests/test_linalg.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -561,12 +561,43 @@ def test_tensordot(x1, x2, kw):
561561

562562
@pytest.mark.xp_extension('linalg')
563563
@given(
564-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),
565-
kw=kwargs(offset=todo)
564+
x=xps.arrays(dtype=xps.numeric_dtypes(), shape=matrix_shapes()),
565+
# offset may produce an overflow if it is too large. Supporting offsets
566+
# that are way larger than the array shape isn't very important.
567+
kw=kwargs(offset=integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE))
566568
)
567569
def test_trace(x, kw):
568-
# res = linalg.trace(x, **kw)
569-
pass
570+
res = linalg.trace(x, **kw)
571+
572+
# TODO: trace() should promote in some cases. See
573+
# https://github.com/data-apis/array-api/issues/202. See also the dtype
574+
# argument to sum() below.
575+
576+
# assert res.dtype == x.dtype, "trace() returned the wrong dtype"
577+
578+
n, m = x.shape[-2:]
579+
offset = kw.get('offset', 0)
580+
assert res.shape == x.shape[:-2], "trace() returned the wrong shape"
581+
582+
def true_trace(x_stack):
583+
# Note: the spec does not specify that offset must be within the
584+
# bounds of the matrix. A large offset should just produce a size 0
585+
# diagonal in the last dimension (trace 0). See test_diagonal().
586+
if offset < 0:
587+
diag_size = min(n, m, max(n + offset, 0))
588+
elif offset == 0:
589+
diag_size = min(n, m)
590+
else:
591+
diag_size = min(n, m, max(m - offset, 0))
592+
593+
if offset >= 0:
594+
x_stack_diag = [x_stack[i, i + offset] for i in range(diag_size)]
595+
else:
596+
x_stack_diag = [x_stack[i - offset, i] for i in range(diag_size)]
597+
return _array_module.sum(asarray(x_stack_diag, dtype=x.dtype), dtype=x.dtype)
598+
599+
_test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace)
600+
570601

571602
@given(
572603
x1=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes),

0 commit comments

Comments
 (0)