fix CocoDataset issue

This commit is contained in:
xiefangqi 2020-06-20 09:45:41 +08:00
parent e9670f3c28
commit 5703a10b8b
3 changed files with 17 additions and 3 deletions

View File

@ -997,7 +997,8 @@ class Dataset:
def get_distribution(output_dataset):
dev_id = 0
if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2,
ManifestDataset, MnistDataset, VOCDataset, CelebADataset, MindDataset)):
ManifestDataset, MnistDataset, VOCDataset, CocoDataset, CelebADataset,
MindDataset)):
sampler = output_dataset.sampler
if isinstance(sampler, samplers.DistributedSampler):
dev_id = sampler.shard_id
@ -4172,8 +4173,8 @@ class CocoDataset(MappableDataset):
- task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
['iscrowd', dtype=uint32], ['area', dtype=uint32]].
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
below shows what input args are allowed and their expected behavior.
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. CocoDataset doesn't support
PKSampler. Table below shows what input args are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
:widths: 25 25 50

View File

@ -560,6 +560,9 @@ def check_cocodataset(method):
check_param_type(nreq_param_bool, param_dict, bool)
sampler = param_dict.get('sampler')
if sampler is not None and isinstance(sampler, samplers.PKSampler):
raise ValueError("CocoDataset doesn't support PKSampler")
check_sampler_shuffle_shard_options(param_dict)
return method(*args, **kwargs)

View File

@ -251,6 +251,16 @@ def test_coco_case_exception():
except RuntimeError as e:
assert "json.exception.parse_error" in str(e)
try:
sampler = ds.PKSampler(3)
data1 = ds.CocoDataset(DATA_DIR, annotation_file=INVALID_FILE, task="Detection", sampler=sampler)
for _ in data1.__iter__():
pass
assert False
except ValueError as e:
assert "CocoDataset doesn't support PKSampler" in str(e)
if __name__ == '__main__':
test_coco_detection()
test_coco_stuff()