diff --git a/mindspore/ccsrc/minddata/dataset/core/config_manager.cc b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc index c705734af62..fedd82260be 100644 --- a/mindspore/ccsrc/minddata/dataset/core/config_manager.cc +++ b/mindspore/ccsrc/minddata/dataset/core/config_manager.cc @@ -50,7 +50,7 @@ ConfigManager::ConfigManager() num_cpu_threads_(std::thread::hardware_concurrency()), auto_num_workers_num_shards_(1), auto_worker_config_(0), - enable_shared_mem_(false) { + enable_shared_mem_(true) { num_cpu_threads_ = num_cpu_threads_ > 0 ? num_cpu_threads_ : std::numeric_limits::max(); num_parallel_workers_ = num_parallel_workers_ < num_cpu_threads_ ? num_parallel_workers_ : num_cpu_threads_; auto env_cache_host = std::getenv("MS_CACHE_HOST"); diff --git a/mindspore/dataset/core/config.py b/mindspore/dataset/core/config.py index 190928e068d..ac4718af42d 100644 --- a/mindspore/dataset/core/config.py +++ b/mindspore/dataset/core/config.py @@ -381,7 +381,7 @@ def get_enable_shared_mem(): Returns: - bool, the state of shared mem enabled variable (default: True). + bool, the state of shared mem enabled variable (default=True). """ return _config.get_enable_shared_mem() diff --git a/mindspore/dataset/engine/queue.py b/mindspore/dataset/engine/queue.py index 2c1911af649..b30df480418 100644 --- a/mindspore/dataset/engine/queue.py +++ b/mindspore/dataset/engine/queue.py @@ -66,7 +66,7 @@ class _SharedQueue(multiprocessing.queues.Queue): if (isinstance(r, np.ndarray) and r.size > self.min_shared_mem and start_bytes + r.nbytes < self.seg_size): ##need to convert start_bytes to offset in array - start_offset = start_bytes // r.dtype.itemsize + start_offset = start_bytes dest = np.ndarray(r.shape, r.dtype, buffer=self.shm_list[self.seg_pos].get_obj(), offset=start_offset) np.copyto(dest, r) byte = r.nbytes @@ -101,7 +101,7 @@ class _SharedQueue(multiprocessing.queues.Queue): byte = x[2] dtype = x[3] shape = x[4] - start_offset = start_bytes // dtype.itemsize + start_offset = start_bytes b = self.shm_list[seg_pos] data = np.ndarray(shape, dtype, buffer=b.get_obj(), offset=start_offset) start_bytes += byte diff --git a/tests/ut/python/dataset/test_datasets_generator.py b/tests/ut/python/dataset/test_datasets_generator.py index 751d45a0d00..9dcf3c8c422 100644 --- a/tests/ut/python/dataset/test_datasets_generator.py +++ b/tests/ut/python/dataset/test_datasets_generator.py @@ -36,6 +36,15 @@ class DatasetGenerator: def __len__(self): return 10 +class DatasetGeneratorLarge: + def __init__(self): + self.data = np.array(range(4000)) + + def __getitem__(self, item): + return (self.data + item, self.data *10) + + def __len__(self): + return 10 def test_generator_0(): @@ -503,6 +512,26 @@ def test_generator_18(): golden = np.array([i * 5]) np.testing.assert_array_equal(item["out0"], golden) +def test_generator_19(): + """ + Test multiprocessing flag with 2 different large columns + """ + logger.info("Test map column order when input_columns is None.") + + # apply dataset operations + data1 = ds.GeneratorDataset(DatasetGeneratorLarge(), ["col0", "col1"], python_multiprocessing=True, shuffle=False) + + # Expected column order is |out0|col1| + i = 0 + for item in data1.create_tuple_iterator(num_epochs=1, output_numpy=True): + assert len(item) == 2 + golden = np.array(range(4000)) + i + np.testing.assert_array_equal(item[0], golden) + golden = np.array(range(4000)) * 10 + np.testing.assert_array_equal(item[1], golden) + i = i + 1 + + def test_generator_error_1(): def generator_np(): for i in range(64): @@ -804,6 +833,7 @@ if __name__ == "__main__": test_generator_16() test_generator_17() test_generator_18() + test_generator_19() test_generator_error_1() test_generator_error_2() test_generator_error_3() diff --git a/tests/ut/python/dataset/test_pyfunc.py b/tests/ut/python/dataset/test_pyfunc.py index 6b234c287a7..66ef9f0e2fb 100644 --- a/tests/ut/python/dataset/test_pyfunc.py +++ b/tests/ut/python/dataset/test_pyfunc.py @@ -341,6 +341,7 @@ if __name__ == "__main__": test_case_7() test_case_8() test_case_9() + test_case_10() test_pyfunc_implicit_compose() test_pyfunc_exception() skip_test_pyfunc_exception_multiprocess()