forked from mindspore-Ecosystem/mindspore
!13045 avoid insert getnext op twice when eval or train twice with data sink mode
From: @zhangbuxue Reviewed-by: @ginfung Signed-off-by:
This commit is contained in:
commit
6a64644e66
|
@ -39,6 +39,32 @@ def _send_data_no_flag(dataset, epoch_num):
|
||||||
exec_dataset.send(epoch_num)
|
exec_dataset.send(epoch_num)
|
||||||
|
|
||||||
|
|
||||||
|
class _DataWrapper(nn.Cell):
|
||||||
|
"""
|
||||||
|
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
|
||||||
|
dataset channel 'queue_name' and performs the forward computation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
|
||||||
|
super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
|
||||||
|
# Also copy the flag in `network` construct
|
||||||
|
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
|
||||||
|
self.info = (dataset_types, dataset_shapes)
|
||||||
|
self.add_flags(**flags)
|
||||||
|
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
outputs = self.get_next()
|
||||||
|
return self.network(*outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name):
|
||||||
|
if not isinstance(network, _DataWrapper):
|
||||||
|
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
|
||||||
|
return network
|
||||||
|
|
||||||
|
|
||||||
def connect_network_with_dataset(network, dataset_helper):
|
def connect_network_with_dataset(network, dataset_helper):
|
||||||
"""
|
"""
|
||||||
Connect the `network` with dataset in `dataset_helper`.
|
Connect the `network` with dataset in `dataset_helper`.
|
||||||
|
@ -70,24 +96,6 @@ def connect_network_with_dataset(network, dataset_helper):
|
||||||
>>> net_with_get_next = connect_network_with_dataset(net, dataset_helper)
|
>>> net_with_get_next = connect_network_with_dataset(net, dataset_helper)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class _DataWrapper(nn.Cell):
|
|
||||||
"""
|
|
||||||
Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
|
|
||||||
dataset channel 'queue_name' and performs the forward computation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, network, dataset_types, dataset_shapes, queue_name):
|
|
||||||
super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
|
|
||||||
# Also copy the flag in `network` construct
|
|
||||||
flags = getattr(network.__class__.construct, "_mindspore_flags", {})
|
|
||||||
self.add_flags(**flags)
|
|
||||||
self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
|
||||||
self.network = network
|
|
||||||
|
|
||||||
def construct(self):
|
|
||||||
outputs = self.get_next()
|
|
||||||
return self.network(*outputs)
|
|
||||||
|
|
||||||
dataset_iter = dataset_helper.iter
|
dataset_iter = dataset_helper.iter
|
||||||
dataset = dataset_iter.dataset
|
dataset = dataset_iter.dataset
|
||||||
|
|
||||||
|
@ -98,11 +106,14 @@ def connect_network_with_dataset(network, dataset_helper):
|
||||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||||
return network
|
return network
|
||||||
|
|
||||||
if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \
|
queue_name = dataset.__transfer_dataset__.queue_name
|
||||||
and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \
|
if hasattr(dataset_iter, "sink_size") and \
|
||||||
and context.get_context("device_target") == "Ascend" \
|
dataset_iter.sink_size == 1 and \
|
||||||
and context.get_context("mode") == context.GRAPH_MODE \
|
hasattr(dataset_iter, "sink_count") and \
|
||||||
and ms_role != "MS_WORKER":
|
dataset_iter.sink_count == 1 and \
|
||||||
|
context.get_context("device_target") == "Ascend" and \
|
||||||
|
context.get_context("mode") == context.GRAPH_MODE and \
|
||||||
|
ms_role != "MS_WORKER":
|
||||||
|
|
||||||
if not hasattr(dataset_iter, '__network__'):
|
if not hasattr(dataset_iter, '__network__'):
|
||||||
dataset_iter.__network__ = network
|
dataset_iter.__network__ = network
|
||||||
|
@ -118,21 +129,19 @@ def connect_network_with_dataset(network, dataset_helper):
|
||||||
if _need_to_full():
|
if _need_to_full():
|
||||||
device_num = _get_device_num()
|
device_num = _get_device_num()
|
||||||
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
||||||
network = _DataWrapper(network, dataset_types, dataset_shapes, dataset.__transfer_dataset__.queue_name)
|
|
||||||
|
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name)
|
||||||
dataset_iter.__network_manage__ = dataset_iter.__network_manage__ if hasattr(
|
dataset_iter.__network_manage__ = dataset_iter.__network_manage__ if hasattr(
|
||||||
dataset_iter, '__network_manage__') else dict()
|
dataset_iter, '__network_manage__') else dict()
|
||||||
dataset_iter.__network_manage__[key] = network
|
dataset_iter.__network_manage__[key] = network
|
||||||
|
|
||||||
return network
|
return network
|
||||||
|
|
||||||
if not hasattr(dataset, '__me_inited__') and context.get_context("device_target") in ("Ascend", "GPU")\
|
if not hasattr(dataset, '__me_inited__') and \
|
||||||
and not context.get_context("enable_ge"):
|
not context.get_context("enable_ge") and \
|
||||||
|
context.get_context("device_target") in ("Ascend", "GPU"):
|
||||||
dataset.__me_inited__ = True
|
dataset.__me_inited__ = True
|
||||||
|
|
||||||
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
||||||
queue_name = dataset.__transfer_dataset__.queue_name
|
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name)
|
||||||
|
|
||||||
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
|
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue