Decide whether to collect data by dataset sink mode and current step in SummaryCollector.

Before, we only decide whether to collect data by current step,
it will not work well in dataset sink mode, so we check to see
if it's a dataset sink mode, and decide whether to collect data.
This commit is contained in:
ougongchang 2020-06-27 16:07:45 +08:00
parent f067c209c6
commit 33b5cda1da
1 changed files with 20 additions and 4 deletions

View File

@ -166,8 +166,11 @@ class SummaryCollector(Callback):
self._has_saved_custom_data = False
self._is_parse_loss_success = True
self._first_step = True
self._dataset_sink_mode = True
def __enter__(self):
self._first_step = True
self._dataset_sink_mode = True
self._record = SummaryRecord(log_dir=self._summary_dir)
return self
@ -279,15 +282,15 @@ class SummaryCollector(Callback):
def step_end(self, run_context):
cb_params = run_context.original_args()
if self._first_step:
# Notice: This way of determining whether dataset sink mode is True does not work in the eval scenario
self._dataset_sink_mode = bool(cb_params.cur_step_num == cb_params.batch_num)
if cb_params.mode == ModeEnum.TRAIN.value:
# Make sure the first step data is recorded
if not self._first_step and cb_params.cur_step_num % self._collect_freq:
if not self._is_collect_this_step(cb_params):
return
self._first_step = False
if not self._has_saved_train_network:
self._collect_graphs(cb_params)
@ -295,6 +298,7 @@ class SummaryCollector(Callback):
self._collect_metric(cb_params)
self._collect_histogram(cb_params)
self._first_step = False
self._record.record(cb_params.cur_step_num)
def end(self, run_context):
@ -320,6 +324,18 @@ class SummaryCollector(Callback):
raise ValueError(f"There are more than one {self.__class__.__name__} instance in callback list,"
f"but expected only one {self.__class__.__name__} instance.")
def _is_collect_this_step(self, cb_params):
"""Decide whether to collect data for the current step."""
# Make sure the first step data is recorded
if not self._first_step:
if self._dataset_sink_mode:
if cb_params.cur_epoch_num % self._collect_freq:
return False
else:
if cb_params.cur_step_num % self._collect_freq:
return False
return True
@staticmethod
def _package_custom_lineage_data(custom_lineage_data):
"""