Skip to content

Make lltm pt2-compliant #88

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extension_cpp/csrc/lltm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}

// Defines the operators
TORCH_LIBRARY(extension_cpp, m) {
m.impl_abstract_pystub("extension_cpp.ops");
m.def("lltm_forward(Tensor input, Tensor weights, Tensor bias, Tensor old_h, Tensor old_cell) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("lltm_backward(Tensor grad_h, Tensor grad_cell, Tensor new_cell, Tensor input_gate, Tensor output_gate, Tensor candidate_cell, Tensor X, Tensor gate_weights, Tensor weights) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
}
Expand Down
17 changes: 17 additions & 0 deletions extension_cpp/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@ def backward(ctx, grad_h, grad_cell):
return d_input, d_weights, d_bias, d_old_h, d_old_cell


@torch.library.impl_abstract("extension_cpp::lltm_forward")
def _(input, weights, bias, old_h, old_cell):
X = torch.cat([old_h, input], dim=1)
gate_weights = torch.nn.functional.linear(X, weights, bias)
gates = gate_weights.chunk(3, dim=1)
input_gate = torch.empty_like(gates[0])
output_gate = torch.empty_like(gates[1])
candidate_cell = torch.empty_like(gates[2])
new_cell = torch.empty_like(old_cell)
new_h = torch.empty_like(old_h)
if input.device.type == "cuda":
batch_size = old_cell.shape[0]
state_size = old_cell.shape[1]
gate_weights = gate_weights.reshape(batch_size, 3, state_size)
return new_h, new_cell, input_gate, output_gate, candidate_cell, X, gate_weights


def reference_lltm(
input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor
) -> Tuple[Tensor, Tensor]:
Expand Down
18 changes: 15 additions & 3 deletions test/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import torch.nn.functional as F


def sample_inputs(device):
def sample_inputs(device, *, requires_grad=False):
batch_size = 3
features = 17
state_size = 5
kwargs = {"dtype": torch.float64, "device": device, "requires_grad": True}
kwargs = {"dtype": torch.float64, "device": device, "requires_grad": requires_grad}
X = torch.randn(
batch_size, # E: No overload variant of "randn" matches argument
features,
Expand Down Expand Up @@ -41,7 +41,8 @@ def test_correctness_cuda(self):
self._test_correctness("cuda")

def _test_gradients(self, device):
args = sample_inputs(device)
args = sample_inputs(device, requires_grad=True)
# Use torch.autograd.gradcheck to check that gradients are OK
torch.autograd.gradcheck(extension_cpp.ops.lltm, args)

def test_gradients_cpu(self):
Expand All @@ -53,6 +54,17 @@ def test_gradients_cpu(self):
def test_gradients_cuda(self):
self._test_gradients("cuda")

def _opcheck(self, device):
args = sample_inputs(device)
# Use opcheck to test that the operator was written correctly.
opcheck(torch.ops.extension_cpp.lltm_forward.default, args)

def test_opcheck_cpu(self):
self._opcheck("cpu")

def test_opcheck_cuda(self):
self._opcheck("cuda")


if __name__ == "__main__":
unittest.main()