!18673 [MD] Fix crash with large max_rowsize and add argument validations

Merge pull request !18673 from harshvardhangupta/max_rowsize_fix
This commit is contained in:
i-robot 2021-06-23 13:32:17 +00:00 committed by Gitee
commit 35c1f14cf3
4 changed files with 33 additions and 4 deletions

View File

@ -266,7 +266,7 @@ def set_auto_num_workers(enable):
>>> ds.config.set_auto_num_workers(True)
"""
if not isinstance(enable, bool):
raise TypeError("enable isn't of type bool.")
raise TypeError("enable must be of type bool.")
_config.set_auto_num_workers(enable)
@ -411,6 +411,8 @@ def set_enable_shared_mem(enable):
Examples:
>>> ds.config.set_enable_shared_mem(True)
"""
if not isinstance(enable, bool):
raise TypeError("enable must be of type bool.")
_config.set_enable_shared_mem(enable)
def set_sending_batches(batch_num):
@ -429,5 +431,5 @@ def set_sending_batches(batch_num):
>>> ds.config.set_sending_batches(10)
"""
if not isinstance(batch_num, int):
raise TypeError("batch_num must be a int dtype.")
raise TypeError("batch_num must be an int dtype.")
_config.set_sending_batches(batch_num)

View File

@ -38,6 +38,8 @@ import threading
import copy
import weakref
import platform
import psutil
import numpy as np
import mindspore._c_dataengine as cde
@ -45,6 +47,7 @@ from mindspore._c_expression import typing
from mindspore import log as logger
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from mindspore.parallel._utils import _get_device_num
import mindspore.dataset.transforms.py_transforms as py_transforms
@ -2168,6 +2171,7 @@ class BatchDataset(Dataset):
num_parallel = get_num_parallel_workers()
if get_enable_shared_mem():
_check_shm_usage(num_parallel, 1, self.max_rowsize * self.batch_size, 2)
for _ in range(num_parallel):
arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size))
res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size))
@ -2631,6 +2635,7 @@ class MapDataset(Dataset):
num_parallel = get_num_parallel_workers()
if get_enable_shared_mem():
_check_shm_usage(num_parallel, 1, self.max_rowsize, 2)
for _ in range(num_parallel):
arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize))
res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize))
@ -3542,6 +3547,26 @@ def _fill_worker_indices(workers, indices, idx):
return idx
def _check_shm_usage(num_worker, queue_size, max_rowsize, num_queues=1):
"""
Check sufficient shared memory is available for shared memory queues
when training in parallel mode.
"""
threshold_ratio = 0.8
if platform.system() != "Windows" and _get_device_num() > 1:
shm_estimate_usage = _get_device_num() * num_worker * num_queues * \
(queue_size + 2) * max_rowsize * 1024 * 1024
try:
shm_available = psutil.disk_usage('/dev/shm').free
if shm_estimate_usage >= threshold_ratio * shm_available:
raise RuntimeError(
"Insufficient shared memory available. Required: {}, Available: {}. "
"Recommend to set_enable_shared_mem to False, reduce max_rowsize or reduce num_parallel_workers."
.format(shm_estimate_usage, shm_available))
except FileNotFoundError:
logger.warning("Expected /dev/shm to exist.")
class SamplerFn:
"""
Multiprocessing or multithread generator function wrapper master process.
@ -3570,6 +3595,8 @@ class SamplerFn:
queue_size = min(queue_size, queue_size * 4 // num_worker)
queue_size = max(2, queue_size)
if multi_process and get_enable_shared_mem():
_check_shm_usage(num_worker, queue_size, max_rowsize)
for _ in range(num_worker):
if multi_process is True:
try:

View File

@ -391,7 +391,7 @@ def check_generatordataset(method):
raise ValueError("schema should be a path to schema file or a schema object.")
# check optional argument
nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"]
nreq_param_int = ["max_rowsize", "num_samples", "num_parallel_workers", "num_shards", "shard_id"]
validate_dataset_param_value(nreq_param_int, param_dict, int)
nreq_param_list = ["column_types"]
validate_dataset_param_value(nreq_param_list, param_dict, list)

View File

@ -364,7 +364,7 @@ def test_auto_num_workers_error():
except TypeError as e:
err_msg = str(e)
assert "isn't of type bool" in err_msg
assert "must be of type bool" in err_msg
def test_auto_num_workers():