Skip to content

Commit befc177

Browse files
ferrinericardoV94
authored andcommitted
add graph_replace function
1 parent 1b67356 commit befc177

File tree

4 files changed

+174
-4
lines changed

4 files changed

+174
-4
lines changed

pytensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def disable_log_handler(logger=pytensor_logger, handler=logging_default_handler)
7474

7575
# isort: off
7676
from pytensor.graph.basic import Variable
77-
from pytensor.graph.replace import clone_replace
77+
from pytensor.graph.replace import clone_replace, graph_replace
7878

7979
# isort: on
8080

pytensor/graph/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
clone,
1010
ancestors,
1111
)
12-
from pytensor.graph.replace import clone_replace
12+
from pytensor.graph.replace import clone_replace, graph_replace
1313
from pytensor.graph.op import Op
1414
from pytensor.graph.type import Type
1515
from pytensor.graph.fg import FunctionGraph

pytensor/graph/replace.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import partial
12
from typing import (
23
Collection,
34
Dict,
@@ -10,7 +11,8 @@
1011
cast,
1112
)
1213

13-
from pytensor.graph.basic import Constant, Variable
14+
from pytensor.graph.basic import Constant, Variable, truncated_graph_inputs
15+
from pytensor.graph.fg import FunctionGraph
1416

1517

