forked from mindspore-Ecosystem/mindspore
!2334 remove dataset send from data exec for r0.3
Merge pull request !2334 from wangnan39/do_not_send_data_duriing_model_init
This commit is contained in:
commit
91c856e5ee
|
@ -15,6 +15,7 @@
|
|||
"""Dataset help for minddata dataset"""
|
||||
from mindspore._checkparam import check_bool
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
|
||||
from mindspore.train.dataset_helper import _send_data
|
||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
|
||||
_to_full_shapes
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
@ -67,7 +68,13 @@ class _DatasetIter:
|
|||
self.loop_size = dataset.get_dataset_size()
|
||||
else:
|
||||
self.loop_size = dataset.__loop_size__
|
||||
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset)
|
||||
else:
|
||||
_send_data(dataset)
|
||||
|
||||
self.ind = 0
|
||||
self.dataset = dataset
|
||||
|
|
|
@ -16,11 +16,10 @@
|
|||
import os
|
||||
import numpy as np
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.dtype import dtype_to_nptype
|
||||
from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import log as logger
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
|
||||
|
||||
def _convert_type(types):
|
||||
|
@ -63,9 +62,6 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
|||
dataset_shapes,
|
||||
input_indexs,
|
||||
phase=phase)
|
||||
|
||||
# engine dataset to write data to tdt queue
|
||||
exec_dataset.send()
|
||||
return exec_dataset
|
||||
|
||||
|
||||
|
|
|
@ -24,6 +24,14 @@ from ..nn.wrap import GetNextSingleOp
|
|||
from ..parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode
|
||||
|
||||
|
||||
def _send_data(dataset):
|
||||
"""Engine dataset to write data to tdt queue."""
|
||||
if not hasattr(dataset, '__has_sent__'):
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset.send()
|
||||
dataset.__has_sent__ = True
|
||||
|
||||
|
||||
class DatasetHelper:
|
||||
"""
|
||||
Help function to use the Minddata dataset.
|
||||
|
@ -82,7 +90,13 @@ class _DatasetIter:
|
|||
self.loop_size = dataset.get_dataset_size()
|
||||
else:
|
||||
self.loop_size = dataset.__loop_size__
|
||||
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset)
|
||||
else:
|
||||
_send_data(dataset)
|
||||
|
||||
self.ind = 0
|
||||
self.dataset = dataset
|
||||
|
|
|
@ -278,7 +278,7 @@ class Model:
|
|||
|
||||
if self._parameter_broadcast:
|
||||
self._train_network.set_broadcast_flag()
|
||||
|
||||
train_dataset.__no_send__ = True
|
||||
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
phase='train',
|
||||
|
@ -295,6 +295,7 @@ class Model:
|
|||
|
||||
self._eval_network.set_train(False)
|
||||
self._eval_network.phase = 'eval'
|
||||
valid_dataset.__no_send__ = True
|
||||
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
phase='eval',
|
||||
|
|
Loading…
Reference in New Issue