!10137 fix GeneratorDataset multiprocessing hangs

From: @heleiwang
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2020-12-21 19:04:27 +08:00 committed by Gitee
commit ea2cabcfec
1 changed files with 42 additions and 20 deletions

View File

@ -22,10 +22,12 @@ import glob
import json
import math
import os
import signal
import uuid
import multiprocessing
import queue
from enum import Enum
from functools import partial
from importlib import import_module
import sys
import threading
@ -3447,6 +3449,7 @@ class SamplerFn:
self.workers = []
self.num_worker = num_worker
self.multi_process = multi_process
self.joined = False
# Event for end of epoch
if multi_process is True:
self.eof = multiprocessing.Event()
@ -3485,29 +3488,47 @@ class SamplerFn:
# Fetch results
for i in range(len(indices)):
if self.eof.is_set():
self._stop_subprocess()
return
# Fetch result and put index
try:
result = self.workers[i % self.num_worker].get()
except queue.Empty:
self._stop_subprocess()
raise Exception("Generator worker process timeout.")
except KeyboardInterrupt:
self.eof.set()
for w in self.workers:
w.terminate()
w.join()
self._stop_subprocess()
raise Exception("Generator worker receives KeyboardInterrupt.")
if self.eof.is_set():
self._stop_subprocess()
return
if idx_cursor < len(indices):
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
yield tuple([np.array(x, copy=False) for x in result])
def __del__(self):
def _stop_subprocess(self):
self.eof.set()
if self.joined is False:
for w in self.workers:
w.join()
self.joined = True
def __del__(self):
self._stop_subprocess()
def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
def _subprocess_handle(eof, signum, frame):
logger.info("The subprocess receives a termination signal.")
eof.set()
def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing):
"""
Multithread or multiprocess generator worker process loop.
"""
if is_multiprocessing:
signal.signal(signal.SIGTERM, partial(_subprocess_handle, eof))
while True:
# Fetch index, block
try:
@ -3516,6 +3537,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
raise Exception("Generator worker receives KeyboardInterrupt.")
except queue.Empty:
if eof.is_set():
if is_multiprocessing:
idx_queue.cancel_join_thread()
result_queue.cancel_join_thread()
return
# If end-of-file (eof) is not set, continue to get data from idx_queue
continue
@ -3525,6 +3549,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
assert eof.is_set(), ""
return
if eof.is_set():
if is_multiprocessing:
idx_queue.cancel_join_thread()
result_queue.cancel_join_thread()
return
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
result = dataset[idx]
@ -3536,6 +3563,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
raise Exception("Generator worker receives KeyboardInterrupt.")
except queue.Full:
if eof.is_set():
if is_multiprocessing:
idx_queue.cancel_join_thread()
result_queue.cancel_join_thread()
return
# If eof is not set, continue to put data to result_queue
continue
@ -3551,7 +3581,7 @@ class _GeneratorWorkerMt(threading.Thread):
def __init__(self, dataset, eof):
self.idx_queue = queue.Queue(16)
self.res_queue = queue.Queue(16)
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof))
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, False))
def put(self, item):
"""
@ -3567,10 +3597,10 @@ class _GeneratorWorkerMt(threading.Thread):
def queue_empty(self):
if not self.idx_queue.empty():
logger.error("idx_queue is not empty")
logger.warning("idx_queue is not empty")
return False
if not self.res_queue.empty():
logger.error("res_queue is not empty")
logger.warning("res_queue is not empty")
return False
return True
@ -3583,7 +3613,7 @@ class _GeneratorWorkerMp(multiprocessing.Process):
def __init__(self, dataset, eof):
self.idx_queue = multiprocessing.Queue(16)
self.res_queue = multiprocessing.Queue(16)
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof))
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True))
def put(self, item):
"""
@ -3601,21 +3631,13 @@ class _GeneratorWorkerMp(multiprocessing.Process):
def queue_empty(self):
if not self.idx_queue.empty():
logger.error("idx_queue is not empty.")
logger.warning("idx_queue is not empty.")
return False
if not self.res_queue.empty():
logger.error("res_queue is not empty.")
logger.warning("res_queue is not empty.")
return False
return True
def __del__(self):
# Try to destruct here, sometimes the class itself will be destructed in advance,
# so "self" will be a NoneType
try:
self.terminate()
except AttributeError:
pass
class GeneratorDataset(MappableDataset):
"""