add cudnn version check

This commit is contained in:
VectorSL 2021-04-12 16:19:07 +08:00
parent 0df1ef53bf
commit 8048b52edd
1 changed files with 26 additions and 0 deletions

View File

@ -55,6 +55,7 @@ class GPUEnvChecker(EnvChecker):
self.v = "0"
self.cuda_lib_path = self._get_lib_path("libcu")
self.cuda_bin_path = self._get_bin_path("cuda")
self.cudnn_lib_path = self._get_lib_path("libcudnn")
def check_env(self, e):
raise e
@ -94,6 +95,20 @@ class GPUEnvChecker(EnvChecker):
return line.strip().split("release")[1].split(",")[0].strip()
return ""
def _get_cudnn_version(self):
"""Get cudnn version by libcudnn.so."""
cudnn_version = []
for path in self.cudnn_lib_path:
ls_cudnn = subprocess.run(["ls " + path + "/lib64/libcudnn.so.*.*"], timeout=10, text=True,
capture_output=True, check=False, shell=True)
if ls_cudnn.returncode == 0:
cudnn_version = ls_cudnn.stdout.split('/')[-1].strip('libcudnn.so.').strip().split('.')
if len(cudnn_version) == 2:
cudnn_version.append('0')
break
version_str = ''.join([n for n in cudnn_version])
return version_str
def check_version(self):
"""Check cuda version."""
version_match = False
@ -118,6 +133,17 @@ class GPUEnvChecker(EnvChecker):
logger.warning(f"MindSpore version {__version__} and nvcc(cuda bin) version {nvcc_version} "
"does not match, please refer to the installation guide for version matching "
"information: https://www.mindspore.cn/install")
cudnn_version = self._get_cudnn_version()
if cudnn_version and int(cudnn_version) < 760:
logger.warning(f"MindSpore version {__version__} and cudDNN version {cudnn_version} "
"does not match, please refer to the installation guide for version matching "
"information: https://www.mindspore.cn/install. The recommended version is "
"CUDA10.1 with cuDNN7.6.x and CUAD11.1 with cuDNN8.0.x")
if cudnn_version and int(cudnn_version) < 800 and int(str(self.v).split('.')[0]) > 10:
logger.warning(f"CUDA version {self.v} and cuDNN version {cudnn_version} "
"does not match, please refer to the installation guide for version matching "
"information: https://www.mindspore.cn/install. The recommended version is "
"CUAD11.1 with cuDNN8.0.x")
def _check_version(self, version_file):
"""Check cuda version by version.txt."""