|
16 | 16 | },
|
17 | 17 | }
|
18 | 18 |
|
19 |
| -# binding availability depends on cuda-python version |
20 |
| -py_major_ver = int(importlib.metadata.version("cuda-python").split(".")[0]) |
21 |
| -if py_major_ver >= 12: |
22 |
| - _backend["new"] = { |
23 |
| - "file": cuda.cuLibraryLoadFromFile, |
24 |
| - "data": cuda.cuLibraryLoadData, |
25 |
| - "kernel": cuda.cuLibraryGetKernel, |
26 |
| - } |
27 |
| - _kernel_ctypes = (cuda.CUfunction, cuda.CUkernel) |
28 |
| -else: |
29 |
| - _kernel_ctypes = (cuda.CUfunction,) |
30 |
| -driver_ver = handle_return(cuda.cuDriverGetVersion()) |
| 19 | + |
| 20 | +# TODO: revisit this treatment for py313t builds |
| 21 | +_inited = False |
| 22 | +_py_major_ver = None |
| 23 | +_driver_ver = None |
| 24 | +_kernel_ctypes = None |
| 25 | + |
| 26 | + |
| 27 | +def _lazy_init(): |
| 28 | + global _inited |
| 29 | + if _inited: |
| 30 | + return |
| 31 | + |
| 32 | + global _py_major_ver, _driver_ver, _kernel_ctypes |
| 33 | + # binding availability depends on cuda-python version |
| 34 | + _py_major_ver = int(importlib.metadata.version("cuda-python").split(".")[0]) |
| 35 | + if _py_major_ver >= 12: |
| 36 | + _backend["new"] = { |
| 37 | + "file": cuda.cuLibraryLoadFromFile, |
| 38 | + "data": cuda.cuLibraryLoadData, |
| 39 | + "kernel": cuda.cuLibraryGetKernel, |
| 40 | + } |
| 41 | + _kernel_ctypes = (cuda.CUfunction, cuda.CUkernel) |
| 42 | + else: |
| 43 | + _kernel_ctypes = (cuda.CUfunction,) |
| 44 | + _driver_ver = handle_return(cuda.cuDriverGetVersion()) |
| 45 | + _inited = True |
31 | 46 |
|
32 | 47 |
|
33 | 48 | class Kernel:
|
@@ -58,13 +73,14 @@ def __init__(self, module, code_type, jit_options=None, *,
|
58 | 73 | symbol_mapping=None):
|
59 | 74 | if code_type not in self._supported_code_type:
|
60 | 75 | raise ValueError
|
| 76 | + _lazy_init() |
61 | 77 | self._handle = None
|
62 | 78 |
|
63 |
| - backend = "new" if (py_major_ver >= 12 and driver_ver >= 12000) else "old" |
| 79 | + backend = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old" |
64 | 80 | self._loader = _backend[backend]
|
65 | 81 |
|
66 | 82 | if isinstance(module, str):
|
67 |
| - if driver_ver < 12000 and jit_options is not None: |
| 83 | + if _driver_ver < 12000 and jit_options is not None: |
68 | 84 | raise ValueError
|
69 | 85 | module = module.encode()
|
70 | 86 | self._handle = handle_return(self._loader["file"](module))
|
|
0 commit comments