forked from mindspore-Ecosystem/mindspore
!13045 avoid insert getnext op twice when eval or train twice with data sink mode
From: @zhangbuxue Reviewed-by: @ginfung Signed-off-by:
This commit is contained in:
commit
6a64644e66
|
@ -39,6 +39,32 @@ def _send_data_no_flag(dataset, epoch_num):
|
|||
exec_dataset.send(epoch_num)
|
||||
|
||||
|
||||
class _DataWrapper(nn.Cell):
|
||||
"""
|
||||
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
|
||||
dataset channel 'queue_name' and performs the forward computation.
|
||||
"""
|
||||
|
||||
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
|
||||
super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
|
||||
# Also copy the flag in `network` construct
|
||||
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
|
||||
self.info = (dataset_types, dataset_shapes)
|
||||
self.add_flags(**flags)
|
||||
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
||||
self.network = network
|
||||
|
||||
def construct(self):
|
||||
outputs = self.get_next()
|
||||
return self.network(*outputs)
|
||||
|
||||
|
||||
def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name):
|
||||
if not isinstance(network, _DataWrapper):
|
||||
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
|
||||
return network
|
||||
|
||||
|
||||
def connect_network_with_dataset(network, dataset_helper):
|
||||
"""
|
||||
Connect the `network` with dataset in `dataset_helper`.
|
||||
|
@ -70,24 +96,6 @@ def connect_network_with_dataset(network, dataset_helper):
|
|||
>>> net_with_get_next = connect_network_with_dataset(net, dataset_helper)
|
||||
"""
|
||||
|
||||
class _DataWrapper(nn.Cell):
|
||||
"""
|
||||
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
|
||||
dataset channel 'queue_name' and performs the forward computation.
|
||||
"""
|
||||
|
||||
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
|
||||
super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
|
||||
# Also copy the flag in `network` construct
|
||||
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
|
||||
self.add_flags(**flags)
|
||||
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
||||
self.network = network
|
||||
|
||||
def construct(self):
|
||||
outputs = self.get_next()
|
||||
return self.network(*outputs)
|
||||
|
||||
dataset_iter = dataset_helper.iter
|
||||
dataset = dataset_iter.dataset
|
||||
|
||||
|
@ -98,11 +106,14 @@ def connect_network_with_dataset(network, dataset_helper):
|
|||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
return network
|
||||
|
||||
if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \
|
||||
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \
|
||||
and context.get_context("device_target") == "Ascend" \
|
||||
and context.get_context("mode") == context.GRAPH_MODE \
|
||||
and ms_role != "MS_WORKER":
|
||||
queue_name = dataset.__transfer_dataset__.queue_name
|
||||
if hasattr(dataset_iter, "sink_size") and \
|
||||
dataset_iter.sink_size == 1 and \
|
||||
hasattr(dataset_iter, "sink_count") and \
|
||||
dataset_iter.sink_count == 1 and \
|
||||
context.get_context("device_target") == "Ascend" and \
|
||||
context.get_context("mode") == context.GRAPH_MODE and \
|
||||
ms_role != "MS_WORKER":
|
||||
|
||||
if not hasattr(dataset_iter, '__network__'):
|
||||
dataset_iter.__network__ = network
|
||||
|
@ -118,21 +129,19 @@ def connect_network_with_dataset(network, dataset_helper):
|
|||
if _need_to_full():
|
||||
device_num = _get_device_num()
|
||||
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
||||
network = _DataWrapper(network, dataset_types, dataset_shapes, dataset.__transfer_dataset__.queue_name)
|
||||
|
||||
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name)
|
||||
dataset_iter.__network_manage__ = dataset_iter.__network_manage__ if hasattr(
|
||||
dataset_iter, '__network_manage__') else dict()
|
||||
dataset_iter.__network_manage__[key] = network
|
||||
|
||||
return network
|
||||
|
||||
if not hasattr(dataset, '__me_inited__') and context.get_context("device_target") in ("Ascend", "GPU")\
|
||||
and not context.get_context("enable_ge"):
|
||||
if not hasattr(dataset, '__me_inited__') and \
|
||||
not context.get_context("enable_ge") and \
|
||||
context.get_context("device_target") in ("Ascend", "GPU"):
|
||||
dataset.__me_inited__ = True
|
||||
|
||||
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
||||
queue_name = dataset.__transfer_dataset__.queue_name
|
||||
|
||||
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
|
||||
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name)
|
||||
return network
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue