diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 2cec476c4a..59961e7c2f 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -325,7 +325,7 @@ def analyze(x): and is_step_constant and is_length_constant ): - assert isinstance(length, int) + assert isinstance(length, int | np.integer) _start, _stop, _step = slice(start, stop, step).indices(length) if _start <= _stop and _step >= 1: return slice(_start, _stop, _step), 1 diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index f4ba58e26a..9cb730aafd 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -154,8 +154,11 @@ def test_symbolic_tensor(self): assert isinstance(res[0].owner.op, ptb.ScalarFromTensor) assert res[1] == 1 - def test_all_integer(self): - res = get_canonical_form_slice(slice(1, 5, 2), 7) + @pytest.mark.parametrize("int_fn", [int, np.int64, as_tensor, as_scalar]) + def test_all_integer(self, int_fn): + res = get_canonical_form_slice( + slice(int_fn(1), int_fn(5), int_fn(2)), int_fn(7) + ) assert isinstance(res[0], slice) assert res[1] == 1