Skip to content

Commit f7ff9d1

Browse files
committed
More precise version checks
1 parent be61180 commit f7ff9d1

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

cuda_core/cuda/core/experimental/system/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
# We need both the existence of cuda.bindings._nvml and a sufficient version
2222
# with the APIs implemented as we need them.
2323

24-
if tuple(int(x) for x in cuda.bindings.__version__.split(".")) >= (13, 1, 1):
24+
_BINDINGS_VERSION = tuple(int(x) for x in cuda.bindings.__version__.split("."))
25+
26+
if (_BINDINGS_VERSION[0] == 13 and _BINDINGS_VERSION[1:3] >= (1, 1)) or (
27+
_BINDINGS_VERSION[0] == 12 and _BINDINGS_VERSION[1:3] >= (9, 5)
28+
):
2529
from cuda.bindings import _nvml
2630

2731
from ._nvml_context import initialize

cuda_core/tests/system/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
CUDA_BINDINGS_VERSION = tuple(int(x) for x in cuda.bindings.__version__.split("."))
1010

1111

12-
NVML_SUPPORTED = CUDA_BINDINGS_VERSION >= (13, 1, 1)
12+
NVML_SUPPORTED = (CUDA_BINDINGS_VERSION[0] == 13 and CUDA_BINDINGS_VERSION[1:3] >= (1, 1)) or (
13+
CUDA_BINDINGS_VERSION[0] == 12 and CUDA_BINDINGS_VERSION[1:3] >= (9, 5)
14+
)
1315

1416

1517
skip_if_nvml_unsupported = pytest.mark.skipif(

0 commit comments

Comments
 (0)