[MD] C++ api add MindDataset

This commit is contained in:
luoyang 2020-10-14 16:05:54 +08:00
parent d0a1a9b73c
commit 2dc8e5f421
9 changed files with 762 additions and 10 deletions

View File

@ -31,6 +31,7 @@
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
#endif
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
@ -223,6 +224,27 @@ std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const
}
#endif
// Function to create a MindDataDataset.
std::shared_ptr<MindDataDataset> MindData(const std::string &dataset_file, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample,
int64_t num_padded) {
auto ds = std::make_shared<MindDataDataset>(dataset_file, columns_list, sampler, padded_sample, num_padded);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a MindDataDataset.
std::shared_ptr<MindDataDataset> MindData(const std::vector<std::string> &dataset_files,
const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample,
int64_t num_padded) {
auto ds = std::make_shared<MindDataDataset>(dataset_files, columns_list, sampler, padded_sample, num_padded);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
}
// Function to create a MnistDataset.
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage,
const std::shared_ptr<SamplerObj> &sampler) {
@ -709,6 +731,11 @@ Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vec
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (access(dataset_file.toString().c_str(), R_OK) == -1) {
std::string err_msg = dataset_name + ": No access to specified dataset file: " + f;
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
return Status::OK();
@ -1388,6 +1415,146 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestDataset::Build() {
}
#endif
#ifndef ENABLE_ANDROID
MindDataDataset::MindDataDataset(const std::vector<std::string> &dataset_files,
const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample,
int64_t num_padded)
: dataset_file_(std::string()),
dataset_files_(dataset_files),
search_for_pattern_(false),
columns_list_(columns_list),
sampler_(sampler),
padded_sample_(padded_sample),
sample_bytes_({}),
num_padded_(num_padded) {}
MindDataDataset::MindDataDataset(const std::string &dataset_file, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample,
int64_t num_padded)
: dataset_file_(dataset_file),
dataset_files_({}),
search_for_pattern_(true),
columns_list_(columns_list),
sampler_(sampler),
padded_sample_(padded_sample),
sample_bytes_({}),
num_padded_(num_padded) {}
Status MindDataDataset::ValidateParams() {
if (!search_for_pattern_ && dataset_files_.size() > 4096) {
std::string err_msg =
"MindDataDataset: length of dataset_file must be less than or equal to 4096, dataset_file length: " +
std::to_string(dataset_file_.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
std::vector<std::string> dataset_file_vec =
search_for_pattern_ ? std::vector<std::string>{dataset_file_} : dataset_files_;
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("MindDataDataset", dataset_file_vec));
RETURN_IF_NOT_OK(ValidateDatasetSampler("MindDataDataset", sampler_));
if (!columns_list_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MindDataDataset", "columns_list", columns_list_));
}
if (padded_sample_ != nullptr) {
if (num_padded_ < 0) {
std::string err_msg =
"MindDataDataset: num_padded must be greater than or equal to zero, num_padded: " + std::to_string(num_padded_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (columns_list_.empty()) {
std::string err_msg = "MindDataDataset: padded_sample is specified and requires columns_list as well";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (std::string &column : columns_list_) {
if (padded_sample_.find(column) == padded_sample_.end()) {
std::string err_msg =
"MindDataDataset: " + column + " in columns_list does not match any column in padded_sample";
MS_LOG(ERROR) << err_msg << ", padded_sample: " << padded_sample_;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
}
if (num_padded_ > 0) {
if (padded_sample_ == nullptr) {
std::string err_msg = "MindDataDataset: num_padded is specified but padded_sample is not";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
return Status::OK();
}
// Helper function to create runtime sampler for minddata dataset
Status MindDataDataset::BuildMindDatasetSamplerChain(
const std::shared_ptr<SamplerObj> &sampler, std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators_,
int64_t num_padded) {
std::shared_ptr<mindrecord::ShardOperator> op = sampler->BuildForMindDataset();
if (op == nullptr) {
std::string err_msg =
"MindDataDataset: Unsupported sampler is supplied for MindDataset. Supported sampler list: "
"SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler and DistributedSampler";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
std::stack<std::shared_ptr<mindrecord::ShardOperator>> stack_ops;
while (op != nullptr) {
auto sampler_op = std::dynamic_pointer_cast<mindrecord::ShardDistributedSample>(op);
if (sampler_op && num_padded > 0) {
sampler_op->SetNumPaddedSamples(num_padded);
stack_ops.push(sampler_op);
} else {
stack_ops.push(op);
}
op = op->GetChildOp();
}
while (!stack_ops.empty()) {
operators_->push_back(stack_ops.top());
stack_ops.pop();
}
return Status::OK();
}
// Helper function to set sample_bytes from py::byte type
void MindDataDataset::SetSampleBytes(std::map<std::string, std::string> *sample_bytes) {
sample_bytes_ = *sample_bytes;
}
std::vector<std::shared_ptr<DatasetOp>> MindDataDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
std::vector<std::shared_ptr<DatasetOp>> node_ops;
std::vector<std::shared_ptr<ShardOperator>> operators_;
RETURN_EMPTY_IF_ERROR(BuildMindDatasetSamplerChain(sampler_, &operators_, num_padded_));
std::shared_ptr<MindRecordOp> mindrecord_op;
// If pass a string to MindData(), it will be treated as a pattern to search for matched files,
// else if pass a vector to MindData(), it will be treated as specified files to be read
if (search_for_pattern_) {
std::vector<std::string> dataset_file_vec_ = {dataset_file_};
mindrecord_op = std::make_shared<MindRecordOp>(num_workers_, rows_per_buffer_, dataset_file_vec_,
search_for_pattern_, connector_que_size_, columns_list_, operators_,
num_padded_, padded_sample_, sample_bytes_);
} else {
mindrecord_op = std::make_shared<MindRecordOp>(num_workers_, rows_per_buffer_, dataset_files_, search_for_pattern_,
connector_que_size_, columns_list_, operators_, num_padded_,
padded_sample_, sample_bytes_);
}
RETURN_EMPTY_IF_ERROR(mindrecord_op->Init());
node_ops.push_back(mindrecord_op);
return node_ops;
}
#endif
MnistDataset::MnistDataset(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}

View File

@ -69,7 +69,7 @@ PYBIND_REGISTER(ShardSequentialSample, 0, ([](const py::module *m) {
(void)py::class_<mindrecord::ShardSequentialSample, mindrecord::ShardSample,
std::shared_ptr<mindrecord::ShardSequentialSample>>(*m,
"MindrecordSequentialSampler")
.def(py::init([](int num_samples, int start_index) {
.def(py::init([](int64_t num_samples, int64_t start_index) {
return std::make_shared<mindrecord::ShardSequentialSample>(num_samples, start_index);
}));
}));

View File

@ -23,10 +23,28 @@
#include "minddata/dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/pk_sampler.h"
#include "minddata/mindrecord/include/shard_distributed_sample.h"
#include "minddata/mindrecord/include/shard_operator.h"
#include "minddata/mindrecord/include/shard_pk_sample.h"
#include "minddata/mindrecord/include/shard_sample.h"
#include "minddata/mindrecord/include/shard_sequential_sample.h"
#include "minddata/mindrecord/include/shard_shuffle.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
namespace api {
#define RETURN_NULL_IF_ERROR(_s) \
do { \
Status __rc = (_s); \
if (__rc.IsError()) { \
MS_LOG(ERROR) << __rc; \
return nullptr; \
} \
} while (false)
// Constructor
SamplerObj::SamplerObj() {}
/// Function to create a Distributed Sampler.
@ -126,8 +144,17 @@ bool DistributedSamplerObj::ValidateParams() {
}
std::shared_ptr<Sampler> DistributedSamplerObj::Build() {
return std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_, offset_,
even_dist_);
// runtime sampler object
auto sampler = std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
offset_, even_dist_);
return sampler;
}
std::shared_ptr<mindrecord::ShardOperator> DistributedSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardDistributedSample>(num_shards_, shard_id_, shuffle_, seed_,
num_samples_, offset_);
return mind_sampler;
}
// PKSampler
@ -148,7 +175,23 @@ bool PKSamplerObj::ValidateParams() {
}
std::shared_ptr<Sampler> PKSamplerObj::Build() {
return std::make_shared<dataset::PKSampler>(num_samples_, num_val_, shuffle_);
// runtime sampler object
auto sampler = std::make_shared<dataset::PKSampler>(num_samples_, num_val_, shuffle_);
return sampler;
}
std::shared_ptr<mindrecord::ShardOperator> PKSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
std::shared_ptr<mindrecord::ShardOperator> mind_sampler;
if (shuffle_ == true) {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, std::numeric_limits<int64_t>::max(),
GetSeed(), num_samples_);
} else {
mind_sampler = std::make_shared<mindrecord::ShardPkSample>("label", num_val_, num_samples_);
}
return mind_sampler;
}
// RandomSampler
@ -164,11 +207,22 @@ bool RandomSamplerObj::ValidateParams() {
}
std::shared_ptr<Sampler> RandomSamplerObj::Build() {
// runtime sampler object
bool reshuffle_each_epoch = true;
auto sampler = std::make_shared<dataset::RandomSampler>(num_samples_, replacement_, reshuffle_each_epoch);
return sampler;
}
std::shared_ptr<mindrecord::ShardOperator> RandomSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
bool reshuffle_each_epoch_ = true;
auto mind_sampler =
std::make_shared<mindrecord::ShardShuffle>(GetSeed(), num_samples_, replacement_, reshuffle_each_epoch_);
return mind_sampler;
}
// SequentialSampler
SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples)
: start_index_(start_index), num_samples_(num_samples) {}
@ -188,10 +242,19 @@ bool SequentialSamplerObj::ValidateParams() {
}
std::shared_ptr<Sampler> SequentialSamplerObj::Build() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SequentialSampler>(num_samples_, start_index_);
return sampler;
}
std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSequentialSample>(num_samples_, start_index_);
return mind_sampler;
}
// SubsetRandomSampler
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
: indices_(std::move(indices)), num_samples_(num_samples) {}
@ -206,10 +269,19 @@ bool SubsetRandomSamplerObj::ValidateParams() {
}
std::shared_ptr<Sampler> SubsetRandomSamplerObj::Build() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SubsetRandomSampler>(num_samples_, indices_);
return sampler;
}
std::shared_ptr<mindrecord::ShardOperator> SubsetRandomSamplerObj::BuildForMindDataset() {
// runtime mindrecord sampler object
auto mind_sampler = std::make_shared<mindrecord::ShardSample>(indices_, GetSeed());
return mind_sampler;
}
// WeightedRandomSampler
WeightedRandomSamplerObj::WeightedRandomSamplerObj(std::vector<double> weights, int64_t num_samples, bool replacement)
: weights_(std::move(weights)), num_samples_(num_samples), replacement_(replacement) {}

View File

@ -66,6 +66,7 @@ class CsvBase;
class ImageFolderDataset;
#ifndef ENABLE_ANDROID
class ManifestDataset;
class MindDataDataset;
#endif
class MnistDataset;
class RandomDataset;
@ -244,6 +245,37 @@ std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const
bool decode = false);
#endif
#ifndef ENABLE_ANDROID
/// \brief Function to create a MindDataDataset
/// \param[in] dataset_file File name of one component of a mindrecord source. Other files with identical source
/// in the same path will be found and loaded automatically.
/// \param[in] columns_list List of columns to be read (default={})
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()),
/// supported sampler list: SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler.
/// \param[in] padded_sample Samples will be appended to dataset, where keys are the same as column_list.
/// \param[in] num_padded Number of padding samples. Dataset size plus num_padded should be divisible by num_shards.
/// \return Shared pointer to the current MindDataDataset
std::shared_ptr<MindDataDataset> MindData(const std::string &dataset_file,
const std::vector<std::string> &columns_list = {},
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
nlohmann::json padded_sample = nullptr, int64_t num_padded = 0);
/// \brief Function to create a MindDataDataset
/// \param[in] dataset_files List of dataset files to be read directly.
/// \param[in] columns_list List of columns to be read (default={})
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given,
/// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()),
/// supported sampler list: SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler.
/// \param[in] padded_sample Samples will be appended to dataset, where keys are the same as column_list.
/// \param[in] num_padded Number of padding samples. Dataset size plus num_padded should be divisible by num_shards.
/// \return Shared pointer to the current MindDataDataset
std::shared_ptr<MindDataDataset> MindData(const std::vector<std::string> &dataset_files,
const std::vector<std::string> &columns_list = {},
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
nlohmann::json padded_sample = nullptr, int64_t num_padded = 0);
#endif
/// \brief Function to create a MnistDataset
/// \notes The generated dataset has two columns ["image", "label"]
/// \param[in] dataset_dir Path to the root directory that contains the dataset
@ -938,6 +970,50 @@ class ManifestDataset : public Dataset {
};
#endif
#ifndef ENABLE_ANDROID
class MindDataDataset : public Dataset {
public:
/// \brief Constructor
MindDataDataset(const std::vector<std::string> &dataset_files, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded);
/// \brief Constructor
MindDataDataset(const std::string &dataset_file, const std::vector<std::string> &columns_list,
const std::shared_ptr<SamplerObj> &sampler, nlohmann::json padded_sample, int64_t num_padded);
/// \brief Destructor
~MindDataDataset() = default;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return The list of shared pointers to the newly created DatasetOps
std::vector<std::shared_ptr<DatasetOp>> Build() override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Build sampler chain for minddata dataset
/// \return Status Status::OK() if input sampler is valid
Status BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler,
std::vector<std::shared_ptr<mindrecord::ShardOperator>> *operators_,
int64_t num_padded);
/// \brief Set sample_bytes when padded_sample has py::byte value
/// \note Pybind will use this function to set sample_bytes into MindDataDataset
void SetSampleBytes(std::map<std::string, std::string> *sample_bytes);
private:
std::string dataset_file_; // search_for_pattern_ will be true in this mode
std::vector<std::string> dataset_files_; // search_for_pattern_ will be false in this mode
bool search_for_pattern_;
std::vector<std::string> columns_list_;
std::shared_ptr<SamplerObj> sampler_;
nlohmann::json padded_sample_;
std::map<std::string, std::string> sample_bytes_; // enable in python
int64_t num_padded_;
};
#endif
class MnistDataset : public Dataset {
public:
/// \brief Constructor

View File

@ -19,6 +19,7 @@
#include <vector>
#include <memory>
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
namespace mindspore {
namespace dataset {
@ -30,12 +31,24 @@ namespace api {
class SamplerObj : public std::enable_shared_from_this<SamplerObj> {
public:
/// \brief Constructor
SamplerObj();
/// \brief Destructor
~SamplerObj() = default;
virtual std::shared_ptr<Sampler> Build() = 0;
/// \brief Pure virtual function for derived class to implement parameters validation
/// \return bool true if all the parameters are valid
virtual bool ValidateParams() = 0;
/// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<Sampler> Build() = 0;
/// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object,
/// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() { return nullptr; }
};
class DistributedSamplerObj;
@ -110,6 +123,8 @@ class DistributedSamplerObj : public SamplerObj {
std::shared_ptr<Sampler> Build() override;
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
bool ValidateParams() override;
private:
@ -130,6 +145,8 @@ class PKSamplerObj : public SamplerObj {
std::shared_ptr<Sampler> Build() override;
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
bool ValidateParams() override;
private:
@ -146,6 +163,8 @@ class RandomSamplerObj : public SamplerObj {
std::shared_ptr<Sampler> Build() override;
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
bool ValidateParams() override;
private:
@ -161,6 +180,8 @@ class SequentialSamplerObj : public SamplerObj {
std::shared_ptr<Sampler> Build() override;
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
bool ValidateParams() override;
private:
@ -176,6 +197,8 @@ class SubsetRandomSamplerObj : public SamplerObj {
std::shared_ptr<Sampler> Build() override;
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
bool ValidateParams() override;
private:

View File

@ -27,7 +27,7 @@ namespace mindspore {
namespace mindrecord {
class ShardSequentialSample : public ShardSample {
public:
ShardSequentialSample(int n, int offset);
ShardSequentialSample(int64_t n, int64_t offset);
ShardSequentialSample(float per, float per_offset);
@ -38,7 +38,7 @@ class ShardSequentialSample : public ShardSample {
int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override;
private:
int offset_;
int64_t offset_;
float per_;
float per_offset_;
};

View File

@ -22,7 +22,7 @@ using mindspore::MsLogLevel::ERROR;
namespace mindspore {
namespace mindrecord {
ShardSequentialSample::ShardSequentialSample(int n, int offset)
ShardSequentialSample::ShardSequentialSample(int64_t n, int64_t offset)
: ShardSample(n), offset_(offset), per_(0.0f), per_offset_(0.0f) {}
ShardSequentialSample::ShardSequentialSample(float per, float per_offset)

View File

@ -3047,7 +3047,10 @@ class MindDataset(MappableDataset):
A source dataset that reads MindRecord files.
Args:
dataset_file (Union[str, list[str]]): One of file names or file list in dataset.
dataset_file (Union[str, list[str]]): If dataset_file is a str, it represents for
a file name of one component of a mindrecord source, other files with identical source
in the same path will be found and loaded automatically. If dataset_file is a list,
it represents for a list of dataset files to be read directly.
columns_list (list[str], optional): List of columns to be read (default=None).
num_parallel_workers (int, optional): The number of readers (default=None).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
@ -3059,7 +3062,7 @@ class MindDataset(MappableDataset):
dataset (default=None, sampler is exclusive
with shuffle and block_reader). Support list: SubsetRandomSampler,
PkSampler, RandomSampler, SequentialSampler, DistributedSampler.
padded_sample (dict, optional): Samples will be appended to dataset, which
padded_sample (dict, optional): Samples will be appended to dataset, where
keys are the same as column_list.
num_padded (int, optional): Number of padding samples. Dataset size
plus num_padded should be divisible by num_shards.

View File

@ -0,0 +1,411 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/datasets.h"
using namespace mindspore::dataset::api;
using mindspore::dataset::Tensor;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestPipeline, TestMindDataSuccess1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataSuccess1 with string file pattern.";
// Create a MindData Dataset
// Pass one mindrecord shard file to parse dataset info, and search for other mindrecord files with same dataset info,
// thus all records in imagenet.mindrecord0 ~ imagenet.mindrecord3 will be read
std::string file_path = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::shared_ptr<Dataset> ds = MindData(file_path);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["file_name"];
MS_LOG(INFO) << "Tensor image file name: " << *image;
iter->GetNextRow(&row);
}
// Each *.mindrecord file has 5 rows, so there are 20 rows in total(imagenet.mindrecord0 ~ imagenet.mindrecord3)
EXPECT_EQ(i, 20);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestMindDataSuccess2) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataSuccess2 with a vector of single mindrecord file.";
// Create a MindData Dataset
// Pass a list of mindrecord file name, files in list will be read directly but not search for related files
std::string file_path1 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::shared_ptr<Dataset> ds = MindData(std::vector<std::string>{file_path1});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["file_name"];
MS_LOG(INFO) << "Tensor image file name: " << *image;
iter->GetNextRow(&row);
}
// Only records in imagenet.mindrecord0 are read
EXPECT_EQ(i, 5);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestMindDataSuccess3) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataSuccess3 with a vector of multiple mindrecord files.";
// Create a MindData Dataset
// Pass a list of mindrecord file name, files in list will be read directly but not search for related files
std::string file_path1 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::string file_path2 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord1";
std::vector<std::string> file_list = {file_path1, file_path2};
std::shared_ptr<Dataset> ds = MindData(file_list);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["file_name"];
MS_LOG(INFO) << "Tensor image file name: " << *image;
iter->GetNextRow(&row);
}
// Only records in imagenet.mindrecord0 and imagenet.mindrecord1 are read
EXPECT_EQ(i, 10);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestMindDataSuccess4) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataSuccess4 with specified column.";
// Create a MindData Dataset
// Pass one mindrecord shard file to parse dataset info, and search for other mindrecord files with same dataset info,
// thus all records in imagenet.mindrecord0 ~ imagenet.mindrecord3 will be read
std::string file_path1 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord1";
std::shared_ptr<Dataset> ds = MindData(file_path1, {"label"});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto label = row["label"];
MS_LOG(INFO) << "Tensor label: " << *label;
iter->GetNextRow(&row);
}
// Shard file "mindrecord0/mindrecord1/mindrecord2/mindrecord3" have same dataset info,
// thus if input file is any of them, all records in imagenet.mindrecord* will be read
EXPECT_EQ(i, 20);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestMindDataSuccess5) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataSuccess5 with specified sampler.";
// Create a MindData Dataset
// Pass one mindrecord shard file to parse dataset info, and search for other mindrecord files with same dataset info,
// thus all records in imagenet.mindrecord0 ~ imagenet.mindrecord3 will be read
std::string file_path1 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::shared_ptr<Dataset> ds = MindData(file_path1, {}, SequentialSampler(0, 3));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto label = row["label"];
std::shared_ptr<Tensor> expected_item;
Tensor::CreateScalar((int64_t)0, &expected_item);
EXPECT_EQ(*expected_item, *label);
iter->GetNextRow(&row);
}
// SequentialSampler will return 3 samples
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestMindDataSuccess6) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataSuccess6 with num_samples out of range.";
// Create a MindData Dataset
// Pass a list of mindrecord file name, files in list will be read directly but not search for related files
// imagenet.mindrecord0 file has 5 rows, but num_samples is larger than 5
std::string file_path1 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::vector<std::string> file_list = {file_path1};
// Check sequential sampler, output number is 10, with duplicate samples(a little weird, wait to fix)
std::shared_ptr<Dataset> ds1 = MindData(file_list, {}, SequentialSampler(0, 10));
EXPECT_NE(ds1, nullptr);
// Check random sampler, output number is 5, same rows with file
std::shared_ptr<Dataset> ds2 = MindData(file_list, {}, RandomSampler(false, 10));
EXPECT_NE(ds2, nullptr);
// Check pk sampler, output number is 2, get 2 samples with label 0
std::shared_ptr<Dataset> ds3 = MindData(file_list, {}, PKSampler(2, false, 10));
EXPECT_NE(ds3, nullptr);
// Check distributed sampler, output number is 3, get 3 samples in shard 0
std::shared_ptr<Dataset> ds4 = MindData(file_list, {}, DistributedSampler(2, 0, false, 10));
EXPECT_NE(ds4, nullptr);
// Check distributed sampler get 3 samples with indice 0, 1 ,2
std::shared_ptr<Dataset> ds5 = MindData(file_list, {}, SubsetRandomSampler({0, 1, 2}, 10));
EXPECT_NE(ds5, nullptr);
std::vector<std::shared_ptr<Dataset>> ds = {ds1, ds2, ds3, ds4, ds5};
std::vector<int32_t> expected_samples = {10, 5, 2, 3, 3};
for (int32_t i = 0; i < ds.size(); i++) {
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds[i]->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t j = 0;
while (row.size() != 0) {
j++;
MS_LOG(INFO) << "Tensor label: " << *row["label"];
iter->GetNextRow(&row);
}
EXPECT_EQ(j, expected_samples[i]);
// Manually terminate the pipeline
iter->Stop();
}
}
TEST_F(MindDataTestPipeline, TestMindDataSuccess7) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataSuccess7 with padded sample.";
// Create pad sample for MindDataset
auto pad = nlohmann::json::object();
pad["file_name"] = "does_not_exist.jpg";
pad["label"] = 999;
// Create a MindData Dataset
// Pass a list of mindrecord file name, files in list will be read directly but not search for related files
std::string file_path1 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::vector<std::string> file_list = {file_path1};
std::shared_ptr<Dataset> ds = MindData(file_list, {"file_name", "label"}, SequentialSampler(), pad, 4);
EXPECT_NE(ds, nullptr);
// Create a Skip operation on ds, skip original data in mindrecord and get padded samples
ds = ds->Skip(5);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["file_name"];
auto label = row["label"];
MS_LOG(INFO) << "Tensor file name: " << *image;
MS_LOG(INFO) << "Tensor label: " << *label;
std::shared_ptr<Tensor> expected_item;
Tensor::CreateScalar((int64_t)999, &expected_item);
EXPECT_EQ(*expected_item, *label);
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 4);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestMindDataFail1) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataFail1 with incorrect file path.";
// Create a MindData Dataset with incorrect pattern
std::string file_path1 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/apple.mindrecord0";
std::shared_ptr<Dataset> ds1 = MindData(file_path1);
EXPECT_EQ(ds1, nullptr);
// Create a MindData Dataset with incorrect file path
std::string file_path2 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/apple.mindrecord0";
std::vector<std::string> file_list = {file_path2};
std::shared_ptr<Dataset> ds2 = MindData(file_list);
EXPECT_EQ(ds2, nullptr);
// Create a MindData Dataset with incorrect file path
// ATTENTION: file_path3 is not a pattern to search for ".mindrecord*"
std::string file_path3 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord";
std::shared_ptr<Dataset> ds3 = MindData(file_path3);
EXPECT_EQ(ds3, nullptr);
}
TEST_F(MindDataTestPipeline, TestMindDataFail2) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataFail2 with incorrect column name.";
// Create a MindData Dataset with incorrect column name
std::string file_path1 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::shared_ptr<Dataset> ds1 = MindData(file_path1, {""});
EXPECT_EQ(ds1, nullptr);
// Create a MindData Dataset with duplicate column name
std::string file_path2 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::shared_ptr<Dataset> ds2 = MindData(file_path2, {"label", "label"});
EXPECT_EQ(ds2, nullptr);
// Create a MindData Dataset with unexpected column name
std::string file_path3 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::vector<std::string> file_list = {file_path3};
std::shared_ptr<Dataset> ds3 = MindData(file_list, {"label", "not_exist"});
EXPECT_NE(ds3, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds3->CreateIterator();
EXPECT_EQ(iter, nullptr);
}
TEST_F(MindDataTestPipeline, TestMindDataFail3) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindDataFail3 with unsupported sampler.";
// Create a MindData Dataset with unsupported sampler
std::string file_path1 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::shared_ptr<Dataset> ds1 = MindData(file_path1, {}, WeightedRandomSampler({1, 1, 1, 1}));
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
EXPECT_EQ(iter1, nullptr);
// Create a MindData Dataset with incorrect sampler
std::string file_path2 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::shared_ptr<Dataset> ds2 = MindData(file_path2, {}, nullptr);
EXPECT_EQ(ds2, nullptr);
}
TEST_F(MindDataTestPipeline, TestMindDataFail4) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMindData with padded sample.";
// Create a MindData Dataset
std::string file_path1 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::shared_ptr<Dataset> ds1 = MindData(file_path1, {}, RandomSampler(), nullptr, 2);
// num_padded is specified but padded_sample is not
EXPECT_EQ(ds1, nullptr);
// Create paded sample for MindDataset
auto pad = nlohmann::json::object();
pad["file_name"] = "1.jpg";
pad["label"] = 123456;
// Create a MindData Dataset
std::string file_path2 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::shared_ptr<Dataset> ds2 = MindData(file_path2, {"label"}, RandomSampler(), pad, -2);
// num_padded must be greater than or equal to zero
EXPECT_EQ(ds2, nullptr);
// Create a MindData Dataset
std::string file_path3 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::shared_ptr<Dataset> ds3 = MindData(file_path3, {}, RandomSampler(), pad, 1);
// padded_sample is specified and requires columns_list as well
EXPECT_EQ(ds3, nullptr);
// Create paded sample with unmatch column name
auto pad2 = nlohmann::json::object();
pad2["a"] = "1.jpg";
pad2["b"] = 123456;
// Create a MindData Dataset
std::string file_path4 = datasets_root_path_ + "/../mindrecord/testMindDataSet/testImageNetData/imagenet.mindrecord0";
std::shared_ptr<Dataset> ds4 = MindData(file_path4, {"file_name", "label"}, RandomSampler(), pad2, 1);
// columns_list does not match any column in padded_sample
EXPECT_EQ(ds4, nullptr);
}