From c04071fa47186429c0c7e5ff0e8e19eb3d1ecd10 Mon Sep 17 00:00:00 2001 From: Altair Liu <1580802568@qq.com> Date: Wed, 5 Jan 2022 16:51:07 +0800 Subject: [PATCH] modify is_ampere_gpu in api/common/launch.py --- api/common/launch.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/api/common/launch.py b/api/common/launch.py index 623de5caf6..cdab268dac 100644 --- a/api/common/launch.py +++ b/api/common/launch.py @@ -23,14 +23,28 @@ def is_ampere_gpu(): - stdout, exit_code = system.run_command("nvidia-smi -L") - if exit_code == 0: - gpu_list = stdout.split("\n") - if len(gpu_list) >= 1: + try: + from pynvml import nvmlInit + from pynvml import nvmlDeviceGetHandleByIndex + from pynvml import nvmlDeviceGetCudaComputeCapability + nvmlInit() + handle = nvmlDeviceGetHandleByIndex(0) + cc_major,cc_minor = nvmlDeviceGetCudaComputeCapability(handle) + #print(str(cc_major)+"."+str(cc_minor)) + # 8.0 or 8.6 + return str(cc_major)+"."+str(cc_minor)>"8.0" + except ImportError: + print("Warning: pynvml package is not installed, please install it as follow \"pip install pynvml\"") + stdout, exit_code = system.run_command("nvidia-smi -L") + if exit_code == 0: + gpu_list = stdout.split("\n") #print(gpu_list[0]) # GPU 0: NVIDIA A100-SXM4-40GB (UUID: xxxx) - return gpu_list[0].find("A100") > 0 - return False + else: + print("Error: Failed to run sys command \"nvidia-smi -L\"") + return False + return gpu_list[0].find("A100") > 0 + class NvprofRunner(object):