diff --git a/api/common/launch.py b/api/common/launch.py index 623de5caf6..cdab268dac 100644 --- a/api/common/launch.py +++ b/api/common/launch.py @@ -23,14 +23,28 @@ def is_ampere_gpu(): - stdout, exit_code = system.run_command("nvidia-smi -L") - if exit_code == 0: - gpu_list = stdout.split("\n") - if len(gpu_list) >= 1: + try: + from pynvml import nvmlInit + from pynvml import nvmlDeviceGetHandleByIndex + from pynvml import nvmlDeviceGetCudaComputeCapability + nvmlInit() + handle = nvmlDeviceGetHandleByIndex(0) + cc_major,cc_minor = nvmlDeviceGetCudaComputeCapability(handle) + #print(str(cc_major)+"."+str(cc_minor)) + # 8.0 or 8.6 + return str(cc_major)+"."+str(cc_minor)>"8.0" + except ImportError: + print("Warning: pynvml package is not installed, please install it as follow \"pip install pynvml\"") + stdout, exit_code = system.run_command("nvidia-smi -L") + if exit_code == 0: + gpu_list = stdout.split("\n") #print(gpu_list[0]) # GPU 0: NVIDIA A100-SXM4-40GB (UUID: xxxx) - return gpu_list[0].find("A100") > 0 - return False + else: + print("Error: Failed to run sys command \"nvidia-smi -L\"") + return False + return gpu_list[0].find("A100") > 0 + class NvprofRunner(object):