Skip to content

Commit ab304cb

Browse files
committed
Deprecate use of "default" and Variable as OpFromGrah overrides
1 parent 6dfc811 commit ab304cb

File tree

2 files changed

+52
-50
lines changed

2 files changed

+52
-50
lines changed

pytensor/compile/builders.py

Lines changed: 36 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import warnings
44
from collections import OrderedDict
5-
from collections.abc import Sequence
5+
from collections.abc import Callable, Sequence
66
from copy import copy
77
from functools import partial
8-
from typing import cast
8+
from typing import Union, cast
99

1010
import pytensor.tensor as pt
1111
from pytensor.compile.function import function
@@ -225,7 +225,7 @@ class OpFromGraph(Op, HasInnerGraph):
225225
e2 = op(x, y, z) + op(z, y, x)
226226
fn = function([x, y, z], [e2])
227227
228-
Example 3 override L_op
228+
Example 3 override second output of L_op
229229
230230
.. code-block:: python
231231
@@ -241,7 +241,7 @@ def rescale_dy(inps, outputs, out_grads):
241241
op = OpFromGraph(
242242
[x, y, z],
243243
[e],
244-
lop_overrides=['default', rescale_dy, 'default'],
244+
lop_overrides=[None, rescale_dy, None],
245245
)
246246
e2 = op(x, y, z)
247247
dx, dy, dz = grad(e2, [x, y, z])
@@ -253,7 +253,7 @@ def rescale_dy(inps, outputs, out_grads):
253253

