Skip to content

Commit aa372e2

Browse files
ricardoV94jessegrabowski
authored andcommitted
Fix dtype of numba dispatch of ArgSort
1 parent c1e453f commit aa372e2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def argort_vec(X, axis):
466466
axis = axis.item()
467467

468468
Y = np.swapaxes(X, axis, 0)
469-
result = np.empty_like(Y)
469+
result = np.empty_like(Y, dtype="int64")
470470

471471
indices = list(np.ndindex(Y.shape[1:]))
472472

0 commit comments

Comments
 (0)