diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 679f4d3f6e3..63b0fa44054 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -599,9 +599,9 @@ class Dataset: def get_distribution(output_dataset): dev_id = 0 - if isinstance(output_dataset, (StorageDataset, GeneratorDataset, MindDataset)): + if isinstance(output_dataset, (StorageDataset, MindDataset)): return output_dataset.distribution, dev_id - if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, ImageFolderDatasetV2, + if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2, ManifestDataset, MnistDataset, VOCDataset, CelebADataset)): sampler = output_dataset.sampler if isinstance(sampler, samplers.DistributedSampler):