diff --git a/mindspore/python/mindspore/dataset/engine/datasets.py b/mindspore/python/mindspore/dataset/engine/datasets.py index 195cc559ddb..7764b11ec45 100644 --- a/mindspore/python/mindspore/dataset/engine/datasets.py +++ b/mindspore/python/mindspore/dataset/engine/datasets.py @@ -5390,21 +5390,15 @@ class GeneratorDataset(MappableDataset, TextBaseDataset): # get process memory usage process = psutil.Process(os.getpid()) process_memory = process.memory_info().rss - sys_memory = psutil.virtual_memory().total + sys_memory_free = psutil.virtual_memory().free - total_memory_maybe_used = process_memory * (self.num_parallel_workers + 1) * valid_num_shards - if total_memory_maybe_used / sys_memory > 0.85: - valid_num_worker = math.floor(sys_memory * 0.85 / valid_num_shards / process_memory - 1) + total_memory_maybe_used = process_memory * self.num_parallel_workers * valid_num_shards + if total_memory_maybe_used / sys_memory_free > 0.85: + valid_num_worker = math.floor(sys_memory_free * 0.85 / valid_num_shards / process_memory) valid_num_worker = 1 if valid_num_worker <= 0 else valid_num_worker - if total_memory_maybe_used / sys_memory > 1.0: - info = "GeneratorDataset num_parallel_workers: " + str(self.num_parallel_workers) + \ - " is too large which maybe cause a lot of memory occupation (>100%) during" \ - " multi process running. Therefore, it is recommended to" \ - " reduce num_parallel_workers to " + str(valid_num_worker) + " or smaller." - raise RuntimeError(info) info = "GeneratorDataset num_parallel_workers: " + str(self.num_parallel_workers) + \ - " is too large which maybe cause a lot of memory occupation (>85%) during multi " \ - "process running. Therefore, it is recommended to reduce num_parallel_workers to " \ + " is too large which maybe cause a lot of memory occupation (>85%) or out of memory(OOM) " \ + "during multi process running. Therefore, it is recommended to reduce num_parallel_workers to " \ + str(valid_num_worker) + " or smaller." logger.warning(info)