Modify and judge multi card training logic
This commit is contained in:
parent
bcabb7d5a9
commit
3a5b1f83f3
|
@ -208,16 +208,7 @@ class Profiler:
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
self._output_path, _ = os.path.split(job_dir)
|
self._output_path, _ = os.path.split(job_dir)
|
||||||
|
|
||||||
env_rank_id = os.getenv("RANK_ID")
|
self._profile_communication = kwargs.pop("profile_communication", False)
|
||||||
env_table_file = os.getenv("RANK_TABLE_FILE")
|
|
||||||
env_hccl_path = os.getenv("MINDSPORE_HCCL_CONFIG_PATH")
|
|
||||||
# Determine whether it is multi card training.
|
|
||||||
if env_rank_id and (env_table_file or env_hccl_path):
|
|
||||||
self._profile_communication = kwargs.pop("profile_communication", False)
|
|
||||||
if "profile_communication" in kwargs:
|
|
||||||
kwargs.pop("profile_communication")
|
|
||||||
logger.warning("The profile_communication parameter is invalid in single device training "
|
|
||||||
" which doesn't work.")
|
|
||||||
if not isinstance(self._profile_communication, bool):
|
if not isinstance(self._profile_communication, bool):
|
||||||
raise TypeError("The parameter profile_communication must be bool.")
|
raise TypeError("The parameter profile_communication must be bool.")
|
||||||
if self._profile_communication:
|
if self._profile_communication:
|
||||||
|
@ -244,8 +235,11 @@ class Profiler:
|
||||||
def _ascend_analyse(self):
|
def _ascend_analyse(self):
|
||||||
"""Collect and analyse ascend performance data"""
|
"""Collect and analyse ascend performance data"""
|
||||||
self._rank_size = 1
|
self._rank_size = 1
|
||||||
|
self._profile_communication = False
|
||||||
if GlobalComm.INITED:
|
if GlobalComm.INITED:
|
||||||
self._rank_size = get_group_size()
|
self._rank_size = get_group_size()
|
||||||
|
self._profile_communication = True
|
||||||
|
|
||||||
release()
|
release()
|
||||||
|
|
||||||
job_id = self._get_profiling_job_id()
|
job_id = self._get_profiling_job_id()
|
||||||
|
|
Loading…
Reference in New Issue