Skip to content

Commit f83eff2

Browse files
authored
Cythonize launch & LaunchConfig more (#1390)
* cythonize launch & launch config * nits
1 parent 87b49b7 commit f83eff2

File tree

5 files changed

+135
-56
lines changed

5 files changed

+135
-56
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from libc.stdint cimport intptr_t
6+
from libcpp cimport vector
7+
8+
9+
cdef class ParamHolder:
10+
11+
cdef:
12+
vector.vector[void*] data
13+
vector.vector[void*] data_addresses
14+
object kernel_args
15+
readonly intptr_t ptr

cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -250,12 +250,6 @@ cdef inline int prepare_numpy_arg(
250250

251251
cdef class ParamHolder:
252252

253-
cdef:
254-
vector.vector[void*] data
255-
vector.vector[void*] data_addresses
256-
object kernel_args
257-
readonly intptr_t ptr
258-
259253
def __init__(self, kernel_args):
260254
if len(kernel_args) == 0:
261255
self.ptr = 0

cuda_core/cuda/core/experimental/_launch_config.pxd

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,23 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from libcpp.vector cimport vector
56

6-
cdef bint _inited
7-
cdef bint _use_ex
7+
from cuda.bindings cimport cydriver
88

9-
cdef void _lazy_init() except *
109

1110
cdef class LaunchConfig:
1211
"""Customizable launch options."""
13-
cdef public tuple grid
14-
cdef public tuple cluster
15-
cdef public tuple block
16-
cdef public int shmem_size
17-
cdef public bint cooperative_launch
12+
cdef:
13+
public tuple grid
14+
public tuple cluster
15+
public tuple block
16+
public int shmem_size
17+
public bint cooperative_launch
18+
19+
vector[cydriver.CUlaunchAttribute] _attrs
20+
21+
cdef cydriver.CUlaunchConfig _to_native_launch_config(self)
22+
1823

1924
cpdef object _to_native_launch_config(LaunchConfig config)

cuda_core/cuda/core/experimental/_launch_config.pyx

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,44 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from cuda.core.experimental._utils.cuda_utils cimport (
6+
HANDLE_RETURN,
7+
)
8+
9+
import threading
10+
511
from cuda.core.experimental._device import Device
612
from cuda.core.experimental._utils.cuda_utils import (
713
CUDAError,
814
cast_to_3_tuple,
915
driver,
1016
get_binding_version,
11-
handle_return,
1217
)
1318

14-
# TODO: revisit this treatment for py313t builds
19+
1520
cdef bint _inited = False
1621
cdef bint _use_ex = False
22+
cdef object _lock = threading.Lock()
1723

1824

19-
cdef void _lazy_init() except *:
20-
"""Initialize module-level globals for driver version checks."""
25+
cdef int _lazy_init() except?-1:
2126
global _inited, _use_ex
2227
if _inited:
23-
return
28+
return 0
2429

2530
cdef tuple _py_major_minor
2631
cdef int _driver_ver
32+
with _lock:
33+
if _inited:
34+
return 0
35+
36+
# binding availability depends on cuda-python version
37+
_py_major_minor = get_binding_version()
38+
HANDLE_RETURN(cydriver.cuDriverGetVersion(&_driver_ver))
39+
_use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8))
40+
_inited = True
2741

28-
# binding availability depends on cuda-python version
29-
_py_major_minor = get_binding_version()
30-
_driver_ver = handle_return(driver.cuDriverGetVersion())
31-
_use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8))
32-
_inited = True
42+
return 0
3343

3444

3545
cdef class LaunchConfig:
@@ -127,7 +137,42 @@ cdef class LaunchConfig:
127137
f"block={self.block}, shmem_size={self.shmem_size}, "
128138
f"cooperative_launch={self.cooperative_launch})")
129139

140+
cdef cydriver.CUlaunchConfig _to_native_launch_config(self):
141+
_lazy_init()
142+
# TODO: memset to zero?
143+
cdef cydriver.CUlaunchConfig drv_cfg
144+
cdef cydriver.CUlaunchAttribute attr
145+
self._attrs.resize(0)
146+
147+
# Handle grid dimensions and cluster configuration
148+
if self.cluster is not None:
149+
# Convert grid from cluster units to block units
150+
drv_cfg.gridDimX = self.grid[0] * self.cluster[0]
151+
drv_cfg.gridDimY = self.grid[1] * self.cluster[1]
152+
drv_cfg.gridDimZ = self.grid[2] * self.cluster[2]
153+
154+
# Set up cluster attribute
155+
attr.id = cydriver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
156+
attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.cluster
157+
self._attrs.push_back(attr)
158+
else:
159+
drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = self.grid
160+
161+
drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = self.block
162+
drv_cfg.sharedMemBytes = self.shmem_size
163+
164+
if self.cooperative_launch:
165+
attr.id = cydriver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_COOPERATIVE
166+
attr.value.cooperative = 1
167+
self._attrs.push_back(attr)
168+
169+
drv_cfg.numAttrs = self._attrs.size()
170+
drv_cfg.attrs = self._attrs.data()
171+
172+
return drv_cfg
173+
130174

