!5354 Fix randomdata issue

Merge pull request !5354 from xiefangqi/fix_randomdata_columnlist
This commit is contained in:
mindspore-ci-bot 2020-09-01 11:03:54 +08:00 committed by Gitee
commit 3be136f6f0
3 changed files with 164 additions and 5 deletions

View File

@ -1367,7 +1367,7 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
// ValideParams for RandomDataset
bool RandomDataset::ValidateParams() {
if (total_rows_ < 0) {
MS_LOG(ERROR) << "RandomDataset: total_rows must be greater than 0, now get " << total_rows_;
MS_LOG(ERROR) << "RandomDataset: total_rows must be greater than or equal 0, now get " << total_rows_;
return false;
}
if (!ValidateDatasetSampler("RandomDataset", sampler_)) {
@ -1413,6 +1413,9 @@ std::vector<std::shared_ptr<DatasetOp>> RandomDataset::Build() {
std::unique_ptr<DataSchema> data_schema;
std::vector<std::string> columns_to_load;
if (columns_list_.size() > 0) {
columns_to_load = columns_list_;
}
if (!schema_file_path.empty() || !schema_json_string.empty()) {
data_schema = std::make_unique<DataSchema>();
if (!schema_file_path.empty()) {

View File

@ -265,11 +265,40 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
/// \return Shared pointer to the current Dataset
template <typename T = std::shared_ptr<SchemaObj>>
std::shared_ptr<RandomDataset> RandomData(const int32_t &total_rows = 0, T schema = nullptr,
std::shared_ptr<RandomDataset> RandomData(const int32_t &total_rows = 0, const T &schema = nullptr,
const std::vector<std::string> &columns_list = {},
const std::shared_ptr<SamplerObj> &sampler = RandomSampler()) {
auto ds = std::make_shared<RandomDataset>(total_rows, schema, columns_list, std::move(sampler));
return ds->ValidateParams() ? ds : nullptr;
if (total_rows < 0) {
MS_LOG(ERROR) << "RandomDataset: total_rows must be greater than or equal 0, now get " << total_rows;
return nullptr;
}
if (sampler == nullptr) {
MS_LOG(ERROR) << "RandomDataset: Sampler is not constructed correctly, sampler: nullptr";
return nullptr;
}
if (!columns_list.empty()) {
for (uint32_t i = 0; i < columns_list.size(); ++i) {
if (columns_list[i].empty()) {
MS_LOG(ERROR) << "RandomDataset:columns_list"
<< "[" << i << "] should not be empty";
return nullptr;
}
}
std::set<std::string> columns_set(columns_list.begin(), columns_list.end());
if (columns_set.size() != columns_list.size()) {
MS_LOG(ERROR) << "RandomDataset:columns_list: Every column name should not be same with others";
return nullptr;
}
}
std::shared_ptr<RandomDataset> ds;
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
std::shared_ptr<SchemaObj> schema_obj = schema;
ds =
std::make_shared<RandomDataset>(total_rows, std::move(schema_obj), std::move(columns_list), std::move(sampler));
} else {
ds = std::make_shared<RandomDataset>(total_rows, std::move(schema), std::move(columns_list), std::move(sampler));
}
return ds;
}
/// \brief Function to create a TextFileDataset

View File

@ -191,7 +191,7 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetBasic3) {
}
TEST_F(MindDataTestPipeline, TestRandomDatasetBasic4) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic3.";
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic4.";
// Create a RandomDataset
u_int32_t curr_seed = GlobalContext::config_manager()->seed();
@ -267,6 +267,133 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetBasic4) {
GlobalContext::config_manager()->set_seed(curr_seed);
}
TEST_F(MindDataTestPipeline, TestRandomDatasetBasic5) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic5.";
// Create a RandomDataset
u_int32_t curr_seed = GlobalContext::config_manager()->seed();
GlobalContext::config_manager()->set_seed(246);
std::string SCHEMA_FILE = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json";
std::shared_ptr<Dataset> ds = RandomData(0, SCHEMA_FILE, {"col_sint32", "col_sint64", "col_1d"});
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
ds = ds->Repeat(2);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
// Check if RandomDataOp read correct columns
uint64_t i = 0;
while (row.size() != 0) {
EXPECT_EQ(row.size(), 3);
auto col_sint32 = row["col_sint32"];
auto col_sint64 = row["col_sint64"];
auto col_1d = row["col_1d"];
// validate shape
ASSERT_EQ(col_sint32->shape(), TensorShape({1}));
ASSERT_EQ(col_sint64->shape(), TensorShape({1}));
ASSERT_EQ(col_1d->shape(), TensorShape({2}));
// validate Rank
ASSERT_EQ(col_sint32->Rank(), 1);
ASSERT_EQ(col_sint64->Rank(), 1);
ASSERT_EQ(col_1d->Rank(), 1);
// validate type
ASSERT_EQ(col_sint32->type(), DataType::DE_INT32);
ASSERT_EQ(col_sint64->type(), DataType::DE_INT64);
ASSERT_EQ(col_1d->type(), DataType::DE_INT64);
iter->GetNextRow(&row);
i++;
}
EXPECT_EQ(i, 984);
// Manually terminate the pipeline
iter->Stop();
GlobalContext::config_manager()->set_seed(curr_seed);
}
TEST_F(MindDataTestPipeline, TestRandomDatasetBasic6) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic6.";
// Create a RandomDataset
u_int32_t curr_seed = GlobalContext::config_manager()->seed();
GlobalContext::config_manager()->set_seed(246);
std::string SCHEMA_FILE = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json";
std::shared_ptr<Dataset> ds = RandomData(10, nullptr, {"col_sint32", "col_sint64", "col_1d"});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
// Check if RandomDataOp read correct columns
uint64_t i = 0;
while (row.size() != 0) {
iter->GetNextRow(&row);
i++;
}
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
GlobalContext::config_manager()->set_seed(curr_seed);
}
TEST_F(MindDataTestPipeline, TestRandomDatasetBasic7) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic7.";
// Create a RandomDataset
u_int32_t curr_seed = GlobalContext::config_manager()->seed();
GlobalContext::config_manager()->set_seed(246);
std::string SCHEMA_FILE = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json";
std::shared_ptr<Dataset> ds = RandomData(10, "", {"col_sint32", "col_sint64", "col_1d"});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
// Check if RandomDataOp read correct columns
uint64_t i = 0;
while (row.size() != 0) {
iter->GetNextRow(&row);
i++;
}
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
GlobalContext::config_manager()->set_seed(curr_seed);
}
TEST_F(MindDataTestPipeline, TestRandomDatasetWithNullSampler) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetWithNullSampler.";