Skip to content

Commit 3076890

Browse files
committed
Update on "Modernize extension-cpp; refactor code"
This PR: - creates a single unified build for extension-cpp (instead of having separate cpu/cuda setup.pys). - Updates the build system to use pyproject.toml (instead of only setup.py) - Uses TORCH_LIBRARY to bind operators (instead of using PyBind) There will be more future work to add improvements (e.g. torch.compile support) and also fix up the corresponding C++ extensions tutorial. Test Plan: - Refactored all of the tests under test/ [ghstack-poisoned]
1 parent 735149e commit 3076890

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
An example of writing a C++/CUDA extension for PyTorch. See
44
[here](http://pytorch.org/tutorials/advanced/cpp_extension.html) for the accompanying tutorial.
5+
This repo demonstrates how to write an example `extension_cpp.ops.lltm`
6+
custom op that has both custom CPU and CUDA kernels.
57

68
To build:
79
```

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121

2222
def get_extensions():
2323
debug_mode = os.getenv("DEBUG", "0") == "1"
24+
use_cuda = os.getenv("USE_CUDA", "1") == "1"
2425
if debug_mode:
2526
print("Compiling in debug mode")
2627

27-
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
28+
use_cuda = use_cuda and torch.cuda.is_available() and CUDA_HOME is not None
2829
extension = CUDAExtension if use_cuda else CppExtension
2930

3031
extra_link_args = []

test/test_extension.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,24 @@ def _test_correctness(self, device):
3434
torch.testing.assert_close(result, expected)
3535

3636
def test_correctness_cpu(self):
37-
self._test_lltm_correctness("cpu")
37+
self._test_correctness("cpu")
3838

39+
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
3940
def test_correctness_cuda(self):
40-
self._test_lltm_correctness("cuda")
41+
self._test_correctness("cuda")
4142

4243
def _test_gradients(self, device):
4344
args = sample_inputs(device)
4445
torch.autograd.gradcheck(extension_cpp.ops.lltm, args)
4546

4647
def test_gradients_cpu(self):
47-
self._test_lltm_grad("cpu")
48+
self._test_gradients("cpu")
4849

4950
# This is supposed to succeed, there's probably a bug in the CUDA kernel.
5051
@unittest.expectedFailure
52+
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
5153
def test_gradients_cuda(self):
52-
self._test_lltm_grad("cuda")
54+
self._test_gradients("cuda")
5355

5456

5557
if __name__ == "__main__":

0 commit comments

Comments
 (0)