3
3
from collections .abc import Callable , Iterable
4
4
from itertools import chain , groupby
5
5
from textwrap import dedent
6
+ from typing import cast , overload
6
7
7
8
import numpy as np
8
9
19
20
from pytensor .link .c .params_type import ParamsType
20
21
from pytensor .misc .safe_asarray import _asarray
21
22
from pytensor .printing import Printer , pprint , set_precedence
22
- from pytensor .scalar .basic import ScalarConstant
23
- from pytensor .tensor import _get_vector_length , as_tensor_variable , get_vector_length
23
+ from pytensor .scalar .basic import ScalarConstant , ScalarVariable
24
+ from pytensor .tensor import (
25
+ TensorLike ,
26
+ _get_vector_length ,
27
+ as_tensor_variable ,
28
+ get_vector_length ,
29
+ )
24
30
from pytensor .tensor .basic import (
25
31
ScalarFromTensor ,
26
32
alloc ,
27
33
get_underlying_scalar_constant_value ,
28
34
nonzero ,
35
+ scalar_from_tensor ,
29
36
)
30
37
from pytensor .tensor .blockwise import vectorize_node_fallback
31
38
from pytensor .tensor .elemwise import DimShuffle
51
58
wscalar ,
52
59
zscalar ,
53
60
)
54
- from pytensor .tensor .type_other import NoneConst , NoneTypeT , SliceType , make_slice
55
- from pytensor .tensor .variable import TensorVariable
61
+ from pytensor .tensor .type_other import (
62
+ NoneConst ,
63
+ NoneTypeT ,
64
+ SliceConstant ,
65
+ SliceType ,
66
+ make_slice ,
67
+ )
68
+ from pytensor .tensor .variable import TensorConstant , TensorVariable
56
69
57
70
58
71
_logger = logging .getLogger ("pytensor.tensor.subtensor" )
@@ -134,7 +147,7 @@ def convert_indices(indices, entry):
134
147
135
148
136
149
def as_index_constant (
137
- a : slice | int | np .integer | Variable | None ,
150
+ a : slice | int | np .integer | Variable | None | TensorLike ,
138
151
) -> Variable | slice | None :
139
152
r"""Convert Python literals to PyTensor constants--when possible--in `Subtensor` arguments.
140
153
@@ -150,15 +163,41 @@ def as_index_constant(
150
163
)
151
164
elif isinstance (a , int | np .integer ):
152
165
return ps .ScalarConstant (ps .int64 , a )
153
- elif not isinstance (a , Variable ):
154
- return as_tensor_variable (a )
155
- else :
166
+ elif isinstance (a , Variable ):
156
167
return a
168
+ return as_tensor_variable (a )
169
+
170
+
171
+ @overload
172
+ def as_index_literal (idx : int | np .integer ) -> int | np .integer : ...
173
+
174
+
175
+ @overload
176
+ def as_index_literal (idx : None ) -> None : ...
177
+
178
+
179
+ @overload
180
+ def as_index_literal (idx : slice | SliceConstant ) -> slice : ...
181
+
182
+
183
+ @overload
184
+ def as_index_literal (idx : ScalarConstant | TensorConstant ) -> int | np .integer : ...
185
+
186
+
187
+ @overload
188
+ def as_index_literal (idx : Variable ): ...
157
189
158
190
159
191
def as_index_literal (
160
- idx : Variable | slice | None ,
161
- ) -> int | slice | None :
192
+ idx : None
193
+ | int
194
+ | np .integer
195
+ | slice
196
+ | SliceConstant
197
+ | ScalarConstant
198
+ | TensorConstant
199
+ | Variable ,
200
+ ) -> int | np .integer | slice | None :
162
201
"""Convert a symbolic index element to its Python equivalent.
163
202
164
203
This is like the inverse of `as_index_constant`
@@ -167,22 +206,8 @@ def as_index_literal(
167
206
------
168
207
NotScalarConstantError
169
208
"""
170
- if idx == np .newaxis or isinstance (getattr (idx , "type" , None ), NoneTypeT ):
171
- return np .newaxis
172
-
173
- if isinstance (idx , Constant ):
174
- return idx .data .item () if isinstance (idx , np .ndarray ) else idx .data
175
-
176
- if isinstance (idx , Variable ):
177
- if (
178
- isinstance (idx .type , ps .ScalarType )
179
- and idx .owner
180
- and isinstance (idx .owner .op , ScalarFromTensor )
181
- ):
182
- return as_index_literal (idx .owner .inputs [0 ])
183
-
184
- if isinstance (idx .type , SliceType ):
185
- idx = slice (* idx .owner .inputs )
209
+ if idx is None or isinstance (idx , int | np .integer ):
210
+ return idx
186
211
187
212
if isinstance (idx , slice ):
188
213
return slice (
@@ -191,17 +216,64 @@ def as_index_literal(
191
216
as_index_literal (idx .step ),
192
217
)
193
218
219
+ if not isinstance (idx , Variable ):
220
+ raise TypeError (f"Not an index element: { idx } " )
221
+
222
+ if isinstance (idx .type , NoneTypeT ):
223
+ return None
224
+
225
+ if isinstance (idx , ScalarConstant ):
226
+ return cast (int , idx .data )
227
+
228
+ if (
229
+ isinstance (idx .type , ps .ScalarType )
230
+ and idx .owner
231
+ and isinstance (idx .owner .op , ScalarFromTensor )
232
+ ):
233
+ return cast (int | np .integer , as_index_literal (idx .owner .inputs [0 ]))
234
+
235
+ if isinstance (idx , TensorConstant ):
236
+ return cast (int , idx .data .item ())
237
+
238
+ if isinstance (idx , SliceConstant ):
239
+ return cast (slice , idx .data )
240
+
241
+ if isinstance (idx .type , SliceType ):
242
+ assert idx .owner is not None
243
+ return slice (* map (as_index_literal , idx .owner .inputs ))
244
+
245
+ # Other kinds of variables are not supported
194
246
raise NotScalarConstantError ()
195
247
196
248
197
249
def get_idx_list (inputs , idx_list ):
198
250
return indices_from_subtensor (inputs [1 :], idx_list )
199
251
200
252
253
+ @overload
254
+ def get_canonical_form_slice (
255
+ theslice : slice ,
256
+ length : int | np .integer | ScalarVariable | TensorVariable ,
257
+ ) -> tuple [slice , int | ScalarConstant ]: ...
258
+
259
+
260
+ @overload
261
+ def get_canonical_form_slice (
262
+ theslice : int | np .integer | ScalarVariable | TensorVariable ,
263
+ length : int | np .integer | ScalarVariable | TensorVariable ,
264
+ ) -> tuple [ScalarVariable , int ]: ...
265
+
266
+
201
267
def get_canonical_form_slice (
202
- theslice : slice | Variable , length : Variable
203
- ) -> tuple [Variable , int ]:
204
- """Convert slices to canonical form.
268
+ theslice : slice | int | np .integer | ScalarVariable | TensorVariable ,
269
+ length : int | np .integer | ScalarVariable | TensorVariable ,
270
+ ) -> tuple [slice | ScalarVariable , int | ScalarConstant ]:
271
+ """Convert indices or slices to canonical form.
272
+
273
+ Scalar integer indices or python Slices with Scalar/None attributes
274
+ used in basic Subtensor Ops are supported.
275
+ Symbolic slices (of SliceType) or vector indices
276
+ used in advanced Subtensor Ops are not supported.
205
277
206
278
Given a slice [start:stop:step] transform it into a canonical form
207
279
that respects the conventions imposed by python and numpy.
@@ -210,18 +282,28 @@ def get_canonical_form_slice(
210
282
in which 0 <= start <= stop <= length and step > 0, and a flag which says
211
283
if the resulting set of numbers needs to be reversed or not.
212
284
285
+ Given a scalar index `idx` that may or not be negative, convert it to
286
+ a certainly positive form `idx if idx >= 0 else length + idx`.
287
+
288
+ Returns
289
+ -------
290
+ slc
291
+ Canonical form slice or scalar variable.
292
+ direction
293
+ Direction to iterate the resulting elements in. (-1 or 1). May be symbolic.
213
294
"""
214
295
from pytensor .tensor import ge , lt , sign , switch
215
296
297
+ # Other non-slice types are the scalar indexing case
216
298
if not isinstance (theslice , slice ):
217
- try :
218
- value = as_index_literal (theslice )
219
- except NotScalarConstantError :
220
- value = theslice
221
-
222
- value = switch ( lt ( value , 0 ), ( value + length ), value )
299
+ if isinstance ( theslice , int | np . integer | ScalarVariable ) or (
300
+ isinstance (theslice , TensorVariable ) and theslice . ndim == 0
301
+ ) :
302
+ cano = switch ( lt ( theslice , 0 ), ( theslice + length ), theslice )
303
+ return scalar_from_tensor ( cano ), 1
304
+ raise ValueError ( f"Slice { theslice } is not a supported slice type." )
223
305
224
- return value , 1
306
+ # At this point we have a slice object. Possibly with symbolic inputs.
225
307
226
308
def analyze (x ):
227
309
try :
@@ -243,6 +325,7 @@ def analyze(x):
243
325
and is_step_constant
244
326
and is_length_constant
245
327
):
328
+ assert isinstance (length , int )
246
329
_start , _stop , _step = slice (start , stop , step ).indices (length )
247
330
if _start <= _stop and _step >= 1 :
248
331
return slice (_start , _stop , _step ), 1
@@ -2917,7 +3000,7 @@ def take(a, indices, axis=None, mode="raise"):
2917
3000
return a [full_indices ]
2918
3001
2919
3002
2920
- @_get_vector_length .register (Subtensor )
3003
+ @_get_vector_length .register (Subtensor ) # type: ignore
2921
3004
def _get_vector_length_Subtensor (op , var ):
2922
3005
# If we take a slice, we know how many elements it will result in
2923
3006
# TODO: We can cover more `*Subtensor` cases.
0 commit comments