!5354 Fix randomdata issue
Merge pull request !5354 from xiefangqi/fix_randomdata_columnlist
This commit is contained in:
commit
3be136f6f0
|
@ -1367,7 +1367,7 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
|
|||
// ValideParams for RandomDataset
|
||||
bool RandomDataset::ValidateParams() {
|
||||
if (total_rows_ < 0) {
|
||||
MS_LOG(ERROR) << "RandomDataset: total_rows must be greater than 0, now get " << total_rows_;
|
||||
MS_LOG(ERROR) << "RandomDataset: total_rows must be greater than or equal 0, now get " << total_rows_;
|
||||
return false;
|
||||
}
|
||||
if (!ValidateDatasetSampler("RandomDataset", sampler_)) {
|
||||
|
@ -1413,6 +1413,9 @@ std::vector<std::shared_ptr<DatasetOp>> RandomDataset::Build() {
|
|||
|
||||
std::unique_ptr<DataSchema> data_schema;
|
||||
std::vector<std::string> columns_to_load;
|
||||
if (columns_list_.size() > 0) {
|
||||
columns_to_load = columns_list_;
|
||||
}
|
||||
if (!schema_file_path.empty() || !schema_json_string.empty()) {
|
||||
data_schema = std::make_unique<DataSchema>();
|
||||
if (!schema_file_path.empty()) {
|
||||
|
|
|
@ -265,11 +265,40 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
|
|||
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler())
|
||||
/// \return Shared pointer to the current Dataset
|
||||
template <typename T = std::shared_ptr<SchemaObj>>
|
||||
std::shared_ptr<RandomDataset> RandomData(const int32_t &total_rows = 0, T schema = nullptr,
|
||||
std::shared_ptr<RandomDataset> RandomData(const int32_t &total_rows = 0, const T &schema = nullptr,
|
||||
const std::vector<std::string> &columns_list = {},
|
||||
const std::shared_ptr<SamplerObj> &sampler = RandomSampler()) {
|
||||
auto ds = std::make_shared<RandomDataset>(total_rows, schema, columns_list, std::move(sampler));
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
if (total_rows < 0) {
|
||||
MS_LOG(ERROR) << "RandomDataset: total_rows must be greater than or equal 0, now get " << total_rows;
|
||||
return nullptr;
|
||||
}
|
||||
if (sampler == nullptr) {
|
||||
MS_LOG(ERROR) << "RandomDataset: Sampler is not constructed correctly, sampler: nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
if (!columns_list.empty()) {
|
||||
for (uint32_t i = 0; i < columns_list.size(); ++i) {
|
||||
if (columns_list[i].empty()) {
|
||||
MS_LOG(ERROR) << "RandomDataset:columns_list"
|
||||
<< "[" << i << "] should not be empty";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
std::set<std::string> columns_set(columns_list.begin(), columns_list.end());
|
||||
if (columns_set.size() != columns_list.size()) {
|
||||
MS_LOG(ERROR) << "RandomDataset:columns_list: Every column name should not be same with others";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
std::shared_ptr<RandomDataset> ds;
|
||||
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
|
||||
std::shared_ptr<SchemaObj> schema_obj = schema;
|
||||
ds =
|
||||
std::make_shared<RandomDataset>(total_rows, std::move(schema_obj), std::move(columns_list), std::move(sampler));
|
||||
} else {
|
||||
ds = std::make_shared<RandomDataset>(total_rows, std::move(schema), std::move(columns_list), std::move(sampler));
|
||||
}
|
||||
return ds;
|
||||
}
|
||||
|
||||
/// \brief Function to create a TextFileDataset
|
||||
|
|
|
@ -191,7 +191,7 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetBasic3) {
|
|||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomDatasetBasic4) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic3.";
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic4.";
|
||||
|
||||
// Create a RandomDataset
|
||||
u_int32_t curr_seed = GlobalContext::config_manager()->seed();
|
||||
|
@ -267,6 +267,133 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetBasic4) {
|
|||
GlobalContext::config_manager()->set_seed(curr_seed);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomDatasetBasic5) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic5.";
|
||||
|
||||
// Create a RandomDataset
|
||||
u_int32_t curr_seed = GlobalContext::config_manager()->seed();
|
||||
GlobalContext::config_manager()->set_seed(246);
|
||||
|
||||
std::string SCHEMA_FILE = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json";
|
||||
std::shared_ptr<Dataset> ds = RandomData(0, SCHEMA_FILE, {"col_sint32", "col_sint64", "col_1d"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a Repeat operation on ds
|
||||
ds = ds->Repeat(2);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
// Check if RandomDataOp read correct columns
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
EXPECT_EQ(row.size(), 3);
|
||||
|
||||
auto col_sint32 = row["col_sint32"];
|
||||
auto col_sint64 = row["col_sint64"];
|
||||
auto col_1d = row["col_1d"];
|
||||
|
||||
// validate shape
|
||||
ASSERT_EQ(col_sint32->shape(), TensorShape({1}));
|
||||
ASSERT_EQ(col_sint64->shape(), TensorShape({1}));
|
||||
ASSERT_EQ(col_1d->shape(), TensorShape({2}));
|
||||
|
||||
// validate Rank
|
||||
ASSERT_EQ(col_sint32->Rank(), 1);
|
||||
ASSERT_EQ(col_sint64->Rank(), 1);
|
||||
ASSERT_EQ(col_1d->Rank(), 1);
|
||||
|
||||
// validate type
|
||||
ASSERT_EQ(col_sint32->type(), DataType::DE_INT32);
|
||||
ASSERT_EQ(col_sint64->type(), DataType::DE_INT64);
|
||||
ASSERT_EQ(col_1d->type(), DataType::DE_INT64);
|
||||
|
||||
iter->GetNextRow(&row);
|
||||
i++;
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 984);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
GlobalContext::config_manager()->set_seed(curr_seed);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomDatasetBasic6) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic6.";
|
||||
|
||||
// Create a RandomDataset
|
||||
u_int32_t curr_seed = GlobalContext::config_manager()->seed();
|
||||
GlobalContext::config_manager()->set_seed(246);
|
||||
|
||||
std::string SCHEMA_FILE = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json";
|
||||
std::shared_ptr<Dataset> ds = RandomData(10, nullptr, {"col_sint32", "col_sint64", "col_1d"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
// Check if RandomDataOp read correct columns
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
iter->GetNextRow(&row);
|
||||
i++;
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 10);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
GlobalContext::config_manager()->set_seed(curr_seed);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomDatasetBasic7) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetBasic7.";
|
||||
|
||||
// Create a RandomDataset
|
||||
u_int32_t curr_seed = GlobalContext::config_manager()->seed();
|
||||
GlobalContext::config_manager()->set_seed(246);
|
||||
|
||||
std::string SCHEMA_FILE = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json";
|
||||
std::shared_ptr<Dataset> ds = RandomData(10, "", {"col_sint32", "col_sint64", "col_1d"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
// Check if RandomDataOp read correct columns
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
iter->GetNextRow(&row);
|
||||
i++;
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 10);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
GlobalContext::config_manager()->set_seed(curr_seed);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRandomDatasetWithNullSampler) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetWithNullSampler.";
|
||||
|
||||
|
|
Loading…
Reference in New Issue