forked from mindspore-Ecosystem/mindspore
!2380 Fix CocoDataset issue
Merge pull request !2380 from xiefangqi/xfq_fix_coco_issue_01
This commit is contained in:
commit
78a8bc302d
|
@ -996,7 +996,8 @@ class Dataset:
|
|||
def get_distribution(output_dataset):
|
||||
dev_id = 0
|
||||
if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2,
|
||||
ManifestDataset, MnistDataset, VOCDataset, CelebADataset, MindDataset)):
|
||||
ManifestDataset, MnistDataset, VOCDataset, CocoDataset, CelebADataset,
|
||||
MindDataset)):
|
||||
sampler = output_dataset.sampler
|
||||
if isinstance(sampler, samplers.DistributedSampler):
|
||||
dev_id = sampler.shard_id
|
||||
|
@ -4171,8 +4172,8 @@ class CocoDataset(MappableDataset):
|
|||
- task='Panoptic', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['category_id', dtype=uint32],
|
||||
['iscrowd', dtype=uint32], ['area', dtype=uint32]].
|
||||
|
||||
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. Table
|
||||
below shows what input args are allowed and their expected behavior.
|
||||
This dataset can take in a sampler. sampler and shuffle are mutually exclusive. CocoDataset doesn't support
|
||||
PKSampler. Table below shows what input args are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
|
||||
:widths: 25 25 50
|
||||
|
|
|
@ -397,6 +397,9 @@ def check_cocodataset(method):
|
|||
|
||||
check_param_type(nreq_param_bool, param_dict, bool)
|
||||
|
||||
sampler = param_dict.get('sampler')
|
||||
if sampler is not None and isinstance(sampler, samplers.PKSampler):
|
||||
raise ValueError("CocoDataset doesn't support PKSampler")
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
|
|
@ -251,6 +251,16 @@ def test_coco_case_exception():
|
|||
except RuntimeError as e:
|
||||
assert "json.exception.parse_error" in str(e)
|
||||
|
||||
try:
|
||||
sampler = ds.PKSampler(3)
|
||||
data1 = ds.CocoDataset(DATA_DIR, annotation_file=INVALID_FILE, task="Detection", sampler=sampler)
|
||||
for _ in data1.__iter__():
|
||||
pass
|
||||
assert False
|
||||
except ValueError as e:
|
||||
assert "CocoDataset doesn't support PKSampler" in str(e)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_coco_detection()
|
||||
test_coco_stuff()
|
||||
|
|
Loading…
Reference in New Issue