fix bug in c-api: rename, concat, take

This commit is contained in:
luoyang 2020-08-25 11:37:31 +08:00
parent 2de216961f
commit b26f6b3d5e
2 changed files with 40 additions and 6 deletions

View File

@ -221,7 +221,7 @@ std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::s
// Function to overload "+" operator to concat two datasets
std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
const std::shared_ptr<Dataset> &datasets2) {
std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets1, datasets2}));
std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets2, datasets1}));
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
@ -1592,6 +1592,10 @@ bool ConcatDataset::ValidateParams() {
MS_LOG(ERROR) << "Concat: concatenated datasets are not specified.";
return false;
}
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
MS_LOG(ERROR) << "Concat: concatenated dataset should not be null.";
return false;
}
return true;
}
@ -1676,6 +1680,16 @@ bool RenameDataset::ValidateParams() {
MS_LOG(ERROR) << "input and output columns must be the same size";
return false;
}
for (uint32_t i = 0; i < input_columns_.size(); ++i) {
if (input_columns_[i].empty()) {
MS_LOG(ERROR) << "input_columns: column name should not be empty.";
return false;
}
if (output_columns_[i].empty()) {
MS_LOG(ERROR) << "output_columns: column name should not be empty.";
return false;
}
}
return true;
}
@ -1766,7 +1780,7 @@ std::vector<std::shared_ptr<DatasetOp>> TakeDataset::Build() {
// Function to validate the parameters for TakeDataset
bool TakeDataset::ValidateParams() {
if (take_count_ < 0 && take_count_ != -1) {
if (take_count_ <= 0 && take_count_ != -1) {
MS_LOG(ERROR) << "Take: take_count should be either -1 or positive integer, take_count: " << take_count_;
return false;
}

View File

@ -362,8 +362,8 @@ TEST_F(MindDataTestPipeline, TestProjectMapAutoInjection) {
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRenameFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameFail.";
TEST_F(MindDataTestPipeline, TestRenameFail1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameFail1.";
// We expect this test to fail because input and output in Rename are not the same size
// Create an ImageFolder Dataset
@ -381,6 +381,20 @@ TEST_F(MindDataTestPipeline, TestRenameFail) {
EXPECT_EQ(ds, nullptr);
}
TEST_F(MindDataTestPipeline, TestRenameFail2) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameFail2.";
// We expect this test to fail because input or output column name is empty
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Rename operation on ds
ds = ds->Rename({"image", "label"}, {"col2", ""});
EXPECT_EQ(ds, nullptr);
}
TEST_F(MindDataTestPipeline, TestRenameSuccess) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameSuccess.";
@ -688,9 +702,15 @@ TEST_F(MindDataTestPipeline, TestTakeDatasetError1) {
// Create a Take operation on ds with invalid count input
int32_t count = -5;
ds = ds->Take(count);
auto ds1 = ds->Take(count);
// Expect nullptr for invalid input take_count
EXPECT_EQ(ds, nullptr);
EXPECT_EQ(ds1, nullptr);
// Create a Take operation on ds with invalid count input
count = 0;
auto ds2 = ds->Take(count);
// Expect nullptr for invalid input take_count
EXPECT_EQ(ds2, nullptr);
}
TEST_F(MindDataTestPipeline, TestTakeDatasetNormal) {