fix get_dataset_size in CelebADataset when usage is not all

This commit is contained in:
yanghaitao1 2020-09-17 23:31:11 -04:00
parent ba2fe29691
commit 4ff4c17632
7 changed files with 67 additions and 16 deletions

View File

@ -4974,9 +4974,30 @@ class CelebADataset(MappableDataset):
with open(attr_file, 'r') as f: with open(attr_file, 'r') as f:
num_rows = int(f.readline()) num_rows = int(f.readline())
except FileNotFoundError: except FileNotFoundError:
raise RuntimeError("attr_file not found.") raise RuntimeError("attr file can not be found.")
except BaseException: except BaseException:
raise RuntimeError("Get dataset size failed from attribution file.") raise RuntimeError("Get dataset size failed from attribution file.")
if self.usage != 'all':
partition_file = os.path.join(dir, "list_eval_partition.txt")
usage_type = 0
partition_num = 0
if self.usage == "train":
usage_type = 0
elif self.usage == "valid":
usage_type = 1
elif self.usage == "test":
usage_type = 2
try:
with open(partition_file, 'r') as f:
for line in f.readlines():
split_line = line.split(' ')
if int(split_line[1]) == usage_type:
partition_num += 1
except FileNotFoundError:
raise RuntimeError("Partition file can not be found")
if partition_num < num_rows:
num_rows = partition_num
self.dataset_size = get_num_rows(num_rows, self.num_shards) self.dataset_size = get_num_rows(num_rows, self.num_shards)
if self.num_samples is not None and self.num_samples < self.dataset_size: if self.num_samples is not None and self.num_samples < self.dataset_size:
self.dataset_size = self.num_samples self.dataset_size = self.num_samples

View File

@ -100,7 +100,7 @@ TEST_F(MindDataTestPipeline, TestCelebADefault) {
i++; i++;
} }
EXPECT_EQ(i, 2); EXPECT_EQ(i, 4);
// Manually terminate the pipeline // Manually terminate the pipeline
iter->Stop(); iter->Stop();

View File

@ -58,8 +58,10 @@ protected:
TEST_F(MindDataTestCelebaDataset, TestSequentialCeleba) { TEST_F(MindDataTestCelebaDataset, TestSequentialCeleba) {
std::string dir = datasets_root_path_ + "/testCelebAData/"; std::string dir = datasets_root_path_ + "/testCelebAData/";
uint32_t expect_labels[2][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}, uint32_t expect_labels[4][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}}; {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}};
uint32_t count = 0; uint32_t count = 0;
auto tree = Build({Celeba(16, 2, 32, dir)}); auto tree = Build({Celeba(16, 2, 32, dir)});
tree->Prepare(); tree->Prepare();
@ -81,16 +83,20 @@ TEST_F(MindDataTestCelebaDataset, TestSequentialCeleba) {
count++; count++;
di.GetNextAsMap(&tersor_map); di.GetNextAsMap(&tersor_map);
} }
EXPECT_TRUE(count == 2); EXPECT_TRUE(count == 4);
} }
} }
TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) { TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) {
std::string dir = datasets_root_path_ + "/testCelebAData/"; std::string dir = datasets_root_path_ + "/testCelebAData/";
uint32_t expect_labels[4][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}, uint32_t expect_labels[8][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}, {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}, {0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}}; {0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}};
uint32_t count = 0; uint32_t count = 0;
auto tree = Build({Celeba(16, 2, 32, dir), Repeat(2)}); auto tree = Build({Celeba(16, 2, 32, dir), Repeat(2)});
tree->Prepare(); tree->Prepare();
@ -112,7 +118,7 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) {
count++; count++;
di.GetNextAsMap(&tersor_map); di.GetNextAsMap(&tersor_map);
} }
EXPECT_TRUE(count == 4); EXPECT_TRUE(count == 8);
} }
} }

View File

@ -1,4 +1,6 @@
2 4
5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young 5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young
1.JPEG -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1 1.JPEG -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1
2.jpg -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 -1 1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 2.jpg -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 -1 1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1
2.jpg -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 -1 1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1
1.JPEG -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1

View File

@ -0,0 +1,4 @@
1.JPEG 0
2.jpeg 1
2.jpeg 2
2.jpeg 0

View File

@ -25,6 +25,10 @@ def test_celeba_dataset_label():
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
0, 0, 1], 0, 0, 1],
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 1],
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 1],
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
0, 0, 1]] 0, 0, 1]]
count = 0 count = 0
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
@ -35,7 +39,7 @@ def test_celeba_dataset_label():
for index in range(len(expect_labels[count])): for index in range(len(expect_labels[count])):
assert item["attr"][index] == expect_labels[count][index] assert item["attr"][index] == expect_labels[count][index]
count = count + 1 count = count + 1
assert count == 2 assert count == 4
def test_celeba_dataset_op(): def test_celeba_dataset_op():
@ -54,14 +58,17 @@ def test_celeba_dataset_op():
logger.info("----------image--------") logger.info("----------image--------")
logger.info(item["image"]) logger.info(item["image"])
count = count + 1 count = count + 1
assert count == 4 assert count == 8
def test_celeba_dataset_ext(): def test_celeba_dataset_ext():
ext = [".JPEG"] ext = [".JPEG"]
data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext) data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext)
expect_labels = [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, expect_labels = [
0, 1, 0, 1, 0, 0, 1], [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
0, 1, 0, 1, 0, 0, 1],
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
0, 1, 0, 1, 0, 0, 1]]
count = 0 count = 0
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("----------image--------") logger.info("----------image--------")
@ -71,7 +78,7 @@ def test_celeba_dataset_ext():
for index in range(len(expect_labels[count])): for index in range(len(expect_labels[count])):
assert item["attr"][index] == expect_labels[count][index] assert item["attr"][index] == expect_labels[count][index]
count = count + 1 count = count + 1
assert count == 1 assert count == 2
def test_celeba_dataset_distribute(): def test_celeba_dataset_distribute():
@ -83,14 +90,25 @@ def test_celeba_dataset_distribute():
logger.info("----------attr--------") logger.info("----------attr--------")
logger.info(item["attr"]) logger.info(item["attr"])
count = count + 1 count = count + 1
assert count == 1 assert count == 2
def test_celeba_get_dataset_size(): def test_celeba_get_dataset_size():
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True) data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
size = data.get_dataset_size() size = data.get_dataset_size()
assert size == 4
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="train")
size = data.get_dataset_size()
assert size == 2 assert size == 2
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="valid")
size = data.get_dataset_size()
assert size == 1
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="test")
size = data.get_dataset_size()
assert size == 1
if __name__ == '__main__': if __name__ == '__main__':
test_celeba_dataset_label() test_celeba_dataset_label()

View File

@ -504,7 +504,7 @@ def test_celeba_padded():
count = 0 count = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
count = count + 1 count = count + 1
assert count == 2 assert count == 4
if __name__ == '__main__': if __name__ == '__main__':