Skip to content

Commit ec45e25

Browse files
committed
Do not reject PatternNodeRewriter due unrelated multiple clients
1 parent 2143d85 commit ec45e25

File tree

2 files changed

+78
-17
lines changed

2 files changed

+78
-17
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,14 +1616,6 @@ def transform(self, fgraph, node, get_nodes=True):
16161616
from etuples.core import ExpressionTuple
16171617
from unification import reify, unify
16181618

1619-
# TODO: We shouldn't need to iterate like this.
1620-
if not self.allow_multiple_clients and any(
1621-
len(fgraph.clients.get(v)) > 1
1622-
for v in vars_between(fgraph.inputs, node.outputs)
1623-
if v not in fgraph.inputs
1624-
):
1625-
return False
1626-
16271619
if get_nodes and self.get_nodes is not None:
16281620
for real_node in self.get_nodes(fgraph, node):
16291621
if real_node == "output":
@@ -1648,6 +1640,15 @@ def transform(self, fgraph, node, get_nodes=True):
16481640
if self.values_eq_approx:
16491641
ret.tag.values_eq_approx = self.values_eq_approx
16501642

1643+
if not self.allow_multiple_clients:
1644+
input_vars = list(s.values())
1645+
if any(
1646+
len(fgraph.clients[v]) > 1
1647+
for v in vars_between(input_vars, node.inputs)
1648+
if v not in input_vars
1649+
):
1650+
return False
1651+
16511652
if ret.owner:
16521653
if not (
16531654
len(node.outputs) == len(ret.owner.outputs)

tests/graph/rewriting/test_basic.py

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,11 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
5050
raise AssertionError()
5151

5252

53-
def OpKeyPatternNodeRewriter(p1, p2, ign=False):
54-
return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
53+
def OpKeyPatternNodeRewriter(p1, p2, allow_multiple_clients=False, ign=False):
54+
return OpKeyGraphRewriter(
55+
PatternNodeRewriter(p1, p2, allow_multiple_clients=allow_multiple_clients),
56+
ignore_newtrees=ign,
57+
)
5558

5659

5760
def WalkingPatternNodeRewriter(p1, p2, ign=True):
@@ -207,13 +210,70 @@ def constraint(r):
207210
assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
208211

209212
def test_allow_multiple_clients(self):
210-
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
211-
e0 = op1(x, y)
212-
# `e0` has multiple clients (i.e. the `op4` and `op3` nodes)
213-
e = op3(op4(e0), e0)
214-
g = FunctionGraph([x, y, z], [e])
215-
OpKeyPatternNodeRewriter((op4, (op1, "x", "y")), (op3, "x", "y")).rewrite(g)
216-
assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))"
213+
x, y, z = inputs = MyVariable("x"), MyVariable("y"), MyVariable("z")
214+
w = op1(x, y)
215+
# `w` has multiple clients (i.e. the `op4` and `op3` nodes)
216+
e = op3(op4(w), w)
217+
218+
# By default, allow_multiple_clients is False
219+
# So the replacement should fail
220+
outputs = [e]
221+
g = FunctionGraph(inputs, outputs, copy_inputs=False)
222+
OpKeyPatternNodeRewriter(
223+
(op4, (op1, "x", "y")),
224+
(op3, "x", "y"),
225+
).rewrite(g)
226+
assert equal_computations(g.outputs, outputs)
227+
228+
# Now it should be fine
229+
g = FunctionGraph(inputs, outputs, copy_inputs=False)
230+
OpKeyPatternNodeRewriter(
231+
(op4, (op1, "x", "y")),
232+
(op3, "x", "y"),
233+
allow_multiple_clients=True,
234+
).rewrite(g)
235+
assert equal_computations(g.outputs, [op3(op3(x, y), w)])
236+
237+
# The fact that the inputs of the pattern have multiple clients should not matter
238+
g = FunctionGraph(inputs, outputs, copy_inputs=False)
239+
OpKeyPatternNodeRewriter(
240+
(op3, (op4, "w"), "w"),
241+
(op3, "w", "w"),
242+
allow_multiple_clients=False,
243+
).rewrite(g)
244+
assert equal_computations(g.outputs, [op3(w, w)])
245+
246+
# The fact that are multiple clients above the inputs of the pattern should not matter
247+
v = op4(e)
248+
e1 = op4(v)
249+
e2 = op1(x, x) # Irrelevant reuse of x that should not block rewrite either
250+
e3 = op1(v, v) # Relevant reuse of v that should block rewrite
251+
252+
outputs = [e1, e2]
253+
g = FunctionGraph(inputs, outputs, copy_inputs=False)
254+
OpKeyPatternNodeRewriter(
255+
(op4, (op4, "e")),
256+
"e",
257+
allow_multiple_clients=False,
258+
).rewrite(g)
259+
assert equal_computations(g.outputs, [e, e2])
260+
261+
outputs = [e1, e3]
262+
g = FunctionGraph([x, y, z], outputs, copy_inputs=False)
263+
OpKeyPatternNodeRewriter(
264+
(op4, (op4, "e")),
265+
"e",
266+
allow_multiple_clients=False,
267+
).rewrite(g)
268+
assert equal_computations(g.outputs, outputs)
269+
270+
g = FunctionGraph(inputs, outputs, copy_inputs=False)
271+
OpKeyPatternNodeRewriter(
272+
(op4, (op4, "e")),
273+
"e",
274+
allow_multiple_clients=True,
275+
).rewrite(g)
276+
assert equal_computations(g.outputs, [e, e3])
217277

218278
def test_eq(self):
219279
# replacing the whole graph

0 commit comments

Comments
 (0)