forked from mindspore-Ecosystem/mindspore
!3396 fix get dataset size error
Merge pull request !3396 from panfengfeng/fix_getdataset_size_error
This commit is contained in:
commit
dc64a3878a
|
@ -212,12 +212,12 @@ Status DeviceQueueOp::SendDataToGPU() {
|
|||
RETURN_IF_NOT_OK(RetryPushGPUData(data_size, curr_row, handle));
|
||||
total_batch++;
|
||||
}
|
||||
if (!TaskManager::FindMe()->Interrupted())
|
||||
if (!TaskManager::FindMe()->Interrupted() && !GpuBufferMgr::GetInstance().IsClosed())
|
||||
RETURN_IF_NOT_OK(GetNextInput(¤t_buffer));
|
||||
else
|
||||
is_break_loop = true;
|
||||
}
|
||||
if (!TaskManager::FindMe()->Interrupted())
|
||||
if (!TaskManager::FindMe()->Interrupted() && !GpuBufferMgr::GetInstance().IsClosed())
|
||||
RETURN_IF_NOT_OK(GetNextInput(¤t_buffer));
|
||||
else
|
||||
is_break_loop = true;
|
||||
|
|
|
@ -2401,7 +2401,7 @@ class TransferDataset(DatasetOp):
|
|||
# need to keep iterator alive so the executionTree is not destroyed
|
||||
if self._noop_mode():
|
||||
return
|
||||
self.iterator = TupleIterator(self, num_epochs=-1)
|
||||
self.iterator = TupleIterator(self, num_epochs=num_epochs)
|
||||
|
||||
def stop_send(self):
|
||||
self.iterator.depipeline.StopSend()
|
||||
|
|
|
@ -24,13 +24,18 @@ from ..nn.wrap import GetNextSingleOp
|
|||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full
|
||||
|
||||
|
||||
def _send_data(dataset):
|
||||
def _send_data(dataset, epoch_num):
|
||||
"""Engine dataset to write data to tdt queue."""
|
||||
if not hasattr(dataset, '__has_sent__'):
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset.send()
|
||||
exec_dataset.send(epoch_num)
|
||||
dataset.__has_sent__ = True
|
||||
|
||||
def _send_data_no_flag(dataset, epoch_num):
|
||||
"""Engine dataset to write data to tdt queue directly."""
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset.send(epoch_num)
|
||||
|
||||
|
||||
class DatasetHelper:
|
||||
"""
|
||||
|
@ -54,7 +59,7 @@ class DatasetHelper:
|
|||
>>> outputs = network(*inputs)
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1):
|
||||
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1):
|
||||
check_bool(dataset_sink_mode)
|
||||
check_int(sink_size)
|
||||
if sink_size < -1 or sink_size == 0:
|
||||
|
@ -74,7 +79,7 @@ class DatasetHelper:
|
|||
iterclass = _DatasetIterMS
|
||||
elif context.get_context("device_target") == "CPU":
|
||||
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
|
||||
self.iter = iterclass(dataset, sink_size)
|
||||
self.iter = iterclass(dataset, sink_size, epoch_num)
|
||||
else:
|
||||
iterclass = _DatasetIterNormal
|
||||
self.iter = iterclass(dataset)
|
||||
|
@ -98,7 +103,7 @@ class DatasetHelper:
|
|||
|
||||
class _DatasetIter:
|
||||
"""Base iter for dataset helper"""
|
||||
def __init__(self, dataset, sink_size):
|
||||
def __init__(self, dataset, sink_size, epoch_num):
|
||||
self.dataset = dataset
|
||||
self.sink_size = sink_size
|
||||
self.sink_count = 1
|
||||
|
@ -110,9 +115,9 @@ class _DatasetIter:
|
|||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
if not hasattr(dataset, '__no_send__'):
|
||||
_send_data(dataset)
|
||||
_send_data(dataset, epoch_num)
|
||||
else:
|
||||
_send_data(dataset)
|
||||
_send_data_no_flag(dataset, epoch_num)
|
||||
|
||||
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
|
||||
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
|
||||
|
@ -156,8 +161,8 @@ class _DatasetIter:
|
|||
|
||||
class _DatasetIterGE(_DatasetIter):
|
||||
"""Iter for GE."""
|
||||
def __init__(self, dataset, sink_size):
|
||||
super().__init__(dataset, sink_size)
|
||||
def __init__(self, dataset, sink_size, epoch_num):
|
||||
super().__init__(dataset, sink_size, epoch_num)
|
||||
self.sink_count = self.get_sink_count(dataset)
|
||||
batch_expand_num = 1
|
||||
if _need_to_full():
|
||||
|
@ -172,8 +177,8 @@ class _DatasetIterGE(_DatasetIter):
|
|||
|
||||
class _DatasetIterMSLoopSink(_DatasetIter):
|
||||
"""Iter for context (device_target=Ascend)"""
|
||||
def __init__(self, dataset, sink_size):
|
||||
super().__init__(dataset, sink_size)
|
||||
def __init__(self, dataset, sink_size, epoch_num):
|
||||
super().__init__(dataset, sink_size, epoch_num)
|
||||
self.sink_count = self.get_sink_count(dataset)
|
||||
ms_role = os.getenv("MS_ROLE")
|
||||
if ms_role in ("MS_PSERVER", "MS_SCHED"):
|
||||
|
@ -193,8 +198,8 @@ class _DatasetIterMSLoopSink(_DatasetIter):
|
|||
|
||||
class _DatasetIterMS(_DatasetIter):
|
||||
"""Iter for MS(enable_loop_sink=False)."""
|
||||
def __init__(self, dataset, sink_size):
|
||||
super().__init__(dataset, sink_size)
|
||||
def __init__(self, dataset, sink_size, epoch_num):
|
||||
super().__init__(dataset, sink_size, epoch_num)
|
||||
if sink_size > 0:
|
||||
self.sink_count = sink_size
|
||||
else:
|
||||
|
@ -206,8 +211,8 @@ class _DatasetIterMS(_DatasetIter):
|
|||
|
||||
class _DatasetIterPSLite(_DatasetIter):
|
||||
"""Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED"""
|
||||
def __init__(self, dataset, sink_size):
|
||||
super().__init__(dataset, sink_size)
|
||||
def __init__(self, dataset, sink_size, epoch_num):
|
||||
super().__init__(dataset, sink_size, epoch_num)
|
||||
self.sink_count = 1
|
||||
self.sink_size = 1
|
||||
self.op = None
|
||||
|
|
|
@ -227,7 +227,7 @@ class Model:
|
|||
scaling_sens /= self._device_number
|
||||
return scaling_sens
|
||||
|
||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1):
|
||||
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1):
|
||||
"""Initializes dataset."""
|
||||
need_wrap = False
|
||||
if dataset_sink_mode:
|
||||
|
@ -239,7 +239,7 @@ class Model:
|
|||
if not is_train:
|
||||
dataset.__loop_size__ = 1
|
||||
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size)
|
||||
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
|
||||
|
||||
# remove later to deal with loop sink
|
||||
if need_wrap:
|
||||
|
@ -399,12 +399,18 @@ class Model:
|
|||
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
|
||||
sink_size (int): Control the amount of data each sink. Default: -1.
|
||||
"""
|
||||
if sink_size == -1:
|
||||
epoch_num = epoch
|
||||
else:
|
||||
epoch_num = epoch * sink_size // train_dataset.get_dataset_size()
|
||||
|
||||
dataset_helper, train_network = self._exec_preprocess(self._train_network,
|
||||
is_train=True,
|
||||
phase='train',
|
||||
dataset=train_dataset,
|
||||
dataset_sink_mode=True,
|
||||
sink_size=sink_size)
|
||||
sink_size=sink_size,
|
||||
epoch_num=epoch_num)
|
||||
self._train_network = train_network
|
||||
cb_params.train_network = self._train_network
|
||||
cb_params.cur_step_num = 0
|
||||
|
@ -621,6 +627,8 @@ class Model:
|
|||
list_callback.step_end(run_context)
|
||||
self._update_metrics(outputs)
|
||||
|
||||
valid_dataset.reset()
|
||||
|
||||
metrics = self._get_metrics()
|
||||
cb_params.metrics = metrics
|
||||
list_callback.end(run_context)
|
||||
|
|
|
@ -58,7 +58,7 @@ class MindData:
|
|||
def create_tuple_iterator(self):
|
||||
return self.__iter__()
|
||||
|
||||
def send(self):
|
||||
def send(self, num_epochs=-1):
|
||||
pass
|
||||
|
||||
def stop_send(self):
|
||||
|
|
|
@ -15,11 +15,16 @@
|
|||
"""Dataset help for minddata dataset"""
|
||||
from mindspore._checkparam import check_bool
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode
|
||||
from mindspore.train.dataset_helper import _send_data
|
||||
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
|
||||
_to_full_shapes
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
|
||||
def _send_data(dataset):
|
||||
"""Engine dataset to write data to tdt queue."""
|
||||
if not hasattr(dataset, '__has_sent__'):
|
||||
exec_dataset = dataset.__TRANSFER_DATASET__
|
||||
exec_dataset.send()
|
||||
dataset.__has_sent__ = True
|
||||
|
||||
class DatasetHelper:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue