forked from mindspore-Ecosystem/mindspore
!5341 transfer tensor to tuple
Merge pull request !5341 from lijiaqi/new
This commit is contained in:
commit
e5780288e9
|
@ -37,6 +37,16 @@ from .dataset_helper import DatasetHelper
|
|||
from . import amp
|
||||
|
||||
|
||||
def _transfer_tensor_to_tuple(inputs):
|
||||
"""
|
||||
If the input is a tensor, convert it to a tuple. If not, the output is unchanged.
|
||||
"""
|
||||
if isinstance(inputs, Tensor):
|
||||
return (inputs,)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class Model:
|
||||
"""
|
||||
High-Level API for Training or Testing.
|
||||
|
@ -386,15 +396,6 @@ class Model:
|
|||
|
||||
return [callbacks]
|
||||
|
||||
def _transfer_tensor_to_tuple(self, inputs):
|
||||
"""
|
||||
If the input is a tensor, convert it to a tuple. If not, the output is unchanged.
|
||||
"""
|
||||
if isinstance(inputs, Tensor):
|
||||
return (inputs,)
|
||||
|
||||
return inputs
|
||||
|
||||
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
|
||||
"""
|
||||
Training process. The data would be passed to network through dataset channel.
|
||||
|
@ -437,7 +438,6 @@ class Model:
|
|||
|
||||
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
||||
for inputs in dataset_helper:
|
||||
inputs = self._transfer_tensor_to_tuple(inputs)
|
||||
if _need_to_full() and context.get_context("device_target") == "GPU":
|
||||
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
|
||||
list_callback.step_begin(run_context)
|
||||
|
@ -485,7 +485,7 @@ class Model:
|
|||
list_callback.epoch_begin(run_context)
|
||||
|
||||
for next_element in dataset_helper:
|
||||
next_element = self._transfer_tensor_to_tuple(next_element)
|
||||
next_element = _transfer_tensor_to_tuple(next_element)
|
||||
len_element = len(next_element)
|
||||
if self._loss_fn and len_element != 2:
|
||||
raise ValueError("when loss_fn is not None, train_dataset should"
|
||||
|
@ -603,7 +603,6 @@ class Model:
|
|||
list_callback.begin(run_context)
|
||||
|
||||
for inputs in dataset_helper:
|
||||
inputs = self._transfer_tensor_to_tuple(inputs)
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
|
||||
|
@ -642,7 +641,7 @@ class Model:
|
|||
for next_element in dataset_helper:
|
||||
cb_params.cur_step_num += 1
|
||||
list_callback.step_begin(run_context)
|
||||
next_element = self._transfer_tensor_to_tuple(next_element)
|
||||
next_element = _transfer_tensor_to_tuple(next_element)
|
||||
outputs = self._eval_network(*next_element)
|
||||
cb_params.net_outputs = outputs
|
||||
list_callback.step_end(run_context)
|
||||
|
|
Loading…
Reference in New Issue