forked from mindspore-Ecosystem/mindspore
!30741 [dataset][dfx] 易用性-Python多进程统一整改专项
Merge pull request !30741 from xiefangqi/md_add_multiprocessing_optimization
This commit is contained in:
commit
a2303a92bd
|
@ -68,6 +68,10 @@ PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) {
|
|||
.def("get_enable_autotune", &ConfigManager::enable_autotune)
|
||||
.def("set_autotune_interval", &ConfigManager::set_autotune_interval)
|
||||
.def("get_autotune_interval", &ConfigManager::autotune_interval)
|
||||
.def("set_enable_watchdog", &ConfigManager::set_enable_watchdog)
|
||||
.def("get_enable_watchdog", &ConfigManager::enable_watchdog)
|
||||
.def("set_multiprocessing_timeout_interval", &ConfigManager::set_multiprocessing_timeout_interval)
|
||||
.def("get_multiprocessing_timeout_interval", &ConfigManager::multiprocessing_timeout_interval)
|
||||
.def("load", [](ConfigManager &c, const std::string &s) { THROW_IF_ERROR(c.LoadFile(s)); });
|
||||
}));
|
||||
|
||||
|
|
|
@ -56,7 +56,9 @@ ConfigManager::ConfigManager()
|
|||
auto_offload_(false),
|
||||
enable_autotune_(false),
|
||||
save_autoconfig_(false),
|
||||
autotune_interval_(kCfgAutoTuneInterval) {
|
||||
autotune_interval_(kCfgAutoTuneInterval),
|
||||
enable_watchdog_(true),
|
||||
multiprocessing_timeout_interval_(kCfgMultiprocessingTimeoutInterval) {
|
||||
autotune_json_filepath_ = kEmptyString;
|
||||
num_cpu_threads_ = num_cpu_threads_ > 0 ? num_cpu_threads_ : std::numeric_limits<uint16_t>::max();
|
||||
num_parallel_workers_ = num_parallel_workers_ < num_cpu_threads_ ? num_parallel_workers_ : num_cpu_threads_;
|
||||
|
|
|
@ -259,6 +259,22 @@ class ConfigManager {
|
|||
// @param interval - autotune interval in steps
|
||||
void set_autotune_interval(int64_t interval) { autotune_interval_ = interval; }
|
||||
|
||||
// setter function
|
||||
// @param enable - To enable watchdog python thread
|
||||
void set_enable_watchdog(bool enable) { enable_watchdog_ = enable; }
|
||||
|
||||
// getter function
|
||||
// @return - Flag to indicate whether watchdog python thread is enabled
|
||||
bool enable_watchdog() const { return enable_watchdog_; }
|
||||
|
||||
// getter function
|
||||
// @return - multiprocessing timeout interval in seconds
|
||||
uint32_t multiprocessing_timeout_interval() const { return multiprocessing_timeout_interval_; }
|
||||
|
||||
// setter function
|
||||
// @param interval - multiprocessing timeout interval in seconds
|
||||
void set_multiprocessing_timeout_interval(uint32_t interval) { multiprocessing_timeout_interval_ = interval; }
|
||||
|
||||
private:
|
||||
int32_t num_parallel_workers_;
|
||||
int32_t worker_connector_size_;
|
||||
|
@ -286,6 +302,8 @@ class ConfigManager {
|
|||
bool save_autoconfig_; // True if should save AutoTune configuration
|
||||
std::string autotune_json_filepath_; // Filepath name of the final AutoTune Configuration JSON file
|
||||
int64_t autotune_interval_;
|
||||
bool enable_watchdog_; // Watchdog python thread enabled flag
|
||||
uint32_t multiprocessing_timeout_interval_; // Multiprocessing timeout interval in seconds
|
||||
// Private helper function that takes a nlohmann json format and populates the settings
|
||||
// @param j - The json nlohmann json info
|
||||
Status FromJson(const nlohmann::json &j);
|
||||
|
|
|
@ -302,8 +302,10 @@ constexpr uint32_t kCfgOpConnectorSize = 16;
|
|||
constexpr uint32_t kCfgSendingBatch = 0;
|
||||
constexpr int32_t kCfgDefaultRankId = -1;
|
||||
constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed;
|
||||
constexpr uint32_t kCfgMonitorSamplingInterval = 1000; // timeout value for sampling interval in milliseconds
|
||||
constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds
|
||||
constexpr uint32_t kCfgMonitorSamplingInterval = 1000; // timeout value for monitor sampling interval in
|
||||
// milliseconds
|
||||
constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds
|
||||
constexpr uint32_t kCfgMultiprocessingTimeoutInterval = 300; // timeout value for multiprocessing interval in seconds
|
||||
constexpr int32_t kCfgDefaultCachePort = 50052;
|
||||
constexpr char kCfgDefaultCacheHost[] = "127.0.0.1";
|
||||
constexpr int32_t kDftCachePrefetchSize = 20;
|
||||
|
|
|
@ -30,12 +30,19 @@ import mindspore._c_dataengine as cde
|
|||
from mindspore import log as logger
|
||||
from .validator_helpers import replace_none
|
||||
|
||||
__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers',
|
||||
'get_num_parallel_workers', 'set_numa_enable', 'get_numa_enable', 'set_monitor_sampling_interval',
|
||||
'get_monitor_sampling_interval', 'set_callback_timeout', 'get_callback_timeout',
|
||||
'set_auto_num_workers', 'get_auto_num_workers', 'set_enable_shared_mem', 'get_enable_shared_mem',
|
||||
'set_sending_batches', 'load', '_init_device_info', 'set_enable_autotune', 'get_enable_autotune',
|
||||
'set_autotune_interval', 'get_autotune_interval']
|
||||
__all__ = ['set_sending_batches', 'load', '_init_device_info',
|
||||
'set_seed', 'get_seed',
|
||||
'set_prefetch_size', 'get_prefetch_size',
|
||||
'set_num_parallel_workers', 'get_num_parallel_workers',
|
||||
'set_numa_enable', 'get_numa_enable',
|
||||
'set_monitor_sampling_interval', 'get_monitor_sampling_interval',
|
||||
'set_callback_timeout', 'get_callback_timeout',
|
||||
'set_auto_num_workers', 'get_auto_num_workers',
|
||||
'set_enable_shared_mem', 'get_enable_shared_mem',
|
||||
'set_enable_autotune', 'get_enable_autotune',
|
||||
'set_autotune_interval', 'get_autotune_interval',
|
||||
'set_enable_watchdog', 'get_enable_watchdog',
|
||||
'set_multiprocessing_timeout_interval', 'get_multiprocessing_timeout_interval']
|
||||
|
||||
INT32_MAX = 2147483647
|
||||
UINT32_MAX = 4294967295
|
||||
|
@ -633,3 +640,76 @@ def get_auto_offload():
|
|||
>>> auto_offload = ds.config.get_auto_offload()
|
||||
"""
|
||||
return _config.get_auto_offload()
|
||||
|
||||
|
||||
def set_enable_watchdog(enable):
|
||||
"""
|
||||
Set the default state of watchdog Python thread as enabled, the default state of watchdog Python thread is enabled.
|
||||
Watchdog is a thread which cleans up hanging subprocesses.
|
||||
|
||||
Args:
|
||||
enable (bool): Whether to launch a watchdog Python thread. System default: True.
|
||||
|
||||
Raises:
|
||||
TypeError: If enable is not a boolean data type.
|
||||
|
||||
Examples:
|
||||
>>> # Set a new global configuration value for the state of watchdog Python thread as enabled.
|
||||
>>> ds.config.set_enable_watchdog(True)
|
||||
"""
|
||||
if not isinstance(enable, bool):
|
||||
raise TypeError("enable must be a boolean dtype.")
|
||||
_config.set_enable_watchdog(enable)
|
||||
|
||||
|
||||
def get_enable_watchdog():
|
||||
"""
|
||||
Get the state of watchdog Python thread to indicate enabled or disabled state.
|
||||
This is the DEFAULT watchdog Python thread state value used for the all processes.
|
||||
|
||||
Returns:
|
||||
bool, the default state of watchdog Python thread enabled.
|
||||
|
||||
Examples:
|
||||
>>> # Get the global configuration of watchdog Python thread.
|
||||
>>> watchdog_state = ds.config.get_enable_watchdog()
|
||||
"""
|
||||
return _config.get_enable_watchdog()
|
||||
|
||||
|
||||
def set_multiprocessing_timeout_interval(interval):
|
||||
"""
|
||||
Set the default interval (in seconds) for multiprocessing timeout when main process gets data from subprocesses.
|
||||
|
||||
Args:
|
||||
interval (int): Interval (in seconds) to be used for multiprocessing timeout when main process gets data from
|
||||
subprocess. System default: 300s.
|
||||
|
||||
Raises:
|
||||
ValueError: If interval is invalid when interval <= 0 or interval > MAX_INT_32.
|
||||
|
||||
Examples:
|
||||
>>> # Set a new global configuration value for multiprocessing timeout when getting data.
|
||||
>>> ds.config.set_multiprocessing_timeout_interval(300)
|
||||
"""
|
||||
if not isinstance(interval, int):
|
||||
raise ValueError("interval isn't of type int.")
|
||||
if interval <= 0 or interval > INT32_MAX:
|
||||
raise ValueError("Interval given is not within the required range (0, INT32_MAX).")
|
||||
_config.set_multiprocessing_timeout_interval(interval)
|
||||
|
||||
|
||||
def get_multiprocessing_timeout_interval():
|
||||
"""
|
||||
Get the global configuration of multiprocessing timeout when main process gets data from subprocesses.
|
||||
|
||||
Returns:
|
||||
int, interval (in seconds) for multiprocessing timeout when main process gets data from subprocesses
|
||||
(default is 300s).
|
||||
|
||||
Examples:
|
||||
>>> # Get the global configuration of multiprocessing timeout when main process gets data from subprocesses.
|
||||
>>> # If set_multiprocessing_timeout_interval() is never called before, the default value(300) will be returned.
|
||||
>>> multiprocessing_timeout_interval = ds.config.get_multiprocessing_timeout_interval()
|
||||
"""
|
||||
return _config.get_multiprocessing_timeout_interval()
|
||||
|
|
|
@ -65,7 +65,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_rename, check_device_send, check_take, check_project, \
|
||||
check_sync_wait, check_zip_dataset, check_add_column, check_concat, check_split, check_bucket_batch_by_length, \
|
||||
check_save, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_enable_watchdog
|
||||
from ..core.datatypes import mstype_to_detype
|
||||
from ..core.validator_helpers import replace_none
|
||||
from ..core.py_util_helpers import ExceptionHandler
|
||||
|
@ -227,8 +228,11 @@ def _get_operator_process():
|
|||
keys = process_info.keys()
|
||||
fetched_all = True
|
||||
for key in keys:
|
||||
op_process[key] = list(process_info[key][1])
|
||||
item_full = (len(process_info[key][1]) == process_info[key][0])
|
||||
try:
|
||||
op_process[key] = list(process_info[key][1])
|
||||
item_full = (len(process_info[key][1]) == process_info[key][0])
|
||||
except KeyError as err:
|
||||
raise err
|
||||
fetched_all = fetched_all and item_full
|
||||
return op_process, fetched_all
|
||||
|
||||
|
@ -1535,7 +1539,7 @@ class Dataset:
|
|||
if self._col_names is None:
|
||||
runtime_getter = self._init_tree_getters()
|
||||
self._col_names = runtime_getter[0].GetColumnNames()
|
||||
self.close_pool()
|
||||
runtime_getter[2].close_pool()
|
||||
runtime_getter[2].notify_watchdog()
|
||||
return self._col_names
|
||||
|
||||
|
@ -1554,7 +1558,7 @@ class Dataset:
|
|||
runtime_getter = self._init_tree_getters()
|
||||
self.saved_output_shapes = runtime_getter[0].GetOutputShapes()
|
||||
self.saved_output_types = runtime_getter[0].GetOutputTypes()
|
||||
self.close_pool()
|
||||
runtime_getter[2].close_pool()
|
||||
runtime_getter[2].notify_watchdog()
|
||||
if self.dynamic_setting[0]:
|
||||
self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes = self._dynamic_output_shapes()
|
||||
|
@ -1575,7 +1579,7 @@ class Dataset:
|
|||
runtime_getter = self._init_tree_getters()
|
||||
self.saved_output_shapes = runtime_getter[0].GetOutputShapes()
|
||||
self.saved_output_types = runtime_getter[0].GetOutputTypes()
|
||||
self.close_pool()
|
||||
runtime_getter[2].close_pool()
|
||||
runtime_getter[2].notify_watchdog()
|
||||
if self.dynamic_setting[0]:
|
||||
self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes = self._dynamic_output_shapes()
|
||||
|
@ -1595,7 +1599,7 @@ class Dataset:
|
|||
if self.dataset_size is None:
|
||||
runtime_getter = self.__init_size_getter()
|
||||
self.dataset_size = runtime_getter[0].GetDatasetSize(False)
|
||||
self.close_pool()
|
||||
runtime_getter[2].close_pool()
|
||||
runtime_getter[2].notify_watchdog()
|
||||
return self.dataset_size
|
||||
|
||||
|
@ -1743,7 +1747,7 @@ class Dataset:
|
|||
if self._num_classes is None:
|
||||
runtime_getter = self._init_tree_getters()
|
||||
self._num_classes = runtime_getter[0].GetNumClasses()
|
||||
self.close_pool()
|
||||
runtime_getter[2].close_pool()
|
||||
runtime_getter[2].notify_watchdog()
|
||||
if self._num_classes == -1:
|
||||
return None
|
||||
|
@ -2656,6 +2660,9 @@ _LOCK = threading.Lock()
|
|||
# Python multiprocessing library forbid sending lambda function through pipe.
|
||||
# This init function allow us to add all Python function to a global collection and then fork afterwards.
|
||||
def _pyfunc_worker_init(pyfunc_list, args_queue, ret_queue):
|
||||
# Some threads in multiprocess.pool can't process sigint signal,
|
||||
# and will occur hang problem, so ctrl+c will pass to parent process.
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
global _GLOBAL_PYFUNC_LIST
|
||||
global _ARGS_QUEUE
|
||||
global _RET_QUEUE
|
||||
|
@ -2765,6 +2772,7 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
|
|||
self.eot = None
|
||||
self.watch_dog = None
|
||||
self.workers = []
|
||||
self.ppid = os.getpid()
|
||||
self.hook = None
|
||||
|
||||
def Launch(self, op_id=-1):
|
||||
|
@ -2974,19 +2982,20 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
|
|||
raise TypeError("[Internal Error] The 3rd parameter of watch dog thread should be multiprocessing.Pool, " \
|
||||
"but got {}".format(type(pool)))
|
||||
while not eot.is_set():
|
||||
subprocess_exit_num = 0
|
||||
clear_subprocess_timeout = 0
|
||||
# Monitoring and count how many subprocesses already exit
|
||||
subprocess_exit_num = _PythonMultiprocessing._monitor_subprocess_exit(workers)
|
||||
clear_subprocess_timeout = _PythonMultiprocessing._monitor_subprocess_exit(workers)
|
||||
# If find subprocess exit, we will wait for 30s and do some waitpid operations
|
||||
if subprocess_exit_num > 0:
|
||||
if clear_subprocess_timeout > 0:
|
||||
if pool is not None:
|
||||
# Python multiprocessing.pool has a bug, if sub process of pool is killed, pool will launch
|
||||
# a new sub process, so we have to set worker_handler._state to TERMINATE to stop relaunching.
|
||||
if pool._state == RUN: # pylint: disable=W0212
|
||||
pool._state = TERMINATE # pylint: disable=W0212
|
||||
pool._worker_handler._state = TERMINATE # pylint: disable=W0212
|
||||
pool._worker_handler.join() # pylint: disable=W0212
|
||||
start = time.time()
|
||||
while time.time() - start < 30:
|
||||
while time.time() - start < clear_subprocess_timeout:
|
||||
# We need to distinguishing get_dataset_size or train finished normally and hang scenario.
|
||||
# If get_dataset_size or train finished normally, _stop_subprocess can be execute and
|
||||
# self.need_abort can be set to True. If main process is hang in get(), self.need_abort
|
||||
|
@ -3002,7 +3011,8 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
|
|||
else:
|
||||
_PythonMultiprocessing._terminate_process(workers)
|
||||
logger.critical("The subprocess of dataset may exit unexpected or be killed, "
|
||||
"main process will exit.")
|
||||
"main process will exit. If this is not an artificial operation, you can use "
|
||||
"ds.config.set_enable_watchdog(False) to block this error.")
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
|
||||
@staticmethod
|
||||
|
@ -3013,23 +3023,92 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
|
|||
w.terminate()
|
||||
for w in workers:
|
||||
if w._closed is False: # pylint: disable=W0212
|
||||
w.join()
|
||||
# We don't use w.join because join can only used in main process or join will raise an error.
|
||||
w._popen.wait() # pylint: disable=W0212
|
||||
|
||||
# Monitor the exit number of subprocesses
|
||||
@staticmethod
|
||||
def _monitor_subprocess_exit(workers):
|
||||
subprocess_exit_num = 0
|
||||
"""
|
||||
To monitor whether process is exit.
|
||||
|
||||
Args:
|
||||
workers (list of multiprocessing.Process): multiprocessing.Process.
|
||||
|
||||
Returns:
|
||||
int, the timeout(in seconds) when process exit.
|
||||
"""
|
||||
for w in workers:
|
||||
if w.exitcode is not None:
|
||||
subprocess_exit_num += 1
|
||||
return subprocess_exit_num
|
||||
exit_code = w.exitcode
|
||||
if exit_code is not None:
|
||||
# For kill -9, we can exit quickly
|
||||
if exit_code == -9:
|
||||
return 1
|
||||
# For kill -15, we still exit after 30s
|
||||
if exit_code == -15:
|
||||
return 30
|
||||
return 0
|
||||
|
||||
# Monitor the exit status of main process
|
||||
@staticmethod
|
||||
def process_still_alive(ppid):
|
||||
"""
|
||||
We always hit dead lock when we use psutil or w.exitcode to check whether a process is still alive. So we use
|
||||
os.kill(ppid, 0) as the best solution when we want to check whether process is still alive.
|
||||
"""
|
||||
try:
|
||||
os.kill(ppid, 0)
|
||||
except OSError:
|
||||
return False
|
||||
return True
|
||||
|
||||
# When main process exit, subprocesses will be terminate
|
||||
@staticmethod
|
||||
def _clean_process(ppid, workers, pool=None):
|
||||
"""
|
||||
This is the execute function of clean process, if we found main process is exit, we will clean subprocesses.
|
||||
|
||||
:param ppid: The process id of main process.
|
||||
:param workers: The list of subprocesses.
|
||||
:param pool: multiprocessing.Pool object, we can get list of subprocesses from _pool.
|
||||
"""
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
while _PythonMultiprocessing.process_still_alive(ppid):
|
||||
time.sleep(0.1)
|
||||
if pool is not None:
|
||||
# Python multiprocessing.pool has a bug, if sub process of pool is killed, pool will launch
|
||||
# a new sub process, so we have to set worker_handler._state to TERMINATE to stop relaunching.
|
||||
# But this pool is not the same object as it in main process, so we don't support kill main process then
|
||||
# kill subprocess.
|
||||
if pool._state == RUN: # pylint: disable=W0212
|
||||
pool._state = TERMINATE # pylint: disable=W0212
|
||||
pool._worker_handler._state = TERMINATE # pylint: disable=W0212
|
||||
pool._worker_handler.join() # pylint: disable=W0212
|
||||
if pool is not None:
|
||||
_PythonMultiprocessing._terminate_process(pool._pool) # pylint: disable=W0212
|
||||
else:
|
||||
_PythonMultiprocessing._terminate_process(workers)
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
|
||||
def _launch_watch_dog(self):
|
||||
"""
|
||||
We will launch a watchdog thread and a clean process to cleaning subprocess when there is process was killed.
|
||||
The watchdog thread will cleanup subprocesses and main process when one of the subprocesses was killed.
|
||||
The cleaning subprocess will cleanup subprocesses when main process was killed.
|
||||
"""
|
||||
if platform.system().lower() != 'windows':
|
||||
self.eot = threading.Event()
|
||||
self.watch_dog = threading.Thread(target=self._watch_dog, args=(self.eot, self.workers, self.process_pool))
|
||||
self.watch_dog.daemon = True
|
||||
self.watch_dog.start()
|
||||
self.cleaning_process = multiprocessing.Process(target=self._clean_process,
|
||||
args=(self.ppid, self.workers, self.process_pool))
|
||||
self.cleaning_process.daemon = True
|
||||
self.cleaning_process.start()
|
||||
|
||||
if get_enable_watchdog():
|
||||
self.eot = threading.Event()
|
||||
self.watch_dog = threading.Thread(target=self._watch_dog,
|
||||
args=(self.eot, self.workers + [self.cleaning_process],
|
||||
self.process_pool))
|
||||
self.watch_dog.daemon = True
|
||||
self.watch_dog.start()
|
||||
|
||||
def _abort_watchdog(self):
|
||||
if not self.eot.is_set():
|
||||
|
@ -3038,6 +3117,8 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
|
|||
def abort_watchdog(self):
|
||||
if hasattr(self, 'watch_dog') and self.watch_dog is not None and hasattr(self, 'eot') and self.eot is not None:
|
||||
self._abort_watchdog()
|
||||
if hasattr(self, 'cleaning_process') and self.cleaning_process is not None:
|
||||
_PythonMultiprocessing._terminate_process([self.cleaning_process])
|
||||
|
||||
def is_running(self):
|
||||
# note here: the RUN state of python3.7 and python3.8 is different:
|
||||
|
|
|
@ -45,7 +45,8 @@ from .datasets import UnionBaseDataset, MappableDataset, Schema, to_list, _Pytho
|
|||
from . import samplers
|
||||
from .queue import _SharedQueue
|
||||
from .validators import check_generatordataset, check_numpyslicesdataset, check_paddeddataset
|
||||
from ..core.config import get_enable_shared_mem, get_prefetch_size
|
||||
from ..core.config import get_enable_shared_mem, get_prefetch_size, get_multiprocessing_timeout_interval, \
|
||||
get_enable_watchdog
|
||||
from ..core.datatypes import mstypelist_to_detypelist
|
||||
from ..core.py_util_helpers import ExceptionHandler
|
||||
|
||||
|
@ -161,7 +162,7 @@ class SamplerFn:
|
|||
self.need_join = False
|
||||
self.ppid = os.getpid()
|
||||
self.pids = []
|
||||
self.check_interval = 300 # the interval of check queue's size
|
||||
self.check_interval = get_multiprocessing_timeout_interval() # the interval of check queue's size
|
||||
self._final_join = True
|
||||
|
||||
# Event for end of epoch
|
||||
|
@ -185,7 +186,7 @@ class SamplerFn:
|
|||
for _ in range(num_worker):
|
||||
if multi_process is True:
|
||||
try:
|
||||
worker = _GeneratorWorkerMp(dataset, self.eof, max_rowsize, queue_size)
|
||||
worker = _GeneratorWorkerMp(dataset, self.eof, max_rowsize, queue_size, self.ppid)
|
||||
except Exception:
|
||||
raise RuntimeError("Init multiprocessing.Queue() failed, This might be caused by insufficient shm,"
|
||||
+ " and the recommended shm size is at least 5 GB.")
|
||||
|
@ -200,19 +201,7 @@ class SamplerFn:
|
|||
worker = _GeneratorWorkerMt(dataset, self.eof)
|
||||
worker.daemon = True
|
||||
self.workers.append(worker)
|
||||
if multi_process is True and platform.system().lower() != 'windows':
|
||||
self.eot = threading.Event()
|
||||
self.watch_dog = threading.Thread(target=_PythonMultiprocessing._watch_dog, # pylint: disable=W0212
|
||||
args=(self.eot, self.workers))
|
||||
self.watch_dog.daemon = True
|
||||
self.watch_dog.start()
|
||||
|
||||
if self._final_join is True:
|
||||
self._jointhread = Finalize(
|
||||
self.watch_dog, self._finalize_join,
|
||||
args=(weakref.ref(self.watch_dog), self.eot),
|
||||
exitpriority=-5
|
||||
)
|
||||
self._launch_cleanup_worker(multi_process=multi_process)
|
||||
|
||||
def process(self, indices):
|
||||
"""
|
||||
|
@ -250,7 +239,11 @@ class SamplerFn:
|
|||
if cost_time / self.check_interval >= wait_count:
|
||||
wait_count += 1
|
||||
logger.warning("It has been waiting for " + str(cost_time) + "s because the multi "
|
||||
"thread/process of the generator generates data had been hung by gil lock.")
|
||||
"thread/process of the generator generates data had been hung by gil lock. "
|
||||
"Check whether the source of generator has an infinite loop operation or the "
|
||||
"output data is too large. You can also set the timeout interval by "
|
||||
"ds.config.set_multiprocessing_interval to adjust the output frequency of this "
|
||||
"log.")
|
||||
|
||||
result = self.workers[i % self.num_worker].get()
|
||||
if isinstance(result, ExceptionHandler):
|
||||
|
@ -268,6 +261,33 @@ class SamplerFn:
|
|||
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
|
||||
yield _convert_row(result)
|
||||
|
||||
def _launch_cleanup_worker(self, multi_process):
|
||||
"""
|
||||
We need a extra thread and process if main process or subprocess was killed.
|
||||
|
||||
Args:
|
||||
multi_process: Whether use multiprocess.
|
||||
"""
|
||||
if multi_process is True and platform.system().lower() != 'windows':
|
||||
_clean_worker_func = _PythonMultiprocessing._clean_process # pylint: disable=W0212
|
||||
self.cleaning_process = multiprocessing.Process(target=_clean_worker_func, args=(self.ppid, self.workers))
|
||||
self.cleaning_process.daemon = True
|
||||
self.cleaning_process.start()
|
||||
|
||||
if get_enable_watchdog():
|
||||
self.eot = threading.Event()
|
||||
self.watch_dog = threading.Thread(target=_PythonMultiprocessing._watch_dog, # pylint: disable=W0212
|
||||
args=(self.eot, self.workers + [self.cleaning_process]))
|
||||
self.watch_dog.daemon = True
|
||||
self.watch_dog.start()
|
||||
|
||||
if self._final_join is True:
|
||||
self._jointhread = Finalize(
|
||||
self.watch_dog, self._finalize_join,
|
||||
args=(weakref.ref(self.watch_dog), self.eot),
|
||||
exitpriority=-5
|
||||
)
|
||||
|
||||
def _stop_subprocess(self):
|
||||
"""Only the main process can call join."""
|
||||
if self.need_join is True and self.ppid == os.getpid():
|
||||
|
@ -281,6 +301,8 @@ class SamplerFn:
|
|||
def _abort_watchdog(self):
|
||||
if hasattr(self, 'eot') and self.eot is not None and not self.eot.is_set():
|
||||
self.eot.set()
|
||||
if hasattr(self, 'cleaning_process') and self.cleaning_process is not None:
|
||||
_PythonMultiprocessing._terminate_process([self.cleaning_process]) # pylint: disable=W0212
|
||||
|
||||
@classmethod
|
||||
def _finalize_join(cls, twr, eot):
|
||||
|
@ -306,7 +328,7 @@ def _ignore_sigint(is_multiprocessing):
|
|||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
|
||||
def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing):
|
||||
def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing, ppid=-1):
|
||||
"""
|
||||
Multithread or multiprocess generator worker process loop.
|
||||
"""
|
||||
|
@ -318,14 +340,8 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiproces
|
|||
# Fetch index, block
|
||||
try:
|
||||
idx = idx_queue.get(timeout=1)
|
||||
except KeyboardInterrupt:
|
||||
if is_multiprocessing:
|
||||
eof.set()
|
||||
idx_queue.cancel_join_thread()
|
||||
result_queue.cancel_join_thread()
|
||||
raise Exception("Generator worker receives KeyboardInterrupt.")
|
||||
except queue.Empty:
|
||||
if eof.is_set():
|
||||
if eof.is_set() or (is_multiprocessing and not _PythonMultiprocessing.process_still_alive(ppid)):
|
||||
if is_multiprocessing:
|
||||
idx_queue.cancel_join_thread()
|
||||
result_queue.cancel_join_thread()
|
||||
|
@ -352,14 +368,8 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiproces
|
|||
while True:
|
||||
try:
|
||||
result_queue.put(result, timeout=5)
|
||||
except KeyboardInterrupt:
|
||||
if is_multiprocessing:
|
||||
eof.set()
|
||||
idx_queue.cancel_join_thread()
|
||||
result_queue.cancel_join_thread()
|
||||
raise Exception("Generator worker receives KeyboardInterrupt.")
|
||||
except queue.Full:
|
||||
if eof.is_set():
|
||||
if eof.is_set() or (is_multiprocessing and not _PythonMultiprocessing.process_still_alive(ppid)):
|
||||
if is_multiprocessing:
|
||||
idx_queue.cancel_join_thread()
|
||||
result_queue.cancel_join_thread()
|
||||
|
@ -407,7 +417,7 @@ class _GeneratorWorkerMp(multiprocessing.Process):
|
|||
Worker process for multiprocess Generator.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, eof, max_rowsize, queue_size):
|
||||
def __init__(self, dataset, eof, max_rowsize, queue_size, ppid):
|
||||
self.idx_queue = multiprocessing.Queue(queue_size)
|
||||
if get_enable_shared_mem():
|
||||
self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize)
|
||||
|
@ -415,7 +425,7 @@ class _GeneratorWorkerMp(multiprocessing.Process):
|
|||
self.res_queue = multiprocessing.Queue(queue_size)
|
||||
self.idx_queue._joincancelled = True # pylint: disable=W0212
|
||||
self.res_queue._joincancelled = True # pylint: disable=W0212
|
||||
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True))
|
||||
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True, ppid))
|
||||
|
||||
def put(self, item):
|
||||
"""
|
||||
|
|
|
@ -394,6 +394,41 @@ def test_auto_num_workers():
|
|||
assert saved_config == ds.config.get_auto_num_workers()
|
||||
|
||||
|
||||
def test_enable_watchdog():
|
||||
"""
|
||||
Feature: Test the function of get_enable_watchdog and set_enable_watchdog.
|
||||
Description: We add this new interface so we can close the watchdog thread
|
||||
Expectation: The default state is True, when execute set_enable_watchdog, the state will update.
|
||||
"""
|
||||
saved_config = ds.config.get_enable_watchdog()
|
||||
assert isinstance(saved_config, bool)
|
||||
assert saved_config is True
|
||||
# change to a different config
|
||||
flipped_config = not saved_config
|
||||
ds.config.set_enable_watchdog(flipped_config)
|
||||
assert flipped_config == ds.config.get_enable_watchdog()
|
||||
# now flip this back
|
||||
ds.config.set_enable_watchdog(saved_config)
|
||||
assert saved_config == ds.config.get_enable_watchdog()
|
||||
|
||||
|
||||
def test_multiprocessing_timeout_interval():
|
||||
"""
|
||||
Feature: Test the function of get_multiprocessing_timeout_interval and set_multiprocessing_timeout_interval.
|
||||
Description: We add this new interface so we can adjust the timeout of multiprocessing get function.
|
||||
Expectation: The default state is 300s, when execute set_multiprocessing_timeout_interval, the state will update.
|
||||
"""
|
||||
saved_config = ds.config.get_multiprocessing_timeout_interval()
|
||||
assert saved_config == 300
|
||||
# change to a different config
|
||||
flipped_config = 1000
|
||||
ds.config.set_multiprocessing_timeout_interval(flipped_config)
|
||||
assert flipped_config == ds.config.get_multiprocessing_timeout_interval()
|
||||
# now flip this back
|
||||
ds.config.set_multiprocessing_timeout_interval(saved_config)
|
||||
assert saved_config == ds.config.get_multiprocessing_timeout_interval()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_basic()
|
||||
test_get_seed()
|
||||
|
@ -406,3 +441,5 @@ if __name__ == '__main__':
|
|||
test_deterministic_python_seed_multi_thread()
|
||||
test_auto_num_workers_error()
|
||||
test_auto_num_workers()
|
||||
test_enable_watchdog()
|
||||
test_multiprocessing_timeout_interval()
|
||||
|
|
Loading…
Reference in New Issue