diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 27a3d478b2e..f7dc497f688 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -996,7 +996,8 @@ class Dataset: def get_distribution(output_dataset): dev_id = 0 if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2, - ManifestDataset, MnistDataset, VOCDataset, CelebADataset, MindDataset)): + ManifestDataset, MnistDataset, VOCDataset, CocoDataset, CelebADataset, + MindDataset)): sampler = output_dataset.sampler if isinstance(sampler, samplers.DistributedSampler): dev_id = sampler.shard_id @@ -4171,8 +4172,8 @@ class CocoDataset(MappableDataset): - task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32], ['iscrowd', dtype=uint32], ['area', dtype=uint32]]. - This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table - below shows what input args are allowed and their expected behavior. + This dataset can take in a sampler. sampler and shuffle are mutually exclusive. CocoDataset doesn't support + PKSampler. Table below shows what input args are allowed and their expected behavior. .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle' :widths: 25 25 50 diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 7dfbdf149af..6cfea7c8b89 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -397,6 +397,9 @@ def check_cocodataset(method): check_param_type(nreq_param_bool, param_dict, bool) + sampler = param_dict.get('sampler') + if sampler is not None and isinstance(sampler, samplers.PKSampler): + raise ValueError("CocoDataset doesn't support PKSampler") check_sampler_shuffle_shard_options(param_dict) return method(*args, **kwargs) diff --git a/tests/ut/python/dataset/test_datasets_coco.py b/tests/ut/python/dataset/test_datasets_coco.py index 8166a952b8d..f5bf7caa6c4 100644 --- a/tests/ut/python/dataset/test_datasets_coco.py +++ b/tests/ut/python/dataset/test_datasets_coco.py @@ -251,6 +251,16 @@ def test_coco_case_exception(): except RuntimeError as e: assert "json.exception.parse_error" in str(e) + try: + sampler = ds.PKSampler(3) + data1 = ds.CocoDataset(DATA_DIR, annotation_file=INVALID_FILE, task="Detection", sampler=sampler) + for _ in data1.__iter__(): + pass + assert False + except ValueError as e: + assert "CocoDataset doesn't support PKSampler" in str(e) + + if __name__ == '__main__': test_coco_detection() test_coco_stuff()