From b26f6b3d5ec2b4e73d36eb1d79bc54d4ef47d3fb Mon Sep 17 00:00:00 2001 From: luoyang Date: Tue, 25 Aug 2020 11:37:31 +0800 Subject: [PATCH] fix bug in c-api: rename, concat, take --- .../ccsrc/minddata/dataset/api/datasets.cc | 18 ++++++++++-- .../ut/cpp/dataset/c_api_dataset_ops_test.cc | 28 ++++++++++++++++--- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 375f7b521b8..add3324ec05 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -221,7 +221,7 @@ std::shared_ptr Mnist(const std::string &dataset_dir, const std::s // Function to overload "+" operator to concat two datasets std::shared_ptr operator+(const std::shared_ptr &datasets1, const std::shared_ptr &datasets2) { - std::shared_ptr ds = std::make_shared(std::vector({datasets1, datasets2})); + std::shared_ptr ds = std::make_shared(std::vector({datasets2, datasets1})); // Call derived class validation method. return ds->ValidateParams() ? ds : nullptr; @@ -1592,6 +1592,10 @@ bool ConcatDataset::ValidateParams() { MS_LOG(ERROR) << "Concat: concatenated datasets are not specified."; return false; } + if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) { + MS_LOG(ERROR) << "Concat: concatenated dataset should not be null."; + return false; + } return true; } @@ -1676,6 +1680,16 @@ bool RenameDataset::ValidateParams() { MS_LOG(ERROR) << "input and output columns must be the same size"; return false; } + for (uint32_t i = 0; i < input_columns_.size(); ++i) { + if (input_columns_[i].empty()) { + MS_LOG(ERROR) << "input_columns: column name should not be empty."; + return false; + } + if (output_columns_[i].empty()) { + MS_LOG(ERROR) << "output_columns: column name should not be empty."; + return false; + } + } return true; } @@ -1766,7 +1780,7 @@ std::vector> TakeDataset::Build() { // Function to validate the parameters for TakeDataset bool TakeDataset::ValidateParams() { - if (take_count_ < 0 && take_count_ != -1) { + if (take_count_ <= 0 && take_count_ != -1) { MS_LOG(ERROR) << "Take: take_count should be either -1 or positive integer, take_count: " << take_count_; return false; } diff --git a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc index e65cf8392d4..48b2f716692 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_ops_test.cc @@ -362,8 +362,8 @@ TEST_F(MindDataTestPipeline, TestProjectMapAutoInjection) { iter->Stop(); } -TEST_F(MindDataTestPipeline, TestRenameFail) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameFail."; +TEST_F(MindDataTestPipeline, TestRenameFail1) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameFail1."; // We expect this test to fail because input and output in Rename are not the same size // Create an ImageFolder Dataset @@ -381,6 +381,20 @@ TEST_F(MindDataTestPipeline, TestRenameFail) { EXPECT_EQ(ds, nullptr); } +TEST_F(MindDataTestPipeline, TestRenameFail2) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameFail2."; + // We expect this test to fail because input or output column name is empty + + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Rename operation on ds + ds = ds->Rename({"image", "label"}, {"col2", ""}); + EXPECT_EQ(ds, nullptr); +} + TEST_F(MindDataTestPipeline, TestRenameSuccess) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameSuccess."; @@ -688,9 +702,15 @@ TEST_F(MindDataTestPipeline, TestTakeDatasetError1) { // Create a Take operation on ds with invalid count input int32_t count = -5; - ds = ds->Take(count); + auto ds1 = ds->Take(count); // Expect nullptr for invalid input take_count - EXPECT_EQ(ds, nullptr); + EXPECT_EQ(ds1, nullptr); + + // Create a Take operation on ds with invalid count input + count = 0; + auto ds2 = ds->Take(count); + // Expect nullptr for invalid input take_count + EXPECT_EQ(ds2, nullptr); } TEST_F(MindDataTestPipeline, TestTakeDatasetNormal) {