Skip to content

Commit aa25c4d

Browse files
committed
Add rewrite for Blockwise with Alloc inputs
Also prevent Alloc from constant_folding when it's used by Elemwise and Blockwise to avoid creating useless large arrays
1 parent df1eaed commit aa25c4d

File tree

4 files changed

+221
-9
lines changed

4 files changed

+221
-9
lines changed

pytensor/graph/basic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1777,6 +1777,7 @@ def equal_computations(
17771777
ys: list[Union[np.ndarray, Variable]],
17781778
in_xs: Optional[list[Variable]] = None,
17791779
in_ys: Optional[list[Variable]] = None,
1780+
strict_dtype=True,
17801781
) -> bool:
17811782
"""Checks if PyTensor graphs represent the same computations.
17821783
@@ -1908,7 +1909,10 @@ def compare_nodes(nd_x, nd_y, common, different):
19081909
if dx != dy:
19091910
if isinstance(dx, Constant) and isinstance(dy, Constant):
19101911
if not dx.equals(dy):
1911-
return False
1912+
if strict_dtype:
1913+
return False
1914+
elif not np.array_equal(dx.data, dy.data):
1915+
return False
19121916
else:
19131917
return False
19141918

pytensor/tensor/basic.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
as_tensor_variable,
4343
get_vector_length,
4444
)
45+
from pytensor.tensor.blockwise import Blockwise
4546
from pytensor.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise
4647
from pytensor.tensor.exceptions import NotScalarConstantError
4748
from pytensor.tensor.shape import (
@@ -1658,16 +1659,22 @@ def do_constant_folding(self, fgraph, node):
16581659
if not clients:
16591660
return False
16601661

1661-
for client in clients:
1662-
if client[0] == "output":
1662+
for client, idx in clients:
1663+
if client == "output":
16631664
# If the output is a constant, it will have to be deepcopied
16641665
# each time the function is called. So we do not fold.
16651666
return False
1667+
# Allow alloc to be lifted out of Elemwise before constant folding it
1668+
elif isinstance(client.op, Elemwise):
1669+
return None
1670+
# Same for Blockwise, unless it has no batch_dims
1671+
elif isinstance(client.op, Blockwise) and client.op.batch_ndim(client):
1672+
return None
16661673
elif (
16671674
# The following ops work inplace of their input id 0.
1668-
client[1] == 0
1675+
idx == 0
16691676
and isinstance(
1670-
client[0].op,
1677+
client.op,
16711678
(
16721679
# Ops that will work inplace on the Alloc. So if they
16731680
# get constant_folded, they would copy the

pytensor/tensor/rewriting/blockwise.py

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from typing import Optional
2+
13
from pytensor.compile.mode import optdb
2-
from pytensor.graph import node_rewriter
4+
from pytensor.graph import Constant, node_rewriter
35
from pytensor.graph.replace import vectorize_node
46
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
5-
from pytensor.tensor.basic import Alloc, ARange, shape_padleft
7+
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
68
from pytensor.tensor.blockwise import Blockwise
79
from pytensor.tensor.math import Dot
810
from pytensor.tensor.rewriting.basic import (
@@ -80,3 +82,120 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
8082
),
8183
):
8284
return local_useless_unbatched_blockwise.fn(fgraph, node)
85+
86+
87+
def _squeeze_left(x, stop_at_dim: Optional[int] = None):
88+
"""Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
89+
x_dims = x.type.broadcastable
90+
squeeze_ndim = len(x_dims) if all(x_dims) else x_dims.index(False)
91+
if stop_at_dim is not None:
92+
squeeze_ndim = min(squeeze_ndim, stop_at_dim)
93+
if squeeze_ndim == 0:
94+
return x
95+
return x.squeeze(axis=tuple(range(squeeze_ndim)))
96+
97+
98+
@register_specialize("shape_unsafe")
99+
@node_rewriter([Blockwise])
100+
def local_blockwise_alloc(fgraph, node):
101+
"""Push Allocs from the inputs to the output of Blockwise Ops.
102+
103+
BOp = Blockwise(Op, signature="(x),(x)->(x)")
104+
BOp(vector, alloc(vector, 10, 5)) -> alloc(BOp)(vector, vector), 10, 5)
105+
BOp(vector, alloc(scalar, 10, 5)) -> alloc(BOp)(vector, alloc(scalar, 5), 10, 5)
106+
BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector)
107+
"""
108+
109+
if not any(isinstance(inp.owner.op, Alloc) for inp in node.inputs if inp.owner):
110+
return None
111+
112+
op: Blockwise = node.op # type: ignore
113+
114+
batch_ndim = op.batch_ndim(node)
115+
if not batch_ndim:
116+
return None
117+
118+
new_inputs = []
119+
batch_shapes = []
120+
can_push_any_alloc = False
121+
for inp, inp_sig in zip(node.inputs, op.inputs_sig):
122+
if inp.owner and isinstance(inp.owner.op, Alloc):
123+
# Push batch dims from Alloc
124+
value, *shape = inp.owner.inputs
125+
126+
# Check what to do with the value of the Alloc
127+
squeezed_value = _squeeze_left(value, batch_ndim)
128+
missing_ndim = len(shape) - value.type.ndim
129+
if (
130+
((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:]
131+
) != inp.type.broadcastable[batch_ndim:]:
132+
# We still need an Alloc for the core dims
133+
core_shape = shape[batch_ndim:]
134+
# And the batch dims of the squeezed value
135+
squeezed_value_batch_ndim = squeezed_value.type.ndim - len(core_shape)
136+
batch_shape = [
137+
1 if broadcastable else dim
138+
for broadcastable, dim in zip(
139+
squeezed_value.type.broadcastable[:squeezed_value_batch_ndim],
140+
tuple(squeezed_value.shape)[:squeezed_value_batch_ndim],
141+
)
142+
]
143+
squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape)
144+
if squeezed_value.type.broadcastable == inp.type.broadcastable:
145+
# We can't change anything about this Alloc input
146+
new_inputs.append(inp)
147+
continue
148+
149+
# We can push batch dims of this Alloc input
150+
batch_shapes.append(
151+
tuple(
152+
1 if broadcastable else dim
153+
for broadcastable, dim in zip(
154+
inp.type.broadcastable, shape[:batch_ndim]
155+
)
156+
)
157+
)
158+
new_inputs.append(squeezed_value)
159+
can_push_any_alloc = True
160+
161+
else:
162+
# Nothing to do with this input other than removing dummy batch dims
163+
new_inputs.append(_squeeze_left(inp, batch_ndim))
164+
165+
if not can_push_any_alloc:
166+
return None
167+
168+
new_outs = node.op.make_node(*new_inputs).outputs
169+
170+
new_out_type = new_outs[0].type
171+
old_out_type = node.outputs[0].type
172+
if new_out_type.broadcastable != old_out_type.broadcastable:
173+
# An Alloc is still needed to broadcast the new output to the original shape
174+
# We pick the most parsimonious batch dim from the pushed Alloc
175+
missing_ndim = old_out_type.ndim - new_out_type.ndim
176+
batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim]
177+
for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples
178+
for batch_dim in batch_dims:
179+
if batch_dim == 1:
180+
continue
181+
if isinstance(batch_dim, Constant):
182+
# Give preference to Constants
183+
batch_shape[i] = batch_dim
184+
break
185+
elif old_out_type.broadcastable[i]:
186+
# Only use non Constant shapes if absolutely necessary
187+
# Otherwise, we use the shape of the non-alloc output
188+
batch_shape[i] = batch_dim
189+
190+
copy_stack_trace(node.outputs, new_outs)
191+
new_outs = [
192+
alloc(
193+
new_out,
194+
*batch_shape,
195+
*tuple(new_out.shape)[batch_ndim - missing_ndim :],
196+
)
197+
for new_out in new_outs
198+
]
199+
assert new_outs[0].type.broadcastable == old_out_type.broadcastable
200+
copy_stack_trace(node.outputs, new_outs)
201+
return new_outs

