Skip to content

Commit 7fd8ccb

Browse files
committed
nit: cleaner treatment
1 parent 7587684 commit 7fd8ccb

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

cuda_core/cuda/core/experimental/_launcher.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
44

55
from dataclasses import dataclass
6+
import importlib.metadata
67
from typing import Optional, Union
78

89
import numpy as np
@@ -64,6 +65,13 @@ def _cast_to_3_tuple(self, cfg):
6465
raise ValueError
6566

6667

68+
# binding availability depends on cuda-python version
69+
py_major_minor = tuple(int(v) for v in (
70+
importlib.metadata.version("cuda-python").split(".")[:2]))
71+
driver_ver = handle_return(cuda.cuDriverGetVersion())
72+
use_ex = (driver_ver >= 11080) and (py_major_minor >= (11, 8))
73+
74+
6775
def launch(kernel, config, *kernel_args):
6876
if not isinstance(kernel, Kernel):
6977
raise ValueError
@@ -76,11 +84,12 @@ def launch(kernel, config, *kernel_args):
7684
kernel_args = ParamHolder(kernel_args)
7785
args_ptr = kernel_args.ptr
7886

79-
# Note: CUkernel can still be launched via the old cuLaunchKernel. We check ._backend
80-
# here not because of the CUfunction/CUkernel difference (which depends on whether the
81-
# "old" or "new" module loading APIs are in use), but only as a proxy to check if
82-
# both binding & driver versions support the "Ex" API, which is more feature rich.
83-
if kernel._backend == "new":
87+
# Note: CUkernel can still be launched via the old cuLaunchKernel and we do not care
88+
# about the CUfunction/CUkernel difference (which depends on whether the "old" or
89+
# "new" module loading APIs are in use). We check both binding & driver versions here
90+
# mainly to see if the "Ex" API is available and if so we use it, as it's more feature
91+
# rich.
92+
if use_ex:
8493
drv_cfg = cuda.CUlaunchConfig()
8594
drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid
8695
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block
@@ -89,7 +98,7 @@ def launch(kernel, config, *kernel_args):
8998
drv_cfg.numAttrs = 0 # TODO
9099
handle_return(cuda.cuLaunchKernelEx(
91100
drv_cfg, int(kernel._handle), args_ptr, 0))
92-
else: # "old" backend
101+
else:
93102
# TODO: check if config has any unsupported attrs
94103
handle_return(cuda.cuLaunchKernel(
95104
int(kernel._handle),

cuda_core/cuda/core/experimental/_module.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,26 @@
3232

3333
class Kernel:
3434

35-
__slots__ = ("_handle", "_module", "_backend")
35+
__slots__ = ("_handle", "_module",)
3636

3737
def __init__(self):
3838
raise NotImplementedError("directly constructing a Kernel instance is not supported")
3939

4040
@staticmethod
41-
def _from_obj(obj, mod, backend):
41+
def _from_obj(obj, mod):
4242
assert isinstance(obj, _kernel_ctypes)
4343
assert isinstance(mod, ObjectCode)
4444
ker = Kernel.__new__(Kernel)
4545
ker._handle = obj
4646
ker._module = mod
47-
ker._backend = backend
4847
return ker
4948

5049
# TODO: implement from_handle()
5150

5251

5352
class ObjectCode:
5453

55-
__slots__ = ("_handle", "_code_type", "_module", "_loader", "_loader_backend", "_sym_map")
54+
__slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map")
5655
_supported_code_type = ("cubin", "ptx", "fatbin")
5756

5857
def __init__(self, module, code_type, jit_options=None, *,
@@ -63,7 +62,6 @@ def __init__(self, module, code_type, jit_options=None, *,
6362

6463
backend = "new" if (py_major_ver >= 12 and driver_ver >= 12000) else "old"
6564
self._loader = _backend[backend]
66-
self._loader_backend = backend
6765

6866
if isinstance(module, str):
6967
if driver_ver < 12000 and jit_options is not None:
@@ -96,6 +94,6 @@ def get_kernel(self, name):
9694
except KeyError:
9795
name = name.encode()
9896
data = handle_return(self._loader["kernel"](self._handle, name))
99-
return Kernel._from_obj(data, self, self._loader_backend)
97+
return Kernel._from_obj(data, self)
10098

10199
# TODO: implement from_handle()

0 commit comments

Comments
 (0)