254254
TYPE_ERR_MSG = (
255255
"L_op/gradient override should be (single or list of)"
256-
"'default' | OpFromGraph | callable | Variable "
256+
"None | OpFromGraph | callable | Variable "
257257
"with NullType or DisconnectedType, got %s"
258258
)
259259
STYPE_ERR_MSG = (
@@ -308,9 +308,9 @@ def __init__(
308308
outputs: list[Variable],
309309
*,
310310
inline: bool = False,
311-
lop_overrides: str = "default",
312-
grad_overrides: str = "default",
313-
rop_overrides: str = "default",
311+
lop_overrides: Union[Callable, "OpFromGraph", None] = None,
312+
grad_overrides: Union[Callable, "OpFromGraph", None] = None,
313+
rop_overrides: Union[Callable, "OpFromGraph", None] = None,
314314
connection_pattern: list[list[bool]] | None = None,
315315
strict: bool = False,
316316
name: str | None = None,
@@ -333,10 +333,10 @@ def __init__(
333333
334334
``False`` : will use a pre-compiled function inside.
335335
grad_overrides
336-
Defaults to ``'default'``.
336+
Defaults to ``None``.
337337
This argument is mutually exclusive with ``lop_overrides``.
338338
339-
``'default'`` : Do not override, use default grad() result
339+
``None`` : Do not override, use default grad() result
340340
341341
`OpFromGraph`: Override with another `OpFromGraph`, should
342342
accept inputs as the same order and types of ``inputs`` and ``output_grads``
@@ -346,14 +346,14 @@ def __init__(
346346
Each argument is expected to be a list of :class:`Variable `.
347347
Must return list of :class:`Variable `.
348348
lop_overrides
349-
Defaults to ``'default'``.
349+
Defaults to ``None``.
350350
351351
This argument is mutually exclusive with ``grad_overrides``.
352352
353353
These options are similar to the ``grad_overrides`` above, but for
354354
the :meth:`Op.L_op` method.
355355
356-
``'default'``: Do not override, use the default :meth:`Op.L_op` result
356+
``None``: Do not override, use the default :meth:`Op.L_op` result
357357
358358
`OpFromGraph`: Override with another `OpFromGraph`, should
359359
accept inputs as the same order and types of ``inputs``,
@@ -373,11 +373,11 @@ def __init__(
373373
a specific input, length of list must be equal to number of inputs.
374374
375375
rop_overrides
376-
One of ``{'default', OpFromGraph, callable, Variable}``.
376+
One of ``{None, OpFromGraph, callable, Variable}``.
377377
378-
Defaults to ``'default'``.
378+
Defaults to ``None``.
379379
380-
``'default'``: Do not override, use the default :meth:`Op.R_op` result
380+
``None``: Do not override, use the default :meth:`Op.R_op` result
381381
382382
`OpFromGraph`: Override with another `OpFromGraph`, should
383383
accept inputs as the same order and types of ``inputs`` and ``eval_points``
@@ -446,27 +446,37 @@ def __init__(
446446
self.input_types = [inp.type for inp in inputs]
447447
self.output_types = [out.type for out in outputs]
448448

449+
for override in (lop_overrides, grad_overrides, rop_overrides):
450+
if override == "default":
451+
raise ValueError(
452+
"'default' is no longer a valid value for overrides. Use None instead."
453+
)
454+
if isinstance(override, Variable):
455+
raise TypeError(
456+
"Variables are no longer valid types for overrides. Return them in a list for each output instead"
457+
)
458+
449459
self.lop_overrides = lop_overrides
450460
self.grad_overrides = grad_overrides
451461
self.rop_overrides = rop_overrides
452462

453-
if lop_overrides != "default":
454-
if grad_overrides != "default":
463+
if lop_overrides is not None:
464+
if grad_overrides is not None:
455465
raise ValueError(
456466
"lop_overrides and grad_overrides are mutually exclusive"
457467
)
458468
else:
459469
self.set_lop_overrides(lop_overrides)
460470
self._lop_type = "lop"
461-
elif grad_overrides != "default":
471+
elif grad_overrides is not None:
462472
warnings.warn(
463473
"grad_overrides is deprecated in favor of lop_overrides. Using it will lead to an error in the future.",
464474
FutureWarning,
465475
)
466476
self.set_lop_overrides(grad_overrides)
467477
self._lop_type = "grad"
468478
else:
469-
self.set_lop_overrides("default")
479+
self.set_lop_overrides(None)
470480
self._lop_type = "lop"
471481

472482
self.set_rop_overrides(rop_overrides)
@@ -546,7 +556,7 @@ def lop_op(inps, grads):
546556
callable_args = (local_inputs, output_grads)
547557

548558
# we need to convert _lop_op into an OfG instance
549-
if lop_op == "default":
559+
if lop_op is None:
550560
gdefaults_l = fn_grad(wrt=local_inputs)
551561
all_grads_l, all_grads_ov_l = zip(
552562
*[
@@ -556,12 +566,6 @@ def lop_op(inps, grads):
556566
)
557567
all_grads_l = list(all_grads_l)
558568
all_grads_ov_l = list(all_grads_ov_l)
559-
elif isinstance(lop_op, Variable):
560-
if isinstance(lop_op.type, DisconnectedType | NullType):
561-
all_grads_l = [inp.zeros_like() for inp in local_inputs]
562-
all_grads_ov_l = [lop_op.type() for _ in range(inp_len)]
563-
else:
564-
raise ValueError(self.STYPE_ERR_MSG % lop_op.type)
565569
elif isinstance(lop_op, list):
566570
goverrides_l = lop_op
567571
if len(goverrides_l) != inp_len:
@@ -571,15 +575,13 @@ def lop_op(inps, grads):
571575
)
572576
# compute non-overriding downsteam grads from upstreams grads
573577
# it's normal some input may be disconnected, thus the 'ignore'
574-
wrt_l = [
575-
lin for lin, gov in zip(local_inputs, goverrides_l) if gov == "default"
576-
]
578+
wrt_l = [lin for lin, gov in zip(local_inputs, goverrides_l) if gov is None]
577579
gdefaults = iter(fn_grad(wrt=wrt_l) if wrt_l else [])
578580
# combine overriding gradients
579581
all_grads_l = []
580582
all_grads_ov_l = []
581583
for inp, fn_gov in zip(local_inputs, goverrides_l):
582-
if fn_gov == "default":
584+
if fn_gov is None:
583585
gnext, gnext_ov = OpFromGraph._filter_grad_var(next(gdefaults), inp)
584586
all_grads_l.append(gnext)
585587
all_grads_ov_l.append(gnext_ov)
@@ -652,13 +654,13 @@ def _recompute_rop_op(self):
652654
fn_rop = partial(Rop, wrt=local_inputs, eval_points=eval_points)
653655
TYPE_ERR_MSG = (
654656
"R_op overrides should be (single or list of)"
655-
"OpFromGraph | 'default' | None | 0 | callable, got %s"
657+
"OpFromGraph, None, a list or a callable, got %s"
656658
)
657659
STYPE_ERR_MSG = (
658660
"Overriding Variable instance can only have type"
659661
" of DisconnectedType or NullType, got %s"
660662
)
661-
if rop_op == "default":
663+
if rop_op is None:
662664
rdefaults_l = fn_rop(f=local_outputs)
663665
all_rops_l, all_rops_ov_l = zip(
664666
*[
@@ -668,15 +670,6 @@ def _recompute_rop_op(self):
668670
)
669671
all_rops_l = list(all_rops_l)
670672
all_rops_ov_l = list(all_rops_ov_l)
671-
elif isinstance(rop_op, Variable):
672-
if isinstance(rop_op.type, NullType):
673-
all_rops_l = [inp.zeros_like() for inp in local_inputs]
674-
all_rops_ov_l = [rop_op.type() for _ in range(out_len)]
675-
elif isinstance(rop_op.type, DisconnectedType):
676-
all_rops_l = [inp.zeros_like() for inp in local_inputs]
677-
all_rops_ov_l = [None] * out_len
678-
else:
679-
raise ValueError(STYPE_ERR_MSG % rop_op.type)
680673
elif isinstance(rop_op, list):
681674
roverrides_l = rop_op
682675
if len(roverrides_l) != out_len:
@@ -686,15 +679,15 @@ def _recompute_rop_op(self):
686679
)
687680
# get outputs that does not have Rop override
688681
odefaults_l = [
689-
lo for lo, rov in zip(local_outputs, roverrides_l) if rov == "default"
682+
lo for lo, rov in zip(local_outputs, roverrides_l) if rov is None
690683
]
691684
rdefaults_l = fn_rop(f=odefaults_l)
692685
rdefaults = iter(rdefaults_l if odefaults_l else [])
693686
# combine overriding Rops
694687
all_rops_l = []
695688
all_rops_ov_l = []
696689
for out, fn_rov in zip(local_outputs, roverrides_l):
697-
if fn_rov == "default":
690+
if fn_rov is None:
698691
rnext, rnext_ov = OpFromGraph._filter_rop_var(next(rdefaults), out)
699692
all_rops_l.append(rnext)
700693
all_rops_ov_l.append(rnext_ov)
@@ -769,7 +762,6 @@ def set_grad_overrides(self, grad_overrides):
769762
self._lop_op = grad_overrides
770763
self._lop_op_is_cached = False
771764
self._lop_type = "grad"
772-
self._lop_is_default = grad_overrides == "default"
773765

774766
def set_lop_overrides(self, lop_overrides):
775767
"""
@@ -780,7 +772,6 @@ def set_lop_overrides(self, lop_overrides):
780772
self._lop_op = lop_overrides
781773
self._lop_op_is_cached = False
782774
self._lop_type = "lop"
783-
self._lop_is_default = lop_overrides == "default"
784775

785776
def set_rop_overrides(self, rop_overrides):
786777
"""
@@ -790,7 +781,6 @@ def set_rop_overrides(self, rop_overrides):
790781
"""
791782
self._rop_op = rop_overrides
792783
self._rop_op_is_cached = False
793-
self._rop_is_default = rop_overrides == "default"
794784

795785
def L_op(self, inputs, outputs, output_grads):
796786
if not self._lop_op_is_cached:

tests/compile/test_builders.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pytensor.gradient import DisconnectedType, Rop, disconnected_type, grad
1212
from pytensor.graph.basic import equal_computations
1313
from pytensor.graph.fg import FunctionGraph
14-
from pytensor.graph.null_type import NullType
14+
from pytensor.graph.null_type import NullType, null_type
1515
from pytensor.graph.rewriting.utils import rewrite_graph
1616
from pytensor.graph.utils import MissingInputError
1717
from pytensor.printing import debugprint
@@ -93,6 +93,20 @@ def test_size_changes(self, cls_ofg):
9393
assert res.shape == (2, 5)
9494
assert np.all(180.0 == res)
9595

96+
def test_overrides_deprecated_api(self):
97+
inp = scalar("x")
98+
out = inp + 1
99+
for kwarg in ("lop_overrides", "grad_overrides", "rop_overrides"):
100+
with pytest.raises(
101+
ValueError, match="'default' is no longer a valid value for overrides"
102+
):
103+
OpFromGraph([inp], [out], **{kwarg: "default"})
104+
105+
with pytest.raises(
106+
TypeError, match="Variables are no longer valid types for overrides"
107+
):
108+
OpFromGraph([inp], [out], **{kwarg: null_type()})
109+
96110
@pytest.mark.parametrize(
97111
"cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)]
98112
)
@@ -211,9 +225,7 @@ def go2(inps, gs):
211225
w, b = vectors("wb")
212226
# we make the 3rd gradient default (no override)
213227
with pytest.warns(FutureWarning, match="grad_overrides is deprecated"):
214-
op_linear = cls_ofg(
215-
[x, w, b], [x * w + b], grad_overrides=[go1, go2, "default"]
216-
)
228+
op_linear = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, go2, None])
217229
xx, ww, bb = vector("xx"), vector("yy"), vector("bb")
218230
zz = pt_sum(op_linear(xx, ww, bb))
219231
dx, dw, db = grad(zz, [xx, ww, bb])

0 commit comments

Comments
 (0)