!5114 fix bug in c-api: rename, concat, take
Merge pull request !5114 from luoyang/c-api-pyfunc
This commit is contained in:
commit
2e3f5cd41b
|
@ -221,7 +221,7 @@ std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::s
|
||||||
// Function to overload "+" operator to concat two datasets
|
// Function to overload "+" operator to concat two datasets
|
||||||
std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
|
std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
|
||||||
const std::shared_ptr<Dataset> &datasets2) {
|
const std::shared_ptr<Dataset> &datasets2) {
|
||||||
std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets1, datasets2}));
|
std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets2, datasets1}));
|
||||||
|
|
||||||
// Call derived class validation method.
|
// Call derived class validation method.
|
||||||
return ds->ValidateParams() ? ds : nullptr;
|
return ds->ValidateParams() ? ds : nullptr;
|
||||||
|
@ -1592,6 +1592,10 @@ bool ConcatDataset::ValidateParams() {
|
||||||
MS_LOG(ERROR) << "Concat: concatenated datasets are not specified.";
|
MS_LOG(ERROR) << "Concat: concatenated datasets are not specified.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
|
||||||
|
MS_LOG(ERROR) << "Concat: concatenated dataset should not be null.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1676,6 +1680,16 @@ bool RenameDataset::ValidateParams() {
|
||||||
MS_LOG(ERROR) << "input and output columns must be the same size";
|
MS_LOG(ERROR) << "input and output columns must be the same size";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
for (uint32_t i = 0; i < input_columns_.size(); ++i) {
|
||||||
|
if (input_columns_[i].empty()) {
|
||||||
|
MS_LOG(ERROR) << "input_columns: column name should not be empty.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (output_columns_[i].empty()) {
|
||||||
|
MS_LOG(ERROR) << "output_columns: column name should not be empty.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1766,7 +1780,7 @@ std::vector<std::shared_ptr<DatasetOp>> TakeDataset::Build() {
|
||||||
|
|
||||||
// Function to validate the parameters for TakeDataset
|
// Function to validate the parameters for TakeDataset
|
||||||
bool TakeDataset::ValidateParams() {
|
bool TakeDataset::ValidateParams() {
|
||||||
if (take_count_ < 0 && take_count_ != -1) {
|
if (take_count_ <= 0 && take_count_ != -1) {
|
||||||
MS_LOG(ERROR) << "Take: take_count should be either -1 or positive integer, take_count: " << take_count_;
|
MS_LOG(ERROR) << "Take: take_count should be either -1 or positive integer, take_count: " << take_count_;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -362,8 +362,8 @@ TEST_F(MindDataTestPipeline, TestProjectMapAutoInjection) {
|
||||||
iter->Stop();
|
iter->Stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MindDataTestPipeline, TestRenameFail) {
|
TEST_F(MindDataTestPipeline, TestRenameFail1) {
|
||||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameFail.";
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameFail1.";
|
||||||
// We expect this test to fail because input and output in Rename are not the same size
|
// We expect this test to fail because input and output in Rename are not the same size
|
||||||
|
|
||||||
// Create an ImageFolder Dataset
|
// Create an ImageFolder Dataset
|
||||||
|
@ -381,6 +381,20 @@ TEST_F(MindDataTestPipeline, TestRenameFail) {
|
||||||
EXPECT_EQ(ds, nullptr);
|
EXPECT_EQ(ds, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestPipeline, TestRenameFail2) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameFail2.";
|
||||||
|
// We expect this test to fail because input or output column name is empty
|
||||||
|
|
||||||
|
// Create an ImageFolder Dataset
|
||||||
|
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||||
|
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// Create a Rename operation on ds
|
||||||
|
ds = ds->Rename({"image", "label"}, {"col2", ""});
|
||||||
|
EXPECT_EQ(ds, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(MindDataTestPipeline, TestRenameSuccess) {
|
TEST_F(MindDataTestPipeline, TestRenameSuccess) {
|
||||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameSuccess.";
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRenameSuccess.";
|
||||||
|
|
||||||
|
@ -688,9 +702,15 @@ TEST_F(MindDataTestPipeline, TestTakeDatasetError1) {
|
||||||
|
|
||||||
// Create a Take operation on ds with invalid count input
|
// Create a Take operation on ds with invalid count input
|
||||||
int32_t count = -5;
|
int32_t count = -5;
|
||||||
ds = ds->Take(count);
|
auto ds1 = ds->Take(count);
|
||||||
// Expect nullptr for invalid input take_count
|
// Expect nullptr for invalid input take_count
|
||||||
EXPECT_EQ(ds, nullptr);
|
EXPECT_EQ(ds1, nullptr);
|
||||||
|
|
||||||
|
// Create a Take operation on ds with invalid count input
|
||||||
|
count = 0;
|
||||||
|
auto ds2 = ds->Take(count);
|
||||||
|
// Expect nullptr for invalid input take_count
|
||||||
|
EXPECT_EQ(ds2, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MindDataTestPipeline, TestTakeDatasetNormal) {
|
TEST_F(MindDataTestPipeline, TestTakeDatasetNormal) {
|
||||||
|
|
Loading…
Reference in New Issue