From 392842b9ddb8f42009516db8f429c506b27ba946 Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Thu, 8 Apr 2021 21:30:36 +0800 Subject: [PATCH] fix python3.8 multi-processing performance issue --- mindspore/dataset/engine/datasets.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 0e7a1ebe0a2..6ae65c7779d 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -27,6 +27,7 @@ import signal import time import uuid import multiprocessing +from multiprocessing.pool import RUN import queue from enum import Enum from functools import partial @@ -1997,6 +1998,9 @@ class BatchDataset(Dataset): self.per_batch_map = _PythonCallable(self.per_batch_map, idx, self.process_pool) self.hook = _ExceptHookHandler() atexit.register(_mp_pool_exit_preprocess) + # If python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown. + if sys.version_info >= (3, 8): + atexit.register(self.process_pool.close) def __del__(self): if hasattr(self, 'process_pool') and self.process_pool is not None: @@ -2230,7 +2234,11 @@ class _PythonCallable: self.idx = idx def __call__(self, *args): - if self.pool is not None and self.pool._state == 0 and check_iterator_cleanup() is False: # pylint: disable=W0212 + # note here: the RUN state of python3.7 and python3.8 is different: + # python3.7: RUN = 0 + # python3.8: RUN = "RUN" + # so we use self.pool._state == RUN instead and we can't use _state == 0 any more. + if self.pool is not None and self.pool._state == RUN and check_iterator_cleanup() is False: # pylint: disable=W0212 # This call will send the tensors along with Python callable index to the process pool. # Block, yield GIL. Current thread will reacquire GIL once result is returned. result = self.pool.apply_async(_pyfunc_worker_exec, [self.idx, *args]) @@ -2384,6 +2392,9 @@ class MapDataset(Dataset): self.operations = iter_specific_operations self.hook = _ExceptHookHandler() atexit.register(_mp_pool_exit_preprocess) + # If python version greater than 3.8, we need to close ThreadPool in atexit for unclean pool teardown. + if sys.version_info >= (3, 8): + atexit.register(self.process_pool.close) def __del__(self): if hasattr(self, 'process_pool') and self.process_pool is not None: