diff --git a/example/resnet50_imagenet2012_THOR/model/dataset_helper.py b/example/resnet50_imagenet2012_THOR/model/dataset_helper.py index 474bccf42f4..77f67344c2b 100644 --- a/example/resnet50_imagenet2012_THOR/model/dataset_helper.py +++ b/example/resnet50_imagenet2012_THOR/model/dataset_helper.py @@ -15,6 +15,7 @@ """Dataset help for minddata dataset""" from mindspore._checkparam import check_bool from mindspore.parallel._utils import _get_device_num, _get_parallel_mode +from mindspore.train.dataset_helper import _send_data from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \ _to_full_shapes from mindspore.train.parallel_utils import ParallelMode @@ -67,7 +68,13 @@ class _DatasetIter: self.loop_size = dataset.get_dataset_size() else: self.loop_size = dataset.__loop_size__ - dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name + dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) + dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name + + if not hasattr(dataset, '__no_send__'): + _send_data(dataset) + else: + _send_data(dataset) self.ind = 0 self.dataset = dataset diff --git a/mindspore/train/_utils.py b/mindspore/train/_utils.py index 7bc07b126ef..3072771b294 100644 --- a/mindspore/train/_utils.py +++ b/mindspore/train/_utils.py @@ -16,11 +16,10 @@ import os import numpy as np from mindspore.common.tensor import Tensor -from mindspore.common.dtype import dtype_to_nptype +from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype from mindspore.common import dtype as mstype from mindspore import log as logger from mindspore.common.api import _executor -from mindspore.common.dtype import pytype_to_dtype def _convert_type(types): @@ -63,9 +62,6 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): dataset_shapes, input_indexs, phase=phase) - - # engine dataset to write data to tdt queue - exec_dataset.send() return exec_dataset diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 083349e5a1c..28d65349ee7 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -24,6 +24,14 @@ from ..nn.wrap import GetNextSingleOp from ..parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode +def _send_data(dataset): + """Engine dataset to write data to tdt queue.""" + if not hasattr(dataset, '__has_sent__'): + exec_dataset = dataset.__TRANSFER_DATASET__ + exec_dataset.send() + dataset.__has_sent__ = True + + class DatasetHelper: """ Help function to use the Minddata dataset. @@ -82,7 +90,13 @@ class _DatasetIter: self.loop_size = dataset.get_dataset_size() else: self.loop_size = dataset.__loop_size__ - dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name + dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) + dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name + + if not hasattr(dataset, '__no_send__'): + _send_data(dataset) + else: + _send_data(dataset) self.ind = 0 self.dataset = dataset diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 427e5a29ce7..27d448b8e79 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -278,7 +278,7 @@ class Model: if self._parameter_broadcast: self._train_network.set_broadcast_flag() - + train_dataset.__no_send__ = True train_dataset_helper, train_network = self._exec_preprocess(self._train_network, is_train=True, phase='train', @@ -295,6 +295,7 @@ class Model: self._eval_network.set_train(False) self._eval_network.phase = 'eval' + valid_dataset.__no_send__ = True valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network, is_train=False, phase='eval',