diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index c48114e65c2..524f6cab689 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -1367,7 +1367,7 @@ std::vector> MnistDataset::Build() { // ValideParams for RandomDataset bool RandomDataset::ValidateParams() { if (total_rows_ < 0) { - MS_LOG(ERROR) << "RandomDataset: total_rows must be greater than 0, now get " << total_rows_; + MS_LOG(ERROR) << "RandomDataset: total_rows must be greater than or equal 0, now get " << total_rows_; return false; } if (!ValidateDatasetSampler("RandomDataset", sampler_)) { @@ -1413,6 +1413,9 @@ std::vector> RandomDataset::Build() { std::unique_ptr data_schema; std::vector columns_to_load; + if (columns_list_.size() > 0) { + columns_to_load = columns_list_; + } if (!schema_file_path.empty() || !schema_json_string.empty()) { data_schema = std::make_unique(); if (!schema_file_path.empty()) { diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index 41dc9993671..1cb272d4a29 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -265,11 +265,40 @@ std::shared_ptr operator+(const std::shared_ptr &dataset /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) /// \return Shared pointer to the current Dataset template > -std::shared_ptr RandomData(const int32_t &total_rows = 0, T schema = nullptr, +std::shared_ptr RandomData(const int32_t &total_rows = 0, const T &schema = nullptr, const std::vector &columns_list = {}, const std::shared_ptr &sampler = RandomSampler()) { - auto ds = std::make_shared(total_rows, schema, columns_list, std::move(sampler)); - return ds->ValidateParams() ? ds : nullptr; + if (total_rows < 0) { + MS_LOG(ERROR) << "RandomDataset: total_rows must be greater than or equal 0, now get " << total_rows; + return nullptr; + } + if (sampler == nullptr) { + MS_LOG(ERROR) << "RandomDataset: Sampler is not constructed correctly, sampler: nullptr"; + return nullptr; + } + if (!columns_list.empty()) { + for (uint32_t i = 0; i < columns_list.size(); ++i) { + if (columns_list[i].empty()) { + MS_LOG(ERROR) << "RandomDataset:columns_list" + << "[" << i << "] should not be empty"; + return nullptr; + } + } + std::set columns_set(columns_list.begin(), columns_list.end()); + if (columns_set.size() != columns_list.size()) { + MS_LOG(ERROR) << "RandomDataset:columns_list: Every column name should not be same with others"; + return nullptr; + } + } + std::shared_ptr ds; + if constexpr (std::is_same::value || std::is_same>::value) { + std::shared_ptr schema_obj = schema; + ds = + std::make_shared(total_rows, std::move(schema_obj), std::move(columns_list), std::move(sampler)); + } else { + ds = std::make_shared(total_rows, std::move(schema), std::move(columns_list), std::move(sampler)); + } + return ds; } /// \brief Function to create a TextFileDataset diff --git a/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc b/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc index 0506a58134e..9a8444f0afe 100644 --- a/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc +++ b/tests/ut/cpp/dataset/c_api_dataset_randomdata_test.cc @@ -191,7 +191,7 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetBasic3) { } TEST_F(MindDataTestPipeline, TestRandomDatasetBasic4) { - MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic3."; + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic4."; // Create a RandomDataset u_int32_t curr_seed = GlobalContext::config_manager()->seed(); @@ -267,6 +267,133 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetBasic4) { GlobalContext::config_manager()->set_seed(curr_seed); } +TEST_F(MindDataTestPipeline, TestRandomDatasetBasic5) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic5."; + + // Create a RandomDataset + u_int32_t curr_seed = GlobalContext::config_manager()->seed(); + GlobalContext::config_manager()->set_seed(246); + + std::string SCHEMA_FILE = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json"; + std::shared_ptr ds = RandomData(0, SCHEMA_FILE, {"col_sint32", "col_sint64", "col_1d"}); + EXPECT_NE(ds, nullptr); + + // Create a Repeat operation on ds + ds = ds->Repeat(2); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + // Check if RandomDataOp read correct columns + uint64_t i = 0; + while (row.size() != 0) { + EXPECT_EQ(row.size(), 3); + + auto col_sint32 = row["col_sint32"]; + auto col_sint64 = row["col_sint64"]; + auto col_1d = row["col_1d"]; + + // validate shape + ASSERT_EQ(col_sint32->shape(), TensorShape({1})); + ASSERT_EQ(col_sint64->shape(), TensorShape({1})); + ASSERT_EQ(col_1d->shape(), TensorShape({2})); + + // validate Rank + ASSERT_EQ(col_sint32->Rank(), 1); + ASSERT_EQ(col_sint64->Rank(), 1); + ASSERT_EQ(col_1d->Rank(), 1); + + // validate type + ASSERT_EQ(col_sint32->type(), DataType::DE_INT32); + ASSERT_EQ(col_sint64->type(), DataType::DE_INT64); + ASSERT_EQ(col_1d->type(), DataType::DE_INT64); + + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 984); + + // Manually terminate the pipeline + iter->Stop(); + GlobalContext::config_manager()->set_seed(curr_seed); +} + +TEST_F(MindDataTestPipeline, TestRandomDatasetBasic6) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic6."; + + // Create a RandomDataset + u_int32_t curr_seed = GlobalContext::config_manager()->seed(); + GlobalContext::config_manager()->set_seed(246); + + std::string SCHEMA_FILE = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json"; + std::shared_ptr ds = RandomData(10, nullptr, {"col_sint32", "col_sint64", "col_1d"}); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + // Check if RandomDataOp read correct columns + uint64_t i = 0; + while (row.size() != 0) { + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 10); + + // Manually terminate the pipeline + iter->Stop(); + GlobalContext::config_manager()->set_seed(curr_seed); +} + +TEST_F(MindDataTestPipeline, TestRandomDatasetBasic7) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic7."; + + // Create a RandomDataset + u_int32_t curr_seed = GlobalContext::config_manager()->seed(); + GlobalContext::config_manager()->set_seed(246); + + std::string SCHEMA_FILE = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json"; + std::shared_ptr ds = RandomData(10, "", {"col_sint32", "col_sint64", "col_1d"}); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + // Check if RandomDataOp read correct columns + uint64_t i = 0; + while (row.size() != 0) { + iter->GetNextRow(&row); + i++; + } + + EXPECT_EQ(i, 10); + + // Manually terminate the pipeline + iter->Stop(); + GlobalContext::config_manager()->set_seed(curr_seed); +} + TEST_F(MindDataTestPipeline, TestRandomDatasetWithNullSampler) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetWithNullSampler.";