From 16147669a68fc2ea0fee96820119664fbe3c73cf Mon Sep 17 00:00:00 2001 From: liyong Date: Sat, 12 Sep 2020 15:02:54 +0800 Subject: [PATCH] fix num_samples in concatDataset --- mindspore/dataset/engine/datasets.py | 2 ++ .../test_tf_file_3_images/datasetSchema.json | 1 - tests/ut/python/dataset/test_paddeddataset.py | 17 +++++++++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 00b68878c18..cc99dc90527 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2438,6 +2438,8 @@ class ConcatDataset(DatasetOp): self._sampler = _select_sampler(None, sampler, None, None, None) cumulative_samples_nums = 0 for index, child in enumerate(self.children): + if hasattr(child, 'sampler') and child.sampler.get_num_samples() is not None: + raise ValueError("The parameter NumSamples of %s is not support to be set!" % (child)) if isinstance(child, BatchDataset): raise TypeError("The parameter %s of concat should't be BatchDataset!" % (child)) diff --git a/tests/ut/data/dataset/test_tf_file_3_images/datasetSchema.json b/tests/ut/data/dataset/test_tf_file_3_images/datasetSchema.json index eafcfd69eab..e00fd39c10d 100644 --- a/tests/ut/data/dataset/test_tf_file_3_images/datasetSchema.json +++ b/tests/ut/data/dataset/test_tf_file_3_images/datasetSchema.json @@ -1,6 +1,5 @@ { "datasetType": "TF", - "numRows": 3, "columns": { "image": { "type": "uint8", diff --git a/tests/ut/python/dataset/test_paddeddataset.py b/tests/ut/python/dataset/test_paddeddataset.py index cd7ef07ae7f..fd21ee58821 100644 --- a/tests/ut/python/dataset/test_paddeddataset.py +++ b/tests/ut/python/dataset/test_paddeddataset.py @@ -213,6 +213,23 @@ def test_raise_error(): ds3.use_sampler(testsampler) assert excinfo.type == 'ValueError' +def test_imagefolder_error(): + DATA_DIR = "../data/dataset/testPK/data" + data = ds.ImageFolderDataset(DATA_DIR, num_samples=14) + + data1 = [{'image': np.zeros(1, np.uint8), 'label': np.array(0, np.int32)}, + {'image': np.zeros(2, np.uint8), 'label': np.array(1, np.int32)}, + {'image': np.zeros(3, np.uint8), 'label': np.array(0, np.int32)}, + {'image': np.zeros(4, np.uint8), 'label': np.array(1, np.int32)}, + {'image': np.zeros(5, np.uint8), 'label': np.array(0, np.int32)}, + {'image': np.zeros(6, np.uint8), 'label': np.array(1, np.int32)}] + + data2 = ds.PaddedDataset(data1) + data3 = data + data2 + with pytest.raises(ValueError) as excinfo: + testsampler = ds.DistributedSampler(num_shards=5, shard_id=4, shuffle=False, num_samples=None) + data3.use_sampler(testsampler) + assert excinfo.type == 'ValueError' def test_imagefolder_padded(): DATA_DIR = "../data/dataset/testPK/data"