Skip to content

Commit acb486e

Browse files
committed
Cleanup Scan symbolic buffer size graph
Graph was being broken by Scalar/Tensor conversions that prevented fusion
1 parent cddefef commit acb486e

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

pytensor/tensor/subtensor.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
alloc,
3434
get_scalar_constant_value,
3535
nonzero,
36-
scalar_from_tensor,
36+
)
37+
from pytensor.tensor.basic import (
38+
constant as tensor_constant,
3739
)
3840
from pytensor.tensor.blockwise import vectorize_node_fallback
3941
from pytensor.tensor.elemwise import DimShuffle
@@ -296,13 +298,30 @@ def get_canonical_form_slice(
296298
"""
297299
from pytensor.tensor import ge, lt, sign, switch
298300

301+
def undo_scalarization(x):
302+
"""Undo scalarization of a variable.
303+
304+
PyTensor Basic index operations use ScalarVariables for the indices/slice arguments.
305+
But reasoning symbolically about the result of multiple indexing operations, we usually
306+
want to work on TensorVariables, since rewrites work on those and not ScalarVariables.
307+
308+
This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants.
309+
"""
310+
if isinstance(x, ScalarVariable):
311+
if isinstance(x, ScalarConstant):
312+
return tensor_constant(x.data, dtype=x.dtype)
313+
elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor):
314+
return x.owner.inputs[0]
315+
return x
316+
299317
# Other non-slice types are the scalar indexing case
300318
if not isinstance(theslice, slice):
319+
theslice = undo_scalarization(theslice)
301320
if isinstance(theslice, int | np.integer | ScalarVariable) or (
302321
isinstance(theslice, TensorVariable) and theslice.ndim == 0
303322
):
304323
cano = switch(lt(theslice, 0), (theslice + length), theslice)
305-
return scalar_from_tensor(cano), 1
324+
return cano, 1
306325
raise ValueError(f"Slice {theslice} is not a supported slice type.")
307326

308327
# At this point we have a slice object. Possibly with symbolic inputs.
@@ -312,7 +331,7 @@ def analyze(x):
312331
x_constant = as_index_literal(x)
313332
is_constant = True
314333
except NotScalarConstantError:
315-
x_constant = x
334+
x_constant = undo_scalarization(x)
316335
is_constant = False
317336
return x_constant, is_constant
318337

0 commit comments

Comments
 (0)