diff --git a/extension_cpp/csrc/lltm.cpp b/extension_cpp/csrc/lltm.cpp index c3bb5bc..c915dd9 100644 --- a/extension_cpp/csrc/lltm.cpp +++ b/extension_cpp/csrc/lltm.cpp @@ -89,6 +89,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} // Defines the operators TORCH_LIBRARY(extension_cpp, m) { + m.impl_abstract_pystub("extension_cpp.ops"); m.def("lltm_forward(Tensor input, Tensor weights, Tensor bias, Tensor old_h, Tensor old_cell) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); m.def("lltm_backward(Tensor grad_h, Tensor grad_cell, Tensor new_cell, Tensor input_gate, Tensor output_gate, Tensor candidate_cell, Tensor X, Tensor gate_weights, Tensor weights) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); } diff --git a/extension_cpp/ops.py b/extension_cpp/ops.py index d0a2a9f..16c0311 100644 --- a/extension_cpp/ops.py +++ b/extension_cpp/ops.py @@ -38,6 +38,23 @@ def backward(ctx, grad_h, grad_cell): return d_input, d_weights, d_bias, d_old_h, d_old_cell +@torch.library.impl_abstract("extension_cpp::lltm_forward") +def _(input, weights, bias, old_h, old_cell): + X = torch.cat([old_h, input], dim=1) + gate_weights = torch.nn.functional.linear(X, weights, bias) + gates = gate_weights.chunk(3, dim=1) + input_gate = torch.empty_like(gates[0]) + output_gate = torch.empty_like(gates[1]) + candidate_cell = torch.empty_like(gates[2]) + new_cell = torch.empty_like(old_cell) + new_h = torch.empty_like(old_h) + if input.device.type == "cuda": + batch_size = old_cell.shape[0] + state_size = old_cell.shape[1] + gate_weights = gate_weights.reshape(batch_size, 3, state_size) + return new_h, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights + + def reference_lltm( input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor ) -> Tuple[Tensor, Tensor]: diff --git a/test/test_extension.py b/test/test_extension.py index 3d4c81f..348ac9e 100644 --- a/test/test_extension.py +++ b/test/test_extension.py @@ -8,11 +8,11 @@ import torch.nn.functional as F -def sample_inputs(device): +def sample_inputs(device, *, requires_grad=False): batch_size = 3 features = 17 state_size = 5 - kwargs = {"dtype": torch.float64, "device": device, "requires_grad": True} + kwargs = {"dtype": torch.float64, "device": device, "requires_grad": requires_grad} X = torch.randn( batch_size, # E: No overload variant of "randn" matches argument features, @@ -41,7 +41,8 @@ def test_correctness_cuda(self): self._test_correctness("cuda") def _test_gradients(self, device): - args = sample_inputs(device) + args = sample_inputs(device, requires_grad=True) + # Use torch.autograd.gradcheck to check that gradients are OK torch.autograd.gradcheck(extension_cpp.ops.lltm, args) def test_gradients_cpu(self): @@ -53,6 +54,17 @@ def test_gradients_cpu(self): def test_gradients_cuda(self): self._test_gradients("cuda") + def _opcheck(self, device): + args = sample_inputs(device) + # Use opcheck to test that the operator was written correctly. + opcheck(torch.ops.extension_cpp.lltm_forward.default, args) + + def test_opcheck_cpu(self): + self._opcheck("cpu") + + def test_opcheck_cuda(self): + self._opcheck("cuda") + if __name__ == "__main__": unittest.main()