|
12 | 12 | from pytensor.compile.ops import ViewOp
|
13 | 13 | from pytensor.configdefaults import config
|
14 | 14 | from pytensor.graph import utils
|
15 |
| -from pytensor.graph.basic import Apply, NominalVariable, Variable |
| 15 | +from pytensor.graph.basic import Apply, NominalVariable, Variable, io_toposort |
16 | 16 | from pytensor.graph.null_type import NullType, null_type
|
17 | 17 | from pytensor.graph.op import get_test_values
|
18 | 18 | from pytensor.graph.type import Type
|
@@ -2292,3 +2292,90 @@ def grad_scale(x, multiplier):
|
2292 | 2292 | 0.416...
|
2293 | 2293 | """
|
2294 | 2294 | return GradScale(multiplier)(x)
|
| 2295 | + |
| 2296 | + |
| 2297 | +# =========================================== |
| 2298 | +# The following is more or less pseudocode... |
| 2299 | +# =========================================== |
| 2300 | + |
| 2301 | +# Use transpose and forward mode autodiff to get reverse mode autodiff |
| 2302 | +# Ops that only define push_forward (Rop) could use this, which is nice |
| 2303 | +# because push_forward is usually easier to derive and think about. |
| 2304 | +def pull_back_through_transpose(outputs, inputs, output_cotangents): |
| 2305 | + tangents = [input.type() for input in inputs] |
| 2306 | + output_tangents = push_forward(outputs, inputs, tangents) |
| 2307 | + return linear_transpose(output_tangents, tangents, output_cotangents) |
| 2308 | + |
| 2309 | + |
| 2310 | +# Ops that only define pull_back (Lop) could use this to derive push_forward. |
| 2311 | +def push_forward_through_pull_back(outputs, inputs, tangents): |
| 2312 | + cotangents = [out.type("u") for out in outputs] |
| 2313 | + input_cotangents = pull_back(outputs, inputs, cotangents) |
| 2314 | + return pull_back(input_cotangents, cotangents, tangents) |
| 2315 | + |
| 2316 | + |
| 2317 | +def push_forward(outputs, inputs, input_tangents): |
| 2318 | + # Get the nodes in topological order and precompute |
| 2319 | + # a set of values that are used in the graph. |
| 2320 | + nodes = io_toposort(inputs, outputs) |
| 2321 | + used_values = set(outputs) |
| 2322 | + for node in reversed(nodes): |
| 2323 | + if any(output in used_values for output in node.outputs): |
| 2324 | + used_values.update(node.inputs) |
| 2325 | + |
| 2326 | + # Maybe a lazy gradient op could use this during rewrite time? |
| 2327 | + recorded_rewrites = {} |
| 2328 | + known_tangents = dict(zip(inputs, input_tangents, strict=True)) |
| 2329 | + for node in nodes: |
| 2330 | + tangents = [known_tangents.get(input, None) for input in node.inputs] |
| 2331 | + result_nums = [i for i in range(len(node.outputs)) if node.outputs[i] in used_values] |
| 2332 | + new_outputs, output_tangents = node.op.push_forward(node, tangents, result_nums) |
| 2333 | + if new_outputs is not None: |
| 2334 | + recorded_rewrites[node] = new_outputs |
| 2335 | + |
| 2336 | + for i, tangent in zip(result_nums, output_tangents, strict=True): |
| 2337 | + known_tangents[node.outputs[i]] = tangent |
| 2338 | + |
| 2339 | + return [known_tangents[output] for output in outputs] |
| 2340 | + |
| 2341 | + |
| 2342 | +def pull_back(outputs, inputs, output_cotangents): |
| 2343 | + known_cotangents = dict(zip(outputs, output_cotangents, strict=True)) |
| 2344 | + |
| 2345 | + nodes = io_toposort(inputs, outputs) |
| 2346 | + used_values = set(outputs) |
| 2347 | + for node in reversed(nodes): |
| 2348 | + if any(output in used_values for output in node.outputs): |
| 2349 | + used_values.update(node.inputs) |
| 2350 | + |
| 2351 | + # Maybe a lazy gradient op could use this during rewrite time? |
| 2352 | + recorded_rewrites = {} |
| 2353 | + for node in reversed(nodes): |
| 2354 | + cotangents = [known_cotangents.get(output, None) for output in node.outputs] |
| 2355 | + argnums = [i for i in range(len(node.inputs)) if node.inputs[i] in used_values] |
| 2356 | + new_outputs, input_cotangents = node.op.pull_back(node, cotangents, argnums) |
| 2357 | + if new_outputs is not None: |
| 2358 | + recorded_rewrites[node] = new_outputs |
| 2359 | + |
| 2360 | + for i, cotangent in zip(argnums, input_cotangents, strict=True): |
| 2361 | + input = node.inputs[i] |
| 2362 | + if input not in known_cotangents: |
| 2363 | + known_cotangents[input] = cotangent |
| 2364 | + else: |
| 2365 | + # TODO check that we are not broadcasting? |
| 2366 | + known_cotangents[input] += cotangent |
| 2367 | + |
| 2368 | + return [known_cotangents[input] for input in inputs] |
| 2369 | + |
| 2370 | +def pullback_grad(cost, wrt): |
| 2371 | + """A new pt.grad that uses the pull_back function. |
| 2372 | +
|
| 2373 | + At some point we might want to replace pt.grad with this? |
| 2374 | + """ |
| 2375 | + # Error checking and allow non-list wrt... |
| 2376 | + return pull_back([cost], wrt, [1.]) |
| 2377 | + |
| 2378 | +def linear_transpose(outputs, inputs, transposed_inputs): |
| 2379 | + """Given a linear function from inputs to outputs, return the transposed function.""" |
| 2380 | + # some loop over inv_toposort... |
| 2381 | + # Should look similar to pull_back? |
0 commit comments