!1591 add get_dataset_size for CelebADataset
Merge pull request !1591 from yanghaitao/yht_celeba_get_dataset_size
This commit is contained in:
commit
21da86b393
|
@ -4024,6 +4024,31 @@ class CelebADataset(MappableDataset):
|
|||
args["shard_id"] = self.shard_id
|
||||
return args
|
||||
|
||||
def get_dataset_size(self):
|
||||
"""
|
||||
Get the number of batches in an epoch.
|
||||
|
||||
Return:
|
||||
Number, number of batches.
|
||||
"""
|
||||
if self._dataset_size is None:
|
||||
dir = os.path.realpath(self.dataset_dir)
|
||||
attr_file = os.path.join(dir, "list_attr_celeba.txt")
|
||||
num_rows = ''
|
||||
try:
|
||||
with open(attr_file, 'r') as f:
|
||||
num_rows = int(f.readline())
|
||||
except Exception:
|
||||
raise RuntimeError("Get dataset size failed from attribution file.")
|
||||
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples is not None:
|
||||
rows_per_shard = min(self.num_samples, rows_per_shard)
|
||||
rows_from_sampler = self._get_sampler_dataset_size()
|
||||
if rows_from_sampler is None:
|
||||
return rows_per_shard
|
||||
return min(rows_from_sampler, rows_per_shard)
|
||||
return self._dataset_size
|
||||
|
||||
def is_shuffled(self):
|
||||
if self.shuffle_level is None:
|
||||
return True
|
||||
|
|
|
@ -85,9 +85,14 @@ def test_celeba_dataset_distribute():
|
|||
count = count + 1
|
||||
assert count == 1
|
||||
|
||||
def test_celeba_get_dataset_size():
|
||||
data = ds.CelebADataset(DATA_DIR, decode=True, shuffle=False)
|
||||
size = data.get_dataset_size()
|
||||
assert size == 2
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_celeba_dataset_label()
|
||||
test_celeba_dataset_op()
|
||||
test_celeba_dataset_ext()
|
||||
test_celeba_dataset_distribute()
|
||||
test_celeba_get_dataset_size()
|
||||
|
|
Loading…
Reference in New Issue