Skip to content

Commit f9dfe70

Browse files
michaelosthegericardoV94
authored andcommitted
Refactor get_canonical_form_slice to fix subtensor typing
1 parent 906e142 commit f9dfe70

File tree

3 files changed

+170
-42
lines changed

3 files changed

+170
-42
lines changed

pytensor/tensor/subtensor.py

Lines changed: 120 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Callable, Iterable
44
from itertools import chain, groupby
55
from textwrap import dedent
6+
from typing import cast, overload
67

78
import numpy as np
89

@@ -19,13 +20,19 @@
1920
from pytensor.link.c.params_type import ParamsType
2021
from pytensor.misc.safe_asarray import _asarray
2122
from pytensor.printing import Printer, pprint, set_precedence
22-
from pytensor.scalar.basic import ScalarConstant
23-
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
23+
from pytensor.scalar.basic import ScalarConstant, ScalarVariable
24+
from pytensor.tensor import (
25+
TensorLike,
26+
_get_vector_length,
27+
as_tensor_variable,
28+
get_vector_length,
29+
)
2430
from pytensor.tensor.basic import (
2531
ScalarFromTensor,
2632
alloc,
2733
get_underlying_scalar_constant_value,
2834
nonzero,
35+
scalar_from_tensor,
2936
)
3037
from pytensor.tensor.blockwise import vectorize_node_fallback
3138
from pytensor.tensor.elemwise import DimShuffle
@@ -51,8 +58,14 @@
5158
wscalar,
5259
zscalar,
5360
)
54-
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType, make_slice
55-
from pytensor.tensor.variable import TensorVariable
61+
from pytensor.tensor.type_other import (
62+
NoneConst,
63+
NoneTypeT,
64+
SliceConstant,
65+
SliceType,
66+
make_slice,
67+
)
68+
from pytensor.tensor.variable import TensorConstant, TensorVariable
5669

5770

5871
_logger = logging.getLogger("pytensor.tensor.subtensor")
@@ -134,7 +147,7 @@ def convert_indices(indices, entry):
134147

135148

