!5236 transfer tensor to tuple

Merge pull request !5236 from lijiaqi/transfer_tensor_to_tuple
This commit is contained in:
mindspore-ci-bot 2020-08-26 19:27:27 +08:00 committed by Gitee
commit adeeda2fe1
1 changed files with 13 additions and 0 deletions

View File

@ -386,6 +386,15 @@ class Model:
return [callbacks]
def _transfer_tensor_to_tuple(self, inputs):
"""
If the input is a tensor, convert it to a tuple. If not, the output is unchanged.
"""
if isinstance(inputs, Tensor):
return (inputs,)
return inputs
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
"""
Training process. The data would be passed to network through dataset channel.
@ -428,6 +437,7 @@ class Model:
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper:
inputs = self._transfer_tensor_to_tuple(inputs)
if _need_to_full() and context.get_context("device_target") == "GPU":
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
list_callback.step_begin(run_context)
@ -475,6 +485,7 @@ class Model:
list_callback.epoch_begin(run_context)
for next_element in dataset_helper:
next_element = self._transfer_tensor_to_tuple(next_element)
len_element = len(next_element)
if self._loss_fn and len_element != 2:
raise ValueError("when loss_fn is not None, train_dataset should"
@ -592,6 +603,7 @@ class Model:
list_callback.begin(run_context)
for inputs in dataset_helper:
inputs = self._transfer_tensor_to_tuple(inputs)
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
@ -630,6 +642,7 @@ class Model:
for next_element in dataset_helper:
cb_params.cur_step_num += 1
list_callback.step_begin(run_context)
next_element = self._transfer_tensor_to_tuple(next_element)
outputs = self._eval_network(*next_element)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)