Skip to content

Commit 6dfc811

Browse files
committed
Deprecate grad_overrides in OpFromGraph
1 parent ca8d60a commit 6dfc811

File tree

2 files changed

+34
-22
lines changed

2 files changed

+34
-22
lines changed

pytensor/compile/builders.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Define new Ops from existing Ops"""
22

3+
import warnings
34
from collections import OrderedDict
45
from collections.abc import Sequence
56
from copy import copy
@@ -189,7 +190,7 @@ class OpFromGraph(Op, HasInnerGraph):
189190
- For overriding, it's recommended to provide pure functions (no side
190191
effects like setting global variable) as callable(s). The callable(s)
191192
supplied for overriding gradient/rop will be called only once at the
192-
first call to grad/R_op, and will be converted to OpFromGraph instances.
193+
first call to L_op/R_op, and will be converted to OpFromGraph instances.
193194
194195
Examples
195196
--------
@@ -224,7 +225,7 @@ class OpFromGraph(Op, HasInnerGraph):
224225
e2 = op(x, y, z) + op(z, y, x)
225226
fn = function([x, y, z], [e2])
226227
227-
Example 3 override gradient
228+
Example 3 override L_op
228229
229230
.. code-block:: python
230231
@@ -233,12 +234,15 @@ class OpFromGraph(Op, HasInnerGraph):
233234
234235
x, y, z = pt.scalars('xyz')
235236
e = x + y * z
236-
def rescale_dy(inps, grads):
237+
def rescale_dy(inps, outputs, out_grads):
237238
x, y, z = inps
238-
g, = grads
239+
g, = out_grads
239240
return z*2
240241
op = OpFromGraph(
241-
[x, y, z], [e], grad_overrides=['default', rescale_dy, 'default']
242+
[x, y, z],
243+
[e],
244+
lop_overrides=['default', rescale_dy, 'default'],
245+
)
242246
e2 = op(x, y, z)
243247
dx, dy, dz = grad(e2, [x, y, z])
244248
fn = function([x, y, z], [dx, dy, dz])
@@ -455,6 +459,10 @@ def __init__(
455459
self.set_lop_overrides(lop_overrides)
456460
self._lop_type = "lop"
457461
elif grad_overrides != "default":
462+
warnings.warn(
463+
"grad_overrides is deprecated in favor of lop_overrides. Using it will lead to an error in the future.",
464+
FutureWarning,
465+
)
458466
self.set_lop_overrides(grad_overrides)
459467
self._lop_type = "grad"
460468
else:

tests/compile/test_builders.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,9 @@ def go(inps, gs):
181181
dedz = vector("dedz")
182182
op_mul_grad = cls_ofg([x, y, dedz], go([x, y], [dedz]))
183183

184-
op_mul = cls_ofg([x, y], [x * y], grad_overrides=go)
185-
op_mul2 = cls_ofg([x, y], [x * y], grad_overrides=op_mul_grad)
184+
with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
185+
op_mul = cls_ofg([x, y], [x * y], grad_overrides=go)
186+
op_mul2 = cls_ofg([x, y], [x * y], grad_overrides=op_mul_grad)
186187

187188
# single override case (function or OfG instance)
188189
xx, yy = vector("xx"), vector("yy")
@@ -209,9 +210,10 @@ def go2(inps, gs):
209210

210211
w, b = vectors("wb")
211212
# we make the 3rd gradient default (no override)
212-
op_linear = cls_ofg(
213-
[x, w, b], [x * w + b], grad_overrides=[go1, go2, "default"]
214-
)
213+
with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
214+
op_linear = cls_ofg(
215+
[x, w, b], [x * w + b], grad_overrides=[go1, go2, "default"]
216+
)
215217
xx, ww, bb = vector("xx"), vector("yy"), vector("bb")
216218
zz = pt_sum(op_linear(xx, ww, bb))
217219
dx, dw, db = grad(zz, [xx, ww, bb])
@@ -225,11 +227,12 @@ def go2(inps, gs):
225227
np.testing.assert_array_almost_equal(np.ones(16, dtype=config.floatX), dbv, 4)
226228

227229
# NullType and DisconnectedType
228-
op_linear2 = cls_ofg(
229-
[x, w, b],
230-
[x * w + b],
231-
grad_overrides=[go1, NullType()(), DisconnectedType()()],
232-
)
230+
with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
231+
op_linear2 = cls_ofg(
232+
[x, w, b],
233+
[x * w + b],
234+
grad_overrides=[go1, NullType()(), DisconnectedType()()],
235+
)
233236
zz2 = pt_sum(op_linear2(xx, ww, bb))
234237
dx2, dw2, db2 = grad(
235238
zz2,
@@ -339,13 +342,14 @@ def f1(x, y):
339342
def f1_back(inputs, output_gradients):
340343
return [output_gradients[0], disconnected_type()]
341344

342-
op = cls_ofg(
343-
inputs=[x, y],
344-
outputs=[f1(x, y)],
345-
grad_overrides=f1_back,
346-
connection_pattern=[[True], [False]], # This is new
347-
on_unused_input="ignore",
348-
) # This is new
345+
with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
346+
op = cls_ofg(
347+
inputs=[x, y],
348+
outputs=[f1(x, y)],
349+
grad_overrides=f1_back,
350+
connection_pattern=[[True], [False]],
351+
on_unused_input="ignore",
352+
)
349353

350354
c = op(x, y)
351355

0 commit comments

Comments
 (0)