Skip to content

Refactor gradient related methods #182

Open
@ricardoV94

Description

@ricardoV94

Right now we have grad, L_op, R_op.

Deprecate grad in favor of L_op:

grad is exactly the same as L_op except it doesn't have access to the outputs of the node that is being differentiated.

def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> List[Variable]:
r"""Construct a graph for the L-operator.
The L-operator computes a row vector times the Jacobian.
This method dispatches to :meth:`Op.grad` by default. In one sense,
this method provides the original outputs when they're needed to
compute the return value, whereas `Op.grad` doesn't.
See `Op.grad` for a mathematical explanation of the inputs and outputs
of this method.
Parameters
----------
inputs
The inputs of the `Apply` node using this `Op`.
outputs
The outputs of the `Apply` node using this `Op`
output_grads
The gradients with respect to each `Variable` in `inputs`.
"""
return self.grad(inputs, output_grads)

L_op allows one to reuse the same output when it's needed in the gradient, which means there is one less node to be merged during compilation. This is mostly relevant for nodes that are costly to merge such as Scan (see 0f5a06d).

It also saves time spent on make_node (e.g., inferring static type shapes). In the Scalar Ops it's used everywhere to quickly check if the output types are discrete (see fd628c5). There are some opportunities still missing, for example, the gradient of Exp:

def L_op(self, inputs, outputs, gout):
(x,) = inputs
(gz,) = gout
if x.type in complex_types:
raise NotImplementedError()
if outputs[0].type in discrete_types:
if x.type in discrete_types:
return [x.zeros_like(dtype=config.floatX)]
else:
return [x.zeros_like()]
return (gz * exp(x),)

Could instead return (gz * outputs[0],)

More importantly for this issue, I think we should deprecate grad completely, since everything can be equally well done with L_op.

Rename L_op and R_op?

The names are pretty non-intuitive, and I don't think they are used in any other auto-diff libraries. The equivalents in JAX are vjp and jvp (you can find direct translation in https://www.pymc-labs.io/blog-posts/jax-functions-in-pymc-3-quick-examples/)

Other suggestions were discussed some time ago by Theano devs here: https://groups.google.com/g/theano-dev/c/8-z2C59rmQk/m/gm432ifVAg0J?pli=1

Remove R_op in favor of double application of L_op (or make it a default fallback)

There was some fanfare sometime ago about R_op being completely redundant in a framework with dead code elimination: Theano/Theano#6035

That thread suggests also the double L_op may generate more efficient graphs in some cases (because most of our rewrites target the type of graphs generated by L_op?)

It probably makes sense to retain the R_op for cases where we/users know that's the best approach but perhaps default/revert to double L_op otherwise. Stale PRs that never quite got into Theano:

Theano/Theano#6400
Theano/Theano#6037

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions