2
2
3
3
import warnings
4
4
from collections import OrderedDict
5
- from collections .abc import Sequence
5
+ from collections .abc import Callable , Sequence
6
6
from copy import copy
7
7
from functools import partial
8
- from typing import cast
8
+ from typing import Union , cast
9
9
10
10
import pytensor .tensor as pt
11
11
from pytensor .compile .function import function
@@ -225,7 +225,7 @@ class OpFromGraph(Op, HasInnerGraph):
225
225
e2 = op(x, y, z) + op(z, y, x)
226
226
fn = function([x, y, z], [e2])
227
227
228
- Example 3 override L_op
228
+ Example 3 override second output of L_op
229
229
230
230
.. code-block:: python
231
231
@@ -241,7 +241,7 @@ def rescale_dy(inps, outputs, out_grads):
241
241
op = OpFromGraph(
242
242
[x, y, z],
243
243
[e],
244
- lop_overrides=['default' , rescale_dy, 'default' ],
244
+ lop_overrides=[None , rescale_dy, None ],
245
245
)
246
246
e2 = op(x, y, z)
247
247
dx, dy, dz = grad(e2, [x, y, z])
@@ -253,7 +253,7 @@ def rescale_dy(inps, outputs, out_grads):
253
253
254
254
TYPE_ERR_MSG = (
255
255
"L_op/gradient override should be (single or list of)"
256
- "'default' | OpFromGraph | callable | Variable "
256
+ "None | OpFromGraph | callable | Variable "
257
257
"with NullType or DisconnectedType, got %s"
258
258
)
259
259
STYPE_ERR_MSG = (
@@ -308,9 +308,9 @@ def __init__(
308
308
outputs : list [Variable ],
309
309
* ,
310
310
inline : bool = False ,
311
- lop_overrides : str = "default" ,
312
- grad_overrides : str = "default" ,
313
- rop_overrides : str = "default" ,
311
+ lop_overrides : Union [ Callable , "OpFromGraph" , None ] = None ,
312
+ grad_overrides : Union [ Callable , "OpFromGraph" , None ] = None ,
313
+ rop_overrides : Union [ Callable , "OpFromGraph" , None ] = None ,
314
314
connection_pattern : list [list [bool ]] | None = None ,
315
315
strict : bool = False ,
316
316
name : str | None = None ,
@@ -333,10 +333,10 @@ def __init__(
333
333
334
334
``False`` : will use a pre-compiled function inside.
335
335
grad_overrides
336
- Defaults to ``'default' ``.
336
+ Defaults to ``None ``.
337
337
This argument is mutually exclusive with ``lop_overrides``.
338
338
339
- ``'default' `` : Do not override, use default grad() result
339
+ ``None `` : Do not override, use default grad() result
340
340
341
341
`OpFromGraph`: Override with another `OpFromGraph`, should
342
342
accept inputs as the same order and types of ``inputs`` and ``output_grads``
@@ -346,14 +346,14 @@ def __init__(
346
346
Each argument is expected to be a list of :class:`Variable `.
347
347
Must return list of :class:`Variable `.
348
348
lop_overrides
349
- Defaults to ``'default' ``.
349
+ Defaults to ``None ``.
350
350
351
351
This argument is mutually exclusive with ``grad_overrides``.
352
352
353
353
These options are similar to the ``grad_overrides`` above, but for
354
354
the :meth:`Op.L_op` method.
355
355
356
- ``'default' ``: Do not override, use the default :meth:`Op.L_op` result
356
+ ``None ``: Do not override, use the default :meth:`Op.L_op` result
357
357
358
358
`OpFromGraph`: Override with another `OpFromGraph`, should
359
359
accept inputs as the same order and types of ``inputs``,
@@ -373,11 +373,11 @@ def __init__(
373
373
a specific input, length of list must be equal to number of inputs.
374
374
375
375
rop_overrides
376
- One of ``{'default' , OpFromGraph, callable, Variable}``.
376
+ One of ``{None , OpFromGraph, callable, Variable}``.
377
377
378
- Defaults to ``'default' ``.
378
+ Defaults to ``None ``.
379
379
380
- ``'default' ``: Do not override, use the default :meth:`Op.R_op` result
380
+ ``None ``: Do not override, use the default :meth:`Op.R_op` result
381
381
382
382
`OpFromGraph`: Override with another `OpFromGraph`, should
383
383
accept inputs as the same order and types of ``inputs`` and ``eval_points``
@@ -446,27 +446,37 @@ def __init__(
446
446
self .input_types = [inp .type for inp in inputs ]
447
447
self .output_types = [out .type for out in outputs ]
448
448
449
+ for override in (lop_overrides , grad_overrides , rop_overrides ):
450
+ if override == "default" :
451
+ raise ValueError (
452
+ "'default' is no longer a valid value for overrides. Use None instead."
453
+ )
454
+ if isinstance (override , Variable ):
455
+ raise TypeError (
456
+ "Variables are no longer valid types for overrides. Return them in a list for each output instead"
457
+ )
458
+
449
459
self .lop_overrides = lop_overrides
450
460
self .grad_overrides = grad_overrides
451
461
self .rop_overrides = rop_overrides
452
462
453
- if lop_overrides != "default" :
454
- if grad_overrides != "default" :
463
+ if lop_overrides is not None :
464
+ if grad_overrides is not None :
455
465
raise ValueError (
456
466
"lop_overrides and grad_overrides are mutually exclusive"
457
467
)
458
468
else :
459
469
self .set_lop_overrides (lop_overrides )
460
470
self ._lop_type = "lop"
461
- elif grad_overrides != "default" :
471
+ elif grad_overrides is not None :
462
472
warnings .warn (
463
473
"grad_overrides is deprecated in favor of lop_overrides. Using it will lead to an error in the future." ,
464
474
FutureWarning ,
465
475
)
466
476
self .set_lop_overrides (grad_overrides )
467
477
self ._lop_type = "grad"
468
478
else :
469
- self .set_lop_overrides ("default" )
479
+ self .set_lop_overrides (None )
470
480
self ._lop_type = "lop"
471
481
472
482
self .set_rop_overrides (rop_overrides )
@@ -546,7 +556,7 @@ def lop_op(inps, grads):
546
556
callable_args = (local_inputs , output_grads )
547
557
548
558
# we need to convert _lop_op into an OfG instance
549
- if lop_op == "default" :
559
+ if lop_op is None :
550
560
gdefaults_l = fn_grad (wrt = local_inputs )
551
561
all_grads_l , all_grads_ov_l = zip (
552
562
* [
@@ -556,12 +566,6 @@ def lop_op(inps, grads):
556
566
)
557
567
all_grads_l = list (all_grads_l )
558
568
all_grads_ov_l = list (all_grads_ov_l )
559
- elif isinstance (lop_op , Variable ):
560
- if isinstance (lop_op .type , DisconnectedType | NullType ):
561
- all_grads_l = [inp .zeros_like () for inp in local_inputs ]
562
- all_grads_ov_l = [lop_op .type () for _ in range (inp_len )]
563
- else :
564
- raise ValueError (self .STYPE_ERR_MSG % lop_op .type )
565
569
elif isinstance (lop_op , list ):
566
570
goverrides_l = lop_op
567
571
if len (goverrides_l ) != inp_len :
@@ -571,15 +575,13 @@ def lop_op(inps, grads):
571
575
)
572
576
# compute non-overriding downsteam grads from upstreams grads
573
577
# it's normal some input may be disconnected, thus the 'ignore'
574
- wrt_l = [
575
- lin for lin , gov in zip (local_inputs , goverrides_l ) if gov == "default"
576
- ]
578
+ wrt_l = [lin for lin , gov in zip (local_inputs , goverrides_l ) if gov is None ]
577
579
gdefaults = iter (fn_grad (wrt = wrt_l ) if wrt_l else [])
578
580
# combine overriding gradients
579
581
all_grads_l = []
580
582
all_grads_ov_l = []
581
583
for inp , fn_gov in zip (local_inputs , goverrides_l ):
582
- if fn_gov == "default" :
584
+ if fn_gov is None :
583
585
gnext , gnext_ov = OpFromGraph ._filter_grad_var (next (gdefaults ), inp )
584
586
all_grads_l .append (gnext )
585
587
all_grads_ov_l .append (gnext_ov )
@@ -652,13 +654,13 @@ def _recompute_rop_op(self):
652
654
fn_rop = partial (Rop , wrt = local_inputs , eval_points = eval_points )
653
655
TYPE_ERR_MSG = (
654
656
"R_op overrides should be (single or list of)"
655
- "OpFromGraph | 'default' | None | 0 | callable, got %s"
657
+ "OpFromGraph, None, a list or a callable, got %s"
656
658
)
657
659
STYPE_ERR_MSG = (
658
660
"Overriding Variable instance can only have type"
659
661
" of DisconnectedType or NullType, got %s"
660
662
)
661
- if rop_op == "default" :
663
+ if rop_op is None :
662
664
rdefaults_l = fn_rop (f = local_outputs )
663
665
all_rops_l , all_rops_ov_l = zip (
664
666
* [
@@ -668,15 +670,6 @@ def _recompute_rop_op(self):
668
670
)
669
671
all_rops_l = list (all_rops_l )
670
672
all_rops_ov_l = list (all_rops_ov_l )
671
- elif isinstance (rop_op , Variable ):
672
- if isinstance (rop_op .type , NullType ):
673
- all_rops_l = [inp .zeros_like () for inp in local_inputs ]
674
- all_rops_ov_l = [rop_op .type () for _ in range (out_len )]
675
- elif isinstance (rop_op .type , DisconnectedType ):
676
- all_rops_l = [inp .zeros_like () for inp in local_inputs ]
677
- all_rops_ov_l = [None ] * out_len
678
- else :
679
- raise ValueError (STYPE_ERR_MSG % rop_op .type )
680
673
elif isinstance (rop_op , list ):
681
674
roverrides_l = rop_op
682
675
if len (roverrides_l ) != out_len :
@@ -686,15 +679,15 @@ def _recompute_rop_op(self):
686
679
)
687
680
# get outputs that does not have Rop override
688
681
odefaults_l = [
689
- lo for lo , rov in zip (local_outputs , roverrides_l ) if rov == "default"
682
+ lo for lo , rov in zip (local_outputs , roverrides_l ) if rov is None
690
683
]
691
684
rdefaults_l = fn_rop (f = odefaults_l )
692
685
rdefaults = iter (rdefaults_l if odefaults_l else [])
693
686
# combine overriding Rops
694
687
all_rops_l = []
695
688
all_rops_ov_l = []
696
689
for out , fn_rov in zip (local_outputs , roverrides_l ):
697
- if fn_rov == "default" :
690
+ if fn_rov is None :
698
691
rnext , rnext_ov = OpFromGraph ._filter_rop_var (next (rdefaults ), out )
699
692
all_rops_l .append (rnext )
700
693
all_rops_ov_l .append (rnext_ov )
@@ -769,7 +762,6 @@ def set_grad_overrides(self, grad_overrides):
769
762
self ._lop_op = grad_overrides
770
763
self ._lop_op_is_cached = False
771
764
self ._lop_type = "grad"
772
- self ._lop_is_default = grad_overrides == "default"
773
765
774
766
def set_lop_overrides (self , lop_overrides ):
775
767
"""
@@ -780,7 +772,6 @@ def set_lop_overrides(self, lop_overrides):
780
772
self ._lop_op = lop_overrides
781
773
self ._lop_op_is_cached = False
782
774
self ._lop_type = "lop"
783
- self ._lop_is_default = lop_overrides == "default"
784
775
785
776
def set_rop_overrides (self , rop_overrides ):
786
777
"""
@@ -790,7 +781,6 @@ def set_rop_overrides(self, rop_overrides):
790
781
"""
791
782
self ._rop_op = rop_overrides
792
783
self ._rop_op_is_cached = False
793
- self ._rop_is_default = rop_overrides == "default"
794
784
795
785
def L_op (self , inputs , outputs , output_grads ):
796
786
if not self ._lop_op_is_cached :
0 commit comments