@@ -181,8 +181,9 @@ def go(inps, gs):
181
181
dedz = vector ("dedz" )
182
182
op_mul_grad = cls_ofg ([x , y , dedz ], go ([x , y ], [dedz ]))
183
183
184
- op_mul = cls_ofg ([x , y ], [x * y ], grad_overrides = go )
185
- op_mul2 = cls_ofg ([x , y ], [x * y ], grad_overrides = op_mul_grad )
184
+ with pytest .warns (FutureWarning , match = "grad_overrides is deprecated" ):
185
+ op_mul = cls_ofg ([x , y ], [x * y ], grad_overrides = go )
186
+ op_mul2 = cls_ofg ([x , y ], [x * y ], grad_overrides = op_mul_grad )
186
187
187
188
# single override case (function or OfG instance)
188
189
xx , yy = vector ("xx" ), vector ("yy" )
@@ -209,9 +210,10 @@ def go2(inps, gs):
209
210
210
211
w , b = vectors ("wb" )
211
212
# we make the 3rd gradient default (no override)
212
- op_linear = cls_ofg (
213
- [x , w , b ], [x * w + b ], grad_overrides = [go1 , go2 , "default" ]
214
- )
213
+ with pytest .warns (FutureWarning , match = "grad_overrides is deprecated" ):
214
+ op_linear = cls_ofg (
215
+ [x , w , b ], [x * w + b ], grad_overrides = [go1 , go2 , "default" ]
216
+ )
215
217
xx , ww , bb = vector ("xx" ), vector ("yy" ), vector ("bb" )
216
218
zz = pt_sum (op_linear (xx , ww , bb ))
217
219
dx , dw , db = grad (zz , [xx , ww , bb ])
@@ -225,11 +227,12 @@ def go2(inps, gs):
225
227
np .testing .assert_array_almost_equal (np .ones (16 , dtype = config .floatX ), dbv , 4 )
226
228
227
229
# NullType and DisconnectedType
228
- op_linear2 = cls_ofg (
229
- [x , w , b ],
230
- [x * w + b ],
231
- grad_overrides = [go1 , NullType ()(), DisconnectedType ()()],
232
- )
230
+ with pytest .warns (FutureWarning , match = "grad_overrides is deprecated" ):
231
+ op_linear2 = cls_ofg (
232
+ [x , w , b ],
233
+ [x * w + b ],
234
+ grad_overrides = [go1 , NullType ()(), DisconnectedType ()()],
235
+ )
233
236
zz2 = pt_sum (op_linear2 (xx , ww , bb ))
234
237
dx2 , dw2 , db2 = grad (
235
238
zz2 ,
@@ -339,13 +342,14 @@ def f1(x, y):
339
342
def f1_back (inputs , output_gradients ):
340
343
return [output_gradients [0 ], disconnected_type ()]
341
344
342
- op = cls_ofg (
343
- inputs = [x , y ],
344
- outputs = [f1 (x , y )],
345
- grad_overrides = f1_back ,
346
- connection_pattern = [[True ], [False ]], # This is new
347
- on_unused_input = "ignore" ,
348
- ) # This is new
345
+ with pytest .warns (FutureWarning , match = "grad_overrides is deprecated" ):
346
+ op = cls_ofg (
347
+ inputs = [x , y ],
348
+ outputs = [f1 (x , y )],
349
+ grad_overrides = f1_back ,
350
+ connection_pattern = [[True ], [False ]],
351
+ on_unused_input = "ignore" ,
352
+ )
349
353
350
354
c = op (x , y )
351
355
0 commit comments