!20589 [MD] fix security check in master

Merge pull request !20589 from liyong126/fix_security_check
This commit is contained in:
i-robot 2021-07-21 06:24:35 +00:00 committed by Gitee
commit 7fe8eddc42
3 changed files with 126 additions and 71 deletions

View File

@ -106,22 +106,27 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const str
std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) {
std::vector<std::string> field_name = StringSplit(field_path, kPoint);
for (uint64_t i = 0; i < field_name.size(); i++) {
if (i != field_name.size() - 1) {
// Get type information from json schema
schema = schema.at(field_name[i]);
schema = schema.at("properties");
} else {
// standard root layer exist "properties" if type is "object"
if (schema.find("properties") != schema.end()) {
try {
if (i != field_name.size() - 1) {
// Get type information from json schema
schema = schema.at(field_name[i]);
schema = schema.at("properties");
}
schema = schema.at(field_name[i]);
std::string field_type = schema.at("type").dump();
if (field_type.length() <= 2) {
return "";
} else {
return field_type.substr(1, field_type.length() - 2);
// standard root layer exist "properties" if type is "object"
if (schema.find("properties") != schema.end()) {
schema = schema.at("properties");
}
schema = schema.at(field_name[i]);
std::string field_type = schema.at("type").dump();
if (field_type.length() <= 2) {
return "";
} else {
return field_type.substr(1, field_type.length() - 2);
}
}
} catch (...) {
MS_LOG(WARNING) << "Exception occurred while get field type.";
return "";
}
}
return "";
@ -330,6 +335,9 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data) {
sqlite3_stmt *stmt = nullptr;
if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
if (stmt) {
(void)sqlite3_finalize(stmt);
}
MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql;
return FAILED;
}
@ -342,29 +350,34 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder));
if (field_type == "INTEGER") {
if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
<< ", field value: " << std::stoll(field_value);
return FAILED;
}
} else if (field_type == "NUMERIC") {
if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
<< ", field value: " << std::stold(field_value);
return FAILED;
}
} else if (field_type == "NULL") {
if (sqlite3_bind_null(stmt, index) != SQLITE_OK) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: NULL";
return FAILED;
}
} else {
if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << field_value;
return FAILED;
}
}
}
if (sqlite3_step(stmt) != SQLITE_DONE) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt.";
return FAILED;
}
@ -422,7 +435,12 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> full_data;
// current raw data page
std::shared_ptr<Page> cur_raw_page = shard_header_.GetPage(shard_no, raw_page_id).first;
auto ret1 = shard_header_.GetPage(shard_no, raw_page_id);
if (ret1.second != SUCCESS) {
MS_LOG(ERROR) << "Get page failed";
return {FAILED, {}};
}
std::shared_ptr<Page> cur_raw_page = ret1.first;
// related blob page
vector<pair<int, uint64_t>> row_group_list = cur_raw_page->GetRowGroupIds();
@ -430,7 +448,17 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
// pair: row_group id, offset in raw data page
for (pair<int, int> blob_ids : row_group_list) {
// get blob data page according to row_group id
std::shared_ptr<Page> cur_blob_page = shard_header_.GetPage(shard_no, blob_id_to_page_id.at(blob_ids.first)).first;
auto iter = blob_id_to_page_id.find(blob_ids.first);
if (iter == blob_id_to_page_id.end()) {
MS_LOG(ERROR) << "Convert blob id failed";
return {FAILED, {}};
}
auto ret2 = shard_header_.GetPage(shard_no, iter->second);
if (ret2.second != SUCCESS) {
MS_LOG(ERROR) << "Get page failed";
return {FAILED, {}};
}
std::shared_ptr<Page> cur_blob_page = ret2.first;
// offset in current raw data page
auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second);
@ -619,7 +647,12 @@ void ShardIndexGenerator::DatabaseWriter() {
std::map<int, int> blob_id_to_page_id;
std::vector<int> raw_page_ids;
for (uint64_t i = 0; i < total_pages; ++i) {
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first;
auto ret = shard_header_.GetPage(shard_no, i);
if (ret.second != SUCCESS) {
write_success_ = false;
return;
}
std::shared_ptr<Page> cur_page = ret.first;
if (cur_page->GetPageType() == "RAW_DATA") {
raw_page_ids.push_back(i);
} else if (cur_page->GetPageType() == "BLOB_DATA") {

View File

@ -340,67 +340,78 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
int shard_id, const std::vector<std::string> &columns,
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
for (int i = 0; i < static_cast<int>(labels.size()); ++i) {
uint64_t group_id = std::stoull(labels[i][0]);
uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len;
uint64_t offset_end = std::stoull(labels[i][2]);
(*offset_ptr)[shard_id].emplace_back(
std::vector<uint64_t>{static_cast<uint64_t>(shard_id), group_id, offset_start, offset_end});
if (!all_in_index_) {
int raw_page_id = std::stoi(labels[i][3]);
uint64_t label_start = std::stoull(labels[i][4]) + kInt64Len;
uint64_t label_end = std::stoull(labels[i][5]);
auto len = label_end - label_start;
auto label_raw = std::vector<uint8_t>(len);
auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
MS_LOG(ERROR) << "File seekg failed";
fs->close();
return FAILED;
}
try {
uint64_t group_id = std::stoull(labels[i][0]);
uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len;
uint64_t offset_end = std::stoull(labels[i][2]);
(*offset_ptr)[shard_id].emplace_back(
std::vector<uint64_t>{static_cast<uint64_t>(shard_id), group_id, offset_start, offset_end});
if (!all_in_index_) {
int raw_page_id = std::stoi(labels[i][3]);
uint64_t label_start = std::stoull(labels[i][4]) + kInt64Len;
uint64_t label_end = std::stoull(labels[i][5]);
auto len = label_end - label_start;
auto label_raw = std::vector<uint8_t>(len);
auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
MS_LOG(ERROR) << "File seekg failed";
fs->close();
return FAILED;
}
auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
if (!io_read.good() || io_read.fail() || io_read.bad()) {
MS_LOG(ERROR) << "File read failed";
fs->close();
return FAILED;
}
json label_json = json::from_msgpack(label_raw);
json tmp;
if (!columns.empty()) {
for (auto &col : columns) {
if (label_json.find(col) != label_json.end()) {
tmp[col] = label_json[col];
auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
if (!io_read.good() || io_read.fail() || io_read.bad()) {
MS_LOG(ERROR) << "File read failed";
fs->close();
return FAILED;
}
json label_json = json::from_msgpack(label_raw);
json tmp;
if (!columns.empty()) {
for (auto &col : columns) {
if (label_json.find(col) != label_json.end()) {
tmp[col] = label_json[col];
}
}
} else {
tmp = label_json;
}
(*col_val_ptr)[shard_id].emplace_back(tmp);
} else {
json construct_json;
for (unsigned int j = 0; j < columns.size(); ++j) {
// construct json "f1": value
auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
// convert the string to base type by schema
if (schema[columns[j]]["type"] == "int32") {
construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j + 3]);
} else if (schema[columns[j]]["type"] == "int64") {
construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j + 3]);
} else if (schema[columns[j]]["type"] == "float32") {
construct_json[columns[j]] = StringToNum<float>(labels[i][j + 3]);
} else if (schema[columns[j]]["type"] == "float64") {
construct_json[columns[j]] = StringToNum<double>(labels[i][j + 3]);
} else {
construct_json[columns[j]] = std::string(labels[i][j + 3]);
}
}
} else {
tmp = label_json;
(*col_val_ptr)[shard_id].emplace_back(construct_json);
}
(*col_val_ptr)[shard_id].emplace_back(tmp);
} else {
json construct_json;
for (unsigned int j = 0; j < columns.size(); ++j) {
// construct json "f1": value
auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
// convert the string to base type by schema
if (schema[columns[j]]["type"] == "int32") {
construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j + 3]);
} else if (schema[columns[j]]["type"] == "int64") {
construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j + 3]);
} else if (schema[columns[j]]["type"] == "float32") {
construct_json[columns[j]] = StringToNum<float>(labels[i][j + 3]);
} else if (schema[columns[j]]["type"] == "float64") {
construct_json[columns[j]] = StringToNum<double>(labels[i][j + 3]);
} else {
construct_json[columns[j]] = std::string(labels[i][j + 3]);
}
}
(*col_val_ptr)[shard_id].emplace_back(construct_json);
} catch (std::out_of_range &e) {
MS_LOG(ERROR) << "Out of range: " << e.what();
return FAILED;
} catch (std::invalid_argument &e) {
MS_LOG(ERROR) << "Invalid argument: " << e.what();
return FAILED;
} catch (...) {
MS_LOG(ERROR) << "Exception was caught while convert label to json.";
return FAILED;
}
}
return SUCCESS;
}
} // namespace mindrecord
MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
@ -965,9 +976,13 @@ MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths
num_samples = category_op->GetNumSamples(num_samples, num_classes);
if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
auto tmp = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
if (tmp != 0) {
if (tmp != 0 && num_samples != -1) {
num_samples = std::min(num_samples, tmp);
}
if (-1 == num_samples) {
MS_LOG(ERROR) << "Number of samples exceeds the upper limit: " << std::numeric_limits<int64_t>::max();
return FAILED;
}
}
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {

View File

@ -39,7 +39,14 @@ MSRStatus ShardCategory::Execute(ShardTaskList &tasks) { return SUCCESS; }
int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (dataset_size == 0) return dataset_size;
if (dataset_size > 0 && num_classes > 0 && num_categories_ > 0 && num_elements_ > 0) {
return std::min(num_categories_, num_classes) * num_elements_;
num_classes = std::min(num_categories_, num_classes);
if (num_classes == 0) {
return 0;
}
if (num_elements_ > std::numeric_limits<int64_t>::max() / num_classes) {
return -1;
}
return num_classes * num_elements_;
}
return 0;
}