1618
def clone_replace(
@@ -58,3 +60,92 @@ def clone_replace(
5860
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
5961

6062
return cast(List[Variable], outs)
63+
64+
65+
def graph_replace(
66+
outputs: Sequence[Variable],
67+
replace: Dict[Variable, Variable],
68+
*,
69+
strict=True,
70+
) -> List[Variable]:
71+
"""Replace variables in ``outputs`` by ``replace``.
72+
73+
Parameters
74+
----------
75+
outputs: Sequence[Variable]
76+
Output graph
77+
replace: Dict[Variable, Variable]
78+
Replace mapping
79+
strict: bool
80+
Raise an error if some replacements were not used
81+
return_unused: bool
82+
Return replacements that were not used
83+
84+
Returns
85+
-------
86+
List[Variable]
87+
Output graph with subgraphs replaced
88+
89+
Raises
90+
------
91+
ValueError
92+
If some replacemens could not be applied and strict is True
93+
"""
94+
# collect minimum graph inputs which is required to compute outputs
95+
# and depend on replacements
96+
# additionally remove constants, they do not matter in clone get equiv
97+
conditions = [
98+
c
99+
for c in truncated_graph_inputs(outputs, replace)
100+
if not isinstance(c, Constant)
101+
]
102+
# for the function graph we need the clean graph where
103+
# inputs do not have owners
104+
# this is exactly the reason to clone conditions
105+
equiv = {c: c.clone(name=f"i-{i}") for i, c in enumerate(conditions)}
106+
# some replace keys may dissapear
107+
# the reason is they are outside the graph
108+
# clone the graph but preserve the equiv mapping
109+
fg = FunctionGraph(
110+
conditions,
111+
outputs,
112+
# clone_get_equiv kwargs
113+
copy_orphans=False,
114+
copy_inputs=False,
115+
memo=equiv,
116+
)
117+
# replace the conditions back
118+
fg_replace = {equiv[c]: c for c in conditions}
119+
# add the replacements on top of input mappings
120+
fg_replace.update({equiv[r]: v for r, v in replace.items() if r in equiv})
121+
# replacements have to be done in reverse topological order so that nested
122+
# expressions get recursively replaced correctly
123+
124+
# some replacements may be initially outside the graph
125+
# but later introduced by a replacement
126+
# So far FunctionGraph does these replacements inplace it is thus unsafe
127+
# apply them using fg.replace, it may change the original graph
128+
if strict:
129+
non_fg_replace = {r: v for r, v in replace.items() if r not in equiv}
130+
if non_fg_replace:
131+
raise ValueError(f"Some replacements were not used: {non_fg_replace}")
132+
toposort = fg.toposort()
133+
134+
def toposort_key(fg: FunctionGraph, ts, pair):
135+
key, _ = pair
136+
if key.owner is not None:
137+
return ts.index(key.owner)
138+
else:
139+
if key in fg.variables:
140+
return -1
141+
else:
142+
raise ValueError(f"{key} is not a part of graph")
143+
144+
sorted_replacements = sorted(
145+
tuple(fg_replace.items()),
146+
# sort based on the fg toposort, if a variable has no owner, it goes first
147+
key=partial(toposort_key, fg, toposort),
148+
reverse=True,
149+
)
150+
fg.replace_all(sorted_replacements, import_missing=True)
151+
return list(fg.outputs)

tests/graph/test_replace.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytensor.tensor as pt
55
from pytensor import config, function, shared
66
from pytensor.graph.basic import graph_inputs
7-
from pytensor.graph.replace import clone_replace
7+
from pytensor.graph.replace import clone_replace, graph_replace
88
from pytensor.tensor import dvector, fvector, vector
99
from tests import unittest_tools as utt
1010
from tests.graph.utils import MyOp, MyVariable
@@ -133,3 +133,82 @@ def test(x, y, mention_y):
133133
utt.assert_allclose(
134134
test(x, pt.sum((x + 1) ** 2), mention_y=True), 1.21000003815
135135
)
136+
137+
138+
class TestGraphReplace:
139+
def test_graph_replace(self):
140+
x = MyVariable("x")
141+
y = MyVariable("y")
142+
z = MyVariable("z")
143+
w = MyVariable("w")
144+
MyOp("zop")(z)
145+
x2 = MyOp("xop")(x, w)
146+
x2.name = "x2"
147+
y2 = MyOp("yop")(y)
148+
y2.name = "y2"
149+
150+
yc = graph_replace([x2], {x: y2})[0]
151+
assert yc.owner.inputs[0] is y2
152+
# the old reference is kept
153+
assert yc.owner.inputs[1] is w
154+
155+
# test replace itself
156+
yc = graph_replace([x2], {x2: y2})[0]
157+
assert yc is y2
158+
assert yc.owner.inputs[0] is y
159+
assert len(yc.owner.inputs) == 1
160+
161+
# the case where inputs have to be replaced in reverse topological order
162+
o = MyOp("xyop")(x2, y2)
163+
new_x = x.clone(name="x_new")
164+
new_y2 = y2.clone(name="y2_new")
165+
166+
oc = graph_replace([o], {x: new_x, y2: new_y2})[0]
167+
assert oc.owner.inputs[1] is new_y2
168+
assert oc.owner.inputs[0].owner.inputs[0] is new_x
169+
# the old reference is still kept
170+
assert oc.owner.inputs[0].owner.inputs[1] is w
171+
172+
def test_graph_replace_advanced(self):
173+
x = MyVariable("x")
174+
y = MyVariable("y")
175+
z = MyVariable("z")
176+
w = MyVariable("w")
177+
z2 = MyOp("zop")(z)
178+
x2 = MyOp("xop")(x, w)
179+
x2.name = "x2"
180+
y2 = MyOp("yop")(y)
181+
y2.name = "y2"
182+
o = MyOp("xyop")(x2, y2)
183+
new_x = x.clone(name="x_new")
184+
new_y2 = y2.clone(name="y2_new")
185+
new_y21 = MyOp("ny2op")(new_y2)
186+
# now yet another replacement that could only appear after new_y2: z
187+
# show we can do that after the prev clone
188+
# the case where new variable is referenced during the replacements
189+
new_y21 = MyOp("ny2op")(new_y2)
190+
# the reference new_y2: z2 is not a part of the original graph so the replacement is unsafe
191+
oc = graph_replace([o], {x: new_x, y2: new_y21})
192+
oc = graph_replace(oc, {new_y2: z2})[0]
193+
assert oc.owner.inputs[1].owner.inputs[0] is z2
194+
assert oc.owner.inputs[0].owner.inputs[0] is new_x
195+
# the old reference is still kept
196+
assert oc.owner.inputs[0].owner.inputs[1] is w
197+
198+
new_z = z.clone(name="z_new")
199+
oc = graph_replace([oc], {z: new_z})[0]
200+
# new reference appear
201+
assert oc.owner.inputs[1].owner.inputs[0] is not z2
202+
assert oc.owner.inputs[1].owner.inputs[0].owner.inputs[0] is new_z
203+
# the old reference is still kept
204+
assert oc.owner.inputs[0].owner.inputs[0] is new_x
205+
assert oc.owner.inputs[0].owner.inputs[1] is w
206+
207+
def test_graph_replace_disconnected(self):
208+
x = MyVariable("x")
209+
fake = MyOp("fake")(x)
210+
o = MyOp("o")(x)
211+
oc = graph_replace([o], {fake: x.clone()}, strict=False)
212+
assert oc[0] is o
213+
with pytest.raises(ValueError, match="Some replacements were not used"):
214+
oc = graph_replace([o], {fake: x.clone()}, strict=True)

0 commit comments

Comments
 (0)