Skip to content

Commit 015c15d

Browse files
committed
Make lltm pt2-compliant
ghstack-source-id: e44bec1 Pull Request resolved: #88
1 parent ec66d01 commit 015c15d

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

extension_cpp/csrc/lltm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
8989

9090
// Defines the operators
9191
TORCH_LIBRARY(extension_cpp, m) {
92+
m.impl_abstract_pystub("extension_cpp.ops");
9293
m.def("lltm_forward(Tensor input, Tensor weights, Tensor bias, Tensor old_h, Tensor old_cell) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
9394
m.def("lltm_backward(Tensor grad_h, Tensor grad_cell, Tensor new_cell, Tensor input_gate, Tensor output_gate, Tensor candidate_cell, Tensor X, Tensor gate_weights, Tensor weights) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
9495
}

extension_cpp/ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,23 @@ def backward(ctx, grad_h, grad_cell):
3838
return d_input, d_weights, d_bias, d_old_h, d_old_cell
3939

4040

41+
@torch.library.impl_abstract("extension_cpp::lltm_forward")
42+
def _(input, weights, bias, old_h, old_cell):
43+
X = torch.cat([old_h, input], dim=1)
44+
gate_weights = torch.nn.functional.linear(X, weights, bias)
45+
gates = gate_weights.chunk(3, dim=1)
46+
input_gate = torch.empty_like(gates[0])
47+
output_gate = torch.empty_like(gates[1])
48+
candidate_cell = torch.empty_like(gates[2])
49+
new_cell = torch.empty_like(old_cell)
50+
new_h = torch.empty_like(old_h)
51+
if input.device.type == "cuda":
52+
batch_size = old_cell.shape[0]
53+
state_size = old_cell.shape[1]
54+
gate_weights = gate_weights.reshape(batch_size, 3, state_size)
55+
return new_h, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights
56+
57+
4158
def reference_lltm(
4259
input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor
4360
) -> Tuple[Tensor, Tensor]:

test/test_extension.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
import torch.nn.functional as F
99

1010

11-
def sample_inputs(device):
11+
def sample_inputs(device, *, requires_grad=False):
1212
batch_size = 3
1313
features = 17
1414
state_size = 5
15-
kwargs = {"dtype": torch.float64, "device": device, "requires_grad": True}
15+
kwargs = {"dtype": torch.float64, "device": device, "requires_grad": requires_grad}
1616
X = torch.randn(
1717
batch_size, # E: No overload variant of "randn" matches argument
1818
features,
@@ -41,7 +41,8 @@ def test_correctness_cuda(self):
4141
self._test_correctness("cuda")
4242

4343
def _test_gradients(self, device):
44-
args = sample_inputs(device)
44+
args = sample_inputs(device, requires_grad=True)
45+
# Use torch.autograd.gradcheck to check that gradients are OK
4546
torch.autograd.gradcheck(extension_cpp.ops.lltm, args)
4647

4748
def test_gradients_cpu(self):
@@ -53,6 +54,17 @@ def test_gradients_cpu(self):
5354
def test_gradients_cuda(self):
5455
self._test_gradients("cuda")
5556

57+
def _opcheck(self, device):
58+
args = sample_inputs(device)
59+
# Use opcheck to test that the operator was written correctly.
60+
opcheck(torch.ops.extension_cpp.lltm_forward.default, args)
61+
62+
def test_opcheck_cpu(self):
63+
self._opcheck("cpu")
64+
65+
def test_opcheck_cuda(self):
66+
self._opcheck("cuda")
67+
5668

5769
if __name__ == "__main__":
5870
unittest.main()

0 commit comments

Comments
 (0)