@@ -50,8 +50,11 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
50
50
raise AssertionError ()
51
51
52
52
53
- def OpKeyPatternNodeRewriter (p1 , p2 , ign = False ):
54
- return OpKeyGraphRewriter (PatternNodeRewriter (p1 , p2 ), ignore_newtrees = ign )
53
+ def OpKeyPatternNodeRewriter (p1 , p2 , allow_multiple_clients = False , ign = False ):
54
+ return OpKeyGraphRewriter (
55
+ PatternNodeRewriter (p1 , p2 , allow_multiple_clients = allow_multiple_clients ),
56
+ ignore_newtrees = ign ,
57
+ )
55
58
56
59
57
60
def WalkingPatternNodeRewriter (p1 , p2 , ign = True ):
@@ -207,13 +210,70 @@ def constraint(r):
207
210
assert str (g ) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
208
211
209
212
def test_allow_multiple_clients (self ):
210
- x , y , z = MyVariable ("x" ), MyVariable ("y" ), MyVariable ("z" )
211
- e0 = op1 (x , y )
212
- # `e0` has multiple clients (i.e. the `op4` and `op3` nodes)
213
- e = op3 (op4 (e0 ), e0 )
214
- g = FunctionGraph ([x , y , z ], [e ])
215
- OpKeyPatternNodeRewriter ((op4 , (op1 , "x" , "y" )), (op3 , "x" , "y" )).rewrite (g )
216
- assert str (g ) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))"
213
+ x , y , z = inputs = MyVariable ("x" ), MyVariable ("y" ), MyVariable ("z" )
214
+ w = op1 (x , y )
215
+ # `w` has multiple clients (i.e. the `op4` and `op3` nodes)
216
+ e = op3 (op4 (w ), w )
217
+
218
+ # By default, allow_multiple_clients is False
219
+ # So the replacement should fail
220
+ outputs = [e ]
221
+ g = FunctionGraph (inputs , outputs , copy_inputs = False )
222
+ OpKeyPatternNodeRewriter (
223
+ (op4 , (op1 , "x" , "y" )),
224
+ (op3 , "x" , "y" ),
225
+ ).rewrite (g )
226
+ assert equal_computations (g .outputs , outputs )
227
+
228
+ # Now it should be fine
229
+ g = FunctionGraph (inputs , outputs , copy_inputs = False )
230
+ OpKeyPatternNodeRewriter (
231
+ (op4 , (op1 , "x" , "y" )),
232
+ (op3 , "x" , "y" ),
233
+ allow_multiple_clients = True ,
234
+ ).rewrite (g )
235
+ assert equal_computations (g .outputs , [op3 (op3 (x , y ), w )])
236
+
237
+ # The fact that the inputs of the pattern have multiple clients should not matter
238
+ g = FunctionGraph (inputs , outputs , copy_inputs = False )
239
+ OpKeyPatternNodeRewriter (
240
+ (op3 , (op4 , "w" ), "w" ),
241
+ (op3 , "w" , "w" ),
242
+ allow_multiple_clients = False ,
243
+ ).rewrite (g )
244
+ assert equal_computations (g .outputs , [op3 (w , w )])
245
+
246
+ # The fact that are multiple clients above the inputs of the pattern should not matter
247
+ v = op4 (e )
248
+ e1 = op4 (v )
249
+ e2 = op1 (x , x ) # Irrelevant reuse of x that should not block rewrite either
250
+ e3 = op1 (v , v ) # Relevant reuse of v that should block rewrite
251
+
252
+ outputs = [e1 , e2 ]
253
+ g = FunctionGraph (inputs , outputs , copy_inputs = False )
254
+ OpKeyPatternNodeRewriter (
255
+ (op4 , (op4 , "e" )),
256
+ "e" ,
257
+ allow_multiple_clients = False ,
258
+ ).rewrite (g )
259
+ assert equal_computations (g .outputs , [e , e2 ])
260
+
261
+ outputs = [e1 , e3 ]
262
+ g = FunctionGraph ([x , y , z ], outputs , copy_inputs = False )
263
+ OpKeyPatternNodeRewriter (
264
+ (op4 , (op4 , "e" )),
265
+ "e" ,
266
+ allow_multiple_clients = False ,
267
+ ).rewrite (g )
268
+ assert equal_computations (g .outputs , outputs )
269
+
270
+ g = FunctionGraph (inputs , outputs , copy_inputs = False )
271
+ OpKeyPatternNodeRewriter (
272
+ (op4 , (op4 , "e" )),
273
+ "e" ,
274
+ allow_multiple_clients = True ,
275
+ ).rewrite (g )
276
+ assert equal_computations (g .outputs , [e , e3 ])
217
277
218
278
def test_eq (self ):
219
279
# replacing the whole graph
0 commit comments