3
3
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4
4
5
5
from dataclasses import dataclass
6
+ import importlib .metadata
6
7
from typing import Optional , Union
7
8
8
9
import numpy as np
@@ -64,6 +65,13 @@ def _cast_to_3_tuple(self, cfg):
64
65
raise ValueError
65
66
66
67
68
+ # binding availability depends on cuda-python version
69
+ py_major_minor = tuple (int (v ) for v in (
70
+ importlib .metadata .version ("cuda-python" ).split ("." )[:2 ]))
71
+ driver_ver = handle_return (cuda .cuDriverGetVersion ())
72
+ use_ex = (driver_ver >= 11080 ) and (py_major_minor >= (11 , 8 ))
73
+
74
+
67
75
def launch (kernel , config , * kernel_args ):
68
76
if not isinstance (kernel , Kernel ):
69
77
raise ValueError
@@ -76,11 +84,12 @@ def launch(kernel, config, *kernel_args):
76
84
kernel_args = ParamHolder (kernel_args )
77
85
args_ptr = kernel_args .ptr
78
86
79
- # Note: CUkernel can still be launched via the old cuLaunchKernel. We check ._backend
80
- # here not because of the CUfunction/CUkernel difference (which depends on whether the
81
- # "old" or "new" module loading APIs are in use), but only as a proxy to check if
82
- # both binding & driver versions support the "Ex" API, which is more feature rich.
83
- if kernel ._backend == "new" :
87
+ # Note: CUkernel can still be launched via the old cuLaunchKernel and we do not care
88
+ # about the CUfunction/CUkernel difference (which depends on whether the "old" or
89
+ # "new" module loading APIs are in use). We check both binding & driver versions here
90
+ # mainly to see if the "Ex" API is available and if so we use it, as it's more feature
91
+ # rich.
92
+ if use_ex :
84
93
drv_cfg = cuda .CUlaunchConfig ()
85
94
drv_cfg .gridDimX , drv_cfg .gridDimY , drv_cfg .gridDimZ = config .grid
86
95
drv_cfg .blockDimX , drv_cfg .blockDimY , drv_cfg .blockDimZ = config .block
@@ -89,7 +98,7 @@ def launch(kernel, config, *kernel_args):
89
98
drv_cfg .numAttrs = 0 # TODO
90
99
handle_return (cuda .cuLaunchKernelEx (
91
100
drv_cfg , int (kernel ._handle ), args_ptr , 0 ))
92
- else : # "old" backend
101
+ else :
93
102
# TODO: check if config has any unsupported attrs
94
103
handle_return (cuda .cuLaunchKernel (
95
104
int (kernel ._handle ),
0 commit comments