diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 8ceb8d6c1ab..bed44552fc4 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -17,11 +17,12 @@ import math import os from mindspore._checkparam import check_bool, check_int -from .. import context +from .. import context, nn from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \ _construct_tensor_list from ..nn.wrap import GetNextSingleOp from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes +from ..ops import operations as P def _send_data(dataset, epoch_num): @@ -37,9 +38,70 @@ def _send_data_no_flag(dataset, epoch_num): exec_dataset.send(epoch_num) +def connect_network_with_dataset(network, dataset_helper): + """ + Connect the `network` with dataset in `dataset_helper`. + + This function wraps the input network with 'GetNext' so that the data can be fetched automatically from the + data channel corresponding to the 'queue_name' and passed to the input network during forward computation. + + Note: + In the case of running the network on Ascend in graph mode, this function will wrap the input network with + 'GetNext', in other cases, the input network will be returned with no change. + The 'GetNext' is required to get data only in sink mode, so this function is not applicable to no-sink mode. + + Args: + network (Cell): The training network for dataset. + dataset_helper(DatasetHelper): A class to process the MindData dataset, it provides the type, shape and queue + name of the dataset to wrap the `GetNext`. + + Outputs: + Cell, a new network wrapped with 'GetNext' in the case of running the task on Ascend in graph mode, otherwise + it is the input network. + + Examples: + >>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset + >>> train_dataset = create_dataset() + >>> dataset_helper = mindspore.DatasetHelper(train_dataset) + >>> net = Net() + >>> net_with_get_next = connect_network_with_dataset(net, dataset_helper) + """ + class _DataWrapper(nn.Cell): + """ + Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the + dataset channel 'queue_name' and performs the forward computation. + """ + def __init__(self, network, dataset_types, dataset_shapes, queue_name): + super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags()) + # Also copy the flag in `network` construct + flags = getattr(network.__class__.construct, "_mindspore_flags", {}) + self.add_flags(**flags) + self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name) + self.network = network + + def construct(self): + outputs = self.get_next() + return self.network(*outputs) + + dataset_iter = dataset_helper.iter + dataset = dataset_iter.dataset + + if isinstance(dataset_iter, _DatasetIterNormal): + raise RuntimeError("Dataset should be connected with network only in sink mode.") + + if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" and \ + context.get_context("mode") == context.GRAPH_MODE and not context.get_context("enable_ge"): + dataset.__ME_INITED__ = True + dataset_types, dataset_shapes = dataset_helper.types_shapes() + queue_name = dataset.__TRANSFER_DATASET__.queue_name + + network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name) + return network + + class DatasetHelper: """ - Help function to use the MindData dataset. + DatasetHelper is a class to process the MindData dataset and it provides the information of dataset. According to different contexts, change the iterations of dataset and use the same iteration for loop in different contexts. @@ -114,7 +176,6 @@ class _DatasetIter: if hasattr(dataset, '__loop_size__'): self.sink_size = dataset.__loop_size__ dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) - dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name if not hasattr(dataset, '__no_send__'): _send_data(dataset, epoch_num) @@ -207,7 +268,7 @@ class _DatasetIterMS(_DatasetIter): else: self.sink_count = dataset.get_dataset_size() - queue_name = dataset.__ME_INITED__ + queue_name = dataset.__TRANSFER_DATASET__.queue_name self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index b3a9b5fd32d..1b18a424278 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -35,7 +35,7 @@ from ..context import ParallelMode from ..parallel._utils import _need_to_full, _to_full_tensor from ..parallel._cost_model_context import _set_multi_subgraphs from ..common import dtype as mstype -from .dataset_helper import DatasetHelper +from .dataset_helper import DatasetHelper, connect_network_with_dataset from . import amp @@ -249,23 +249,15 @@ class Model: def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1): """Initializes dataset.""" - need_wrap = False - if dataset_sink_mode: - # remove later to deal with loop sink - if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ - and not context.get_context("enable_ge"): - need_wrap = True - - if not is_train: - dataset.__loop_size__ = 1 - + if dataset_sink_mode and not is_train: + dataset.__loop_size__ = 1 dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num) - # remove later to deal with loop sink - if need_wrap: - network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__) - network.set_train(is_train) - network.phase = phase + if dataset_sink_mode: + network = connect_network_with_dataset(network, dataset_helper) + + network.set_train(is_train) + network.phase = phase if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): network.set_auto_parallel() @@ -306,11 +298,9 @@ class Model: if train_dataset: _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) - self._train_network.set_train() - self._train_network.phase = 'train' - if self._parameter_broadcast: self._train_network.set_broadcast_flag() + train_dataset.__no_send__ = True train_dataset_helper, train_network = self._exec_preprocess(self._train_network, is_train=True, @@ -326,8 +316,6 @@ class Model: if not self._metric_fns: raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.') - self._eval_network.set_train(False) - self._eval_network.phase = 'eval' valid_dataset.__no_send__ = True valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network, is_train=False, @@ -358,8 +346,6 @@ class Model: sink_size (int): Control the amount of data in each sink. Default: -1. """ epoch = check_int_positive(epoch) - self._train_network.set_train() - if self._parameter_broadcast: self._train_network.set_broadcast_flag() @@ -701,9 +687,6 @@ class Model: cb_params.list_callback = self._transform_callbacks(callbacks) cb_params.network = self._network - self._eval_network.set_train(mode=False) - self._eval_network.phase = 'eval' - self._clear_metrics() if context.get_context("device_target") == "CPU": diff --git a/model_zoo/official/cv/resnet_thor/src/dataset_helper.py b/model_zoo/official/cv/resnet_thor/src/dataset_helper.py index d5982aa7ec1..bb4109c8e09 100644 --- a/model_zoo/official/cv/resnet_thor/src/dataset_helper.py +++ b/model_zoo/official/cv/resnet_thor/src/dataset_helper.py @@ -104,7 +104,6 @@ class _DatasetIter: if hasattr(dataset, '__loop_size__'): self.sink_size = dataset.__loop_size__ dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) - dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name if not hasattr(dataset, '__no_send__'): _send_data(dataset, epoch_num) @@ -188,5 +187,5 @@ class _DatasetIterMS(_DatasetIter): else: self.sink_count = dataset.get_dataset_size() - queue_name = dataset.__ME_INITED__ + queue_name = dataset.__TRANSFER_DATASET__.queue_name self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) diff --git a/model_zoo/official/cv/resnet_thor/src/model_thor.py b/model_zoo/official/cv/resnet_thor/src/model_thor.py index 1b86acf51f7..7c806b81ff6 100644 --- a/model_zoo/official/cv/resnet_thor/src/model_thor.py +++ b/model_zoo/official/cv/resnet_thor/src/model_thor.py @@ -17,9 +17,9 @@ import math from mindspore.train.callback import RunContext from mindspore import context -from mindspore import nn from mindspore.context import ParallelMode from mindspore.train.model import Model +from mindspore.train.dataset_helper import connect_network_with_dataset from mindspore.parallel._utils import _need_to_full, _to_full_tensor from mindspore.common.dtype import pytype_to_dtype from mindspore._c_expression import init_exec_dataset @@ -57,7 +57,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): # transform data format dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) - init_exec_dataset(exec_dataset.__ME_INITED__, + init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name, dataset_size, batch_size, dataset_types, @@ -114,23 +114,14 @@ class Model_Thor(Model): def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1, iter_first_order=1): """Initializes dataset.""" - need_wrap = False - if dataset_sink_mode: - # remove later to deal with loop sink - if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ - and not context.get_context("enable_ge"): - need_wrap = True - - if not is_train: - dataset.__loop_size__ = 1 - + if dataset_sink_mode and not is_train: + dataset.__loop_size__ = 1 dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order) - # remove later to deal with loop sink - if need_wrap: - network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__) - network.set_train(is_train) - network.phase = phase + if dataset_sink_mode: + network = connect_network_with_dataset(network, dataset_helper) + network.set_train(is_train) + network.phase = phase if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): network.set_auto_parallel() diff --git a/model_zoo/official/nlp/bert_thor/src/dataset_helper.py b/model_zoo/official/nlp/bert_thor/src/dataset_helper.py index 3a72cbcda29..9a011450c9e 100644 --- a/model_zoo/official/nlp/bert_thor/src/dataset_helper.py +++ b/model_zoo/official/nlp/bert_thor/src/dataset_helper.py @@ -111,7 +111,6 @@ class _DatasetIter: if hasattr(dataset, '__loop_size__'): self.sink_size = dataset.__loop_size__ dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) - dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name if not hasattr(dataset, '__no_send__'): _send_data(dataset, epoch_num) diff --git a/model_zoo/official/nlp/bert_thor/src/model_thor.py b/model_zoo/official/nlp/bert_thor/src/model_thor.py index 710bd7a9f0b..c6073a07f17 100644 --- a/model_zoo/official/nlp/bert_thor/src/model_thor.py +++ b/model_zoo/official/nlp/bert_thor/src/model_thor.py @@ -28,6 +28,7 @@ from mindspore.common.dtype import pytype_to_dtype from mindspore.common.tensor import Tensor from mindspore.nn.metrics import Loss from mindspore.nn.metrics import get_metrics +from mindspore.train.dataset_helper import connect_network_with_dataset from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check @@ -70,7 +71,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): # transform data format dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) - init_exec_dataset(exec_dataset.__ME_INITED__, + init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name, dataset_size, batch_size, dataset_types, @@ -275,23 +276,14 @@ class Model: def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1, iter_first_order=9): """Initializes dataset.""" - need_wrap = False - if dataset_sink_mode: - # remove later to deal with loop sink - if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ - and not context.get_context("enable_ge"): - need_wrap = True - - if not is_train: - dataset.__loop_size__ = 1 - + if dataset_sink_mode and not is_train: + dataset.__loop_size__ = 1 dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order) - # remove later to deal with loop sink - if need_wrap: - network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__) - network.set_train(is_train) - network.phase = phase + if dataset_sink_mode: + network = connect_network_with_dataset(network, dataset_helper) + network.set_train(is_train) + network.phase = phase if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): network.set_auto_parallel() diff --git a/tests/st/networks/models/resnet50/src_thor/dataset_helper.py b/tests/st/networks/models/resnet50/src_thor/dataset_helper.py index 448589b6f7e..e4cfc87c8a3 100644 --- a/tests/st/networks/models/resnet50/src_thor/dataset_helper.py +++ b/tests/st/networks/models/resnet50/src_thor/dataset_helper.py @@ -18,6 +18,7 @@ from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _to_f from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes from mindspore.context import ParallelMode + def _send_data(dataset): """Engine dataset to write data to tdt queue.""" if not hasattr(dataset, '__has_sent__'): @@ -25,6 +26,7 @@ def _send_data(dataset): exec_dataset.send() dataset.__has_sent__ = True + class DatasetHelper: """ Help function to use the Minddata dataset. @@ -69,13 +71,12 @@ class _DatasetIter: def __init__(self, dataset): self.loop_size = 1 - if not hasattr(dataset, '__ME_INITED__'): + if not hasattr(dataset, '__TRANSFER_DATASET__'): if not hasattr(dataset, '__loop_size__'): self.loop_size = dataset.get_dataset_size() else: self.loop_size = dataset.__loop_size__ dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) - dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name if not hasattr(dataset, '__no_send__'): _send_data(dataset) diff --git a/tests/st/networks/models/resnet50/src_thor/model_thor.py b/tests/st/networks/models/resnet50/src_thor/model_thor.py index c99a41cb6ca..2bcbdb50bbe 100644 --- a/tests/st/networks/models/resnet50/src_thor/model_thor.py +++ b/tests/st/networks/models/resnet50/src_thor/model_thor.py @@ -23,6 +23,7 @@ from mindspore._checkparam import check_input_data, check_output_data, check_int from mindspore.common import dtype as mstype from mindspore.common.dtype import pytype_to_dtype from mindspore.common.tensor import Tensor +from mindspore.train.dataset_helper import connect_network_with_dataset from mindspore.nn.metrics import Loss from mindspore.nn.metrics import get_metrics from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell @@ -66,7 +67,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): # transform data format dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) - init_exec_dataset(exec_dataset.__ME_INITED__, + init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name, dataset_size, batch_size, dataset_types, @@ -266,23 +267,14 @@ class Model: def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, iter_first_order=1): """Initializes dataset.""" - need_wrap = False - if dataset_sink_mode: - # remove later to deal with loop sink - if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ - and not context.get_context("enable_ge"): - need_wrap = True - - if not is_train: - dataset.__loop_size__ = 1 - + if dataset_sink_mode and not is_train: + dataset.__loop_size__ = 1 dataset_helper = DatasetHelper(dataset, dataset_sink_mode, iter_first_order) - # remove later to deal with loop sink - if need_wrap: - network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__) - network.set_train(is_train) - network.phase = phase + if dataset_sink_mode: + network = connect_network_with_dataset(network, dataset_helper) + network.set_train(is_train) + network.phase = phase return dataset_helper, network @@ -605,7 +597,6 @@ class Model: Dict, returns the loss value & metrics values for the model in test mode. """ run_context = RunContext(cb_params) - dataset_helper, eval_network = self._exec_preprocess(self._eval_network, is_train=False, phase='eval',