66# this software and related documentation outside the terms of the EULA
77# is strictly prohibited.
88
9- try :
10- from cuda .bindings import driver
11- except ImportError :
12- from cuda import cuda as driver
139
1410import ctypes
1511import warnings
1915
2016import cuda .core .experimental
2117from cuda .core .experimental import ObjectCode , Program , ProgramOptions , system
22- from cuda .core .experimental ._utils import cuda_utils
18+ from cuda .core .experimental ._utils . cuda_utils import CUDAError , driver , get_binding_version , handle_return
2319
2420SAXPY_KERNEL = r"""
2521template<typename T>
3733
3834
3935@pytest .fixture (scope = "module" )
40- def cuda_version ():
36+ def cuda12_prerequisite_check ():
4137 # binding availability depends on cuda-python version
42- _py_major_ver , _ = cuda_utils .get_binding_version ()
43- _driver_ver = cuda_utils .handle_return (driver .cuDriverGetVersion ())
44- return _py_major_ver , _driver_ver
38+ # and version of underlying CUDA toolkit
39+ _py_major_ver , _ = get_binding_version ()
40+ _driver_ver = handle_return (driver .cuDriverGetVersion ())
41+ return _py_major_ver >= 12 and _driver_ver >= 12000
4542
4643
4744def test_kernel_attributes_init_disabled ():
@@ -180,9 +177,8 @@ def test_object_code_handle(get_saxpy_object_code):
180177
181178
182179@skipif_testing_with_compute_sanitizer
183- def test_saxpy_arguments (get_saxpy_kernel , cuda_version ):
184- _ , dr_ver = cuda_version
185- if dr_ver < 12 :
180+ def test_saxpy_arguments (get_saxpy_kernel , cuda12_prerequisite_check ):
181+ if not cuda12_prerequisite_check :
186182 pytest .skip ("Test requires CUDA 12" )
187183 krn , _ = get_saxpy_kernel
188184
@@ -213,9 +209,8 @@ class ExpectedStruct(ctypes.Structure):
213209@skipif_testing_with_compute_sanitizer
214210@pytest .mark .parametrize ("nargs" , [0 , 1 , 2 , 3 , 16 ])
215211@pytest .mark .parametrize ("c_type_name,c_type" , [("int" , ctypes .c_int ), ("short" , ctypes .c_short )], ids = ["int" , "short" ])
216- def test_num_arguments (init_cuda , nargs , c_type_name , c_type , cuda_version ):
217- _ , dr_ver = cuda_version
218- if dr_ver < 12 :
212+ def test_num_arguments (init_cuda , nargs , c_type_name , c_type , cuda12_prerequisite_check ):
213+ if not cuda12_prerequisite_check :
219214 pytest .skip ("Test requires CUDA 12" )
220215 args_str = ", " .join ([f"{ c_type_name } p_{ i } " for i in range (nargs )])
221216 src = f"__global__ void foo{ nargs } ({ args_str } ) {{ }}"
@@ -238,9 +233,8 @@ class ExpectedStruct(ctypes.Structure):
238233
239234
240235@skipif_testing_with_compute_sanitizer
241- def check_num_args_error_handling (deinit_cuda , cuda_version ):
242- _ , dr_ver = cuda_version
243- if dr_ver < 12 :
236+ def test_num_args_error_handling (deinit_cuda , cuda12_prerequisite_check ):
237+ if not cuda12_prerequisite_check :
244238 pytest .skip ("Test requires CUDA 12" )
245239 src = "__global__ void foo(int a) { }"
246240 prog = Program (src , code_type = "c++" )
@@ -249,5 +243,6 @@ def check_num_args_error_handling(deinit_cuda, cuda_version):
249243 name_expressions = ("foo" ,),
250244 )
251245 krn = mod .get_kernel ("foo" )
252- with pytest .raises (cuda_utils .CUDAError ):
246+ with pytest .raises (CUDAError ):
247+ # assignment resolves linter error "B018: useless expression"
253248 _ = krn .num_arguments
0 commit comments