forked from mindspore-Ecosystem/mindspore
Update minddata log message
This commit is contained in:
parent
f513e0bfdd
commit
1a6a42c083
|
@ -220,7 +220,7 @@ Status DataSchema::ColumnOrderLoad(nlohmann::json column_tree, const std::vector
|
|||
// Find the column in the json document
|
||||
auto column_info = column_tree.find(common::SafeCStr(curr_col_name));
|
||||
if (column_info == column_tree.end()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, failed to find column name: " + curr_col_name + " in given json file.");
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, failed to find column: " + curr_col_name + " in JSON schema file.");
|
||||
}
|
||||
// At this point, columnInfo.value() is the subtree in the json document that contains
|
||||
// all of the data for a given column. This data will formulate our schema column.
|
||||
|
@ -238,7 +238,8 @@ Status DataSchema::ColumnOrderLoad(nlohmann::json column_tree, const std::vector
|
|||
for (const auto &it_child : column_tree.items()) {
|
||||
auto name = it_child.value().find("name");
|
||||
if (name == it_child.value().end()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, \"name\" field is missing for column: " + curr_col_name);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, \"name\" field is missing for column: " + curr_col_name +
|
||||
" in JSON schema file.");
|
||||
}
|
||||
if (name.value() == curr_col_name) {
|
||||
index = i;
|
||||
|
@ -247,7 +248,7 @@ Status DataSchema::ColumnOrderLoad(nlohmann::json column_tree, const std::vector
|
|||
i++;
|
||||
}
|
||||
if (index == -1) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, failed to find column name: " + curr_col_name + " in given json file.");
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, failed to find column: " + curr_col_name + " in JSON schema file.");
|
||||
}
|
||||
nlohmann::json column_child_tree = column_tree[index];
|
||||
RETURN_IF_NOT_OK(ColumnLoad(column_child_tree, curr_col_name));
|
||||
|
@ -301,14 +302,12 @@ Status DataSchema::ColumnLoad(nlohmann::json column_child_tree, const std::strin
|
|||
}
|
||||
if (!name.empty()) {
|
||||
if (!col_name.empty() && col_name != name) {
|
||||
std::string err_msg =
|
||||
"Invalid data, json schema file for column " + col_name + " has column name that does not match columnsToLoad";
|
||||
std::string err_msg = "Invalid data, failed to find column: " + col_name + " in JSON schema file.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
} else {
|
||||
if (col_name.empty()) {
|
||||
std::string err_msg =
|
||||
"Invalid data, json schema file for column " + col_name + " has invalid or missing column name.";
|
||||
std::string err_msg = "Invalid data, \"name\" field is missing for column " + col_name + " in JSON schema file.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
} else {
|
||||
name = col_name;
|
||||
|
@ -317,12 +316,12 @@ Status DataSchema::ColumnLoad(nlohmann::json column_child_tree, const std::strin
|
|||
// data type is mandatory field
|
||||
if (type_str.empty())
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid data, json schema file for column " + col_name + " has invalid or missing column type.");
|
||||
"Invalid data, \"type\" field is missing for column " + col_name + " in JSON schema file.");
|
||||
|
||||
// rank number is mandatory field
|
||||
if (rank_value <= -1)
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid data, json schema file for column " + col_name + " must define a positive rank value.");
|
||||
"Invalid data, \"rank\" field of column " + col_name + " must have value >= 0 in JSON schema file.");
|
||||
|
||||
// Create the column descriptor for this column from the data we pulled from the json file
|
||||
TensorShape col_shape = TensorShape(tmp_shape);
|
||||
|
@ -349,12 +348,13 @@ Status DataSchema::LoadSchemaFile(const std::string &schema_file_path,
|
|||
num_rows_ = 0;
|
||||
} catch (nlohmann::json::exception &e) {
|
||||
in.close();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, unable to parse \"numRows\" from schema file: " + schema_file_path);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, unable to parse \"numRows\" field from JSON schema file: " +
|
||||
schema_file_path + ", check syntax with JSON tool.");
|
||||
}
|
||||
nlohmann::json column_tree = js.at("columns");
|
||||
if (column_tree.empty()) {
|
||||
in.close();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, \"columns\" field is missing in schema file: " + schema_file_path);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, \"columns\" field is missing in JSON schema file: " + schema_file_path);
|
||||
}
|
||||
if (columns_to_load.empty()) {
|
||||
// Parse the json tree and load the schema's columns in whatever order that the json
|
||||
|
@ -375,7 +375,8 @@ Status DataSchema::LoadSchemaFile(const std::string &schema_file_path,
|
|||
in.close();
|
||||
} catch (const std::exception &err) {
|
||||
// Catch any exception and convert to Status return code
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to load and parse schema file: " + schema_file_path);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, failed to load and parse JSON schema file: " + schema_file_path +
|
||||
", check syntax with JSON tools.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -389,7 +390,7 @@ Status DataSchema::LoadSchemaString(const std::string &schema_json_string,
|
|||
num_rows_ = js.value("numRows", 0);
|
||||
nlohmann::json column_tree = js.at("columns");
|
||||
if (column_tree.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, \"columns\" field is missing in schema string.");
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, \"columns\" field is missing in JSON schema string.");
|
||||
}
|
||||
if (columns_to_load.empty()) {
|
||||
// Parse the json tree and load the schema's columns in whatever order that the json
|
||||
|
@ -404,7 +405,7 @@ Status DataSchema::LoadSchemaString(const std::string &schema_json_string,
|
|||
}
|
||||
} catch (const std::exception &err) {
|
||||
// Catch any exception and convert to Status return code
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, failed to load and parse schema string.");
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, failed to load and parse JSON schema string, check syntax with JSON tool.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -446,7 +447,7 @@ Status DataSchema::PreLoadExceptionCheck(const nlohmann::json &js) {
|
|||
// Check if columns node exists. It is required for building schema from file.
|
||||
if (js.find("columns") == js.end())
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"Invalid data, \"columns\" node is required in the schema json file.");
|
||||
"Invalid data, \"columns\" field is missing in the JSON schema file.");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -72,18 +72,18 @@ void BatchNode::Print(std::ostream &out) const {
|
|||
Status BatchNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
if (batch_size_ <= 0) {
|
||||
std::string err_msg = "BatchNode: batch_size should be positive integer, but got: " + std::to_string(batch_size_);
|
||||
std::string err_msg = "Batch: 'batch_size' should be positive integer, but got: " + std::to_string(batch_size_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
if (batch_map_func_ && pad_) {
|
||||
std::string err_msg = "BatchNode: per_batch_map and pad should not be used at the same time.";
|
||||
std::string err_msg = "Batch: 'per_batch_map' and 'pad_info' should not be used at the same time.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (batch_map_func_ && in_col_names_.empty()) {
|
||||
std::string err_msg = "BatchNode: in_col_names cannot be empty when per_batch_map is used.";
|
||||
std::string err_msg = "Batch: 'in_col_names' cannot be empty when per_batch_map is used.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
#endif
|
||||
|
@ -169,10 +169,9 @@ Status BatchNode::to_json(nlohmann::json *out_json) {
|
|||
|
||||
Status BatchNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("batch_size") != json_obj.end(), "Failed to find batch_size");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("drop_remainder") != json_obj.end(), "Failed to find drop_remainder");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kBatchNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "batch_size", kBatchNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "drop_remainder", kBatchNode));
|
||||
int32_t batch_size = json_obj["batch_size"];
|
||||
bool drop_remainder = json_obj["drop_remainder"];
|
||||
*result = std::make_shared<BatchNode>(ds, batch_size, drop_remainder);
|
||||
|
|
|
@ -56,19 +56,19 @@ Status ConcatNode::ValidateParams() {
|
|||
constexpr size_t kMinChildrenSize = 2;
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
if (children_.size() < kMinChildrenSize) {
|
||||
std::string err_msg = "ConcatNode: concatenated datasets are not specified.";
|
||||
std::string err_msg = "Concat: concatenated datasets are not specified.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (find(children_.begin(), children_.end(), nullptr) != children_.end()) {
|
||||
std::string err_msg = "ConcatNode: concatenated datasets should not be null.";
|
||||
std::string err_msg = "Concat: concatenated datasets should not be null.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
// Either one of children_flag_and_nums_ or children_start_end_index_ should be non-empty.
|
||||
if ((children_flag_and_nums_.empty() && !children_start_end_index_.empty()) ||
|
||||
(!children_flag_and_nums_.empty() && children_start_end_index_.empty())) {
|
||||
std::string err_msg = "ConcatNode: children_flag_and_nums and children_start_end_index should be used together";
|
||||
std::string err_msg = "Concat: 'children_flag_and_nums' and 'children_start_end_index' should be used together";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -162,11 +162,9 @@ Status ConcatNode::to_json(nlohmann::json *out_json) {
|
|||
#ifndef ENABLE_ANDROID
|
||||
Status ConcatNode::from_json(nlohmann::json json_obj, std::vector<std::shared_ptr<DatasetNode>> datasets,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("children_flag_and_nums") != json_obj.end(),
|
||||
"Failed to find children_flag_and_nums");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("children_start_end_index") != json_obj.end(),
|
||||
"Failed to find children_start_end_index");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kConcatNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "children_flag_and_nums", kConcatNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "children_start_end_index", kConcatNode));
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
RETURN_IF_NOT_OK(Serdes::ConstructSampler(json_obj["sampler"], &sampler));
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums = json_obj["children_flag_and_nums"];
|
||||
|
|
|
@ -94,7 +94,8 @@ Status ValidateDatasetDirParam(const std::string &dataset_name, std::string data
|
|||
}
|
||||
|
||||
// Helper function to validate dataset files parameter
|
||||
Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files) {
|
||||
Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files,
|
||||
const std::string &file_name) {
|
||||
if (dataset_files.empty()) {
|
||||
std::string err_msg = dataset_name + ": dataset_files is not specified.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
|
@ -103,11 +104,11 @@ Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vec
|
|||
for (auto f : dataset_files) {
|
||||
Path dataset_file(f);
|
||||
if (!dataset_file.Exists()) {
|
||||
std::string err_msg = dataset_name + ": dataset file: [" + f + "] is invalid or does not exist.";
|
||||
std::string err_msg = dataset_name + ": " + file_name + ": [" + f + "] is invalid or does not exist.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (access(dataset_file.ToString().c_str(), R_OK) == -1) {
|
||||
std::string err_msg = dataset_name + ": No access to specified dataset file: " + f;
|
||||
std::string err_msg = dataset_name + ": No access to specified " + file_name + ": " + f;
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
|
@ -158,12 +159,12 @@ Status ValidateStringValue(const std::string &dataset_name, const std::string &s
|
|||
Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
|
||||
const std::vector<std::string> &columns) {
|
||||
if (columns.empty()) {
|
||||
std::string err_msg = dataset_name + ":" + column_param + " should not be empty string";
|
||||
std::string err_msg = dataset_name + ": '" + column_param + "' should not be empty string";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
for (uint32_t i = 0; i < columns.size(); ++i) {
|
||||
if (columns[i].empty()) {
|
||||
std::string err_msg = dataset_name + ":" + column_param + "[" + std::to_string(i) + "] must not be empty";
|
||||
std::string err_msg = dataset_name + ": '" + column_param + "' [" + std::to_string(i) + "] must not be empty";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
|
@ -171,8 +172,8 @@ Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::st
|
|||
for (auto &column_name : columns) {
|
||||
auto result = columns_set.insert(column_name);
|
||||
if (result.second == false) {
|
||||
std::string err_msg = dataset_name + ":" + column_param +
|
||||
": Invalid parameter, duplicate column names are not allowed: " + *result.first;
|
||||
std::string err_msg = dataset_name + ": '" + column_param +
|
||||
"' : Invalid parameter, duplicate column names are not allowed: " + *result.first;
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,8 +38,10 @@
|
|||
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
#include "minddata/dataset/kernels/ir/validators.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/validators.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -110,7 +112,8 @@ Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, in
|
|||
int32_t connector_que_size, std::shared_ptr<DatasetOp> *shuffle_op);
|
||||
|
||||
// Helper function to validate dataset files parameter
|
||||
Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files);
|
||||
Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files,
|
||||
const std::string &file_name = "dataset file");
|
||||
|
||||
// Helper function to validate dataset num_shards and shard_id parameters
|
||||
Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id);
|
||||
|
|
|
@ -103,27 +103,27 @@ Status MapNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
|||
Status MapNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
if (operations_.empty()) {
|
||||
std::string err_msg = "MapNode: No operation is specified.";
|
||||
std::string err_msg = "Map: No 'operations' are specified.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
for (const auto &op : operations_) {
|
||||
if (op == nullptr) {
|
||||
std::string err_msg = "MapNode: operation must not be nullptr.";
|
||||
std::string err_msg = "Map: 'operations' must not be nullptr.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(op->ValidateParams());
|
||||
}
|
||||
}
|
||||
if (!input_columns_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "input_columns", input_columns_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("Map", "input_columns", input_columns_));
|
||||
}
|
||||
|
||||
if (!output_columns_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "output_columns", output_columns_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("Map", "output_columns", output_columns_));
|
||||
}
|
||||
|
||||
if (!project_columns_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "project_columns", project_columns_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("Map", "project_columns", project_columns_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -191,12 +191,11 @@ Status MapNode::to_json(nlohmann::json *out_json) {
|
|||
#ifndef ENABLE_ANDROID
|
||||
Status MapNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("input_columns") != json_obj.end(), "Failed to find input_columns");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("output_columns") != json_obj.end(), "Failed to find output_columns");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("project_columns") != json_obj.end(), "Failed to find project_columns");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("operations") != json_obj.end(), "Failed to find operations");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kMapNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "input_columns", kMapNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "output_columns", kMapNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "project_columns", kMapNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "operations", kMapNode));
|
||||
std::vector<std::string> input_columns = json_obj["input_columns"];
|
||||
std::vector<std::string> output_columns = json_obj["output_columns"];
|
||||
std::vector<std::string> project_columns = json_obj["project_columns"];
|
||||
|
|
|
@ -42,11 +42,11 @@ void ProjectNode::Print(std::ostream &out) const { out << (Name() + "(column: "
|
|||
Status ProjectNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
if (columns_.empty()) {
|
||||
std::string err_msg = "ProjectNode: No columns are specified.";
|
||||
std::string err_msg = "Project: No 'columns' are specified.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("ProjectNode", "columns", columns_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("Project", "columns", columns_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -68,7 +68,7 @@ Status ProjectNode::to_json(nlohmann::json *out_json) {
|
|||
|
||||
Status ProjectNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns") != json_obj.end(), "Failed to find columns");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "columns", kProjectNode));
|
||||
std::vector<std::string> columns = json_obj["columns"];
|
||||
*result = std::make_shared<ProjectNode>(ds, columns);
|
||||
return Status::OK();
|
||||
|
|
|
@ -45,13 +45,13 @@ void RenameNode::Print(std::ostream &out) const {
|
|||
Status RenameNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
if (input_columns_.size() != output_columns_.size()) {
|
||||
std::string err_msg = "RenameNode: input and output columns must be the same size";
|
||||
std::string err_msg = "Rename: 'input columns' and 'output columns' must have the same size.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "input_columns", input_columns_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("Rename", "input_columns", input_columns_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RenameNode", "output_columns", output_columns_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("Rename", "output_columns", output_columns_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -74,8 +74,8 @@ Status RenameNode::to_json(nlohmann::json *out_json) {
|
|||
|
||||
Status RenameNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("input_columns") != json_obj.end(), "Failed to find input_columns");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("output_columns") != json_obj.end(), "Failed to find output_columns");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "input_columns", kRenameNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "output_columns", kRenameNode));
|
||||
std::vector<std::string> input_columns = json_obj["input_columns"];
|
||||
std::vector<std::string> output_columns = json_obj["output_columns"];
|
||||
*result = std::make_shared<RenameNode>(ds, input_columns, output_columns);
|
||||
|
|
|
@ -60,8 +60,8 @@ Status RepeatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
|||
Status RepeatNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
if (repeat_count_ <= 0 && repeat_count_ != -1) {
|
||||
std::string err_msg = "RepeatNode: repeat_count should be either -1 or positive integer, repeat_count_: " +
|
||||
std::to_string(repeat_count_);
|
||||
std::string err_msg =
|
||||
"Repeat: 'repeat_count' should be either -1 or positive integer, but got: " + std::to_string(repeat_count_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
|
@ -106,7 +106,7 @@ Status RepeatNode::to_json(nlohmann::json *out_json) {
|
|||
|
||||
Status RepeatNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "count", kRepeatNode));
|
||||
int32_t count = json_obj["count"];
|
||||
*result = std::make_shared<RepeatNode>(ds, count);
|
||||
return Status::OK();
|
||||
|
|
|
@ -54,11 +54,7 @@ Status ShuffleNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op
|
|||
// Function to validate the parameters for ShuffleNode
|
||||
Status ShuffleNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
if (shuffle_size_ <= 1) {
|
||||
std::string err_msg = "ShuffleNode: Invalid input, shuffle_size: " + std::to_string(shuffle_size_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateScalar("Shuffle", "shuffle_size", shuffle_size_, {1}, true));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
|
|||
Status SkipNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
if (skip_count_ <= -1) {
|
||||
std::string err_msg = "SkipNode: skip_count should not be negative, skip_count: " + std::to_string(skip_count_);
|
||||
std::string err_msg = "Skip: 'skip_count' should not be negative, but got: " + std::to_string(skip_count_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
|
@ -95,7 +95,7 @@ Status SkipNode::to_json(nlohmann::json *out_json) {
|
|||
|
||||
Status SkipNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "count", kSkipNode));
|
||||
int32_t count = json_obj["count"];
|
||||
*result = std::make_shared<SkipNode>(ds, count);
|
||||
return Status::OK();
|
||||
|
|
|
@ -55,20 +55,13 @@ void AGNewsNode::Print(std::ostream &out) const {
|
|||
|
||||
Status AGNewsNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("AGNewsNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("AGNewsNode", usage_, {"train", "test", "all"}));
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "AGNewsNode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (num_shards_ < 1) {
|
||||
std::string err_msg = "AGNewsNode: Invalid number of shards: " + std::to_string(num_shards_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("AGNewsNode", num_shards_, shard_id_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("AGNewsDataset", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("AGNewsDataset", usage_, {"train", "test", "all"}));
|
||||
RETURN_IF_NOT_OK(ValidateScalar("AGNewsDataset", "num_samples", num_samples_, {0}, false));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("AGNewsDataset", num_shards_, shard_id_));
|
||||
|
||||
if (!column_names_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("AGNewsNode", "column_names", column_names_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("AGNewsDataset", "column_names", column_names_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -54,14 +54,14 @@ void AlbumNode::Print(std::ostream &out) const {
|
|||
|
||||
Status AlbumNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumDataset", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AlbumNode", {schema_path_}));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AlbumDataset", {schema_path_}));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("AlbumNode", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("AlbumDataset", sampler_));
|
||||
|
||||
if (!column_names_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("AlbumNode", "column_names", column_names_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("AlbumDataset", "column_names", column_names_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -148,13 +148,12 @@ Status AlbumNode::to_json(nlohmann::json *out_json) {
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status AlbumNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("data_schema") != json_obj.end(), "Failed to find data_schema");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("column_names") != json_obj.end(), "Failed to find column_names");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kAlbumNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kAlbumNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "data_schema", kAlbumNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "column_names", kAlbumNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kAlbumNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kAlbumNode));
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string data_schema = json_obj["data_schema"];
|
||||
std::vector<std::string> column_names = json_obj["column_names"];
|
||||
|
|
|
@ -55,11 +55,11 @@ void CelebANode::Print(std::ostream &out) const {
|
|||
|
||||
Status CelebANode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebADataset", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebANode", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebADataset", sampler_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("CelebANode", usage_, {"all", "train", "valid", "test"}));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("CelebADataset", usage_, {"all", "train", "valid", "test"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -99,8 +99,9 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
|
|||
|
||||
auto realpath = FileUtils::GetRealPath((folder_path / "list_attr_celeba.txt").ToString().data());
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path failed, path=" << (folder_path / "list_attr_celeba.txt").ToString();
|
||||
RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + (folder_path / "list_attr_celeba.txt").ToString());
|
||||
MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << (folder_path / "list_attr_celeba.txt").ToString();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" +
|
||||
(folder_path / "list_attr_celeba.txt").ToString());
|
||||
}
|
||||
|
||||
std::ifstream attr_file(realpath.value());
|
||||
|
@ -138,8 +139,10 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
|
|||
if (!partition_file.is_open()) {
|
||||
auto realpath_eval = FileUtils::GetRealPath((folder_path / "list_eval_partition.txt").ToString().data());
|
||||
if (!realpath_eval.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path failed, path=" << (folder_path / "list_eval_partition.txt").ToString();
|
||||
RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + (folder_path / "list_eval_partition.txt").ToString());
|
||||
MS_LOG(ERROR) << "Invalid file, get real path failed, path="
|
||||
<< (folder_path / "list_eval_partition.txt").ToString();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" +
|
||||
(folder_path / "list_eval_partition.txt").ToString());
|
||||
}
|
||||
|
||||
partition_file.open(realpath_eval.value());
|
||||
|
@ -188,13 +191,12 @@ Status CelebANode::to_json(nlohmann::json *out_json) {
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status CelebANode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extensions") != json_obj.end(), "Failed to find extension");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCelebANode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCelebANode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kCelebANode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kCelebANode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kCelebANode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "extensions", kCelebANode));
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
|
|
|
@ -47,11 +47,11 @@ void Cifar100Node::Print(std::ostream &out) const {
|
|||
|
||||
Status Cifar100Node::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Dataset", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Node", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Dataset", sampler_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("Cifar100Node", usage_, {"train", "test", "all"}));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("Cifar100Dataset", usage_, {"train", "test", "all"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -123,11 +123,10 @@ Status Cifar100Node::to_json(nlohmann::json *out_json) {
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status Cifar100Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCifar100Node));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCifar100Node));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kCifar100Node));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kCifar100Node));
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
|
|
|
@ -47,11 +47,11 @@ void Cifar10Node::Print(std::ostream &out) const {
|
|||
|
||||
Status Cifar10Node::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Node", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Dataset", sampler_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("Cifar10Node", usage_, {"train", "test", "all"}));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("Cifar10Dataset", usage_, {"train", "test", "all"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -124,11 +124,10 @@ Status Cifar10Node::to_json(nlohmann::json *out_json) {
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status Cifar10Node::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCifar10Node));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCifar10Node));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kCifar10Node));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kCifar10Node));
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
|
|
|
@ -50,23 +50,13 @@ void CLUENode::Print(std::ostream &out) const {
|
|||
|
||||
Status CLUENode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", task_, {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"}));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", usage_, {"train", "test", "eval"}));
|
||||
|
||||
if (shuffle_ != ShuffleMode::kFalse && shuffle_ != ShuffleMode::kFiles && shuffle_ != ShuffleMode::kGlobal) {
|
||||
std::string err_msg = "CLUENode: Invalid ShuffleMode, check input value of enum.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "CLUENode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("CLUENode", num_shards_, shard_id_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUEDataset", dataset_files_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("CLUEDataset", task_, {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"}));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("CLUEDataset", usage_, {"train", "test", "eval"}));
|
||||
RETURN_IF_NOT_OK(ValidateEnum("CLUEDataset", "ShuffleMode", shuffle_,
|
||||
{ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal}));
|
||||
RETURN_IF_NOT_OK(ValidateScalar("CLUEDataset", "num_samples", num_samples_, {0}, false));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("CLUEDataset", num_shards_, shard_id_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -250,15 +240,14 @@ Status CLUENode::to_json(nlohmann::json *out_json) {
|
|||
}
|
||||
|
||||
Status CLUENode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCLUENode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCLUENode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "task", kCLUENode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kCLUENode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_samples", kCLUENode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "shuffle", kCLUENode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_shards", kCLUENode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "shard_id", kCLUENode));
|
||||
std::vector<std::string> dataset_files = json_obj["dataset_dir"];
|
||||
std::string task = json_obj["task"];
|
||||
std::string usage = json_obj["usage"];
|
||||
|
|
|
@ -53,21 +53,10 @@ void CocoNode::Print(std::ostream &out) const { out << Name(); }
|
|||
|
||||
Status CocoNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoNode", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("CocoNode", sampler_));
|
||||
|
||||
Path annotation_file(annotation_file_);
|
||||
if (!annotation_file.Exists()) {
|
||||
std::string err_msg = "CocoNode: annotation_file is invalid or does not exist.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (access(annotation_file_.c_str(), R_OK) == -1) {
|
||||
std::string err_msg = "CocoNode: No access to specified annotation file: " + annotation_file_;
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("CocoNode", task_, {"Detection", "Stuff", "Panoptic", "Keypoint"}));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CocoDataset", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("CocoDataset", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CocoDataset", {annotation_file_}, "annotation file"));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("CocoDataset", task_, {"Detection", "Stuff", "Panoptic", "Keypoint"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -164,7 +153,7 @@ Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g
|
|||
int64_t num_rows = 0, sample_size;
|
||||
std::vector<std::shared_ptr<DatasetOp>> ops;
|
||||
RETURN_IF_NOT_OK(Build(&ops));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build CocoOp.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "[Internal ERROR] Unable to build CocoOp.");
|
||||
auto op = std::dynamic_pointer_cast<CocoOp>(ops.front());
|
||||
RETURN_IF_NOT_OK(op->CountTotalRows(&num_rows));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
|
@ -199,14 +188,13 @@ Status CocoNode::to_json(nlohmann::json *out_json) {
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status CocoNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extra_metadata") != json_obj.end(), "Failed to find extra_metadata");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCocoNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kCocoNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "annotation_file", kCocoNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "task", kCocoNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kCocoNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kCocoNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "extra_metadata", kCocoNode));
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string annotation_file = json_obj["annotation_file"];
|
||||
std::string task = json_obj["task"];
|
||||
|
|
|
@ -66,25 +66,18 @@ Status CSVNode::ValidateParams() {
|
|||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (shuffle_ != ShuffleMode::kFalse && shuffle_ != ShuffleMode::kFiles && shuffle_ != ShuffleMode::kGlobal) {
|
||||
std::string err_msg = "CSVNode: Invalid ShuffleMode, check input value of enum.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "CSVNode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("CSVNode", num_shards_, shard_id_));
|
||||
RETURN_IF_NOT_OK(ValidateEnum("CSVDataset", "ShuffleMode", shuffle_,
|
||||
{ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal}));
|
||||
RETURN_IF_NOT_OK(ValidateScalar("CSVDataset", "num_samples", num_samples_, {0}, false));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("CSVDataset", num_shards_, shard_id_));
|
||||
|
||||
if (find(column_defaults_.begin(), column_defaults_.end(), nullptr) != column_defaults_.end()) {
|
||||
std::string err_msg = "CSVNode: column_default should not be null.";
|
||||
std::string err_msg = "CSVDataset: column_default should not be null.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (!column_names_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("CSVNode", "column_names", column_names_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("CSVDataset", "column_names", column_names_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -187,15 +180,14 @@ Status CSVNode::to_json(nlohmann::json *out_json) {
|
|||
}
|
||||
|
||||
Status CSVNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("field_delim") != json_obj.end(), "Failed to find field_delim");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("column_names") != json_obj.end(), "Failed to find column_names");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kCSVNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_files", kCSVNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "field_delim", kCSVNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "column_names", kCSVNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_samples", kCSVNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "shuffle", kCSVNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_shards", kCSVNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "shard_id", kCSVNode));
|
||||
std::vector<std::string> dataset_files = json_obj["dataset_files"];
|
||||
std::string field_delim = json_obj["field_delim"];
|
||||
std::vector<std::shared_ptr<CsvBase>> column_defaults = {};
|
||||
|
|
|
@ -56,15 +56,10 @@ void DBpediaNode::Print(std::ostream &out) const {
|
|||
|
||||
Status DBpediaNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("DBpediaNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("DBpediaNode", usage_, {"train", "test", "all"}));
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "DBpediaNode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("DBpediaNode", num_shards_, shard_id_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("DBpediaDataset", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("DBpediaDataset", usage_, {"train", "test", "all"}));
|
||||
RETURN_IF_NOT_OK(ValidateScalar("DBpediaDataset", "num_samples", num_samples_, {0}, false));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("DBpediaDataset", num_shards_, shard_id_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -49,29 +49,21 @@ void FakeImageNode::Print(std::ostream &out) const {
|
|||
Status FakeImageNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("FakeImageNode", sampler_));
|
||||
if (num_images_ <= 0) {
|
||||
std::string err_msg = "FakeImageNode: num_images must be greater than 0, but got: " + std::to_string(num_images_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("FakeImageDataset", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateScalar("FakeImageDataset", "num_images", num_images_, {0}, true));
|
||||
|
||||
if (image_size_.size() != 3) {
|
||||
std::string err_msg =
|
||||
"FakeImageNode: image_size expecting size 3, but got image_size.size(): " + std::to_string(image_size_.size());
|
||||
std::string err_msg = "FakeImageDataset: 'image_size' expecting size 3, but got image_size.size(): " +
|
||||
std::to_string(image_size_.size());
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
for (auto i = 0; i < 3; i++) {
|
||||
if (image_size_[i] <= 0) {
|
||||
std::string err_msg = "FakeImageNode: image_size[" + std::to_string(i) +
|
||||
"] must be greater than 0, but got: " + std::to_string(image_size_[i]);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
if (num_classes_ <= 0) {
|
||||
std::string err_msg = "FakeImageNode: num_classes must be greater than 0, but got: " + std::to_string(num_classes_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
RETURN_IF_NOT_OK(
|
||||
ValidateScalar("FakeImageDataset", "image_size[" + std::to_string(i) + "]", image_size_[i], {0}, true));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateScalar("FakeImageDataset", "num_classes", num_classes_, {0}, true));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -59,10 +59,10 @@ void FlickrNode::Print(std::ostream &out) const {
|
|||
|
||||
Status FlickrNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("FlickrNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("FlickrDataset", dataset_dir_));
|
||||
|
||||
if (annotation_file_.empty()) {
|
||||
std::string err_msg = "FlickrNode: annotation_file is not specified.";
|
||||
std::string err_msg = "FlickrDataset: 'annotation_file' is not specified.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
|
@ -70,17 +70,14 @@ Status FlickrNode::ValidateParams() {
|
|||
for (char c : annotation_file_) {
|
||||
auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c);
|
||||
if (p != forbidden_symbols.end()) {
|
||||
std::string err_msg = "FlickrNode: annotation_file: [" + annotation_file_ + "] should not contain :*?\"<>|`&;\'.";
|
||||
std::string err_msg =
|
||||
"FlickrDataset: 'annotation_file': [" + annotation_file_ + "] should not contain :*?\"<>|`&;\'.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
Path annotation_file(annotation_file_);
|
||||
if (!annotation_file.Exists()) {
|
||||
std::string err_msg = "FlickrNode: annotation_file: [" + annotation_file_ + "] is invalid or not exist.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("FlickrNode", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("FlickrDataset", {annotation_file_}, "annotation file"));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("FlickrDataset", sampler_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -149,8 +146,7 @@ Status FlickrNode::to_json(nlohmann::json *out_json) {
|
|||
}
|
||||
|
||||
Status FlickrNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kFlickrNode));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("annotation_file") != json_obj.end(), "Failed to find annotation_file");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
|
|
|
@ -57,9 +57,9 @@ void ImageFolderNode::Print(std::ostream &out) const {
|
|||
|
||||
Status ImageFolderNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderDataset", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("ImageFolderNode", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("ImageFolderDataset", sampler_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -131,14 +131,13 @@ Status ImageFolderNode::to_json(nlohmann::json *out_json) {
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status ImageFolderNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("recursive") != json_obj.end(), "Failed to find recursive");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extensions") != json_obj.end(), "Failed to find extension");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kImageFolderNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kImageFolderNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kImageFolderNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kImageFolderNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "recursive", kImageFolderNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "extensions", kImageFolderNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "class_indexing", kImageFolderNode));
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
bool decode = json_obj["decode"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
|
|
|
@ -65,24 +65,18 @@ Status ManifestNode::ValidateParams() {
|
|||
for (char c : dataset_file_) {
|
||||
auto p = std::find(forbidden_symbols.begin(), forbidden_symbols.end(), c);
|
||||
if (p != forbidden_symbols.end()) {
|
||||
std::string err_msg = "ManifestNode: filename should not contain :*?\"<>|`&;\'";
|
||||
std::string err_msg =
|
||||
"ManifestDataset: filename of 'dataset_file' should not contain :*?\"<>|`&;\', check dataset_file: " +
|
||||
dataset_file_;
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
Path manifest_file(dataset_file_);
|
||||
if (!manifest_file.Exists()) {
|
||||
std::string err_msg = "ManifestNode: dataset file: [" + dataset_file_ + "] is invalid or not exist";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (access(dataset_file_.c_str(), R_OK) == -1) {
|
||||
std::string err_msg = "ManifestNode: No access to specified annotation file: " + dataset_file_;
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("ManifestDataset", {dataset_file_}, "annotation file"));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("ManifestNode", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("ManifestDataset", sampler_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("ManifestNode", usage_, {"train", "eval", "inference"}));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("ManifestDataset", usage_, {"train", "eval", "inference"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -125,7 +119,7 @@ Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
|
|||
int64_t num_rows, sample_size;
|
||||
std::vector<std::shared_ptr<DatasetOp>> ops;
|
||||
RETURN_IF_NOT_OK(Build(&ops));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build op.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "[Internal ERROR] Unable to build op.");
|
||||
auto op = std::dynamic_pointer_cast<ManifestOp>(ops.front());
|
||||
RETURN_IF_NOT_OK(op->CountTotalRows(&num_rows));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
|
@ -160,13 +154,12 @@ Status ManifestNode::to_json(nlohmann::json *out_json) {
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status ManifestNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_file") != json_obj.end(), "Failed to find dataset_file");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kManifestNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_file", kManifestNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kManifestNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kManifestNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "class_indexing", kManifestNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kManifestNode));
|
||||
std::string dataset_file = json_obj["dataset_file"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
|
|
|
@ -80,40 +80,39 @@ Status MindDataNode::ValidateParams() {
|
|||
constexpr size_t max_len = 4096;
|
||||
if (!search_for_pattern_ && dataset_files_.size() > max_len) {
|
||||
std::string err_msg =
|
||||
"MindDataNode: length of dataset_file must be less than or equal to 4096, dataset_file length: " +
|
||||
"MindDataset: length of dataset_file must be less than or equal to 4096, dataset_file length: " +
|
||||
std::to_string(dataset_file_.size());
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (shuffle_mode_ != ShuffleMode::kFalse && shuffle_mode_ != ShuffleMode::kFiles &&
|
||||
shuffle_mode_ != ShuffleMode::kGlobal && shuffle_mode_ != ShuffleMode::kInfile) {
|
||||
std::string err_msg = "MindDataNode: Invalid ShuffleMode, check input value of enum.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
RETURN_IF_NOT_OK(
|
||||
ValidateEnum("MindDataset", "ShuffleMode", shuffle_mode_,
|
||||
{ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal, ShuffleMode::kInfile}));
|
||||
|
||||
std::vector<std::string> dataset_file_vec =
|
||||
search_for_pattern_ ? std::vector<std::string>{dataset_file_} : dataset_files_;
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("MindDataNode", dataset_file_vec));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("MindDataset", dataset_file_vec));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("MindDataNode", input_sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("MindDataset", input_sampler_));
|
||||
|
||||
if (!columns_list_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MindDataNode", "columns_list", columns_list_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MindDataset", "columns_list", columns_list_));
|
||||
}
|
||||
|
||||
if (padded_sample_ != nullptr) {
|
||||
if (num_padded_ < 0 || num_padded_ > INT_MAX) {
|
||||
std::string err_msg =
|
||||
"MindDataNode: num_padded must to be between 0 and INT32_MAX, but got: " + std::to_string(num_padded_);
|
||||
"MindDataset: 'num_padded' must to be between 0 and INT32_MAX, but got: " + std::to_string(num_padded_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (columns_list_.empty()) {
|
||||
std::string err_msg = "MindDataNode: padded_sample is specified and requires columns_list as well";
|
||||
std::string err_msg = "MindDataset: 'padded_sample' is specified and requires 'columns_list' as well";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
for (std::string &column : columns_list_) {
|
||||
if (padded_sample_.find(column) == padded_sample_.end()) {
|
||||
std::string err_msg = "MindDataNode: " + column + " in columns_list does not match any column in padded_sample";
|
||||
std::string err_msg =
|
||||
"MindDataset: " + column + " in 'columns_list' does not match any column in 'padded_sample'";
|
||||
MS_LOG(ERROR) << err_msg << ", padded_sample: " << padded_sample_;
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
}
|
||||
|
@ -121,7 +120,7 @@ Status MindDataNode::ValidateParams() {
|
|||
}
|
||||
if (num_padded_ > 0) {
|
||||
if (padded_sample_ == nullptr) {
|
||||
std::string err_msg = "MindDataNode: num_padded is specified but padded_sample is not";
|
||||
std::string err_msg = "MindDataset: num_padded is specified but padded_sample is not";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
|
@ -136,7 +135,7 @@ Status MindDataNode::BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerO
|
|||
std::shared_ptr<mindrecord::ShardOperator> op = sampler->BuildForMindDataset();
|
||||
if (op == nullptr) {
|
||||
std::string err_msg =
|
||||
"MindDataNode: Unsupported sampler is supplied for MindDataset. Supported sampler list: "
|
||||
"MindDataset: Unsupported sampler is supplied for MindDataset. Supported sampler list: "
|
||||
"SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler and DistributedSampler";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
@ -149,7 +148,7 @@ Status MindDataNode::BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerO
|
|||
if (op->GetNumSamples() != 0 &&
|
||||
(op->GetShuffleMode() == ShuffleMode::kFiles || op->GetShuffleMode() == ShuffleMode::kInfile)) {
|
||||
std::string err_msg =
|
||||
"MindDataNode: Shuffle.FILES or Shuffle.INFILE and num_samples cannot be specified at the same time.";
|
||||
"MindDataset: Shuffle.kFiles or Shuffle.kInfile and num_samples cannot be specified at the same time.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
|
|
|
@ -44,11 +44,11 @@ void MnistNode::Print(std::ostream &out) const { out << Name(); }
|
|||
|
||||
Status MnistNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("MnistDataset", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistNode", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("MnistDataset", sampler_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("MnistNode", usage_, {"train", "test", "all"}));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("MnistDataset", usage_, {"train", "test", "all"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -117,11 +117,10 @@ Status MnistNode::to_json(nlohmann::json *out_json) {
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status MnistNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kMnistNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kMnistNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kMnistNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kMnistNode));
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
std::shared_ptr<SamplerObj> sampler;
|
||||
|
|
|
@ -51,9 +51,10 @@ void QMnistNode::Print(std::ostream &out) const {
|
|||
|
||||
Status QMnistNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("QMnistNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("QMnistNode", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("QMnistNode", usage_, {"train", "test", "test10k", "test50k", "nist", "all"}));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("QMnistDataset", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("QMnistDataset", sampler_));
|
||||
RETURN_IF_NOT_OK(
|
||||
ValidateStringValue("QMnistDataset", usage_, {"train", "test", "test10k", "test50k", "nist", "all"}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -128,12 +129,11 @@ Status QMnistNode::to_json(nlohmann::json *out_json) {
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status QMnistNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("compat") != json_obj.end(), "Failed to find compat");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kQMnistNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kQMnistNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kQMnistNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "compat", kQMnistNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kQMnistNode));
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string usage = json_obj["usage"];
|
||||
bool compat = json_obj["compat"];
|
||||
|
|
|
@ -41,20 +41,16 @@ void RandomNode::Print(std::ostream &out) const {
|
|||
// ValidateParams for RandomNode
|
||||
Status RandomNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
if (total_rows_ < 0) {
|
||||
std::string err_msg =
|
||||
"RandomNode: total_rows must be greater than or equal 0, now get " + std::to_string(total_rows_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
RETURN_IF_NOT_OK(ValidateScalar("RandomDataset", "total_rows", total_rows_, {0}, false));
|
||||
|
||||
if (!columns_list_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RandomNode", "columns_list", columns_list_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RandomDataset", "columns_list", columns_list_));
|
||||
}
|
||||
|
||||
// allow total_rows == 0 for now because RandomOp would generate a random row when it gets a 0
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(total_rows_ == 0 || total_rows_ >= num_workers_,
|
||||
"RandomNode needs total_rows >= num_workers, total_rows=" + std::to_string(total_rows_) +
|
||||
", num_workers=" + std::to_string(num_workers_) + ".");
|
||||
"RandomDataset needs 'total_rows' >= 'num_workers', total_rows=" +
|
||||
std::to_string(total_rows_) + ", num_workers=" + std::to_string(num_workers_) + ".");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -72,7 +68,7 @@ Status RandomNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
|
|||
if (!schema_path_.empty()) {
|
||||
schema_obj = Schema(schema_path_);
|
||||
if (schema_obj == nullptr) {
|
||||
std::string err_msg = "RandomNode::Build : Invalid schema path";
|
||||
std::string err_msg = "RandomDataset: Invalid schema path, check schema path:" + schema_path_;
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
|
|
@ -110,12 +110,12 @@ Status DistributedSamplerObj::to_json(nlohmann::json *const out_json) {
|
|||
#ifndef ENABLE_ANDROID
|
||||
Status DistributedSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("seed") != json_obj.end(), "Failed to find seed");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("offset") != json_obj.end(), "Failed to find offset");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("even_dist") != json_obj.end(), "Failed to find even_dist");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_shards", "DistributedSampler"));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "shard_id", "DistributedSampler"));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "shuffle", "DistributedSampler"));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "seed", "DistributedSampler"));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "offset", "DistributedSampler"));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "even_dist", "DistributedSampler"));
|
||||
int64_t num_shards = json_obj["num_shards"];
|
||||
int64_t shard_id = json_obj["shard_id"];
|
||||
bool shuffle = json_obj["shuffle"];
|
||||
|
@ -135,7 +135,7 @@ std::shared_ptr<SamplerObj> DistributedSamplerObj::SamplerCopy() {
|
|||
std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_, offset_, even_dist_);
|
||||
for (const auto &child : children_) {
|
||||
Status rc = sampler->AddChildSampler(child);
|
||||
if (rc.IsError()) MS_LOG(ERROR) << "Error in copying the sampler. Message: " << rc;
|
||||
if (rc.IsError()) MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc;
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ std::shared_ptr<SamplerObj> MindRecordSamplerObj::SamplerCopy() {
|
|||
// Note this function can only be called after SamplerBuild is finished, and can only be called once. Otherwise this
|
||||
// function will return error status.
|
||||
Status MindRecordSamplerObj::GetShardReader(std::unique_ptr<mindrecord::ShardReader> *shard_reader) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_ != nullptr, "Internal error. Attempt to get an empty shard reader.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_ != nullptr, "[Internal ERROR] Attempt to get an empty shard reader.");
|
||||
*shard_reader = std::move(shard_reader_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -62,8 +62,8 @@ Status PKSamplerObj::to_json(nlohmann::json *const out_json) {
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status PKSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_val") != json_obj.end(), "Failed to find num_val");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_val", "PKSampler"));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "shuffle", "PKSampler"));
|
||||
int64_t num_val = json_obj["num_val"];
|
||||
bool shuffle = json_obj["shuffle"];
|
||||
*sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
|
||||
|
@ -100,7 +100,7 @@ std::shared_ptr<SamplerObj> PKSamplerObj::SamplerCopy() {
|
|||
auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
|
||||
for (const auto &child : children_) {
|
||||
Status rc = sampler->AddChildSampler(child);
|
||||
if (rc.IsError()) MS_LOG(ERROR) << "Error in copying the sampler. Message: " << rc;
|
||||
if (rc.IsError()) MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc;
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
|
|
@ -62,7 +62,9 @@ std::shared_ptr<SamplerObj> PreBuiltSamplerObj::SamplerCopy() {
|
|||
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_minddataset_);
|
||||
for (const auto &child : children_) {
|
||||
Status rc = sampler->AddChildSampler(child);
|
||||
if (rc.IsError()) MS_LOG(ERROR) << "Error in copying the sampler. Message: " << rc;
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc;
|
||||
}
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
@ -70,7 +72,9 @@ std::shared_ptr<SamplerObj> PreBuiltSamplerObj::SamplerCopy() {
|
|||
auto sampler = std::make_shared<PreBuiltSamplerObj>(sp_);
|
||||
for (const auto &child : children_) {
|
||||
Status rc = sampler->AddChildSampler(child);
|
||||
if (rc.IsError()) MS_LOG(ERROR) << "Error in copying the sampler. Message: " << rc;
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc;
|
||||
}
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
|
|
@ -58,9 +58,8 @@ Status RandomSamplerObj::to_json(nlohmann::json *const out_json) {
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status RandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("reshuffle_each_epoch") != json_obj.end(),
|
||||
"Failed to find reshuffle_each_epoch");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "replacement", "RandomSampler"));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "reshuffle_each_epoch", "RandomSampler"));
|
||||
bool replacement = json_obj["replacement"];
|
||||
bool reshuffle_each_epoch = json_obj["reshuffle_each_epoch"];
|
||||
*sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples, reshuffle_each_epoch);
|
||||
|
@ -92,7 +91,9 @@ std::shared_ptr<SamplerObj> RandomSamplerObj::SamplerCopy() {
|
|||
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_);
|
||||
for (const auto &child : children_) {
|
||||
Status rc = sampler->AddChildSampler(child);
|
||||
if (rc.IsError()) MS_LOG(ERROR) << "Error in copying the sampler. Message: " << rc;
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc;
|
||||
}
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/util/validators.h"
|
||||
#ifndef ENABLE_ANDROID
|
||||
#include "minddata/mindrecord/include/shard_operator.h"
|
||||
#endif
|
||||
|
|
|
@ -64,7 +64,7 @@ Status SequentialSamplerObj::to_json(nlohmann::json *const out_json) {
|
|||
#ifndef ENABLE_ANDROID
|
||||
Status SequentialSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("start_index") != json_obj.end(), "Failed to find start_index");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "start_index", "SequentialSampler"));
|
||||
int64_t start_index = json_obj["start_index"];
|
||||
*sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples);
|
||||
// Run common code in super class to add children samplers
|
||||
|
@ -89,11 +89,14 @@ std::shared_ptr<mindrecord::ShardOperator> SequentialSamplerObj::BuildForMindDat
|
|||
return mind_sampler;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::shared_ptr<SamplerObj> SequentialSamplerObj::SamplerCopy() {
|
||||
auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
|
||||
for (const auto &child : children_) {
|
||||
Status rc = sampler->AddChildSampler(child);
|
||||
if (rc.IsError()) MS_LOG(ERROR) << "Error in copying the sampler. Message: " << rc;
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc;
|
||||
}
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ Status SubsetRandomSamplerObj::to_json(nlohmann::json *const out_json) {
|
|||
#ifndef ENABLE_ANDROID
|
||||
Status SubsetRandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("indices") != json_obj.end(), "Failed to find indices");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "indices", "SubsetRandomSampler"));
|
||||
std::vector<int64_t> indices = json_obj["indices"];
|
||||
*sampler = std::make_shared<SubsetRandomSamplerObj>(indices, num_samples);
|
||||
// Run common code in super class to add children samplers
|
||||
|
@ -80,7 +80,9 @@ std::shared_ptr<SamplerObj> SubsetRandomSamplerObj::SamplerCopy() {
|
|||
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
|
||||
for (const auto &child : children_) {
|
||||
Status rc = sampler->AddChildSampler(child);
|
||||
if (rc.IsError()) MS_LOG(ERROR) << "Error in copying the sampler. Message: " << rc;
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc;
|
||||
}
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
|
|
@ -74,7 +74,7 @@ Status SubsetSamplerObj::to_json(nlohmann::json *const out_json) {
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status SubsetSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples, std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("indices") != json_obj.end(), "Failed to find indices");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "indices", "SubsetSampler"));
|
||||
std::vector<int64_t> indices = json_obj["indices"];
|
||||
*sampler = std::make_shared<SubsetSamplerObj>(indices, num_samples);
|
||||
// Run common code in super class to add children samplers
|
||||
|
@ -87,7 +87,9 @@ std::shared_ptr<SamplerObj> SubsetSamplerObj::SamplerCopy() {
|
|||
auto sampler = std::make_shared<SubsetSamplerObj>(indices_, num_samples_);
|
||||
for (const auto &child : children_) {
|
||||
Status rc = sampler->AddChildSampler(child);
|
||||
if (rc.IsError()) MS_LOG(ERROR) << "Error in copying the sampler. Message: " << rc;
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc;
|
||||
}
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
|
|
@ -35,8 +35,10 @@ Status WeightedRandomSamplerObj::ValidateParams() {
|
|||
int32_t zero_elem = 0;
|
||||
for (int32_t i = 0; i < weights_.size(); ++i) {
|
||||
if (weights_[i] < 0) {
|
||||
RETURN_STATUS_UNEXPECTED("WeightedRandomSampler: weights vector must not contain negative number, got: " +
|
||||
std::to_string(weights_[i]));
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"WeightedRandomSampler: weights vector must not contain negative numbers, got: "
|
||||
"weights[" +
|
||||
std::to_string(i) + "] = " + std::to_string(weights_[i]));
|
||||
}
|
||||
if (weights_[i] == 0.0) {
|
||||
zero_elem++;
|
||||
|
@ -66,8 +68,8 @@ Status WeightedRandomSamplerObj::to_json(nlohmann::json *const out_json) {
|
|||
#ifndef ENABLE_ANDROID
|
||||
Status WeightedRandomSamplerObj::from_json(nlohmann::json json_obj, int64_t num_samples,
|
||||
std::shared_ptr<SamplerObj> *sampler) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("weights") != json_obj.end(), "Failed to find weights");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("replacement") != json_obj.end(), "Failed to find replacement");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "weights", "WeightedRandomSampler"));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "replacement", "WeightedRandomSampler"));
|
||||
std::vector<double> weights = json_obj["weights"];
|
||||
bool replacement = json_obj["replacement"];
|
||||
*sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
|
||||
|
@ -83,11 +85,14 @@ Status WeightedRandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sample
|
|||
sampler = s.IsOk() ? sampler : nullptr;
|
||||
return s;
|
||||
}
|
||||
|
||||
std::shared_ptr<SamplerObj> WeightedRandomSamplerObj::SamplerCopy() {
|
||||
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);
|
||||
for (const auto &child : children_) {
|
||||
Status rc = sampler->AddChildSampler(child);
|
||||
if (rc.IsError()) MS_LOG(ERROR) << "Error in copying the sampler. Message: " << rc;
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "[Internal ERROR] Error in copying the sampler. Message: " << rc;
|
||||
}
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
|
|
@ -43,8 +43,8 @@ void SBUNode::Print(std::ostream &out) const {
|
|||
|
||||
Status SBUNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("SBUNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("SBUNode", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("SBUDataset", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("SBUDataset", sampler_));
|
||||
|
||||
Path root_dir(dataset_dir_);
|
||||
|
||||
|
@ -52,9 +52,9 @@ Status SBUNode::ValidateParams() {
|
|||
Path caption_path = root_dir / Path("SBU_captioned_photo_dataset_captions.txt");
|
||||
Path image_path = root_dir / Path("sbu_images");
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("SBUNode", {url_path.ToString()}));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("SBUNode", {caption_path.ToString()}));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("SBUNode", {image_path.ToString()}));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("SBUDataset", {url_path.ToString()}, "url file"));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("SBUDataset", {caption_path.ToString()}, "caption file"));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("SBUDataset", {image_path.ToString()}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -53,19 +53,11 @@ void TextFileNode::Print(std::ostream &out) const {
|
|||
|
||||
Status TextFileNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileNode", dataset_files_));
|
||||
|
||||
if (shuffle_ != ShuffleMode::kFalse && shuffle_ != ShuffleMode::kFiles && shuffle_ != ShuffleMode::kGlobal) {
|
||||
std::string err_msg = "TextFileNode: Invalid ShuffleMode, check input value of enum.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "TextFileNode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("TextFileNode", num_shards_, shard_id_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TextFileDataset", dataset_files_));
|
||||
RETURN_IF_NOT_OK(ValidateEnum("TextFileDataset", "ShuffleMode", shuffle_,
|
||||
{ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal}));
|
||||
RETURN_IF_NOT_OK(ValidateScalar("TextFileDataset", "num_samples", num_samples_, {0}, false));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("TextFileDataset", num_shards_, shard_id_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -155,13 +147,12 @@ Status TextFileNode::to_json(nlohmann::json *out_json) {
|
|||
}
|
||||
|
||||
Status TextFileNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kTextFileNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_files", kTextFileNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_samples", kTextFileNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "shuffle", kTextFileNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_shards", kTextFileNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "shard_id", kTextFileNode));
|
||||
std::vector<std::string> dataset_files = json_obj["dataset_files"];
|
||||
int64_t num_samples = json_obj["num_samples"];
|
||||
ShuffleMode shuffle = static_cast<ShuffleMode>(json_obj["shuffle"]);
|
||||
|
|
|
@ -52,45 +52,11 @@ void TFRecordNode::Print(std::ostream &out) const {
|
|||
// Validator for TFRecordNode
|
||||
Status TFRecordNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
|
||||
if (shuffle_ != ShuffleMode::kFalse && shuffle_ != ShuffleMode::kFiles && shuffle_ != ShuffleMode::kGlobal) {
|
||||
std::string err_msg = "TFRecordNode: Invalid ShuffleMode, check input value of enum.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (dataset_files_.empty()) {
|
||||
std::string err_msg = "TFRecordNode: dataset_files is not specified.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
||||
for (const auto &f : dataset_files_) {
|
||||
auto realpath = FileUtils::GetRealPath(f.data());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(),
|
||||
"TFRecordNode: dataset file: [" + f + "] is invalid or does not exist.");
|
||||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "TFRecordNode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
|
||||
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
||||
if (num_shards_ <= 0) {
|
||||
std::string err_msg = "TFRecordNode: Invalid num_shards: " + std::to_string(num_shards_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
|
||||
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
||||
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
|
||||
std::string err_msg = "TFRecordNode: Invalid input, shard_id: " + std::to_string(shard_id_) +
|
||||
", num_shards: " + std::to_string(num_shards_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
|
||||
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
RETURN_IF_NOT_OK(ValidateEnum("TFRecordDataset", "ShuffleMode", shuffle_,
|
||||
{ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal}));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("TFRecordDataset", dataset_files_));
|
||||
RETURN_IF_NOT_OK(ValidateScalar("TFRecordDataset", "num_samples", num_samples_, {0}, false));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("TFRecordDataset", num_shards_, shard_id_));
|
||||
|
||||
std::vector<std::string> invalid_files(dataset_files_.size());
|
||||
auto it = std::copy_if(dataset_files_.begin(), dataset_files_.end(), invalid_files.begin(),
|
||||
|
@ -239,15 +205,14 @@ Status TFRecordNode::to_json(nlohmann::json *out_json) {
|
|||
}
|
||||
|
||||
Status TFRecordNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_files") != json_obj.end(), "Failed to find dataset_files");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("columns_list") != json_obj.end(), "Failed to find columns_list");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_samples") != json_obj.end(), "Failed to find num_samples");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shuffle") != json_obj.end(), "Failed to find shuffle");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_shards") != json_obj.end(), "Failed to find num_shards");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_id") != json_obj.end(), "Failed to find shard_id");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("shard_equal_rows") != json_obj.end(), "Failed to find shard_equal_rows");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kTFRecordNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_files", kTFRecordNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "columns_list", kTFRecordNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_samples", kTFRecordNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "shuffle", kTFRecordNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_shards", kTFRecordNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "shard_id", kTFRecordNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "shard_equal_rows", kTFRecordNode));
|
||||
std::vector<std::string> dataset_files = json_obj["dataset_files"];
|
||||
std::vector<std::string> columns_list = json_obj["columns_list"];
|
||||
int64_t num_samples = json_obj["num_samples"];
|
||||
|
|
|
@ -56,15 +56,10 @@ void USPSNode::Print(std::ostream &out) const {
|
|||
|
||||
Status USPSNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("USPSNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("USPSNode", usage_, {"train", "test", "all"}));
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "USPSNode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("USPSNode", num_shards_, shard_id_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("USPSDataset", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("USPSDataset", usage_, {"train", "test", "all"}));
|
||||
RETURN_IF_NOT_OK(ValidateScalar("USPSDataset", "num_samples", num_samples_, {0}, false));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("USPSDataset", num_shards_, shard_id_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -57,30 +57,32 @@ Status VOCNode::ValidateParams() {
|
|||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
Path dir(dataset_dir_);
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("VOCNode", dataset_dir_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("VOCDataset", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("VOCNode", sampler_));
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("VOCDataset", sampler_));
|
||||
|
||||
if (task_ == "Segmentation") {
|
||||
if (!class_index_.empty()) {
|
||||
std::string err_msg = "VOCNode: class_indexing is invalid in Segmentation task.";
|
||||
std::string err_msg = "VOCDataset: 'class_indexing' is invalid in Segmentation task.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
Path imagesets_file = dir / "ImageSets" / "Segmentation" / usage_ + ".txt";
|
||||
if (!imagesets_file.Exists()) {
|
||||
std::string err_msg = "VOCNode: Invalid usage: " + usage_ + ", file does not exist";
|
||||
MS_LOG(ERROR) << "VOCNode: Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!";
|
||||
std::string err_msg = "VOCDataset: Invalid 'usage': " + usage_ + ", file does not exist";
|
||||
MS_LOG(ERROR) << "VOCDataset: Invalid 'usage': " << usage_ << ", file \"" << imagesets_file
|
||||
<< "\" does not exist!";
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
}
|
||||
} else if (task_ == "Detection") {
|
||||
Path imagesets_file = dir / "ImageSets" / "Main" / usage_ + ".txt";
|
||||
if (!imagesets_file.Exists()) {
|
||||
std::string err_msg = "VOCNode: Invalid usage: " + usage_ + ", file does not exist";
|
||||
MS_LOG(ERROR) << "VOCNode: Invalid usage: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!";
|
||||
std::string err_msg = "VOCDataset: Invalid 'usage': " + usage_ + ", file does not exist";
|
||||
MS_LOG(ERROR) << "VOCDataset: Invalid 'usage': " << usage_ << ", file \"" << imagesets_file
|
||||
<< "\" does not exist!";
|
||||
return Status(StatusCode::kMDSyntaxError, err_msg);
|
||||
}
|
||||
} else {
|
||||
std::string err_msg = "VOCNode: Invalid task: " + task_;
|
||||
std::string err_msg = "VOCDataset: Invalid 'task': " + task_ + ", expected Segmentation or Detection.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
|
@ -146,7 +148,7 @@ Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge
|
|||
int64_t num_rows = 0, sample_size;
|
||||
std::vector<std::shared_ptr<DatasetOp>> ops;
|
||||
RETURN_IF_NOT_OK(Build(&ops));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build VocOp.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "[Internal ERROR] Unable to build VocOp.");
|
||||
auto op = std::dynamic_pointer_cast<VOCOp>(ops.front());
|
||||
RETURN_IF_NOT_OK(op->CountTotalRows(&num_rows));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
|
@ -182,15 +184,14 @@ Status VOCNode::to_json(nlohmann::json *out_json) {
|
|||
|
||||
#ifndef ENABLE_ANDROID
|
||||
Status VOCNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> *ds) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("num_parallel_workers") != json_obj.end(),
|
||||
"Failed to find num_parallel_workers");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("dataset_dir") != json_obj.end(), "Failed to find dataset_dir");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("task") != json_obj.end(), "Failed to find task");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("usage") != json_obj.end(), "Failed to find usage");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("class_indexing") != json_obj.end(), "Failed to find class_indexing");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("decode") != json_obj.end(), "Failed to find decode");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("sampler") != json_obj.end(), "Failed to find sampler");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("extra_metadata") != json_obj.end(), "Failed to find extra_metadata");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "num_parallel_workers", kTFRecordNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "dataset_dir", kTFRecordNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "task", kTFRecordNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "usage", kTFRecordNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "class_indexing", kTFRecordNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "decode", kTFRecordNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "sampler", kTFRecordNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "extra_metadata", kTFRecordNode));
|
||||
std::string dataset_dir = json_obj["dataset_dir"];
|
||||
std::string task = json_obj["task"];
|
||||
std::string usage = json_obj["usage"];
|
||||
|
|
|
@ -52,7 +52,7 @@ Status TakeNode::ValidateParams() {
|
|||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
if (take_count_ <= 0 && take_count_ != -1) {
|
||||
std::string err_msg =
|
||||
"TakeNode: take_count should be either -1 or positive integer, take_count: " + std::to_string(take_count_);
|
||||
"TakeNode: 'take_count' should be either -1 or positive integer, but got: " + std::to_string(take_count_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -93,7 +93,7 @@ Status TakeNode::to_json(nlohmann::json *out_json) {
|
|||
|
||||
Status TakeNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("count") != json_obj.end(), "Failed to find count");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "count", kTakeNode));
|
||||
int32_t count = json_obj["count"];
|
||||
*result = std::make_shared<TakeNode>(ds, count);
|
||||
return Status::OK();
|
||||
|
|
|
@ -57,10 +57,7 @@ void TransferNode::Print(std::ostream &out) const {
|
|||
// Validator for TransferNode
|
||||
Status TransferNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
if (total_batch_ < 0) {
|
||||
std::string err_msg = "TransferNode: Total batches should be >= 0, value given: " + std::to_string(total_batch_);
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
RETURN_IF_NOT_OK(ValidateScalar("Transfer", "Total batches", total_batch_, {0}, false));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -89,7 +86,7 @@ Status TransferNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
} else if (device_type_ == kAscendDevice) {
|
||||
type = DeviceQueueOp::DeviceType::Ascend;
|
||||
} else {
|
||||
std::string err_msg = "Unknown device target.";
|
||||
std::string err_msg = "Unknown device target, support CPU, GPU or Ascend";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
@ -128,13 +125,12 @@ Status TransferNode::to_json(nlohmann::json *out_json) {
|
|||
|
||||
Status TransferNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode> ds,
|
||||
std::shared_ptr<DatasetNode> *result) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("queue_name") != json_obj.end(), "Failed to find queue_name");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("device_type") != json_obj.end(), "Failed to find device_type");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("device_id") != json_obj.end(), "Failed to find device_id");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("send_epoch_end") != json_obj.end(), "Failed to find send_epoch_end");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("total_batch") != json_obj.end(), "Failed to find total_batch");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("create_data_info_queue") != json_obj.end(),
|
||||
"Failed to find create_data_info_queue");
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "queue_name", kTransferNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "device_type", kTransferNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "device_id", kTransferNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "send_epoch_end", kTransferNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "total_batch", kTransferNode));
|
||||
RETURN_IF_NOT_OK(ValidateParamInJson(json_obj, "create_data_info_queue", kTransferNode));
|
||||
std::string queue_name = json_obj["queue_name"];
|
||||
std::string device_type = json_obj["device_type"];
|
||||
int32_t device_id = json_obj["device_id"];
|
||||
|
|
|
@ -23,7 +23,7 @@ Status PythonRuntimeContext::Terminate() {
|
|||
if (tree_consumer_ != nullptr) {
|
||||
return TerminateImpl();
|
||||
}
|
||||
MS_LOG(WARNING) << "Dataset TreeConsumer was not initialized.";
|
||||
MS_LOG(INFO) << "Dataset TreeConsumer was not initialized.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -36,7 +36,9 @@ Status PythonRuntimeContext::TerminateImpl() {
|
|||
|
||||
PythonRuntimeContext::~PythonRuntimeContext() {
|
||||
Status rc = PythonRuntimeContext::Terminate();
|
||||
if (rc.IsError()) MS_LOG(ERROR) << "Error while terminating the consumer. Message:" << rc;
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "Error while terminating the consumer. Message:" << rc;
|
||||
}
|
||||
{
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
tree_consumer_.reset();
|
||||
|
|
|
@ -24,7 +24,7 @@ Status NativeRuntimeContext::Terminate() {
|
|||
if (tree_consumer_ != nullptr) {
|
||||
return TerminateImpl();
|
||||
}
|
||||
MS_LOG(WARNING) << "Dataset TreeConsumer was not initialized.";
|
||||
MS_LOG(INFO) << "Dataset TreeConsumer was not initialized.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -35,7 +35,9 @@ Status NativeRuntimeContext::TerminateImpl() {
|
|||
|
||||
NativeRuntimeContext::~NativeRuntimeContext() {
|
||||
Status rc = NativeRuntimeContext::Terminate();
|
||||
if (rc.IsError()) MS_LOG(ERROR) << "Error while terminating the consumer. Message:" << rc;
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "Error while terminating the consumer. Message:" << rc;
|
||||
}
|
||||
}
|
||||
|
||||
TreeConsumer *RuntimeContext::GetConsumer() { return tree_consumer_.get(); }
|
||||
|
|
|
@ -65,8 +65,8 @@ Status Serdes::SaveJSONToFile(nlohmann::json json_string, const std::string &fil
|
|||
}
|
||||
auto realpath = FileUtils::GetRealPath(dir.value().data());
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path failed, path=" << file_name;
|
||||
RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + file_name);
|
||||
MS_LOG(ERROR) << "Invalid file, get real path failed, path=" << file_name;
|
||||
RETURN_STATUS_UNEXPECTED("Invalid file, get real path failed, path=" + file_name);
|
||||
}
|
||||
|
||||
std::optional<std::string> whole_path = "";
|
||||
|
@ -78,7 +78,8 @@ Status Serdes::SaveJSONToFile(nlohmann::json json_string, const std::string &fil
|
|||
|
||||
ChangeFileMode(whole_path.value(), S_IRUSR | S_IWUSR);
|
||||
} catch (const std::exception &err) {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, failed to save json string into file: " + file_name);
|
||||
RETURN_STATUS_UNEXPECTED("Invalid data, failed to save json string into file: " + file_name +
|
||||
", error message: " + err.what());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -91,7 +92,8 @@ Status Serdes::Deserialize(const std::string &json_filepath, std::shared_ptr<Dat
|
|||
try {
|
||||
json_in >> json_obj;
|
||||
} catch (const std::exception &e) {
|
||||
return Status(StatusCode::kMDSyntaxError, "Invalid file, failed to parse json file: " + json_filepath);
|
||||
return Status(StatusCode::kMDSyntaxError,
|
||||
"Invalid file, failed to parse json file: " + json_filepath + ", error message: " + e.what());
|
||||
}
|
||||
RETURN_IF_NOT_OK(ConstructPipeline(json_obj, ds));
|
||||
return Status::OK();
|
||||
|
@ -337,7 +339,7 @@ Status Serdes::ParseMindIRPreprocess(const std::string &dataset_json, const std:
|
|||
try {
|
||||
dataset_js = nlohmann::json::parse(dataset_json);
|
||||
} catch (const std::exception &err) {
|
||||
MS_LOG(ERROR) << "Invalid json content, failed to parse JSON data.";
|
||||
MS_LOG(ERROR) << "Invalid json content, failed to parse JSON data, error message: " << err.what();
|
||||
RETURN_STATUS_UNEXPECTED("Invalid json content, failed to parse JSON data.");
|
||||
}
|
||||
|
||||
|
|
|
@ -50,8 +50,8 @@ Status ValidateScalar(const std::string &op_name, const std::string &scalar_name
|
|||
}
|
||||
if ((left_open_interval && scalar <= range[0]) || (!left_open_interval && scalar < range[0])) {
|
||||
std::string interval_description = left_open_interval ? " greater than " : " greater than or equal to ";
|
||||
std::string err_msg = op_name + ":" + scalar_name + " must be" + interval_description + std::to_string(range[0]) +
|
||||
", got: " + std::to_string(scalar);
|
||||
std::string err_msg = op_name + ": '" + scalar_name + "' must be" + interval_description +
|
||||
std::to_string(range[0]) + ", got: " + std::to_string(scalar);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
@ -69,6 +69,15 @@ Status ValidateScalar(const std::string &op_name, const std::string &scalar_name
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Helper function to validate enum
|
||||
template <typename T>
|
||||
Status ValidateEnum(const std::string &op_name, const std::string &enum_name, const T enumeration,
|
||||
const std::vector<T> &enum_list) {
|
||||
auto existed = std::find(enum_list.begin(), enum_list.end(), enumeration);
|
||||
std::string err_msg = op_name + ": Invalid " + enum_name + ", check input value of enum.";
|
||||
return existed != enum_list.end() ? Status::OK() : Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
|
||||
// Helper function to validate color attribute
|
||||
Status ValidateVectorColorAttribute(const std::string &op_name, const std::string &attr_name,
|
||||
const std::vector<float> &attr, const std::vector<float> &range);
|
||||
|
|
|
@ -25,11 +25,11 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// validator Parameter in json file
|
||||
inline Status ValidateParamInJson(nlohmann::json op_params, const std::string ¶m_name,
|
||||
inline Status ValidateParamInJson(const nlohmann::json &json_obj, const std::string ¶m_name,
|
||||
const std::string &operator_name) {
|
||||
if (op_params.find(param_name) == op_params.end()) {
|
||||
std::string err_msg = "Failed to find parameter '" + param_name + "' of '" + operator_name +
|
||||
"' operator in input json file or input dict, check input parameter of API 'deserialize.";
|
||||
if (json_obj.find(param_name) == json_obj.end()) {
|
||||
std::string err_msg = "Failed to find key '" + param_name + "' in " + operator_name +
|
||||
"' JSON file or input dict, check input content of deserialize().";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
|
|
|
@ -407,7 +407,7 @@ def test_weighted_random_sampler_exception():
|
|||
sampler = ds.WeightedRandomSampler(weights)
|
||||
sampler.parse()
|
||||
|
||||
error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative number, got: "
|
||||
error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative numbers, got: "
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
weights = [1.0, 0.1, 0.02, 0.3, -0.4]
|
||||
sampler = ds.WeightedRandomSampler(weights)
|
||||
|
|
Loading…
Reference in New Issue