Skip to content

Commit 55f6d31

Browse files
Use ParamInfo instead of NamedTuple is annotation
This required moving ParamInfo definition from class scope to module scope, since referencing Kernel.ParamInfo from annotations of methods of the Kernel class results in error that Kernel class does not yet exist.
1 parent 2ca9704 commit 55f6d31

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

cuda_core/cuda/core/experimental/_module.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
from collections import namedtuple
6-
from typing import NamedTuple, Optional, Union
6+
from typing import Optional, Union
77
from warnings import warn
88

99
from cuda.core.experimental._utils.clear_error_support import (
@@ -184,6 +184,9 @@ def cluster_scheduling_policy_preference(self, device_id: int = None) -> int:
184184
)
185185

186186

187+
ParamInfo = namedtuple("ParamInfo", ["offset", "size"])
188+
189+
187190
class Kernel:
188191
"""Represent a compiled kernel that had been loaded onto the device.
189192
@@ -196,7 +199,6 @@ class Kernel:
196199
"""
197200

198201
__slots__ = ("_handle", "_module", "_attributes")
199-
ParamInfo = namedtuple("ParamInfo", ["offset", "size"])
200202

201203
def __new__(self, *args, **kwargs):
202204
raise RuntimeError("Kernel objects cannot be instantiated directly. Please use ObjectCode APIs.")
@@ -218,7 +220,7 @@ def attributes(self) -> KernelAttributes:
218220
self._attributes = KernelAttributes._init(self._handle)
219221
return self._attributes
220222

221-
def _get_arguments_info(self, param_info=False) -> tuple[int, list[NamedTuple]]:
223+
def _get_arguments_info(self, param_info=False) -> tuple[int, list[ParamInfo]]:
222224
attr_impl = self.attributes
223225
if attr_impl._backend_version != "new":
224226
raise NotImplementedError("New backend is required")
@@ -229,7 +231,7 @@ def _get_arguments_info(self, param_info=False) -> tuple[int, list[NamedTuple]]:
229231
if result[0] != driver.CUresult.CUDA_SUCCESS:
230232
break
231233
if param_info:
232-
p_info = Kernel.ParamInfo(offset=result[1], size=result[2])
234+
p_info = ParamInfo(offset=result[1], size=result[2])
233235
param_info_data.append(p_info)
234236
arg_pos = arg_pos + 1
235237
if result[0] != driver.CUresult.CUDA_ERROR_INVALID_VALUE:
@@ -243,8 +245,8 @@ def num_arguments(self) -> int:
243245
return num_args
244246

245247
@property
246-
def arguments_info(self) -> list[NamedTuple]:
247-
"""list[NamedTuple[int, int]]: (offset, size) for each argument of this function"""
248+
def arguments_info(self) -> list[ParamInfo]:
249+
"""list[ParamInfo]: (offset, size) for each argument of this function"""
248250
_, param_info = self._get_arguments_info(param_info=True)
249251
return param_info
250252

0 commit comments

Comments
 (0)