forked from mindspore-Ecosystem/mindspore
C++ API Support for TextFile Dataset and Unit Tests
This commit is contained in:
parent
4f75adb11a
commit
7f39b5cfd7
|
@ -26,6 +26,7 @@
|
|||
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
|
||||
// Dataset operator headers (in alphabetical order)
|
||||
#include "minddata/dataset/engine/datasetops/batch_op.h"
|
||||
|
@ -95,6 +96,7 @@ Dataset::Dataset() {
|
|||
num_workers_ = cfg->num_parallel_workers();
|
||||
rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
connector_que_size_ = cfg->op_connector_size();
|
||||
worker_connector_size_ = cfg->worker_connector_size();
|
||||
}
|
||||
|
||||
// FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS
|
||||
|
@ -140,7 +142,7 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str
|
|||
std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode,
|
||||
std::shared_ptr<SamplerObj> sampler, std::set<std::string> extensions,
|
||||
std::map<std::string, int32_t> class_indexing) {
|
||||
// This arg is exist in ImageFolderOp, but not externalized (in Python API). The default value is false.
|
||||
// This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
|
||||
bool recursive = false;
|
||||
|
||||
// Create logical representation of ImageFolderDataset.
|
||||
|
@ -163,6 +165,16 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
|
|||
const std::shared_ptr<Dataset> &datasets2) {
|
||||
std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets1, datasets2}));
|
||||
|
||||
// Call derived class validation method.
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
}
|
||||
|
||||
// Function to create a TextFileDataset.
|
||||
std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files, int32_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) {
|
||||
auto ds = std::make_shared<TextFileDataset>(dataset_files, num_samples, shuffle, num_shards, shard_id);
|
||||
|
||||
// Call derived class validation method.
|
||||
return ds->ValidateParams() ? ds : nullptr;
|
||||
}
|
||||
|
||||
|
@ -340,6 +352,34 @@ std::shared_ptr<SamplerObj> CreateDefaultSampler() {
|
|||
return std::make_shared<RandomSamplerObj>(replacement, num_samples);
|
||||
}
|
||||
|
||||
// Helper function to compute a default shuffle size
|
||||
int64_t ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows) {
|
||||
const int64_t average_files_multiplier = 4;
|
||||
const int64_t shuffle_max = 10000;
|
||||
int64_t avg_rows_per_file = 0;
|
||||
int64_t shuffle_size = 0;
|
||||
|
||||
// Adjust the num rows per shard if sharding was given
|
||||
if (num_devices > 0) {
|
||||
if (num_rows % num_devices == 0) {
|
||||
num_rows = num_rows / num_devices;
|
||||
} else {
|
||||
num_rows = (num_rows / num_devices) + 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Cap based on total rows directive. Some ops do not have this and give value of 0.
|
||||
if (total_rows > 0) {
|
||||
num_rows = std::min(num_rows, total_rows);
|
||||
}
|
||||
|
||||
// get the average per file
|
||||
avg_rows_per_file = num_rows / num_files;
|
||||
|
||||
shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max);
|
||||
return shuffle_size;
|
||||
}
|
||||
|
||||
// Helper function to validate dataset params
|
||||
bool ValidateCommonDatasetParams(std::string dataset_dir) {
|
||||
if (dataset_dir.empty()) {
|
||||
|
@ -613,6 +653,87 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
|
|||
return node_ops;
|
||||
}
|
||||
|
||||
// Constructor for TextFileDataset
|
||||
TextFileDataset::TextFileDataset(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle,
|
||||
int32_t num_shards, int32_t shard_id)
|
||||
: dataset_files_(dataset_files),
|
||||
num_samples_(num_samples),
|
||||
shuffle_(shuffle),
|
||||
num_shards_(num_shards),
|
||||
shard_id_(shard_id) {}
|
||||
|
||||
bool TextFileDataset::ValidateParams() {
|
||||
if (dataset_files_.empty()) {
|
||||
MS_LOG(ERROR) << "TextFileDataset: dataset_files is not specified.";
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto file : dataset_files_) {
|
||||
std::ifstream handle(file);
|
||||
if (!handle.is_open()) {
|
||||
MS_LOG(ERROR) << "TextFileDataset: Failed to open file: " << file;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "TextFileDataset: Invalid number of samples: " << num_samples_;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (num_shards_ <= 0) {
|
||||
MS_LOG(ERROR) << "TextFileDataset: Invalid num_shards: " << num_shards_;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
|
||||
MS_LOG(ERROR) << "TextFileDataset: Invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Function to build TextFileDataset
|
||||
std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
|
||||
// Create and initalize TextFileOp
|
||||
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
|
||||
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), dataset_files_,
|
||||
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(nullptr));
|
||||
RETURN_EMPTY_IF_ERROR(text_file_op->Init());
|
||||
|
||||
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t shuffle_size = 0;
|
||||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset and then compute the shuffle size
|
||||
RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(dataset_files_, &num_rows));
|
||||
shuffle_size = ComputeShuffleSize(dataset_files_.size(), num_shards_, num_rows, 0);
|
||||
MS_LOG(INFO) << "TextFileDataset::Build - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size;
|
||||
|
||||
// Add the shuffle op after this op
|
||||
shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size_, true, rows_per_buffer_);
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
|
||||
// Add TextFileOp
|
||||
node_ops.push_back(text_file_op);
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Constructor for VOCDataset
|
||||
VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
|
||||
const std::map<std::string, int32_t> &class_index, bool decode,
|
||||
|
|
|
@ -35,6 +35,9 @@ enum class DatasetType { kUnknown, kArrow, kTf };
|
|||
// Possible flavours of Tensor implementations
|
||||
enum class TensorImpl { kNone, kFlexible, kCv, kNP };
|
||||
|
||||
// Possible values for shuffle
|
||||
enum class ShuffleMode { kFalse = 0, kFiles = 1, kGlobal = 2 };
|
||||
|
||||
// Possible values for Border types
|
||||
enum class BorderType { kConstant = 0, kEdge = 1, kReflect = 2, kSymmetric = 3 };
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <map>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/include/tensor.h"
|
||||
#include "minddata/dataset/include/iterator.h"
|
||||
#include "minddata/dataset/include/samplers.h"
|
||||
|
@ -47,6 +48,7 @@ class Cifar100Dataset;
|
|||
class CocoDataset;
|
||||
class ImageFolderDataset;
|
||||
class MnistDataset;
|
||||
class TextFileDataset;
|
||||
class VOCDataset;
|
||||
// Dataset Op classes (in alphabetical order)
|
||||
class BatchDataset;
|
||||
|
@ -83,7 +85,7 @@ std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std:
|
|||
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr);
|
||||
|
||||
/// \brief Function to create a Cifar100 Dataset
|
||||
/// \notes The generated dataset has two columns ['image', 'coarse_label', 'fine_label']
|
||||
/// \notes The generated dataset has three columns ['image', 'coarse_label', 'fine_label']
|
||||
/// \param[in] dataset_dir Path to the root directory that contains the dataset
|
||||
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
|
||||
/// will be used to randomly iterate the entire dataset
|
||||
|
@ -143,6 +145,25 @@ std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<Sam
|
|||
std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1,
|
||||
const std::shared_ptr<Dataset> &datasets2);
|
||||
|
||||
/// \brief Function to create a TextFileDataset
|
||||
/// \notes The generated dataset has one column ['text']
|
||||
/// \param[in] dataset_files List of files to be read to search for a pattern of files. The list
|
||||
/// will be sorted in a lexicographical order.
|
||||
/// \param[in] num_samples The number of samples to be included in the dataset.
|
||||
/// (Default = 0 means all samples.)
|
||||
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal)
|
||||
/// Can be any of:
|
||||
/// ShuffleMode.kFalse - No shuffling is performed.
|
||||
/// ShuffleMode.kFiles - Shuffle files only.
|
||||
/// ShuffleMode.kGlobal - Shuffle both the files and samples.
|
||||
/// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1)
|
||||
/// \param[in] shard_id The shard ID within num_shards. This argument should be
|
||||
/// specified only when num_shards is also specified. (Default = 0)
|
||||
/// \return Shared pointer to the current TextFileDataset
|
||||
std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files, int32_t num_samples = 0,
|
||||
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
|
||||
int32_t shard_id = 0);
|
||||
|
||||
/// \brief Function to create a VOCDataset
|
||||
/// \notes The generated dataset has multi-columns :
|
||||
/// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32],
|
||||
|
@ -289,10 +310,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
int32_t num_workers_;
|
||||
int32_t rows_per_buffer_;
|
||||
int32_t connector_que_size_;
|
||||
int32_t worker_connector_size_;
|
||||
};
|
||||
|
||||
/* ####################################### Derived Dataset classes ################################# */
|
||||
|
||||
// DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS
|
||||
// (In alphabetical order)
|
||||
|
||||
class CelebADataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -318,6 +343,8 @@ class CelebADataset : public Dataset {
|
|||
std::set<std::string> extensions_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
// DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS
|
||||
// (In alphabetical order)
|
||||
|
||||
class Cifar10Dataset : public Dataset {
|
||||
public:
|
||||
|
@ -435,6 +462,33 @@ class MnistDataset : public Dataset {
|
|||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
/// \class TextFileDataset
|
||||
/// \brief A Dataset derived class to represent TextFile dataset
|
||||
class TextFileDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
TextFileDataset(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards,
|
||||
int32_t shard_id);
|
||||
|
||||
/// \brief Destructor
|
||||
~TextFileDataset() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return bool true if all the params are valid
|
||||
bool ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> dataset_files_;
|
||||
int32_t num_samples_;
|
||||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
ShuffleMode shuffle_;
|
||||
};
|
||||
|
||||
class VOCDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -467,6 +521,9 @@ class VOCDataset : public Dataset {
|
|||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
// DERIVED DATASET CLASSES FOR DATASET OPS
|
||||
// (In alphabetical order)
|
||||
|
||||
class BatchDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
|
|
@ -5012,7 +5012,7 @@ class CSVDataset(SourceDataset):
|
|||
class TextFileDataset(SourceDataset):
|
||||
"""
|
||||
A source dataset that reads and parses datasets stored on disk in text format.
|
||||
The generated dataset has one columns ['text'].
|
||||
The generated dataset has one column ['text'].
|
||||
|
||||
Args:
|
||||
dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a
|
||||
|
|
|
@ -97,6 +97,7 @@ SET(DE_UT_SRCS
|
|||
c_api_dataset_ops_test.cc
|
||||
c_api_dataset_cifar_test.cc
|
||||
c_api_dataset_coco_test.cc
|
||||
c_api_dataset_filetext_test.cc
|
||||
c_api_dataset_voc_test.cc
|
||||
c_api_datasets_test.cc
|
||||
c_api_dataset_iterator_test.cc
|
||||
|
|
|
@ -0,0 +1,596 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "./securec.h"
|
||||
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/include/iterator.h"
|
||||
#include "minddata/dataset/include/samplers.h"
|
||||
#include "minddata/dataset/include/status.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::ShuffleMode;
|
||||
using mindspore::dataset::Status;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorImpl;
|
||||
using mindspore::dataset::TensorShape;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::MsLogLevel::ERROR;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetBasic) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetBasic.";
|
||||
// Test TextFile Dataset with single text file and many default inputs
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(987);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||
|
||||
// Create a TextFile Dataset, with single text file
|
||||
// Note: 1.txt has 3 rows
|
||||
// Use 2 samples
|
||||
// Use defaults for other input parameters
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file1}, 2);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("text"), row.end());
|
||||
std::vector<std::string> expected_result = {"Be happy every day.", "This is a text file."};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto text = row["text"];
|
||||
MS_LOG(INFO) << "Tensor text shape: " << text->shape();
|
||||
std::string_view sv;
|
||||
text->GetItemAt(&sv, {0});
|
||||
std::string ss(sv);
|
||||
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||
// Compare against expected result
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
|
||||
i++;
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
// Expect 2 samples
|
||||
EXPECT_EQ(i, 2);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse1.";
|
||||
// Test TextFile Dataset with two text files and no shuffle, num_parallel_workers=1
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(654);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(1);
|
||||
|
||||
// Create a TextFile Dataset, with two text files
|
||||
// Note: 1.txt has 3 rows
|
||||
// Note: 2.txt has 2 rows
|
||||
// Use default of all samples
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file1, tf_file2}, 0, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("text"), row.end());
|
||||
std::vector<std::string> expected_result = {"This is a text file.", "Be happy every day.", "Good luck to everyone.",
|
||||
"Another file.", "End of file."};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto text = row["text"];
|
||||
MS_LOG(INFO) << "Tensor text shape: " << text->shape();
|
||||
std::string_view sv;
|
||||
text->GetItemAt(&sv, {0});
|
||||
std::string ss(sv);
|
||||
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||
// Compare against expected result
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
|
||||
i++;
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
// Expect 2 + 3 = 5 samples
|
||||
EXPECT_EQ(i, 5);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse4Shard) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse4Shard.";
|
||||
// Test TextFile Dataset with two text files and no shuffle, num_parallel_workers=4, shard coverage
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(654);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||
|
||||
// Create a TextFile Dataset, with two text files
|
||||
// Note: 1.txt has 3 rows
|
||||
// Note: 2.txt has 2 rows
|
||||
// Set shuffle to file shuffle, num_shards=2, shard_id=0
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file1, tf_file2}, 0, ShuffleMode::kFalse, 2, 0);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("text"), row.end());
|
||||
std::vector<std::string> expected_result = {"This is a text file.", "Be happy every day.", "Good luck to everyone."};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto text = row["text"];
|
||||
MS_LOG(INFO) << "Tensor text shape: " << text->shape();
|
||||
std::string_view sv;
|
||||
text->GetItemAt(&sv, {0});
|
||||
std::string ss(sv);
|
||||
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||
// Compare against expected result
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
|
||||
i++;
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
// Expect 3 samples for this shard
|
||||
EXPECT_EQ(i, 3);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleGlobal1A) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleGlobal1A.";
|
||||
// Test TextFile Dataset with 1 text file, global shuffle, num_parallel_workers=1
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(246);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(1);
|
||||
|
||||
// Create a TextFile Dataset, with two text files
|
||||
// Note: 1.txt has 3 rows
|
||||
// Set shuffle to global shuffle
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file1}, 0, ShuffleMode::kGlobal);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("text"), row.end());
|
||||
std::vector<std::string> expected_result = {"Good luck to everyone.", "This is a text file.", "Be happy every day."};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto text = row["text"];
|
||||
MS_LOG(INFO) << "Tensor text shape: " << text->shape();
|
||||
std::string_view sv;
|
||||
text->GetItemAt(&sv, {0});
|
||||
std::string ss(sv);
|
||||
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||
// Compare against expected result
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
|
||||
i++;
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
// Expect 3 samples
|
||||
EXPECT_EQ(i, 3);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleGlobal1B) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleGlobal1B.";
|
||||
// Test TextFile Dataset with 2 text files, global shuffle, num_parallel_workers=1
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(246);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(1);
|
||||
|
||||
// Create a TextFile Dataset, with two text files
|
||||
// Note: 1.txt has 3 rows
|
||||
// Note: 2.txt has 2 rows
|
||||
// Set shuffle to global shuffle
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file1, tf_file2}, 0, ShuffleMode::kGlobal);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("text"), row.end());
|
||||
std::vector<std::string> expected_result = {"Another file.", "Good luck to everyone.", "This is a text file.",
|
||||
"End of file.", "Be happy every day."};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto text = row["text"];
|
||||
MS_LOG(INFO) << "Tensor text shape: " << text->shape();
|
||||
std::string_view sv;
|
||||
text->GetItemAt(&sv, {0});
|
||||
std::string ss(sv);
|
||||
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||
// Compare against expected result
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
|
||||
i++;
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
// Expect 2 + 3 = 5 samples
|
||||
EXPECT_EQ(i, 5);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleGlobal4) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleGlobal4.";
|
||||
// Test TextFile Dataset with 2 text files, global shuffle, num_parallel_workers=4
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(246);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||
|
||||
// Create a TextFile Dataset, with two text files
|
||||
// Note: 1.txt has 3 rows
|
||||
// Note: 2.txt has 2 rows
|
||||
// Set shuffle to global shuffle
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file1, tf_file2}, 0, ShuffleMode::kGlobal);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("text"), row.end());
|
||||
std::vector<std::string> expected_result = {"Another file.", "Good luck to everyone.", "End of file.",
|
||||
"This is a text file.", "Be happy every day."};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto text = row["text"];
|
||||
MS_LOG(INFO) << "Tensor text shape: " << text->shape();
|
||||
std::string_view sv;
|
||||
text->GetItemAt(&sv, {0});
|
||||
std::string ss(sv);
|
||||
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||
// Compare against expected result
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
|
||||
i++;
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
// Expect 2 + 3 = 5 samples
|
||||
EXPECT_EQ(i, 5);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles1.";
|
||||
// Test TextFile Dataset with files shuffle, num_parallel_workers=1
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(135);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(1);
|
||||
|
||||
// Create a TextFile Dataset, with two text files
|
||||
// Note: 1.txt has 3 rows
|
||||
// Note: 2.txt has 2 rows
|
||||
// Use default of all samples
|
||||
// Set shuffle to files shuffle
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file1, tf_file2}, 0, ShuffleMode::kFiles);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("text"), row.end());
|
||||
std::vector<std::string> expected_result = {
|
||||
"This is a text file.", "Be happy every day.", "Good luck to everyone.", "Another file.", "End of file.",
|
||||
};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto text = row["text"];
|
||||
MS_LOG(INFO) << "Tensor text shape: " << text->shape();
|
||||
std::string_view sv;
|
||||
text->GetItemAt(&sv, {0});
|
||||
std::string ss(sv);
|
||||
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||
// Compare against expected result
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
|
||||
i++;
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
// Expect 2 + 3 = 5 samples
|
||||
EXPECT_EQ(i, 5);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles4) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles4.";
|
||||
// Test TextFile Dataset with files shuffle, num_parallel_workers=4
|
||||
|
||||
// Set configuration
|
||||
uint32_t original_seed = GlobalContext::config_manager()->seed();
|
||||
uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
GlobalContext::config_manager()->set_seed(135);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(4);
|
||||
|
||||
// Create a TextFile Dataset, with two text files
|
||||
// Note: 1.txt has 3 rows
|
||||
// Note: 2.txt has 2 rows
|
||||
// Use default of all samples
|
||||
// Set shuffle to files shuffle
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file1, tf_file2}, 0, ShuffleMode::kFiles);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
EXPECT_NE(row.find("text"), row.end());
|
||||
std::vector<std::string> expected_result = {"This is a text file.", "Another file.", "Be happy every day.",
|
||||
"End of file.", "Good luck to everyone."};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto text = row["text"];
|
||||
MS_LOG(INFO) << "Tensor text shape: " << text->shape();
|
||||
std::string_view sv;
|
||||
text->GetItemAt(&sv, {0});
|
||||
std::string ss(sv);
|
||||
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||
// Compare against expected result
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
|
||||
i++;
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
// Expect 2 + 3 = 5 samples
|
||||
EXPECT_EQ(i, 5);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
GlobalContext::config_manager()->set_seed(original_seed);
|
||||
GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetFail1) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail1.";
|
||||
|
||||
// Attempt to create a TextFile Dataset
|
||||
// with invalid samplers=-1
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file1}, -1);
|
||||
|
||||
// Expect failure: Number of samples cannot be negative
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetFail2) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail2.";
|
||||
|
||||
// Attempt to create a TextFile Dataset
|
||||
// with wrongful empty dataset_files input
|
||||
std::shared_ptr<Dataset> ds = TextFile({});
|
||||
|
||||
// Expect failure: dataset_files is not specified
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetFail3) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail3.";
|
||||
|
||||
// Attempt to create a TextFile Dataset
|
||||
// with non-existent dataset_files input
|
||||
std::shared_ptr<Dataset> ds = TextFile({"notexist.txt"}, 0, ShuffleMode::kFalse);
|
||||
|
||||
// Expect failure: specified dataset_files does not exist
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetFail4) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail4.";
|
||||
|
||||
// Attempt to create a TextFile Dataset
|
||||
// with empty string dataset_files input
|
||||
std::shared_ptr<Dataset> ds = TextFile({""}, 0, ShuffleMode::kFiles);
|
||||
|
||||
// Expect failure: specified dataset_files does not exist
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetFail5) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail5.";
|
||||
|
||||
// Attempt to create a TextFile Dataset
|
||||
// with invalid num_shards=0 value
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file1}, 1, ShuffleMode::kFalse, 0);
|
||||
|
||||
// Expect failure: Number of shards cannot be <=0
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetFail6) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail6.";
|
||||
|
||||
// Attempt to create a TextFile Dataset
|
||||
// with invalid shard_id=-1 value
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file1}, 0, ShuffleMode::kFiles, -1);
|
||||
|
||||
// Expect failure: shard_id cannot be negative
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTextFileDatasetFail7) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail7.";
|
||||
|
||||
// Attempt to create a TextFile Dataset
|
||||
// with invalid shard_id=2 and num_shards=2 combination
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({tf_file1}, 0, ShuffleMode::kGlobal, 2, 2);
|
||||
|
||||
// Expect failure: Cannot have shard_id >= num_shards
|
||||
EXPECT_EQ(ds, nullptr);
|
||||
}
|
|
@ -89,6 +89,23 @@ TEST_F(MindDataTestTextFileOp, TestTextFileBasic) {
|
|||
ASSERT_EQ(row_count, 3);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTextFileOp, TestTextFileFileNotExist) {
|
||||
// Start with an empty execution tree
|
||||
auto tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
std::string dataset_path = datasets_root_path_ + "/does/not/exist/0.txt";
|
||||
|
||||
std::shared_ptr<TextFileOp> op;
|
||||
TextFileOp::Builder builder;
|
||||
builder.SetTextFilesList({dataset_path})
|
||||
.SetRowsPerBuffer(16)
|
||||
.SetNumWorkers(16)
|
||||
.SetOpConnectorSize(2);
|
||||
|
||||
Status rc = builder.Build(&op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTextFileOp, TestTotalRows) {
|
||||
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt";
|
||||
|
@ -110,3 +127,14 @@ TEST_F(MindDataTestTextFileOp, TestTotalRows) {
|
|||
ASSERT_EQ(total_rows, 5);
|
||||
files.clear();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTextFileOp, TestTotalRowsFileNotExist) {
|
||||
std::string tf_file1 = datasets_root_path_ + "/does/not/exist/0.txt";
|
||||
std::vector<std::string> files;
|
||||
files.push_back(tf_file1);
|
||||
int64_t total_rows = 0;
|
||||
TextFileOp::CountAllFileRows(files, &total_rows);
|
||||
ASSERT_EQ(total_rows, 0);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -12,9 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
from util import config_get_set_num_parallel_workers
|
||||
from util import config_get_set_num_parallel_workers, config_get_set_seed
|
||||
|
||||
|
||||
DATA_FILE = "../data/dataset/testTextFileDataset/1.txt"
|
||||
|
@ -39,8 +40,18 @@ def test_textline_dataset_all_file():
|
|||
assert count == 5
|
||||
|
||||
|
||||
def test_textline_dataset_totext():
|
||||
def test_textline_dataset_num_samples_zero():
|
||||
data = ds.TextFileDataset(DATA_FILE, num_samples=0)
|
||||
count = 0
|
||||
for i in data.create_dict_iterator():
|
||||
logger.info("{}".format(i["text"]))
|
||||
count += 1
|
||||
assert count == 3
|
||||
|
||||
|
||||
def test_textline_dataset_shuffle_false4():
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||
original_seed = config_get_set_seed(987)
|
||||
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False)
|
||||
count = 0
|
||||
line = ["This is a text file.", "Another file.",
|
||||
|
@ -50,8 +61,94 @@ def test_textline_dataset_totext():
|
|||
assert strs == line[count]
|
||||
count += 1
|
||||
assert count == 5
|
||||
# Restore configuration num_parallel_workers
|
||||
# Restore configuration
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
ds.config.set_seed(original_seed)
|
||||
|
||||
|
||||
def test_textline_dataset_shuffle_false1():
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
original_seed = config_get_set_seed(987)
|
||||
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False)
|
||||
count = 0
|
||||
line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.",
|
||||
"Another file.", "End of file."]
|
||||
for i in data.create_dict_iterator():
|
||||
strs = i["text"].item().decode("utf8")
|
||||
assert strs == line[count]
|
||||
count += 1
|
||||
assert count == 5
|
||||
# Restore configuration
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
ds.config.set_seed(original_seed)
|
||||
|
||||
|
||||
def test_textline_dataset_shuffle_files4():
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||
original_seed = config_get_set_seed(135)
|
||||
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES)
|
||||
count = 0
|
||||
line = ["This is a text file.", "Another file.",
|
||||
"Be happy every day.", "End of file.", "Good luck to everyone."]
|
||||
for i in data.create_dict_iterator():
|
||||
strs = i["text"].item().decode("utf8")
|
||||
assert strs == line[count]
|
||||
count += 1
|
||||
assert count == 5
|
||||
# Restore configuration
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
ds.config.set_seed(original_seed)
|
||||
|
||||
|
||||
def test_textline_dataset_shuffle_files1():
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
original_seed = config_get_set_seed(135)
|
||||
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES)
|
||||
count = 0
|
||||
line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.",
|
||||
"Another file.", "End of file."]
|
||||
for i in data.create_dict_iterator():
|
||||
strs = i["text"].item().decode("utf8")
|
||||
assert strs == line[count]
|
||||
count += 1
|
||||
assert count == 5
|
||||
# Restore configuration
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
ds.config.set_seed(original_seed)
|
||||
|
||||
|
||||
def test_textline_dataset_shuffle_global4():
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(4)
|
||||
original_seed = config_get_set_seed(246)
|
||||
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL)
|
||||
count = 0
|
||||
line = ["Another file.", "Good luck to everyone.", "End of file.",
|
||||
"This is a text file.", "Be happy every day."]
|
||||
for i in data.create_dict_iterator():
|
||||
strs = i["text"].item().decode("utf8")
|
||||
assert strs == line[count]
|
||||
count += 1
|
||||
assert count == 5
|
||||
# Restore configuration
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
ds.config.set_seed(original_seed)
|
||||
|
||||
|
||||
def test_textline_dataset_shuffle_global1():
|
||||
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
|
||||
original_seed = config_get_set_seed(246)
|
||||
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL)
|
||||
count = 0
|
||||
line = ["Another file.", "Good luck to everyone.", "This is a text file.",
|
||||
"End of file.", "Be happy every day."]
|
||||
for i in data.create_dict_iterator():
|
||||
strs = i["text"].item().decode("utf8")
|
||||
assert strs == line[count]
|
||||
count += 1
|
||||
assert count == 5
|
||||
# Restore configuration
|
||||
ds.config.set_num_parallel_workers(original_num_parallel_workers)
|
||||
ds.config.set_seed(original_seed)
|
||||
|
||||
|
||||
def test_textline_dataset_num_samples():
|
||||
|
@ -94,11 +191,33 @@ def test_textline_dataset_to_device():
|
|||
data = data.to_device()
|
||||
data.send()
|
||||
|
||||
def test_textline_dataset_exceptions():
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
_ = ds.TextFileDataset(DATA_FILE, num_samples=-1)
|
||||
assert "Input num_samples is not within the required interval" in str(error_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
_ = ds.TextFileDataset("does/not/exist/no.txt")
|
||||
assert "The following patterns did not match any files" in str(error_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as error_info:
|
||||
_ = ds.TextFileDataset("")
|
||||
assert "The following patterns did not match any files" in str(error_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_textline_dataset_one_file()
|
||||
test_textline_dataset_all_file()
|
||||
test_textline_dataset_totext()
|
||||
test_textline_dataset_num_samples_zero()
|
||||
test_textline_dataset_shuffle_false4()
|
||||
test_textline_dataset_shuffle_false1()
|
||||
test_textline_dataset_shuffle_files4()
|
||||
test_textline_dataset_shuffle_files1()
|
||||
test_textline_dataset_shuffle_global4()
|
||||
test_textline_dataset_shuffle_global1()
|
||||
test_textline_dataset_num_samples()
|
||||
test_textline_dataset_distribution()
|
||||
test_textline_dataset_repeat()
|
||||
test_textline_dataset_get_datasetsize()
|
||||
test_textline_dataset_to_device()
|
||||
test_textline_dataset_exceptions()
|
||||
|
|
Loading…
Reference in New Issue