33# SPDX-License-Identifier: Apache-2.0
44
55from collections import namedtuple
6- from typing import NamedTuple , Optional , Union
6+ from typing import Optional , Union
77from warnings import warn
88
99from cuda .core .experimental ._utils .clear_error_support import (
@@ -184,6 +184,9 @@ def cluster_scheduling_policy_preference(self, device_id: int = None) -> int:
184184 )
185185
186186
187+ ParamInfo = namedtuple ("ParamInfo" , ["offset" , "size" ])
188+
189+
187190class Kernel :
188191 """Represent a compiled kernel that had been loaded onto the device.
189192
@@ -196,7 +199,6 @@ class Kernel:
196199 """
197200
198201 __slots__ = ("_handle" , "_module" , "_attributes" )
199- ParamInfo = namedtuple ("ParamInfo" , ["offset" , "size" ])
200202
201203 def __new__ (self , * args , ** kwargs ):
202204 raise RuntimeError ("Kernel objects cannot be instantiated directly. Please use ObjectCode APIs." )
@@ -218,7 +220,7 @@ def attributes(self) -> KernelAttributes:
218220 self ._attributes = KernelAttributes ._init (self ._handle )
219221 return self ._attributes
220222
221- def _get_arguments_info (self , param_info = False ) -> tuple [int , list [NamedTuple ]]:
223+ def _get_arguments_info (self , param_info = False ) -> tuple [int , list [ParamInfo ]]:
222224 attr_impl = self .attributes
223225 if attr_impl ._backend_version != "new" :
224226 raise NotImplementedError ("New backend is required" )
@@ -229,7 +231,7 @@ def _get_arguments_info(self, param_info=False) -> tuple[int, list[NamedTuple]]:
229231 if result [0 ] != driver .CUresult .CUDA_SUCCESS :
230232 break
231233 if param_info :
232- p_info = Kernel . ParamInfo (offset = result [1 ], size = result [2 ])
234+ p_info = ParamInfo (offset = result [1 ], size = result [2 ])
233235 param_info_data .append (p_info )
234236 arg_pos = arg_pos + 1
235237 if result [0 ] != driver .CUresult .CUDA_ERROR_INVALID_VALUE :
@@ -243,8 +245,8 @@ def num_arguments(self) -> int:
243245 return num_args
244246
245247 @property
246- def arguments_info (self ) -> list [NamedTuple ]:
247- """list[NamedTuple[int, int] ]: (offset, size) for each argument of this function"""
248+ def arguments_info (self ) -> list [ParamInfo ]:
249+ """list[ParamInfo ]: (offset, size) for each argument of this function"""
248250 _ , param_info = self ._get_arguments_info (param_info = True )
249251 return param_info
250252
0 commit comments