forked from mindspore-Ecosystem/mindspore
!7600 fix GeneratorDataset timeout
Merge pull request !7600 from heleiwang/fix_generator
This commit is contained in:
commit
0218f8a06f
|
@ -3239,21 +3239,19 @@ def _cpp_sampler_fn(sampler, dataset):
|
|||
yield tuple([np.array(x, copy=False) for x in val])
|
||||
|
||||
|
||||
def _cpp_sampler_fn_mp(sampler, dataset, num_worker, multi_process):
|
||||
def _cpp_sampler_fn_mp(sampler, sample_fn):
|
||||
"""
|
||||
Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
|
||||
"""
|
||||
indices = sampler.get_indices()
|
||||
sample_fn = SamplerFn(dataset, num_worker, multi_process)
|
||||
return sample_fn.process(indices)
|
||||
|
||||
|
||||
def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process):
|
||||
def _py_sampler_fn_mp(sampler, num_samples, sample_fn):
|
||||
"""
|
||||
Multiprocessing generator function wrapper for mappable dataset with Python sampler.
|
||||
"""
|
||||
indices = _fetch_py_sampler_indices(sampler, num_samples)
|
||||
sample_fn = SamplerFn(dataset, num_worker, multi_process)
|
||||
return sample_fn.process(indices)
|
||||
|
||||
|
||||
|
@ -3299,17 +3297,21 @@ class SamplerFn:
|
|||
self.multi_process = multi_process
|
||||
# Event for end of epoch
|
||||
if multi_process is True:
|
||||
self.eoe = multiprocessing.Event()
|
||||
self.eof = multiprocessing.Event()
|
||||
else:
|
||||
self.eoe = threading.Event()
|
||||
self.eof = threading.Event()
|
||||
# Create workers
|
||||
for _ in range(num_worker):
|
||||
if multi_process is True:
|
||||
worker = _GeneratorWorkerMp(dataset, self.eoe)
|
||||
worker = _GeneratorWorkerMp(dataset, self.eof)
|
||||
worker.daemon = True
|
||||
# When multi processes fork a subprocess, the lock of the main process is copied to the subprocess,
|
||||
# which may cause deadlock. Therefore, the subprocess startup is performed in che initialization phase.
|
||||
# In this phase, the main process is not locked.
|
||||
worker.start()
|
||||
else:
|
||||
worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof)
|
||||
worker.daemon = True
|
||||
worker = _GeneratorWorkerMt(dataset, self.eof)
|
||||
worker.daemon = True
|
||||
self.workers.append(worker)
|
||||
|
||||
def process(self, indices):
|
||||
|
@ -3317,14 +3319,18 @@ class SamplerFn:
|
|||
The main process, start the child process or child thread, and fill the index queue.
|
||||
Get the result and return.
|
||||
"""
|
||||
for w in self.workers:
|
||||
# Check whether the queue of the subprocess is empty.
|
||||
if not w.queue_empty():
|
||||
raise Exception("The queue of the subprocess is not empty.")
|
||||
# Start all workers
|
||||
if not w.is_alive():
|
||||
w.start()
|
||||
|
||||
# Fill initial index queues
|
||||
idx_cursor = 0
|
||||
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
|
||||
|
||||
# Start all workers
|
||||
for w in self.workers:
|
||||
w.start()
|
||||
|
||||
# Fetch results
|
||||
for i in range(len(indices)):
|
||||
# Fetch result and put index
|
||||
|
@ -3340,64 +3346,31 @@ class SamplerFn:
|
|||
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||
if idx_cursor < len(indices):
|
||||
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
|
||||
# Set end-of-epoch (eoe) event once all indices are sent
|
||||
if idx_cursor == len(indices) and not self.eoe.is_set():
|
||||
self.eoe.set()
|
||||
yield tuple([np.array(x, copy=False) for x in result])
|
||||
|
||||
def __del__(self):
|
||||
self.eoe.set()
|
||||
if self.multi_process is False:
|
||||
self.eof.set()
|
||||
for w in self.workers:
|
||||
w.join()
|
||||
self.eof.set()
|
||||
|
||||
|
||||
def _generator_worker_loop_mp(dataset, idx_queue, result_queue, eoe):
|
||||
def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
|
||||
"""
|
||||
Multiprocessing generator worker process loop
|
||||
Multithread or multiprocess generator worker process loop.
|
||||
"""
|
||||
while True:
|
||||
# Fetch index, block
|
||||
try:
|
||||
idx = idx_queue.get()
|
||||
except KeyboardInterrupt:
|
||||
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||
if idx is None:
|
||||
# When the queue is out of scope from master process, a None item can be fetched from the queue.
|
||||
# Upon receiving None, worker process should check if EOE is set.
|
||||
assert eoe.is_set(), ""
|
||||
return
|
||||
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
|
||||
result = dataset[idx]
|
||||
# Send data, block
|
||||
try:
|
||||
result_queue.put(result)
|
||||
except KeyboardInterrupt:
|
||||
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||
del result, idx
|
||||
|
||||
|
||||
def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof):
|
||||
"""
|
||||
Multithread generator worker process loop.
|
||||
"""
|
||||
while True:
|
||||
# Fetch index, block
|
||||
try:
|
||||
# Index is generated very fast, so the timeout is very short
|
||||
idx = idx_queue.get(timeout=0.01)
|
||||
idx = idx_queue.get(timeout=1)
|
||||
except KeyboardInterrupt:
|
||||
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||
except queue.Empty:
|
||||
if eof.is_set() or eoe.is_set():
|
||||
if eof.is_set():
|
||||
return
|
||||
# If end-of-epoch (eoe) or end-of-file (eof) is not set, continue to get data from idx_queue
|
||||
# If end-of-file (eof) is not set, continue to get data from idx_queue
|
||||
continue
|
||||
if idx is None:
|
||||
# When the queue is out of scope from master process, a None item can be fetched from the queue.
|
||||
# Upon receiving None, worker process should check if EOE is set.
|
||||
assert eoe.is_set(), ""
|
||||
# Upon receiving None, worker process should check if eof is set.
|
||||
assert eof.is_set(), ""
|
||||
return
|
||||
if eof.is_set():
|
||||
return
|
||||
|
@ -3416,8 +3389,6 @@ def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof):
|
|||
continue
|
||||
break
|
||||
del result, idx
|
||||
if eoe.is_set() and idx_queue.empty():
|
||||
return
|
||||
|
||||
|
||||
class _GeneratorWorkerMt(threading.Thread):
|
||||
|
@ -3425,10 +3396,10 @@ class _GeneratorWorkerMt(threading.Thread):
|
|||
Worker process for multithread Generator.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, eoe, eof):
|
||||
def __init__(self, dataset, eof):
|
||||
self.idx_queue = queue.Queue(16)
|
||||
self.res_queue = queue.Queue(16)
|
||||
super().__init__(target=_generator_worker_loop_mt, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
|
||||
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof))
|
||||
|
||||
def put(self, item):
|
||||
"""
|
||||
|
@ -3442,16 +3413,25 @@ class _GeneratorWorkerMt(threading.Thread):
|
|||
"""
|
||||
return self.res_queue.get(timeout=30)
|
||||
|
||||
def queue_empty(self):
|
||||
if not self.idx_queue.empty():
|
||||
logger.error("idx_queue is not empty")
|
||||
return False
|
||||
if not self.res_queue.empty():
|
||||
logger.error("res_queue is not empty")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class _GeneratorWorkerMp(multiprocessing.Process):
|
||||
"""
|
||||
Worker process for multiprocess Generator.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, eoe):
|
||||
def __init__(self, dataset, eof):
|
||||
self.idx_queue = multiprocessing.Queue(16)
|
||||
self.res_queue = multiprocessing.Queue(16)
|
||||
super().__init__(target=_generator_worker_loop_mp, args=(dataset, self.idx_queue, self.res_queue, eoe))
|
||||
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof))
|
||||
|
||||
def put(self, item):
|
||||
"""
|
||||
|
@ -3467,6 +3447,15 @@ class _GeneratorWorkerMp(multiprocessing.Process):
|
|||
# when we run too many iterators with infinite epoch(num_epoch=-1)
|
||||
return self.res_queue.get(timeout=30)
|
||||
|
||||
def queue_empty(self):
|
||||
if not self.idx_queue.empty():
|
||||
logger.error("idx_queue is not empty")
|
||||
return False
|
||||
if not self.res_queue.empty():
|
||||
logger.error("res_queue is not empty")
|
||||
return False
|
||||
return True
|
||||
|
||||
def __del__(self):
|
||||
# Try to destruct here, sometimes the class itself will be destructed in advance,
|
||||
# so "self" will be a NoneType
|
||||
|
@ -3657,16 +3646,14 @@ class GeneratorDataset(MappableDataset):
|
|||
sampler_instance.set_num_rows(len(self.source))
|
||||
sampler_instance.initialize()
|
||||
if new_op.num_parallel_workers > 1:
|
||||
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source,
|
||||
new_op.num_parallel_workers,
|
||||
self.python_multiprocessing))
|
||||
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
|
||||
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, sample_fn))
|
||||
else:
|
||||
new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source))
|
||||
else:
|
||||
if new_op.num_parallel_workers > 1:
|
||||
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source,
|
||||
new_op.num_parallel_workers,
|
||||
self.python_multiprocessing))
|
||||
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing)
|
||||
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, sample_fn))
|
||||
else:
|
||||
new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source))
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue