!5908 delete DataWrapper in Model to unify the expression of dataset sinking

Merge pull request !5908 from wangnan39/optim_datawrapper_in_model
This commit is contained in:
mindspore-ci-bot 2020-09-10 19:33:26 +08:00 committed by Gitee
commit 55372b41fd
8 changed files with 102 additions and 85 deletions

View File

@ -17,11 +17,12 @@ import math
import os import os
from mindspore._checkparam import check_bool, check_int from mindspore._checkparam import check_bool, check_int
from .. import context from .. import context, nn
from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \ from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
_construct_tensor_list _construct_tensor_list
from ..nn.wrap import GetNextSingleOp from ..nn.wrap import GetNextSingleOp
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes
from ..ops import operations as P
def _send_data(dataset, epoch_num): def _send_data(dataset, epoch_num):
@ -37,9 +38,70 @@ def _send_data_no_flag(dataset, epoch_num):
exec_dataset.send(epoch_num) exec_dataset.send(epoch_num)
def connect_network_with_dataset(network, dataset_helper):
"""
Connect the `network` with dataset in `dataset_helper`.
This function wraps the input network with 'GetNext' so that the data can be fetched automatically from the
data channel corresponding to the 'queue_name' and passed to the input network during forward computation.
Note:
In the case of running the network on Ascend in graph mode, this function will wrap the input network with
'GetNext', in other cases, the input network will be returned with no change.
The 'GetNext' is required to get data only in sink mode, so this function is not applicable to no-sink mode.
Args:
network (Cell): The training network for dataset.
dataset_helper(DatasetHelper): A class to process the MindData dataset, it provides the type, shape and queue
name of the dataset to wrap the `GetNext`.
Outputs:
Cell, a new network wrapped with 'GetNext' in the case of running the task on Ascend in graph mode, otherwise
it is the input network.
Examples:
>>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset
>>> train_dataset = create_dataset()
>>> dataset_helper = mindspore.DatasetHelper(train_dataset)
>>> net = Net()
>>> net_with_get_next = connect_network_with_dataset(net, dataset_helper)
"""
class _DataWrapper(nn.Cell):
"""
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
dataset channel 'queue_name' and performs the forward computation.
"""
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
# Also copy the flag in `network` construct
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
self.add_flags(**flags)
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
self.network = network
def construct(self):
outputs = self.get_next()
return self.network(*outputs)
dataset_iter = dataset_helper.iter
dataset = dataset_iter.dataset
if isinstance(dataset_iter, _DatasetIterNormal):
raise RuntimeError("Dataset should be connected with network only in sink mode.")
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" and \
context.get_context("mode") == context.GRAPH_MODE and not context.get_context("enable_ge"):
dataset.__ME_INITED__ = True
dataset_types, dataset_shapes = dataset_helper.types_shapes()
queue_name = dataset.__TRANSFER_DATASET__.queue_name
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
return network
class DatasetHelper: class DatasetHelper:
""" """
Help function to use the MindData dataset. DatasetHelper is a class to process the MindData dataset and it provides the information of dataset.
According to different contexts, change the iterations of dataset and use the same iteration for loop in different According to different contexts, change the iterations of dataset and use the same iteration for loop in different
contexts. contexts.
@ -114,7 +176,6 @@ class _DatasetIter:
if hasattr(dataset, '__loop_size__'): if hasattr(dataset, '__loop_size__'):
self.sink_size = dataset.__loop_size__ self.sink_size = dataset.__loop_size__
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'): if not hasattr(dataset, '__no_send__'):
_send_data(dataset, epoch_num) _send_data(dataset, epoch_num)
@ -207,7 +268,7 @@ class _DatasetIterMS(_DatasetIter):
else: else:
self.sink_count = dataset.get_dataset_size() self.sink_count = dataset.get_dataset_size()
queue_name = dataset.__ME_INITED__ queue_name = dataset.__TRANSFER_DATASET__.queue_name
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)

View File

