Skip to content

Add rough version of an autodiff refactor #788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 227 additions & 3 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
import warnings
from collections.abc import Callable, Mapping, MutableSequence, Sequence
from functools import partial, reduce
from typing import TYPE_CHECKING, Literal, TypeVar, Union
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union

import numpy as np

import pytensor
from pytensor.compile.ops import ViewOp
from pytensor.configdefaults import config
from pytensor.graph import utils
from pytensor.graph.basic import Apply, NominalVariable, Variable
from pytensor.graph.basic import Apply, NominalVariable, Variable, io_toposort
from pytensor.graph.null_type import NullType, null_type
from pytensor.graph.op import get_test_values
from pytensor.graph.op import Op, OutputStorageType, get_test_values
from pytensor.graph.type import Type


Expand Down Expand Up @@ -2292,3 +2292,227 @@ def grad_scale(x, multiplier):
0.416...
"""
return GradScale(multiplier)(x)


# ===========================================
# The following is more or less pseudocode...
# ===========================================


# Use transpose and forward mode autodiff to get reverse mode autodiff
# Ops that only define push_forward (Rop) could use this, which is nice
# because push_forward is usually easier to derive and think about.
def pull_back_through_transpose(outputs, inputs, output_cotangents):
tangents = [input.type() for input in inputs]
output_tangents = push_forward(outputs, inputs, tangents)
return linear_transpose(output_tangents, tangents, output_cotangents)


# Ops that only define pull_back (Lop) could use this to derive push_forward.
def push_forward_through_pull_back(outputs, inputs, tangents):
cotangents = [out.type("u") for out in outputs]
input_cotangents = pull_back(outputs, inputs, cotangents)
return pull_back(input_cotangents, cotangents, tangents)


def _push_forward_impl(outputs, inputs, input_tangents):
# Get the nodes in topological order and precompute
# a set of values that are used in the graph.
nodes = io_toposort(inputs, outputs)
used_values = set(outputs)
for node in reversed(nodes):
if any(output in used_values for output in node.outputs):
used_values.update(node.inputs)

# Maybe a lazy gradient op could use this during rewrite time?
recorded_rewrites = {}
known_tangents = dict(zip(inputs, input_tangents, strict=True))
for node in nodes:
tangents = [known_tangents.get(input, None) for input in node.inputs]
result_nums = [
i for i in range(len(node.outputs)) if node.outputs[i] in used_values
]
new_outputs, output_tangents = node.op.push_forward(node, tangents, result_nums)
if new_outputs is not None:
recorded_rewrites[node] = new_outputs

for i, tangent in zip(result_nums, output_tangents, strict=True):
known_tangents[node.outputs[i]] = tangent

return [known_tangents[output] for output in outputs]


def _pull_back_impl(outputs, inputs, output_cotangents):
known_cotangents = dict(zip(outputs, output_cotangents, strict=True))

nodes = io_toposort(inputs, outputs)
used_values = set(outputs)
for node in reversed(nodes):
if any(output in used_values for output in node.outputs):
used_values.update(node.inputs)

# Maybe a lazy gradient op could use this during rewrite time?
recorded_rewrites = {}
for node in reversed(nodes):
cotangents = [known_cotangents.get(output, None) for output in node.outputs]
argnums = [i for i in range(len(node.inputs)) if node.inputs[i] in used_values]
new_outputs, input_cotangents = node.op.pull_back(node, cotangents, argnums)
if new_outputs is not None:
recorded_rewrites[node] = new_outputs

for i, cotangent in zip(argnums, input_cotangents, strict=True):
if cotangent is None:
continue
input = node.inputs[i]
if input not in known_cotangents:
known_cotangents[input] = cotangent
else:
# TODO check that we are not broadcasting?
known_cotangents[input] += cotangent

return [known_cotangents[input] for input in inputs]


def pullback_grad(cost, wrt):
"""A new pt.grad that uses the pull_back function.