175+
# TODO: once all modules are cythonized, this function can be dropped in favor of the cdef method above
131176
cpdef object _to_native_launch_config(LaunchConfig config):
132177
"""Convert LaunchConfig to native driver CUlaunchConfig.
133178

cuda_core/cuda/core/experimental/_launcher.pyx

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,51 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
from cuda.core.experimental._launch_config cimport LaunchConfig, _to_native_launch_config
5-
from cuda.core.experimental._stream cimport Stream_accept
64

5+
from libc.stdint cimport uintptr_t
6+
7+
from cuda.bindings cimport cydriver
8+
9+
from cuda.core.experimental._launch_config cimport LaunchConfig
10+
from cuda.core.experimental._kernel_arg_handler cimport ParamHolder
11+
from cuda.core.experimental._stream cimport Stream_accept, Stream
12+
from cuda.core.experimental._utils.cuda_utils cimport (
13+
check_or_create_options,
14+
HANDLE_RETURN,
15+
)
16+
17+
import threading
718

8-
from cuda.core.experimental._kernel_arg_handler import ParamHolder
919
from cuda.core.experimental._module import Kernel
1020
from cuda.core.experimental._stream import Stream
11-
from cuda.core.experimental._utils.clear_error_support import assert_type
1221
from cuda.core.experimental._utils.cuda_utils import (
1322
_reduce_3_tuple,
14-
check_or_create_options,
15-
driver,
1623
get_binding_version,
17-
handle_return,
1824
)
1925

20-
# TODO: revisit this treatment for py313t builds
21-
_inited = False
22-
_use_ex = None
26+
27+
cdef bint _inited = False
28+
cdef bint _use_ex = False
29+
cdef object _lock = threading.Lock()
2330

2431

25-
def _lazy_init():
26-
global _inited
32+
cdef int _lazy_init() except?-1:
33+
global _inited, _use_ex
2734
if _inited:
28-
return
35+
return 0
36+
37+
cdef int _driver_ver
38+
with _lock:
39+
if _inited:
40+
return 0
2941

30-
global _use_ex
31-
# binding availability depends on cuda-python version
32-
_py_major_minor = get_binding_version()
33-
_driver_ver = handle_return(driver.cuDriverGetVersion())
34-
_use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8))
35-
_inited = True
42+
# binding availability depends on cuda-python version
43+
_py_major_minor = get_binding_version()
44+
HANDLE_RETURN(cydriver.cuDriverGetVersion(&_driver_ver))
45+
_use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8))
46+
_inited = True
47+
48+
return 0
3649

3750

3851
def launch(stream: Stream | GraphBuilder | IsStreamT, config: LaunchConfig, kernel: Kernel, *kernel_args):
@@ -54,32 +67,39 @@ def launch(stream: Stream | GraphBuilder | IsStreamT, config: LaunchConfig, kern
5467
launching kernel.
5568
5669
"""
57-
stream = Stream_accept(stream, allow_stream_protocol=True)
58-
assert_type(kernel, Kernel)
70+
cdef Stream s = Stream_accept(stream, allow_stream_protocol=True)
5971
_lazy_init()
60-
config = check_or_create_options(LaunchConfig, config, "launch config")
72+
cdef LaunchConfig conf = check_or_create_options(LaunchConfig, config, "launch config")
6173

6274
# TODO: can we ensure kernel_args is valid/safe to use here?
6375
# TODO: merge with HelperKernelParams?
64-
kernel_args = ParamHolder(kernel_args)
65-
args_ptr = kernel_args.ptr
76+
cdef ParamHolder ker_args = ParamHolder(kernel_args)
77+
cdef void** args_ptr = <void**><uintptr_t>(ker_args.ptr)
78+
79+
# TODO: cythonize Module/Kernel/...
80+
# Note: CUfunction and CUkernel are interchangeable
81+
cdef cydriver.CUfunction func_handle = <cydriver.CUfunction>(<uintptr_t>(kernel._handle))
6682

6783
# Note: CUkernel can still be launched via the old cuLaunchKernel and we do not care
6884
# about the CUfunction/CUkernel difference (which depends on whether the "old" or
6985
# "new" module loading APIs are in use). We check both binding & driver versions here
7086
# mainly to see if the "Ex" API is available and if so we use it, as it's more feature
7187
# rich.
7288
if _use_ex:
73-
drv_cfg = _to_native_launch_config(config)
74-
drv_cfg.hStream = stream.handle
75-
if config.cooperative_launch:
76-
_check_cooperative_launch(kernel, config, stream)
77-
handle_return(driver.cuLaunchKernelEx(drv_cfg, int(kernel._handle), args_ptr, 0))
89+
drv_cfg = conf._to_native_launch_config()
90+
drv_cfg.hStream = s._handle
91+
if conf.cooperative_launch:
92+
_check_cooperative_launch(kernel, conf, s)
93+
with nogil:
94+
HANDLE_RETURN(cydriver.cuLaunchKernelEx(&drv_cfg, func_handle, args_ptr, NULL))
7895
else:
7996
# TODO: check if config has any unsupported attrs
80-
handle_return(
81-
driver.cuLaunchKernel(
82-
int(kernel._handle), *config.grid, *config.block, config.shmem_size, stream.handle, args_ptr, 0
97+
HANDLE_RETURN(
98+
cydriver.cuLaunchKernel(
99+
func_handle,
100+
conf.grid[0], conf.grid[1], conf.grid[2],
101+
conf.block[0], conf.block[1], conf.block[2],
102+
conf.shmem_size, s._handle, args_ptr, NULL
83103
)
84104
)
85105

0 commit comments

Comments
 (0)