33
33
alloc ,
34
34
get_scalar_constant_value ,
35
35
nonzero ,
36
- scalar_from_tensor ,
36
+ )
37
+ from pytensor .tensor .basic import (
38
+ constant as tensor_constant ,
37
39
)
38
40
from pytensor .tensor .blockwise import vectorize_node_fallback
39
41
from pytensor .tensor .elemwise import DimShuffle
@@ -296,13 +298,30 @@ def get_canonical_form_slice(
296
298
"""
297
299
from pytensor .tensor import ge , lt , sign , switch
298
300
301
+ def undo_scalarization (x ):
302
+ """Undo scalarization of a variable.
303
+
304
+ PyTensor Basic index operations use ScalarVariables for the indices/slice arguments.
305
+ But reasoning symbolically about the result of multiple indexing operations, we usually
306
+ want to work on TensorVariables, since rewrites work on those and not ScalarVariables.
307
+
308
+ This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants.
309
+ """
310
+ if isinstance (x , ScalarVariable ):
311
+ if isinstance (x , ScalarConstant ):
312
+ return tensor_constant (x .data , dtype = x .dtype )
313
+ elif x .owner is not None and isinstance (x .owner .op , ScalarFromTensor ):
314
+ return x .owner .inputs [0 ]
315
+ return x
316
+
299
317
# Other non-slice types are the scalar indexing case
300
318
if not isinstance (theslice , slice ):
319
+ theslice = undo_scalarization (theslice )
301
320
if isinstance (theslice , int | np .integer | ScalarVariable ) or (
302
321
isinstance (theslice , TensorVariable ) and theslice .ndim == 0
303
322
):
304
323
cano = switch (lt (theslice , 0 ), (theslice + length ), theslice )
305
- return scalar_from_tensor ( cano ) , 1
324
+ return cano , 1
306
325
raise ValueError (f"Slice { theslice } is not a supported slice type." )
307
326
308
327
# At this point we have a slice object. Possibly with symbolic inputs.
@@ -312,7 +331,7 @@ def analyze(x):
312
331
x_constant = as_index_literal (x )
313
332
is_constant = True
314
333
except NotScalarConstantError :
315
- x_constant = x
334
+ x_constant = undo_scalarization ( x )
316
335
is_constant = False
317
336
return x_constant , is_constant
318
337
0 commit comments