forked from mindspore-Ecosystem/mindspore
!15536 GPU update cuda version check
From: @VectorSL Reviewed-by: @wilfchen,@cristoval Signed-off-by: @cristoval
This commit is contained in:
commit
6f0456e274
|
@ -99,7 +99,7 @@ class GPUEnvChecker(EnvChecker):
|
|||
"""Get cudnn version by libcudnn.so."""
|
||||
cudnn_version = []
|
||||
for path in self.cudnn_lib_path:
|
||||
ls_cudnn = subprocess.run(["ls " + path + "/lib64/libcudnn.so.*.*"], timeout=10, text=True,
|
||||
ls_cudnn = subprocess.run(["ls " + path + "/lib*/libcudnn.so.*.*"], timeout=10, text=True,
|
||||
capture_output=True, check=False, shell=True)
|
||||
if ls_cudnn.returncode == 0:
|
||||
cudnn_version = ls_cudnn.stdout.split('/')[-1].strip('libcudnn.so.').strip().split('.')
|
||||
|
@ -107,21 +107,26 @@ class GPUEnvChecker(EnvChecker):
|
|||
cudnn_version.append('0')
|
||||
break
|
||||
version_str = ''.join([n for n in cudnn_version])
|
||||
return version_str
|
||||
return version_str[0:3]
|
||||
|
||||
def _get_cudart_version(self):
|
||||
"""Get cuda runtime version by libcudart.so."""
|
||||
for path in self.cuda_lib_path:
|
||||
ls_cudart = subprocess.run(["ls " + path + "/lib*/libcudart.so.*.*.*"], timeout=10, text=True,
|
||||
capture_output=True, check=False, shell=True)
|
||||
if ls_cudart.returncode == 0:
|
||||
self.v = ls_cudart.stdout.split('/')[-1].strip('libcudart.so.').strip()
|
||||
break
|
||||
return self.v
|
||||
|
||||
def check_version(self):
|
||||
"""Check cuda version."""
|
||||
version_match = False
|
||||
for path in self.cuda_lib_path:
|
||||
version_file = path + "/version.txt"
|
||||
if not Path(version_file).is_file():
|
||||
continue
|
||||
if self._check_version(version_file):
|
||||
version_match = True
|
||||
break
|
||||
if self._check_version():
|
||||
version_match = True
|
||||
if not version_match:
|
||||
if self.v == "0":
|
||||
logger.warning("Cuda version file version.txt is not found, please confirm that the correct "
|
||||
logger.warning("Can not found cuda libs, please confirm that the correct "
|
||||
"cuda version has been installed, you can refer to the "
|
||||
"installation guidelines: https://www.mindspore.cn/install")
|
||||
else:
|
||||
|
@ -145,9 +150,9 @@ class GPUEnvChecker(EnvChecker):
|
|||
"information: https://www.mindspore.cn/install. The recommended version is "
|
||||
"CUAD11.1 with cuDNN8.0.x")
|
||||
|
||||
def _check_version(self, version_file):
|
||||
"""Check cuda version by version.txt."""
|
||||
v = self._read_version(version_file)
|
||||
def _check_version(self):
|
||||
"""Check cuda version"""
|
||||
v = self._get_cudart_version()
|
||||
v = version.parse(v)
|
||||
v_str = str(v.major) + "." + str(v.minor)
|
||||
if v_str not in self.version:
|
||||
|
@ -383,10 +388,9 @@ def check_version_and_env_config():
|
|||
return
|
||||
|
||||
try:
|
||||
from . import _c_expression
|
||||
# check version of ascend site or cuda
|
||||
env_checker.check_version()
|
||||
|
||||
from . import _c_expression
|
||||
env_checker.set_env()
|
||||
except ImportError as e:
|
||||
env_checker.check_env(e)
|
||||
|
|
Loading…
Reference in New Issue