diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 0974805e..a289a1f2 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -626,7 +626,7 @@ def true_trace(x_stack): x_stack_diag = [x_stack[i, i + offset] for i in range(diag_size)] else: x_stack_diag = [x_stack[i - offset, i] for i in range(diag_size)] - return _array_module.sum(asarray(x_stack_diag, dtype=x.dtype), dtype=x.dtype) + return _array_module.sum(asarray(x_stack_diag, dtype=x.dtype)) _test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace)