From d9c10d35481440faae56b8d3c265cbef04c0ae68 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Thu, 22 Apr 2021 19:43:09 +0800 Subject: [PATCH] update cuda versin check --- mindspore/_check_version.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/mindspore/_check_version.py b/mindspore/_check_version.py index 978d748c2be..1715c5be536 100644 --- a/mindspore/_check_version.py +++ b/mindspore/_check_version.py @@ -99,7 +99,7 @@ class GPUEnvChecker(EnvChecker): """Get cudnn version by libcudnn.so.""" cudnn_version = [] for path in self.cudnn_lib_path: - ls_cudnn = subprocess.run(["ls " + path + "/lib64/libcudnn.so.*.*"], timeout=10, text=True, + ls_cudnn = subprocess.run(["ls " + path + "/lib*/libcudnn.so.*.*"], timeout=10, text=True, capture_output=True, check=False, shell=True) if ls_cudnn.returncode == 0: cudnn_version = ls_cudnn.stdout.split('/')[-1].strip('libcudnn.so.').strip().split('.') @@ -107,21 +107,26 @@ class GPUEnvChecker(EnvChecker): cudnn_version.append('0') break version_str = ''.join([n for n in cudnn_version]) - return version_str + return version_str[0:3] + + def _get_cudart_version(self): + """Get cuda runtime version by libcudart.so.""" + for path in self.cuda_lib_path: + ls_cudart = subprocess.run(["ls " + path + "/lib*/libcudart.so.*.*.*"], timeout=10, text=True, + capture_output=True, check=False, shell=True) + if ls_cudart.returncode == 0: + self.v = ls_cudart.stdout.split('/')[-1].strip('libcudart.so.').strip() + break + return self.v def check_version(self): """Check cuda version.""" version_match = False - for path in self.cuda_lib_path: - version_file = path + "/version.txt" - if not Path(version_file).is_file(): - continue - if self._check_version(version_file): - version_match = True - break + if self._check_version(): + version_match = True if not version_match: if self.v == "0": - logger.warning("Cuda version file version.txt is not found, please confirm that the correct " + logger.warning("Can not found cuda libs, please confirm that the correct " "cuda version has been installed, you can refer to the " "installation guidelines: https://www.mindspore.cn/install") else: @@ -145,9 +150,9 @@ class GPUEnvChecker(EnvChecker): "information: https://www.mindspore.cn/install. The recommended version is " "CUAD11.1 with cuDNN8.0.x") - def _check_version(self, version_file): - """Check cuda version by version.txt.""" - v = self._read_version(version_file) + def _check_version(self): + """Check cuda version""" + v = self._get_cudart_version() v = version.parse(v) v_str = str(v.major) + "." + str(v.minor) if v_str not in self.version: @@ -383,10 +388,9 @@ def check_version_and_env_config(): return try: - from . import _c_expression # check version of ascend site or cuda env_checker.check_version() - + from . import _c_expression env_checker.set_env() except ImportError as e: env_checker.check_env(e)