!15536 GPU update cuda version check

From: @VectorSL
Reviewed-by: @wilfchen,@cristoval
Signed-off-by: @cristoval
This commit is contained in:
mindspore-ci-bot 2021-05-11 09:26:24 +08:00 committed by Gitee
commit 6f0456e274
1 changed files with 19 additions and 15 deletions

View File

@ -99,7 +99,7 @@ class GPUEnvChecker(EnvChecker):
"""Get cudnn version by libcudnn.so."""
cudnn_version = []
for path in self.cudnn_lib_path:
ls_cudnn = subprocess.run(["ls " + path + "/lib64/libcudnn.so.*.*"], timeout=10, text=True,
ls_cudnn = subprocess.run(["ls " + path + "/lib*/libcudnn.so.*.*"], timeout=10, text=True,
capture_output=True, check=False, shell=True)
if ls_cudnn.returncode == 0:
cudnn_version = ls_cudnn.stdout.split('/')[-1].strip('libcudnn.so.').strip().split('.')
@ -107,21 +107,26 @@ class GPUEnvChecker(EnvChecker):
cudnn_version.append('0')
break
version_str = ''.join([n for n in cudnn_version])
return version_str
return version_str[0:3]
def _get_cudart_version(self):
"""Get cuda runtime version by libcudart.so."""
for path in self.cuda_lib_path:
ls_cudart = subprocess.run(["ls " + path + "/lib*/libcudart.so.*.*.*"], timeout=10, text=True,
capture_output=True, check=False, shell=True)
if ls_cudart.returncode == 0:
self.v = ls_cudart.stdout.split('/')[-1].strip('libcudart.so.').strip()
break
return self.v
def check_version(self):
"""Check cuda version."""
version_match = False
for path in self.cuda_lib_path:
version_file = path + "/version.txt"
if not Path(version_file).is_file():
continue
if self._check_version(version_file):
version_match = True
break
if self._check_version():
version_match = True
if not version_match:
if self.v == "0":
logger.warning("Cuda version file version.txt is not found, please confirm that the correct "
logger.warning("Can not found cuda libs, please confirm that the correct "
"cuda version has been installed, you can refer to the "
"installation guidelines: https://www.mindspore.cn/install")
else:
@ -145,9 +150,9 @@ class GPUEnvChecker(EnvChecker):
"information: https://www.mindspore.cn/install. The recommended version is "
"CUAD11.1 with cuDNN8.0.x")
def _check_version(self, version_file):
"""Check cuda version by version.txt."""
v = self._read_version(version_file)
def _check_version(self):
"""Check cuda version"""
v = self._get_cudart_version()
v = version.parse(v)
v_str = str(v.major) + "." + str(v.minor)
if v_str not in self.version:
@ -383,10 +388,9 @@ def check_version_and_env_config():
return
try:
from . import _c_expression
# check version of ascend site or cuda
env_checker.check_version()
from . import _c_expression
env_checker.set_env()
except ImportError as e:
env_checker.check_env(e)