all Dataset support get_dataset_size

This commit is contained in:
anzhengqi 2020-08-12 15:35:47 +08:00
parent 0082c0e5df
commit 3e31ac6d62
6 changed files with 336 additions and 153 deletions

View File

@ -143,7 +143,7 @@ class Dataset:
self._input_indexs = ()
self._output_types = None
self._output_shapes = None
self._dataset_size = None
self.dataset_size = None
self._batch_size = None
self._num_classes = None
self._repeat_count = None
@ -1189,8 +1189,6 @@ class Dataset:
device_iter = TupleIterator(self)
self._output_shapes = device_iter.get_output_shapes()
self._output_types = device_iter.get_output_types()
if self._dataset_size is None:
self._dataset_size = device_iter.get_dataset_size()
self._batch_size = device_iter.get_batch_size()
self._num_classes = device_iter.num_classes()
self._repeat_count = device_iter.get_repeat_count()
@ -1225,9 +1223,10 @@ class Dataset:
Return:
Number, number of batches.
"""
if self.children:
return self.children[0].get_dataset_size()
return None
if self.dataset_size is None:
if self.children:
self.dataset_size = self.children[0].get_dataset_size()
return self.dataset_size
def num_classes(self):
"""
@ -1378,6 +1377,8 @@ class MappableDataset(SourceDataset):
def add_sampler(self, new_sampler):
# note: by adding a sampler, we mean that the sampled ids will flow to new_sampler
# after first passing through the current samplers attached to this dataset.
if self.dataset_size is not None:
self.dataset_size = None
new_sampler.add_child(self.sampler)
self.sampler = new_sampler
@ -1406,6 +1407,8 @@ class MappableDataset(SourceDataset):
raise TypeError("Input sampler can not be None.")
if not isinstance(new_sampler, (samplers.BuiltinSampler, samplers.Sampler)):
raise TypeError("Input sampler is not an instance of a sampler.")
if self.dataset_size is not None:
self.dataset_size = None
self.sampler = self.sampler.child_sampler
self.add_sampler(new_sampler)
@ -1505,6 +1508,7 @@ class MappableDataset(SourceDataset):
current_split_start_index = 0
for size in absolute_sizes:
ds = copy.deepcopy(self)
ds.dataset_size = None
if randomize:
# want to shuffle the same way every epoch before split, we are assuming
# that the user will call set_seed
@ -1582,7 +1586,12 @@ class BucketBatchByLengthDataset(DatasetOp):
Return:
Number, number of batches.
"""
return None
if self.dataset_size is None:
num_rows = 0
for _ in self.create_dict_iterator():
num_rows += 1
self.dataset_size = num_rows
return self.dataset_size
class BatchDataset(DatasetOp):
@ -1643,12 +1652,14 @@ class BatchDataset(DatasetOp):
Return:
Number, number of batches.
"""
child_size = self.children[0].get_dataset_size()
if child_size is not None and isinstance(self.batch_size, int):
if self.drop_remainder:
return math.floor(child_size / self.batch_size)
return math.ceil(child_size / self.batch_size)
return None
if self.dataset_size is None:
child_size = self.children[0].get_dataset_size()
if child_size is not None and isinstance(self.batch_size, int):
if self.drop_remainder:
self.dataset_size = math.floor(child_size / self.batch_size)
else:
self.dataset_size = math.ceil(child_size / self.batch_size)
return self.dataset_size
def get_batch_size(self):
"""
@ -2000,7 +2011,9 @@ class MapDataset(DatasetOp):
Return:
Number, number of batches.
"""
return self.children[0].get_dataset_size()
if self.dataset_size is None:
self.dataset_size = self.children[0].get_dataset_size()
return self.dataset_size
def __deepcopy__(self, memodict):
if id(self) in memodict:
@ -2019,6 +2032,7 @@ class MapDataset(DatasetOp):
new_op.python_multiprocessing = copy.deepcopy(self.python_multiprocessing, memodict)
new_op.cache = copy.deepcopy(self.cache, memodict)
new_op.operations = self.operations
new_op.dataset_size = self.dataset_size
return new_op
# Iterator bootstrap will be called on iterator construction.
@ -2091,11 +2105,16 @@ class FilterDataset(DatasetOp):
def get_dataset_size(self):
"""
Get the number of batches in an epoch.
the size cannot be determined before we run the pipeline.
Return:
0
Number, num of batches.
"""
return 0
if self.dataset_size is None:
num_rows = 0
for _ in self.create_dict_iterator():
num_rows += 1
self.dataset_size = num_rows
return self.dataset_size
class RepeatDataset(DatasetOp):
@ -2129,10 +2148,11 @@ class RepeatDataset(DatasetOp):
Return:
Number, number of batches.
"""
child_size = self.children[0].get_dataset_size()
if child_size is not None:
return child_size * self.count
return None
if self.dataset_size is None:
child_size = self.children[0].get_dataset_size()
if child_size is not None:
self.dataset_size = child_size * self.count
return self.dataset_size
def get_repeat_count(self):
"""
@ -2172,11 +2192,12 @@ class SkipDataset(DatasetOp):
Return:
Number, number of batches.
"""
child_size = self.children[0].get_dataset_size()
output_size = 0
if self.count >= 0 and self.count < child_size:
output_size = child_size - self.count
return output_size
if self.dataset_size is None:
child_size = self.children[0].get_dataset_size()
self.dataset_size = 0
if self.count >= 0 and self.count < child_size:
self.dataset_size = child_size - self.count
return self.dataset_size
class TakeDataset(DatasetOp):
@ -2207,10 +2228,13 @@ class TakeDataset(DatasetOp):
Return:
Number, number of batches.
"""
child_size = self.children[0].get_dataset_size()
if child_size < self.count:
return child_size
return self.count
if self.dataset_size is None:
child_size = self.children[0].get_dataset_size()
if child_size < self.count:
self.dataset_size = child_size
else:
self.dataset_size = self.count
return self.dataset_size
class ZipDataset(DatasetOp):
@ -2241,10 +2265,11 @@ class ZipDataset(DatasetOp):
Return:
Number, number of batches.
"""
children_sizes = [c.get_dataset_size() for c in self.children]
if all(c is not None for c in children_sizes):
return min(children_sizes)
return None
if self.dataset_size is None:
children_sizes = [c.get_dataset_size() for c in self.children]
if all(c is not None for c in children_sizes):
self.dataset_size = min(children_sizes)
return self.dataset_size
def num_classes(self):
"""
@ -2291,9 +2316,10 @@ class ConcatDataset(DatasetOp):
Return:
Number, number of batches.
"""
children_sizes = [c.get_dataset_size() for c in self.children]
dataset_size = sum(children_sizes)
return dataset_size
if self.dataset_size is None:
children_sizes = [c.get_dataset_size() for c in self.children]
self.dataset_size = sum(children_sizes)
return self.dataset_size
class RenameDataset(DatasetOp):
@ -2439,6 +2465,11 @@ class RangeDataset(MappableDataset):
def is_sharded(self):
return False
def get_dataset_size(self):
if self.dataset_size is None:
self.dataset_size = math.ceil((self.stop - self.start)/self.step)
return self.dataset_size
def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id, non_mappable=False):
"""
@ -2617,14 +2648,13 @@ class ImageFolderDatasetV2(MappableDataset):
Return:
Number, number of batches.
"""
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[0]
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
if self.dataset_size is None:
num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[0]
self.dataset_size = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
self.dataset_size = rows_from_sampler
return self.dataset_size
def num_classes(self):
"""
@ -2758,14 +2788,13 @@ class MnistDataset(MappableDataset):
Return:
Number, number of batches.
"""
num_rows = MnistOp.get_num_rows(self.dataset_dir)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
if self.dataset_size is None:
num_rows = MnistOp.get_num_rows(self.dataset_dir)
self.dataset_size = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
self.dataset_size = rows_from_sampler
return self.dataset_size
def is_shuffled(self):
if self.shuffle_level is None:
@ -2868,20 +2897,20 @@ class MindDataset(MappableDataset):
Return:
Number, number of batches.
"""
if self._dataset_size is None:
if self.dataset_size is None:
if self.load_dataset:
dataset_file = [self.dataset_file]
else:
dataset_file = self.dataset_file
num_rows = MindRecordOp.get_num_rows(dataset_file, self.load_dataset, self.sampler, self.num_padded)
return num_rows
return self._dataset_size
self.dataset_size = num_rows
return self.dataset_size
# manually set dataset_size as a tempoary solution.
def set_dataset_size(self, value):
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
if value >= 0:
self._dataset_size = value
self.dataset_size = value
else:
raise ValueError('Set dataset_size with negative value {}'.format(value))
@ -3205,6 +3234,7 @@ class GeneratorDataset(MappableDataset):
self.source = source
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples
self.num_shards = num_shards
if column_names is not None and not isinstance(column_names, list):
column_names = [column_names]
@ -3225,9 +3255,6 @@ class GeneratorDataset(MappableDataset):
self.column_names.append(col["name"])
self.column_types.append(DataType(col["type"]))
if source is not None and hasattr(source, "__len__"):
self._dataset_size = len(source)
def get_args(self):
args = super().get_args()
args["source"] = self.source
@ -3242,19 +3269,27 @@ class GeneratorDataset(MappableDataset):
Return:
Number, number of batches.
"""
rows_from_sampler = self._get_sampler_dataset_size()
if self.dataset_size is None:
if hasattr(self.source, "__len__"):
if not self.num_shards:
self.dataset_size = len(self.source)
else:
self.dataset_size = math.ceil(len(self.source)/self.num_shards)
if rows_from_sampler is None:
return self._dataset_size
if self._dataset_size is None:
return None
return min(rows_from_sampler, self._dataset_size)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
self.dataset_size = rows_from_sampler
else:
num_rows = 0
for _ in self.create_dict_iterator():
num_rows += 1
self.dataset_size = num_rows
return self.dataset_size
# manually set dataset_size as a temporary solution.
def set_dataset_size(self, value):
if value >= 0:
self._dataset_size = value
self.dataset_size = value
else:
raise ValueError('Set dataset_size with negative value {}'.format(value))
@ -3271,6 +3306,7 @@ class GeneratorDataset(MappableDataset):
new_op.column_types = copy.deepcopy(self.column_types, memodict)
new_op.column_names = copy.deepcopy(self.column_names, memodict)
new_op.num_samples = copy.deepcopy(self.num_samples, memodict)
new_op.dataset_size = self.dataset_size
new_op.sampler = copy.deepcopy(self.sampler)
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
@ -3433,19 +3469,18 @@ class TFRecordDataset(SourceDataset):
Return:
Number, number of batches.
"""
if self._dataset_size is None:
if self.dataset_size is None:
num_rows = TFReaderOp.get_num_rows(self.dataset_files, 8, estimate)
num_rows = get_num_rows(num_rows, self.num_shards)
if self.num_samples is None:
return num_rows
return min(self.num_samples, num_rows)
return self._dataset_size
self.dataset_size = get_num_rows(num_rows, self.num_shards)
if self.num_samples is not None and self.num_samples < self.dataset_size:
self.dataset_size = self.num_samples
return self.dataset_size
# manually set dataset_size as a tempoary solution.
def set_dataset_size(self, value):
logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.")
if value >= 0:
self._dataset_size = value
self.dataset_size = value
else:
raise ValueError('Set dataset_size with negative value {}'.format(value))
@ -3574,19 +3609,19 @@ class ManifestDataset(MappableDataset):
Return:
Number, number of batches.
"""
if self.class_indexing is None:
class_indexing = dict()
else:
class_indexing = self.class_indexing
if self.dataset_size is None:
if self.class_indexing is None:
class_indexing = dict()
else:
class_indexing = self.class_indexing
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[0]
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[0]
self.dataset_size = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
self.dataset_size = rows_from_sampler
return self.dataset_size
def num_classes(self):
"""
@ -3742,15 +3777,15 @@ class Cifar10Dataset(MappableDataset):
Return:
Number, number of batches.
"""
if self.dataset_size is None:
num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
self.dataset_size = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
self.dataset_size = rows_from_sampler
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
return self.dataset_size
def is_shuffled(self):
if self.shuffle_level is None:
@ -3878,15 +3913,15 @@ class Cifar100Dataset(MappableDataset):
Return:
Number, number of batches.
"""
if self.dataset_size is None:
num_rows = CifarOp.get_num_rows(self.dataset_dir, False)
self.dataset_size = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
num_rows = CifarOp.get_num_rows(self.dataset_dir, False)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
self.dataset_size = rows_from_sampler
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
return self.dataset_size
def is_shuffled(self):
if self.shuffle_level is None:
@ -3971,16 +4006,16 @@ class RandomDataset(SourceDataset):
Return:
Number, number of batches.
"""
if self.dataset_size is None:
num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
num_rows = CifarOp.get_num_rows(self.dataset_dir, True)
self.dataset_size = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
self.dataset_size = rows_from_sampler
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
return self.dataset_size
def is_shuffled(self):
if self.shuffle_level is None:
@ -4317,24 +4352,25 @@ class VOCDataset(MappableDataset):
Return:
Number, number of batches.
"""
if self.num_samples is None:
num_samples = 0
else:
num_samples = self.num_samples
if self.dataset_size is None:
if self.num_samples is None:
num_samples = 0
else:
num_samples = self.num_samples
if self.class_indexing is None:
class_indexing = dict()
else:
class_indexing = self.class_indexing
if self.class_indexing is None:
class_indexing = dict()
else:
class_indexing = self.class_indexing
num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples)
self.dataset_size = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
self.dataset_size = rows_from_sampler
return min(rows_from_sampler, rows_per_shard)
return self.dataset_size
def get_class_indexing(self):
"""
@ -4514,14 +4550,15 @@ class CocoDataset(MappableDataset):
Return:
Number, number of batches.
"""
num_rows = CocoOp.get_num_rows(self.dataset_dir, self.annotation_file, self.task)
rows_per_shard = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if self.dataset_size is None:
num_rows = CocoOp.get_num_rows(self.dataset_dir, self.annotation_file, self.task)
self.dataset_size = get_num_rows(num_rows, self.num_shards)
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
self.dataset_size = rows_from_sampler
return min(rows_from_sampler, rows_per_shard)
return self.dataset_size
def get_class_indexing(self):
"""
@ -4638,7 +4675,7 @@ class CelebADataset(MappableDataset):
Return:
Number, number of batches.
"""
if self._dataset_size is None:
if self.dataset_size is None:
dir = os.path.realpath(self.dataset_dir)
attr_file = os.path.join(dir, "list_attr_celeba.txt")
num_rows = ''
@ -4649,14 +4686,13 @@ class CelebADataset(MappableDataset):
raise RuntimeError("attr_file not found.")
except BaseException:
raise RuntimeError("Get dataset size failed from attribution file.")
rows_per_shard = get_num_rows(num_rows, self.num_shards)
if self.num_samples is not None:
rows_per_shard = min(self.num_samples, rows_per_shard)
self.dataset_size = get_num_rows(num_rows, self.num_shards)
if self.num_samples is not None and self.num_samples < self.dataset_size:
self.dataset_size = self.num_samples
rows_from_sampler = self._get_sampler_dataset_size()
if rows_from_sampler is None:
return rows_per_shard
return min(rows_from_sampler, rows_per_shard)
return self._dataset_size
if rows_from_sampler is not None and rows_from_sampler < self.dataset_size:
self.dataset_size = rows_from_sampler
return self.dataset_size
def is_shuffled(self):
if self.shuffle_level is None:
@ -4888,13 +4924,12 @@ class CLUEDataset(SourceDataset):
Return:
Number, number of batches.
"""
if self._dataset_size is None:
if self.dataset_size is None:
num_rows = ClueOp.get_num_rows(self.dataset_files)
num_rows = get_num_rows(num_rows, self.num_shards)
if self.num_samples is None:
return num_rows
return min(self.num_samples, num_rows)
return self._dataset_size
self.dataset_size = get_num_rows(num_rows, self.num_shards)
if self.num_samples is not None and self.num_samples < self.dataset_size:
self.dataset_size = self.num_samples
return self.dataset_size
def is_shuffled(self):
return self.shuffle_files
@ -4991,13 +5026,12 @@ class CSVDataset(SourceDataset):
Return:
Number, number of batches.
"""
if self._dataset_size is None:
if self.dataset_size is None:
num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None)
num_rows = get_num_rows(num_rows, self.num_shards)
if self.num_samples == -1:
return num_rows
return min(self.num_samples, num_rows)
return self._dataset_size
self.dataset_size = get_num_rows(num_rows, self.num_shards)
if self.num_samples != -1 and self.num_samples < self.dataset_size:
self.dataset_size = num_rows
return self.dataset_size
def is_shuffled(self):
return self.shuffle_files
@ -5082,15 +5116,14 @@ class TextFileDataset(SourceDataset):
Return:
Number, number of batches.
"""
if self._dataset_size is None:
if self.dataset_size is None:
num_rows = TextFileOp.get_num_rows(self.dataset_files)
num_rows = get_num_rows(num_rows, self.num_shards)
self.dataset_size = get_num_rows(num_rows, self.num_shards)
# If the user gave a num samples in the dataset, then the sampler will limit the rows returned
# to that amount. Account for that here in the row count
if self.num_samples is not None and self.num_samples > 0 and num_rows > self.num_samples:
num_rows = self.num_samples
return num_rows
return self._dataset_size
self.dataset_size = self.num_samples
return self.dataset_size
def is_shuffled(self):
return self.shuffle_files
@ -5308,6 +5341,7 @@ class BuildVocabDataset(DatasetOp):
new_op.vocab = self.vocab
new_op.special_tokens = copy.deepcopy(self.special_tokens)
new_op.special_first = copy.deepcopy(self.special_first)
new_op.dataset_size = self.dataset_size
return new_op
@ -5365,4 +5399,5 @@ class BuildSentencePieceVocabDataset(DatasetOp):
new_op.params = copy.deepcopy(self.params, memodict)
new_op.vocab = self.vocab
new_op.model_type = copy.deepcopy(self.model_type)
new_op.dataset_size = self.dataset_size
return new_op

View File

@ -72,6 +72,7 @@ class Iterator:
ITERATORS_LIST.append(weakref.ref(self))
# create a copy of tree and work on it.
self.dataset = copy.deepcopy(dataset)
self.ori_dataset = dataset
self.parent_subtree = []
# The dataset passed into the iterator is not the root of the tree.
@ -247,6 +248,8 @@ class Iterator:
if not data:
if self._index == 0:
logger.warning("No records available.")
if self.ori_dataset.dataset_size is None:
self.ori_dataset.dataset_size = self._index
raise StopIteration
self._index += 1
return data

View File

@ -31,7 +31,8 @@ class OneHot(cde.OneHotOp):
Tensor operation to apply one hot encoding.
Args:
num_classes (int): Number of classes of the label, it should be bigger than feature size.
num_classes (int): Number of classes of the label
it should be bigger than or equal to label class number.
Raises:
RuntimeError: feature size is bigger than num_classes.

View File

@ -382,6 +382,25 @@ def test_bucket_batch_multi_column():
assert same_shape_output == same_shape_expected_output
assert variable_shape_output == variable_shape_expected_output
def test_bucket_batch_get_dataset_size():
dataset = ds.GeneratorDataset((lambda: generate_sequential_same_shape(10)), ["col1"])
column_names = ["col1"]
bucket_boundaries = [1, 2, 3]
bucket_batch_sizes = [3, 3, 2, 2]
element_length_function = (lambda x: x[0] % 4)
dataset = dataset.bucket_batch_by_length(column_names, bucket_boundaries,
bucket_batch_sizes, element_length_function)
data_size = dataset.get_dataset_size()
num_rows = 0
for _ in dataset.create_dict_iterator():
num_rows += 1
assert data_size == num_rows
if __name__ == '__main__':
test_bucket_batch_invalid_input()
@ -394,3 +413,4 @@ if __name__ == '__main__':
test_bucket_batch_drop_remainder()
test_bucket_batch_default_length_function()
test_bucket_batch_multi_column()
test_bucket_batch_get_dataset_size()

View File

@ -25,6 +25,16 @@ def generator_1d():
for i in range(64):
yield (np.array([i]),)
class DatasetGenerator:
def __init__(self):
pass
def __getitem__(self, item):
return (np.array([item]),)
def __len__(self):
return 10
def test_generator_0():
"""
@ -615,6 +625,103 @@ def test_generator_schema():
type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]])
def test_generator_dataset_size_0():
"""
Test GeneratorDataset get_dataset_size by iterator method.
"""
logger.info("Test 1D Generator : 0 - 63 get_dataset_size")
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data_size = data1.get_dataset_size()
num_rows = 0
for _ in data1.create_dict_iterator(): # each data is a dictionary
num_rows = num_rows + 1
assert data_size == num_rows
def test_generator_dataset_size_1():
"""
Test GeneratorDataset get_dataset_size by __len__ method.
"""
logger.info("Test DatasetGenerator get_dataset_size")
dataset_generator = DatasetGenerator()
data1 = ds.GeneratorDataset(dataset_generator, ["data"])
data_size = data1.get_dataset_size()
num_rows = 0
for _ in data1.create_dict_iterator():
num_rows = num_rows + 1
assert data_size == num_rows
def test_generator_dataset_size_2():
"""
Test GeneratorDataset + repeat get_dataset_size
"""
logger.info("Test 1D Generator + repeat get_dataset_size")
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
data_size = data1.get_dataset_size()
num_rows = 0
for _ in data1.create_dict_iterator():
num_rows = num_rows + 1
assert data_size == num_rows
def test_generator_dataset_size_3():
"""
Test GeneratorDataset + batch get_dataset_size
"""
logger.info("Test 1D Generator + batch get_dataset_size")
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.batch(4)
data_size = data1.get_dataset_size()
num_rows = 0
for _ in data1.create_dict_iterator():
num_rows += 1
assert data_size == num_rows
def test_generator_dataset_size_4():
"""
Test GeneratorDataset + num_shards
"""
logger.info("Test 1D Generator : 0 - 63 + num_shards get_dataset_size")
dataset_generator = DatasetGenerator()
data1 = ds.GeneratorDataset(dataset_generator, ["data"], num_shards=3, shard_id=0)
data_size = data1.get_dataset_size()
num_rows = 0
for _ in data1.create_dict_iterator(): # each data is a dictionary
num_rows = num_rows + 1
assert data_size == num_rows
def test_generator_dataset_size_5():
"""
Test get_dataset_size after create_dict_iterator
"""
logger.info("Test get_dataset_size after create_dict_iterator")
dataset_generator = DatasetGenerator()
data1 = ds.GeneratorDataset(dataset_generator, ["data"], num_shards=3, shard_id=0)
num_rows = 0
for _ in data1.create_dict_iterator(): # each data is a dictionary
num_rows = num_rows + 1
data_size = data1.get_dataset_size()
assert data_size == num_rows
def manual_test_generator_keyboard_interrupt():
"""
Test keyboard_interrupt
@ -663,3 +770,9 @@ if __name__ == "__main__":
test_generator_num_samples()
test_generator_num_samples_underflow()
test_generator_schema()
test_generator_dataset_size_0()
test_generator_dataset_size_1()
test_generator_dataset_size_2()
test_generator_dataset_size_3()
test_generator_dataset_size_4()
test_generator_dataset_size_5()

View File

@ -484,6 +484,16 @@ def test_filter_by_generator_with_map_all_sort():
assert ret_data[0]["col1"] == 0
assert ret_data[9]["col6"] == 509
def test_filter_by_generator_get_dataset_size():
dataset = ds.GeneratorDataset(generator_1d, ["data"])
dataset = dataset.filter(predicate=filter_func_shuffle_after, num_parallel_workers=4)
data_sie = dataset.get_dataset_size()
num_iter = 0
for _ in dataset.create_dict_iterator():
num_iter += 1
assert data_sie == num_iter
if __name__ == '__main__':
test_diff_predicate_func()
@ -506,3 +516,4 @@ if __name__ == '__main__':
test_filter_by_generator_with_zip()
test_filter_by_generator_with_zip_after()
test_filter_by_generator_Partial()
test_filter_by_generator_get_dataset_size()