|
6 | 6 | from typing import (
|
7 | 7 | TYPE_CHECKING,
|
8 | 8 | Any,
|
| 9 | + Optional, |
9 | 10 | Protocol,
|
| 11 | + Tuple, |
10 | 12 | TypeVar,
|
11 | 13 | cast,
|
12 | 14 | )
|
@@ -323,6 +325,119 @@ def __ne__(self, other: Any) -> bool:
|
323 | 325 | # just to self.add_tag_trace
|
324 | 326 | add_tag_trace = staticmethod(add_tag_trace)
|
325 | 327 |
|
| 328 | + def linear_transpose( |
| 329 | + self, |
| 330 | + node: Apply, |
| 331 | + transposed_inputs: Sequence[Variable], |
| 332 | + linear_inputs: Sequence[int], |
| 333 | + linear_outputs: Sequence[int], |
| 334 | + ) -> Sequence[Variable]: |
| 335 | + """Transpose a linear function. |
| 336 | +
|
| 337 | + The function f: [node.inputs[i] for i in linear_inputs] to [node.outputs[i] ofr i in linear_outputs] |
| 338 | + given the remaining inputs as constants must be linear. This function can then |
| 339 | + be implemented by an Op, and return f^*(transposed_inputs). |
| 340 | +
|
| 341 | + Parameters |
| 342 | + ---------- |
| 343 | + node: Apply |
| 344 | + The point at which to do the transpose |
| 345 | + transposed_inputs: |
| 346 | + The inputs for the transposed function. |
| 347 | + linear_inputs: |
| 348 | + Indices of input arguments to consider. |
| 349 | + linear_outputs: |
| 350 | + Indices of output arguments to consider. |
| 351 | + """ |
| 352 | + raise NotImplementedError(f"Linear transpos of {self} is not defined or not implemented.") |
| 353 | + |
| 354 | + def push_forward( |
| 355 | + self, |
| 356 | + node: Apply, |
| 357 | + input_tangents: Sequence[Variable | None], |
| 358 | + result_nums: Sequence[int], |
| 359 | + ) -> Tuple[Sequence[Variable] | None, Sequence[Variable | None]]: |
| 360 | + """Compute the push_forward of tangent vectors at the specified point. |
| 361 | +
|
| 362 | + Parameters |
| 363 | + ---------- |
| 364 | + node: Apply |
| 365 | + The point at which to compute the push_forward. (ie at x = node.inputs |
| 366 | + and f(x) = node.outputs). |
| 367 | + input_tangents: |
| 368 | + The values of the tangent vectors that we wish to map. Values that |
| 369 | + are set to None are assumed to be constants. |
| 370 | + result_nums: |
| 371 | + Compute only the output tangents of [node.outputs[i] for i in argnums]. |
| 372 | +
|
| 373 | + Returns |
| 374 | + ------- |
| 375 | + alternative_outputs: |
| 376 | + Optionally a hint to the rewriter that the outputs of the op could |
| 377 | + also be computed with the provided values, if the tangents are also |
| 378 | + computed. |
| 379 | + output_tangents: |
| 380 | + The tangents of the outputs specified in argnums. |
| 381 | + If the value is None, this indicates that the output did |
| 382 | + not depend on the inputs that had tangents provided.. |
| 383 | + """ |
| 384 | + from pytensor.gradient import DisconnectedType |
| 385 | + from pytensor.graph.null_type import NullType |
| 386 | + from pytensor.tensor.basic import zeros_like |
| 387 | + |
| 388 | + tangents_filled = [ |
| 389 | + # TODO do the R_op methods also accept a disconnected_grad? |
| 390 | + tangent if tangent is not None else zeros_like(input) |
| 391 | + for tangent, input in zip(input_tangents, node.inputs, strict=True) |
| 392 | + ] |
| 393 | + output_tangents = self.R_op(node.inputs, tangents_filled) |
| 394 | + output_tangents = [output_tangents[i] for i in result_nums] |
| 395 | + |
| 396 | + mapped_output_tangents = [] |
| 397 | + for argnum, tangent in zip(result_nums, output_tangents): |
| 398 | + if isinstance(tangent.type, DisconnectedType): |
| 399 | + mapped_output_tangents.append(None) |
| 400 | + elif isinstance(tangent.type, NullType): |
| 401 | + raise NotImplementedError( |
| 402 | + f"The push_forward of argument {argnum} of op " |
| 403 | + f"{self} is not implemented or not defined." |
| 404 | + ) |
| 405 | + else: |
| 406 | + mapped_output_tangents.append(tangent) |
| 407 | + return (None, mapped_output_tangents) |
| 408 | + |
| 409 | + def pull_back( |
| 410 | + self, |
| 411 | + node: Apply, |
| 412 | + output_cotangents: Sequence[Variable | None], |
| 413 | + argnums: Sequence[int], |
| 414 | + ) -> Tuple[Sequence[Variable] | None, Sequence[Variable | None]]: |
| 415 | + from pytensor.gradient import DisconnectedType |
| 416 | + from pytensor.graph.null_type import NullType |
| 417 | + from pytensor.tensor.basic import zeros_like |
| 418 | + |
| 419 | + cotangents_filled = [ |
| 420 | + # TODO do the L_op methods also accept a disconnected_grad? |
| 421 | + cotangent if cotangent is not None else zeros_like(input) |
| 422 | + for cotangent, input in zip(output_cotangents, node.outputs, strict=True) |
| 423 | + ] |
| 424 | + |
| 425 | + input_cotangents = self.L_op(node.inputs, node.outputs, cotangents_filled) |
| 426 | + input_cotangents = [input_cotangents[i] for i in argnums] |
| 427 | + |
| 428 | + mapped_input_cotangents = [] |
| 429 | + for argnum, cotangent in zip(argnums, input_cotangents): |
| 430 | + if isinstance(cotangent.type, DisconnectedType): |
| 431 | + mapped_input_cotangents.append(None) |
| 432 | + elif isinstance(cotangent.type, NullType): |
| 433 | + raise NotImplementedError( |
| 434 | + f"The push_forward of argument {argnum} of op " |
| 435 | + f"{self} is not implemented or not defined." |
| 436 | + ) |
| 437 | + else: |
| 438 | + mapped_input_cotangents.append(cotangent) |
| 439 | + return (None, mapped_input_cotangents) |
| 440 | + |
326 | 441 | def grad(
|
327 | 442 | self, inputs: Sequence[Variable], output_grads: Sequence[Variable]
|
328 | 443 | ) -> list[Variable]:
|
|
0 commit comments