@ -35,7 +35,7 @@ from ..context import ParallelMode
from ..parallel._utils import _need_to_full, _to_full_tensor from ..parallel._utils import _need_to_full, _to_full_tensor
from ..parallel._cost_model_context import _set_multi_subgraphs from ..parallel._cost_model_context import _set_multi_subgraphs
from ..common import dtype as mstype from ..common import dtype as mstype
from .dataset_helper import DatasetHelper from .dataset_helper import DatasetHelper, connect_network_with_dataset
from . import amp from . import amp
@ -249,21 +249,13 @@ class Model:
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1): def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1):
"""Initializes dataset.""" """Initializes dataset."""
need_wrap = False if dataset_sink_mode and not is_train:
if dataset_sink_mode:
# remove later to deal with loop sink
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
and not context.get_context("enable_ge"):
need_wrap = True
if not is_train:
dataset.__loop_size__ = 1 dataset.__loop_size__ = 1
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num) dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
# remove later to deal with loop sink if dataset_sink_mode:
if need_wrap: network = connect_network_with_dataset(network, dataset_helper)
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
network.set_train(is_train) network.set_train(is_train)
network.phase = phase network.phase = phase
@ -306,11 +298,9 @@ class Model:
if train_dataset: if train_dataset:
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
self._train_network.set_train()
self._train_network.phase = 'train'
if self._parameter_broadcast: if self._parameter_broadcast:
self._train_network.set_broadcast_flag() self._train_network.set_broadcast_flag()
train_dataset.__no_send__ = True train_dataset.__no_send__ = True
train_dataset_helper, train_network = self._exec_preprocess(self._train_network, train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
is_train=True, is_train=True,
@ -326,8 +316,6 @@ class Model:
if not self._metric_fns: if not self._metric_fns:
raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.') raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.')
self._eval_network.set_train(False)
self._eval_network.phase = 'eval'
valid_dataset.__no_send__ = True valid_dataset.__no_send__ = True
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network, valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
is_train=False, is_train=False,
@ -358,8 +346,6 @@ class Model:
sink_size (int): Control the amount of data in each sink. Default: -1. sink_size (int): Control the amount of data in each sink. Default: -1.
""" """
epoch = check_int_positive(epoch) epoch = check_int_positive(epoch)
self._train_network.set_train()
if self._parameter_broadcast: if self._parameter_broadcast:
self._train_network.set_broadcast_flag() self._train_network.set_broadcast_flag()
@ -701,9 +687,6 @@ class Model:
cb_params.list_callback = self._transform_callbacks(callbacks) cb_params.list_callback = self._transform_callbacks(callbacks)
cb_params.network = self._network cb_params.network = self._network
self._eval_network.set_train(mode=False)
self._eval_network.phase = 'eval'
self._clear_metrics() self._clear_metrics()
if context.get_context("device_target") == "CPU": if context.get_context("device_target") == "CPU":

View File

@ -104,7 +104,6 @@ class _DatasetIter:
if hasattr(dataset, '__loop_size__'): if hasattr(dataset, '__loop_size__'):
self.sink_size = dataset.__loop_size__ self.sink_size = dataset.__loop_size__
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'): if not hasattr(dataset, '__no_send__'):
_send_data(dataset, epoch_num) _send_data(dataset, epoch_num)
@ -188,5 +187,5 @@ class _DatasetIterMS(_DatasetIter):
else: else:
self.sink_count = dataset.get_dataset_size() self.sink_count = dataset.get_dataset_size()
queue_name = dataset.__ME_INITED__ queue_name = dataset.__TRANSFER_DATASET__.queue_name
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)

View File

@ -17,9 +17,9 @@
import math import math
from mindspore.train.callback import RunContext from mindspore.train.callback import RunContext
from mindspore import context from mindspore import context
from mindspore import nn
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.dataset_helper import connect_network_with_dataset
from mindspore.parallel._utils import _need_to_full, _to_full_tensor from mindspore.parallel._utils import _need_to_full, _to_full_tensor
from mindspore.common.dtype import pytype_to_dtype from mindspore.common.dtype import pytype_to_dtype
from mindspore._c_expression import init_exec_dataset from mindspore._c_expression import init_exec_dataset
@ -57,7 +57,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
# transform data format # transform data format
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
init_exec_dataset(exec_dataset.__ME_INITED__, init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name,
dataset_size, dataset_size,
batch_size, batch_size,
dataset_types, dataset_types,
@ -114,21 +114,12 @@ class Model_Thor(Model):
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1,
epoch_num=1, iter_first_order=1): epoch_num=1, iter_first_order=1):
"""Initializes dataset.""" """Initializes dataset."""
need_wrap = False if dataset_sink_mode and not is_train:
if dataset_sink_mode:
# remove later to deal with loop sink
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
and not context.get_context("enable_ge"):
need_wrap = True
if not is_train:
dataset.__loop_size__ = 1 dataset.__loop_size__ = 1
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order) dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
# remove later to deal with loop sink if dataset_sink_mode:
if need_wrap: network = connect_network_with_dataset(network, dataset_helper)
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
network.set_train(is_train) network.set_train(is_train)
network.phase = phase network.phase = phase

View File

