Closed
Description
torch.linalg.trace
should work on stacks of matrices.
But this
import array_api_compat.torch as torch
a = torch.reshape(torch.arange(10*5*5), (10,5,5))
torch.linalg.trace(a)
throws RuntimeError: trace: expected a matrix, but got tensor with dim 3
.
The pytorch's trace
function doesn't handle stacks: see the docs.
Metadata
Metadata
Assignees
Labels
No labels