forked from mindspore-Ecosystem/mindspore
!2272 remove dataset send from data exec
Merge pull request !2272 from wangnan39/remove_dataset_send_from_model_init
This commit is contained in:
commit
b7b4333d13
|
@ -15,6 +15,7 @@
|
|||
"""Dataset help for minddata dataset"""
|
||||
from mindspore._checkparam import check_bool
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
|
||||
from mindspore.train.dataset_helper import _send_data
|
||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
|
||||
_to_full_shapes
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
@ -67,7 +68,13 @@ class _DatasetIter:
|
|||
self.loop_size = dataset.get_dataset_size()
|
||||
else:
|
||||
self.loop_size = dataset.__loop_size__
|
||||
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset)
|
||||
else:
|
||||
_send_data(dataset)
|
||||
|
||||
self.ind = 0
|
||||
self.dataset = dataset
|
||||
|
|
|
@ -16,11 +16,10 @@
|
|||
import os
|
||||
import numpy as np
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.dtype import dtype_to_nptype
|
||||
from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import log as logger
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
|
||||
|
||||
def _convert_type(types):
|
||||
|
@ -64,8 +63,6 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
|||
input_indexs,
|
||||
phase=phase)
|
||||
|
||||
# engine dataset to write data to tdt queue
|
||||
exec_dataset.send()
|
||||
return exec_dataset
|
||||
|
||||
|
||||
|
|
|
@ -23,6 +23,14 @@ from ..nn.wrap import GetNextSingleOp
|
|||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full
|
||||
|
||||
|
||||
def _send_data(dataset):
|
||||
"""Engine dataset to write data to tdt queue."""
|
||||
if not hasattr(dataset, '__has_sent__'):
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset.send()
|
||||
dataset.__has_sent__ = True
|
||||
|
||||
|
||||
class DatasetHelper:
|
||||
"""
|
||||
Help function to use the Minddata dataset.
|
||||
|
@ -81,7 +89,13 @@ class _DatasetIter:
|
|||
self.loop_size = dataset.get_dataset_size()
|
||||
else:
|
||||
self.loop_size = dataset.__loop_size__
|
||||
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset)
|
||||
else:
|
||||
_send_data(dataset)
|
||||
|
||||
self.ind = 0
|
||||
self.dataset = dataset
|
||||
|
|
|
@ -285,7 +285,7 @@ class Model:
|
|||
|
||||
if self._parameter_broadcast:
|
||||
self._train_network.set_broadcast_flag()
|
||||
|
||||
train_dataset.__no_send__ = True
|
||||
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
phase='train',
|
||||
|
@ -302,6 +302,7 @@ class Model:
|
|||
|
||||
self._eval_network.set_train(False)
|
||||
self._eval_network.phase = 'eval'
|
||||
valid_dataset.__no_send__ = True
|
||||
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
phase='eval',
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""Dataset help for minddata dataset"""
|
||||
from mindspore._checkparam import check_bool
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
|
||||
from mindspore.train.dataset_helper import _send_data
|
||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
|
||||
_to_full_shapes
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
@ -69,7 +70,13 @@ class _DatasetIter:
|
|||
self.loop_size = dataset.get_dataset_size()
|
||||
else:
|
||||
self.loop_size = dataset.__loop_size__
|
||||
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset)
|
||||
else:
|
||||
_send_data(dataset)
|
||||
|
||||
self.ind = 0
|
||||
self.dataset = dataset
|
||||
|
|
Loading…
Reference in New Issue