!6916 [MD] Fix timeout of GeneratorDataset multiprocessing
Merge pull request !6916 from luoyang/son_r1.0
This commit is contained in:
commit
40568790e5
|
@ -3256,14 +3256,13 @@ class SamplerFn:
|
|||
# Event for end of epoch
|
||||
if multi_process is True:
|
||||
self.eoe = multiprocessing.Event()
|
||||
self.eof = multiprocessing.Event()
|
||||
else:
|
||||
self.eoe = threading.Event()
|
||||
self.eof = threading.Event()
|
||||
# Create workers
|
||||
for _ in range(num_worker):
|
||||
if multi_process is True:
|
||||
worker = _GeneratorWorkerMp(dataset, self.eoe, self.eof)
|
||||
worker = _GeneratorWorkerMp(dataset, self.eoe)
|
||||
else:
|
||||
worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof)
|
||||
worker.daemon = True
|
||||
|
@ -3304,15 +3303,40 @@ class SamplerFn:
|
|||
|
||||
def __del__(self):
|
||||
self.eoe.set()
|
||||
self.eof.set()
|
||||
if self.multi_process is False:
|
||||
self.eof.set()
|
||||
for w in self.workers:
|
||||
w.join()
|
||||
|
||||
|
||||
def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof):
|
||||
def _generator_worker_loop_mp(dataset, idx_queue, result_queue, eoe):
|
||||
"""
|
||||
Multiprocessing or multithread generator worker process loop.
|
||||
Multiprocessing generator worker process loop
|
||||
"""
|
||||
while True:
|
||||
# Fetch index, block
|
||||
try:
|
||||
idx = idx_queue.get()
|
||||
except KeyboardInterrupt:
|
||||
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||
if idx is None:
|
||||
# When the queue is out of scope from master process, a None item can be fetched from the queue.
|
||||
# Upon receiving None, worker process should check if EOE is set.
|
||||
assert eoe.is_set(), ""
|
||||
return
|
||||
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
|
||||
result = dataset[idx]
|
||||
# Send data, block
|
||||
try:
|
||||
result_queue.put(result)
|
||||
except KeyboardInterrupt:
|
||||
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||
del result, idx
|
||||
|
||||
|
||||
def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof):
|
||||
"""
|
||||
Multithread generator worker process loop.
|
||||
"""
|
||||
while True:
|
||||
# Fetch index, block
|
||||
|
@ -3360,7 +3384,7 @@ class _GeneratorWorkerMt(threading.Thread):
|
|||
def __init__(self, dataset, eoe, eof):
|
||||
self.idx_queue = queue.Queue(16)
|
||||
self.res_queue = queue.Queue(16)
|
||||
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
|
||||
super().__init__(target=_generator_worker_loop_mt, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
|
||||
|
||||
def put(self, item):
|
||||
"""
|
||||
|
@ -3380,10 +3404,10 @@ class _GeneratorWorkerMp(multiprocessing.Process):
|
|||
Worker process for multiprocess Generator.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, eoe, eof):
|
||||
def __init__(self, dataset, eoe):
|
||||
self.idx_queue = multiprocessing.Queue(16)
|
||||
self.res_queue = multiprocessing.Queue(16)
|
||||
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
|
||||
super().__init__(target=_generator_worker_loop_mp, args=(dataset, self.idx_queue, self.res_queue, eoe))
|
||||
|
||||
def put(self, item):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue