forked from mindspore-Ecosystem/mindspore
!4346 all dataset support get_dataset_size
Merge pull request !4346 from anzhengqi/I1PXKS-Dataset-support-getsize
This commit is contained in:
commit
58be11907c
|
@ -143,7 +143,7 @@ class Dataset:
|
|||
self._input_indexs = ()
|
||||
self._output_types = None
|
||||
self._output_shapes = None
|
||||
self._dataset_size = None
|
||||
self.dataset_size = None
|
||||
self._batch_size = None
|
||||
self._num_classes = None
|
||||
self._repeat_count = None
|
||||
|
@ -1189,8 +1189,6 @@ class Dataset:
|
|||
device_iter = TupleIterator(self)
|
||||
self._output_shapes = device_iter.get_output_shapes()
|
||||
self._output_types = device_iter.get_output_types()
|
||||
if self._dataset_size is None:
|
||||
self._dataset_size = device_iter.get_dataset_size()
|
||||
self._batch_size = device_iter.get_batch_size()
|
||||
self._num_classes = device_iter.num_classes()
|
||||
self._repeat_count = device_iter.get_repeat_count()
|
||||
|
@ -1225,9 +1223,10 @@ class Dataset:
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self.children:
|
||||
return self.children[0].get_dataset_size()
|
||||
return None
|
||||
if self.dataset_size is None:
|
||||
if self.children:
|
||||
self.dataset_size = self.children[0].get_dataset_size()
|
||||
return self.dataset_size
|
||||
|
||||
def num_classes(self):
|
||||
"""
|
||||
|
@ -1378,6 +1377,8 @@ class MappableDataset(SourceDataset):
|
|||
def add_sampler(self, new_sampler):
|
||||
# note: by adding a sampler, we mean that the sampled ids will flow to new_sampler
|
||||
# after first passing through the current samplers attached to this dataset.
|
||||
if self.dataset_size is not None:
|
||||
self.dataset_size = None
|
||||
new_sampler.add_child(self.sampler)
|
||||
self.sampler = new_sampler
|
||||
|
||||
|
@ -1406,6 +1407,8 @@ class MappableDataset(SourceDataset):
|
|||
raise TypeError("Input sampler can not be None.")
|
||||
if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
|
||||
raise TypeError("Input sampler is not an instance of a sampler.")
|
||||
if self.dataset_size is not None:
|
||||
self.dataset_size = None
|
||||
|
||||
self.sampler = self.sampler.child_sampler
|
||||
self.add_sampler(new_sampler)
|
||||
|
@ -1505,6 +1508,7 @@ class MappableDataset(SourceDataset):
|
|||
current_split_start_index = 0
|
||||
for size in absolute_sizes:
|
||||
ds = copy.deepcopy(self)
|
||||
ds.dataset_size = None
|
||||
if randomize:
|
||||
# want to shuffle the same way every epoch before split, we are assuming
|
||||
# that the user will call set_seed
|
||||
|
@ -1582,7 +1586,12 @@ class BucketBatchByLengthDataset(DatasetOp):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
return None
|
||||
if self.dataset_size is None:
|
||||
num_rows = 0
|
||||
for _ in self.create_dict_iterator():
|
||||
num_rows += 1
|
||||
self.dataset_size = num_rows
|
||||
return self.dataset_size
|
||||
|
||||
|
||||
class BatchDataset(DatasetOp):
|
||||
|
@ -1643,12 +1652,14 @@ class BatchDataset(DatasetOp):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
child_size = self.children[0].get_dataset_size()
|
||||
if child_size is not None and isinstance(self.batch_size, int):
|
||||
if self.drop_remainder:
|
||||
return math.floor(child_size / self.batch_size)
|
||||
return math.ceil(child_size / self.batch_size)
|
||||
return None
|
||||
if self.dataset_size is None:
|
||||
child_size = self.children[0].get_dataset_size()
|
||||
if child_size is not None and isinstance(self.batch_size, int):
|
||||
if self.drop_remainder:
|
||||
self.dataset_size = math.floor(child_size / self.batch_size)
|
||||
else:
|
||||
self.dataset_size = math.ceil(child_size / self.batch_size)
|
||||
return self.dataset_size
|
||||
|
||||
def get_batch_size(self):
|
||||
"""
|
||||
|
@ -2000,7 +2011,9 @@ class MapDataset(DatasetOp):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
return self.children[0].get_dataset_size()
|
||||
if self.dataset_size is None:
|
||||
self.dataset_size = self.children[0].get_dataset_size()
|
||||
return self.dataset_size
|
||||
|
||||
def __deepcopy__(self, memodict):
|
||||
if id(self) in memodict:
|
||||
|
@ -2019,6 +2032,7 @@ class MapDataset(DatasetOp):
|
|||
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
|
||||
new_op.cache = copy.deepcopy(self.cache, memodict)
|
||||
new_op.operations = self.operations
|
||||
new_op.dataset_size = self.dataset_size
|
||||
return new_op
|
||||
|
||||
# Iterator bootstrap will be called on iterator construction.
|
||||
|
@ -2091,11 +2105,16 @@ class FilterDataset(DatasetOp):
|
|||
def get_dataset_size(self):
|
||||
"""
|
||||
Get the number of batches in an epoch.
|
||||
the size cannot be determined before we run the pipeline.
|
||||
|
||||
Return:
|
||||
0
|
||||
Number, num of batches.
|
||||
"""
|
||||
return 0
|
||||
if self.dataset_size is None:
|
||||
num_rows = 0
|
||||
for _ in self.create_dict_iterator():
|
||||
num_rows += 1
|
||||
self.dataset_size = num_rows
|
||||
return self.dataset_size
|
||||
|
||||
|
||||
class RepeatDataset(DatasetOp):
|
||||
|
@ -2129,10 +2148,11 @@ class RepeatDataset(DatasetOp):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
child_size = self.children[0].get_dataset_size()
|
||||
if child_size is not None:
|
||||
return child_size * self.count
|
||||
return None
|
||||
if self.dataset_size is None:
|
||||
child_size = self.children[0].get_dataset_size()
|
||||
if child_size is not None:
|
||||
self.dataset_size = child_size * self.count
|
||||
return self.dataset_size
|
||||
|
||||
def get_repeat_count(self):
|
||||
"""
|
||||
|
@ -2172,11 +2192,12 @@ class SkipDataset(DatasetOp):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
child_size = self.children[0].get_dataset_size()
|
||||
output_size = 0
|
||||
if self.count >= 0 and self.count < child_size:
|
||||
output_size = child_size - self.count
|
||||
return output_size
|
||||
if self.dataset_size is None:
|
||||
child_size = self.children[0].get_dataset_size()
|
||||
self.dataset_size = 0
|
||||
if self.count >= 0 and self.count < child_size:
|
||||
self.dataset_size = child_size - self.count
|
||||
return self.dataset_size
|
||||
|
||||
|
||||
class TakeDataset(DatasetOp):
|
||||
|
@ -2207,10 +2228,13 @@ class TakeDataset(DatasetOp):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
child_size = self.children[0].get_dataset_size()
|
||||
if child_size < self.count:
|
||||
return child_size
|
||||
return self.count
|
||||
if self.dataset_size is None:
|
||||
child_size = self.children[0].get_dataset_size()
|
||||
if child_size < self.count:
|
||||
self.dataset_size = child_size
|
||||
else:
|
||||
self.dataset_size = self.count
|
||||
return self.dataset_size
|
||||
|
||||
|
||||
class ZipDataset(DatasetOp):
|
||||
|
@ -2241,10 +2265,11 @@ class ZipDataset(DatasetOp):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
children_sizes = [c.get_dataset_size() for c in self.children]
|
||||
if all(c is not None for c in children_sizes):
|
||||
return min(children_sizes)
|
||||
return None
|
||||
if self.dataset_size is None:
|
||||
children_sizes = [c.get_dataset_size() for c in self.children]
|
||||
if all(c is not None for c in children_sizes):
|
||||
self.dataset_size = min(children_sizes)
|
||||
return self.dataset_size
|
||||
|
||||
def num_classes(self):
|
||||
"""
|
||||
|
@ -2291,9 +2316,10 @@ class ConcatDataset(DatasetOp):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
children_sizes = [c.get_dataset_size() for c in self.children]
|
||||
dataset_size = sum(children_sizes)
|
||||
return dataset_size
|
||||
if self.dataset_size is None:
|
||||
children_sizes = [c.get_dataset_size() for c in self.children]
|
||||
self.dataset_size = sum(children_sizes)
|
||||
return self.dataset_size
|
||||
|
||||
|
||||
class RenameDataset(DatasetOp):
|
||||
|
@ -2439,6 +2465,11 @@ class RangeDataset(MappableDataset):
|
|||
def is_sharded(self):
|
||||
return False
|
||||
|
||||
def get_dataset_size(self):
|
||||
if self.dataset_size is None:
|
||||
self.dataset_size = math.ceil((self.stop - self.start)/self.step)
|
||||
return self.dataset_size
|
||||
|
||||
|
||||
def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False):
|
||||
"""
|
||||
|
@ -2617,14 +2648,13 @@ class ImageFolderDatasetV2(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[0]
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
if self.dataset_size is None:
|
||||
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[0]
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
||||
self.dataset_size = rows_from_sampler
|
||||
return self.dataset_size
|
||||
|
||||
def num_classes(self):
|
||||
"""
|
||||
|
@ -2758,14 +2788,13 @@ class MnistDataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
num_rows = MnistOp.get_num_rows(self.dataset_dir)
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
if self.dataset_size is None:
|
||||
num_rows = MnistOp.get_num_rows(self.dataset_dir)
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
||||
self.dataset_size = rows_from_sampler
|
||||
return self.dataset_size
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
|
@ -2868,20 +2897,20 @@ class MindDataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self._dataset_size is None:
|
||||
if self.dataset_size is None:
|
||||
if self.load_dataset:
|
||||
dataset_file = [self.dataset_file]
|
||||
else:
|
||||
dataset_file = self.dataset_file
|
||||
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded)
|
||||
return num_rows
|
||||
return self._dataset_size
|
||||
self.dataset_size = num_rows
|
||||
return self.dataset_size
|
||||
|
||||
# manually set dataset_size as a tempoary solution.
|
||||
def set_dataset_size(self, value):
|
||||
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
|
||||
if value >= 0:
|
||||
self._dataset_size = value
|
||||
self.dataset_size = value
|
||||
else:
|
||||
raise ValueError('Set dataset_size with negative value {}'.format(value))
|
||||
|
||||
|
@ -3205,6 +3234,7 @@ class GeneratorDataset(MappableDataset):
|
|||
self.source = source
|
||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||
self.num_samples = num_samples
|
||||
self.num_shards = num_shards
|
||||
|
||||
if column_names is not None and not isinstance(column_names, list):
|
||||
column_names = [column_names]
|
||||
|
@ -3225,9 +3255,6 @@ class GeneratorDataset(MappableDataset):
|
|||
self.column_names.append(col["name"])
|
||||
self.column_types.append(DataType(col["type"]))
|
||||
|
||||
if source is not None and hasattr(source, "__len__"):
|
||||
self._dataset_size = len(source)
|
||||
|
||||
def get_args(self):
|
||||
args = super().get_args()
|
||||
args["source"] = self.source
|
||||
|
@ -3242,19 +3269,27 @@ class GeneratorDataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
if self.dataset_size is None:
|
||||
if hasattr(self.source, "__len__"):
|
||||
if not self.num_shards:
|
||||
self.dataset_size = len(self.source)
|
||||
else:
|
||||
self.dataset_size = math.ceil(len(self.source)/self.num_shards)
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return self._dataset_size
|
||||
if self._dataset_size is None:
|
||||
return None
|
||||
|
||||
return min(rows_from_sampler, self._dataset_size)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
||||
self.dataset_size = rows_from_sampler
|
||||
else:
|
||||
num_rows = 0
|
||||
for _ in self.create_dict_iterator():
|
||||
num_rows += 1
|
||||
self.dataset_size = num_rows
|
||||
return self.dataset_size
|
||||
|
||||
# manually set dataset_size as a temporary solution.
|
||||
def set_dataset_size(self, value):
|
||||
if value >= 0:
|
||||
self._dataset_size = value
|
||||
self.dataset_size = value
|
||||
else:
|
||||
raise ValueError('Set dataset_size with negative value {}'.format(value))
|
||||
|
||||
|
@ -3271,6 +3306,7 @@ class GeneratorDataset(MappableDataset):
|
|||
new_op.column_types = copy.deepcopy(self.column_types, memodict)
|
||||
new_op.column_names = copy.deepcopy(self.column_names, memodict)
|
||||
new_op.num_samples = copy.deepcopy(self.num_samples, memodict)
|
||||
new_op.dataset_size = self.dataset_size
|
||||
|
||||
new_op.sampler = copy.deepcopy(self.sampler)
|
||||
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
|
||||
|
@ -3433,19 +3469,18 @@ class TFRecordDataset(SourceDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self._dataset_size is None:
|
||||
if self.dataset_size is None:
|
||||
num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate)
|
||||
num_rows = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples is None:
|
||||
return num_rows
|
||||
return min(self.num_samples, num_rows)
|
||||
return self._dataset_size
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples is not None and self.num_samples < self.dataset_size:
|
||||
self.dataset_size = self.num_samples
|
||||
return self.dataset_size
|
||||
|
||||
# manually set dataset_size as a tempoary solution.
|
||||
def set_dataset_size(self, value):
|
||||
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
|
||||
if value >= 0:
|
||||
self._dataset_size = value
|
||||
self.dataset_size = value
|
||||
else:
|
||||
raise ValueError('Set dataset_size with negative value {}'.format(value))
|
||||
|
||||
|
@ -3574,19 +3609,19 @@ class ManifestDataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self.class_indexing is None:
|
||||
class_indexing = dict()
|
||||
else:
|
||||
class_indexing = self.class_indexing
|
||||
if self.dataset_size is None:
|
||||
if self.class_indexing is None:
|
||||
class_indexing = dict()
|
||||
else:
|
||||
class_indexing = self.class_indexing
|
||||
|
||||
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[0]
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[0]
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
||||
self.dataset_size = rows_from_sampler
|
||||
return self.dataset_size
|
||||
|
||||
def num_classes(self):
|
||||
"""
|
||||
|
@ -3742,15 +3777,15 @@ class Cifar10Dataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self.dataset_size is None:
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
||||
self.dataset_size = rows_from_sampler
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
return self.dataset_size
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
|
@ -3878,15 +3913,15 @@ class Cifar100Dataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self.dataset_size is None:
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, False)
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, False)
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
||||
self.dataset_size = rows_from_sampler
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
return self.dataset_size
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
|
@ -3971,16 +4006,16 @@ class RandomDataset(SourceDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self.dataset_size is None:
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
|
||||
|
||||
num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
||||
self.dataset_size = rows_from_sampler
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
return self.dataset_size
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
|
@ -4317,24 +4352,25 @@ class VOCDataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self.num_samples is None:
|
||||
num_samples = 0
|
||||
else:
|
||||
num_samples = self.num_samples
|
||||
if self.dataset_size is None:
|
||||
if self.num_samples is None:
|
||||
num_samples = 0
|
||||
else:
|
||||
num_samples = self.num_samples
|
||||
|
||||
if self.class_indexing is None:
|
||||
class_indexing = dict()
|
||||
else:
|
||||
class_indexing = self.class_indexing
|
||||
if self.class_indexing is None:
|
||||
class_indexing = dict()
|
||||
else:
|
||||
class_indexing = self.class_indexing
|
||||
|
||||
num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
||||
self.dataset_size = rows_from_sampler
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
return self.dataset_size
|
||||
|
||||
def get_class_indexing(self):
|
||||
"""
|
||||
|
@ -4514,14 +4550,15 @@ class CocoDataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
num_rows = CocoOp.get_num_rows(self.dataset_dir, self.annotation_file, self.task)
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
if self.dataset_size is None:
|
||||
num_rows = CocoOp.get_num_rows(self.dataset_dir, self.annotation_file, self.task)
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
||||
self.dataset_size = rows_from_sampler
|
||||
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
return self.dataset_size
|
||||
|
||||
def get_class_indexing(self):
|
||||
"""
|
||||
|
@ -4638,7 +4675,7 @@ class CelebADataset(MappableDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self._dataset_size is None:
|
||||
if self.dataset_size is None:
|
||||
dir = os.path.realpath(self.dataset_dir)
|
||||
attr_file = os.path.join(dir, "list_attr_celeba.txt")
|
||||
num_rows = ''
|
||||
|
@ -4649,14 +4686,13 @@ class CelebADataset(MappableDataset):
|
|||
raise RuntimeError("attr_file not found.")
|
||||
except BaseException:
|
||||
raise RuntimeError("Get dataset size failed from attribution file.")
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples is not None:
|
||||
rows_per_shard = min(self.num_samples, rows_per_shard)
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples is not None and self.num_samples < self.dataset_size:
|
||||
self.dataset_size = self.num_samples
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
return self._dataset_size
|
||||
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
|
||||
self.dataset_size = rows_from_sampler
|
||||
return self.dataset_size
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
|
@ -4888,13 +4924,12 @@ class CLUEDataset(SourceDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self._dataset_size is None:
|
||||
if self.dataset_size is None:
|
||||
num_rows = ClueOp.get_num_rows(self.dataset_files)
|
||||
num_rows = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples is None:
|
||||
return num_rows
|
||||
return min(self.num_samples, num_rows)
|
||||
return self._dataset_size
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples is not None and self.num_samples < self.dataset_size:
|
||||
self.dataset_size = self.num_samples
|
||||
return self.dataset_size
|
||||
|
||||
def is_shuffled(self):
|
||||
return self.shuffle_files
|
||||
|
@ -4991,13 +5026,12 @@ class CSVDataset(SourceDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self._dataset_size is None:
|
||||
if self.dataset_size is None:
|
||||
num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None)
|
||||
num_rows = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples == -1:
|
||||
return num_rows
|
||||
return min(self.num_samples, num_rows)
|
||||
return self._dataset_size
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples != -1 and self.num_samples < self.dataset_size:
|
||||
self.dataset_size = num_rows
|
||||
return self.dataset_size
|
||||
|
||||
def is_shuffled(self):
|
||||
return self.shuffle_files
|
||||
|
@ -5082,15 +5116,14 @@ class TextFileDataset(SourceDataset):
|
|||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self._dataset_size is None:
|
||||
if self.dataset_size is None:
|
||||
num_rows = TextFileOp.get_num_rows(self.dataset_files)
|
||||
num_rows = get_num_rows(num_rows, self.num_shards)
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
# If the user gave a num samples in the dataset, then the sampler will limit the rows returned
|
||||
# to that amount. Account for that here in the row count
|
||||
if self.num_samples is not None and self.num_samples > 0 and num_rows > self.num_samples:
|
||||
num_rows = self.num_samples
|
||||
return num_rows
|
||||
return self._dataset_size
|
||||
self.dataset_size = self.num_samples
|
||||
return self.dataset_size
|
||||
|
||||
def is_shuffled(self):
|
||||
return self.shuffle_files
|
||||
|
@ -5308,6 +5341,7 @@ class BuildVocabDataset(DatasetOp):
|
|||
new_op.vocab = self.vocab
|
||||
new_op.special_tokens = copy.deepcopy(self.special_tokens)
|
||||
new_op.special_first = copy.deepcopy(self.special_first)
|
||||
new_op.dataset_size = self.dataset_size
|
||||
|
||||
return new_op
|
||||
|
||||
|
@ -5365,4 +5399,5 @@ class BuildSentencePieceVocabDataset(DatasetOp):
|
|||
new_op.params = copy.deepcopy(self.params, memodict)
|
||||
new_op.vocab = self.vocab
|
||||
new_op.model_type = copy.deepcopy(self.model_type)
|
||||
new_op.dataset_size = self.dataset_size
|
||||
return new_op
|
||||
|
|
|
@ -72,6 +72,7 @@ class Iterator:
|
|||
ITERATORS_LIST.append(weakref.ref(self))
|
||||
# create a copy of tree and work on it.
|
||||
self.dataset = copy.deepcopy(dataset)
|
||||
self.ori_dataset = dataset
|
||||
self.parent_subtree = []
|
||||
|
||||
# The dataset passed into the iterator is not the root of the tree.
|
||||
|
@ -247,6 +248,8 @@ class Iterator:
|
|||
if not data:
|
||||
if self._index == 0:
|
||||
logger.warning("No records available.")
|
||||
if self.ori_dataset.dataset_size is None:
|
||||
self.ori_dataset.dataset_size = self._index
|
||||
raise StopIteration
|
||||
self._index += 1
|
||||
return data
|
||||
|
|
|
@ -31,7 +31,8 @@ class OneHot(cde.OneHotOp):
|
|||
Tensor operation to apply one hot encoding.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of classes of the label, it should be bigger than feature size.
|
||||
num_classes (int): Number of classes of the label
|
||||
it should be bigger than or equal to label class number.
|
||||
|
||||
Raises:
|
||||
RuntimeError: feature size is bigger than num_classes.
|
||||
|
|
|
@ -382,6 +382,25 @@ def test_bucket_batch_multi_column():
|
|||
assert same_shape_output == same_shape_expected_output
|
||||
assert variable_shape_output == variable_shape_expected_output
|
||||
|
||||
def test_bucket_batch_get_dataset_size():
|
||||
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])
|
||||
|
||||
column_names = ["col1"]
|
||||
bucket_boundaries = [1, 2, 3]
|
||||
bucket_batch_sizes = [3, 3, 2, 2]
|
||||
element_length_function = (lambda x: x[0] % 4)
|
||||
|
||||
dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
|
||||
bucket_batch_sizes, element_length_function)
|
||||
|
||||
data_size = dataset.get_dataset_size()
|
||||
|
||||
num_rows = 0
|
||||
for _ in dataset.create_dict_iterator():
|
||||
num_rows += 1
|
||||
|
||||
assert data_size == num_rows
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_bucket_batch_invalid_input()
|
||||
|
@ -394,3 +413,4 @@ if __name__ == '__main__':
|
|||
test_bucket_batch_drop_remainder()
|
||||
test_bucket_batch_default_length_function()
|
||||
test_bucket_batch_multi_column()
|
||||
test_bucket_batch_get_dataset_size()
|
||||
|
|
|
@ -25,6 +25,16 @@ def generator_1d():
|
|||
for i in range(64):
|
||||
yield (np.array([i]),)
|
||||
|
||||
class DatasetGenerator:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __getitem__(self, item):
|
||||
return (np.array([item]),)
|
||||
|
||||
def __len__(self):
|
||||
return 10
|
||||
|
||||
|
||||
def test_generator_0():
|
||||
"""
|
||||
|
@ -615,6 +625,103 @@ def test_generator_schema():
|
|||
type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]])
|
||||
|
||||
|
||||
def test_generator_dataset_size_0():
|
||||
"""
|
||||
Test GeneratorDataset get_dataset_size by iterator method.
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63 get_dataset_size")
|
||||
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data_size = data1.get_dataset_size()
|
||||
|
||||
num_rows = 0
|
||||
for _ in data1.create_dict_iterator(): # each data is a dictionary
|
||||
num_rows = num_rows + 1
|
||||
assert data_size == num_rows
|
||||
|
||||
|
||||
def test_generator_dataset_size_1():
|
||||
"""
|
||||
Test GeneratorDataset get_dataset_size by __len__ method.
|
||||
"""
|
||||
logger.info("Test DatasetGenerator get_dataset_size")
|
||||
|
||||
dataset_generator = DatasetGenerator()
|
||||
data1 = ds.GeneratorDataset(dataset_generator, ["data"])
|
||||
|
||||
data_size = data1.get_dataset_size()
|
||||
|
||||
num_rows = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_rows = num_rows + 1
|
||||
assert data_size == num_rows
|
||||
|
||||
|
||||
def test_generator_dataset_size_2():
|
||||
"""
|
||||
Test GeneratorDataset + repeat get_dataset_size
|
||||
"""
|
||||
logger.info("Test 1D Generator + repeat get_dataset_size")
|
||||
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data1 = data1.repeat(2)
|
||||
|
||||
data_size = data1.get_dataset_size()
|
||||
|
||||
num_rows = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_rows = num_rows + 1
|
||||
assert data_size == num_rows
|
||||
|
||||
|
||||
def test_generator_dataset_size_3():
|
||||
"""
|
||||
Test GeneratorDataset + batch get_dataset_size
|
||||
"""
|
||||
logger.info("Test 1D Generator + batch get_dataset_size")
|
||||
|
||||
data1 = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
data1 = data1.batch(4)
|
||||
|
||||
data_size = data1.get_dataset_size()
|
||||
|
||||
num_rows = 0
|
||||
for _ in data1.create_dict_iterator():
|
||||
num_rows += 1
|
||||
assert data_size == num_rows
|
||||
|
||||
|
||||
def test_generator_dataset_size_4():
|
||||
"""
|
||||
Test GeneratorDataset + num_shards
|
||||
"""
|
||||
logger.info("Test 1D Generator : 0 - 63 + num_shards get_dataset_size")
|
||||
|
||||
dataset_generator = DatasetGenerator()
|
||||
data1 = ds.GeneratorDataset(dataset_generator, ["data"], num_shards=3, shard_id=0)
|
||||
data_size = data1.get_dataset_size()
|
||||
|
||||
num_rows = 0
|
||||
for _ in data1.create_dict_iterator(): # each data is a dictionary
|
||||
num_rows = num_rows + 1
|
||||
assert data_size == num_rows
|
||||
|
||||
def test_generator_dataset_size_5():
|
||||
"""
|
||||
Test get_dataset_size after create_dict_iterator
|
||||
"""
|
||||
logger.info("Test get_dataset_size after create_dict_iterator")
|
||||
|
||||
dataset_generator = DatasetGenerator()
|
||||
data1 = ds.GeneratorDataset(dataset_generator, ["data"], num_shards=3, shard_id=0)
|
||||
|
||||
num_rows = 0
|
||||
for _ in data1.create_dict_iterator(): # each data is a dictionary
|
||||
num_rows = num_rows + 1
|
||||
data_size = data1.get_dataset_size()
|
||||
assert data_size == num_rows
|
||||
|
||||
|
||||
def manual_test_generator_keyboard_interrupt():
|
||||
"""
|
||||
Test keyboard_interrupt
|
||||
|
@ -663,3 +770,9 @@ if __name__ == "__main__":
|
|||
test_generator_num_samples()
|
||||
test_generator_num_samples_underflow()
|
||||
test_generator_schema()
|
||||
test_generator_dataset_size_0()
|
||||
test_generator_dataset_size_1()
|
||||
test_generator_dataset_size_2()
|
||||
test_generator_dataset_size_3()
|
||||
test_generator_dataset_size_4()
|
||||
test_generator_dataset_size_5()
|
||||
|
|
|
@ -484,6 +484,16 @@ def test_filter_by_generator_with_map_all_sort():
|
|||
assert ret_data[0]["col1"] == 0
|
||||
assert ret_data[9]["col6"] == 509
|
||||
|
||||
def test_filter_by_generator_get_dataset_size():
|
||||
dataset = ds.GeneratorDataset(generator_1d, ["data"])
|
||||
dataset = dataset.filter(predicate=filter_func_shuffle_after, num_parallel_workers=4)
|
||||
data_sie = dataset.get_dataset_size()
|
||||
|
||||
num_iter = 0
|
||||
for _ in dataset.create_dict_iterator():
|
||||
num_iter += 1
|
||||
assert data_sie == num_iter
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_diff_predicate_func()
|
||||
|
@ -506,3 +516,4 @@ if __name__ == '__main__':
|
|||
test_filter_by_generator_with_zip()
|
||||
test_filter_by_generator_with_zip_after()
|
||||
test_filter_by_generator_Partial()
|
||||
test_filter_by_generator_get_dataset_size()
|
||||
|
|
Loading…
Reference in New Issue