@ -111,7 +111,6 @@ class _DatasetIter:
if hasattr(dataset, '__loop_size__'): if hasattr(dataset, '__loop_size__'):
self.sink_size = dataset.__loop_size__ self.sink_size = dataset.__loop_size__
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size) dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'): if not hasattr(dataset, '__no_send__'):
_send_data(dataset, epoch_num) _send_data(dataset, epoch_num)

View File

@ -28,6 +28,7 @@ from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.nn.metrics import Loss from mindspore.nn.metrics import Loss
from mindspore.nn.metrics import get_metrics from mindspore.nn.metrics import get_metrics
from mindspore.train.dataset_helper import connect_network_with_dataset
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
@ -70,7 +71,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
# transform data format # transform data format
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
init_exec_dataset(exec_dataset.__ME_INITED__, init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name,
dataset_size, dataset_size,
batch_size, batch_size,
dataset_types, dataset_types,
@ -275,21 +276,12 @@ class Model:
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1, def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1,
iter_first_order=9): iter_first_order=9):
"""Initializes dataset.""" """Initializes dataset."""
need_wrap = False if dataset_sink_mode and not is_train:
if dataset_sink_mode:
# remove later to deal with loop sink
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
and not context.get_context("enable_ge"):
need_wrap = True
if not is_train:
dataset.__loop_size__ = 1 dataset.__loop_size__ = 1
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order) dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
# remove later to deal with loop sink if dataset_sink_mode:
if need_wrap: network = connect_network_with_dataset(network, dataset_helper)
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
network.set_train(is_train) network.set_train(is_train)
network.phase = phase network.phase = phase

View File

@ -18,6 +18,7 @@ from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _to_f
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
def _send_data(dataset): def _send_data(dataset):
"""Engine dataset to write data to tdt queue.""" """Engine dataset to write data to tdt queue."""
if not hasattr(dataset, '__has_sent__'): if not hasattr(dataset, '__has_sent__'):
@ -25,6 +26,7 @@ def _send_data(dataset):
exec_dataset.send() exec_dataset.send()
dataset.__has_sent__ = True dataset.__has_sent__ = True
class DatasetHelper: class DatasetHelper:
""" """
Help function to use the Minddata dataset. Help function to use the Minddata dataset.
@ -69,13 +71,12 @@ class _DatasetIter:
def __init__(self, dataset): def __init__(self, dataset):
self.loop_size = 1 self.loop_size = 1
if not hasattr(dataset, '__ME_INITED__'): if not hasattr(dataset, '__TRANSFER_DATASET__'):
if not hasattr(dataset, '__loop_size__'): if not hasattr(dataset, '__loop_size__'):
self.loop_size = dataset.get_dataset_size() self.loop_size = dataset.get_dataset_size()
else: else:
self.loop_size = dataset.__loop_size__ self.loop_size = dataset.__loop_size__
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'): if not hasattr(dataset, '__no_send__'):
_send_data(dataset) _send_data(dataset)

View File

@ -23,6 +23,7 @@ from mindspore._checkparam import check_input_data, check_output_data, check_int
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.dtype import pytype_to_dtype from mindspore.common.dtype import pytype_to_dtype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.train.dataset_helper import connect_network_with_dataset
from mindspore.nn.metrics import Loss from mindspore.nn.metrics import Loss
from mindspore.nn.metrics import get_metrics from mindspore.nn.metrics import get_metrics
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
@ -66,7 +67,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
# transform data format # transform data format
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset) dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
init_exec_dataset(exec_dataset.__ME_INITED__, init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name,
dataset_size, dataset_size,
batch_size, batch_size,
dataset_types, dataset_types,
@ -266,21 +267,12 @@ class Model:
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, iter_first_order=1): def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, iter_first_order=1):
"""Initializes dataset.""" """Initializes dataset."""
need_wrap = False if dataset_sink_mode and not is_train:
if dataset_sink_mode:
# remove later to deal with loop sink
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
and not context.get_context("enable_ge"):
need_wrap = True
if not is_train:
dataset.__loop_size__ = 1 dataset.__loop_size__ = 1
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, iter_first_order) dataset_helper = DatasetHelper(dataset, dataset_sink_mode, iter_first_order)
# remove later to deal with loop sink if dataset_sink_mode:
if need_wrap: network = connect_network_with_dataset(network, dataset_helper)
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
network.set_train(is_train) network.set_train(is_train)
network.phase = phase network.phase = phase
@ -605,7 +597,6 @@ class Model:
Dict, returns the loss value & metrics values for the model in test mode. Dict, returns the loss value & metrics values for the model in test mode.
""" """
run_context = RunContext(cb_params) run_context = RunContext(cb_params)
dataset_helper, eval_network = self._exec_preprocess(self._eval_network, dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
is_train=False, is_train=False,
phase='eval', phase='eval',