Skip to content

Commit 78eac91

Browse files
committed
Add rough version of a autodiff refactor
1 parent fc21336 commit 78eac91

File tree

2 files changed

+186
-0
lines changed

2 files changed

+186
-0
lines changed

pytensor/gradient.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2292,3 +2292,74 @@ def grad_scale(x, multiplier):
22922292
0.416...
22932293
"""
22942294
return GradScale(multiplier)(x)
2295+
2296+
2297+
# ===========================================
2298+
# The following is more or less pseudocode...
2299+
# ===========================================
2300+
2301+
# Use transpose and forward mode autodiff to get reverse mode autodiff
2302+
# Ops that only define push_forward (Rop) could use this, which is nice
2303+
# because push_forward is usually easier to derive and think about.
2304+
def pull_back_through_transpose(outputs, inputs, output_cotangents):
2305+
tangents = [input.type() for input in inputs]
2306+
output_tangents = push_forward(outputs, inputs, tangents)
2307+
return linear_transpose(output_tangents, tangents, output_cotangents)
2308+
2309+
2310+
# Ops that only define pull_back (Lop) could use this to derive push_forward.
2311+
def push_forward_through_pull_back(outputs, inputs, tangents):
2312+
cotangents = [out.type("u") for out in outputs]
2313+
input_cotangents = pull_back(outputs, inputs, cotangents)
2314+
return pull_back(input_cotangents, cotangents, tangents)
2315+
2316+
2317+
def push_forward(outputs, inputs, input_tangents):
2318+
# Get the nodes in topological order and precompute
2319+
# a set of values that are used in the graph.
2320+
nodes, used_values = toposort_and_intermediate(outputs, inputs)
2321+
# Maybe a lazy gradient op could use this during rewrite time?
2322+
recorded_rewrites = {}
2323+
known_tangents = dict(zip(inputs, input_tangents, strict=True))
2324+
for node in nodes:
2325+
tangents = [known_tangents.get(input, None) for input in node.inputs]
2326+
result_nums = [i for i in range(len(node.outputs)) if node.outputs[i] in used_values]
2327+
new_outputs, output_tangents = node.op.push_forward(node, tangents, result_nums)
2328+
if new_outputs is not None:
2329+
recorded_rewrites[node] = new_outputs
2330+
2331+
for i, tangent in zip(result_nums, output_tangents, strict=True):
2332+
known_tangents[node.outputs[i]] = tangent
2333+
2334+
return [known_tangents[output] for output in outputs]
2335+
2336+
2337+
def pull_back(outputs, inputs, output_cotangents):
2338+
known_cotangents = dict(zip(outputs, output_cotangents, strict=True))
2339+
2340+
nodes, used_values = toposort_and_intermediate(outputs, inputs)
2341+
2342+
# Maybe a lazy gradient op could use this during rewrite time?
2343+
recorded_rewrites = {}
2344+
for node in reversed(nodes):
2345+
cotangents = [known_cotangents.get(output, None) for output in node.outputs]
2346+
argnums = [i for i in range(len(node.inputs)) if node.inputs[i] in used_values]
2347+
new_outputs, input_cotangents = node.op.pull_back(node, cotangents, argnums)
2348+
if new_outputs is not None:
2349+
recorded_rewrites[node] = new_outputs
2350+
2351+
for i, cotangent in zip(argnums, input_cotangents, strict=True):
2352+
output = node.outputs[i]
2353+
if output not in known_cotangents:
2354+
known_cotangents[output] = cotangent
2355+
else:
2356+
# TODO check that we are not broadcasting?
2357+
known_cotangents[output] += cotangent
2358+
2359+
return [known_cotangents[input] for input in inputs]
2360+
2361+
2362+
def linear_transpose(outputs, inputs, transposed_inputs):
2363+
"""Given a linear function from inputs to outputs, return the transposed function."""
2364+
# some loop over inv_toposort...
2365+
# Should look similar to pull_back?

pytensor/graph/op.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from typing import (
77
TYPE_CHECKING,
88
Any,
9+
Optional,
910
Protocol,
11+
Tuple,
1012
TypeVar,
1113
cast,
1214
)
@@ -323,6 +325,119 @@ def __ne__(self, other: Any) -> bool:
323325
# just to self.add_tag_trace
324326
add_tag_trace = staticmethod(add_tag_trace)
325327

328+
def linear_transpose(
329+
self,
330+
node: Apply,
331+
transposed_inputs: Sequence[Variable],
332+
linear_inputs: Sequence[int],
333+
linear_outputs: Sequence[int],
334+
) -> Sequence[Variable]:
335+
"""Transpose a linear function.
336+
337+
The function f: [node.inputs[i] for i in linear_inputs] to [node.outputs[i] ofr i in linear_outputs]
338+
given the remaining inputs as constants must be linear. This function can then
339+
be implemented by an Op, and return f^*(transposed_inputs).
340+
341+
Parameters
342+
----------
343+
node: Apply
344+
The point at which to do the transpose
345+
transposed_inputs:
346+
The inputs for the transposed function.
347+
linear_inputs:
348+
Indices of input arguments to consider.
349+
linear_outputs:
350+
Indices of output arguments to consider.
351+
"""
352+
raise NotImplementedError(f"Linear transpos of {self} is not defined or not implemented.")
353+
354+
def push_forward(
355+
self,
356+
node: Apply,
357+
input_tangents: Sequence[Variable | None],
358+
result_nums: Sequence[int],
359+
) -> Tuple[Sequence[Variable] | None, Sequence[Variable | None]]:
360+
"""Compute the push_forward of tangent vectors at the specified point.
361+
362+
Parameters
363+
----------
364+
node: Apply
365+
The point at which to compute the push_forward. (ie at x = node.inputs
366+
and f(x) = node.outputs).
367+
input_tangents:
368+
The values of the tangent vectors that we wish to map. Values that
369+
are set to None are assumed to be constants.
370+
result_nums:
371+
Compute only the output tangents of [node.outputs[i] for i in argnums].
372+
373+
Returns
374+
-------
375+
alternative_outputs:
376+
Optionally a hint to the rewriter that the outputs of the op could
377+
also be computed with the provided values, if the tangents are also
378+
computed.
379+
output_tangents:
380+
The tangents of the outputs specified in argnums.
381+
If the value is None, this indicates that the output did
382+
not depend on the inputs that had tangents provided..
383+
"""
384+
from pytensor.gradient import DisconnectedType
385+
from pytensor.graph.null_type import NullType
386+
from pytensor.tensor.basic import zeros_like
387+
388+
tangents_filled = [
389+
# TODO do the R_op methods also accept a disconnected_grad?
390+
tangent if tangent is not None else zeros_like(input)
391+
for tangent, input in zip(input_tangents, node.inputs, strict=True)
392+
]
393+
output_tangents = self.R_op(node.inputs, tangents_filled)
394+
output_tangents = [output_tangents[i] for i in result_nums]
395+
396+
mapped_output_tangents = []
397+
for argnum, tangent in zip(result_nums, output_tangents):
398+
if isinstance(tangent.type, DisconnectedType):
399+
mapped_output_tangents.append(None)
400+
elif isinstance(tangent.type, NullType):
401+
raise NotImplementedError(
402+
f"The push_forward of argument {argnum} of op "
403+
f"{self} is not implemented or not defined."
404+
)
405+
else:
406+
mapped_output_tangents.append(tangent)
407+
return (None, mapped_output_tangents)
408+
409+
def pull_back(
410+
self,
411+
node: Apply,
412+
output_cotangents: Sequence[Variable | None],
413+
argnums: Sequence[int],
414+
) -> Tuple[Sequence[Variable] | None, Sequence[Variable | None]]:
415+
from pytensor.gradient import DisconnectedType
416+
from pytensor.graph.null_type import NullType
417+
from pytensor.tensor.basic import zeros_like
418+
419+
cotangents_filled = [
420+
# TODO do the L_op methods also accept a disconnected_grad?
421+
cotangent if cotangent is not None else zeros_like(input)
422+
for cotangent, input in zip(output_cotangents, node.outputs, strict=True)
423+
]
424+
425+
input_cotangents = self.L_op(node.inputs, node.outputs, cotangents_filled)
426+
input_cotangents = [input_cotangents[i] for i in argnums]
427+
428+
mapped_input_cotangents = []
429+
for argnum, cotangent in zip(argnums, input_cotangents):
430+
if isinstance(cotangent.type, DisconnectedType):
431+
mapped_input_cotangents.append(None)
432+
elif isinstance(cotangent.type, NullType):
433+
raise NotImplementedError(
434+
f"The push_forward of argument {argnum} of op "
435+
f"{self} is not implemented or not defined."
436+
)
437+
else:
438+
mapped_input_cotangents.append(cotangent)
439+
return (None, mapped_input_cotangents)
440+
326441
def grad(
327442
self, inputs: Sequence[Variable], output_grads: Sequence[Variable]
328443
) -> list[Variable]:

0 commit comments

Comments
 (0)