Skip to content

Commit 3db127e

Browse files
committed
Avoid recreating Ops in local_uint_constant_indices
This prevents undoing the rewrite that introduces AdvancedSubtensor1
1 parent a14cb2b commit 3db127e

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,7 +1805,8 @@ def local_join_subtensors(fgraph, node):
18051805
def local_uint_constant_indices(fgraph, node):
18061806
"""Convert constant indices to unsigned dtypes."""
18071807

1808-
if isinstance(node.op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1):
1808+
op = node.op
1809+
if isinstance(op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1):
18091810
x, y, *indices = node.inputs
18101811
else:
18111812
x, *indices = node.inputs
@@ -1864,21 +1865,18 @@ def local_uint_constant_indices(fgraph, node):
18641865
if not has_new_index:
18651866
return False
18661867

1867-
new_out = x[tuple(new_indices)]
1868-
1869-
if y is not None:
1870-
new_out = inc_subtensor(
1871-
new_out,
1872-
y,
1873-
inplace=node.op.inplace,
1874-
set_instead_of_inc=node.op.set_instead_of_inc,
1875-
ignore_duplicates=getattr(node.op, "ignore_duplicates", False),
1876-
)
1877-
1878-
new_outs = new_out.owner.outputs
1879-
copy_stack_trace(node.outputs, new_outs)
1880-
1881-
return new_outs
1868+
if isinstance(op, Subtensor | IncSubtensor):
1869+
# Basic index Ops contain information about the dtype of the indices, so wee have to recreate them
1870+
props = op._props_dict()
1871+
props["idx_list"] = new_indices
1872+
op = type(op)(**props)
1873+
# Basic index Ops don't expect slices, but the respective start/step/stop
1874+
new_indices = get_slice_elements(new_indices)
1875+
1876+
new_args = (x, *new_indices) if y is None else (x, y, *new_indices)
1877+
new_out = op(*new_args)
1878+
copy_stack_trace(node.outputs[0], new_out)
1879+
return [new_out]
18821880

18831881

18841882
@register_canonicalize("shape_unsafe")

pytensor/tensor/subtensor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,10 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
550550
return res_shape
551551

552552

553-
def get_slice_elements(idxs: list, cond: Callable) -> list:
553+
def get_slice_elements(
554+
idxs: list,
555+
cond: Callable = lambda x: isinstance(x, Variable),
556+
) -> list:
554557
"""Extract slice elements conditional on a given predicate function.
555558
556559
Parameters

0 commit comments

Comments
 (0)