fix get_dataset_size in CelebADataset when usage is not all
This commit is contained in:
parent
ba2fe29691
commit
4ff4c17632
|
@ -4974,9 +4974,30 @@ class CelebADataset(MappableDataset):
|
||||||
with open(attr_file, 'r') as f:
|
with open(attr_file, 'r') as f:
|
||||||
num_rows = int(f.readline())
|
num_rows = int(f.readline())
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise RuntimeError("attr_file not found.")
|
raise RuntimeError("attr file can not be found.")
|
||||||
except BaseException:
|
except BaseException:
|
||||||
raise RuntimeError("Get dataset size failed from attribution file.")
|
raise RuntimeError("Get dataset size failed from attribution file.")
|
||||||
|
if self.usage != 'all':
|
||||||
|
partition_file = os.path.join(dir, "list_eval_partition.txt")
|
||||||
|
usage_type = 0
|
||||||
|
partition_num = 0
|
||||||
|
if self.usage == "train":
|
||||||
|
usage_type = 0
|
||||||
|
elif self.usage == "valid":
|
||||||
|
usage_type = 1
|
||||||
|
elif self.usage == "test":
|
||||||
|
usage_type = 2
|
||||||
|
try:
|
||||||
|
with open(partition_file, 'r') as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
split_line = line.split(' ')
|
||||||
|
if int(split_line[1]) == usage_type:
|
||||||
|
partition_num += 1
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise RuntimeError("Partition file can not be found")
|
||||||
|
if partition_num < num_rows:
|
||||||
|
num_rows = partition_num
|
||||||
|
|
||||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||||
if self.num_samples is not None and self.num_samples < self.dataset_size:
|
if self.num_samples is not None and self.num_samples < self.dataset_size:
|
||||||
self.dataset_size = self.num_samples
|
self.dataset_size = self.num_samples
|
||||||
|
|
|
@ -100,7 +100,7 @@ TEST_F(MindDataTestPipeline, TestCelebADefault) {
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_EQ(i, 2);
|
EXPECT_EQ(i, 4);
|
||||||
|
|
||||||
// Manually terminate the pipeline
|
// Manually terminate the pipeline
|
||||||
iter->Stop();
|
iter->Stop();
|
||||||
|
|
|
@ -58,8 +58,10 @@ protected:
|
||||||
|
|
||||||
TEST_F(MindDataTestCelebaDataset, TestSequentialCeleba) {
|
TEST_F(MindDataTestCelebaDataset, TestSequentialCeleba) {
|
||||||
std::string dir = datasets_root_path_ + "/testCelebAData/";
|
std::string dir = datasets_root_path_ + "/testCelebAData/";
|
||||||
uint32_t expect_labels[2][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
|
uint32_t expect_labels[4][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
|
||||||
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}};
|
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
|
||||||
|
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
|
||||||
|
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}};
|
||||||
uint32_t count = 0;
|
uint32_t count = 0;
|
||||||
auto tree = Build({Celeba(16, 2, 32, dir)});
|
auto tree = Build({Celeba(16, 2, 32, dir)});
|
||||||
tree->Prepare();
|
tree->Prepare();
|
||||||
|
@ -81,16 +83,20 @@ TEST_F(MindDataTestCelebaDataset, TestSequentialCeleba) {
|
||||||
count++;
|
count++;
|
||||||
di.GetNextAsMap(&tersor_map);
|
di.GetNextAsMap(&tersor_map);
|
||||||
}
|
}
|
||||||
EXPECT_TRUE(count == 2);
|
EXPECT_TRUE(count == 4);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) {
|
TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) {
|
||||||
std::string dir = datasets_root_path_ + "/testCelebAData/";
|
std::string dir = datasets_root_path_ + "/testCelebAData/";
|
||||||
uint32_t expect_labels[4][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
|
uint32_t expect_labels[8][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
|
||||||
|
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
|
||||||
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
|
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
|
||||||
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
|
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
|
||||||
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}};
|
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
|
||||||
|
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
|
||||||
|
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
|
||||||
|
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}};
|
||||||
uint32_t count = 0;
|
uint32_t count = 0;
|
||||||
auto tree = Build({Celeba(16, 2, 32, dir), Repeat(2)});
|
auto tree = Build({Celeba(16, 2, 32, dir), Repeat(2)});
|
||||||
tree->Prepare();
|
tree->Prepare();
|
||||||
|
@ -112,7 +118,7 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) {
|
||||||
count++;
|
count++;
|
||||||
di.GetNextAsMap(&tersor_map);
|
di.GetNextAsMap(&tersor_map);
|
||||||
}
|
}
|
||||||
EXPECT_TRUE(count == 4);
|
EXPECT_TRUE(count == 8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
2
|
4
|
||||||
5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young
|
5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young
|
||||||
1.JPEG -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1
|
1.JPEG -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1
|
||||||
2.jpg -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 -1 1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1
|
2.jpg -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 -1 1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1
|
||||||
|
2.jpg -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 -1 1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1
|
||||||
|
1.JPEG -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
1.JPEG 0
|
||||||
|
2.jpeg 1
|
||||||
|
2.jpeg 2
|
||||||
|
2.jpeg 0
|
|
@ -25,6 +25,10 @@ def test_celeba_dataset_label():
|
||||||
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
|
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
|
||||||
0, 0, 1],
|
0, 0, 1],
|
||||||
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
|
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 1],
|
||||||
|
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
|
||||||
|
0, 0, 1],
|
||||||
|
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
|
||||||
0, 0, 1]]
|
0, 0, 1]]
|
||||||
count = 0
|
count = 0
|
||||||
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
@ -35,7 +39,7 @@ def test_celeba_dataset_label():
|
||||||
for index in range(len(expect_labels[count])):
|
for index in range(len(expect_labels[count])):
|
||||||
assert item["attr"][index] == expect_labels[count][index]
|
assert item["attr"][index] == expect_labels[count][index]
|
||||||
count = count + 1
|
count = count + 1
|
||||||
assert count == 2
|
assert count == 4
|
||||||
|
|
||||||
|
|
||||||
def test_celeba_dataset_op():
|
def test_celeba_dataset_op():
|
||||||
|
@ -54,14 +58,17 @@ def test_celeba_dataset_op():
|
||||||
logger.info("----------image--------")
|
logger.info("----------image--------")
|
||||||
logger.info(item["image"])
|
logger.info(item["image"])
|
||||||
count = count + 1
|
count = count + 1
|
||||||
assert count == 4
|
assert count == 8
|
||||||
|
|
||||||
|
|
||||||
def test_celeba_dataset_ext():
|
def test_celeba_dataset_ext():
|
||||||
ext = [".JPEG"]
|
ext = [".JPEG"]
|
||||||
data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext)
|
data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext)
|
||||||
expect_labels = [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
|
expect_labels = [
|
||||||
0, 1, 0, 1, 0, 0, 1],
|
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
|
||||||
|
0, 1, 0, 1, 0, 0, 1],
|
||||||
|
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
|
||||||
|
0, 1, 0, 1, 0, 0, 1]]
|
||||||
count = 0
|
count = 0
|
||||||
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
logger.info("----------image--------")
|
logger.info("----------image--------")
|
||||||
|
@ -71,7 +78,7 @@ def test_celeba_dataset_ext():
|
||||||
for index in range(len(expect_labels[count])):
|
for index in range(len(expect_labels[count])):
|
||||||
assert item["attr"][index] == expect_labels[count][index]
|
assert item["attr"][index] == expect_labels[count][index]
|
||||||
count = count + 1
|
count = count + 1
|
||||||
assert count == 1
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
def test_celeba_dataset_distribute():
|
def test_celeba_dataset_distribute():
|
||||||
|
@ -83,14 +90,25 @@ def test_celeba_dataset_distribute():
|
||||||
logger.info("----------attr--------")
|
logger.info("----------attr--------")
|
||||||
logger.info(item["attr"])
|
logger.info(item["attr"])
|
||||||
count = count + 1
|
count = count + 1
|
||||||
assert count == 1
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
def test_celeba_get_dataset_size():
|
def test_celeba_get_dataset_size():
|
||||||
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
|
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
|
||||||
size = data.get_dataset_size()
|
size = data.get_dataset_size()
|
||||||
|
assert size == 4
|
||||||
|
|
||||||
|
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="train")
|
||||||
|
size = data.get_dataset_size()
|
||||||
assert size == 2
|
assert size == 2
|
||||||
|
|
||||||
|
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="valid")
|
||||||
|
size = data.get_dataset_size()
|
||||||
|
assert size == 1
|
||||||
|
|
||||||
|
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="test")
|
||||||
|
size = data.get_dataset_size()
|
||||||
|
assert size == 1
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_celeba_dataset_label()
|
test_celeba_dataset_label()
|
||||||
|
|
|
@ -504,7 +504,7 @@ def test_celeba_padded():
|
||||||
count = 0
|
count = 0
|
||||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
count = count + 1
|
count = count + 1
|
||||||
assert count == 2
|
assert count == 4
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue