From 4ff4c17632938b2da63e89f63c110cc2f79f6760 Mon Sep 17 00:00:00 2001 From: yanghaitao1 Date: Thu, 17 Sep 2020 23:31:11 -0400 Subject: [PATCH] fix get_dataset_size in CelebADataset when usage is not all --- mindspore/dataset/engine/datasets.py | 23 +++++++++++++- tests/ut/cpp/dataset/c_api_datasets_test.cc | 2 +- tests/ut/cpp/dataset/celeba_op_test.cc | 18 +++++++---- .../testCelebAData/list_attr_celeba.txt | 4 ++- .../testCelebAData/list_eval_partition.txt | 4 +++ .../ut/python/dataset/test_datasets_celeba.py | 30 +++++++++++++++---- tests/ut/python/dataset/test_paddeddataset.py | 2 +- 7 files changed, 67 insertions(+), 16 deletions(-) create mode 100644 tests/ut/data/dataset/testCelebAData/list_eval_partition.txt diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index a392a8c1ede..c2577dd3ffa 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -4974,9 +4974,30 @@ class CelebADataset(MappableDataset): with open(attr_file, 'r') as f: num_rows = int(f.readline()) except FileNotFoundError: - raise RuntimeError("attr_file not found.") + raise RuntimeError("attr file can not be found.") except BaseException: raise RuntimeError("Get dataset size failed from attribution file.") + if self.usage != 'all': + partition_file = os.path.join(dir, "list_eval_partition.txt") + usage_type = 0 + partition_num = 0 + if self.usage == "train": + usage_type = 0 + elif self.usage == "valid": + usage_type = 1 + elif self.usage == "test": + usage_type = 2 + try: + with open(partition_file, 'r') as f: + for line in f.readlines(): + split_line = line.split(' ') + if int(split_line[1]) == usage_type: + partition_num += 1 + except FileNotFoundError: + raise RuntimeError("Partition file can not be found") + if partition_num < num_rows: + num_rows = partition_num + self.dataset_size = get_num_rows(num_rows, self.num_shards) if self.num_samples is not None and self.num_samples < self.dataset_size: self.dataset_size = self.num_samples diff --git a/tests/ut/cpp/dataset/c_api_datasets_test.cc b/tests/ut/cpp/dataset/c_api_datasets_test.cc index ad950f83345..5597398ee17 100644 --- a/tests/ut/cpp/dataset/c_api_datasets_test.cc +++ b/tests/ut/cpp/dataset/c_api_datasets_test.cc @@ -100,7 +100,7 @@ TEST_F(MindDataTestPipeline, TestCelebADefault) { i++; } - EXPECT_EQ(i, 2); + EXPECT_EQ(i, 4); // Manually terminate the pipeline iter->Stop(); diff --git a/tests/ut/cpp/dataset/celeba_op_test.cc b/tests/ut/cpp/dataset/celeba_op_test.cc index 915c7b25b26..ded5490fd55 100644 --- a/tests/ut/cpp/dataset/celeba_op_test.cc +++ b/tests/ut/cpp/dataset/celeba_op_test.cc @@ -58,8 +58,10 @@ protected: TEST_F(MindDataTestCelebaDataset, TestSequentialCeleba) { std::string dir = datasets_root_path_ + "/testCelebAData/"; - uint32_t expect_labels[2][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}, - {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}}; + uint32_t expect_labels[4][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}, + {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}, + {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}, + {0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}}; uint32_t count = 0; auto tree = Build({Celeba(16, 2, 32, dir)}); tree->Prepare(); @@ -81,16 +83,20 @@ TEST_F(MindDataTestCelebaDataset, TestSequentialCeleba) { count++; di.GetNextAsMap(&tersor_map); } - EXPECT_TRUE(count == 2); + EXPECT_TRUE(count == 4); } } TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) { std::string dir = datasets_root_path_ + "/testCelebAData/"; - uint32_t expect_labels[4][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}, + uint32_t expect_labels[8][40] = {{0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}, + {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}, {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}, {0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}, - {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}}; + {0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}, + {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}, + {0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}, + {0,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,0,1,1,0,1,0,1,0,0,1}}; uint32_t count = 0; auto tree = Build({Celeba(16, 2, 32, dir), Repeat(2)}); tree->Prepare(); @@ -112,7 +118,7 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) { count++; di.GetNextAsMap(&tersor_map); } - EXPECT_TRUE(count == 4); + EXPECT_TRUE(count == 8); } } diff --git a/tests/ut/data/dataset/testCelebAData/list_attr_celeba.txt b/tests/ut/data/dataset/testCelebAData/list_attr_celeba.txt index 0e57965ea6c..289044ab8e9 100644 --- a/tests/ut/data/dataset/testCelebAData/list_attr_celeba.txt +++ b/tests/ut/data/dataset/testCelebAData/list_attr_celeba.txt @@ -1,4 +1,6 @@ -2 +4 5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young 1.JPEG -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1 2.jpg -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 -1 1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 +2.jpg -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 -1 1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 -1 1 +1.JPEG -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1 diff --git a/tests/ut/data/dataset/testCelebAData/list_eval_partition.txt b/tests/ut/data/dataset/testCelebAData/list_eval_partition.txt new file mode 100644 index 00000000000..156c46cc3d4 --- /dev/null +++ b/tests/ut/data/dataset/testCelebAData/list_eval_partition.txt @@ -0,0 +1,4 @@ +1.JPEG 0 +2.jpeg 1 +2.jpeg 2 +2.jpeg 0 diff --git a/tests/ut/python/dataset/test_datasets_celeba.py b/tests/ut/python/dataset/test_datasets_celeba.py index 329ff8888a9..3710e6ad491 100644 --- a/tests/ut/python/dataset/test_datasets_celeba.py +++ b/tests/ut/python/dataset/test_datasets_celeba.py @@ -25,6 +25,10 @@ def test_celeba_dataset_label(): [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1], [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 1], + [0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, + 0, 0, 1], + [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1]] count = 0 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): @@ -35,7 +39,7 @@ def test_celeba_dataset_label(): for index in range(len(expect_labels[count])): assert item["attr"][index] == expect_labels[count][index] count = count + 1 - assert count == 2 + assert count == 4 def test_celeba_dataset_op(): @@ -54,14 +58,17 @@ def test_celeba_dataset_op(): logger.info("----------image--------") logger.info(item["image"]) count = count + 1 - assert count == 4 + assert count == 8 def test_celeba_dataset_ext(): ext = [".JPEG"] data = ds.CelebADataset(DATA_DIR, decode=True, extensions=ext) - expect_labels = [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, - 0, 1, 0, 1, 0, 0, 1], + expect_labels = [ + [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, + 0, 1, 0, 1, 0, 0, 1], + [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, + 0, 1, 0, 1, 0, 0, 1]] count = 0 for item in data.create_dict_iterator(num_epochs=1, output_numpy=True): logger.info("----------image--------") @@ -71,7 +78,7 @@ def test_celeba_dataset_ext(): for index in range(len(expect_labels[count])): assert item["attr"][index] == expect_labels[count][index] count = count + 1 - assert count == 1 + assert count == 2 def test_celeba_dataset_distribute(): @@ -83,14 +90,25 @@ def test_celeba_dataset_distribute(): logger.info("----------attr--------") logger.info(item["attr"]) count = count + 1 - assert count == 1 + assert count == 2 def test_celeba_get_dataset_size(): data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True) size = data.get_dataset_size() + assert size == 4 + + data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="train") + size = data.get_dataset_size() assert size == 2 + data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="valid") + size = data.get_dataset_size() + assert size == 1 + + data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True, usage="test") + size = data.get_dataset_size() + assert size == 1 if __name__ == '__main__': test_celeba_dataset_label() diff --git a/tests/ut/python/dataset/test_paddeddataset.py b/tests/ut/python/dataset/test_paddeddataset.py index 4dbc187447b..3de4b50833e 100644 --- a/tests/ut/python/dataset/test_paddeddataset.py +++ b/tests/ut/python/dataset/test_paddeddataset.py @@ -504,7 +504,7 @@ def test_celeba_padded(): count = 0 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): count = count + 1 - assert count == 2 + assert count == 4 if __name__ == '__main__':