Modify and judge multi card training logic

This commit is contained in:
zhangyihui 2021-09-07 14:58:23 +08:00
parent bcabb7d5a9
commit 3a5b1f83f3
1 changed files with 4 additions and 10 deletions

View File

@ -208,16 +208,7 @@ class Profiler:
raise ValueError(msg) raise ValueError(msg)
self._output_path, _ = os.path.split(job_dir) self._output_path, _ = os.path.split(job_dir)
env_rank_id = os.getenv("RANK_ID") self._profile_communication = kwargs.pop("profile_communication", False)
env_table_file = os.getenv("RANK_TABLE_FILE")
env_hccl_path = os.getenv("MINDSPORE_HCCL_CONFIG_PATH")
# Determine whether it is multi card training.
if env_rank_id and (env_table_file or env_hccl_path):
self._profile_communication = kwargs.pop("profile_communication", False)
if "profile_communication" in kwargs:
kwargs.pop("profile_communication")
logger.warning("The profile_communication parameter is invalid in single device training "
" which doesn't work.")
if not isinstance(self._profile_communication, bool): if not isinstance(self._profile_communication, bool):
raise TypeError("The parameter profile_communication must be bool.") raise TypeError("The parameter profile_communication must be bool.")
if self._profile_communication: if self._profile_communication:
@ -244,8 +235,11 @@ class Profiler:
def _ascend_analyse(self): def _ascend_analyse(self):
"""Collect and analyse ascend performance data""" """Collect and analyse ascend performance data"""
self._rank_size = 1 self._rank_size = 1
self._profile_communication = False
if GlobalComm.INITED: if GlobalComm.INITED:
self._rank_size = get_group_size() self._rank_size = get_group_size()
self._profile_communication = True
release() release()
job_id = self._get_profiling_job_id() job_id = self._get_profiling_job_id()