forked from mindspore-Ecosystem/mindspore
dataset_attr_name
This commit is contained in:
parent
9a8d4da5dc
commit
b4e020d081
|
@ -27,13 +27,13 @@ from ..ops import operations as P
|
|||
def _send_data(dataset, epoch_num):
|
||||
"""Engine dataset to write data to tdt queue."""
|
||||
if not hasattr(dataset, '__has_sent__'):
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset = dataset.__transfer_dataset__
|
||||
exec_dataset.send(epoch_num)
|
||||
dataset.__has_sent__ = True
|
||||
|
||||
def _send_data_no_flag(dataset, epoch_num):
|
||||
"""Engine dataset to write data to tdt queue directly."""
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset = dataset.__transfer_dataset__
|
||||
exec_dataset.send(epoch_num)
|
||||
|
||||
|
||||
|
@ -88,11 +88,13 @@ def connect_network_with_dataset(network, dataset_helper):
|
|||
if isinstance(dataset_iter, _DatasetIterNormal):
|
||||
raise RuntimeError("Dataset should be connected with network only in sink mode.")
|
||||
|
||||
if not hasattr(dataset, '__ME_INITED__') and (context.get_context("device_target") == "Ascend" \
|
||||
or context.get_context("device_target") == "GPU") and not context.get_context("enable_ge"):
|
||||
dataset.__ME_INITED__ = True
|
||||
if not hasattr(dataset, '__me_inited__') and (context.get_context("device_target") == "Ascend"
|
||||
or context.get_context("device_target") == "GPU") and not \
|
||||
context.get_context("enable_ge"):
|
||||
dataset.__me_inited__ = True
|
||||
|
||||
dataset_types, dataset_shapes = dataset_helper.types_shapes()
|
||||
queue_name = dataset.__TRANSFER_DATASET__.queue_name
|
||||
queue_name = dataset.__transfer_dataset__.queue_name
|
||||
|
||||
network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
|
||||
return network
|
||||
|
@ -175,18 +177,18 @@ class _DatasetIter:
|
|||
self.sink_size = sink_size
|
||||
self.sink_count = 1
|
||||
|
||||
if not hasattr(dataset, '__TRANSFER_DATASET__'):
|
||||
if not hasattr(dataset, '__transfer_dataset__'):
|
||||
if hasattr(dataset, '__loop_size__'):
|
||||
self.sink_size = dataset.__loop_size__
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
||||
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size)
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset, epoch_num)
|
||||
else:
|
||||
_send_data_no_flag(dataset, epoch_num)
|
||||
|
||||
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
|
||||
self.continue_send = dataset.__TRANSFER_DATASET__.continue_send
|
||||
self.stop_send = dataset.__transfer_dataset__.stop_send
|
||||
self.continue_send = dataset.__transfer_dataset__.continue_send
|
||||
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -273,7 +275,7 @@ class _DatasetIterMS(_DatasetIter):
|
|||
else:
|
||||
self.sink_count = dataset.get_dataset_size()
|
||||
|
||||
queue_name = dataset.__TRANSFER_DATASET__.queue_name
|
||||
queue_name = dataset.__transfer_dataset__.queue_name
|
||||
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
|
||||
|
||||
|
||||
|
|
|
@ -25,14 +25,14 @@ from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_s
|
|||
def _send_data(dataset, epoch_num):
|
||||
"""Engine dataset to write data to tdt queue."""
|
||||
if not hasattr(dataset, '__has_sent__'):
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset = dataset.__transfer_dataset__
|
||||
exec_dataset.send(epoch_num)
|
||||
dataset.__has_sent__ = True
|
||||
|
||||
|
||||
def _send_data_no_flag(dataset, epoch_num):
|
||||
"""Engine dataset to write data to tdt queue directly."""
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset = dataset.__transfer_dataset__
|
||||
exec_dataset.send(epoch_num)
|
||||
|
||||
|
||||
|
@ -100,17 +100,17 @@ class _DatasetIter:
|
|||
self.sink_size = sink_size
|
||||
self.sink_count = 1
|
||||
|
||||
if not hasattr(dataset, '__TRANSFER_DATASET__'):
|
||||
if not hasattr(dataset, '__transfer_dataset__'):
|
||||
if hasattr(dataset, '__loop_size__'):
|
||||
self.sink_size = dataset.__loop_size__
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
||||
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size)
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset, epoch_num)
|
||||
else:
|
||||
_send_data_no_flag(dataset, epoch_num)
|
||||
|
||||
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
|
||||
self.stop_send = dataset.__transfer_dataset__.stop_send
|
||||
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -187,5 +187,5 @@ class _DatasetIterMS(_DatasetIter):
|
|||
else:
|
||||
self.sink_count = dataset.get_dataset_size()
|
||||
|
||||
queue_name = dataset.__TRANSFER_DATASET__.queue_name
|
||||
queue_name = dataset.__transfer_dataset__.queue_name
|
||||
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
|
||||
|
|
|
@ -57,7 +57,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
|||
|
||||
# transform data format
|
||||
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
||||
init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name,
|
||||
init_exec_dataset(exec_dataset.__transfer_dataset__.queue_name,
|
||||
dataset_size,
|
||||
batch_size,
|
||||
dataset_types,
|
||||
|
|
|
@ -24,14 +24,14 @@ from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
|
|||
def _send_data(dataset, epoch_num):
|
||||
"""Engine dataset to write data to tdt queue."""
|
||||
if not hasattr(dataset, '__has_sent__'):
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset = dataset.__transfer_dataset__
|
||||
exec_dataset.send(epoch_num)
|
||||
dataset.__has_sent__ = True
|
||||
|
||||
|
||||
def _send_data_no_flag(dataset, epoch_num):
|
||||
"""Engine dataset to write data to tdt queue directly."""
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset = dataset.__transfer_dataset__
|
||||
exec_dataset.send(epoch_num)
|
||||
|
||||
|
||||
|
@ -107,17 +107,17 @@ class _DatasetIter:
|
|||
self.sink_size = sink_size
|
||||
self.sink_count = 1
|
||||
|
||||
if not hasattr(dataset, '__TRANSFER_DATASET__'):
|
||||
if not hasattr(dataset, '__transfer_dataset__'):
|
||||
if hasattr(dataset, '__loop_size__'):
|
||||
self.sink_size = dataset.__loop_size__
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
|
||||
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size)
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset, epoch_num)
|
||||
else:
|
||||
_send_data_no_flag(dataset, epoch_num)
|
||||
|
||||
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
|
||||
self.stop_send = dataset.__transfer_dataset__.stop_send
|
||||
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
||||
|
||||
def __iter__(self):
|
||||
|
|
|
@ -71,7 +71,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
|||
|
||||
# transform data format
|
||||
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
||||
init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name,
|
||||
init_exec_dataset(exec_dataset.__transfer_dataset__.queue_name,
|
||||
dataset_size,
|
||||
batch_size,
|
||||
dataset_types,
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore.context import ParallelMode
|
|||
def _send_data(dataset):
|
||||
"""Engine dataset to write data to tdt queue."""
|
||||
if not hasattr(dataset, '__has_sent__'):
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset = dataset.__transfer_dataset__
|
||||
exec_dataset.send()
|
||||
dataset.__has_sent__ = True
|
||||
|
||||
|
@ -71,12 +71,12 @@ class _DatasetIter:
|
|||
|
||||
def __init__(self, dataset):
|
||||
self.loop_size = 1
|
||||
if not hasattr(dataset, '__TRANSFER_DATASET__'):
|
||||
if not hasattr(dataset, '__transfer_dataset__'):
|
||||
if not hasattr(dataset, '__loop_size__'):
|
||||
self.loop_size = dataset.get_dataset_size()
|
||||
else:
|
||||
self.loop_size = dataset.__loop_size__
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
||||
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.loop_size)
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset)
|
||||
|
|
|
@ -67,7 +67,7 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
|
|||
|
||||
# transform data format
|
||||
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
|
||||
init_exec_dataset(exec_dataset.__TRANSFER_DATASET__.queue_name,
|
||||
init_exec_dataset(exec_dataset.__transfer_dataset__.queue_name,
|
||||
dataset_size,
|
||||
batch_size,
|
||||
dataset_types,
|
||||
|
|
Loading…
Reference in New Issue