forked from mindspore-Ecosystem/mindspore
do not send the data during the model init
This commit is contained in:
parent
fe797aaf10
commit
887838d452
|
@ -15,6 +15,7 @@
|
||||||
"""Dataset help for minddata dataset"""
|
"""Dataset help for minddata dataset"""
|
||||||
from mindspore._checkparam import check_bool
|
from mindspore._checkparam import check_bool
|
||||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
|
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
|
||||||
|
from mindspore.train.dataset_helper import _send_data
|
||||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
|
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
|
||||||
_to_full_shapes
|
_to_full_shapes
|
||||||
from mindspore.train.parallel_utils import ParallelMode
|
from mindspore.train.parallel_utils import ParallelMode
|
||||||
|
@ -67,7 +68,13 @@ class _DatasetIter:
|
||||||
self.loop_size = dataset.get_dataset_size()
|
self.loop_size = dataset.get_dataset_size()
|
||||||
else:
|
else:
|
||||||
self.loop_size = dataset.__loop_size__
|
self.loop_size = dataset.__loop_size__
|
||||||
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
|
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
||||||
|
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||||
|
|
||||||
|
if not hasattr(dataset, '__no_send__'):
|
||||||
|
_send_data(dataset)
|
||||||
|
else:
|
||||||
|
_send_data(dataset)
|
||||||
|
|
||||||
self.ind = 0
|
self.ind = 0
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
|
|
@ -16,11 +16,10 @@
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.common.dtype import dtype_to_nptype
|
from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore.common.api import _executor
|
from mindspore.common.api import _executor
|
||||||
from mindspore.common.dtype import pytype_to_dtype
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_type(types):
|
def _convert_type(types):
|
||||||
|
@ -64,8 +63,6 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
||||||
input_indexs,
|
input_indexs,
|
||||||
phase=phase)
|
phase=phase)
|
||||||
|
|
||||||
# engine dataset to write data to tdt queue
|
|
||||||
exec_dataset.send()
|
|
||||||
return exec_dataset
|
return exec_dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,14 @@ from ..nn.wrap import GetNextSingleOp
|
||||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full
|
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full
|
||||||
|
|
||||||
|
|
||||||
|
def _send_data(dataset):
|
||||||
|
"""Engine dataset to write data to tdt queue."""
|
||||||
|
if not hasattr(dataset, '__has_sent__'):
|
||||||
|
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||||
|
exec_dataset.send()
|
||||||
|
dataset.__has_sent__ = True
|
||||||
|
|
||||||
|
|
||||||
class DatasetHelper:
|
class DatasetHelper:
|
||||||
"""
|
"""
|
||||||
Help function to use the Minddata dataset.
|
Help function to use the Minddata dataset.
|
||||||
|
@ -81,7 +89,13 @@ class _DatasetIter:
|
||||||
self.loop_size = dataset.get_dataset_size()
|
self.loop_size = dataset.get_dataset_size()
|
||||||
else:
|
else:
|
||||||
self.loop_size = dataset.__loop_size__
|
self.loop_size = dataset.__loop_size__
|
||||||
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
|
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
||||||
|
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||||
|
|
||||||
|
if not hasattr(dataset, '__no_send__'):
|
||||||
|
_send_data(dataset)
|
||||||
|
else:
|
||||||
|
_send_data(dataset)
|
||||||
|
|
||||||
self.ind = 0
|
self.ind = 0
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
|
|
@ -285,7 +285,7 @@ class Model:
|
||||||
|
|
||||||
if self._parameter_broadcast:
|
if self._parameter_broadcast:
|
||||||
self._train_network.set_broadcast_flag()
|
self._train_network.set_broadcast_flag()
|
||||||
|
train_dataset.__no_send__ = True
|
||||||
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||||
is_train=True,
|
is_train=True,
|
||||||
phase='train',
|
phase='train',
|
||||||
|
@ -302,6 +302,7 @@ class Model:
|
||||||
|
|
||||||
self._eval_network.set_train(False)
|
self._eval_network.set_train(False)
|
||||||
self._eval_network.phase = 'eval'
|
self._eval_network.phase = 'eval'
|
||||||
|
valid_dataset.__no_send__ = True
|
||||||
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||||
is_train=False,
|
is_train=False,
|
||||||
phase='eval',
|
phase='eval',
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
"""Dataset help for minddata dataset"""
|
"""Dataset help for minddata dataset"""
|
||||||
from mindspore._checkparam import check_bool
|
from mindspore._checkparam import check_bool
|
||||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
|
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
|
||||||
|
from mindspore.train.dataset_helper import _send_data
|
||||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
|
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
|
||||||
_to_full_shapes
|
_to_full_shapes
|
||||||
from mindspore.train.parallel_utils import ParallelMode
|
from mindspore.train.parallel_utils import ParallelMode
|
||||||
|
@ -69,7 +70,13 @@ class _DatasetIter:
|
||||||
self.loop_size = dataset.get_dataset_size()
|
self.loop_size = dataset.get_dataset_size()
|
||||||
else:
|
else:
|
||||||
self.loop_size = dataset.__loop_size__
|
self.loop_size = dataset.__loop_size__
|
||||||
dataset.__ME_INITED__ = _exec_datagraph(dataset, self.loop_size).queue_name
|
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
||||||
|
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||||
|
|
||||||
|
if not hasattr(dataset, '__no_send__'):
|
||||||
|
_send_data(dataset)
|
||||||
|
else:
|
||||||
|
_send_data(dataset)
|
||||||
|
|
||||||
self.ind = 0
|
self.ind = 0
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
|
Loading…
Reference in New Issue