From d35ee58ad596578fd90e0aff2191e2f5989c1d70 Mon Sep 17 00:00:00 2001 From: Ralf Gommers Date: Tue, 14 Nov 2023 13:06:19 +0100 Subject: [PATCH] Fix linalg.trace test, result dtype was incorrect As the spec says, the output dtype should be the default integer/float/complex dtype just like for `sum`. Given that the reference was implemented with `sum`, avoiding to pass an explicit dtype should be enough to obtain the correct results. --- array_api_tests/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 0974805e..a289a1f2 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -626,7 +626,7 @@ def true_trace(x_stack): x_stack_diag = [x_stack[i, i + offset] for i in range(diag_size)] else: x_stack_diag = [x_stack[i - offset, i] for i in range(diag_size)] - return _array_module.sum(asarray(x_stack_diag, dtype=x.dtype), dtype=x.dtype) + return _array_module.sum(asarray(x_stack_diag, dtype=x.dtype)) _test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace)