forked from mindspore-Ecosystem/mindspore
!10137 fix GeneratorDataset multiprocessing hangs
From: @heleiwang Reviewed-by: Signed-off-by:
This commit is contained in:
commit
ea2cabcfec
|
@ -22,10 +22,12 @@ import glob
|
|||
import json
|
||||
import math
|
||||
import os
|
||||
import signal
|
||||
import uuid
|
||||
import multiprocessing
|
||||
import queue
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from importlib import import_module
|
||||
import sys
|
||||
import threading
|
||||
|
@ -3447,6 +3449,7 @@ class SamplerFn:
|
|||
self.workers = []
|
||||
self.num_worker = num_worker
|
||||
self.multi_process = multi_process
|
||||
self.joined = False
|
||||
# Event for end of epoch
|
||||
if multi_process is True:
|
||||
self.eof = multiprocessing.Event()
|
||||
|
@ -3485,29 +3488,47 @@ class SamplerFn:
|
|||
|
||||
# Fetch results
|
||||
for i in range(len(indices)):
|
||||
if self.eof.is_set():
|
||||
self._stop_subprocess()
|
||||
return
|
||||
# Fetch result and put index
|
||||
try:
|
||||
result = self.workers[i % self.num_worker].get()
|
||||
except queue.Empty:
|
||||
self._stop_subprocess()
|
||||
raise Exception("Generator worker process timeout.")
|
||||
except KeyboardInterrupt:
|
||||
self.eof.set()
|
||||
for w in self.workers:
|
||||
w.terminate()
|
||||
w.join()
|
||||
self._stop_subprocess()
|
||||
raise Exception("Generator worker receives KeyboardInterrupt.")
|
||||
if self.eof.is_set():
|
||||
self._stop_subprocess()
|
||||
return
|
||||
if idx_cursor < len(indices):
|
||||
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
|
||||
yield tuple([np.array(x, copy=False) for x in result])
|
||||
|
||||
def __del__(self):
|
||||
def _stop_subprocess(self):
|
||||
self.eof.set()
|
||||
if self.joined is False:
|
||||
for w in self.workers:
|
||||
w.join()
|
||||
self.joined = True
|
||||
|
||||
def __del__(self):
|
||||
self._stop_subprocess()
|
||||
|
||||
|
||||
def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
|
||||
def _subprocess_handle(eof, signum, frame):
|
||||
logger.info("The subprocess receives a termination signal.")
|
||||
eof.set()
|
||||
|
||||
|
||||
def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing):
|
||||
"""
|
||||
Multithread or multiprocess generator worker process loop.
|
||||
"""
|
||||
if is_multiprocessing:
|
||||
signal.signal(signal.SIGTERM, partial(_subprocess_handle, eof))
|
||||
while True:
|
||||
# Fetch index, block
|
||||
try:
|
||||
|
@ -3516,6 +3537,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
|
|||
raise Exception("Generator worker receives KeyboardInterrupt.")
|
||||
except queue.Empty:
|
||||
if eof.is_set():
|
||||
if is_multiprocessing:
|
||||
idx_queue.cancel_join_thread()
|
||||
result_queue.cancel_join_thread()
|
||||
return
|
||||
# If end-of-file (eof) is not set, continue to get data from idx_queue
|
||||
continue
|
||||
|
@ -3525,6 +3549,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
|
|||
assert eof.is_set(), ""
|
||||
return
|
||||
if eof.is_set():
|
||||
if is_multiprocessing:
|
||||
idx_queue.cancel_join_thread()
|
||||
result_queue.cancel_join_thread()
|
||||
return
|
||||
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
|
||||
result = dataset[idx]
|
||||
|
@ -3536,6 +3563,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
|
|||
raise Exception("Generator worker receives KeyboardInterrupt.")
|
||||
except queue.Full:
|
||||
if eof.is_set():
|
||||
if is_multiprocessing:
|
||||
idx_queue.cancel_join_thread()
|
||||
result_queue.cancel_join_thread()
|
||||
return
|
||||
# If eof is not set, continue to put data to result_queue
|
||||
continue
|
||||
|
@ -3551,7 +3581,7 @@ class _GeneratorWorkerMt(threading.Thread):
|
|||
def __init__(self, dataset, eof):
|
||||
self.idx_queue = queue.Queue(16)
|
||||
self.res_queue = queue.Queue(16)
|
||||
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof))
|
||||
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, False))
|
||||
|
||||
def put(self, item):
|
||||
"""
|
||||
|
@ -3567,10 +3597,10 @@ class _GeneratorWorkerMt(threading.Thread):
|
|||
|
||||
def queue_empty(self):
|
||||
if not self.idx_queue.empty():
|
||||
logger.error("idx_queue is not empty")
|
||||
logger.warning("idx_queue is not empty")
|
||||
return False
|
||||
if not self.res_queue.empty():
|
||||
logger.error("res_queue is not empty")
|
||||
logger.warning("res_queue is not empty")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -3583,7 +3613,7 @@ class _GeneratorWorkerMp(multiprocessing.Process):
|
|||
def __init__(self, dataset, eof):
|
||||
self.idx_queue = multiprocessing.Queue(16)
|
||||
self.res_queue = multiprocessing.Queue(16)
|
||||
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof))
|
||||
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True))
|
||||
|
||||
def put(self, item):
|
||||
"""
|
||||
|
@ -3601,21 +3631,13 @@ class _GeneratorWorkerMp(multiprocessing.Process):
|
|||
|
||||
def queue_empty(self):
|
||||
if not self.idx_queue.empty():
|
||||
logger.error("idx_queue is not empty.")
|
||||
logger.warning("idx_queue is not empty.")
|
||||
return False
|
||||
if not self.res_queue.empty():
|
||||
logger.error("res_queue is not empty.")
|
||||
logger.warning("res_queue is not empty.")
|
||||
return False
|
||||
return True
|
||||
|
||||
def __del__(self):
|
||||
# Try to destruct here, sometimes the class itself will be destructed in advance,
|
||||
# so "self" will be a NoneType
|
||||
try:
|
||||
self.terminate()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
class GeneratorDataset(MappableDataset):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue