forked from mindspore-Ecosystem/mindspore
fix bug in c-api: rename, concat, take
This commit is contained in:
parent
2de216961f
commit
b26f6b3d5e
|
@ -221,7 +221,7 @@ std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::s
|
|||
// Function to overload "+" operator to concat two datasets
|
||||
std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
|
||||
const std::shared_ptr<Dataset> &datasets2) {
|
||||
std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets1, datasets2}));
|
||||
std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets2, datasets1}));
|
||||
|
||||
// Call derived class validation method.
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
|
@ -1592,6 +1592,10 @@ bool ConcatDataset::ValidateParams() {
|
|||
MS_LOG(ERROR) << "Concat: concatenated datasets are not specified.";
|
||||
return false;
|
||||
}
|
||||
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
|
||||
MS_LOG(ERROR) << "Concat: concatenated dataset should not be null.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1676,6 +1680,16 @@ bool RenameDataset::ValidateParams() {
|
|||
MS_LOG(ERROR) << "input and output columns must be the same size";
|
||||
return false;
|
||||
}
|
||||
for (uint32_t i = 0; i < input_columns_.size(); ++i) {
|
||||
if (input_columns_[i].empty()) {
|
||||
MS_LOG(ERROR) << "input_columns: column name should not be empty.";
|
||||
return false;
|
||||
}
|
||||
if (output_columns_[i].empty()) {
|
||||
MS_LOG(ERROR) << "output_columns: column name should not be empty.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1766,7 +1780,7 @@ std::vector<std::shared_ptr<DatasetOp>> TakeDataset::Build() {
|
|||
|
||||
// Function to validate the parameters for TakeDataset
|
||||
bool TakeDataset::ValidateParams() {
|
||||
if (take_count_ < 0 && take_count_ != -1) {
|
||||
if (take_count_ <= 0 && take_count_ != -1) {
|
||||
MS_LOG(ERROR) << "Take: take_count should be either -1 or positive integer, take_count: " << take_count_;
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -362,8 +362,8 @@ TEST_F(MindDataTestPipeline, TestProjectMapAutoInjection) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRenameFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameFail.";
|
||||
TEST_F(MindDataTestPipeline, TestRenameFail1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameFail1.";
|
||||
// We expect this test to fail because input and output in Rename are not the same size
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
|
@ -381,6 +381,20 @@ TEST_F(MindDataTestPipeline, TestRenameFail) {
|
|||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRenameFail2) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameFail2.";
|
||||
// We expect this test to fail because input or output column name is empty
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Rename operation on ds
|
||||
ds = ds->Rename({"image", "label"}, {"col2", ""});
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRenameSuccess) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameSuccess.";
|
||||
|
||||
|
@ -688,9 +702,15 @@ TEST_F(MindDataTestPipeline, TestTakeDatasetError1) {
|
|||
|
||||
// Create a Take operation on ds with invalid count input
|
||||
int32_t count = -5;
|
||||
ds = ds->Take(count);
|
||||
auto ds1 = ds->Take(count);
|
||||
// Expect nullptr for invalid input take_count
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
EXPECT_EQ(ds1, nullptr);
|
||||
|
||||
// Create a Take operation on ds with invalid count input
|
||||
count = 0;
|
||||
auto ds2 = ds->Take(count);
|
||||
// Expect nullptr for invalid input take_count
|
||||
EXPECT_EQ(ds2, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTakeDatasetNormal) {
|
||||
|
|
Loading…
Reference in New Issue