!18673 [MD] Fix crash with large max_rowsize and add argument validations
Merge pull request !18673 from harshvardhangupta/max_rowsize_fix
This commit is contained in:
commit
35c1f14cf3
|
@ -266,7 +266,7 @@ def set_auto_num_workers(enable):
|
|||
>>> ds.config.set_auto_num_workers(True)
|
||||
"""
|
||||
if not isinstance(enable, bool):
|
||||
raise TypeError("enable isn't of type bool.")
|
||||
raise TypeError("enable must be of type bool.")
|
||||
_config.set_auto_num_workers(enable)
|
||||
|
||||
|
||||
|
@ -411,6 +411,8 @@ def set_enable_shared_mem(enable):
|
|||
Examples:
|
||||
>>> ds.config.set_enable_shared_mem(True)
|
||||
"""
|
||||
if not isinstance(enable, bool):
|
||||
raise TypeError("enable must be of type bool.")
|
||||
_config.set_enable_shared_mem(enable)
|
||||
|
||||
def set_sending_batches(batch_num):
|
||||
|
@ -429,5 +431,5 @@ def set_sending_batches(batch_num):
|
|||
>>> ds.config.set_sending_batches(10)
|
||||
"""
|
||||
if not isinstance(batch_num, int):
|
||||
raise TypeError("batch_num must be a int dtype.")
|
||||
raise TypeError("batch_num must be an int dtype.")
|
||||
_config.set_sending_batches(batch_num)
|
||||
|
|
|
@ -38,6 +38,8 @@ import threading
|
|||
|
||||
import copy
|
||||
import weakref
|
||||
import platform
|
||||
import psutil
|
||||
import numpy as np
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
|
@ -45,6 +47,7 @@ from mindspore._c_expression import typing
|
|||
|
||||
from mindspore import log as logger
|
||||
from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
|
||||
from mindspore.parallel._utils import _get_device_num
|
||||
|
||||
import mindspore.dataset.transforms.py_transforms as py_transforms
|
||||
|
||||
|
@ -2168,6 +2171,7 @@ class BatchDataset(Dataset):
|
|||
num_parallel = get_num_parallel_workers()
|
||||
|
||||
if get_enable_shared_mem():
|
||||
_check_shm_usage(num_parallel, 1, self.max_rowsize * self.batch_size, 2)
|
||||
for _ in range(num_parallel):
|
||||
arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size))
|
||||
res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size))
|
||||
|
@ -2631,6 +2635,7 @@ class MapDataset(Dataset):
|
|||
num_parallel = get_num_parallel_workers()
|
||||
|
||||
if get_enable_shared_mem():
|
||||
_check_shm_usage(num_parallel, 1, self.max_rowsize, 2)
|
||||
for _ in range(num_parallel):
|
||||
arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize))
|
||||
res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize))
|
||||
|
@ -3542,6 +3547,26 @@ def _fill_worker_indices(workers, indices, idx):
|
|||
return idx
|
||||
|
||||
|
||||
def _check_shm_usage(num_worker, queue_size, max_rowsize, num_queues=1):
|
||||
"""
|
||||
Check sufficient shared memory is available for shared memory queues
|
||||
when training in parallel mode.
|
||||
"""
|
||||
threshold_ratio = 0.8
|
||||
if platform.system() != "Windows" and _get_device_num() > 1:
|
||||
shm_estimate_usage = _get_device_num() * num_worker * num_queues * \
|
||||
(queue_size + 2) * max_rowsize * 1024 * 1024
|
||||
try:
|
||||
shm_available = psutil.disk_usage('/dev/shm').free
|
||||
if shm_estimate_usage >= threshold_ratio * shm_available:
|
||||
raise RuntimeError(
|
||||
"Insufficient shared memory available. Required: {}, Available: {}. "
|
||||
"Recommend to set_enable_shared_mem to False, reduce max_rowsize or reduce num_parallel_workers."
|
||||
.format(shm_estimate_usage, shm_available))
|
||||
except FileNotFoundError:
|
||||
logger.warning("Expected /dev/shm to exist.")
|
||||
|
||||
|
||||
class SamplerFn:
|
||||
"""
|
||||
Multiprocessing or multithread generator function wrapper master process.
|
||||
|
@ -3570,6 +3595,8 @@ class SamplerFn:
|
|||
queue_size = min(queue_size, queue_size * 4 // num_worker)
|
||||
queue_size = max(2, queue_size)
|
||||
|
||||
if multi_process and get_enable_shared_mem():
|
||||
_check_shm_usage(num_worker, queue_size, max_rowsize)
|
||||
for _ in range(num_worker):
|
||||
if multi_process is True:
|
||||
try:
|
||||
|
|
|
@ -391,7 +391,7 @@ def check_generatordataset(method):
|
|||
raise ValueError("schema should be a path to schema file or a schema object.")
|
||||
|
||||
# check optional argument
|
||||
nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"]
|
||||
nreq_param_int = ["max_rowsize", "num_samples", "num_parallel_workers", "num_shards", "shard_id"]
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
nreq_param_list = ["column_types"]
|
||||
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
||||
|
|
|
@ -364,7 +364,7 @@ def test_auto_num_workers_error():
|
|||
except TypeError as e:
|
||||
err_msg = str(e)
|
||||
|
||||
assert "isn't of type bool" in err_msg
|
||||
assert "must be of type bool" in err_msg
|
||||
|
||||
|
||||
def test_auto_num_workers():
|
||||
|
|
Loading…
Reference in New Issue