fix coredump with large max_rowsize and add arg validations

This commit is contained in:
Harshvardhan Gupta 2021-06-21 12:37:15 -04:00
parent f98497ca09
commit 6845a16b44
4 changed files with 33 additions and 4 deletions

View File

@ -266,7 +266,7 @@ def set_auto_num_workers(enable):
>>> ds.config.set_auto_num_workers(True) >>> ds.config.set_auto_num_workers(True)
""" """
if not isinstance(enable, bool): if not isinstance(enable, bool):
raise TypeError("enable isn't of type bool.") raise TypeError("enable must be of type bool.")
_config.set_auto_num_workers(enable) _config.set_auto_num_workers(enable)
@ -411,6 +411,8 @@ def set_enable_shared_mem(enable):
Examples: Examples:
>>> ds.config.set_enable_shared_mem(True) >>> ds.config.set_enable_shared_mem(True)
""" """
if not isinstance(enable, bool):
raise TypeError("enable must be of type bool.")
_config.set_enable_shared_mem(enable) _config.set_enable_shared_mem(enable)
def set_sending_batches(batch_num): def set_sending_batches(batch_num):
@ -429,5 +431,5 @@ def set_sending_batches(batch_num):
>>> ds.config.set_sending_batches(10) >>> ds.config.set_sending_batches(10)
""" """
if not isinstance(batch_num, int): if not isinstance(batch_num, int):
raise TypeError("batch_num must be a int dtype.") raise TypeError("batch_num must be an int dtype.")
_config.set_sending_batches(batch_num) _config.set_sending_batches(batch_num)

View File

@ -38,6 +38,8 @@ import threading
import copy import copy
import weakref import weakref
import platform
import psutil
import numpy as np import numpy as np
import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
@ -45,6 +47,7 @@ from mindspore._c_expression import typing
from mindspore import log as logger from mindspore import log as logger
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from mindspore.parallel._utils import _get_device_num
import mindspore.dataset.transforms.py_transforms as py_transforms import mindspore.dataset.transforms.py_transforms as py_transforms
@ -2168,6 +2171,7 @@ class BatchDataset(Dataset):
num_parallel = get_num_parallel_workers() num_parallel = get_num_parallel_workers()
if get_enable_shared_mem(): if get_enable_shared_mem():
_check_shm_usage(num_parallel, 1, self.max_rowsize * self.batch_size, 2)
for _ in range(num_parallel): for _ in range(num_parallel):
arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size)) arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size))
res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size)) res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size))
@ -2631,6 +2635,7 @@ class MapDataset(Dataset):
num_parallel = get_num_parallel_workers() num_parallel = get_num_parallel_workers()
if get_enable_shared_mem(): if get_enable_shared_mem():
_check_shm_usage(num_parallel, 1, self.max_rowsize, 2)
for _ in range(num_parallel): for _ in range(num_parallel):
arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize)) arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize))
res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize)) res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize))
@ -3542,6 +3547,26 @@ def _fill_worker_indices(workers, indices, idx):
return idx return idx
def _check_shm_usage(num_worker, queue_size, max_rowsize, num_queues=1):
"""
Check sufficient shared memory is available for shared memory queues
when training in parallel mode.
"""
threshold_ratio = 0.8
if platform.system() != "Windows" and _get_device_num() > 1:
shm_estimate_usage = _get_device_num() * num_worker * num_queues * \
(queue_size + 2) * max_rowsize * 1024 * 1024
try:
shm_available = psutil.disk_usage('/dev/shm').free
if shm_estimate_usage >= threshold_ratio * shm_available:
raise RuntimeError(
"Insufficient shared memory available. Required: {}, Available: {}. "
"Recommend to set_enable_shared_mem to False, reduce max_rowsize or reduce num_parallel_workers."
.format(shm_estimate_usage, shm_available))
except FileNotFoundError:
logger.warning("Expected /dev/shm to exist.")
class SamplerFn: class SamplerFn:
""" """
Multiprocessing or multithread generator function wrapper master process. Multiprocessing or multithread generator function wrapper master process.
@ -3570,6 +3595,8 @@ class SamplerFn:
queue_size = min(queue_size, queue_size * 4 // num_worker) queue_size = min(queue_size, queue_size * 4 // num_worker)
queue_size = max(2, queue_size) queue_size = max(2, queue_size)
if multi_process and get_enable_shared_mem():
_check_shm_usage(num_worker, queue_size, max_rowsize)
for _ in range(num_worker): for _ in range(num_worker):
if multi_process is True: if multi_process is True:
try: try:

View File

@ -391,7 +391,7 @@ def check_generatordataset(method):
raise ValueError("schema should be a path to schema file or a schema object.") raise ValueError("schema should be a path to schema file or a schema object.")
# check optional argument # check optional argument
nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"] nreq_param_int = ["max_rowsize", "num_samples", "num_parallel_workers", "num_shards", "shard_id"]
validate_dataset_param_value(nreq_param_int, param_dict, int) validate_dataset_param_value(nreq_param_int, param_dict, int)
nreq_param_list = ["column_types"] nreq_param_list = ["column_types"]
validate_dataset_param_value(nreq_param_list, param_dict, list) validate_dataset_param_value(nreq_param_list, param_dict, list)

View File

@ -364,7 +364,7 @@ def test_auto_num_workers_error():
except TypeError as e: except TypeError as e:
err_msg = str(e) err_msg = str(e)
assert "isn't of type bool" in err_msg assert "must be of type bool" in err_msg
def test_auto_num_workers(): def test_auto_num_workers():