C++ API: Provide validate param functions

This commit is contained in:
Cathy Wong 2020-08-14 17:38:16 -04:00
parent 7d70fb4dc4
commit 9c8af0d1cf
3 changed files with 93 additions and 81 deletions

View File

@ -105,7 +105,7 @@ Dataset::Dataset() {
// Function to create a CelebADataset.
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type,
const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
const std::shared_ptr<SamplerObj> &sampler, bool decode,
const std::set<std::string> &extensions) {
auto ds = std::make_shared<CelebADataset>(dataset_dir, dataset_type, sampler, decode, extensions);
@ -114,7 +114,7 @@ std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std:
}
// Function to create a Cifar10Dataset.
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) {
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, sampler);
// Call derived class validation method.
@ -122,7 +122,7 @@ std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::sha
}
// Function to create a Cifar100Dataset.
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) {
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, sampler);
// Call derived class validation method.
@ -131,8 +131,8 @@ std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, std::s
// Function to create a CLUEDataset.
std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &clue_files, const std::string &task,
const std::string &usage, int64_t num_samples, ShuffleMode shuffle, int num_shards,
int shard_id) {
const std::string &usage, int64_t num_samples, ShuffleMode shuffle,
int32_t num_shards, int32_t shard_id) {
auto ds = std::make_shared<CLUEDataset>(clue_files, task, usage, num_samples, shuffle, num_shards, shard_id);
// Call derived class validation method.
@ -150,9 +150,10 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str
}
// Function to create a ImageFolderDataset.
std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode,
std::shared_ptr<SamplerObj> sampler, std::set<std::string> extensions,
std::map<std::string, int32_t> class_indexing) {
std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode,
const std::shared_ptr<SamplerObj> &sampler,
const std::set<std::string> &extensions,
const std::map<std::string, int32_t> &class_indexing) {
// This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false.
bool recursive = false;
@ -164,7 +165,7 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool de
}
// Function to create a MnistDataset.
std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler) {
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<MnistDataset>(dataset_dir, sampler);
// Call derived class validation method.
@ -181,7 +182,7 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
}
// Function to create a TextFileDataset.
std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files, int32_t num_samples,
std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int32_t num_samples,
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) {
auto ds = std::make_shared<TextFileDataset>(dataset_files, num_samples, shuffle, num_shards, shard_id);
@ -191,9 +192,9 @@ std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files
// Function to create a VOCDataset.
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &mode,
const std::map<std::string, int32_t> &class_index, bool decode,
std::shared_ptr<SamplerObj> sampler) {
auto ds = std::make_shared<VOCDataset>(dataset_dir, task, mode, class_index, decode, sampler);
const std::map<std::string, int32_t> &class_indexing, bool decode,
const std::shared_ptr<SamplerObj> &sampler) {
auto ds = std::make_shared<VOCDataset>(dataset_dir, task, mode, class_indexing, decode, sampler);
// Call derived class validation method.
return ds->ValidateParams() ? ds : nullptr;
@ -402,16 +403,57 @@ Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, in
return Status::OK();
}
// Helper function to validate dataset params
bool ValidateCommonDatasetParams(std::string dataset_dir) {
// Helper function to validate dataset directory parameter
bool ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir) {
if (dataset_dir.empty()) {
MS_LOG(ERROR) << "No dataset path is specified";
MS_LOG(ERROR) << dataset_name << ": dataset_dir is not specified.";
return false;
}
Path dir(dataset_dir);
if (!dir.IsDirectory()) {
MS_LOG(ERROR) << dataset_name << ": dataset_dir: [" << dataset_dir << "] is an invalid directory path.";
return false;
}
if (access(dataset_dir.c_str(), R_OK) == -1) {
MS_LOG(ERROR) << "No access to specified dataset path: " << dataset_dir;
MS_LOG(ERROR) << dataset_name << ": No access to specified dataset path: " << dataset_dir;
return false;
}
return true;
}
// Helper function to validate dataset dataset files parameter
bool ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) {
if (dataset_files.empty()) {
MS_LOG(ERROR) << dataset_name << ": dataset_files is not specified.";
return false;
}
for (auto f : dataset_files) {
Path dataset_file(f);
if (!dataset_file.Exists()) {
MS_LOG(ERROR) << dataset_name << ": dataset file: [" << f << "] is invalid or does not exist.";
return false;
}
}
return true;
}
// Helper function to validate dataset num_shards and shard_id parameters
bool ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id) {
if (num_shards <= 0) {
MS_LOG(ERROR) << dataset_name << ": Invalid num_shards: " << num_shards;
return false;
}
if (shard_id < 0 || shard_id >= num_shards) {
MS_LOG(ERROR) << dataset_name << ": Invalid input, shard_id: " << shard_id << ", num_shards: " << num_shards;
return false;
}
return true;
}
@ -431,9 +473,7 @@ CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &
extensions_(extensions) {}
bool CelebADataset::ValidateParams() {
Path dir(dataset_dir_);
if (!dir.IsDirectory()) {
MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified.";
if (!ValidateDatasetDirParam("CelebADataset", dataset_dir_)) {
return false;
}
std::set<std::string> dataset_type_list = {"all", "train", "valid", "test"};
@ -471,7 +511,7 @@ std::vector<std::shared_ptr<DatasetOp>> CelebADataset::Build() {
Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), sampler_(sampler) {}
bool Cifar10Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
bool Cifar10Dataset::ValidateParams() { return ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_); }
// Function to build CifarOp for Cifar10
std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
@ -500,7 +540,7 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() {
Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), sampler_(sampler) {}
bool Cifar100Dataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
bool Cifar100Dataset::ValidateParams() { return ValidateDatasetDirParam("Cifar100Dataset", dataset_dir_); }
// Function to build CifarOp for Cifar100
std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() {
@ -529,7 +569,7 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() {
// Constructor for CLUEDataset
CLUEDataset::CLUEDataset(const std::vector<std::string> clue_files, std::string task, std::string usage,
int64_t num_samples, ShuffleMode shuffle, int num_shards, int shard_id)
int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id)
: dataset_files_(clue_files),
task_(task),
usage_(usage),
@ -539,19 +579,10 @@ CLUEDataset::CLUEDataset(const std::vector<std::string> clue_files, std::string
shard_id_(shard_id) {}
bool CLUEDataset::ValidateParams() {
if (dataset_files_.empty()) {
MS_LOG(ERROR) << "CLUEDataset: dataset_files is not specified.";
if (!ValidateDatasetFilesParam("CLUEDataset", dataset_files_)) {
return false;
}
for (auto f : dataset_files_) {
Path clue_file(f);
if (!clue_file.Exists()) {
MS_LOG(ERROR) << "dataset file: [" << f << "] is invalid or not exist";
return false;
}
}
std::vector<std::string> task_list = {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"};
std::vector<std::string> usage_list = {"train", "test", "eval"};
@ -570,13 +601,7 @@ bool CLUEDataset::ValidateParams() {
return false;
}
if (num_shards_ <= 0) {
MS_LOG(ERROR) << "CLUEDataset: Invalid num_shards: " << num_shards_;
return false;
}
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
MS_LOG(ERROR) << "CLUEDataset: Invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_;
if (!ValidateDatasetShardParams("CLUEDataset", num_shards_, shard_id_)) {
return false;
}
@ -734,9 +759,7 @@ CocoDataset::CocoDataset(const std::string &dataset_dir, const std::string &anno
: dataset_dir_(dataset_dir), annotation_file_(annotation_file), task_(task), decode_(decode), sampler_(sampler) {}
bool CocoDataset::ValidateParams() {
Path dir(dataset_dir_);
if (!dir.IsDirectory()) {
MS_LOG(ERROR) << "Invalid dataset path or no dataset path is specified.";
if (!ValidateDatasetDirParam("CocoDataset", dataset_dir_)) {
return false;
}
Path annotation_file(annotation_file_);
@ -829,7 +852,7 @@ ImageFolderDataset::ImageFolderDataset(std::string dataset_dir, bool decode, std
class_indexing_(class_indexing),
exts_(extensions) {}
bool ImageFolderDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
bool ImageFolderDataset::ValidateParams() { return ValidateDatasetDirParam("ImageFolderDataset", dataset_dir_); }
std::vector<std::shared_ptr<DatasetOp>> ImageFolderDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
@ -857,7 +880,7 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderDataset::Build() {
MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir), sampler_(sampler) {}
bool MnistDataset::ValidateParams() { return ValidateCommonDatasetParams(dataset_dir_); }
bool MnistDataset::ValidateParams() { return ValidateDatasetDirParam("MnistDataset", dataset_dir_); }
std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() {
// A vector containing shared pointer to the Dataset Ops that this object will create
@ -890,31 +913,16 @@ TextFileDataset::TextFileDataset(std::vector<std::string> dataset_files, int32_t
shard_id_(shard_id) {}
bool TextFileDataset::ValidateParams() {
if (dataset_files_.empty()) {
MS_LOG(ERROR) << "TextFileDataset: dataset_files is not specified.";
if (!ValidateDatasetFilesParam("TextFileDataset", dataset_files_)) {
return false;
}
for (auto file : dataset_files_) {
std::ifstream handle(file);
if (!handle.is_open()) {
MS_LOG(ERROR) << "TextFileDataset: Failed to open file: " << file;
return false;
}
}
if (num_samples_ < 0) {
MS_LOG(ERROR) << "TextFileDataset: Invalid number of samples: " << num_samples_;
return false;
}
if (num_shards_ <= 0) {
MS_LOG(ERROR) << "TextFileDataset: Invalid num_shards: " << num_shards_;
return false;
}
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
MS_LOG(ERROR) << "TextFileDataset: Invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_;
if (!ValidateDatasetShardParams("TextfileDataset", num_shards_, shard_id_)) {
return false;
}
@ -960,12 +968,12 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
// Constructor for VOCDataset
VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
const std::map<std::string, int32_t> &class_index, bool decode,
const std::map<std::string, int32_t> &class_indexing, bool decode,
std::shared_ptr<SamplerObj> sampler)
: dataset_dir_(dataset_dir),
task_(task),
mode_(mode),
class_index_(class_index),
class_index_(class_indexing),
decode_(decode),
sampler_(sampler) {}

View File

@ -75,7 +75,7 @@ class ZipDataset;
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type = "all",
const std::shared_ptr<SamplerObj> &sampler = nullptr, const bool &decode = false,
const std::shared_ptr<SamplerObj> &sampler = nullptr, bool decode = false,
const std::set<std::string> &extensions = {});
/// \brief Function to create a Cifar10 Dataset
@ -84,7 +84,8 @@ std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std:
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler`
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr);
std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir,
const std::shared_ptr<SamplerObj> &sampler = nullptr);
/// \brief Function to create a Cifar100 Dataset
/// \notes The generated dataset has three columns ['image', 'coarse_label', 'fine_label']
@ -93,7 +94,7 @@ std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::sha
/// will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current Dataset
std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
std::shared_ptr<SamplerObj> sampler = nullptr);
const std::shared_ptr<SamplerObj> &sampler = nullptr);
/// \brief Function to create a CLUEDataset
/// \notes The generated dataset has a variable number of columns depending on the task and usage
@ -114,7 +115,8 @@ std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir,
/// \return Shared pointer to the current CLUEDataset
std::shared_ptr<CLUEDataset> CLUE(const std::vector<std::string> &dataset_files, const std::string &task = "AFQMC",
const std::string &usage = "train", int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int num_shards = 1, int shard_id = 0);
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0);
/// \brief Function to create a CocoDataset
/// \notes The generated dataset has multi-columns :
@ -147,10 +149,10 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str
/// \param[in] extensions File extensions to be read
/// \param[in] class_indexing a class name to label map
/// \return Shared pointer to the current ImageFolderDataset
std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode = false,
std::shared_ptr<SamplerObj> sampler = nullptr,
std::set<std::string> extensions = {},
std::map<std::string, int32_t> class_indexing = {});
std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, bool decode = false,
const std::shared_ptr<SamplerObj> &sampler = nullptr,
const std::set<std::string> &extensions = {},
const std::map<std::string, int32_t> &class_indexing = {});
/// \brief Function to create a MnistDataset
/// \notes The generated dataset has two columns ['image', 'label']
@ -158,7 +160,8 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool de
/// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`,
/// A `RandomSampler` will be used to randomly iterate the entire dataset
/// \return Shared pointer to the current MnistDataset
std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr);
std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir,
const std::shared_ptr<SamplerObj> &sampler = nullptr);
/// \brief Function to create a ConcatDataset
/// \notes Reload "+" operator to concat two datasets
@ -183,7 +186,7 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
/// \param[in] shard_id The shard ID within num_shards. This argument should be
/// specified only when num_shards is also specified. (Default = 0)
/// \return Shared pointer to the current TextFileDataset
std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files, int32_t num_samples = 0,
std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &dataset_files, int32_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0);
@ -202,8 +205,8 @@ std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files
/// \return Shared pointer to the current Dataset
std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation",
const std::string &mode = "train",
const std::map<std::string, int32_t> &class_index = {}, bool decode = false,
std::shared_ptr<SamplerObj> sampler = nullptr);
const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false,
const std::shared_ptr<SamplerObj> &sampler = nullptr);
/// \brief Function to create a ZipDataset
/// \notes Applies zip to the dataset
@ -417,7 +420,7 @@ class CLUEDataset : public Dataset {
public:
/// \brief Constructor
CLUEDataset(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
ShuffleMode shuffle, int num_shards, int shard_id);
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id);
/// \brief Destructor
~CLUEDataset() = default;
@ -440,8 +443,8 @@ class CLUEDataset : public Dataset {
std::string usage_;
int64_t num_samples_;
ShuffleMode shuffle_;
int num_shards_;
int shard_id_;
int32_t num_shards_;
int32_t shard_id_;
};
class CocoDataset : public Dataset {
@ -549,7 +552,7 @@ class VOCDataset : public Dataset {
public:
/// \brief Constructor
VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode,
const std::map<std::string, int32_t> &class_index, bool decode, std::shared_ptr<SamplerObj> sampler);
const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler);
/// \brief Destructor
~VOCDataset() = default;

View File

@ -111,7 +111,8 @@ TEST_F(MindDataTestPipeline, TestTextFileDatasetFail3) {
// Attempt to create a TextFile Dataset
// with non-existent dataset_files input
std::shared_ptr<Dataset> ds = TextFile({"notexist.txt"}, 0, ShuffleMode::kFalse);
std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt";
std::shared_ptr<Dataset> ds = TextFile({tf_file1, "notexist.txt"}, 0, ShuffleMode::kFalse);
// Expect failure: specified dataset_files does not exist
EXPECT_EQ(ds, nullptr);