store get dataset size

This commit is contained in:
yanghaitao1 2020-06-17 04:22:19 -04:00
parent 1e90e7be05
commit 038040750d
1 changed files with 12 additions and 1 deletions

View File

@ -2284,6 +2284,7 @@ class ImageFolderDatasetV2(MappableDataset):
self.decode = decode
self.num_shards = num_shards
self.shard_id = shard_id
self.cur_dataset_size = None
def get_args(self):
args = super().get_args()
@ -2305,6 +2306,9 @@ class ImageFolderDatasetV2(MappableDataset):
Return:
Number, number of batches.
"""
if self.cur_dataset_size is not None:
return self.cur_dataset_size
if self.num_samples is None:
num_samples = 0
else:
@ -2314,9 +2318,11 @@ class ImageFolderDatasetV2(MappableDataset):
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
self.cur_dataset_size = rows_per_shard
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
self.cur_dataset_size = min(rows_from_sampler, rows_per_shard)
return self.cur_dataset_size
def num_classes(self):
"""
@ -2509,6 +2515,7 @@ class MindDataset(SourceDataset):
self.shuffle_option = shuffle
self.distribution = ""
self.sampler = sampler
self.cur_dataset_size = None
if num_shards is None or shard_id is None:
self.partitions = None
@ -2578,6 +2585,9 @@ class MindDataset(SourceDataset):
Number, number of batches.
"""
if self._dataset_size is None:
if self.cur_dataset_size is not None:
return self.cur_dataset_size
if self.load_dataset:
dataset_file = [self.dataset_file]
else:
@ -2591,6 +2601,7 @@ class MindDataset(SourceDataset):
raise RuntimeError(
"Dataset size plus number of padded samples is not divisible by number of shards.")
num_rows = num_rows // self.partitions[0] + 1
self.cur_dataset_size = num_rows
return num_rows
return self._dataset_size