fix get_dataset_size in CelebADataset when usage is not all
This commit is contained in:
parent
ba2fe29691
commit
4ff4c17632
|
@ -4974,9 +4974,30 @@ class CelebADataset(MappableDataset):
|
|||
with open(attr_file, 'r') as f:
|
||||
num_rows = int(f.readline())
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError("attr_file not found.")
|
||||
raise RuntimeError("attr file can not be found.")
|
||||
except BaseException:
|
||||
raise RuntimeError("Get dataset size failed from attribution file.")
|
||||
if self.usage != 'all':
|
||||
partition_file = os.path.join(dir, "list_eval_partition.txt")
|
||||
usage_type = 0
|
||||
partition_num = 0
|
||||
if self.usage == "train":
|
||||
usage_type = 0
|
||||
elif self.usage == "valid":
|
||||
usage_type = 1
|
||||
elif self.usage == "test":
|
||||
usage_type = 2
|
||||
try:
|
||||
with open(partition_file, 'r') as f:
|
||||
for line in f.readlines():
|
||||
split_line = line.split(' ')
|
||||
if int(split_line[1]) == usage_type:
|
||||
partition_num += 1
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError("Partition file can not be found")
|
||||
if partition_num < num_rows:
|
||||
num_rows = partition_num
|
||||
|
||||
self.dataset_size = get_num_rows(num_rows, self.num_shards)
|
||||
if self.num_samples is not None and self.num_samples < self.dataset_size:
|
||||
self.dataset_size = self.num_samples
|
||||
|
|
|
@ -100,7 +100,7 @@ TEST_F(MindDataTestPipeline, TestCelebADefault) {
|
|||
i++;
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 2);
|
||||
EXPECT_EQ(i, 4);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
|
|
@ -58,8 +58,10 @@ protected:
|
|||
|
||||
TEST_F(MindDataTestCelebaDataset, TestSequentialCeleba) {
|
||||
std::string dir = datasets_root_path_ + "/testCelebAData/";
|
||||
uint32_t expect_labels[2][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
|
||||
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}};
|
||||
uint32_t expect_labels[4][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
|
||||
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
|
||||
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
|
||||
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}};
|
||||
uint32_t count = 0;
|
||||
auto tree = Build({Celeba(16, 2, 32, dir)});
|
||||
tree->Prepare();
|
||||
|
@ -81,16 +83,20 @@ TEST_F(MindDataTestCelebaDataset, TestSequentialCeleba) {
|
|||
count++;
|
||||
di.GetNextAsMap(&tersor_map);
|
||||
}
|
||||
EXPECT_TRUE(count == 2);
|
||||
EXPECT_TRUE(count == 4);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) {
|
||||
std::string dir = datasets_root_path_ + "/testCelebAData/";
|
||||
uint32_t expect_labels[4][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
|
||||
uint32_t expect_labels[8][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
|
||||
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
|
||||
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
|
||||
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
|
||||
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}};
|
||||
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1},
|
||||
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
|
||||
{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1},
|
||||
{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}};
|
||||
uint32_t count = 0;
|
||||
auto tree = Build({Celeba(16, 2, 32, dir), Repeat(2)});
|
||||
tree->Prepare();
|
||||
|
@ -112,7 +118,7 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) {
|
|||
count++;
|
||||
di.GetNextAsMap(&tersor_map);
|
||||
}
|
||||
EXPECT_TRUE(count == 4);
|
||||
EXPECT_TRUE(count == 8);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
2
|
||||
4
|
||||
5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young
|
||||
1.JPEG -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1
|
||||
2.jpg -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 -1 1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1
|
||||
2.jpg -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 -1 1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1
|
||||
1.JPEG -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
1.JPEG 0
|
||||
2.jpeg 1
|
||||
2.jpeg 2
|
||||
2.jpeg 0
|
|
@ -25,6 +25,10 @@ def test_celeba_dataset_label():
|
|||
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
|
||||
0, 0, 1],
|
||||
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
|
||||
0, 0, 1],
|
||||
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
|
||||
0, 0, 1],
|
||||
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
|
||||
0, 0, 1]]
|
||||
count = 0
|
||||
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
|
@ -35,7 +39,7 @@ def test_celeba_dataset_label():
|
|||
for index in range(len(expect_labels[count])):
|
||||
assert item["attr"][index] == expect_labels[count][index]
|
||||
count = count + 1
|
||||
assert count == 2
|
||||
assert count == 4
|
||||
|
||||
|
||||
def test_celeba_dataset_op():
|
||||
|
@ -54,14 +58,17 @@ def test_celeba_dataset_op():
|
|||
logger.info("----------image--------")
|
||||
logger.info(item["image"])
|
||||
count = count + 1
|
||||
assert count == 4
|
||||
assert count == 8
|
||||
|
||||
|
||||
def test_celeba_dataset_ext():
|
||||
ext = [".JPEG"]
|
||||
data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext)
|
||||
expect_labels = [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
|
||||
0, 1, 0, 1, 0, 0, 1],
|
||||
expect_labels = [
|
||||
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
|
||||
0, 1, 0, 1, 0, 0, 1],
|
||||
[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1,
|
||||
0, 1, 0, 1, 0, 0, 1]]
|
||||
count = 0
|
||||
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
logger.info("----------image--------")
|
||||
|
@ -71,7 +78,7 @@ def test_celeba_dataset_ext():
|
|||
for index in range(len(expect_labels[count])):
|
||||
assert item["attr"][index] == expect_labels[count][index]
|
||||
count = count + 1
|
||||
assert count == 1
|
||||
assert count == 2
|
||||
|
||||
|
||||
def test_celeba_dataset_distribute():
|
||||
|
@ -83,14 +90,25 @@ def test_celeba_dataset_distribute():
|
|||
logger.info("----------attr--------")
|
||||
logger.info(item["attr"])
|
||||
count = count + 1
|
||||
assert count == 1
|
||||
assert count == 2
|
||||
|
||||
|
||||
def test_celeba_get_dataset_size():
|
||||
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True)
|
||||
size = data.get_dataset_size()
|
||||
assert size == 4
|
||||
|
||||
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="train")
|
||||
size = data.get_dataset_size()
|
||||
assert size == 2
|
||||
|
||||
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="valid")
|
||||
size = data.get_dataset_size()
|
||||
assert size == 1
|
||||
|
||||
data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="test")
|
||||
size = data.get_dataset_size()
|
||||
assert size == 1
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_celeba_dataset_label()
|
||||
|
|
|
@ -504,7 +504,7 @@ def test_celeba_padded():
|
|||
count = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count = count + 1
|
||||
assert count == 2
|
||||
assert count == 4
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue