!30741 [dataset][dfx] 易用性-Python多进程统一整改专项

Merge pull request !30741 from xiefangqi/md_add_multiprocessing_optimization
This commit is contained in:
i-robot 2022-03-07 01:29:21 +00:00 committed by Gitee
commit a2303a92bd
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 299 additions and 65 deletions

View File

@ -68,6 +68,10 @@ PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) {
.def("get_enable_autotune", &ConfigManager::enable_autotune)
.def("set_autotune_interval", &ConfigManager::set_autotune_interval)
.def("get_autotune_interval", &ConfigManager::autotune_interval)
.def("set_enable_watchdog", &ConfigManager::set_enable_watchdog)
.def("get_enable_watchdog", &ConfigManager::enable_watchdog)
.def("set_multiprocessing_timeout_interval", &ConfigManager::set_multiprocessing_timeout_interval)
.def("get_multiprocessing_timeout_interval", &ConfigManager::multiprocessing_timeout_interval)
.def("load", [](ConfigManager &c, const std::string &s) { THROW_IF_ERROR(c.LoadFile(s)); });
}));

View File

@ -56,7 +56,9 @@ ConfigManager::ConfigManager()
auto_offload_(false),
enable_autotune_(false),
save_autoconfig_(false),
autotune_interval_(kCfgAutoTuneInterval) {
autotune_interval_(kCfgAutoTuneInterval),
enable_watchdog_(true),
multiprocessing_timeout_interval_(kCfgMultiprocessingTimeoutInterval) {
autotune_json_filepath_ = kEmptyString;
num_cpu_threads_ = num_cpu_threads_ > 0 ? num_cpu_threads_ : std::numeric_limits<uint16_t>::max();
num_parallel_workers_ = num_parallel_workers_ < num_cpu_threads_ ? num_parallel_workers_ : num_cpu_threads_;

View File

@ -259,6 +259,22 @@ class ConfigManager {
// @param interval - autotune interval in steps
void set_autotune_interval(int64_t interval) { autotune_interval_ = interval; }
// setter function
// @param enable - To enable watchdog python thread
void set_enable_watchdog(bool enable) { enable_watchdog_ = enable; }
// getter function
// @return - Flag to indicate whether watchdog python thread is enabled
bool enable_watchdog() const { return enable_watchdog_; }
// getter function
// @return - multiprocessing timeout interval in seconds
uint32_t multiprocessing_timeout_interval() const { return multiprocessing_timeout_interval_; }
// setter function
// @param interval - multiprocessing timeout interval in seconds
void set_multiprocessing_timeout_interval(uint32_t interval) { multiprocessing_timeout_interval_ = interval; }
private:
int32_t num_parallel_workers_;
int32_t worker_connector_size_;
@ -286,6 +302,8 @@ class ConfigManager {
bool save_autoconfig_; // True if should save AutoTune configuration
std::string autotune_json_filepath_; // Filepath name of the final AutoTune Configuration JSON file
int64_t autotune_interval_;
bool enable_watchdog_; // Watchdog python thread enabled flag
uint32_t multiprocessing_timeout_interval_; // Multiprocessing timeout interval in seconds
// Private helper function that takes a nlohmann json format and populates the settings
// @param j - The json nlohmann json info
Status FromJson(const nlohmann::json &j);

View File

@ -302,8 +302,10 @@ constexpr uint32_t kCfgOpConnectorSize = 16;
constexpr uint32_t kCfgSendingBatch = 0;
constexpr int32_t kCfgDefaultRankId = -1;
constexpr uint32_t kCfgDefaultSeed = std::mt19937::default_seed;
constexpr uint32_t kCfgMonitorSamplingInterval = 1000; // timeout value for sampling interval in milliseconds
constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds
constexpr uint32_t kCfgMonitorSamplingInterval = 1000; // timeout value for monitor sampling interval in
// milliseconds
constexpr uint32_t kCfgCallbackTimeout = 60; // timeout value for callback in seconds
constexpr uint32_t kCfgMultiprocessingTimeoutInterval = 300; // timeout value for multiprocessing interval in seconds
constexpr int32_t kCfgDefaultCachePort = 50052;
constexpr char kCfgDefaultCacheHost[] = "127.0.0.1";
constexpr int32_t kDftCachePrefetchSize = 20;

View File

@ -30,12 +30,19 @@ import mindspore._c_dataengine as cde
from mindspore import log as logger
from .validator_helpers import replace_none
__all__ = ['set_seed', 'get_seed', 'set_prefetch_size', 'get_prefetch_size', 'set_num_parallel_workers',
'get_num_parallel_workers', 'set_numa_enable', 'get_numa_enable', 'set_monitor_sampling_interval',
'get_monitor_sampling_interval', 'set_callback_timeout', 'get_callback_timeout',
'set_auto_num_workers', 'get_auto_num_workers', 'set_enable_shared_mem', 'get_enable_shared_mem',
'set_sending_batches', 'load', '_init_device_info', 'set_enable_autotune', 'get_enable_autotune',
'set_autotune_interval', 'get_autotune_interval']
__all__ = ['set_sending_batches', 'load', '_init_device_info',
'set_seed', 'get_seed',
'set_prefetch_size', 'get_prefetch_size',
'set_num_parallel_workers', 'get_num_parallel_workers',
'set_numa_enable', 'get_numa_enable',
'set_monitor_sampling_interval', 'get_monitor_sampling_interval',
'set_callback_timeout', 'get_callback_timeout',
'set_auto_num_workers', 'get_auto_num_workers',
'set_enable_shared_mem', 'get_enable_shared_mem',
'set_enable_autotune', 'get_enable_autotune',
'set_autotune_interval', 'get_autotune_interval',
'set_enable_watchdog', 'get_enable_watchdog',
'set_multiprocessing_timeout_interval', 'get_multiprocessing_timeout_interval']
INT32_MAX = 2147483647
UINT32_MAX = 4294967295
@ -633,3 +640,76 @@ def get_auto_offload():
>>> auto_offload = ds.config.get_auto_offload()
"""
return _config.get_auto_offload()
def set_enable_watchdog(enable):
"""
Set the default state of watchdog Python thread as enabled, the default state of watchdog Python thread is enabled.
Watchdog is a thread which cleans up hanging subprocesses.
Args:
enable (bool): Whether to launch a watchdog Python thread. System default: True.
Raises:
TypeError: If enable is not a boolean data type.
Examples:
>>> # Set a new global configuration value for the state of watchdog Python thread as enabled.
>>> ds.config.set_enable_watchdog(True)
"""
if not isinstance(enable, bool):
raise TypeError("enable must be a boolean dtype.")
_config.set_enable_watchdog(enable)
def get_enable_watchdog():
"""
Get the state of watchdog Python thread to indicate enabled or disabled state.
This is the DEFAULT watchdog Python thread state value used for the all processes.
Returns:
bool, the default state of watchdog Python thread enabled.
Examples:
>>> # Get the global configuration of watchdog Python thread.
>>> watchdog_state = ds.config.get_enable_watchdog()
"""
return _config.get_enable_watchdog()
def set_multiprocessing_timeout_interval(interval):
"""
Set the default interval (in seconds) for multiprocessing timeout when main process gets data from subprocesses.
Args:
interval (int): Interval (in seconds) to be used for multiprocessing timeout when main process gets data from
subprocess. System default: 300s.
Raises:
ValueError: If interval is invalid when interval <= 0 or interval > MAX_INT_32.
Examples:
>>> # Set a new global configuration value for multiprocessing timeout when getting data.
>>> ds.config.set_multiprocessing_timeout_interval(300)
"""
if not isinstance(interval, int):
raise ValueError("interval isn't of type int.")
if interval <= 0 or interval > INT32_MAX:
raise ValueError("Interval given is not within the required range (0, INT32_MAX).")
_config.set_multiprocessing_timeout_interval(interval)
def get_multiprocessing_timeout_interval():
"""
Get the global configuration of multiprocessing timeout when main process gets data from subprocesses.
Returns:
int, interval (in seconds) for multiprocessing timeout when main process gets data from subprocesses
(default is 300s).
Examples:
>>> # Get the global configuration of multiprocessing timeout when main process gets data from subprocesses.
>>> # If set_multiprocessing_timeout_interval() is never called before, the default value(300) will be returned.
>>> multiprocessing_timeout_interval = ds.config.get_multiprocessing_timeout_interval()
"""
return _config.get_multiprocessing_timeout_interval()

View File

@ -65,7 +65,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_rename, check_device_send, check_take, check_project, \
check_sync_wait, check_zip_dataset, check_add_column, check_concat, check_split, check_bucket_batch_by_length, \
check_save, check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
get_enable_watchdog
from ..core.datatypes import mstype_to_detype
from ..core.validator_helpers import replace_none
from ..core.py_util_helpers import ExceptionHandler
@ -227,8 +228,11 @@ def _get_operator_process():
keys = process_info.keys()
fetched_all = True
for key in keys:
op_process[key] = list(process_info[key][1])
item_full = (len(process_info[key][1]) == process_info[key][0])
try:
op_process[key] = list(process_info[key][1])
item_full = (len(process_info[key][1]) == process_info[key][0])
except KeyError as err:
raise err
fetched_all = fetched_all and item_full
return op_process, fetched_all
@ -1535,7 +1539,7 @@ class Dataset:
if self._col_names is None:
runtime_getter = self._init_tree_getters()
self._col_names = runtime_getter[0].GetColumnNames()
self.close_pool()
runtime_getter[2].close_pool()
runtime_getter[2].notify_watchdog()
return self._col_names
@ -1554,7 +1558,7 @@ class Dataset:
runtime_getter = self._init_tree_getters()
self.saved_output_shapes = runtime_getter[0].GetOutputShapes()
self.saved_output_types = runtime_getter[0].GetOutputTypes()
self.close_pool()
runtime_getter[2].close_pool()
runtime_getter[2].notify_watchdog()
if self.dynamic_setting[0]:
self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes = self._dynamic_output_shapes()
@ -1575,7 +1579,7 @@ class Dataset:
runtime_getter = self._init_tree_getters()
self.saved_output_shapes = runtime_getter[0].GetOutputShapes()
self.saved_output_types = runtime_getter[0].GetOutputTypes()
self.close_pool()
runtime_getter[2].close_pool()
runtime_getter[2].notify_watchdog()
if self.dynamic_setting[0]:
self.saved_output_shapes, self.saved_min_shapes, self.saved_max_shapes = self._dynamic_output_shapes()
@ -1595,7 +1599,7 @@ class Dataset:
if self.dataset_size is None:
runtime_getter = self.__init_size_getter()
self.dataset_size = runtime_getter[0].GetDatasetSize(False)
self.close_pool()
runtime_getter[2].close_pool()
runtime_getter[2].notify_watchdog()
return self.dataset_size
@ -1743,7 +1747,7 @@ class Dataset:
if self._num_classes is None:
runtime_getter = self._init_tree_getters()
self._num_classes = runtime_getter[0].GetNumClasses()
self.close_pool()
runtime_getter[2].close_pool()
runtime_getter[2].notify_watchdog()
if self._num_classes == -1:
return None
@ -2656,6 +2660,9 @@ _LOCK = threading.Lock()
# Python multiprocessing library forbid sending lambda function through pipe.
# This init function allow us to add all Python function to a global collection and then fork afterwards.
def _pyfunc_worker_init(pyfunc_list, args_queue, ret_queue):
# Some threads in multiprocess.pool can't process sigint signal,
# and will occur hang problem, so ctrl+c will pass to parent process.
signal.signal(signal.SIGINT, signal.SIG_IGN)
global _GLOBAL_PYFUNC_LIST
global _ARGS_QUEUE
global _RET_QUEUE
@ -2765,6 +2772,7 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
self.eot = None
self.watch_dog = None
self.workers = []
self.ppid = os.getpid()
self.hook = None
def Launch(self, op_id=-1):
@ -2974,19 +2982,20 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
raise TypeError("[Internal Error] The 3rd parameter of watch dog thread should be multiprocessing.Pool, " \
"but got {}".format(type(pool)))
while not eot.is_set():
subprocess_exit_num = 0
clear_subprocess_timeout = 0
# Monitoring and count how many subprocesses already exit
subprocess_exit_num = _PythonMultiprocessing._monitor_subprocess_exit(workers)
clear_subprocess_timeout = _PythonMultiprocessing._monitor_subprocess_exit(workers)
# If find subprocess exit, we will wait for 30s and do some waitpid operations
if subprocess_exit_num > 0:
if clear_subprocess_timeout > 0:
if pool is not None:
# Python multiprocessing.pool has a bug, if sub process of pool is killed, pool will launch
# a new sub process, so we have to set worker_handler._state to TERMINATE to stop relaunching.
if pool._state == RUN: # pylint: disable=W0212
pool._state = TERMINATE # pylint: disable=W0212
pool._worker_handler._state = TERMINATE # pylint: disable=W0212
pool._worker_handler.join() # pylint: disable=W0212
start = time.time()
while time.time() - start < 30:
while time.time() - start < clear_subprocess_timeout:
# We need to distinguishing get_dataset_size or train finished normally and hang scenario.
# If get_dataset_size or train finished normally, _stop_subprocess can be execute and
# self.need_abort can be set to True. If main process is hang in get(), self.need_abort
@ -3002,7 +3011,8 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
else:
_PythonMultiprocessing._terminate_process(workers)
logger.critical("The subprocess of dataset may exit unexpected or be killed, "
"main process will exit.")
"main process will exit. If this is not an artificial operation, you can use "
"ds.config.set_enable_watchdog(False) to block this error.")
os.kill(os.getpid(), signal.SIGTERM)
@staticmethod
@ -3013,23 +3023,92 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
w.terminate()
for w in workers:
if w._closed is False: # pylint: disable=W0212
w.join()
# We don't use w.join because join can only used in main process or join will raise an error.
w._popen.wait() # pylint: disable=W0212
# Monitor the exit number of subprocesses
@staticmethod
def _monitor_subprocess_exit(workers):
subprocess_exit_num = 0
"""
To monitor whether process is exit.
Args:
workers (list of multiprocessing.Process): multiprocessing.Process.
Returns:
int, the timeout(in seconds) when process exit.
"""
for w in workers:
if w.exitcode is not None:
subprocess_exit_num += 1
return subprocess_exit_num
exit_code = w.exitcode
if exit_code is not None:
# For kill -9, we can exit quickly
if exit_code == -9:
return 1
# For kill -15, we still exit after 30s
if exit_code == -15:
return 30
return 0
# Monitor the exit status of main process
@staticmethod
def process_still_alive(ppid):
"""
We always hit dead lock when we use psutil or w.exitcode to check whether a process is still alive. So we use
os.kill(ppid, 0) as the best solution when we want to check whether process is still alive.
"""
try:
os.kill(ppid, 0)
except OSError:
return False
return True
# When main process exit, subprocesses will be terminate
@staticmethod
def _clean_process(ppid, workers, pool=None):
"""
This is the execute function of clean process, if we found main process is exit, we will clean subprocesses.
:param ppid: The process id of main process.
:param workers: The list of subprocesses.
:param pool: multiprocessing.Pool object, we can get list of subprocesses from _pool.
"""
signal.signal(signal.SIGINT, signal.SIG_IGN)
while _PythonMultiprocessing.process_still_alive(ppid):
time.sleep(0.1)
if pool is not None:
# Python multiprocessing.pool has a bug, if sub process of pool is killed, pool will launch
# a new sub process, so we have to set worker_handler._state to TERMINATE to stop relaunching.
# But this pool is not the same object as it in main process, so we don't support kill main process then
# kill subprocess.
if pool._state == RUN: # pylint: disable=W0212
pool._state = TERMINATE # pylint: disable=W0212
pool._worker_handler._state = TERMINATE # pylint: disable=W0212
pool._worker_handler.join() # pylint: disable=W0212
if pool is not None:
_PythonMultiprocessing._terminate_process(pool._pool) # pylint: disable=W0212
else:
_PythonMultiprocessing._terminate_process(workers)
os.kill(os.getpid(), signal.SIGTERM)
def _launch_watch_dog(self):
"""
We will launch a watchdog thread and a clean process to cleaning subprocess when there is process was killed.
The watchdog thread will cleanup subprocesses and main process when one of the subprocesses was killed.
The cleaning subprocess will cleanup subprocesses when main process was killed.
"""
if platform.system().lower() != 'windows':
self.eot = threading.Event()
self.watch_dog = threading.Thread(target=self._watch_dog, args=(self.eot, self.workers, self.process_pool))
self.watch_dog.daemon = True
self.watch_dog.start()
self.cleaning_process = multiprocessing.Process(target=self._clean_process,
args=(self.ppid, self.workers, self.process_pool))
self.cleaning_process.daemon = True
self.cleaning_process.start()
if get_enable_watchdog():
self.eot = threading.Event()
self.watch_dog = threading.Thread(target=self._watch_dog,
args=(self.eot, self.workers + [self.cleaning_process],
self.process_pool))
self.watch_dog.daemon = True
self.watch_dog.start()
def _abort_watchdog(self):
if not self.eot.is_set():
@ -3038,6 +3117,8 @@ class _PythonMultiprocessing(cde.PythonMultiprocessingRuntime):
def abort_watchdog(self):
if hasattr(self, 'watch_dog') and self.watch_dog is not None and hasattr(self, 'eot') and self.eot is not None:
self._abort_watchdog()
if hasattr(self, 'cleaning_process') and self.cleaning_process is not None:
_PythonMultiprocessing._terminate_process([self.cleaning_process])
def is_running(self):
# note here: the RUN state of python3.7 and python3.8 is different:

View File

@ -45,7 +45,8 @@ from .datasets import UnionBaseDataset, MappableDataset, Schema, to_list, _Pytho
from . import samplers
from .queue import _SharedQueue
from .validators import check_generatordataset, check_numpyslicesdataset, check_paddeddataset
from ..core.config import get_enable_shared_mem, get_prefetch_size
from ..core.config import get_enable_shared_mem, get_prefetch_size, get_multiprocessing_timeout_interval, \
get_enable_watchdog
from ..core.datatypes import mstypelist_to_detypelist
from ..core.py_util_helpers import ExceptionHandler
@ -161,7 +162,7 @@ class SamplerFn:
self.need_join = False
self.ppid = os.getpid()
self.pids = []
self.check_interval = 300 # the interval of check queue's size
self.check_interval = get_multiprocessing_timeout_interval() # the interval of check queue's size
self._final_join = True
# Event for end of epoch
@ -185,7 +186,7 @@ class SamplerFn:
for _ in range(num_worker):
if multi_process is True:
try:
worker = _GeneratorWorkerMp(dataset, self.eof, max_rowsize, queue_size)
worker = _GeneratorWorkerMp(dataset, self.eof, max_rowsize, queue_size, self.ppid)
except Exception:
raise RuntimeError("Init multiprocessing.Queue() failed, This might be caused by insufficient shm,"
+ " and the recommended shm size is at least 5 GB.")
@ -200,19 +201,7 @@ class SamplerFn:
worker = _GeneratorWorkerMt(dataset, self.eof)
worker.daemon = True
self.workers.append(worker)
if multi_process is True and platform.system().lower() != 'windows':
self.eot = threading.Event()
self.watch_dog = threading.Thread(target=_PythonMultiprocessing._watch_dog, # pylint: disable=W0212
args=(self.eot, self.workers))
self.watch_dog.daemon = True
self.watch_dog.start()
if self._final_join is True:
self._jointhread = Finalize(
self.watch_dog, self._finalize_join,
args=(weakref.ref(self.watch_dog), self.eot),
exitpriority=-5
)
self._launch_cleanup_worker(multi_process=multi_process)
def process(self, indices):
"""
@ -250,7 +239,11 @@ class SamplerFn:
if cost_time / self.check_interval >= wait_count:
wait_count += 1
logger.warning("It has been waiting for " + str(cost_time) + "s because the multi "
"thread/process of the generator generates data had been hung by gil lock.")
"thread/process of the generator generates data had been hung by gil lock. "
"Check whether the source of generator has an infinite loop operation or the "
"output data is too large. You can also set the timeout interval by "
"ds.config.set_multiprocessing_interval to adjust the output frequency of this "
"log.")
result = self.workers[i % self.num_worker].get()
if isinstance(result, ExceptionHandler):
@ -268,6 +261,33 @@ class SamplerFn:
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
yield _convert_row(result)
def _launch_cleanup_worker(self, multi_process):
"""
We need a extra thread and process if main process or subprocess was killed.
Args:
multi_process: Whether use multiprocess.
"""
if multi_process is True and platform.system().lower() != 'windows':
_clean_worker_func = _PythonMultiprocessing._clean_process # pylint: disable=W0212
self.cleaning_process = multiprocessing.Process(target=_clean_worker_func, args=(self.ppid, self.workers))
self.cleaning_process.daemon = True
self.cleaning_process.start()
if get_enable_watchdog():
self.eot = threading.Event()
self.watch_dog = threading.Thread(target=_PythonMultiprocessing._watch_dog, # pylint: disable=W0212
args=(self.eot, self.workers + [self.cleaning_process]))
self.watch_dog.daemon = True
self.watch_dog.start()
if self._final_join is True:
self._jointhread = Finalize(
self.watch_dog, self._finalize_join,
args=(weakref.ref(self.watch_dog), self.eot),
exitpriority=-5
)
def _stop_subprocess(self):
"""Only the main process can call join."""
if self.need_join is True and self.ppid == os.getpid():
@ -281,6 +301,8 @@ class SamplerFn:
def _abort_watchdog(self):
if hasattr(self, 'eot') and self.eot is not None and not self.eot.is_set():
self.eot.set()
if hasattr(self, 'cleaning_process') and self.cleaning_process is not None:
_PythonMultiprocessing._terminate_process([self.cleaning_process]) # pylint: disable=W0212
@classmethod
def _finalize_join(cls, twr, eot):
@ -306,7 +328,7 @@ def _ignore_sigint(is_multiprocessing):
signal.signal(signal.SIGINT, signal.SIG_IGN)
def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing):
def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing, ppid=-1):
"""
Multithread or multiprocess generator worker process loop.
"""
@ -318,14 +340,8 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiproces
# Fetch index, block
try:
idx = idx_queue.get(timeout=1)
except KeyboardInterrupt:
if is_multiprocessing:
eof.set()
idx_queue.cancel_join_thread()
result_queue.cancel_join_thread()
raise Exception("Generator worker receives KeyboardInterrupt.")
except queue.Empty:
if eof.is_set():
if eof.is_set() or (is_multiprocessing and not _PythonMultiprocessing.process_still_alive(ppid)):
if is_multiprocessing:
idx_queue.cancel_join_thread()
result_queue.cancel_join_thread()
@ -352,14 +368,8 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiproces
while True:
try:
result_queue.put(result, timeout=5)
except KeyboardInterrupt:
if is_multiprocessing:
eof.set()
idx_queue.cancel_join_thread()
result_queue.cancel_join_thread()
raise Exception("Generator worker receives KeyboardInterrupt.")
except queue.Full:
if eof.is_set():
if eof.is_set() or (is_multiprocessing and not _PythonMultiprocessing.process_still_alive(ppid)):
if is_multiprocessing:
idx_queue.cancel_join_thread()
result_queue.cancel_join_thread()
@ -407,7 +417,7 @@ class _GeneratorWorkerMp(multiprocessing.Process):
Worker process for multiprocess Generator.
"""
def __init__(self, dataset, eof, max_rowsize, queue_size):
def __init__(self, dataset, eof, max_rowsize, queue_size, ppid):
self.idx_queue = multiprocessing.Queue(queue_size)
if get_enable_shared_mem():
self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize)
@ -415,7 +425,7 @@ class _GeneratorWorkerMp(multiprocessing.Process):
self.res_queue = multiprocessing.Queue(queue_size)
self.idx_queue._joincancelled = True # pylint: disable=W0212
self.res_queue._joincancelled = True # pylint: disable=W0212
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True))
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True, ppid))
def put(self, item):
"""

View File

@ -394,6 +394,41 @@ def test_auto_num_workers():
assert saved_config == ds.config.get_auto_num_workers()
def test_enable_watchdog():
"""
Feature: Test the function of get_enable_watchdog and set_enable_watchdog.
Description: We add this new interface so we can close the watchdog thread
Expectation: The default state is True, when execute set_enable_watchdog, the state will update.
"""
saved_config = ds.config.get_enable_watchdog()
assert isinstance(saved_config, bool)
assert saved_config is True
# change to a different config
flipped_config = not saved_config
ds.config.set_enable_watchdog(flipped_config)
assert flipped_config == ds.config.get_enable_watchdog()
# now flip this back
ds.config.set_enable_watchdog(saved_config)
assert saved_config == ds.config.get_enable_watchdog()
def test_multiprocessing_timeout_interval():
"""
Feature: Test the function of get_multiprocessing_timeout_interval and set_multiprocessing_timeout_interval.
Description: We add this new interface so we can adjust the timeout of multiprocessing get function.
Expectation: The default state is 300s, when execute set_multiprocessing_timeout_interval, the state will update.
"""
saved_config = ds.config.get_multiprocessing_timeout_interval()
assert saved_config == 300
# change to a different config
flipped_config = 1000
ds.config.set_multiprocessing_timeout_interval(flipped_config)
assert flipped_config == ds.config.get_multiprocessing_timeout_interval()
# now flip this back
ds.config.set_multiprocessing_timeout_interval(saved_config)
assert saved_config == ds.config.get_multiprocessing_timeout_interval()
if __name__ == '__main__':
test_basic()
test_get_seed()
@ -406,3 +441,5 @@ if __name__ == '__main__':
test_deterministic_python_seed_multi_thread()
test_auto_num_workers_error()
test_auto_num_workers()
test_enable_watchdog()
test_multiprocessing_timeout_interval()