@@ -417,7 +417,10 @@ def __init__(
417
417
FutureWarning ,
418
418
)
419
419
self ._lop_op_interface = False
420
- self ._lop_op_cache : Callable | None = None
420
+ # Dictionary where we cache OpFromGraph that represent the L_op
421
+ # A distinct OpFromGraph is needed to represent each pattern of output_grads connection
422
+ # It also returns a tuple that indicates which input_gradients are disconnected
423
+ self ._lop_op_cache : dict [tuple [bool , ...], Callable ] = {}
421
424
self ._rop_op_cache : Callable | None = None
422
425
423
426
self ._connection_pattern = connection_pattern
@@ -480,24 +483,30 @@ def _call_custom_override(self, op_overrides, callable_args, nout):
480
483
return outputs
481
484
482
485
@config .change_flags (compute_test_value = "off" )
483
- def _build_and_cache_lop_op (self ) -> Callable :
484
- """converts lop_overrides (or grad_overrides) from user supplied form to type(self) instance.
486
+ def _build_and_cache_lop_op (
487
+ self , disconnected_output_grads : tuple [bool , ...]
488
+ ) -> Callable :
489
+ """converts lop_overrides (or grad_overrides) from user supplied form to type(self) instance,
490
+ specialized for the pattern of disconnected_output_grads
485
491
486
492
Results are cached in self._lop_op_cache
487
493
"""
488
- if self ._lop_op_cache is not None :
489
- return self ._lop_op_cache
494
+ try :
495
+ return self ._lop_op_cache [disconnected_output_grads ]
496
+ except KeyError :
497
+ pass
490
498
491
499
inner_inputs = self .inner_inputs
492
500
inner_outputs = self .inner_outputs
493
501
nin = len (inner_inputs )
502
+ nout = len (inner_outputs )
494
503
lop_overrides = (
495
504
self .lop_overrides if self ._lop_op_interface else self .grad_overrides
496
505
)
497
506
498
507
if isinstance (lop_overrides , OpFromGraph ):
499
508
if self ._lop_op_interface :
500
- self ._lop_op_cache = lop_overrides
509
+ self ._lop_op_cache [ disconnected_output_grads ] = lop_overrides
501
510
lop_overrides .kwargs ["on_unused_input" ] = "ignore"
502
511
return lop_overrides
503
512
@@ -507,20 +516,42 @@ def _build_and_cache_lop_op(self) -> Callable:
507
516
def lop_overrides (inps , grads ):
508
517
return self .grad_overrides (* inps , * grads )
509
518
510
- output_grads = [out_t () for out_t in self .output_types ]
519
+ # We try to compute the gradient with respect to connected outputs only
520
+ connected_inner_outputs = [
521
+ # We add an identity operation(copy) so that we don't override indirect
522
+ # gradient contributions to an inner output coming from other inner outputs
523
+ inner_out .copy ()
524
+ for inner_out , disconnected in zip (
525
+ inner_outputs , disconnected_output_grads , strict = True
526
+ )
527
+ if not disconnected
528
+ ]
529
+ connected_output_grads = [
530
+ out_t ()
531
+ for out_t , disconnected in zip (
532
+ self .output_types , disconnected_output_grads , strict = True
533
+ )
534
+ if not disconnected
535
+ ]
511
536
fn_grad = partial (
512
537
grad ,
513
538
cost = None ,
514
539
disconnected_inputs = "ignore" ,
515
540
return_disconnected = "disconnected" ,
516
541
null_gradients = "return" ,
517
- known_grads = dict (zip (inner_outputs , output_grads )),
542
+ known_grads = dict (
543
+ zip (connected_inner_outputs , connected_output_grads , strict = True )
544
+ ),
518
545
)
519
546
520
547
if self ._lop_op_interface :
521
- callable_args = (inner_inputs , inner_outputs , output_grads )
548
+ callable_args = (
549
+ inner_inputs ,
550
+ connected_inner_outputs ,
551
+ connected_output_grads ,
552
+ )
522
553
else :
523
- callable_args = (inner_inputs , output_grads )
554
+ callable_args = (inner_inputs , connected_output_grads )
524
555
525
556
# we need to convert _lop_op into an OfG instance
526
557
if lop_overrides is None :
@@ -544,32 +575,51 @@ def lop_overrides(inps, grads):
544
575
else :
545
576
input_grads = self ._call_custom_override (lop_overrides , callable_args , nin )
546
577
547
- # Filter out disconnected input and output gradients
578
+ # Filter out disconnected/null input generated from the inner graph grad
579
+ # We append them in the outer wrapper function below
548
580
connected_input_grads = [
549
581
inp_grad
550
582
for inp_grad in input_grads
551
583
if not isinstance (inp_grad .type , DisconnectedType | NullType )
552
584
]
553
585
lop_op = type (self )(
554
- inputs = inner_inputs + inner_outputs + output_grads ,
586
+ inputs = inner_inputs + connected_inner_outputs + connected_output_grads ,
555
587
outputs = connected_input_grads ,
556
588
inline = self .is_inline ,
557
589
name = (None if self .name is None else f"{ self .name } _LOp" ),
558
590
# TODO: We can be eager here and exclude unused inputs in the OFG
559
591
on_unused_input = "ignore" ,
560
592
)
561
593
562
- # Return a wrapper that combines connected and disconnected input gradients
594
+ # Return a wrapper that combines connected and disconnected/null input gradients
595
+ # And also filters out disconnected/null output gradients
563
596
def wrapper (* inputs : Variable , ** kwargs ) -> list [Variable ]:
564
- connected_input_grads = iter (lop_op (* inputs , ** kwargs ))
597
+ inputs , outputs , output_grads = (
598
+ inputs [: - nout * 2 ],
599
+ inputs [- nout * 2 : - nout ],
600
+ inputs [- nout :],
601
+ )
602
+ connected_outputs = [
603
+ output
604
+ for output , output_grad in zip (outputs , output_grads , strict = True )
605
+ if not isinstance (output_grad .type , DisconnectedType | NullType )
606
+ ]
607
+ connected_output_grads = [
608
+ output_grad
609
+ for output_grad in output_grads
610
+ if not isinstance (output_grad .type , DisconnectedType )
611
+ ]
612
+ connected_input_grads = iter (
613
+ lop_op (* inputs , * connected_outputs , * connected_output_grads , ** kwargs )
614
+ )
565
615
return [
566
616
input_grad
567
617
if isinstance (input_grad .type , DisconnectedType | NullType )
568
618
else next (connected_input_grads )
569
619
for input_grad in input_grads
570
620
]
571
621
572
- self ._lop_op_cache = wrapper
622
+ self ._lop_op_cache [ disconnected_output_grads ] = wrapper
573
623
return wrapper
574
624
575
625
@config .change_flags (compute_test_value = "off" )
@@ -652,7 +702,10 @@ def wrapper(*inputs: Variable, **kwargs) -> list[Variable | None]:
652
702
return wrapper
653
703
654
704
def L_op (self , inputs , outputs , output_grads ):
655
- lop_op = self ._build_and_cache_lop_op ()
705
+ disconnected_output_grads = tuple (
706
+ isinstance (og .type , DisconnectedType ) for og in output_grads
707
+ )
708
+ lop_op = self ._build_and_cache_lop_op (disconnected_output_grads )
656
709
return lop_op (* inputs , * outputs , * output_grads , return_list = True )
657
710
658
711
def R_op (self , inputs , eval_points ):
0 commit comments