From ef9363fd21e7208bdb85b30ed0eecbe4a2dc03f3 Mon Sep 17 00:00:00 2001 From: heleiwang Date: Thu, 17 Dec 2020 19:42:11 +0800 Subject: [PATCH] fix GeneratorDataset multiprocessing hangs --- mindspore/dataset/engine/datasets.py | 62 +++++++++++++++++++--------- 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 2eb232d6b7d..3fc53a7f5f8 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -22,10 +22,12 @@ import glob import json import math import os +import signal import uuid import multiprocessing import queue from enum import Enum +from functools import partial from importlib import import_module import sys import threading @@ -3443,6 +3445,7 @@ class SamplerFn: self.workers = [] self.num_worker = num_worker self.multi_process = multi_process + self.joined = False # Event for end of epoch if multi_process is True: self.eof = multiprocessing.Event() @@ -3481,29 +3484,47 @@ class SamplerFn: # Fetch results for i in range(len(indices)): + if self.eof.is_set(): + self._stop_subprocess() + return # Fetch result and put index try: result = self.workers[i % self.num_worker].get() except queue.Empty: + self._stop_subprocess() raise Exception("Generator worker process timeout.") except KeyboardInterrupt: - self.eof.set() - for w in self.workers: - w.terminate() - w.join() + self._stop_subprocess() raise Exception("Generator worker receives KeyboardInterrupt.") + if self.eof.is_set(): + self._stop_subprocess() + return if idx_cursor < len(indices): idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) yield tuple([np.array(x, copy=False) for x in result]) - def __del__(self): + def _stop_subprocess(self): self.eof.set() + if self.joined is False: + for w in self.workers: + w.join() + self.joined = True + + def __del__(self): + self._stop_subprocess() -def _generator_worker_loop(dataset, idx_queue, result_queue, eof): +def _subprocess_handle(eof, signum, frame): + logger.info("The subprocess receives a termination signal.") + eof.set() + + +def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing): """ Multithread or multiprocess generator worker process loop. """ + if is_multiprocessing: + signal.signal(signal.SIGTERM, partial(_subprocess_handle, eof)) while True: # Fetch index, block try: @@ -3512,6 +3533,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof): raise Exception("Generator worker receives KeyboardInterrupt.") except queue.Empty: if eof.is_set(): + if is_multiprocessing: + idx_queue.cancel_join_thread() + result_queue.cancel_join_thread() return # If end-of-file (eof) is not set, continue to get data from idx_queue continue @@ -3521,6 +3545,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof): assert eof.is_set(), "" return if eof.is_set(): + if is_multiprocessing: + idx_queue.cancel_join_thread() + result_queue.cancel_join_thread() return # Fetch data, any exception from __getitem__ will terminate worker and timeout master process result = dataset[idx] @@ -3532,6 +3559,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof): raise Exception("Generator worker receives KeyboardInterrupt.") except queue.Full: if eof.is_set(): + if is_multiprocessing: + idx_queue.cancel_join_thread() + result_queue.cancel_join_thread() return # If eof is not set, continue to put data to result_queue continue @@ -3547,7 +3577,7 @@ class _GeneratorWorkerMt(threading.Thread): def __init__(self, dataset, eof): self.idx_queue = queue.Queue(16) self.res_queue = queue.Queue(16) - super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof)) + super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, False)) def put(self, item): """ @@ -3563,10 +3593,10 @@ class _GeneratorWorkerMt(threading.Thread): def queue_empty(self): if not self.idx_queue.empty(): - logger.error("idx_queue is not empty") + logger.warning("idx_queue is not empty") return False if not self.res_queue.empty(): - logger.error("res_queue is not empty") + logger.warning("res_queue is not empty") return False return True @@ -3579,7 +3609,7 @@ class _GeneratorWorkerMp(multiprocessing.Process): def __init__(self, dataset, eof): self.idx_queue = multiprocessing.Queue(16) self.res_queue = multiprocessing.Queue(16) - super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof)) + super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True)) def put(self, item): """ @@ -3597,21 +3627,13 @@ class _GeneratorWorkerMp(multiprocessing.Process): def queue_empty(self): if not self.idx_queue.empty(): - logger.error("idx_queue is not empty.") + logger.warning("idx_queue is not empty.") return False if not self.res_queue.empty(): - logger.error("res_queue is not empty.") + logger.warning("res_queue is not empty.") return False return True - def __del__(self): - # Try to destruct here, sometimes the class itself will be destructed in advance, - # so "self" will be a NoneType - try: - self.terminate() - except AttributeError: - pass - class GeneratorDataset(MappableDataset): """