Skip to content

Commit 8020878

Browse files
committed
Fix get_canonical_form_slice when lengths are numpy integers
Introduced in f9dfe70
1 parent df769f6 commit 8020878

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

pytensor/tensor/subtensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,9 @@ def get_canonical_form_slice(
305305

306306
# At this point we have a slice object. Possibly with symbolic inputs.
307307

308-
def analyze(x):
308+
def analyze(x) -> tuple[int | Variable, bool]:
309309
try:
310-
x_constant = as_index_literal(x)
310+
x_constant = int(as_index_literal(x))
311311
is_constant = True
312312
except NotScalarConstantError:
313313
x_constant = x

tests/tensor/test_subtensor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,11 @@ def test_symbolic_tensor(self):
154154
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
155155
assert res[1] == 1
156156

157-
def test_all_integer(self):
158-
res = get_canonical_form_slice(slice(1, 5, 2), 7)
157+
@pytest.mark.parametrize("int_fn", [int, np.int64, as_tensor, as_scalar])
158+
def test_all_integer(self, int_fn):
159+
res = get_canonical_form_slice(
160+
slice(int_fn(1), int_fn(5), int_fn(2)), int_fn(7)
161+
)
159162
assert isinstance(res[0], slice)
160163
assert res[1] == 1
161164

0 commit comments

Comments
 (0)