Skip to content

Commit c1e453f

Browse files
ricardoV94jessegrabowski
authored andcommitted
Fix type of numba Argmax special case
1 parent fd70495 commit c1e453f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def numba_funcify_Argmax(op, node, **kwargs):
561561

562562
@numba_basic.numba_njit(inline="always")
563563
def argmax(x):
564-
return 0
564+
return np.array(0, dtype="int64")
565565

566566
else:
567567
axes = tuple(int(ax) for ax in axis)

0 commit comments

Comments
 (0)