!13045 avoid insert getnext op twice when eval or train twice with data sink mode

From: @zhangbuxue
Reviewed-by: @ginfung
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-12 09:30:28 +08:00 committed by Gitee
commit 6a64644e66
1 changed files with 40 additions and 31 deletions

View File

@ -39,6 +39,32 @@ def _send_data_no_flag(dataset, epoch_num):
exec_dataset.send(epoch_num) exec_dataset.send(epoch_num)
class _DataWrapper(nn.Cell):
"""
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
dataset channel 'queue_name' and performs the forward computation.
"""
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
# Also copy the flag in `network` construct
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
self.info = (dataset_types, dataset_shapes)
self.add_flags(**flags)
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
self.network = network
def construct(self):
outputs = self.get_next()
return self.network(*outputs)
def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name):
if not isinstance(network, _DataWrapper):
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
return network
def connect_network_with_dataset(network, dataset_helper): def connect_network_with_dataset(network, dataset_helper):
""" """
Connect the `network` with dataset in `dataset_helper`. Connect the `network` with dataset in `dataset_helper`.
@ -70,24 +96,6 @@ def connect_network_with_dataset(network, dataset_helper):
>>> net_with_get_next = connect_network_with_dataset(net, dataset_helper) >>> net_with_get_next = connect_network_with_dataset(net, dataset_helper)
""" """
class _DataWrapper(nn.Cell):
"""
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
dataset channel 'queue_name' and performs the forward computation.
"""
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
# Also copy the flag in `network` construct
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
self.add_flags(**flags)
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
self.network = network
def construct(self):
outputs = self.get_next()
return self.network(*outputs)
dataset_iter = dataset_helper.iter dataset_iter = dataset_helper.iter
dataset = dataset_iter.dataset dataset = dataset_iter.dataset
@ -98,11 +106,14 @@ def connect_network_with_dataset(network, dataset_helper):
if ms_role in ("MS_PSERVER", "MS_SCHED"): if ms_role in ("MS_PSERVER", "MS_SCHED"):
return network return network
if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ queue_name = dataset.__transfer_dataset__.queue_name
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ if hasattr(dataset_iter, "sink_size") and \
and context.get_context("device_target") == "Ascend" \ dataset_iter.sink_size == 1 and \
and context.get_context("mode") == context.GRAPH_MODE \ hasattr(dataset_iter, "sink_count") and \
and ms_role != "MS_WORKER": dataset_iter.sink_count == 1 and \
context.get_context("device_target") == "Ascend" and \
context.get_context("mode") == context.GRAPH_MODE and \
ms_role != "MS_WORKER":
if not hasattr(dataset_iter, '__network__'): if not hasattr(dataset_iter, '__network__'):
dataset_iter.__network__ = network dataset_iter.__network__ = network
@ -118,21 +129,19 @@ def connect_network_with_dataset(network, dataset_helper):
if _need_to_full(): if _need_to_full():
device_num = _get_device_num() device_num = _get_device_num()
dataset_shapes = _to_full_shapes(dataset_shapes, device_num) dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
network = _DataWrapper(network, dataset_types, dataset_shapes, dataset.__transfer_dataset__.queue_name)
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name)
dataset_iter.__network_manage__ = dataset_iter.__network_manage__ if hasattr( dataset_iter.__network_manage__ = dataset_iter.__network_manage__ if hasattr(
dataset_iter, '__network_manage__') else dict() dataset_iter, '__network_manage__') else dict()
dataset_iter.__network_manage__[key] = network dataset_iter.__network_manage__[key] = network
return network return network
if not hasattr(dataset, '__me_inited__') and context.get_context("device_target") in ("Ascend", "GPU")\ if not hasattr(dataset, '__me_inited__') and \
and not context.get_context("enable_ge"): not context.get_context("enable_ge") and \
context.get_context("device_target") in ("Ascend", "GPU"):
dataset.__me_inited__ = True dataset.__me_inited__ = True
dataset_types, dataset_shapes = dataset_helper.types_shapes() dataset_types, dataset_shapes = dataset_helper.types_shapes()
queue_name = dataset.__transfer_dataset__.queue_name network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name)
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
return network return network