|
8 | 8 | def lltm(
|
9 | 9 | input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor
|
10 | 10 | ) -> Tuple[Tensor, Tensor]:
|
11 |
| - return LLTMFunction.apply(input, weights, bias, old_h, old_cell) |
12 |
| - |
13 |
| - |
14 |
| -class LLTMFunction(torch.autograd.Function): |
15 |
| - @staticmethod |
16 |
| - def forward(ctx, input, weights, bias, old_h, old_cell): |
17 |
| - outputs = torch.ops.extension_cpp.lltm_forward.default( |
18 |
| - input, weights, bias, old_h, old_cell |
19 |
| - ) |
20 |
| - new_h, new_cell = outputs[:2] |
21 |
| - variables = list(outputs[1:]) + [weights] |
22 |
| - ctx.save_for_backward(*variables) |
23 |
| - |
24 |
| - return new_h, new_cell |
25 |
| - |
26 |
| - @staticmethod |
27 |
| - @torch.autograd.function.once_differentiable |
28 |
| - def backward(ctx, grad_h, grad_cell): |
29 |
| - ( |
30 |
| - d_old_h, |
31 |
| - d_input, |
32 |
| - d_weights, |
33 |
| - d_bias, |
34 |
| - d_old_cell, |
35 |
| - ) = torch.ops.extension_cpp.lltm_backward.default( |
36 |
| - grad_h, grad_cell, *ctx.saved_tensors |
37 |
| - ) |
38 |
| - return d_input, d_weights, d_bias, d_old_h, d_old_cell |
39 |
| - |
40 |
| - |
41 |
| -@torch.library.impl_abstract("extension_cpp::lltm_forward") |
| 11 | + """The lltm API""" |
| 12 | + outputs = torch.ops.extension_cpp.lltm_forward.default( |
| 13 | + input, weights, bias, old_h, old_cell |
| 14 | + ) |
| 15 | + new_h, new_cell = outputs[:2] |
| 16 | + return new_h, new_cell |
| 17 | + |
| 18 | + |
| 19 | +# This is the backward for lltm_forward. |
| 20 | +# lltm_forward has 7 returns so they all get gradients. |
| 21 | +def backward(ctx, grad_h, grad_cell, _0, _1, _2, _3, _4): |
| 22 | + ( |
| 23 | + d_old_h, |
| 24 | + d_input, |
| 25 | + d_weights, |
| 26 | + d_bias, |
| 27 | + d_old_cell, |
| 28 | + ) = torch.ops.extension_cpp.lltm_backward.default( |
| 29 | + grad_h, grad_cell, *ctx.saved_tensors |
| 30 | + ) |
| 31 | + return d_input, d_weights, d_bias, d_old_h, d_old_cell |
| 32 | + |
| 33 | + |
| 34 | +def setup_context(ctx, inputs, output): |
| 35 | + weights = inputs[1] |
| 36 | + new_h, new_cell = output[:2] |
| 37 | + variables = list(output[1:]) + [weights] |
| 38 | + ctx.save_for_backward(*variables) |
| 39 | + |
| 40 | + |
| 41 | +torch.library.register_autograd( |
| 42 | + "extension_cpp::lltm_forward", backward, setup_context=setup_context) |
| 43 | + |
| 44 | + |
| 45 | +@torch.library.register_fake("extension_cpp::lltm_forward") |
42 | 46 | def _(input, weights, bias, old_h, old_cell):
|
43 | 47 | X = torch.cat([old_h, input], dim=1)
|
44 | 48 | gate_weights = torch.nn.functional.linear(X, weights, bias)
|
|
0 commit comments