From 53c01c437d8d2ed810bd3bed90bc1a9c013ee6a5 Mon Sep 17 00:00:00 2001 From: Xiao Tianci Date: Thu, 22 Jul 2021 11:31:37 +0800 Subject: [PATCH] fix NumpySlicesDataset not accept numpy --- mindspore/dataset/engine/validators.py | 2 +- tests/ut/python/dataset/test_dataset_numpy_slices.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index e89cbe451a9..bf688a8e577 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -1260,7 +1260,7 @@ def check_numpyslicesdataset(method): data = param_dict.get("data") column_names = param_dict.get("column_names") - if not data: + if data is None or len(data) == 0: # pylint: disable=len-as-condition raise ValueError("Argument data cannot be empty") type_check(data, (list, tuple, dict, np.ndarray), "data") if isinstance(data, tuple): diff --git a/tests/ut/python/dataset/test_dataset_numpy_slices.py b/tests/ut/python/dataset/test_dataset_numpy_slices.py index f61c8db903b..4d36db50448 100644 --- a/tests/ut/python/dataset/test_dataset_numpy_slices.py +++ b/tests/ut/python/dataset/test_dataset_numpy_slices.py @@ -51,6 +51,16 @@ def test_numpy_slices_list_3(): assert np.equal(data[0].asnumpy(), np_data[i]).all() +def test_numpy_slices_numpy(): + logger.info("Test NumPy structure data.") + + np_data = np.array([[[1, 1], [2, 2]], [[3, 3], [4, 4]]]) + ds = de.NumpySlicesDataset(np_data, column_names=["col1"], shuffle=False) + + for i, data in enumerate(ds): + assert np.equal(data[0].asnumpy(), np_data[i]).all() + + def test_numpy_slices_list_append(): logger.info("Test reading data of image list.")