11# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22#
33# SPDX-License-Identifier: Apache-2.0
4- from cuda.core.experimental._launch_config cimport LaunchConfig, _to_native_launch_config
5- from cuda.core.experimental._stream cimport Stream_accept
64
5+ from libc.stdint cimport uintptr_t
6+
7+ from cuda.bindings cimport cydriver
8+
9+ from cuda.core.experimental._launch_config cimport LaunchConfig
10+ from cuda.core.experimental._kernel_arg_handler cimport ParamHolder
11+ from cuda.core.experimental._stream cimport Stream_accept, Stream
12+ from cuda.core.experimental._utils.cuda_utils cimport (
13+ check_or_create_options,
14+ HANDLE_RETURN,
15+ )
16+
17+ import threading
718
8- from cuda.core.experimental._kernel_arg_handler import ParamHolder
919from cuda.core.experimental._module import Kernel
1020from cuda.core.experimental._stream import Stream
11- from cuda.core.experimental._utils.clear_error_support import assert_type
1221from cuda.core.experimental._utils.cuda_utils import (
1322 _reduce_3_tuple,
14- check_or_create_options,
15- driver,
1623 get_binding_version,
17- handle_return,
1824)
1925
20- # TODO: revisit this treatment for py313t builds
21- _inited = False
22- _use_ex = None
26+
27+ cdef bint _inited = False
28+ cdef bint _use_ex = False
29+ cdef object _lock = threading.Lock()
2330
2431
25- def _lazy_init ():
26- global _inited
32+ cdef int _lazy_init() except ? - 1 :
33+ global _inited, _use_ex
2734 if _inited:
28- return
35+ return 0
36+
37+ cdef int _driver_ver
38+ with _lock:
39+ if _inited:
40+ return 0
2941
30- global _use_ex
31- # binding availability depends on cuda-python version
32- _py_major_minor = get_binding_version()
33- _driver_ver = handle_return(driver.cuDriverGetVersion())
34- _use_ex = (_driver_ver >= 11080 ) and (_py_major_minor >= (11 , 8 ))
35- _inited = True
42+ # binding availability depends on cuda-python version
43+ _py_major_minor = get_binding_version()
44+ HANDLE_RETURN(cydriver.cuDriverGetVersion(& _driver_ver))
45+ _use_ex = (_driver_ver >= 11080 ) and (_py_major_minor >= (11 , 8 ))
46+ _inited = True
47+
48+ return 0
3649
3750
3851def launch (stream: Stream | GraphBuilder | IsStreamT , config: LaunchConfig , kernel: Kernel , *kernel_args ):
@@ -54,32 +67,39 @@ def launch(stream: Stream | GraphBuilder | IsStreamT, config: LaunchConfig, kern
5467 launching kernel.
5568
5669 """
57- stream = Stream_accept(stream, allow_stream_protocol = True )
58- assert_type(kernel, Kernel)
70+ cdef Stream s = Stream_accept(stream, allow_stream_protocol = True )
5971 _lazy_init()
60- config = check_or_create_options(LaunchConfig, config, " launch config" )
72+ cdef LaunchConfig conf = check_or_create_options(LaunchConfig, config, " launch config" )
6173
6274 # TODO: can we ensure kernel_args is valid/safe to use here?
6375 # TODO: merge with HelperKernelParams?
64- kernel_args = ParamHolder(kernel_args)
65- args_ptr = kernel_args.ptr
76+ cdef ParamHolder ker_args = ParamHolder(kernel_args)
77+ cdef void ** args_ptr = < void ** >< uintptr_t> (ker_args.ptr)
78+
79+ # TODO: cythonize Module/Kernel/...
80+ # Note: CUfunction and CUkernel are interchangeable
81+ cdef cydriver.CUfunction func_handle = < cydriver.CUfunction> (< uintptr_t> (kernel._handle))
6682
6783 # Note: CUkernel can still be launched via the old cuLaunchKernel and we do not care
6884 # about the CUfunction/CUkernel difference (which depends on whether the "old" or
6985 # "new" module loading APIs are in use). We check both binding & driver versions here
7086 # mainly to see if the "Ex" API is available and if so we use it, as it's more feature
7187 # rich.
7288 if _use_ex:
73- drv_cfg = _to_native_launch_config(config)
74- drv_cfg.hStream = stream.handle
75- if config.cooperative_launch:
76- _check_cooperative_launch(kernel, config, stream)
77- handle_return(driver.cuLaunchKernelEx(drv_cfg, int (kernel._handle), args_ptr, 0 ))
89+ drv_cfg = conf._to_native_launch_config()
90+ drv_cfg.hStream = s._handle
91+ if conf.cooperative_launch:
92+ _check_cooperative_launch(kernel, conf, s)
93+ with nogil:
94+ HANDLE_RETURN(cydriver.cuLaunchKernelEx(& drv_cfg, func_handle, args_ptr, NULL ))
7895 else :
7996 # TODO: check if config has any unsupported attrs
80- handle_return(
81- driver.cuLaunchKernel(
82- int (kernel._handle), * config.grid, * config.block, config.shmem_size, stream.handle, args_ptr, 0
97+ HANDLE_RETURN(
98+ cydriver.cuLaunchKernel(
99+ func_handle,
100+ conf.grid[0 ], conf.grid[1 ], conf.grid[2 ],
101+ conf.block[0 ], conf.block[1 ], conf.block[2 ],
102+ conf.shmem_size, s._handle, args_ptr, NULL
83103 )
84104 )
85105
0 commit comments