!1591 add get_dataset_size for CelebADataset
Merge pull request !1591 from yanghaitao/yht_celeba_get_dataset_size
This commit is contained in:
commit
21da86b393
|
@ -4024,6 +4024,31 @@ class CelebADataset(MappableDataset):
|
||||||
args["shard_id"] = self.shard_id
|
args["shard_id"] = self.shard_id
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
def get_dataset_size(self):
|
||||||
|
"""
|
||||||
|
Get the number of batches in an epoch.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Number, number of batches.
|
||||||
|
"""
|
||||||
|
if self._dataset_size is None:
|
||||||
|
dir = os.path.realpath(self.dataset_dir)
|
||||||
|
attr_file = os.path.join(dir, "list_attr_celeba.txt")
|
||||||
|
num_rows = ''
|
||||||
|
try:
|
||||||
|
with open(attr_file, 'r') as f:
|
||||||
|
num_rows = int(f.readline())
|
||||||
|
except Exception:
|
||||||
|
raise RuntimeError("Get dataset size failed from attribution file.")
|
||||||
|
rows_per_shard = get_num_rows(num_rows, self.num_shards)
|
||||||
|
if self.num_samples is not None:
|
||||||
|
rows_per_shard = min(self.num_samples, rows_per_shard)
|
||||||
|
rows_from_sampler = self._get_sampler_dataset_size()
|
||||||
|
if rows_from_sampler is None:
|
||||||
|
return rows_per_shard
|
||||||
|
return min(rows_from_sampler, rows_per_shard)
|
||||||
|
return self._dataset_size
|
||||||
|
|
||||||
def is_shuffled(self):
|
def is_shuffled(self):
|
||||||
if self.shuffle_level is None:
|
if self.shuffle_level is None:
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -85,9 +85,14 @@ def test_celeba_dataset_distribute():
|
||||||
count = count + 1
|
count = count + 1
|
||||||
assert count == 1
|
assert count == 1
|
||||||
|
|
||||||
|
def test_celeba_get_dataset_size():
|
||||||
|
data = ds.CelebADataset(DATA_DIR, decode=True, shuffle=False)
|
||||||
|
size = data.get_dataset_size()
|
||||||
|
assert size == 2
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_celeba_dataset_label()
|
test_celeba_dataset_label()
|
||||||
test_celeba_dataset_op()
|
test_celeba_dataset_op()
|
||||||
test_celeba_dataset_ext()
|
test_celeba_dataset_ext()
|
||||||
test_celeba_dataset_distribute()
|
test_celeba_dataset_distribute()
|
||||||
|
test_celeba_get_dataset_size()
|
||||||
|
|
Loading…
Reference in New Issue