forked from OSSInnovation/mindspore
!5908 delete DataWrapper in Model to unify the expression of dataset sinking
Merge pull request !5908 from wangnan39/optim_datawrapper_in_model
This commit is contained in:
commit
55372b41fd
|
@ -17,11 +17,12 @@ import math
|
|||
import os
|
||||
|
||||
from mindspore._checkparam import check_bool, check_int
|
||||
from .. import context
|
||||
from .. import context, nn
|
||||
from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
|
||||
_construct_tensor_list
|
||||
from ..nn.wrap import GetNextSingleOp
|
||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes
|
||||
from ..ops import operations as P
|
||||
|
||||
|
||||
def _send_data(dataset, epoch_num):
|
||||
|
@ -37,9 +38,70 @@ def _send_data_no_flag(dataset, epoch_num):
|
|||
exec_dataset.send(epoch_num)
|
||||
|
||||
|
||||
def connect_network_with_dataset(network, dataset_helper):
|
||||
"""
|
||||
Connect the `network` with dataset in `dataset_helper`.
|
||||
|
||||
This function wraps the input network with 'GetNext' so that the data can be fetched automatically from the
|
||||
data channel corresponding to the 'queue_name' and passed to the input network during forward computation.
|
||||
|
||||
Note:
|
||||
In the case of running the network on Ascend in graph mode, this function will wrap the input network with
|
||||
'GetNext', in other cases, the input network will be returned with no change.
|
||||
The 'GetNext' is required to get data only in sink mode, so this function is not applicable to no-sink mode.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network for dataset.
|
||||
dataset_helper(DatasetHelper): A class to process the MindData dataset, it provides the type, shape and queue
|
||||
name of the dataset to wrap the `GetNext`.
|
||||
|
||||
Outputs:
|
||||
Cell, a new network wrapped with 'GetNext' in the case of running the task on Ascend in graph mode, otherwise
|
||||
it is the input network.
|
||||
|
||||
Examples:
|
||||
>>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset
|
||||
>>> train_dataset = create_dataset()
|
||||
>>> dataset_helper = mindspore.DatasetHelper(train_dataset)
|
||||
>>> net = Net()
|
||||
>>> net_with_get_next = connect_network_with_dataset(net, dataset_helper)
|
||||
"""
|
||||
class _DataWrapper(nn.Cell):
|
||||
"""
|
||||
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
|
||||
dataset channel 'queue_name' and performs the forward computation.
|
||||
"""
|
||||
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
|
||||
super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
|
||||
# Also copy the flag in `network` construct
|
||||
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
|
||||
self.add_flags(**flags)
|
||||
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
||||
self.network = network
|
||||
|
||||
def construct(self):
|
||||
outputs = self.get_next()
|
||||
return self.network(*outputs)
|
||||
|
||||
dataset_iter = dataset_helper.iter
|
||||
dataset = dataset_iter.dataset
|
||||
|
||||
if isinstance(dataset_iter, _DatasetIterNormal):
|
||||
raise RuntimeError("Dataset should be connected with network only in sink mode.")
|
||||
|
||||
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" and \
|
||||
context.get_context("mode") == context.GRAPH_MODE and not context.get_context("enable_ge"):
|
||||
dataset.__ME_INITED__ = True
|
||||
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
||||
queue_name = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
|
||||
return network
|
||||
|
||||
|
||||
class DatasetHelper:
|
||||
"""
|
||||
Help function to use the MindData dataset.
|
||||
DatasetHelper is a class to process the MindData dataset and it provides the information of dataset.
|
||||
|
||||
According to different contexts, change the iterations of dataset and use the same iteration for loop in different
|
||||
contexts.
|
||||
|
@ -114,7 +176,6 @@ class _DatasetIter:
|
|||
if hasattr(dataset, '__loop_size__'):
|
||||
self.sink_size = dataset.__loop_size__
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset, epoch_num)
|
||||
|
@ -207,7 +268,7 @@ class _DatasetIterMS(_DatasetIter):
|
|||
else:
|
||||
self.sink_count = dataset.get_dataset_size()
|
||||
|
||||
queue_name = dataset.__ME_INITED__
|
||||
queue_name = dataset.__TRANSFER_DATASET__.queue_name
|
||||
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
|
||||
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ from ..context import ParallelMode
|
|||
from ..parallel._utils import _need_to_full, _to_full_tensor
|
||||
from ..parallel._cost_model_context import _set_multi_subgraphs
|
||||
from ..common import dtype as mstype
|
||||
from .dataset_helper import DatasetHelper
|
||||
from .dataset_helper import DatasetHelper, connect_network_with_dataset
|
||||
from . import amp
|
||||
|
||||
|
||||
|
@ -249,23 +249,15 @@ class Model:
|
|||
|
||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1):
|
||||
"""Initializes dataset."""
|
||||
need_wrap = False
|
||||
if dataset_sink_mode:
|
||||
# remove later to deal with loop sink
|
||||
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
|
||||
and not context.get_context("enable_ge"):
|
||||
need_wrap = True
|
||||
|
||||
if not is_train:
|
||||
dataset.__loop_size__ = 1
|
||||
|
||||
if dataset_sink_mode and not is_train:
|
||||
dataset.__loop_size__ = 1
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
|
||||
|
||||
# remove later to deal with loop sink
|
||||
if need_wrap:
|
||||
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
|
||||
network.set_train(is_train)
|
||||
network.phase = phase
|
||||
if dataset_sink_mode:
|
||||
network = connect_network_with_dataset(network, dataset_helper)
|
||||
|
||||
network.set_train(is_train)
|
||||
network.phase = phase
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network.set_auto_parallel()
|
||||
|
@ -306,11 +298,9 @@ class Model:
|
|||
|
||||
if train_dataset:
|
||||
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
||||
self._train_network.set_train()
|
||||
self._train_network.phase = 'train'
|
||||
|
||||
if self._parameter_broadcast:
|
||||
self._train_network.set_broadcast_flag()
|
||||
|
||||
train_dataset.__no_send__ = True
|
||||
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
|
@ -326,8 +316,6 @@ class Model:
|
|||
if not self._metric_fns:
|
||||
raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.')
|
||||
|
||||
self._eval_network.set_train(False)
|
||||
self._eval_network.phase = 'eval'
|
||||
valid_dataset.__no_send__ = True
|
||||
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
|
@ -358,8 +346,6 @@ class Model:
|
|||
sink_size (int): Control the amount of data in each sink. Default: -1.
|
||||
"""
|
||||
epoch = check_int_positive(epoch)
|
||||
self._train_network.set_train()
|
||||
|
||||
if self._parameter_broadcast:
|
||||
self._train_network.set_broadcast_flag()
|
||||
|
||||
|
@ -701,9 +687,6 @@ class Model:
|
|||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||
cb_params.network = self._network
|
||||
|
||||
self._eval_network.set_train(mode=False)
|
||||
self._eval_network.phase = 'eval'
|
||||
|
||||
self._clear_metrics()
|
||||
|
||||
if context.get_context("device_target") == "CPU":
|
||||
|
|
|
@ -104,7 +104,6 @@ class _DatasetIter:
|
|||
if hasattr(dataset, '__loop_size__'):
|
||||
self.sink_size = dataset.__loop_size__
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset, epoch_num)
|
||||
|
@ -188,5 +187,5 @@ class _DatasetIterMS(_DatasetIter):
|
|||
else:
|
||||
self.sink_count = dataset.get_dataset_size()
|
||||
|
||||
queue_name = dataset.__ME_INITED__
|
||||
queue_name = dataset.__TRANSFER_DATASET__.queue_name
|
||||
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
|
||||
|
|
|
@ -17,9 +17,9 @@
|
|||
import math
|
||||
from mindspore.train.callback import RunContext
|
||||
from mindspore import context
|
||||
from mindspore import nn
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.dataset_helper import connect_network_with_dataset
|
||||
from mindspore.parallel._utils import _need_to_full, _to_full_tensor
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from mindspore._c_expression import init_exec_dataset
|
||||
|
@ -57,7 +57,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
|||
|
||||
# transform data format
|
||||
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
||||
init_exec_dataset(exec_dataset.__ME_INITED__,
|
||||
init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name,
|
||||
dataset_size,
|
||||
batch_size,
|
||||
dataset_types,
|
||||
|
@ -114,23 +114,14 @@ class Model_Thor(Model):
|
|||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1,
|
||||
epoch_num=1, iter_first_order=1):
|
||||
"""Initializes dataset."""
|
||||
need_wrap = False
|
||||
if dataset_sink_mode:
|
||||
# remove later to deal with loop sink
|
||||
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
|
||||
and not context.get_context("enable_ge"):
|
||||
need_wrap = True
|
||||
|
||||
if not is_train:
|
||||
dataset.__loop_size__ = 1
|
||||
|
||||
if dataset_sink_mode and not is_train:
|
||||
dataset.__loop_size__ = 1
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
|
||||
|
||||
# remove later to deal with loop sink
|
||||
if need_wrap:
|
||||
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
|
||||
network.set_train(is_train)
|
||||
network.phase = phase
|
||||
if dataset_sink_mode:
|
||||
network = connect_network_with_dataset(network, dataset_helper)
|
||||
network.set_train(is_train)
|
||||
network.phase = phase
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network.set_auto_parallel()
|
||||
|
|
|
@ -111,7 +111,6 @@ class _DatasetIter:
|
|||
if hasattr(dataset, '__loop_size__'):
|
||||
self.sink_size = dataset.__loop_size__
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset, epoch_num)
|
||||
|
|
|
@ -28,6 +28,7 @@ from mindspore.common.dtype import pytype_to_dtype
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.metrics import Loss
|
||||
from mindspore.nn.metrics import get_metrics
|
||||
from mindspore.train.dataset_helper import connect_network_with_dataset
|
||||
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||
|
@ -70,7 +71,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
|||
|
||||
# transform data format
|
||||
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
||||
init_exec_dataset(exec_dataset.__ME_INITED__,
|
||||
init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name,
|
||||
dataset_size,
|
||||
batch_size,
|
||||
dataset_types,
|
||||
|
@ -275,23 +276,14 @@ class Model:
|
|||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1,
|
||||
iter_first_order=9):
|
||||
"""Initializes dataset."""
|
||||
need_wrap = False
|
||||
if dataset_sink_mode:
|
||||
# remove later to deal with loop sink
|
||||
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
|
||||
and not context.get_context("enable_ge"):
|
||||
need_wrap = True
|
||||
|
||||
if not is_train:
|
||||
dataset.__loop_size__ = 1
|
||||
|
||||
if dataset_sink_mode and not is_train:
|
||||
dataset.__loop_size__ = 1
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
|
||||
|
||||
# remove later to deal with loop sink
|
||||
if need_wrap:
|
||||
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
|
||||
network.set_train(is_train)
|
||||
network.phase = phase
|
||||
if dataset_sink_mode:
|
||||
network = connect_network_with_dataset(network, dataset_helper)
|
||||
network.set_train(is_train)
|
||||
network.phase = phase
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
network.set_auto_parallel()
|
||||
|
|
|
@ -18,6 +18,7 @@ from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _to_f
|
|||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
|
||||
from mindspore.context import ParallelMode
|
||||
|
||||
|
||||
def _send_data(dataset):
|
||||
"""Engine dataset to write data to tdt queue."""
|
||||
if not hasattr(dataset, '__has_sent__'):
|
||||
|
@ -25,6 +26,7 @@ def _send_data(dataset):
|
|||
exec_dataset.send()
|
||||
dataset.__has_sent__ = True
|
||||
|
||||
|
||||
class DatasetHelper:
|
||||
"""
|
||||
Help function to use the Minddata dataset.
|
||||
|
@ -69,13 +71,12 @@ class _DatasetIter:
|
|||
|
||||
def __init__(self, dataset):
|
||||
self.loop_size = 1
|
||||
if not hasattr(dataset, '__ME_INITED__'):
|
||||
if not hasattr(dataset, '__TRANSFER_DATASET__'):
|
||||
if not hasattr(dataset, '__loop_size__'):
|
||||
self.loop_size = dataset.get_dataset_size()
|
||||
else:
|
||||
self.loop_size = dataset.__loop_size__
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset)
|
||||
|
|
|
@ -23,6 +23,7 @@ from mindspore._checkparam import check_input_data, check_output_data, check_int
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.train.dataset_helper import connect_network_with_dataset
|
||||
from mindspore.nn.metrics import Loss
|
||||
from mindspore.nn.metrics import get_metrics
|
||||
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
|
@ -66,7 +67,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
|||
|
||||
# transform data format
|
||||
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
||||
init_exec_dataset(exec_dataset.__ME_INITED__,
|
||||
init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name,
|
||||
dataset_size,
|
||||
batch_size,
|
||||
dataset_types,
|
||||
|
@ -266,23 +267,14 @@ class Model:
|
|||
|
||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, iter_first_order=1):
|
||||
"""Initializes dataset."""
|
||||
need_wrap = False
|
||||
if dataset_sink_mode:
|
||||
# remove later to deal with loop sink
|
||||
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
|
||||
and not context.get_context("enable_ge"):
|
||||
need_wrap = True
|
||||
|
||||
if not is_train:
|
||||
dataset.__loop_size__ = 1
|
||||
|
||||
if dataset_sink_mode and not is_train:
|
||||
dataset.__loop_size__ = 1
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, iter_first_order)
|
||||
|
||||
# remove later to deal with loop sink
|
||||
if need_wrap:
|
||||
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
|
||||
network.set_train(is_train)
|
||||
network.phase = phase
|
||||
if dataset_sink_mode:
|
||||
network = connect_network_with_dataset(network, dataset_helper)
|
||||
network.set_train(is_train)
|
||||
network.phase = phase
|
||||
|
||||
return dataset_helper, network
|
||||
|
||||
|
@ -605,7 +597,6 @@ class Model:
|
|||
Dict, returns the loss value & metrics values for the model in test mode.
|
||||
"""
|
||||
run_context = RunContext(cb_params)
|
||||
|
||||
dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||
is_train=False,
|
||||
phase='eval',
|
||||
|
|
Loading…
Reference in New Issue