diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 78862de7e1..9e25d4f77f 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -4,7 +4,7 @@ import warnings from collections.abc import Callable, Mapping, MutableSequence, Sequence from functools import partial, reduce -from typing import TYPE_CHECKING, Literal, TypeVar, Union +from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union import numpy as np @@ -12,9 +12,9 @@ from pytensor.compile.ops import ViewOp from pytensor.configdefaults import config from pytensor.graph import utils -from pytensor.graph.basic import Apply, NominalVariable, Variable +from pytensor.graph.basic import Apply, NominalVariable, Variable, io_toposort from pytensor.graph.null_type import NullType, null_type -from pytensor.graph.op import get_test_values +from pytensor.graph.op import Op, OutputStorageType, get_test_values from pytensor.graph.type import Type @@ -2292,3 +2292,227 @@ def grad_scale(x, multiplier): 0.416... """ return GradScale(multiplier)(x) + + +# =========================================== +# The following is more or less pseudocode... +# =========================================== + + +# Use transpose and forward mode autodiff to get reverse mode autodiff +# Ops that only define push_forward (Rop) could use this, which is nice +# because push_forward is usually easier to derive and think about. +def pull_back_through_transpose(outputs, inputs, output_cotangents): + tangents = [input.type() for input in inputs] + output_tangents = push_forward(outputs, inputs, tangents) + return linear_transpose(output_tangents, tangents, output_cotangents) + + +# Ops that only define pull_back (Lop) could use this to derive push_forward. +def push_forward_through_pull_back(outputs, inputs, tangents): + cotangents = [out.type("u") for out in outputs] + input_cotangents = pull_back(outputs, inputs, cotangents) + return pull_back(input_cotangents, cotangents, tangents) + + +def _push_forward_impl(outputs, inputs, input_tangents): + # Get the nodes in topological order and precompute + # a set of values that are used in the graph. + nodes = io_toposort(inputs, outputs) + used_values = set(outputs) + for node in reversed(nodes): + if any(output in used_values for output in node.outputs): + used_values.update(node.inputs) + + # Maybe a lazy gradient op could use this during rewrite time? + recorded_rewrites = {} + known_tangents = dict(zip(inputs, input_tangents, strict=True)) + for node in nodes: + tangents = [known_tangents.get(input, None) for input in node.inputs] + result_nums = [ + i for i in range(len(node.outputs)) if node.outputs[i] in used_values + ] + new_outputs, output_tangents = node.op.push_forward(node, tangents, result_nums) + if new_outputs is not None: + recorded_rewrites[node] = new_outputs + + for i, tangent in zip(result_nums, output_tangents, strict=True): + known_tangents[node.outputs[i]] = tangent + + return [known_tangents[output] for output in outputs] + + +def _pull_back_impl(outputs, inputs, output_cotangents): + known_cotangents = dict(zip(outputs, output_cotangents, strict=True)) + + nodes = io_toposort(inputs, outputs) + used_values = set(outputs) + for node in reversed(nodes): + if any(output in used_values for output in node.outputs): + used_values.update(node.inputs) + + # Maybe a lazy gradient op could use this during rewrite time? + recorded_rewrites = {} + for node in reversed(nodes): + cotangents = [known_cotangents.get(output, None) for output in node.outputs] + argnums = [i for i in range(len(node.inputs)) if node.inputs[i] in used_values] + new_outputs, input_cotangents = node.op.pull_back(node, cotangents, argnums) + if new_outputs is not None: + recorded_rewrites[node] = new_outputs + + for i, cotangent in zip(argnums, input_cotangents, strict=True): + if cotangent is None: + continue + input = node.inputs[i] + if input not in known_cotangents: + known_cotangents[input] = cotangent + else: + # TODO check that we are not broadcasting? + known_cotangents[input] += cotangent + + return [known_cotangents[input] for input in inputs] + + +def pullback_grad(cost, wrt): + """A new pt.grad that uses the pull_back function. + + At some point we might want to replace pt.grad with this? + """ + from pytensor.tensor import as_tensor_variable + + # Error checking and allow non-list wrt... + return pull_back([cost], wrt, [as_tensor_variable(1.0)]) + + +def linear_transpose(outputs, inputs, transposed_inputs): + """Given a linear function from inputs to outputs, return the transposed function.""" + # some loop over inv_toposort... + # Should look similar to pull_back? + + +class PullBackOp(Op): + __props__ = ("n_outputs", "n_inputs") + + def __init__(self, n_outputs, n_inputs): + self.n_outputs = n_outputs + self.n_inputs = n_inputs + super().__init__() + + def make_node(self, *all_inputs) -> Apply: + # all_inputs is [*outputs, *inputs, *output_cotangents] + if len(all_inputs) != 2 * self.n_outputs + self.n_inputs: + raise ValueError("Incorrect number of inputs") + + inputs_output_cotangents = all_inputs[self.n_outputs :] + inputs = inputs_output_cotangents[: self.n_inputs] + + input_cotangents = [input.type() for input in inputs] + + # TODO + continous_dtypes = ["float64", "float32", "float16"] + for input in inputs: + if input.type.dtype not in continous_dtypes: + raise ValueError( + f"Can not compute pullback for non-continous value {input}" + ) + + return Apply(self, all_inputs, input_cotangents) + + def _get_pullback_primal_outputs(self, node): + return node.inputs[: self.n_outputs] + + def _get_pullback_primal_inputs(self, node): + return node.inputs[self.n_outputs : self.n_outputs + self.n_inputs] + + def _get_pullback_output_cotangents(self, node): + return node.inputs[self.n_outputs + self.n_inputs :] + + def _get_pullback_input_cotangents(self, node): + return node.outputs + + def _pullback_split_args(self, node): + return ( + self._get_pullback_primal_outputs(node), + self._get_pullback_primal_inputs(node), + self._get_pullback_output_cotangents(node), + ) + + def perform( + self, node: Apply, inputs: Sequence[Any], output_storage: OutputStorageType + ) -> None: + raise NotImplementedError( + "PullBackOp can not be executed, but needs to be removed in rewrites" + ) + + def infer_shape(self, fgraph, node, shapes): + return shapes[self.n_outputs + self.n_inputs :] + + +class PushForwardOp(Op): + __props__ = ("n_outputs", "n_inputs") + + def __init__(self, n_outputs, n_inputs): + self.n_outputs = n_outputs + self.n_inputs = n_inputs + super().__init__() + + def make_node(self, *all_inputs) -> Apply: + # all_inputs is [*outputs, *inputs, *input_tangents] + if len(all_inputs) != self.n_outputs + 2 * self.n_inputs: + raise ValueError("Incorrect number of inputs") + + outputs = all_inputs[: self.n_outputs] + inputs_input_tangents = all_inputs[self.n_outputs :] + + inputs = inputs_input_tangents[: self.n_inputs] + + output_tangents = [output.type() for output in outputs] + + # TODO + for input in inputs: + continous_dtypes = ["float64", "float32", "float16"] + if input.type.dtype not in continous_dtypes: + raise ValueError( + f"Can not compute push forward for non-continous value {input}" + ) + + return Apply(self, all_inputs, output_tangents) + + def _get_push_forward_primal_outputs(self, node): + return node.inputs[: self.n_outputs] + + def _get_push_forward_primal_inputs(self, node): + return node.inputs[self.n_outputs : self.n_outputs + self.n_inputs] + + def _get_push_forward_output_tangents(self, node): + return node.outputs + + def _get_push_forward_input_tangents(self, node): + return node.inputs[self.n_outputs + self.n_inputs :] + + def _push_forward_split_args(self, node): + return ( + self._get_push_forward_primal_outputs(node), + self._get_push_forward_primal_inputs(node), + self._get_push_forward_input_tangents(node), + ) + + def perform( + self, node: Apply, inputs: Sequence[Any], output_storage: OutputStorageType + ) -> None: + raise NotImplementedError( + "PullBackOp can not be executed, but needs to be removed in rewrites" + ) + + def infer_shape(self, fgraph, node, shapes): + return shapes[: self.n_outputs] + + +def pull_back(outputs, inputs, output_cotangents): + op = PullBackOp(len(outputs), len(inputs)) + return op(*outputs, *inputs, *output_cotangents, return_list=True) + + +def push_forward(outputs, inputs, input_tangents): + op = PushForwardOp(len(outputs), len(inputs)) + return op(*outputs, *inputs, *input_tangents, return_list=True) diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index 160a65dd7a..acd7dec11f 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -6,7 +6,9 @@ from typing import ( TYPE_CHECKING, Any, + Optional, Protocol, + Tuple, TypeVar, cast, ) @@ -323,6 +325,119 @@ def __ne__(self, other: Any) -> bool: # just to self.add_tag_trace add_tag_trace = staticmethod(add_tag_trace) + def linear_transpose( + self, + node: Apply, + transposed_inputs: Sequence[Variable], + linear_inputs: Sequence[int], + linear_outputs: Sequence[int], + ) -> Sequence[Variable]: + """Transpose a linear function. + + The function f: [node.inputs[i] for i in linear_inputs] to [node.outputs[i] ofr i in linear_outputs] + given the remaining inputs as constants must be linear. This function can then + be implemented by an Op, and return f^*(transposed_inputs). + + Parameters + ---------- + node: Apply + The point at which to do the transpose + transposed_inputs: + The inputs for the transposed function. + linear_inputs: + Indices of input arguments to consider. + linear_outputs: + Indices of output arguments to consider. + """ + raise NotImplementedError(f"Linear transpos of {self} is not defined or not implemented.") + + def push_forward( + self, + node: Apply, + input_tangents: Sequence[Variable | None], + result_nums: Sequence[int], + ) -> Tuple[Sequence[Variable] | None, Sequence[Variable | None]]: + """Compute the push_forward of tangent vectors at the specified point. + + Parameters + ---------- + node: Apply + The point at which to compute the push_forward. (ie at x = node.inputs + and f(x) = node.outputs). + input_tangents: + The values of the tangent vectors that we wish to map. Values that + are set to None are assumed to be constants. + result_nums: + Compute only the output tangents of [node.outputs[i] for i in argnums]. + + Returns + ------- + alternative_outputs: + Optionally a hint to the rewriter that the outputs of the op could + also be computed with the provided values, if the tangents are also + computed. + output_tangents: + The tangents of the outputs specified in argnums. + If the value is None, this indicates that the output did + not depend on the inputs that had tangents provided.. + """ + from pytensor.gradient import DisconnectedType + from pytensor.graph.null_type import NullType + from pytensor.tensor.basic import zeros_like + + tangents_filled = [ + # TODO do the R_op methods also accept a disconnected_grad? + tangent if tangent is not None else zeros_like(input) + for tangent, input in zip(input_tangents, node.inputs, strict=True) + ] + output_tangents = self.R_op(node.inputs, tangents_filled) + output_tangents = [output_tangents[i] for i in result_nums] + + mapped_output_tangents = [] + for argnum, tangent in zip(result_nums, output_tangents): + if isinstance(tangent.type, DisconnectedType): + mapped_output_tangents.append(None) + elif isinstance(tangent.type, NullType): + raise NotImplementedError( + f"The push_forward of argument {argnum} of op " + f"{self} is not implemented or not defined." + ) + else: + mapped_output_tangents.append(tangent) + return (None, mapped_output_tangents) + + def pull_back( + self, + node: Apply, + output_cotangents: Sequence[Variable | None], + argnums: Sequence[int], + ) -> Tuple[Sequence[Variable] | None, Sequence[Variable | None]]: + from pytensor.gradient import DisconnectedType + from pytensor.graph.null_type import NullType + from pytensor.tensor.basic import zeros_like + + cotangents_filled = [ + # TODO do the L_op methods also accept a disconnected_grad? + cotangent if cotangent is not None else zeros_like(input) + for cotangent, input in zip(output_cotangents, node.outputs, strict=True) + ] + + input_cotangents = self.L_op(node.inputs, node.outputs, cotangents_filled) + input_cotangents = [input_cotangents[i] for i in argnums] + + mapped_input_cotangents = [] + for argnum, cotangent in zip(argnums, input_cotangents): + if isinstance(cotangent.type, DisconnectedType): + mapped_input_cotangents.append(None) + elif isinstance(cotangent.type, NullType): + raise NotImplementedError( + f"The push_forward of argument {argnum} of op " + f"{self} is not implemented or not defined." + ) + else: + mapped_input_cotangents.append(cotangent) + return (None, mapped_input_cotangents) + def grad( self, inputs: Sequence[Variable], output_grads: Sequence[Variable] ) -> list[Variable]: diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 06d023d780..1fbb180853 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -2366,27 +2366,32 @@ def local_log_add_exp(fgraph, node): TODO: in canonicalize, change log10 and log2 -> log """ + z = node.inputs[0] + if not z.owner or z.owner.op != add: + return - if node.op == log: - z = node.inputs[0] - if z.owner and z.owner.op == add: - zi = z.owner.inputs - pre_exp = [x.owner.inputs[0] for x in zi if x.owner and x.owner.op == exp] - # all arguments to add are exp() - if len(pre_exp) == len(zi): - # Do not offset when max_pre = -np.inf, to avoid nan in the output - # Switch statement is placed directly inside add to break the self-symmetry - # of the returned output (otherwise the rewrite would not stabilize) - max_pre = reduce(maximum, pre_exp) - ret = max_pre + log( - add( - *[ - switch(isinf(max_pre), exp(max_pre), exp(p - max_pre)) - for p in pre_exp - ] - ) - ) - return [ret] + zi = z.owner.inputs + pre_exp = [x.owner.inputs[0] for x in zi if x.owner and x.owner.op == exp] + + # all arguments to add are exp() + if len(pre_exp) != len(zi): + return + + if len(zi) == 2: + a, b = pre_exp + replace_val = switch(a > b, a + log1p(a - b), b + log1p(b - a)) + # Handle inf cases + replace_val = switch(eq(a, b), a + log(2), replace_val) + return [replace_val] + + # Do not offset when max_pre = -np.inf, to avoid nan in the output + # Switch statement is placed directly inside add to break the self-symmetry + # of the returned output (otherwise the rewrite would not stabilize) + max_pre = reduce(maximum, pre_exp) + ret = max_pre + log( + add(*[switch(isinf(max_pre), exp(max_pre), exp(p - max_pre)) for p in pre_exp]) + ) + return [ret] @register_stabilize diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index c89049105f..43315b46c0 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -9,7 +9,7 @@ import pytensor from pytensor import scalar as ps from pytensor.configdefaults import config -from pytensor.gradient import DisconnectedType +from pytensor.gradient import DisconnectedType, linear_transpose, push_forward from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.op import Op from pytensor.graph.replace import _vectorize_node @@ -837,6 +837,27 @@ def infer_shape(self, fgraph, node, shapes): assert len(outshp) == node.outputs[0].ndim return [outshp] + def linear_transpose(self, node, transposed_inputs, linear_inputs, linear_outputs): + assert linear_inputs == [0] + assert linear_outputs == [0] + (transposed_input,) = transposed_inputs + + x, *others = node.inputs + return [IncSubtensor(self.idx_list)(x.zeros_like(), transposed_input, *others)] + + def push_forward(self, node, input_tangents, result_nums): + if len(result_nums) == 0: + return None, [] + + assert result_nums[0] == 0 + + value_tangent, *_ = input_tangents + if value_tangent is None: + return None, [None] + + _, *others = node.inputs + return None, [self(value_tangent, *others)] + def grad(self, inputs, grads): (gz,) = grads x = inputs[0]