Skip to content

Commit 2ca9704

Browse files
Implement suggestions from @leofang
1. Narrow fixture scope 2. Rename fixture to cuda12_prerequisite_check that provide a boolean rather than a pair of versions
1 parent 917a386 commit 2ca9704

File tree

1 file changed

+14
-19
lines changed

1 file changed

+14
-19
lines changed

cuda_core/tests/test_module.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@
66
# this software and related documentation outside the terms of the EULA
77
# is strictly prohibited.
88

9-
try:
10-
from cuda.bindings import driver
11-
except ImportError:
12-
from cuda import cuda as driver
139

1410
import ctypes
1511
import warnings
@@ -19,7 +15,7 @@
1915

2016
import cuda.core.experimental
2117
from cuda.core.experimental import ObjectCode, Program, ProgramOptions, system
22-
from cuda.core.experimental._utils import cuda_utils
18+
from cuda.core.experimental._utils.cuda_utils import CUDAError, driver, get_binding_version, handle_return
2319

2420
SAXPY_KERNEL = r"""
2521
template<typename T>
@@ -37,11 +33,12 @@
3733

3834

3935
@pytest.fixture(scope="module")
40-
def cuda_version():
36+
def cuda12_prerequisite_check():
4137
# binding availability depends on cuda-python version
42-
_py_major_ver, _ = cuda_utils.get_binding_version()
43-
_driver_ver = cuda_utils.handle_return(driver.cuDriverGetVersion())
44-
return _py_major_ver, _driver_ver
38+
# and version of underlying CUDA toolkit
39+
_py_major_ver, _ = get_binding_version()
40+
_driver_ver = handle_return(driver.cuDriverGetVersion())
41+
return _py_major_ver >= 12 and _driver_ver >= 12000
4542

4643

4744
def test_kernel_attributes_init_disabled():
@@ -180,9 +177,8 @@ def test_object_code_handle(get_saxpy_object_code):
180177

181178

182179
@skipif_testing_with_compute_sanitizer
183-
def test_saxpy_arguments(get_saxpy_kernel, cuda_version):
184-
_, dr_ver = cuda_version
185-
if dr_ver < 12:
180+
def test_saxpy_arguments(get_saxpy_kernel, cuda12_prerequisite_check):
181+
if not cuda12_prerequisite_check:
186182
pytest.skip("Test requires CUDA 12")
187183
krn, _ = get_saxpy_kernel
188184

@@ -213,9 +209,8 @@ class ExpectedStruct(ctypes.Structure):
213209
@skipif_testing_with_compute_sanitizer
214210
@pytest.mark.parametrize("nargs", [0, 1, 2, 3, 16])
215211
@pytest.mark.parametrize("c_type_name,c_type", [("int", ctypes.c_int), ("short", ctypes.c_short)], ids=["int", "short"])
216-
def test_num_arguments(init_cuda, nargs, c_type_name, c_type, cuda_version):
217-
_, dr_ver = cuda_version
218-
if dr_ver < 12:
212+
def test_num_arguments(init_cuda, nargs, c_type_name, c_type, cuda12_prerequisite_check):
213+
if not cuda12_prerequisite_check:
219214
pytest.skip("Test requires CUDA 12")
220215
args_str = ", ".join([f"{c_type_name} p_{i}" for i in range(nargs)])
221216
src = f"__global__ void foo{nargs}({args_str}) {{ }}"
@@ -238,9 +233,8 @@ class ExpectedStruct(ctypes.Structure):
238233

239234

240235
@skipif_testing_with_compute_sanitizer
241-
def check_num_args_error_handling(deinit_cuda, cuda_version):
242-
_, dr_ver = cuda_version
243-
if dr_ver < 12:
236+
def test_num_args_error_handling(deinit_cuda, cuda12_prerequisite_check):
237+
if not cuda12_prerequisite_check:
244238
pytest.skip("Test requires CUDA 12")
245239
src = "__global__ void foo(int a) { }"
246240
prog = Program(src, code_type="c++")
@@ -249,5 +243,6 @@ def check_num_args_error_handling(deinit_cuda, cuda_version):
249243
name_expressions=("foo",),
250244
)
251245
krn = mod.get_kernel("foo")
252-
with pytest.raises(cuda_utils.CUDAError):
246+
with pytest.raises(CUDAError):
247+
# assignment resolves linter error "B018: useless expression"
253248
_ = krn.num_arguments

0 commit comments

Comments
 (0)