Skip to content

Commit 394b355

Browse files
committed
Fix OpFromGraph L_op with related and/or disconnected outputs
1 parent 2eb8fca commit 394b355

File tree

2 files changed

+113
-18
lines changed

2 files changed

+113
-18
lines changed

pytensor/compile/builders.py

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,10 @@ def __init__(
417417
FutureWarning,
418418
)
419419
self._lop_op_interface = False
420-
self._lop_op_cache: Callable | None = None
420+
# Dictionary where we cache OpFromGraph that represent the L_op
421+
# A distinct OpFromGraph is needed to represent each pattern of output_grads connection
422+
# It also returns a tuple that indicates which input_gradients are disconnected
423+
self._lop_op_cache: dict[tuple[bool, ...], Callable] = {}
421424
self._rop_op_cache: Callable | None = None
422425

423426
self._connection_pattern = connection_pattern
@@ -480,24 +483,30 @@ def _call_custom_override(self, op_overrides, callable_args, nout):
480483
return outputs
481484

482485
@config.change_flags(compute_test_value="off")
483-
def _build_and_cache_lop_op(self) -> Callable:
484-
"""converts lop_overrides (or grad_overrides) from user supplied form to type(self) instance.
486+
def _build_and_cache_lop_op(
487+
self, disconnected_output_grads: tuple[bool, ...]
488+
) -> Callable:
489+
"""converts lop_overrides (or grad_overrides) from user supplied form to type(self) instance,
490+
specialized for the pattern of disconnected_output_grads
485491
486492
Results are cached in self._lop_op_cache
487493
"""
488-
if self._lop_op_cache is not None:
489-
return self._lop_op_cache
494+
try:
495+
return self._lop_op_cache[disconnected_output_grads]
496+
except KeyError:
497+
pass
490498

491499
inner_inputs = self.inner_inputs
492500
inner_outputs = self.inner_outputs
493501
nin = len(inner_inputs)
502+
nout = len(inner_outputs)
494503
lop_overrides = (
495504
self.lop_overrides if self._lop_op_interface else self.grad_overrides
496505
)
497506

498507
if isinstance(lop_overrides, OpFromGraph):
499508
if self._lop_op_interface:
500-
self._lop_op_cache = lop_overrides
509+
self._lop_op_cache[disconnected_output_grads] = lop_overrides
501510
lop_overrides.kwargs["on_unused_input"] = "ignore"
502511
return lop_overrides
503512

@@ -507,20 +516,42 @@ def _build_and_cache_lop_op(self) -> Callable:
507516
def lop_overrides(inps, grads):
508517
return self.grad_overrides(*inps, *grads)
509518

510-
output_grads = [out_t() for out_t in self.output_types]
519+
# We try to compute the gradient with respect to connected outputs only
520+
connected_inner_outputs = [
521+
# We add an identity operation(copy) so that we don't override indirect
522+
# gradient contributions to an inner output coming from other inner outputs
523+
inner_out.copy()
524+
for inner_out, disconnected in zip(
525+
inner_outputs, disconnected_output_grads, strict=True
526+
)
527+
if not disconnected
528+
]
529+
connected_output_grads = [
530+
out_t()
531+
for out_t, disconnected in zip(
532+
self.output_types, disconnected_output_grads, strict=True
533+
)
534+
if not disconnected
535+
]
511536
fn_grad = partial(
512537
grad,
513538
cost=None,
514539
disconnected_inputs="ignore",
515540
return_disconnected="disconnected",
516541
null_gradients="return",
517-
known_grads=dict(zip(inner_outputs, output_grads)),
542+
known_grads=dict(
543+
zip(connected_inner_outputs, connected_output_grads, strict=True)
544+
),
518545
)
519546

520547
if self._lop_op_interface:
521-
callable_args = (inner_inputs, inner_outputs, output_grads)
548+
callable_args = (
549+
inner_inputs,
550+
connected_inner_outputs,
551+
connected_output_grads,
552+
)
522553
else:
523-
callable_args = (inner_inputs, output_grads)
554+
callable_args = (inner_inputs, connected_output_grads)
524555

525556
# we need to convert _lop_op into an OfG instance
526557
if lop_overrides is None:
@@ -544,32 +575,51 @@ def lop_overrides(inps, grads):
544575
else:
545576
input_grads = self._call_custom_override(lop_overrides, callable_args, nin)
546577

547-
# Filter out disconnected input and output gradients
578+
# Filter out disconnected/null input generated from the inner graph grad
579+
# We append them in the outer wrapper function below
548580
connected_input_grads = [
549581
inp_grad
550582
for inp_grad in input_grads
551583
if not isinstance(inp_grad.type, DisconnectedType | NullType)
552584
]
553585
lop_op = type(self)(
554-
inputs=inner_inputs + inner_outputs + output_grads,
586+
inputs=inner_inputs + connected_inner_outputs + connected_output_grads,
555587
outputs=connected_input_grads,
556588
inline=self.is_inline,
557589
name=(None if self.name is None else f"{self.name}_LOp"),
558590
# TODO: We can be eager here and exclude unused inputs in the OFG
559591
on_unused_input="ignore",
560592
)
561593

562-
# Return a wrapper that combines connected and disconnected input gradients
594+
# Return a wrapper that combines connected and disconnected/null input gradients
595+
# And also filters out disconnected/null output gradients
563596
def wrapper(*inputs: Variable, **kwargs) -> list[Variable]:
564-
connected_input_grads = iter(lop_op(*inputs, **kwargs))
597+
inputs, outputs, output_grads = (
598+
inputs[: -nout * 2],
599+
inputs[-nout * 2 : -nout],
600+
inputs[-nout:],
601+
)
602+
connected_outputs = [
603+
output
604+
for output, output_grad in zip(outputs, output_grads, strict=True)
605+
if not isinstance(output_grad.type, DisconnectedType | NullType)
606+
]
607+
connected_output_grads = [
608+
output_grad
609+
for output_grad in output_grads
610+
if not isinstance(output_grad.type, DisconnectedType)
611+
]
612+
connected_input_grads = iter(
613+
lop_op(*inputs, *connected_outputs, *connected_output_grads, **kwargs)
614+
)
565615
return [
566616
input_grad
567617
if isinstance(input_grad.type, DisconnectedType | NullType)
568618
else next(connected_input_grads)
569619
for input_grad in input_grads
570620
]
571621

572-
self._lop_op_cache = wrapper
622+
self._lop_op_cache[disconnected_output_grads] = wrapper
573623
return wrapper
574624

575625
@config.change_flags(compute_test_value="off")
@@ -652,7 +702,10 @@ def wrapper(*inputs: Variable, **kwargs) -> list[Variable | None]:
652702
return wrapper
653703

654704
def L_op(self, inputs, outputs, output_grads):
655-
lop_op = self._build_and_cache_lop_op()
705+
disconnected_output_grads = tuple(
706+
isinstance(og.type, DisconnectedType) for og in output_grads
707+
)
708+
lop_op = self._build_and_cache_lop_op(disconnected_output_grads)
656709
return lop_op(*inputs, *outputs, *output_grads, return_list=True)
657710

658711
def R_op(self, inputs, eval_points):

tests/compile/test_builders.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
from pytensor.compile.builders import OpFromGraph
99
from pytensor.compile.function import function
1010
from pytensor.configdefaults import config
11-
from pytensor.gradient import DisconnectedType, Rop, disconnected_type, grad
11+
from pytensor.gradient import (
12+
DisconnectedType,
13+
Rop,
14+
disconnected_type,
15+
grad,
16+
verify_grad,
17+
)
1218
from pytensor.graph.basic import equal_computations
1319
from pytensor.graph.fg import FunctionGraph
1420
from pytensor.graph.null_type import NullType, null_type
@@ -22,7 +28,15 @@
2228
from pytensor.tensor.random.utils import RandomStream
2329
from pytensor.tensor.rewriting.shape import ShapeOptimizer
2430
from pytensor.tensor.shape import specify_shape
25-
from pytensor.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors
31+
from pytensor.tensor.type import (
32+
TensorType,
33+
dscalars,
34+
matrices,
35+
matrix,
36+
scalar,
37+
vector,
38+
vectors,
39+
)
2640
from tests import unittest_tools
2741
from tests.graph.utils import MyVariable
2842

@@ -638,6 +652,34 @@ def test_explicit_input_from_shared(self):
638652
out = test_ofg(y, y)
639653
assert out.eval() == 4
640654

655+
def test_L_op_disconnected_output_grad(self):
656+
x, y = dscalars("x", "y")
657+
rng = np.random.default_rng(594)
658+
point = list(rng.normal(size=(2,)))
659+
660+
out1 = x + y
661+
out2 = x * y
662+
out3 = out1 * out2 # Create dependency between outputs
663+
op = OpFromGraph([x, y], [out1, out2, out3])
664+
verify_grad(lambda x, y: pt.add(*op(x, y)), point, rng=rng)
665+
verify_grad(lambda x, y: pt.add(*op(x, y)[:-1]), point, rng=rng)
666+
verify_grad(lambda x, y: pt.add(*op(x, y)[1:]), point, rng=rng)
667+
verify_grad(lambda x, y: pt.add(*op(x, y)[::2]), point, rng=rng)
668+
verify_grad(lambda x, y: op(x, y)[0], point, rng=rng)
669+
verify_grad(lambda x, y: op(x, y)[1], point, rng=rng)
670+
verify_grad(lambda x, y: op(x, y)[2], point, rng=rng)
671+
672+
# Test disconnected graphs are handled correctly
673+
op = OpFromGraph([x, y], [x**2, y**3])
674+
with pytest.warns(UserWarning):
675+
grad_x_wrt_y = grad(
676+
op(x, y)[0],
677+
wrt=y,
678+
return_disconnected="disconnected",
679+
disconnected_inputs="warn",
680+
)
681+
assert isinstance(grad_x_wrt_y.type, DisconnectedType)
682+
641683
def test_repeated_inputs(self):
642684
x = pt.dscalar("x")
643685
y = pt.dscalar("y")

0 commit comments

Comments
 (0)