fix security check

This commit is contained in:
liyong 2021-07-20 17:19:02 +08:00
parent 31a4c3116e
commit e262b11901
3 changed files with 126 additions and 71 deletions

View File

@ -106,6 +106,7 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const str
std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) {
std::vector<std::string> field_name = StringSplit(field_path, kPoint);
for (uint64_t i = 0; i < field_name.size(); i++) {
try {
if (i != field_name.size() - 1) {
// Get type information from json schema
schema = schema.at(field_name[i]);
@ -123,6 +124,10 @@ std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json sc
return field_type.substr(1, field_type.length() - 2);
}
}
} catch (...) {
MS_LOG(WARNING) << "Exception occurred while get field type.";
return "";
}
}
return "";
}
@ -330,6 +335,9 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data) {
sqlite3_stmt *stmt = nullptr;
if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
if (stmt) {
(void)sqlite3_finalize(stmt);
}
MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql;
return FAILED;
}
@ -342,29 +350,34 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder));
if (field_type == "INTEGER") {
if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
<< ", field value: " << std::stoll(field_value);
return FAILED;
}
} else if (field_type == "NUMERIC") {
if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
<< ", field value: " << std::stold(field_value);
return FAILED;
}
} else if (field_type == "NULL") {
if (sqlite3_bind_null(stmt, index) != SQLITE_OK) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: NULL";
return FAILED;
}
} else {
if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << field_value;
return FAILED;
}
}
}
if (sqlite3_step(stmt) != SQLITE_DONE) {
(void)sqlite3_finalize(stmt);
MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt.";
return FAILED;
}
@ -422,7 +435,12 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> full_data;
// current raw data page
std::shared_ptr<Page> cur_raw_page = shard_header_.GetPage(shard_no, raw_page_id).first;
auto ret1 = shard_header_.GetPage(shard_no, raw_page_id);
if (ret1.second != SUCCESS) {
MS_LOG(ERROR) << "Get page failed";
return {FAILED, {}};
}
std::shared_ptr<Page> cur_raw_page = ret1.first;
// related blob page
vector<pair<int, uint64_t>> row_group_list = cur_raw_page->GetRowGroupIds();
@ -430,7 +448,17 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
// pair: row_group id, offset in raw data page
for (pair<int, int> blob_ids : row_group_list) {
// get blob data page according to row_group id
std::shared_ptr<Page> cur_blob_page = shard_header_.GetPage(shard_no, blob_id_to_page_id.at(blob_ids.first)).first;
auto iter = blob_id_to_page_id.find(blob_ids.first);
if (iter == blob_id_to_page_id.end()) {
MS_LOG(ERROR) << "Convert blob id failed";
return {FAILED, {}};
}
auto ret2 = shard_header_.GetPage(shard_no, iter->second);
if (ret2.second != SUCCESS) {
MS_LOG(ERROR) << "Get page failed";
return {FAILED, {}};
}
std::shared_ptr<Page> cur_blob_page = ret2.first;
// offset in current raw data page
auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second);
@ -618,7 +646,12 @@ void ShardIndexGenerator::DatabaseWriter() {
std::map<int, int> blob_id_to_page_id;
std::vector<int> raw_page_ids;
for (uint64_t i = 0; i < total_pages; ++i) {
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first;
auto ret = shard_header_.GetPage(shard_no, i);
if (ret.second != SUCCESS) {
write_success_ = false;
return;
}
std::shared_ptr<Page> cur_page = ret.first;
if (cur_page->GetPageType() == "RAW_DATA") {
raw_page_ids.push_back(i);
} else if (cur_page->GetPageType() == "BLOB_DATA") {

View File

@ -340,6 +340,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
int shard_id, const std::vector<std::string> &columns,
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
for (int i = 0; i < static_cast<int>(labels.size()); ++i) {
try {
uint64_t group_id = std::stoull(labels[i][0]);
uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len;
uint64_t offset_end = std::stoull(labels[i][2]);
@ -397,10 +398,20 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
}
(*col_val_ptr)[shard_id].emplace_back(construct_json);
}
} catch (std::out_of_range &e) {
MS_LOG(ERROR) << "Out of range: " << e.what();
return FAILED;
} catch (std::invalid_argument &e) {
MS_LOG(ERROR) << "Invalid argument: " << e.what();
return FAILED;
} catch (...) {
MS_LOG(ERROR) << "Exception was caught while convert label to json.";
return FAILED;
}
}
return SUCCESS;
}
} // namespace mindrecord
MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
@ -961,9 +972,13 @@ MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths
num_samples = category_op->GetNumSamples(num_samples, num_classes);
if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
auto tmp = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
if (tmp != 0) {
if (tmp != 0 && num_samples != -1) {
num_samples = std::min(num_samples, tmp);
}
if (-1 == num_samples) {
MS_LOG(ERROR) << "Number of samples exceeds the upper limit: " << std::numeric_limits<int64_t>::max();
return FAILED;
}
}
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {

View File

@ -39,7 +39,14 @@ MSRStatus ShardCategory::Execute(ShardTaskList &tasks) { return SUCCESS; }
int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (dataset_size == 0) return dataset_size;
if (dataset_size > 0 && num_classes > 0 && num_categories_ > 0 && num_elements_ > 0) {
return std::min(num_categories_, num_classes) * num_elements_;
num_classes = std::min(num_categories_, num_classes);
if (num_classes == 0) {
return 0;
}
if (num_elements_ > std::numeric_limits<int64_t>::max() / num_classes) {
return -1;
}
return num_classes * num_elements_;
}
return 0;
}