Skip to content

Commit 9221d2e

Browse files
ricardoV94aseyboldt
andcommitted
Compute pushforward via double application of pullback
Also fixes bug in Scan L_op and Max R_op Co-authored-by: Adrian Seyboldt <aseyboldt@users.noreply.github.com>
1 parent 2aecb95 commit 9221d2e

File tree

13 files changed

+509
-278
lines changed

13 files changed

+509
-278
lines changed

doc/extending/op.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,4 +506,3 @@ These are the function required to work with :func:`pytensor.gradient.grad`.
506506
the outputs) back to their corresponding shapes and return them as the
507507
output of the :meth:`Op.R_op` method.
508508

509-
:ref:`List of op with r op support <R_op_list>`.

doc/library/gradient.rst

Lines changed: 0 additions & 76 deletions
This file was deleted.

doc/library/tensor/basic.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,5 +1791,3 @@ Gradient / Differentiation
17911791
:members: grad
17921792
:noindex:
17931793

1794-
See the :ref:`gradient <libdoc_gradient>` page for complete documentation
1795-
of the gradient module.

doc/tutorial/gradients.rst

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,7 @@ of symbolic differentiation).
8686
``i`` of the output list is the gradient of the first argument of
8787
`pt.grad` with respect to the ``i``-th element of the list given as second argument.
8888
The first argument of `pt.grad` has to be a scalar (a tensor
89-
of size 1). For more information on the semantics of the arguments of
90-
`pt.grad` and details about the implementation, see
91-
:ref:`this<libdoc_gradient>` section of the library.
89+
of size 1).
9290

9391
Additional information on the inner workings of differentiation may also be
9492
found in the more advanced tutorial :ref:`Extending PyTensor<extending>`.
@@ -204,7 +202,21 @@ you need to do something similar to this:
204202
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
205203
array([ 2., 2.])
206204

207-
:ref:`List <R_op_list>` of Op that implement Rop.
205+
By default, the R-operator is implemented as a double application of the L_operator
206+
(see `reference <https://j-towns.github.io/2017/06/12/A-new-trick.html>`).
207+
In most cases this should be as performant as a specialized implementation of the R-operator.
208+
However, PyTensor may sometimes fail to prune dead branches or fuse common expressions within composite operators,
209+
such as Scan and OpFromGraph, that would be more easily avoidable in a direct implentation of the R-operator.
210+
211+
When this is a concern, it is possible to force `Rop` to use the specialized `Op.R_op` methods by passing
212+
`use_op_rop_implementation=True`. Note that this will fail if the graph contains `Op`s that don't implement this method.
213+
214+
215+
>>> JV = pytensor.gradient.Rop(y, W, V, use_op_rop_implementation=True)
216+
>>> f = pytensor.function([W, V, x], JV)
217+
>>> f([[1, 1], [1, 1]], [[2, 2], [2, 2]], [0,1])
218+
array([ 2., 2.])
219+
208220

209221
L-operator
210222
----------
@@ -234,7 +246,6 @@ array([[ 0., 0.],
234246
as the input parameter, while the result of the R-operator has a shape similar
235247
to that of the output.
236248

237-
:ref:`List of op with r op support <R_op_list>`.
238249

239250
Hessian times a Vector
240251
======================

pytensor/compile/builders.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,12 @@ def __init__(
340340
``None``, this will be used as the connection_pattern for this
341341
:class:`Op`.
342342
343+
.. warning::
344+
345+
rop overrides is ignored when `pytensor.gradient.Rop` is called with
346+
`use_op_rop_implementation=False` (default). In this case the Lop
347+
is used twice to obtain a mathematically equivalent Rop.
348+
343349
strict: bool, default False
344350
If true, it raises when any variables needed to compute the inner graph
345351
are not provided as explici inputs. This can only happen for graphs with
@@ -641,7 +647,12 @@ def _build_and_cache_rop_op(self):
641647
return rop_overrides
642648

643649
eval_points = [inp_t() for inp_t in self.input_types]
644-
fn_rop = partial(Rop, wrt=inner_inputs, eval_points=eval_points)
650+
fn_rop = partial(
651+
Rop,
652+
wrt=inner_inputs,
653+
eval_points=eval_points,
654+
use_op_rop_implementation=True,
655+
)
645656

646657
callable_args = (inner_inputs, eval_points)
647658
if rop_overrides is None:

0 commit comments

Comments
 (0)