forked from mindspore-Ecosystem/mindspore
!2228 cache get_dataset_size value
Merge pull request !2228 from yanghaitao/yht_get_dataset_size
This commit is contained in:
commit
1127ace7ec
|
@ -2284,6 +2284,7 @@ class ImageFolderDatasetV2(MappableDataset):
|
|||
self.decode = decode
|
||||
self.num_shards = num_shards
|
||||
self.shard_id = shard_id
|
||||
self.cur_dataset_size = None
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
|
@ -2305,6 +2306,9 @@ class ImageFolderDatasetV2(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self.cur_dataset_size is not None:
|
||||
return self.cur_dataset_size
|
||||
|
||||
if self.num_samples is None:
|
||||
num_samples = 0
|
||||
else:
|
||||
|
@ -2314,9 +2318,11 @@ class ImageFolderDatasetV2(MappableDataset):
|
|||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
if rows_from_sampler is None:
|
||||
self.cur_dataset_size = rows_per_shard
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
self.cur_dataset_size = min(rows_from_sampler, rows_per_shard)
|
||||
return self.cur_dataset_size
|
||||
|
||||
def num_classes(self):
|
||||
"""
|
||||
|
@ -2509,6 +2515,7 @@ class MindDataset(SourceDataset):
|
|||
self.shuffle_option = shuffle
|
||||
self.distribution = ""
|
||||
self.sampler = sampler
|
||||
self.cur_dataset_size = None
|
||||
|
||||
if num_shards is None or shard_id is None:
|
||||
self.partitions = None
|
||||
|
@ -2578,6 +2585,9 @@ class MindDataset(SourceDataset):
|
|||
Number, number of batches.
|
||||
"""
|
||||
if self._dataset_size is None:
|
||||
if self.cur_dataset_size is not None:
|
||||
return self.cur_dataset_size
|
||||
|
||||
if self.load_dataset:
|
||||
dataset_file = [self.dataset_file]
|
||||
else:
|
||||
|
@ -2591,6 +2601,7 @@ class MindDataset(SourceDataset):
|
|||
raise RuntimeError(
|
||||
"Dataset size plus number of padded samples is not divisible by number of shards.")
|
||||
num_rows = num_rows // self.partitions[0] + 1
|
||||
self.cur_dataset_size = num_rows
|
||||
return num_rows
|
||||
return self._dataset_size
|
||||
|
||||
|
|
Loading…
Reference in New Issue