tests/tensor/rewriting/test_blockwise.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from functools import partial
2+
13
from pytensor import function
2-
from pytensor.graph import FunctionGraph
4+
from pytensor.graph import FunctionGraph, rewrite_graph
5+
from pytensor.graph.basic import equal_computations
36
from pytensor.scalar import log as scalar_log
4-
from pytensor.tensor import matrix, tensor3
7+
from pytensor.tensor import add, alloc, matrix, tensor, tensor3
58
from pytensor.tensor.blockwise import Blockwise
69
from pytensor.tensor.elemwise import Elemwise
710
from pytensor.tensor.nlinalg import MatrixPinv
@@ -36,3 +39,82 @@ def test_useless_unbatched_blockwise():
3639
fn = function([x], out, mode="FAST_COMPILE")
3740
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
3841
assert isinstance(fn.maker.fgraph.outputs[0].owner.op.core_op, MatrixPinv)
42+
43+
44+
def test_blockwise_alloc():
45+
rewrite = partial(
46+
rewrite_graph,
47+
include=("ShapeOpt", "specialize"),
48+
exclude=("local_useless_unbatched_blockwise",),
49+
)
50+
51+
vector_add = Blockwise(core_op=add, signature="(x),(x)->(x)")
52+
53+
# Depending on the rewrites the Alloc shape may be upcast to int64 or not
54+
# We do not care about that for the purposes of this test
55+
equal = partial(equal_computations, strict_dtype=False)
56+
57+
# Case where Alloc is not necessary
58+
x = tensor("x", shape=(7, 5))
59+
y = tensor("y", shape=(5,))
60+
out = vector_add(x, alloc(y, 7, 5))
61+
expected_out = vector_add(x, y)
62+
assert equal([rewrite(out)], [expected_out])
63+
64+
# Cases where Alloc can be fully pushed
65+
x = tensor("x", shape=(5,))
66+
y = tensor("y", shape=(5,))
67+
out = vector_add(x, alloc(y, 7, 5))
68+
expected_out = alloc(vector_add(x, y), 7, 5)
69+
assert equal([rewrite(out)], [expected_out])
70+
71+
x = tensor("x", shape=(1, 5))
72+
y = tensor("y", shape=(5,))
73+
out = vector_add(x, alloc(y, 7, 5))
74+
expected_out = alloc(vector_add(x.squeeze(0), y), 7, 5)
75+
assert equal([rewrite(out)], [expected_out])
76+
77+
x = tensor("x", shape=(7, 5))
78+
y = tensor("y", shape=(7, 5))
79+
out = vector_add(x, alloc(y, 3, 7, 5))
80+
expected_out = alloc(vector_add(x, y), 3, 7, 5)
81+
assert equal([rewrite(out)], [expected_out])
82+
83+
x = tensor("x", shape=(5,))
84+
y = tensor("y", shape=(7, 1, 5))
85+
out = vector_add(x, alloc(y, 7, 2, 5))
86+
expected_out = alloc(vector_add(x, y), 7, 2, 5)
87+
assert equal([rewrite(out)], [expected_out])
88+
89+
# Case where Alloc can be partially pushed
90+
x = tensor("x", shape=(5,))
91+
y = tensor("y", shape=())
92+
out = vector_add(x, alloc(y, 7, 5))
93+
expected_out = alloc(vector_add(x, alloc(y, 5)), 7, 5)
94+
assert equal([rewrite(out)], [expected_out])
95+
96+
x = tensor("x", shape=(5,))
97+
y = tensor("y", shape=(7, 1, 1))
98+
out = vector_add(x, alloc(y, 7, 2, 5))
99+
expected_out = alloc(vector_add(x, alloc(y, 7, 1, 5)), 7, 2, 5)
100+
assert equal([rewrite(out)], [expected_out], strict_dtype=False)
101+
102+
# Cases involving multiple Allocs being pushed
103+
x = tensor("x", shape=())
104+
y = tensor("y", shape=())
105+
out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5))
106+
expected_out = alloc(vector_add(alloc(x, 5), alloc(y, 5)), 3, 7, 5)
107+
assert equal([rewrite(out)], [expected_out])
108+
109+
x = tensor("x", shape=(5,))
110+
y = tensor("y", shape=())
111+
out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5))
112+
expected_out = alloc(vector_add(x, alloc(y, 5)), 3, 7, 5)
113+
assert equal([rewrite(out)], [expected_out])
114+
115+
# Case where Alloc cannot be pushed
116+
x = tensor("x", shape=(5,))
117+
y = tensor("y", shape=(1,))
118+
out = vector_add(x, alloc(y, 5))
119+
expected_out = out
120+
assert equal([rewrite(out)], [expected_out])

0 commit comments

Comments
 (0)