Skip to content

Commit b634d70

Browse files
committed
try to defer driver loading
1 parent 7fd8ccb commit b634d70

File tree

2 files changed

+52
-22
lines changed

2 files changed

+52
-22
lines changed

cuda_core/cuda/core/experimental/_launcher.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,25 @@
1616
from cuda.core.experimental._utils import CUDAError, check_or_create_options, handle_return
1717

1818

19+
# TODO: revisit this treatment for py313t builds
20+
_inited = False
21+
_use_ex = None
22+
23+
24+
def _lazy_init():
25+
global _inited
26+
if _inited:
27+
return
28+
29+
global _use_ex
30+
# binding availability depends on cuda-python version
31+
_py_major_minor = tuple(int(v) for v in (
32+
importlib.metadata.version("cuda-python").split(".")[:2]))
33+
_driver_ver = handle_return(cuda.cuDriverGetVersion())
34+
_use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8))
35+
_inited = True
36+
37+
1938
@dataclass
2039
class LaunchConfig:
2140
"""
@@ -41,6 +60,8 @@ def __post_init__(self):
4160
if self.shmem_size is None:
4261
self.shmem_size = 0
4362

63+
_lazy_init()
64+
4465
def _cast_to_3_tuple(self, cfg):
4566
if isinstance(cfg, int):
4667
if cfg < 1:
@@ -65,13 +86,6 @@ def _cast_to_3_tuple(self, cfg):
6586
raise ValueError
6687

6788

68-
# binding availability depends on cuda-python version
69-
py_major_minor = tuple(int(v) for v in (
70-
importlib.metadata.version("cuda-python").split(".")[:2]))
71-
driver_ver = handle_return(cuda.cuDriverGetVersion())
72-
use_ex = (driver_ver >= 11080) and (py_major_minor >= (11, 8))
73-
74-
7589
def launch(kernel, config, *kernel_args):
7690
if not isinstance(kernel, Kernel):
7791
raise ValueError
@@ -89,7 +103,7 @@ def launch(kernel, config, *kernel_args):
89103
# "new" module loading APIs are in use). We check both binding & driver versions here
90104
# mainly to see if the "Ex" API is available and if so we use it, as it's more feature
91105
# rich.
92-
if use_ex:
106+
if _use_ex:
93107
drv_cfg = cuda.CUlaunchConfig()
94108
drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid
95109
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block

cuda_core/cuda/core/experimental/_module.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,33 @@
1616
},
1717
}
1818

19-
# binding availability depends on cuda-python version
20-
py_major_ver = int(importlib.metadata.version("cuda-python").split(".")[0])
21-
if py_major_ver >= 12:
22-
_backend["new"] = {
23-
"file": cuda.cuLibraryLoadFromFile,
24-
"data": cuda.cuLibraryLoadData,
25-
"kernel": cuda.cuLibraryGetKernel,
26-
}
27-
_kernel_ctypes = (cuda.CUfunction, cuda.CUkernel)
28-
else:
29-
_kernel_ctypes = (cuda.CUfunction,)
30-
driver_ver = handle_return(cuda.cuDriverGetVersion())
19+
20+
# TODO: revisit this treatment for py313t builds
21+
_inited = False
22+
_py_major_ver = None
23+
_driver_ver = None
24+
_kernel_ctypes = None
25+
26+
27+
def _lazy_init():
28+
global _inited
29+
if _inited:
30+
return
31+
32+
global _py_major_ver, _driver_ver, _kernel_ctypes
33+
# binding availability depends on cuda-python version
34+
_py_major_ver = int(importlib.metadata.version("cuda-python").split(".")[0])
35+
if _py_major_ver >= 12:
36+
_backend["new"] = {
37+
"file": cuda.cuLibraryLoadFromFile,
38+
"data": cuda.cuLibraryLoadData,
39+
"kernel": cuda.cuLibraryGetKernel,
40+
}
41+
_kernel_ctypes = (cuda.CUfunction, cuda.CUkernel)
42+
else:
43+
_kernel_ctypes = (cuda.CUfunction,)
44+
_driver_ver = handle_return(cuda.cuDriverGetVersion())
45+
_inited = True
3146

3247

3348
class Kernel:
@@ -58,13 +73,14 @@ def __init__(self, module, code_type, jit_options=None, *,
5873
symbol_mapping=None):
5974
if code_type not in self._supported_code_type:
6075
raise ValueError
76+
_lazy_init()
6177
self._handle = None
6278

63-
backend = "new" if (py_major_ver >= 12 and driver_ver >= 12000) else "old"
79+
backend = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
6480
self._loader = _backend[backend]
6581

6682
if isinstance(module, str):
67-
if driver_ver < 12000 and jit_options is not None:
83+
if _driver_ver < 12000 and jit_options is not None:
6884
raise ValueError
6985
module = module.encode()
7086
self._handle = handle_return(self._loader["file"](module))

0 commit comments

Comments
 (0)