Skip to content

Commit 9299e7b

Browse files
committed
Add rewrite for Blockwise with Alloc inputs
Also prevent Alloc from constant_folding when it's used by Elemwise and Blockwise to avoid creating useless large arrays
1 parent 7e8e2f6 commit 9299e7b

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed

pytensor/tensor/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
as_tensor_variable,
4343
get_vector_length,
4444
)
45+
from pytensor.tensor.blockwise import Blockwise
4546
from pytensor.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise
4647
from pytensor.tensor.exceptions import NotScalarConstantError
4748
from pytensor.tensor.shape import (
@@ -1663,6 +1664,9 @@ def do_constant_folding(self, fgraph, node):
16631664
# If the output is a constant, it will have to be deepcopied
16641665
# each time the function is called. So we do not fold.
16651666
return False
1667+
# Allow alloc to be lifted out of Elemwise and Blockwise, before constant folding it
1668+
elif isinstance(client[0].op, (Elemwise, Blockwise)):
1669+
return None
16661670
elif (
16671671
# The following ops work inplace of their input id 0.
16681672
client[1] == 0

pytensor/tensor/rewriting/blockwise.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from pytensor.graph import node_rewriter
33
from pytensor.graph.replace import vectorize_node
44
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
5-
from pytensor.tensor.basic import Alloc, ARange, shape_padleft
5+
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
66
from pytensor.tensor.blockwise import Blockwise
7+
from pytensor.tensor.extra_ops import broadcast_shape_iter
78
from pytensor.tensor.math import _matrix_matrix_matmul
89
from pytensor.tensor.rewriting.basic import (
910
register_canonicalize,
@@ -75,3 +76,94 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
7576
)
7677
):
7778
return local_useless_unbatched_blockwise.fn(fgraph, node)
79+
80+
81+
@register_specialize("shape_unsafe")
82+
@node_rewriter([Blockwise])
83+
def local_blockwise_alloc(fgraph, node):
84+
"""Push Allocs from the inputs to the output of Blockwise Ops."""
85+
86+
op: Blockwise = node.op # type: ignore
87+
88+
batch_ndim = node.inputs[0].type.ndim - len(op.inputs_sig[0])
89+
batch_axes = tuple(range(batch_ndim))
90+
new_inputs = []
91+
batch_shapes = []
92+
can_lift_alloc = False
93+
for inp, inp_sig in zip(node.inputs, op.inputs_sig):
94+
if all(inp.type.broadcastable[:batch_ndim]):
95+
# The input only has dummy batch dims (if it has any)
96+
inp = inp.squeeze(batch_axes)
97+
new_inputs.append(inp)
98+
continue
99+
100+
core_ndim = len(inp_sig)
101+
if inp.owner and isinstance(inp.owner.op, Alloc):
102+
value, *shape = inp.owner.inputs
103+
value_ndim = value.type.ndim
104+
value_batch_ndim = value_ndim - core_ndim
105+
if value_batch_ndim:
106+
# The original value already has batch dims, let's see if it's just dummy ones
107+
if all(value.type.broadcastable[:value_batch_ndim]):
108+
value = value.squeeze(axis=tuple(range(value_batch_ndim)))
109+
else:
110+
# The original value has batch dims that are not dummy
111+
# We cannot lift this Alloc
112+
new_inputs.append(inp)
113+
continue
114+
115+
alloc_ndim = len(shape)
116+
if alloc_ndim > core_ndim:
117+
# The Alloc adds all the batch dims
118+
batch_shape = shape[:batch_ndim]
119+
core_shape = shape[batch_ndim:]
120+
if any(value.type.broadcastable[batch_ndim:]):
121+
# We still need an Alloc for the core dims
122+
value = alloc(value, *core_shape)
123+
new_inputs.append(value)
124+
batch_shapes.append(
125+
[
126+
dim if not bcast else 1
127+
for dim, bcast in zip(
128+
batch_shape, inp.type.broadcastable[:batch_ndim]
129+
)
130+
]
131+
)
132+
can_lift_alloc = True
133+
continue
134+
135+
# Nothing to do with this input
136+
new_inputs.append(inp)
137+
138+
if not can_lift_alloc:
139+
return None
140+
141+
new_outs = node.op.make_node(*new_inputs).outputs
142+
143+
# Pushed Allocs are still needed
144+
if new_outs[0].type.broadcastable != node.outputs[0].type.broadcastable:
145+
out = new_outs[0]
146+
batch_ndim = out.type.ndim - len(op.outputs_sig[0])
147+
if batch_ndim:
148+
# The new output already has batch dims, we need to consider this when broadcasting
149+
bcast_shape = tuple(out.shape)[:batch_ndim]
150+
batch_shapes.append(
151+
[
152+
dim if not bcast else 1
153+
for dim, bcast in zip(
154+
bcast_shape, out.type.broadcastable[:batch_ndim]
155+
)
156+
]
157+
)
158+
if len(batch_shapes) == 1:
159+
[batch_shape] = batch_shapes
160+
else:
161+
batch_shape = broadcast_shape_iter(batch_shapes, arrays_are_shapes=True)
162+
core_shapes = [out.shape[batch_ndim:] for out in new_outs]
163+
new_outs = [
164+
alloc(new_out, *batch_shape, *core_shape)
165+
for new_out, core_shape in zip(new_outs, core_shapes)
166+
]
167+
168+
copy_stack_trace(node.outputs, new_outs)
169+
return new_outs

0 commit comments

Comments
 (0)