!2616 Decide whether to collect data by dataset sink mode and current step in SummaryCollector

Merge pull request !2616 from ougongchang/fix_collect_freq
This commit is contained in:
mindspore-ci-bot 2020-06-28 10:11:35 +08:00 committed by Gitee
commit 19f79cd744
1 changed files with 20 additions and 4 deletions

View File

@ -166,8 +166,11 @@ class SummaryCollector(Callback):
self._has_saved_custom_data = False
self._is_parse_loss_success = True
self._first_step = True
self._dataset_sink_mode = True
def __enter__(self):
self._first_step = True
self._dataset_sink_mode = True
self._record = SummaryRecord(log_dir=self._summary_dir)
return self
@ -279,15 +282,15 @@ class SummaryCollector(Callback):
def step_end(self, run_context):
cb_params = run_context.original_args()
if self._first_step:
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num)
if cb_params.mode == ModeEnum.TRAIN.value:
# Make sure the first step data is recorded
if not self._first_step and cb_params.cur_step_num % self._collect_freq:
if not self._is_collect_this_step(cb_params):
return
self._first_step = False
if not self._has_saved_train_network:
self._collect_graphs(cb_params)
@ -295,6 +298,7 @@ class SummaryCollector(Callback):
self._collect_metric(cb_params)
self._collect_histogram(cb_params)
self._first_step = False
self._record.record(cb_params.cur_step_num)
def end(self, run_context):
@ -320,6 +324,18 @@ class SummaryCollector(Callback):
raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list,"
f"but expected only one {self.__class__.__name__} instance.")
def _is_collect_this_step(self, cb_params):
"""Decide whether to collect data for the current step."""
# Make sure the first step data is recorded
if not self._first_step:
if self._dataset_sink_mode:
if cb_params.cur_epoch_num % self._collect_freq:
return False
else:
if cb_params.cur_step_num % self._collect_freq:
return False
return True
@staticmethod
def _package_custom_lineage_data(custom_lineage_data):
"""