diff --git a/mindspore/dataset/core/config.py b/mindspore/dataset/core/config.py index f5111ea7ec4..55677813362 100644 --- a/mindspore/dataset/core/config.py +++ b/mindspore/dataset/core/config.py @@ -266,7 +266,7 @@ def set_auto_num_workers(enable): >>> ds.config.set_auto_num_workers(True) """ if not isinstance(enable, bool): - raise TypeError("enable isn't of type bool.") + raise TypeError("enable must be of type bool.") _config.set_auto_num_workers(enable) @@ -411,6 +411,8 @@ def set_enable_shared_mem(enable): Examples: >>> ds.config.set_enable_shared_mem(True) """ + if not isinstance(enable, bool): + raise TypeError("enable must be of type bool.") _config.set_enable_shared_mem(enable) def set_sending_batches(batch_num): @@ -429,5 +431,5 @@ def set_sending_batches(batch_num): >>> ds.config.set_sending_batches(10) """ if not isinstance(batch_num, int): - raise TypeError("batch_num must be a int dtype.") + raise TypeError("batch_num must be an int dtype.") _config.set_sending_batches(batch_num) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 7ecda0e038e..510a3d38aab 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -38,6 +38,8 @@ import threading import copy import weakref +import platform +import psutil import numpy as np import mindspore._c_dataengine as cde @@ -45,6 +47,7 @@ from mindspore._c_expression import typing from mindspore import log as logger from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched +from mindspore.parallel._utils import _get_device_num import mindspore.dataset.transforms.py_transforms as py_transforms @@ -2168,6 +2171,7 @@ class BatchDataset(Dataset): num_parallel = get_num_parallel_workers() if get_enable_shared_mem(): + _check_shm_usage(num_parallel, 1, self.max_rowsize * self.batch_size, 2) for _ in range(num_parallel): arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size)) res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize * self.batch_size)) @@ -2631,6 +2635,7 @@ class MapDataset(Dataset): num_parallel = get_num_parallel_workers() if get_enable_shared_mem(): + _check_shm_usage(num_parallel, 1, self.max_rowsize, 2) for _ in range(num_parallel): arg_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize)) res_q_list.append(_SharedQueue(1, max_rowsize=self.max_rowsize)) @@ -3542,6 +3547,26 @@ def _fill_worker_indices(workers, indices, idx): return idx +def _check_shm_usage(num_worker, queue_size, max_rowsize, num_queues=1): + """ + Check sufficient shared memory is available for shared memory queues + when training in parallel mode. + """ + threshold_ratio = 0.8 + if platform.system() != "Windows" and _get_device_num() > 1: + shm_estimate_usage = _get_device_num() * num_worker * num_queues * \ + (queue_size + 2) * max_rowsize * 1024 * 1024 + try: + shm_available = psutil.disk_usage('/dev/shm').free + if shm_estimate_usage >= threshold_ratio * shm_available: + raise RuntimeError( + "Insufficient shared memory available. Required: {}, Available: {}. " + "Recommend to set_enable_shared_mem to False, reduce max_rowsize or reduce num_parallel_workers." + .format(shm_estimate_usage, shm_available)) + except FileNotFoundError: + logger.warning("Expected /dev/shm to exist.") + + class SamplerFn: """ Multiprocessing or multithread generator function wrapper master process. @@ -3570,6 +3595,8 @@ class SamplerFn: queue_size = min(queue_size, queue_size * 4 // num_worker) queue_size = max(2, queue_size) + if multi_process and get_enable_shared_mem(): + _check_shm_usage(num_worker, queue_size, max_rowsize) for _ in range(num_worker): if multi_process is True: try: diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 19e51ec7419..e1c5a4ce5cd 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -391,7 +391,7 @@ def check_generatordataset(method): raise ValueError("schema should be a path to schema file or a schema object.") # check optional argument - nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"] + nreq_param_int = ["max_rowsize", "num_samples", "num_parallel_workers", "num_shards", "shard_id"] validate_dataset_param_value(nreq_param_int, param_dict, int) nreq_param_list = ["column_types"] validate_dataset_param_value(nreq_param_list, param_dict, list) diff --git a/tests/ut/python/dataset/test_config.py b/tests/ut/python/dataset/test_config.py index fdeab39e3b3..08b20a28fe9 100644 --- a/tests/ut/python/dataset/test_config.py +++ b/tests/ut/python/dataset/test_config.py @@ -364,7 +364,7 @@ def test_auto_num_workers_error(): except TypeError as e: err_msg = str(e) - assert "isn't of type bool" in err_msg + assert "must be of type bool" in err_msg def test_auto_num_workers():