forked from mindspore-Ecosystem/mindspore
add cudnn version check
This commit is contained in:
parent
0df1ef53bf
commit
8048b52edd
|
@ -55,6 +55,7 @@ class GPUEnvChecker(EnvChecker):
|
|||
self.v = "0"
|
||||
self.cuda_lib_path = self._get_lib_path("libcu")
|
||||
self.cuda_bin_path = self._get_bin_path("cuda")
|
||||
self.cudnn_lib_path = self._get_lib_path("libcudnn")
|
||||
|
||||
def check_env(self, e):
|
||||
raise e
|
||||
|
@ -94,6 +95,20 @@ class GPUEnvChecker(EnvChecker):
|
|||
return line.strip().split("release")[1].split(",")[0].strip()
|
||||
return ""
|
||||
|
||||
def _get_cudnn_version(self):
|
||||
"""Get cudnn version by libcudnn.so."""
|
||||
cudnn_version = []
|
||||
for path in self.cudnn_lib_path:
|
||||
ls_cudnn = subprocess.run(["ls " + path + "/lib64/libcudnn.so.*.*"], timeout=10, text=True,
|
||||
capture_output=True, check=False, shell=True)
|
||||
if ls_cudnn.returncode == 0:
|
||||
cudnn_version = ls_cudnn.stdout.split('/')[-1].strip('libcudnn.so.').strip().split('.')
|
||||
if len(cudnn_version) == 2:
|
||||
cudnn_version.append('0')
|
||||
break
|
||||
version_str = ''.join([n for n in cudnn_version])
|
||||
return version_str
|
||||
|
||||
def check_version(self):
|
||||
"""Check cuda version."""
|
||||
version_match = False
|
||||
|
@ -118,6 +133,17 @@ class GPUEnvChecker(EnvChecker):
|
|||
logger.warning(f"MindSpore version {__version__} and nvcc(cuda bin) version {nvcc_version} "
|
||||
"does not match, please refer to the installation guide for version matching "
|
||||
"information: https://www.mindspore.cn/install")
|
||||
cudnn_version = self._get_cudnn_version()
|
||||
if cudnn_version and int(cudnn_version) < 760:
|
||||
logger.warning(f"MindSpore version {__version__} and cudDNN version {cudnn_version} "
|
||||
"does not match, please refer to the installation guide for version matching "
|
||||
"information: https://www.mindspore.cn/install. The recommended version is "
|
||||
"CUDA10.1 with cuDNN7.6.x and CUAD11.1 with cuDNN8.0.x")
|
||||
if cudnn_version and int(cudnn_version) < 800 and int(str(self.v).split('.')[0]) > 10:
|
||||
logger.warning(f"CUDA version {self.v} and cuDNN version {cudnn_version} "
|
||||
"does not match, please refer to the installation guide for version matching "
|
||||
"information: https://www.mindspore.cn/install. The recommended version is "
|
||||
"CUAD11.1 with cuDNN8.0.x")
|
||||
|
||||
def _check_version(self, version_file):
|
||||
"""Check cuda version by version.txt."""
|
||||
|
|
Loading…
Reference in New Issue