forked from mindspore-Ecosystem/mindspore
!5908 delete DataWrapper in Model to unify the expression of dataset sinking
Merge pull request !5908 from wangnan39/optim_datawrapper_in_model
This commit is contained in:
commit
55372b41fd
|
@ -17,11 +17,12 @@ import math
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from mindspore._checkparam import check_bool, check_int
|
from mindspore._checkparam import check_bool, check_int
|
||||||
from .. import context
|
from .. import context, nn
|
||||||
from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
|
from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
|
||||||
_construct_tensor_list
|
_construct_tensor_list
|
||||||
from ..nn.wrap import GetNextSingleOp
|
from ..nn.wrap import GetNextSingleOp
|
||||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes
|
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes
|
||||||
|
from ..ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
def _send_data(dataset, epoch_num):
|
def _send_data(dataset, epoch_num):
|
||||||
|
@ -37,9 +38,70 @@ def _send_data_no_flag(dataset, epoch_num):
|
||||||
exec_dataset.send(epoch_num)
|
exec_dataset.send(epoch_num)
|
||||||
|
|
||||||
|
|
||||||
|
def connect_network_with_dataset(network, dataset_helper):
|
||||||
|
"""
|
||||||
|
Connect the `network` with dataset in `dataset_helper`.
|
||||||
|
|
||||||
|
This function wraps the input network with 'GetNext' so that the data can be fetched automatically from the
|
||||||
|
data channel corresponding to the 'queue_name' and passed to the input network during forward computation.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
In the case of running the network on Ascend in graph mode, this function will wrap the input network with
|
||||||
|
'GetNext', in other cases, the input network will be returned with no change.
|
||||||
|
The 'GetNext' is required to get data only in sink mode, so this function is not applicable to no-sink mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): The training network for dataset.
|
||||||
|
dataset_helper(DatasetHelper): A class to process the MindData dataset, it provides the type, shape and queue
|
||||||
|
name of the dataset to wrap the `GetNext`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Cell, a new network wrapped with 'GetNext' in the case of running the task on Ascend in graph mode, otherwise
|
||||||
|
it is the input network.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset
|
||||||
|
>>> train_dataset = create_dataset()
|
||||||
|
>>> dataset_helper = mindspore.DatasetHelper(train_dataset)
|
||||||
|
>>> net = Net()
|
||||||
|
>>> net_with_get_next = connect_network_with_dataset(net, dataset_helper)
|
||||||
|
"""
|
||||||
|
class _DataWrapper(nn.Cell):
|
||||||
|
"""
|
||||||
|
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
|
||||||
|
dataset channel 'queue_name' and performs the forward computation.
|
||||||
|
"""
|
||||||
|
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
|
||||||
|
super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
|
||||||
|
# Also copy the flag in `network` construct
|
||||||
|
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
|
||||||
|
self.add_flags(**flags)
|
||||||
|
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
outputs = self.get_next()
|
||||||
|
return self.network(*outputs)
|
||||||
|
|
||||||
|
dataset_iter = dataset_helper.iter
|
||||||
|
dataset = dataset_iter.dataset
|
||||||
|
|
||||||
|
if isinstance(dataset_iter, _DatasetIterNormal):
|
||||||
|
raise RuntimeError("Dataset should be connected with network only in sink mode.")
|
||||||
|
|
||||||
|
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" and \
|
||||||
|
context.get_context("mode") == context.GRAPH_MODE and not context.get_context("enable_ge"):
|
||||||
|
dataset.__ME_INITED__ = True
|
||||||
|
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
||||||
|
queue_name = dataset.__TRANSFER_DATASET__.queue_name
|
||||||
|
|
||||||
|
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
|
||||||
|
return network
|
||||||
|
|
||||||
|
|
||||||
class DatasetHelper:
|
class DatasetHelper:
|
||||||
"""
|
"""
|
||||||
Help function to use the MindData dataset.
|
DatasetHelper is a class to process the MindData dataset and it provides the information of dataset.
|
||||||
|
|
||||||
According to different contexts, change the iterations of dataset and use the same iteration for loop in different
|
According to different contexts, change the iterations of dataset and use the same iteration for loop in different
|
||||||
contexts.
|
contexts.
|
||||||
|
@ -114,7 +176,6 @@ class _DatasetIter:
|
||||||
if hasattr(dataset, '__loop_size__'):
|
if hasattr(dataset, '__loop_size__'):
|
||||||
self.sink_size = dataset.__loop_size__
|
self.sink_size = dataset.__loop_size__
|
||||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
||||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
|
||||||
|
|
||||||
if not hasattr(dataset, '__no_send__'):
|
if not hasattr(dataset, '__no_send__'):
|
||||||
_send_data(dataset, epoch_num)
|
_send_data(dataset, epoch_num)
|
||||||
|
@ -207,7 +268,7 @@ class _DatasetIterMS(_DatasetIter):
|
||||||
else:
|
else:
|
||||||
self.sink_count = dataset.get_dataset_size()
|
self.sink_count = dataset.get_dataset_size()
|
||||||
|
|
||||||
queue_name = dataset.__ME_INITED__
|
queue_name = dataset.__TRANSFER_DATASET__.queue_name
|
||||||
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
|
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ from ..context import ParallelMode
|
||||||
from ..parallel._utils import _need_to_full, _to_full_tensor
|
from ..parallel._utils import _need_to_full, _to_full_tensor
|
||||||
from ..parallel._cost_model_context import _set_multi_subgraphs
|
from ..parallel._cost_model_context import _set_multi_subgraphs
|
||||||
from ..common import dtype as mstype
|
from ..common import dtype as mstype
|
||||||
from .dataset_helper import DatasetHelper
|
from .dataset_helper import DatasetHelper, connect_network_with_dataset
|
||||||
from . import amp
|
from . import amp
|
||||||
|
|
||||||
|
|
||||||
|
@ -249,21 +249,13 @@ class Model:
|
||||||
|
|
||||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1):
|
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1):
|
||||||
"""Initializes dataset."""
|
"""Initializes dataset."""
|
||||||
need_wrap = False
|
if dataset_sink_mode and not is_train:
|
||||||
if dataset_sink_mode:
|
|
||||||
# remove later to deal with loop sink
|
|
||||||
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
|
|
||||||
and not context.get_context("enable_ge"):
|
|
||||||
need_wrap = True
|
|
||||||
|
|
||||||
if not is_train:
|
|
||||||
dataset.__loop_size__ = 1
|
dataset.__loop_size__ = 1
|
||||||
|
|
||||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
|
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
|
||||||
|
|
||||||
# remove later to deal with loop sink
|
if dataset_sink_mode:
|
||||||
if need_wrap:
|
network = connect_network_with_dataset(network, dataset_helper)
|
||||||
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
|
|
||||||
network.set_train(is_train)
|
network.set_train(is_train)
|
||||||
network.phase = phase
|
network.phase = phase
|
||||||
|
|
||||||
|
@ -306,11 +298,9 @@ class Model:
|
||||||
|
|
||||||
if train_dataset:
|
if train_dataset:
|
||||||
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
||||||
self._train_network.set_train()
|
|
||||||
self._train_network.phase = 'train'
|
|
||||||
|
|
||||||
if self._parameter_broadcast:
|
if self._parameter_broadcast:
|
||||||
self._train_network.set_broadcast_flag()
|
self._train_network.set_broadcast_flag()
|
||||||
|
|
||||||
train_dataset.__no_send__ = True
|
train_dataset.__no_send__ = True
|
||||||
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
train_dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||||
is_train=True,
|
is_train=True,
|
||||||
|
@ -326,8 +316,6 @@ class Model:
|
||||||
if not self._metric_fns:
|
if not self._metric_fns:
|
||||||
raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.')
|
raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.')
|
||||||
|
|
||||||
self._eval_network.set_train(False)
|
|
||||||
self._eval_network.phase = 'eval'
|
|
||||||
valid_dataset.__no_send__ = True
|
valid_dataset.__no_send__ = True
|
||||||
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||||
is_train=False,
|
is_train=False,
|
||||||
|
@ -358,8 +346,6 @@ class Model:
|
||||||
sink_size (int): Control the amount of data in each sink. Default: -1.
|
sink_size (int): Control the amount of data in each sink. Default: -1.
|
||||||
"""
|
"""
|
||||||
epoch = check_int_positive(epoch)
|
epoch = check_int_positive(epoch)
|
||||||
self._train_network.set_train()
|
|
||||||
|
|
||||||
if self._parameter_broadcast:
|
if self._parameter_broadcast:
|
||||||
self._train_network.set_broadcast_flag()
|
self._train_network.set_broadcast_flag()
|
||||||
|
|
||||||
|
@ -701,9 +687,6 @@ class Model:
|
||||||
cb_params.list_callback = self._transform_callbacks(callbacks)
|
cb_params.list_callback = self._transform_callbacks(callbacks)
|
||||||
cb_params.network = self._network
|
cb_params.network = self._network
|
||||||
|
|
||||||
self._eval_network.set_train(mode=False)
|
|
||||||
self._eval_network.phase = 'eval'
|
|
||||||
|
|
||||||
self._clear_metrics()
|
self._clear_metrics()
|
||||||
|
|
||||||
if context.get_context("device_target") == "CPU":
|
if context.get_context("device_target") == "CPU":
|
||||||
|
|
|
@ -104,7 +104,6 @@ class _DatasetIter:
|
||||||
if hasattr(dataset, '__loop_size__'):
|
if hasattr(dataset, '__loop_size__'):
|
||||||
self.sink_size = dataset.__loop_size__
|
self.sink_size = dataset.__loop_size__
|
||||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
||||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
|
||||||
|
|
||||||
if not hasattr(dataset, '__no_send__'):
|
if not hasattr(dataset, '__no_send__'):
|
||||||
_send_data(dataset, epoch_num)
|
_send_data(dataset, epoch_num)
|
||||||
|
@ -188,5 +187,5 @@ class _DatasetIterMS(_DatasetIter):
|
||||||
else:
|
else:
|
||||||
self.sink_count = dataset.get_dataset_size()
|
self.sink_count = dataset.get_dataset_size()
|
||||||
|
|
||||||
queue_name = dataset.__ME_INITED__
|
queue_name = dataset.__TRANSFER_DATASET__.queue_name
|
||||||
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
|
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
|
||||||
|
|
|
@ -17,9 +17,9 @@
|
||||||
import math
|
import math
|
||||||
from mindspore.train.callback import RunContext
|
from mindspore.train.callback import RunContext
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore import nn
|
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.dataset_helper import connect_network_with_dataset
|
||||||
from mindspore.parallel._utils import _need_to_full, _to_full_tensor
|
from mindspore.parallel._utils import _need_to_full, _to_full_tensor
|
||||||
from mindspore.common.dtype import pytype_to_dtype
|
from mindspore.common.dtype import pytype_to_dtype
|
||||||
from mindspore._c_expression import init_exec_dataset
|
from mindspore._c_expression import init_exec_dataset
|
||||||
|
@ -57,7 +57,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
||||||
|
|
||||||
# transform data format
|
# transform data format
|
||||||
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
||||||
init_exec_dataset(exec_dataset.__ME_INITED__,
|
init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name,
|
||||||
dataset_size,
|
dataset_size,
|
||||||
batch_size,
|
batch_size,
|
||||||
dataset_types,
|
dataset_types,
|
||||||
|
@ -114,21 +114,12 @@ class Model_Thor(Model):
|
||||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1,
|
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1,
|
||||||
epoch_num=1, iter_first_order=1):
|
epoch_num=1, iter_first_order=1):
|
||||||
"""Initializes dataset."""
|
"""Initializes dataset."""
|
||||||
need_wrap = False
|
if dataset_sink_mode and not is_train:
|
||||||
if dataset_sink_mode:
|
|
||||||
# remove later to deal with loop sink
|
|
||||||
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
|
|
||||||
and not context.get_context("enable_ge"):
|
|
||||||
need_wrap = True
|
|
||||||
|
|
||||||
if not is_train:
|
|
||||||
dataset.__loop_size__ = 1
|
dataset.__loop_size__ = 1
|
||||||
|
|
||||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
|
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
|
||||||
|
|
||||||
# remove later to deal with loop sink
|
if dataset_sink_mode:
|
||||||
if need_wrap:
|
network = connect_network_with_dataset(network, dataset_helper)
|
||||||
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
|
|
||||||
network.set_train(is_train)
|
network.set_train(is_train)
|
||||||
network.phase = phase
|
network.phase = phase
|
||||||
|
|
||||||
|
|
|
@ -111,7 +111,6 @@ class _DatasetIter:
|
||||||
if hasattr(dataset, '__loop_size__'):
|
if hasattr(dataset, '__loop_size__'):
|
||||||
self.sink_size = dataset.__loop_size__
|
self.sink_size = dataset.__loop_size__
|
||||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
||||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
|
||||||
|
|
||||||
if not hasattr(dataset, '__no_send__'):
|
if not hasattr(dataset, '__no_send__'):
|
||||||
_send_data(dataset, epoch_num)
|
_send_data(dataset, epoch_num)
|
||||||
|
|
|
@ -28,6 +28,7 @@ from mindspore.common.dtype import pytype_to_dtype
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore.nn.metrics import Loss
|
from mindspore.nn.metrics import Loss
|
||||||
from mindspore.nn.metrics import get_metrics
|
from mindspore.nn.metrics import get_metrics
|
||||||
|
from mindspore.train.dataset_helper import connect_network_with_dataset
|
||||||
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||||
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
|
||||||
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
|
||||||
|
@ -70,7 +71,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
||||||
|
|
||||||
# transform data format
|
# transform data format
|
||||||
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
||||||
init_exec_dataset(exec_dataset.__ME_INITED__,
|
init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name,
|
||||||
dataset_size,
|
dataset_size,
|
||||||
batch_size,
|
batch_size,
|
||||||
dataset_types,
|
dataset_types,
|
||||||
|
@ -275,21 +276,12 @@ class Model:
|
||||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1,
|
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1,
|
||||||
iter_first_order=9):
|
iter_first_order=9):
|
||||||
"""Initializes dataset."""
|
"""Initializes dataset."""
|
||||||
need_wrap = False
|
if dataset_sink_mode and not is_train:
|
||||||
if dataset_sink_mode:
|
|
||||||
# remove later to deal with loop sink
|
|
||||||
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
|
|
||||||
and not context.get_context("enable_ge"):
|
|
||||||
need_wrap = True
|
|
||||||
|
|
||||||
if not is_train:
|
|
||||||
dataset.__loop_size__ = 1
|
dataset.__loop_size__ = 1
|
||||||
|
|
||||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
|
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)
|
||||||
|
|
||||||
# remove later to deal with loop sink
|
if dataset_sink_mode:
|
||||||
if need_wrap:
|
network = connect_network_with_dataset(network, dataset_helper)
|
||||||
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
|
|
||||||
network.set_train(is_train)
|
network.set_train(is_train)
|
||||||
network.phase = phase
|
network.phase = phase
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _to_f
|
||||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
|
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
|
||||||
from mindspore.context import ParallelMode
|
from mindspore.context import ParallelMode
|
||||||
|
|
||||||
|
|
||||||
def _send_data(dataset):
|
def _send_data(dataset):
|
||||||
"""Engine dataset to write data to tdt queue."""
|
"""Engine dataset to write data to tdt queue."""
|
||||||
if not hasattr(dataset, '__has_sent__'):
|
if not hasattr(dataset, '__has_sent__'):
|
||||||
|
@ -25,6 +26,7 @@ def _send_data(dataset):
|
||||||
exec_dataset.send()
|
exec_dataset.send()
|
||||||
dataset.__has_sent__ = True
|
dataset.__has_sent__ = True
|
||||||
|
|
||||||
|
|
||||||
class DatasetHelper:
|
class DatasetHelper:
|
||||||
"""
|
"""
|
||||||
Help function to use the Minddata dataset.
|
Help function to use the Minddata dataset.
|
||||||
|
@ -69,13 +71,12 @@ class _DatasetIter:
|
||||||
|
|
||||||
def __init__(self, dataset):
|
def __init__(self, dataset):
|
||||||
self.loop_size = 1
|
self.loop_size = 1
|
||||||
if not hasattr(dataset, '__ME_INITED__'):
|
if not hasattr(dataset, '__TRANSFER_DATASET__'):
|
||||||
if not hasattr(dataset, '__loop_size__'):
|
if not hasattr(dataset, '__loop_size__'):
|
||||||
self.loop_size = dataset.get_dataset_size()
|
self.loop_size = dataset.get_dataset_size()
|
||||||
else:
|
else:
|
||||||
self.loop_size = dataset.__loop_size__
|
self.loop_size = dataset.__loop_size__
|
||||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
||||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
|
||||||
|
|
||||||
if not hasattr(dataset, '__no_send__'):
|
if not hasattr(dataset, '__no_send__'):
|
||||||
_send_data(dataset)
|
_send_data(dataset)
|
||||||
|
|
|
@ -23,6 +23,7 @@ from mindspore._checkparam import check_input_data, check_output_data, check_int
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.common.dtype import pytype_to_dtype
|
from mindspore.common.dtype import pytype_to_dtype
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
|
from mindspore.train.dataset_helper import connect_network_with_dataset
|
||||||
from mindspore.nn.metrics import Loss
|
from mindspore.nn.metrics import Loss
|
||||||
from mindspore.nn.metrics import get_metrics
|
from mindspore.nn.metrics import get_metrics
|
||||||
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||||
|
@ -66,7 +67,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
||||||
|
|
||||||
# transform data format
|
# transform data format
|
||||||
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
||||||
init_exec_dataset(exec_dataset.__ME_INITED__,
|
init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name,
|
||||||
dataset_size,
|
dataset_size,
|
||||||
batch_size,
|
batch_size,
|
||||||
dataset_types,
|
dataset_types,
|
||||||
|
@ -266,21 +267,12 @@ class Model:
|
||||||
|
|
||||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, iter_first_order=1):
|
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, iter_first_order=1):
|
||||||
"""Initializes dataset."""
|
"""Initializes dataset."""
|
||||||
need_wrap = False
|
if dataset_sink_mode and not is_train:
|
||||||
if dataset_sink_mode:
|
|
||||||
# remove later to deal with loop sink
|
|
||||||
if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
|
|
||||||
and not context.get_context("enable_ge"):
|
|
||||||
need_wrap = True
|
|
||||||
|
|
||||||
if not is_train:
|
|
||||||
dataset.__loop_size__ = 1
|
dataset.__loop_size__ = 1
|
||||||
|
|
||||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, iter_first_order)
|
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, iter_first_order)
|
||||||
|
|
||||||
# remove later to deal with loop sink
|
if dataset_sink_mode:
|
||||||
if need_wrap:
|
network = connect_network_with_dataset(network, dataset_helper)
|
||||||
network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__)
|
|
||||||
network.set_train(is_train)
|
network.set_train(is_train)
|
||||||
network.phase = phase
|
network.phase = phase
|
||||||
|
|
||||||
|
@ -605,7 +597,6 @@ class Model:
|
||||||
Dict, returns the loss value & metrics values for the model in test mode.
|
Dict, returns the loss value & metrics values for the model in test mode.
|
||||||
"""
|
"""
|
||||||
run_context = RunContext(cb_params)
|
run_context = RunContext(cb_params)
|
||||||
|
|
||||||
dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
dataset_helper, eval_network = self._exec_preprocess(self._eval_network,
|
||||||
is_train=False,
|
is_train=False,
|
||||||
phase='eval',
|
phase='eval',
|
||||||
|
|
Loading…
Reference in New Issue