136149
def as_index_constant(
137-
a: slice | int | np.integer | Variable | None,
150+
a: slice | int | np.integer | Variable | None | TensorLike,
138151
) -> Variable | slice | None:
139152
r"""Convert Python literals to PyTensor constants--when possible--in `Subtensor` arguments.
140153
@@ -150,15 +163,41 @@ def as_index_constant(
150163
)
151164
elif isinstance(a, int | np.integer):
152165
return ps.ScalarConstant(ps.int64, a)
153-
elif not isinstance(a, Variable):
154-
return as_tensor_variable(a)
155-
else:
166+
elif isinstance(a, Variable):
156167
return a
168+
return as_tensor_variable(a)
169+
170+
171+
@overload
172+
def as_index_literal(idx: int | np.integer) -> int | np.integer: ...
173+
174+
175+
@overload
176+
def as_index_literal(idx: None) -> None: ...
177+
178+
179+
@overload
180+
def as_index_literal(idx: slice | SliceConstant) -> slice: ...
181+
182+
183+
@overload
184+
def as_index_literal(idx: ScalarConstant | TensorConstant) -> int | np.integer: ...
185+
186+
187+
@overload
188+
def as_index_literal(idx: Variable): ...
157189

158190

159191
def as_index_literal(
160-
idx: Variable | slice | None,
161-
) -> int | slice | None:
192+
idx: None
193+
| int
194+
| np.integer
195+
| slice
196+
| SliceConstant
197+
| ScalarConstant
198+
| TensorConstant
199+
| Variable,
200+
) -> int | np.integer | slice | None:
162201
"""Convert a symbolic index element to its Python equivalent.
163202
164203
This is like the inverse of `as_index_constant`
@@ -167,22 +206,8 @@ def as_index_literal(
167206
------
168207
NotScalarConstantError
169208
"""
170-
if idx == np.newaxis or isinstance(getattr(idx, "type", None), NoneTypeT):
171-
return np.newaxis
172-
173-
if isinstance(idx, Constant):
174-
return idx.data.item() if isinstance(idx, np.ndarray) else idx.data
175-
176-
if isinstance(idx, Variable):
177-
if (
178-
isinstance(idx.type, ps.ScalarType)
179-
and idx.owner
180-
and isinstance(idx.owner.op, ScalarFromTensor)
181-
):
182-
return as_index_literal(idx.owner.inputs[0])
183-
184-
if isinstance(idx.type, SliceType):
185-
idx = slice(*idx.owner.inputs)
209+
if idx is None or isinstance(idx, int | np.integer):
210+
return idx
186211

187212
if isinstance(idx, slice):
188213
return slice(
@@ -191,17 +216,64 @@ def as_index_literal(
191216
as_index_literal(idx.step),
192217
)
193218

219+
if not isinstance(idx, Variable):
220+
raise TypeError(f"Not an index element: {idx}")
221+
222+
if isinstance(idx.type, NoneTypeT):
223+
return None
224+
225+
if isinstance(idx, ScalarConstant):
226+
return cast(int, idx.data)
227+
228+
if (
229+
isinstance(idx.type, ps.ScalarType)
230+
and idx.owner
231+
and isinstance(idx.owner.op, ScalarFromTensor)
232+
):
233+
return cast(int | np.integer, as_index_literal(idx.owner.inputs[0]))
234+
235+
if isinstance(idx, TensorConstant):
236+
return cast(int, idx.data.item())
237+
238+
if isinstance(idx, SliceConstant):
239+
return cast(slice, idx.data)
240+
241+
if isinstance(idx.type, SliceType):
242+
assert idx.owner is not None
243+
return slice(*map(as_index_literal, idx.owner.inputs))
244+
245+
# Other kinds of variables are not supported
194246
raise NotScalarConstantError()
195247

196248

197249
def get_idx_list(inputs, idx_list):
198250
return indices_from_subtensor(inputs[1:], idx_list)
199251

200252

253+
@overload
254+
def get_canonical_form_slice(
255+
theslice: slice,
256+
length: int | np.integer | ScalarVariable | TensorVariable,
257+
) -> tuple[slice, int | ScalarConstant]: ...
258+
259+
260+
@overload
261+
def get_canonical_form_slice(
262+
theslice: int | np.integer | ScalarVariable | TensorVariable,
263+
length: int | np.integer | ScalarVariable | TensorVariable,
264+
) -> tuple[ScalarVariable, int]: ...
265+
266+
201267
def get_canonical_form_slice(
202-
theslice: slice | Variable, length: Variable
203-
) -> tuple[Variable, int]:
204-
"""Convert slices to canonical form.
268+
theslice: slice | int | np.integer | ScalarVariable | TensorVariable,
269+
length: int | np.integer | ScalarVariable | TensorVariable,
270+
) -> tuple[slice | ScalarVariable, int | ScalarConstant]:
271+
"""Convert indices or slices to canonical form.
272+
273+
Scalar integer indices or python Slices with Scalar/None attributes
274+
used in basic Subtensor Ops are supported.
275+
Symbolic slices (of SliceType) or vector indices
276+
used in advanced Subtensor Ops are not supported.
205277
206278
Given a slice [start:stop:step] transform it into a canonical form
207279
that respects the conventions imposed by python and numpy.
@@ -210,18 +282,28 @@ def get_canonical_form_slice(
210282
in which 0 <= start <= stop <= length and step > 0, and a flag which says
211283
if the resulting set of numbers needs to be reversed or not.
212284
285+
Given a scalar index `idx` that may or not be negative, convert it to
286+
a certainly positive form `idx if idx >= 0 else length + idx`.
287+
288+
Returns
289+
-------
290+
slc
291+
Canonical form slice or scalar variable.
292+
direction
293+
Direction to iterate the resulting elements in. (-1 or 1). May be symbolic.
213294
"""
214295
from pytensor.tensor import ge, lt, sign, switch
215296

297+
# Other non-slice types are the scalar indexing case
216298
if not isinstance(theslice, slice):
217-
try:
218-
value = as_index_literal(theslice)
219-
except NotScalarConstantError:
220-
value = theslice
221-
222-
value = switch(lt(value, 0), (value + length), value)
299+
if isinstance(theslice, int | np.integer | ScalarVariable) or (
300+
isinstance(theslice, TensorVariable) and theslice.ndim == 0
301+
):
302+
cano = switch(lt(theslice, 0), (theslice + length), theslice)
303+
return scalar_from_tensor(cano), 1
304+
raise ValueError(f"Slice {theslice} is not a supported slice type.")
223305

