forked from mindspore-Ecosystem/mindspore
!5599 Collect input data when `dataset_sink_mode` set on GPU
Merge pull request !5599 from LiHongzhang/dataset_sink_mode
This commit is contained in:
commit
ad186e79d8
|
@ -414,11 +414,11 @@ class SummaryCollector(Callback):
|
|||
logger.info("The 'train_dataset_element' in cb_params is None, maybe there is dataset sink mode.")
|
||||
return
|
||||
|
||||
if isinstance(input_data, (list, tuple)):
|
||||
if isinstance(input_data, (list, tuple)) and input_data:
|
||||
input_data = input_data[0]
|
||||
try:
|
||||
self._record.add_value(PluginEnum.IMAGE.value, 'input_data/auto', input_data)
|
||||
except ValueError:
|
||||
except (TypeError, ValueError):
|
||||
logger.warning('The input data of network are not image, so will not collect by SummaryCollector.')
|
||||
self._collect_specified_data['collect_input_data'] = False
|
||||
return
|
||||
|
|
|
@ -448,6 +448,7 @@ class Model:
|
|||
for inputs in dataset_helper:
|
||||
if _need_to_full() and context.get_context("device_target") == "GPU":
|
||||
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
|
||||
cb_params.train_dataset_element = inputs
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._train_network(*inputs)
|
||||
cb_params.cur_step_num += dataset_helper.sink_size()
|
||||
|
@ -499,7 +500,6 @@ class Model:
|
|||
raise ValueError("when loss_fn is not None, train_dataset should"
|
||||
"return two elements, but got {}".format(len_element))
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
|
||||
overflow = False
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
|
@ -507,6 +507,7 @@ class Model:
|
|||
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
|
||||
|
||||
cb_params.train_dataset_element = next_element
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._train_network(*next_element)
|
||||
cb_params.net_outputs = outputs
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
|
|
|
@ -482,6 +482,7 @@ class Model:
|
|||
for inputs in dataset_helper:
|
||||
if _need_to_full():
|
||||
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
|
||||
cb_params.train_dataset_element = inputs
|
||||
list_callback.step_begin(run_context)
|
||||
if switch_branch_one:
|
||||
cb_params.cur_step_num += dataset_helper.sink_size()
|
||||
|
@ -546,7 +547,6 @@ class Model:
|
|||
raise ValueError("when loss_fn is not None, train_dataset should"
|
||||
"return two elements, but got {}".format(len_element))
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
|
||||
overflow = False
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
|
@ -554,6 +554,7 @@ class Model:
|
|||
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
|
||||
|
||||
cb_params.train_dataset_element = next_element
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._train_network(*next_element)
|
||||
cb_params.net_outputs = outputs
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
|
|
|
@ -454,7 +454,6 @@ class Model:
|
|||
|
||||
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
||||
for inputs in dataset_helper:
|
||||
list_callback.step_begin(run_context)
|
||||
if switch_branch_one:
|
||||
cb_params.cur_step_num += loop_size
|
||||
self._train_network.add_flags_recursive(thor=True)
|
||||
|
@ -467,6 +466,8 @@ class Model:
|
|||
_exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
|
||||
self._has_do_dataset_init = True
|
||||
switch_branch_one = not switch_branch_one
|
||||
cb_params.train_dataset_element = inputs
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._train_network(*inputs)
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
|
@ -514,13 +515,14 @@ class Model:
|
|||
raise ValueError("when loss_fn is not None, train_dataset should"
|
||||
"return two elements, but got {}".format(len_element))
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
|
||||
overflow = False
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
scaling_sens = self._get_scaling_sens()
|
||||
next_element = tuple(next_element) + (Tensor(scaling_sens, mstype.float32),)
|
||||
|
||||
cb_params.train_dataset_element = next_element
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._train_network(*next_element)
|
||||
cb_params.net_outputs = outputs
|
||||
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
||||
|
|
|
@ -242,13 +242,14 @@ class TestSummaryCollector:
|
|||
SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))._check_callbacks(cb_params)
|
||||
assert f"more than one SummaryCollector instance in callback list" in str(exc.value)
|
||||
|
||||
def test_collect_input_data_with_train_dataset_element_none(self):
|
||||
"""Test the param 'train_dataset_element' in cb_params is none."""
|
||||
def test_collect_input_data_with_train_dataset_element_invalid(self):
|
||||
"""Test the param 'train_dataset_element' in cb_params is invalid."""
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params.train_dataset_element = None
|
||||
summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))
|
||||
summary_collector._collect_input_data(cb_params)
|
||||
assert not summary_collector._collect_specified_data['collect_input_data']
|
||||
for invalid in (), [], None, [None]:
|
||||
cb_params.train_dataset_element = invalid
|
||||
with SummaryCollector(tempfile.mkdtemp(dir=self.base_summary_dir)) as summary_collector:
|
||||
summary_collector._collect_input_data(cb_params)
|
||||
assert not summary_collector._collect_specified_data['collect_input_data']
|
||||
|
||||
@mock.patch.object(SummaryRecord, 'add_value')
|
||||
def test_collect_input_data_success(self, mock_add_value):
|
||||
|
|
Loading…
Reference in New Issue