Skip to content

Commit 59d25ad

Browse files
committed
remove pybind usage, make example python agnostic
1 parent 1882282 commit 59d25ad

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

extension_cpp/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,10 @@
11
import torch
2-
from . import _C, ops
2+
from pathlib import Path
3+
4+
so_files = list(Path(__file__).parent.glob("_C*.so"))
5+
assert (
6+
len(so_files) == 1
7+
), f"Expected one _C*.so file, found {len(so_files)}"
8+
torch.ops.load_library(so_files[0])
9+
10+
from . import ops

extension_cpp/csrc/muladd.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@ void myadd_out_cpu(const at::Tensor& a, const at::Tensor& b, at::Tensor& out) {
6161
}
6262
}
6363

64-
// Registers _C as a Python extension module.
65-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
66-
6764
// Defines the operators
6865
TORCH_LIBRARY(extension_cpp, m) {
6966
m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor");

setup.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818

1919
library_name = "extension_cpp"
2020

21+
if torch.__version__ >= "2.6.0":
22+
py_limited_api = True
23+
else:
24+
py_limited_api = False
25+
2126

2227
def get_extensions():
2328
debug_mode = os.getenv("DEBUG", "0") == "1"
@@ -59,6 +64,7 @@ def get_extensions():
5964
sources,
6065
extra_compile_args=extra_compile_args,
6166
extra_link_args=extra_link_args,
67+
py_limited_api=py_limited_api,
6268
)
6369
]
6470

@@ -71,9 +77,10 @@ def get_extensions():
7177
packages=find_packages(),
7278
ext_modules=get_extensions(),
7379
install_requires=["torch"],
74-
description="Example of PyTorch cpp and CUDA extensions",
80+
description="Example of PyTorch C++ and CUDA extensions",
7581
long_description=open("README.md").read(),
7682
long_description_content_type="text/markdown",
7783
url="https://github.com/pytorch/extension-cpp",
7884
cmdclass={"build_ext": BuildExtension},
85+
options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {},
7986
)

0 commit comments

Comments
 (0)