!2228 cache get_dataset_size value

Merge pull request !2228 from yanghaitao/yht_get_dataset_size
This commit is contained in:
mindspore-ci-bot 2020-06-18 14:44:05 +08:00 committed by Gitee
commit 1127ace7ec
1 changed files with 12 additions and 1 deletions

View File

@ -2284,6 +2284,7 @@ class ImageFolderDatasetV2(MappableDataset):
self.decode = decode
self.num_shards = num_shards
self.shard_id = shard_id
self.cur_dataset_size = None
def get_args(self):
args = super().get_args()
@ -2305,6 +2306,9 @@ class ImageFolderDatasetV2(MappableDataset):
Return:
Number, number of batches.
"""
if self.cur_dataset_size is not None:
return self.cur_dataset_size
if self.num_samples is None:
num_samples = 0
else:
@ -2314,9 +2318,11 @@ class ImageFolderDatasetV2(MappableDataset):
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
self.cur_dataset_size = rows_per_shard
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
self.cur_dataset_size = min(rows_from_sampler, rows_per_shard)
return self.cur_dataset_size
def num_classes(self):
"""
@ -2509,6 +2515,7 @@ class MindDataset(SourceDataset):
self.shuffle_option = shuffle
self.distribution = ""
self.sampler = sampler
self.cur_dataset_size = None
if num_shards is None or shard_id is None:
self.partitions = None
@ -2578,6 +2585,9 @@ class MindDataset(SourceDataset):
Number, number of batches.
"""
if self._dataset_size is None:
if self.cur_dataset_size is not None:
return self.cur_dataset_size
if self.load_dataset:
dataset_file = [self.dataset_file]
else:
@ -2591,6 +2601,7 @@ class MindDataset(SourceDataset):
raise RuntimeError(
"Dataset size plus number of padded samples is not divisible by number of shards.")
num_rows = num_rows // self.partitions[0] + 1
self.cur_dataset_size = num_rows
return num_rows
return self._dataset_size