C++ API: Minor improvements to ValidateParams support

This commit is contained in:
Cathy Wong 2020-10-20 15:37:58 -04:00
parent db0868d745
commit 0e5f7beebc
1 changed files with 59 additions and 72 deletions

View File

@ -14,10 +14,10 @@
* limitations under the License.
*/
#include "minddata/dataset/include/datasets.h"
#include <algorithm>
#include <fstream>
#include <unordered_set>
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/include/transforms.h"
// Source dataset headers (in alphabetical order)
@ -696,7 +696,7 @@ Status ValidateDatasetDirParam(const std::string &dataset_name, std::string data
return Status::OK();
}
// Helper function to validate dataset dataset files parameter
// Helper function to validate dataset files parameter
Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) {
if (dataset_files.empty()) {
std::string err_msg = dataset_name + ": dataset_files is not specified.";
@ -743,7 +743,6 @@ Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_s
// Helper function to validate dataset sampler parameter
Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) {
if (sampler == nullptr) {
MS_LOG(ERROR) << dataset_name << ": Sampler is not constructed correctly, sampler: nullptr";
std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
@ -751,12 +750,13 @@ Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared
return Status::OK();
}
Status ValidateStringValue(const std::string &str, const std::unordered_set<std::string> &valid_strings) {
Status ValidateStringValue(const std::string &dataset_name, const std::string &str,
const std::unordered_set<std::string> &valid_strings) {
if (valid_strings.find(str) == valid_strings.end()) {
std::string mode;
mode = std::accumulate(valid_strings.begin(), valid_strings.end(), mode,
[](std::string a, std::string b) { return std::move(a) + " " + std::move(b); });
std::string err_msg = str + " does not match any mode in [" + mode + " ]";
std::string err_msg = dataset_name + ": " + str + " does not match any mode in [" + mode + " ]";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
@ -842,7 +842,7 @@ Status CelebANode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebANode", sampler_));
RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"all", "train", "valid", "test"}));
RETURN_IF_NOT_OK(ValidateStringValue("CelebANode", usage_, {"all", "train", "valid", "test"}));
return Status::OK();
}
@ -873,7 +873,7 @@ Status Cifar10Node::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Node", sampler_));
RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"}));
RETURN_IF_NOT_OK(ValidateStringValue("Cifar10Node", usage_, {"train", "test", "all"}));
return Status::OK();
}
@ -906,7 +906,7 @@ Status Cifar100Node::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Node", sampler_));
RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"}));
RETURN_IF_NOT_OK(ValidateStringValue("Cifar100Node", usage_, {"train", "test", "all"}));
return Status::OK();
}
@ -945,20 +945,9 @@ CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task,
Status CLUENode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_));
std::vector<std::string> task_list = {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"};
std::vector<std::string> usage_list = {"train", "test", "eval"};
RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", task_, {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"}));
if (find(task_list.begin(), task_list.end(), task_) == task_list.end()) {
std::string err_msg = "task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) {
std::string err_msg = "usage should be train, test or eval.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", usage_, {"train", "test", "eval"}));
if (num_samples_ < 0) {
std::string err_msg = "CLUENode: Invalid number of samples: " + std::to_string(num_samples_);
@ -1133,18 +1122,12 @@ Status CocoNode::ValidateParams() {
Path annotation_file(annotation_file_);
if (!annotation_file.Exists()) {
std::string err_msg = "annotation_file is invalid or not exist";
std::string err_msg = "CocoNode: annotation_file is invalid or does not exist.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
std::set<std::string> task_list = {"Detection", "Stuff", "Panoptic", "Keypoint"};
auto task_iter = task_list.find(task_);
if (task_iter == task_list.end()) {
std::string err_msg = "Invalid task type";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateStringValue("CocoNode", task_, {"Detection", "Stuff", "Panoptic", "Keypoint"}));
return Status::OK();
}
@ -1348,7 +1331,7 @@ Status ManifestNode::ValidateParams() {
for (char c : dataset_file_) {
auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c);
if (p != forbidden_symbols.end()) {
std::string err_msg = "filename should not contains :*?\"<>|`&;\'";
std::string err_msg = "ManifestNode: filename should not contain :*?\"<>|`&;\'";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
@ -1356,19 +1339,14 @@ Status ManifestNode::ValidateParams() {
Path manifest_file(dataset_file_);
if (!manifest_file.Exists()) {
std::string err_msg = "dataset file: [" + dataset_file_ + "] is invalid or not exist";
std::string err_msg = "ManifestNode: dataset file: [" + dataset_file_ + "] is invalid or not exist";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateDatasetSampler("ManifestNode", sampler_));
std::vector<std::string> usage_list = {"train", "eval", "inference"};
if (find(usage_list.begin(), usage_list.end(), usage_) == usage_list.end()) {
std::string err_msg = "usage should be train, eval or inference.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateStringValue("ManifestNode", usage_, {"train", "eval", "inference"}));
return Status::OK();
}
@ -1536,7 +1514,7 @@ Status MnistNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistNode", sampler_));
RETURN_IF_NOT_OK(ValidateStringValue(usage_, {"train", "test", "all"}));
RETURN_IF_NOT_OK(ValidateStringValue("MnistNode", usage_, {"train", "test", "all"}));
return Status::OK();
}
@ -1753,35 +1731,32 @@ VOCNode::VOCNode(const std::string &dataset_dir, const std::string &task, const
Status VOCNode::ValidateParams() {
Path dir(dataset_dir_);
if (!dir.IsDirectory()) {
std::string err_msg = "Invalid dataset path or no dataset path is specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateDatasetDirParam("VOCNode", dataset_dir_));
RETURN_IF_NOT_OK(ValidateDatasetSampler("VOCNode", sampler_));
if (task_ == "Segmentation") {
if (!class_index_.empty()) {
std::string err_msg = "class_indexing is invalid in Segmentation task.";
std::string err_msg = "VOCNode: class_indexing is invalid in Segmentation task.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
Path imagesets_file = dir / "ImageSets" / "Segmentation" / usage_ + ".txt";
if (!imagesets_file.Exists()) {
std::string err_msg = "Invalid usage: " + usage_ + ", file does not exist";
MS_LOG(ERROR) << "Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!";
std::string err_msg = "VOCNode: Invalid usage: " + usage_ + ", file does not exist";
MS_LOG(ERROR) << "VOCNode: Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!";
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
} else if (task_ == "Detection") {
Path imagesets_file = dir / "ImageSets" / "Main" / usage_ + ".txt";
if (!imagesets_file.Exists()) {
std::string err_msg = "Invalid usage: " + usage_ + ", file does not exist";
MS_LOG(ERROR) << "Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!";
std::string err_msg = "VOCNode: Invalid usage: " + usage_ + ", file does not exist";
MS_LOG(ERROR) << "VOCNode: Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!";
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
} else {
std::string err_msg = "Invalid task: " + task_;
std::string err_msg = "VOCNode: Invalid task: " + task_;
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
@ -1859,15 +1834,17 @@ std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() {
Status BatchNode::ValidateParams() {
if (batch_size_ <= 0) {
std::string err_msg = "Batch: batch_size should be positive integer, but got: " + std::to_string(batch_size_);
std::string err_msg = "BatchNode: batch_size should be positive integer, but got: " + std::to_string(batch_size_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (!cols_to_map_.empty()) {
std::string err_msg = "cols_to_map functionality is not implemented in C++; this should be left empty.";
std::string err_msg = "BatchNode: cols_to_map functionality is not implemented in C++; this should be left empty.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
@ -1906,28 +1883,29 @@ std::vector<std::shared_ptr<DatasetOp>> BucketBatchByLengthNode::Build() {
Status BucketBatchByLengthNode::ValidateParams() {
if (element_length_function_ == nullptr && column_names_.size() != 1) {
std::string err_msg =
"BucketBatchByLength: element_length_function not specified, but not one column name: " + column_names_.size();
std::string err_msg = "BucketBatchByLengthNode: element_length_function not specified, but not one column name: " +
column_names_.size();
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
// Check bucket_boundaries: must be positive and strictly increasing
if (bucket_boundaries_.empty()) {
std::string err_msg = "BucketBatchByLength: bucket_boundaries cannot be empty.";
std::string err_msg = "BucketBatchByLengthNode: bucket_boundaries cannot be empty.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
for (int i = 0; i < bucket_boundaries_.size(); i++) {
if (bucket_boundaries_[i] <= 0) {
std::string err_msg = "BucketBatchByLength: Invalid non-positive bucket_boundaries, index: ";
std::string err_msg = "BucketBatchByLengthNode: Invalid non-positive bucket_boundaries, index: ";
MS_LOG(ERROR)
<< "BucketBatchByLength: bucket_boundaries must only contain positive numbers. However, the element at index: "
<< i << " was: " << bucket_boundaries_[i];
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (i > 0 && bucket_boundaries_[i - 1] >= bucket_boundaries_[i]) {
std::string err_msg = "BucketBatchByLength: Invalid bucket_boundaries not be strictly increasing.";
std::string err_msg = "BucketBatchByLengthNode: Invalid bucket_boundaries not be strictly increasing.";
MS_LOG(ERROR)
<< "BucketBatchByLength: bucket_boundaries must be strictly increasing. However, the elements at index: "
<< i - 1 << " and " << i << " were: " << bucket_boundaries_[i - 1] << " and " << bucket_boundaries_[i]
@ -1938,20 +1916,24 @@ Status BucketBatchByLengthNode::ValidateParams() {
// Check bucket_batch_sizes: must be positive
if (bucket_batch_sizes_.empty()) {
std::string err_msg = "BucketBatchByLength: bucket_batch_sizes must be non-empty";
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must be non-empty";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (bucket_batch_sizes_.size() != bucket_boundaries_.size() + 1) {
std::string err_msg = "BucketBatchByLength: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1";
std::string err_msg =
"BucketBatchByLengthNode: bucket_batch_sizes's size must equal the size of bucket_boundaries + 1";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (std::any_of(bucket_batch_sizes_.begin(), bucket_batch_sizes_.end(), [](int i) { return i <= 0; })) {
std::string err_msg = "BucketBatchByLength: bucket_batch_sizes must only contain positive numbers.";
std::string err_msg = "BucketBatchByLengthNode: bucket_batch_sizes must only contain positive numbers.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
@ -1981,26 +1963,26 @@ std::vector<std::shared_ptr<DatasetOp>> BuildVocabNode::Build() {
Status BuildVocabNode::ValidateParams() {
if (vocab_ == nullptr) {
std::string err_msg = "BuildVocab: vocab is null.";
std::string err_msg = "BuildVocabNode: vocab is null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (top_k_ <= 0) {
std::string err_msg = "BuildVocab: top_k should be positive, but got: " + top_k_;
std::string err_msg = "BuildVocabNode: top_k should be positive, but got: " + std::to_string(top_k_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (freq_range_.first < 0 || freq_range_.second > kDeMaxFreq || freq_range_.first > freq_range_.second) {
std::string err_msg = "BuildVocab: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)";
MS_LOG(ERROR) << "BuildVocab: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), "
std::string err_msg = "BuildVocabNode: frequency_range [a,b] violates 0 <= a <= b (a,b are inclusive)";
MS_LOG(ERROR) << "BuildVocabNode: frequency_range [a,b] should be 0 <= a <= b (a,b are inclusive), "
<< "but got [" << freq_range_.first << ", " << freq_range_.second << "]";
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (!columns_.empty()) {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocab", "columns", columns_));
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("BuildVocabNode", "columns", columns_));
}
return Status::OK();
@ -2014,15 +1996,17 @@ ConcatNode::ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets) :
Status ConcatNode::ValidateParams() {
if (datasets_.empty()) {
std::string err_msg = "Concat: concatenated datasets are not specified.";
std::string err_msg = "ConcatNode: concatenated datasets are not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
std::string err_msg = "Concat: concatenated datasets should not be null.";
std::string err_msg = "ConcatNode: concatenated datasets should not be null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
@ -2070,7 +2054,7 @@ std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() {
Status MapNode::ValidateParams() {
if (operations_.empty()) {
std::string err_msg = "Map: No operation is specified.";
std::string err_msg = "MapNode: No operation is specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
@ -2158,8 +2142,8 @@ std::vector<std::shared_ptr<DatasetOp>> RepeatNode::Build() {
Status RepeatNode::ValidateParams() {
if (repeat_count_ <= 0 && repeat_count_ != -1) {
std::string err_msg =
"Repeat: repeat_count should be either -1 or positive integer, repeat_count_: " + std::to_string(repeat_count_);
std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " +
std::to_string(repeat_count_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
@ -2211,10 +2195,11 @@ std::vector<std::shared_ptr<DatasetOp>> SkipNode::Build() {
// Function to validate the parameters for SkipNode
Status SkipNode::ValidateParams() {
if (skip_count_ <= -1) {
std::string err_msg = "Skip: skip_count should not be negative, skip_count: " + std::to_string(skip_count_);
std::string err_msg = "SkipNode: skip_count should not be negative, skip_count: " + std::to_string(skip_count_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
@ -2236,7 +2221,7 @@ std::vector<std::shared_ptr<DatasetOp>> TakeNode::Build() {
Status TakeNode::ValidateParams() {
if (take_count_ <= 0 && take_count_ != -1) {
std::string err_msg =
"Take: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_);
"TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
@ -2252,15 +2237,17 @@ ZipNode::ZipNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datase
Status ZipNode::ValidateParams() {
if (datasets_.empty()) {
std::string err_msg = "Zip: datasets to zip are not specified.";
std::string err_msg = "ZipNode: datasets to zip are not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
std::string err_msg = "ZipNode: zip datasets should not be null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}