!5187 C++ API: Lexicographical order support for CLUE, CSV & TextFile Datasets
Merge pull request !5187 from cathwong/ckw_c_api_fixes2
This commit is contained in:
commit
cd1240180b
|
@ -1009,9 +1009,14 @@ std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() {
|
|||
}
|
||||
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
|
||||
// Sort the dataset files in a lexicographical order
|
||||
std::vector<std::string> sorted_dataset_files = dataset_files_;
|
||||
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
|
||||
|
||||
std::shared_ptr<ClueOp> clue_op =
|
||||
std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map,
|
||||
dataset_files_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
|
||||
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_);
|
||||
RETURN_EMPTY_IF_ERROR(clue_op->Init());
|
||||
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
|
@ -1019,10 +1024,10 @@ std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() {
|
|||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(dataset_files_, &num_rows));
|
||||
RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows));
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(dataset_files_.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
|
@ -1162,6 +1167,11 @@ std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() {
|
|||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
|
||||
// Sort the dataset files in a lexicographical order
|
||||
std::vector<std::string> sorted_dataset_files = dataset_files_;
|
||||
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
|
||||
|
||||
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list;
|
||||
for (auto v : column_defaults_) {
|
||||
if (v->type == CsvType::INT) {
|
||||
|
@ -1177,8 +1187,8 @@ std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() {
|
|||
}
|
||||
|
||||
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
|
||||
dataset_files_, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_, num_samples_,
|
||||
worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
|
||||
sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_,
|
||||
num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
|
||||
RETURN_EMPTY_IF_ERROR(csv_op->Init());
|
||||
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
|
@ -1186,10 +1196,10 @@ std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() {
|
|||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(CsvOp::CountAllFileRows(dataset_files_, column_names_.empty(), &num_rows));
|
||||
RETURN_EMPTY_IF_ERROR(CsvOp::CountAllFileRows(sorted_dataset_files, column_names_.empty(), &num_rows));
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(dataset_files_.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
|
@ -1398,6 +1408,10 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
|
|||
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
|
||||
// Sort the dataset files in a lexicographical order
|
||||
std::vector<std::string> sorted_dataset_files = dataset_files_;
|
||||
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
|
||||
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
|
@ -1405,7 +1419,7 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
|
|||
|
||||
// Create and initalize TextFileOp
|
||||
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
|
||||
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), dataset_files_,
|
||||
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files,
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(nullptr));
|
||||
RETURN_EMPTY_IF_ERROR(text_file_op->Init());
|
||||
|
||||
|
@ -1415,10 +1429,10 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
|
|||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(dataset_files_, &num_rows));
|
||||
RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(sorted_dataset_files, &num_rows));
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(dataset_files_.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
|
|
|
@ -362,8 +362,8 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetIFLYTEK) {
|
|||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFiles) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCLUEDatasetShuffleFiles.";
|
||||
TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFilesA) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCLUEDatasetShuffleFilesA.";
|
||||
// Test CLUE Dataset with files shuffle, num_parallel_workers=1
|
||||
|
||||
// Set configuration
|
||||
|
@ -373,7 +373,74 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFiles) {
|
|||
GlobalContext::config_manager()->set_seed(135);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(1);
|
||||
|
||||
// Create a CLUE Dataset, with two text files
|
||||
// Create a CLUE Dataset, with two text files, dev.json and train.json, in lexicographical order
|
||||
// Note: train.json has 3 rows
|
||||
// Note: dev.json has 3 rows
|
||||
// Use default of all samples
|
||||
// They have the same keywords
|
||||
// Set shuffle to files shuffle
|
||||
std::string clue_file1 = datasets_root_path_ + "/testCLUE/afqmc/train.json";
|
||||
std::string clue_file2 = datasets_root_path_ + "/testCLUE/afqmc/dev.json";
|
||||
std::string task = "AFQMC";
|
||||
std::string usage = "train";
|
||||
std::shared_ptr<Dataset> ds = CLUE({clue_file2, clue_file1}, task, usage, 0, ShuffleMode::kFiles);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("sentence1"), row.end());
|
||||
std::vector<std::string> expected_result = {
|
||||
"你有花呗吗",
|
||||
"吃饭能用花呗吗",
|
||||
"蚂蚁花呗支付金额有什么限制",
|
||||
"蚂蚁借呗等额还款能否换成先息后本",
|
||||
"蚂蚁花呗说我违约了",
|
||||
"帮我看看本月花呗账单结清了没"
|
||||
};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto text = row["sentence1"];
|
||||
std::string_view sv;
|
||||
text->GetItemAt(&sv, {0});
|
||||
std::string ss(sv);
|
||||
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||
// Compare against expected result
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
|
||||
i++;
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
// Expect 3 + 3 = 6 samples
|
||||
EXPECT_EQ(i, 6);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFilesB) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCLUEDatasetShuffleFilesB.";
|
||||
// Test CLUE Dataset with files shuffle, num_parallel_workers=1
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(135);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(1);
|
||||
|
||||
// Create a CLUE Dataset, with two text files, train.json and dev.json, in non-lexicographical order
|
||||
// Note: train.json has 3 rows
|
||||
// Note: dev.json has 3 rows
|
||||
// Use default of all samples
|
||||
|
@ -397,12 +464,12 @@ TEST_F(MindDataTestPipeline, TestCLUEDatasetShuffleFiles) {
|
|||
|
||||
EXPECT_NE(row.find("sentence1"), row.end());
|
||||
std::vector<std::string> expected_result = {
|
||||
"蚂蚁借呗等额还款能否换成先息后本",
|
||||
"蚂蚁花呗说我违约了",
|
||||
"帮我看看本月花呗账单结清了没",
|
||||
"你有花呗吗",
|
||||
"吃饭能用花呗吗",
|
||||
"蚂蚁花呗支付金额有什么限制"
|
||||
"蚂蚁花呗支付金额有什么限制",
|
||||
"蚂蚁借呗等额还款能否换成先息后本",
|
||||
"蚂蚁花呗说我违约了",
|
||||
"帮我看看本月花呗账单结清了没"
|
||||
};
|
||||
|
||||
uint64_t i = 0;
|
||||
|
|
|
@ -359,8 +359,8 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetException) {
|
|||
EXPECT_EQ(ds5, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFiles) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVDatasetShuffleFiles.";
|
||||
TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFilesA) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVDatasetShuffleFilesA.";
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
|
@ -369,7 +369,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFiles) {
|
|||
GlobalContext::config_manager()->set_seed(130);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||
|
||||
// Create a CSVDataset, with single CSV file
|
||||
// Create a CSVDataset, with 2 CSV files, 1.csv and append.csv in lexicographical order
|
||||
std::string file1 = datasets_root_path_ + "/testCSV/1.csv";
|
||||
std::string file2 = datasets_root_path_ + "/testCSV/append.csv";
|
||||
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
|
||||
|
@ -418,6 +418,66 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFiles) {
|
|||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFilesB) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVDatasetShuffleFilesB.";
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(130);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||
|
||||
// Create a CSVDataset, with 2 CSV files, append.csv and 1.csv in non-lexicographical order
|
||||
std::string file1 = datasets_root_path_ + "/testCSV/1.csv";
|
||||
std::string file2 = datasets_root_path_ + "/testCSV/append.csv";
|
||||
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
|
||||
std::shared_ptr<Dataset> ds = CSV({file2, file1}, ',', {}, column_names, -1, ShuffleMode::kFiles);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
EXPECT_NE(row.find("col1"), row.end());
|
||||
std::vector<std::vector<std::string>> expected_result = {
|
||||
{"13", "14", "15", "16"},
|
||||
{"1", "2", "3", "4"},
|
||||
{"17", "18", "19", "20"},
|
||||
{"5", "6", "7", "8"},
|
||||
{"21", "22", "23", "24"},
|
||||
{"9", "10", "11", "12"},
|
||||
};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
for (int j = 0; j < column_names.size(); j++) {
|
||||
auto text = row[column_names[j]];
|
||||
std::string_view sv;
|
||||
text->GetItemAt(&sv, {0});
|
||||
std::string ss(sv);
|
||||
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i][j].c_str());
|
||||
}
|
||||
iter->GetNextRow(&row);
|
||||
i++;
|
||||
}
|
||||
|
||||
// Expect 6 samples
|
||||
EXPECT_EQ(i, 6);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleGlobal) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCSVDatasetShuffleGlobal.";
|
||||
// Test CSV Dataset with GLOBLE shuffle
|
||||
|
|
|
@ -165,8 +165,8 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetFail7) {
|
|||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse1.";
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse1A) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse1A.";
|
||||
// Test TextFile Dataset with two text files and no shuffle, num_parallel_workers=1
|
||||
|
||||
// Set configuration
|
||||
|
@ -176,7 +176,7 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse1) {
|
|||
GlobalContext::config_manager()->set_seed(654);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(1);
|
||||
|
||||
// Create a TextFile Dataset, with two text files
|
||||
// Create a TextFile Dataset, with two text files, 1.txt then 2.txt, in lexicographical order.
|
||||
// Note: 1.txt has 3 rows
|
||||
// Note: 2.txt has 2 rows
|
||||
// Use default of all samples
|
||||
|
@ -223,6 +223,64 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse1) {
|
|||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse1B) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse1B.";
|
||||
// Test TextFile Dataset with two text files and no shuffle, num_parallel_workers=1
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(654);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(1);
|
||||
|
||||
// Create a TextFile Dataset, with two text files, 2.txt then 1.txt, in non-lexicographical order
|
||||
// Note: 1.txt has 3 rows
|
||||
// Note: 2.txt has 2 rows
|
||||
// Use default of all samples
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file2, tf_file1}, 0, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("text"), row.end());
|
||||
std::vector<std::string> expected_result = {"This is a text file.", "Be happy every day.", "Good luck to everyone.",
|
||||
"Another file.", "End of file."};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto text = row["text"];
|
||||
MS_LOG(INFO) << "Tensor text shape: " << text->shape();
|
||||
std::string_view sv;
|
||||
text->GetItemAt(&sv, {0});
|
||||
std::string ss(sv);
|
||||
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||
// Compare against expected result
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
|
||||
i++;
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
// Expect 2 + 3 = 5 samples
|
||||
EXPECT_EQ(i, 5);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse4Shard) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse4Shard.";
|
||||
// Test TextFile Dataset with two text files and no shuffle, num_parallel_workers=4, shard coverage
|
||||
|
@ -280,8 +338,8 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse4Shard) {
|
|||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles1.";
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles1A) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles1A.";
|
||||
// Test TextFile Dataset with files shuffle, num_parallel_workers=1
|
||||
|
||||
// Set configuration
|
||||
|
@ -291,7 +349,7 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles1) {
|
|||
GlobalContext::config_manager()->set_seed(135);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(1);
|
||||
|
||||
// Create a TextFile Dataset, with two text files
|
||||
// Create a TextFile Dataset, with two text files, 1.txt then 2.txt, in lexicographical order.
|
||||
// Note: 1.txt has 3 rows
|
||||
// Note: 2.txt has 2 rows
|
||||
// Use default of all samples
|
||||
|
@ -340,6 +398,66 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles1) {
|
|||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles1B) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles1B.";
|
||||
// Test TextFile Dataset with files shuffle, num_parallel_workers=1
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(135);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(1);
|
||||
|
||||
// Create a TextFile Dataset, with two text files, 2.txt then 1.txt, in non-lexicographical order.
|
||||
// Note: 1.txt has 3 rows
|
||||
// Note: 2.txt has 2 rows
|
||||
// Use default of all samples
|
||||
// Set shuffle to files shuffle
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file2, tf_file1}, 0, ShuffleMode::kFiles);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("text"), row.end());
|
||||
std::vector<std::string> expected_result = {
|
||||
"This is a text file.", "Be happy every day.", "Good luck to everyone.", "Another file.", "End of file.",
|
||||
};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto text = row["text"];
|
||||
MS_LOG(INFO) << "Tensor text shape: " << text->shape();
|
||||
std::string_view sv;
|
||||
text->GetItemAt(&sv, {0});
|
||||
std::string ss(sv);
|
||||
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||
// Compare against expected result
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
|
||||
i++;
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
// Expect 2 + 3 = 5 samples
|
||||
EXPECT_EQ(i, 5);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles4) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles4.";
|
||||
// Test TextFile Dataset with files shuffle, num_parallel_workers=4
|
||||
|
|
Loading…
Reference in New Issue