Skip to content

Commit 7ac96a9

Browse files
committed
Update to use the new Python custom op APIs
Won't land this until 2.4 comes by. ghstack-source-id: a9f4f75 Pull Request resolved: #90
1 parent a5ed0b0 commit 7ac96a9

File tree

2 files changed

+35
-32
lines changed

2 files changed

+35
-32
lines changed

extension_cpp/csrc/lltm.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
8989

9090
// Defines the operators
9191
TORCH_LIBRARY(extension_cpp, m) {
92-
m.impl_abstract_pystub("extension_cpp.ops");
9392
m.def("lltm_forward(Tensor input, Tensor weights, Tensor bias, Tensor old_h, Tensor old_cell) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
9493
m.def("lltm_backward(Tensor grad_h, Tensor grad_cell, Tensor new_cell, Tensor input_gate, Tensor output_gate, Tensor candidate_cell, Tensor X, Tensor gate_weights, Tensor weights) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
9594
}

extension_cpp/ops.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,41 @@
88
def lltm(
99
input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor
1010
) -> Tuple[Tensor, Tensor]:
11-
return LLTMFunction.apply(input, weights, bias, old_h, old_cell)
12-
13-
14-
class LLTMFunction(torch.autograd.Function):
15-
@staticmethod
16-
def forward(ctx, input, weights, bias, old_h, old_cell):
17-
outputs = torch.ops.extension_cpp.lltm_forward.default(
18-
input, weights, bias, old_h, old_cell
19-
)
20-
new_h, new_cell = outputs[:2]
21-
variables = list(outputs[1:]) + [weights]
22-
ctx.save_for_backward(*variables)
23-
24-
return new_h, new_cell
25-
26-
@staticmethod
27-
@torch.autograd.function.once_differentiable
28-
def backward(ctx, grad_h, grad_cell):
29-
(
30-
d_old_h,
31-
d_input,
32-
d_weights,
33-
d_bias,
34-
d_old_cell,
35-
) = torch.ops.extension_cpp.lltm_backward.default(
36-
grad_h, grad_cell, *ctx.saved_tensors
37-
)
38-
return d_input, d_weights, d_bias, d_old_h, d_old_cell
39-
40-
41-
@torch.library.impl_abstract("extension_cpp::lltm_forward")
11+
"""The lltm API"""
12+
outputs = torch.ops.extension_cpp.lltm_forward.default(
13+
input, weights, bias, old_h, old_cell
14+
)
15+
new_h, new_cell = outputs[:2]
16+
return new_h, new_cell
17+
18+
19+
# This is the backward for lltm_forward.
20+
# lltm_forward has 7 returns so they all get gradients.
21+
def backward(ctx, grad_h, grad_cell, _0, _1, _2, _3, _4):
22+
(
23+
d_old_h,
24+
d_input,
25+
d_weights,
26+
d_bias,
27+
d_old_cell,
28+
) = torch.ops.extension_cpp.lltm_backward.default(
29+
grad_h, grad_cell, *ctx.saved_tensors
30+
)
31+
return d_input, d_weights, d_bias, d_old_h, d_old_cell
32+
33+
34+
def setup_context(ctx, inputs, output):
35+
weights = inputs[1]
36+
new_h, new_cell = output[:2]
37+
variables = list(output[1:]) + [weights]
38+
ctx.save_for_backward(*variables)
39+
40+
41+
torch.library.register_autograd(
42+
"extension_cpp::lltm_forward", backward, setup_context=setup_context)
43+
44+
45+
@torch.library.register_fake("extension_cpp::lltm_forward")
4246
def _(input, weights, bias, old_h, old_cell):
4347
X = torch.cat([old_h, input], dim=1)
4448
gate_weights = torch.nn.functional.linear(X, weights, bias)

0 commit comments

Comments
 (0)