Skip to content

Commit 80cccc6

Browse files
committed
Add rewrite to remove useless Blockwise AdvancedIncSubtensor
1 parent 9299e7b commit 80cccc6

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
register_infer_shape,
3030
switch,
3131
)
32+
from pytensor.tensor.blockwise import Blockwise
3233
from pytensor.tensor.elemwise import Elemwise
3334
from pytensor.tensor.exceptions import NotScalarConstantError
3435
from pytensor.tensor.math import Dot, add
@@ -1865,3 +1866,37 @@ def local_uint_constant_indices(fgraph, node):
18651866
copy_stack_trace(node.outputs, new_outs)
18661867

18671868
return new_outs
1869+
1870+
1871+
@register_canonicalize
1872+
@register_stabilize
1873+
@register_specialize
1874+
@node_rewriter([Blockwise])
1875+
def useless_blockwise_advanced_inc_subtensor(fgraph, node):
1876+
"""Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices."""
1877+
if not isinstance(node.op.core_op, AdvancedIncSubtensor):
1878+
return None
1879+
1880+
op: Blockwise = node.op # type: ignore
1881+
1882+
x, y, *idxs = node.inputs
1883+
1884+
batch_ndim = node.outputs[0].type.ndim - len(op.outputs_sig[0])
1885+
1886+
new_idxs = []
1887+
for idx in idxs:
1888+
if all(idx.type.broadcastable[:batch_ndim]):
1889+
new_idxs.append(idx.squeeze(tuple(range(batch_ndim))))
1890+
else:
1891+
# Rewrite does not apply
1892+
return None
1893+
1894+
# Need to broadcast x to final shape
1895+
if any(x.type.broadcastable[:batch_ndim]):
1896+
x = alloc(x, *tuple(y.shape)[:batch_ndim], *tuple(x.shape)[batch_ndim:])
1897+
1898+
new_idxs = [slice(None)] * batch_ndim + new_idxs
1899+
symbolic_idxs = x[tuple(new_idxs)].owner.inputs[1:]
1900+
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
1901+
copy_stack_trace(node.outputs, new_out)
1902+
return new_out

0 commit comments

Comments
 (0)