224-
return value, 1
306+
# At this point we have a slice object. Possibly with symbolic inputs.
225307

226308
def analyze(x):
227309
try:
@@ -243,6 +325,7 @@ def analyze(x):
243325
and is_step_constant
244326
and is_length_constant
245327
):
328+
assert isinstance(length, int)
246329
_start, _stop, _step = slice(start, stop, step).indices(length)
247330
if _start <= _stop and _step >= 1:
248331
return slice(_start, _stop, _step), 1
@@ -2917,7 +3000,7 @@ def take(a, indices, axis=None, mode="raise"):
29173000
return a[full_indices]
29183001

29193002

2920-
@_get_vector_length.register(Subtensor)
3003+
@_get_vector_length.register(Subtensor) # type: ignore
29213004
def _get_vector_length_Subtensor(op, var):
29223005
# If we take a slice, we know how many elements it will result in
29233006
# TODO: We can cover more `*Subtensor` cases.

scripts/mypy-failing.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ pytensor/tensor/random/op.py
2525
pytensor/tensor/random/utils.py
2626
pytensor/tensor/rewriting/basic.py
2727
pytensor/tensor/slinalg.py
28-
pytensor/tensor/subtensor.py
2928
pytensor/tensor/type.py
3029
pytensor/tensor/type_other.py
3130
pytensor/tensor/variable.py

tests/tensor/test_subtensor.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from pytensor.graph.op import get_test_value
1717
from pytensor.graph.rewriting.utils import is_same_graph
1818
from pytensor.printing import pprint
19-
from pytensor.scalar.basic import as_scalar
20-
from pytensor.tensor import get_vector_length, vectorize
19+
from pytensor.scalar.basic import as_scalar, int16
20+
from pytensor.tensor import as_tensor, get_vector_length, vectorize
2121
from pytensor.tensor.blockwise import Blockwise
2222
from pytensor.tensor.elemwise import DimShuffle
2323
from pytensor.tensor.math import exp, isinf
@@ -69,7 +69,13 @@
6969
tensor5,
7070
vector,
7171
)
72-
from pytensor.tensor.type_other import NoneConst, SliceConstant, make_slice, slicetype
72+
from pytensor.tensor.type_other import (
73+
NoneConst,
74+
SliceConstant,
75+
as_symbolic_slice,
76+
make_slice,
77+
slicetype,
78+
)
7379
from tests import unittest_tools as utt
7480
from tests.tensor.utils import inplace_func, integers_ranged, random
7581

@@ -106,11 +112,51 @@ def test_as_index_literal():
106112

107113

108114
class TestGetCanonicalFormSlice:
115+
@pytest.mark.parametrize(
116+
"idx",
117+
[
118+
NoneConst,
119+
None,
120+
as_symbolic_slice(slice(3, 7, 2)),
121+
as_symbolic_slice(slice(3, int16(), 2)),
122+
vector(),
123+
],
124+
)
125+
def test_unsupported_inputs(self, idx):
126+
with pytest.raises(ValueError, match="not a supported slice"):
127+
get_canonical_form_slice(idx, 5)
128+
109129
def test_scalar_constant(self):
110130
a = as_scalar(0)
111131
length = lscalar()
112132
res = get_canonical_form_slice(a, length)
113-
assert res[0].owner.op == ptb.switch
133+
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
134+
assert res[1] == 1
135+
136+
def test_tensor_constant(self):
137+
a = as_tensor(0)
138+
length = lscalar()
139+
res = get_canonical_form_slice(a, length)
140+
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
141+
assert res[1] == 1
142+
143+
def test_symbolic_scalar(self):
144+
a = int16()
145+
length = lscalar()
146+
res = get_canonical_form_slice(a, length)
147+
assert res[0].owner.op, ptb.switch
148+
assert res[1] == 1
149+
150+
def test_symbolic_tensor(self):
151+
a = lscalar()
152+
length = lscalar()
153+
res = get_canonical_form_slice(a, length)
154+
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
155+
assert res[1] == 1
156+
157+
def test_all_integer(self):
158+
res = get_canonical_form_slice(slice(1, 5, 2), 7)
159+
assert isinstance(res[0], slice)
114160
assert res[1] == 1
115161

116162
def test_all_symbolic(self):

0 commit comments

Comments
 (0)