At some point we might want to replace pt.grad with this?
"""
from pytensor.tensor import as_tensor_variable

# Error checking and allow non-list wrt...
return pull_back([cost], wrt, [as_tensor_variable(1.0)])


def linear_transpose(outputs, inputs, transposed_inputs):
"""Given a linear function from inputs to outputs, return the transposed function."""
# some loop over inv_toposort...
# Should look similar to pull_back?


class PullBackOp(Op):
__props__ = ("n_outputs", "n_inputs")

def __init__(self, n_outputs, n_inputs):
self.n_outputs = n_outputs
self.n_inputs = n_inputs
super().__init__()

def make_node(self, *all_inputs) -> Apply:
# all_inputs is [*outputs, *inputs, *output_cotangents]
if len(all_inputs) != 2 * self.n_outputs + self.n_inputs:
raise ValueError("Incorrect number of inputs")

inputs_output_cotangents = all_inputs[self.n_outputs :]
inputs = inputs_output_cotangents[: self.n_inputs]

input_cotangents = [input.type() for input in inputs]

# TODO
continous_dtypes = ["float64", "float32", "float16"]
for input in inputs:
if input.type.dtype not in continous_dtypes:
raise ValueError(
f"Can not compute pullback for non-continous value {input}"
)

return Apply(self, all_inputs, input_cotangents)

def _get_pullback_primal_outputs(self, node):
return node.inputs[: self.n_outputs]

def _get_pullback_primal_inputs(self, node):
return node.inputs[self.n_outputs : self.n_outputs + self.n_inputs]

def _get_pullback_output_cotangents(self, node):
return node.inputs[self.n_outputs + self.n_inputs :]

def _get_pullback_input_cotangents(self, node):
return node.outputs

def _pullback_split_args(self, node):
return (
self._get_pullback_primal_outputs(node),
self._get_pullback_primal_inputs(node),
self._get_pullback_output_cotangents(node),
)

def perform(
self, node: Apply, inputs: Sequence[Any], output_storage: OutputStorageType
) -> None:
raise NotImplementedError(
"PullBackOp can not be executed, but needs to be removed in rewrites"
)

def infer_shape(self, fgraph, node, shapes):
return shapes[self.n_outputs + self.n_inputs :]


class PushForwardOp(Op):
__props__ = ("n_outputs", "n_inputs")

def __init__(self, n_outputs, n_inputs):
self.n_outputs = n_outputs
self.n_inputs = n_inputs
super().__init__()

def make_node(self, *all_inputs) -> Apply:
# all_inputs is [*outputs, *inputs, *input_tangents]
if len(all_inputs) != self.n_outputs + 2 * self.n_inputs:
raise ValueError("Incorrect number of inputs")

outputs = all_inputs[: self.n_outputs]
inputs_input_tangents = all_inputs[self.n_outputs :]

inputs = inputs_input_tangents[: self.n_inputs]

output_tangents = [output.type() for output in outputs]

# TODO
for input in inputs:
continous_dtypes = ["float64", "float32", "float16"]
if input.type.dtype not in continous_dtypes:
raise ValueError(
f"Can not compute push forward for non-continous value {input}"
)

return Apply(self, all_inputs, output_tangents)

def _get_push_forward_primal_outputs(self, node):
return node.inputs[: self.n_outputs]

def _get_push_forward_primal_inputs(self, node):
return node.inputs[self.n_outputs : self.n_outputs + self.n_inputs]

def _get_push_forward_output_tangents(self, node):
return node.outputs

def _get_push_forward_input_tangents(self, node):
return node.inputs[self.n_outputs + self.n_inputs :]

def _push_forward_split_args(self, node):
return (
self._get_push_forward_primal_outputs(node),
self._get_push_forward_primal_inputs(node),
self._get_push_forward_input_tangents(node),
)

def perform(
self, node: Apply, inputs: Sequence[Any], output_storage: OutputStorageType
) -> None:
raise NotImplementedError(
"PullBackOp can not be executed, but needs to be removed in rewrites"
)

def infer_shape(self, fgraph, node, shapes):
return shapes[: self.n_outputs]


def pull_back(outputs, inputs, output_cotangents):
op = PullBackOp(len(outputs), len(inputs))
return op(*outputs, *inputs, *output_cotangents, return_list=True)


def push_forward(outputs, inputs, input_tangents):
op = PushForwardOp(len(outputs), len(inputs))
return op(*outputs, *inputs, *input_tangents, return_list=True)
115 changes: 115 additions & 0 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from typing import (
TYPE_CHECKING,
Any,
Optional,
Protocol,
Tuple,
TypeVar,
cast,
)
Expand Down Expand Up @@ -323,6 +325,119 @@ def __ne__(self, other: Any) -> bool:
# just to self.add_tag_trace
add_tag_trace = staticmethod(add_tag_trace)

def linear_transpose(
self,
node: Apply,
transposed_inputs: Sequence[Variable],
linear_inputs: Sequence[int],
linear_outputs: Sequence[int],
) -> Sequence[Variable]:
"""Transpose a linear function.

