do not send the data during the model init

This commit is contained in:
wangnan39@huawei.com 2020-06-18 11:45:35 +08:00
parent fe797aaf10
commit 887838d452
5 changed files with 34 additions and 8 deletions

View File

@ -15,6 +15,7 @@
"""Dataset help for minddata dataset"""
from mindspore._checkparam import check_bool
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
from mindspore.train.dataset_helper import _send_data
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
_to_full_shapes
from mindspore.train.parallel_utils import ParallelMode
@ -67,7 +68,13 @@ class _DatasetIter:
self.loop_size = dataset.get_dataset_size()
else:
self.loop_size = dataset.__loop_size__
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'):
_send_data(dataset)
else:
_send_data(dataset)
self.ind = 0
self.dataset = dataset

View File

@ -16,11 +16,10 @@
import os
import numpy as np
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype
from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
from mindspore.common import dtype as mstype
from mindspore import log as logger
from mindspore.common.api import _executor
from mindspore.common.dtype import pytype_to_dtype
def _convert_type(types):
@ -64,8 +63,6 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
input_indexs,
phase=phase)
# engine dataset to write data to tdt queue
exec_dataset.send()
return exec_dataset

View File

@ -23,6 +23,14 @@ from ..nn.wrap import GetNextSingleOp
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full
def _send_data(dataset):
"""Engine dataset to write data to tdt queue."""
if not hasattr(dataset, '__has_sent__'):
exec_dataset = dataset.__TRANSFER_DATASET__
exec_dataset.send()
dataset.__has_sent__ = True
class DatasetHelper:
"""
Help function to use the Minddata dataset.
@ -81,7 +89,13 @@ class _DatasetIter:
self.loop_size = dataset.get_dataset_size()
else:
self.loop_size = dataset.__loop_size__
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'):
_send_data(dataset)
else:
_send_data(dataset)
self.ind = 0
self.dataset = dataset

View File

@ -285,7 +285,7 @@ class Model:
if self._parameter_broadcast:
self._train_network.set_broadcast_flag()
train_dataset.__no_send__ = True
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
is_train=True,
phase='train',
@ -302,6 +302,7 @@ class Model:
self._eval_network.set_train(False)
self._eval_network.phase = 'eval'
valid_dataset.__no_send__ = True
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
is_train=False,
phase='eval',

View File

@ -15,6 +15,7 @@
"""Dataset help for minddata dataset"""
from mindspore._checkparam import check_bool
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
from mindspore.train.dataset_helper import _send_data
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
_to_full_shapes
from mindspore.train.parallel_utils import ParallelMode
@ -69,7 +70,13 @@ class _DatasetIter:
self.loop_size = dataset.get_dataset_size()
else:
self.loop_size = dataset.__loop_size__
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'):
_send_data(dataset)
else:
_send_data(dataset)
self.ind = 0
self.dataset = dataset