From e4894a1bf1bc9a5dda8489375a4d32c3093abbd1 Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Thu, 12 Aug 2021 20:12:43 +0800 Subject: [PATCH] fix batch usability problem --- mindspore/dataset/engine/datasets.py | 3 +++ tests/ut/python/dataset/test_batch.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index e9c352ecc5c..a9f52955d88 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2151,6 +2151,9 @@ class BatchDataset(Dataset): Per iterator bootstrap callback. """ if self.python_multiprocessing: + if self.per_batch_map is None: + logger.warning("per_batch_map is None so python_multiprocessing does not work.") + return arg_q_list = [] res_q_list = [] diff --git a/tests/ut/python/dataset/test_batch.py b/tests/ut/python/dataset/test_batch.py index 692c3f640ef..7044de4cec0 100644 --- a/tests/ut/python/dataset/test_batch.py +++ b/tests/ut/python/dataset/test_batch.py @@ -238,6 +238,23 @@ def test_batch_12(): save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) +def test_batch_13(): + """ + Test batch: python_multiprocessing is True and does not work for per_batch_map is None + """ + logger.info("test_batch_12") + # define parameters + batch_size = True + + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES) + data1 = data1.batch(batch_size=batch_size, python_multiprocessing=True) + + assert sum([1 for _ in data1]) == 12 + filename = "batch_12_result.npz" + save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN) + + def test_batch_exception_01(): """ Test batch exception: num_parallel_workers=0 @@ -493,6 +510,7 @@ if __name__ == '__main__': test_batch_10() test_batch_11() test_batch_12() + test_batch_13() test_batch_exception_01() test_batch_exception_02() test_batch_exception_03()