The function f: [node.inputs[i] for i in linear_inputs] to [node.outputs[i] ofr i in linear_outputs]
given the remaining inputs as constants must be linear. This function can then
be implemented by an Op, and return f^*(transposed_inputs).

Parameters
----------
node: Apply
The point at which to do the transpose
transposed_inputs:
The inputs for the transposed function.
linear_inputs:
Indices of input arguments to consider.
linear_outputs:
Indices of output arguments to consider.
"""
raise NotImplementedError(f"Linear transpos of {self} is not defined or not implemented.")

def push_forward(
self,
node: Apply,
input_tangents: Sequence[Variable | None],
result_nums: Sequence[int],
) -> Tuple[Sequence[Variable] | None, Sequence[Variable | None]]:
"""Compute the push_forward of tangent vectors at the specified point.

Parameters
----------
node: Apply
The point at which to compute the push_forward. (ie at x = node.inputs
and f(x) = node.outputs).
input_tangents:
The values of the tangent vectors that we wish to map. Values that
are set to None are assumed to be constants.
result_nums:
Compute only the output tangents of [node.outputs[i] for i in argnums].

Returns
-------
alternative_outputs:
Optionally a hint to the rewriter that the outputs of the op could
also be computed with the provided values, if the tangents are also
computed.
output_tangents:
The tangents of the outputs specified in argnums.
If the value is None, this indicates that the output did
not depend on the inputs that had tangents provided..
"""
from pytensor.gradient import DisconnectedType
from pytensor.graph.null_type import NullType
from pytensor.tensor.basic import zeros_like

tangents_filled = [
# TODO do the R_op methods also accept a disconnected_grad?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but we should

tangent if tangent is not None else zeros_like(input)
for tangent, input in zip(input_tangents, node.inputs, strict=True)
]
output_tangents = self.R_op(node.inputs, tangents_filled)
output_tangents = [output_tangents[i] for i in result_nums]

mapped_output_tangents = []
for argnum, tangent in zip(result_nums, output_tangents):
if isinstance(tangent.type, DisconnectedType):
mapped_output_tangents.append(None)
elif isinstance(tangent.type, NullType):
raise NotImplementedError(
f"The push_forward of argument {argnum} of op "
f"{self} is not implemented or not defined."
)
else:
mapped_output_tangents.append(tangent)
return (None, mapped_output_tangents)

def pull_back(
self,
node: Apply,
output_cotangents: Sequence[Variable | None],
argnums: Sequence[int],
) -> Tuple[Sequence[Variable] | None, Sequence[Variable | None]]:
from pytensor.gradient import DisconnectedType
from pytensor.graph.null_type import NullType
from pytensor.tensor.basic import zeros_like

cotangents_filled = [
# TODO do the L_op methods also accept a disconnected_grad?
cotangent if cotangent is not None else zeros_like(input)
for cotangent, input in zip(output_cotangents, node.outputs, strict=True)
]

input_cotangents = self.L_op(node.inputs, node.outputs, cotangents_filled)
input_cotangents = [input_cotangents[i] for i in argnums]

mapped_input_cotangents = []
for argnum, cotangent in zip(argnums, input_cotangents):
if isinstance(cotangent.type, DisconnectedType):
mapped_input_cotangents.append(None)
elif isinstance(cotangent.type, NullType):
raise NotImplementedError(
f"The push_forward of argument {argnum} of op "
f"{self} is not implemented or not defined."
)
else:
mapped_input_cotangents.append(cotangent)
return (None, mapped_input_cotangents)

def grad(
self, inputs: Sequence[Variable], output_grads: Sequence[Variable]
) -> list[Variable]:
Expand Down
Loading