2
2
#
3
3
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
4
4
5
+ import importlib .metadata
6
+
5
7
from cuda import cuda , cudart
6
8
from cuda .core .experimental ._utils import handle_return
7
9
8
10
9
11
_backend = {
10
- "new" : {
11
- "file" : cuda .cuLibraryLoadFromFile ,
12
- "data" : cuda .cuLibraryLoadData ,
13
- "kernel" : cuda .cuLibraryGetKernel ,
14
- },
15
12
"old" : {
16
13
"file" : cuda .cuModuleLoad ,
17
14
"data" : cuda .cuModuleLoadDataEx ,
20
17
}
21
18
22
19
20
+ # TODO: revisit this treatment for py313t builds
21
+ _inited = False
22
+ _py_major_ver = None
23
+ _driver_ver = None
24
+ _kernel_ctypes = None
25
+
26
+
27
+ def _lazy_init ():
28
+ global _inited
29
+ if _inited :
30
+ return
31
+
32
+ global _py_major_ver , _driver_ver , _kernel_ctypes
33
+ # binding availability depends on cuda-python version
34
+ _py_major_ver = int (importlib .metadata .version ("cuda-python" ).split ("." )[0 ])
35
+ if _py_major_ver >= 12 :
36
+ _backend ["new" ] = {
37
+ "file" : cuda .cuLibraryLoadFromFile ,
38
+ "data" : cuda .cuLibraryLoadData ,
39
+ "kernel" : cuda .cuLibraryGetKernel ,
40
+ }
41
+ _kernel_ctypes = (cuda .CUfunction , cuda .CUkernel )
42
+ else :
43
+ _kernel_ctypes = (cuda .CUfunction ,)
44
+ _driver_ver = handle_return (cuda .cuDriverGetVersion ())
45
+ _inited = True
46
+
47
+
23
48
class Kernel :
24
49
25
50
__slots__ = ("_handle" , "_module" ,)
@@ -29,13 +54,15 @@ def __init__(self):
29
54
30
55
@staticmethod
31
56
def _from_obj (obj , mod ):
32
- assert isinstance (obj , ( cuda . CUkernel , cuda . CUfunction ) )
57
+ assert isinstance (obj , _kernel_ctypes )
33
58
assert isinstance (mod , ObjectCode )
34
59
ker = Kernel .__new__ (Kernel )
35
60
ker ._handle = obj
36
61
ker ._module = mod
37
62
return ker
38
63
64
+ # TODO: implement from_handle()
65
+
39
66
40
67
class ObjectCode :
41
68
@@ -46,26 +73,29 @@ def __init__(self, module, code_type, jit_options=None, *,
46
73
symbol_mapping = None ):
47
74
if code_type not in self ._supported_code_type :
48
75
raise ValueError
76
+ _lazy_init ()
49
77
self ._handle = None
50
78
51
- driver_ver = handle_return ( cuda . cuDriverGetVersion ())
52
- self ._loader = _backend ["new" ] if driver_ver >= 12000 else _backend [ "old" ]
79
+ backend = "new" if ( _py_major_ver >= 12 and _driver_ver >= 12000 ) else "old"
80
+ self ._loader = _backend [backend ]
53
81
54
82
if isinstance (module , str ):
55
- if driver_ver < 12000 and jit_options is not None :
83
+ # TODO: this option is only taken by the new library APIs, but we have
84
+ # a bug that we can't easily support it just yet (NVIDIA/cuda-python#73).
85
+ if jit_options is not None :
56
86
raise ValueError
57
87
module = module .encode ()
58
88
self ._handle = handle_return (self ._loader ["file" ](module ))
59
89
else :
60
90
assert isinstance (module , bytes )
61
91
if jit_options is None :
62
92
jit_options = {}
63
- if driver_ver >= 12000 :
93
+ if backend == "new" :
64
94
args = (module , list (jit_options .keys ()), list (jit_options .values ()), len (jit_options ),
65
95
# TODO: support library options
66
96
[], [], 0 )
67
- else :
68
- args = (module , len (jit_options ), jit_options .keys (), jit_options .values ())
97
+ else : # "old" backend
98
+ args = (module , len (jit_options ), list ( jit_options .keys ()), list ( jit_options .values () ))
69
99
self ._handle = handle_return (self ._loader ["data" ](* args ))
70
100
71
101
self ._code_type = code_type
@@ -83,3 +113,5 @@ def get_kernel(self, name):
83
113
name = name .encode ()
84
114
data = handle_return (self ._loader ["kernel" ](self ._handle , name ))
85
115
return Kernel ._from_obj (data , self )
116
+
117
+ # TODO: implement from_handle()
0 commit comments