fix security check
This commit is contained in:
parent
31a4c3116e
commit
e262b11901
|
@ -106,6 +106,7 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const str
|
|||
std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) {
|
||||
std::vector<std::string> field_name = StringSplit(field_path, kPoint);
|
||||
for (uint64_t i = 0; i < field_name.size(); i++) {
|
||||
try {
|
||||
if (i != field_name.size() - 1) {
|
||||
// Get type information from json schema
|
||||
schema = schema.at(field_name[i]);
|
||||
|
@ -123,6 +124,10 @@ std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json sc
|
|||
return field_type.substr(1, field_type.length() - 2);
|
||||
}
|
||||
}
|
||||
} catch (...) {
|
||||
MS_LOG(WARNING) << "Exception occurred while get field type.";
|
||||
return "";
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
@ -330,6 +335,9 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
|
|||
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data) {
|
||||
sqlite3_stmt *stmt = nullptr;
|
||||
if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
|
||||
if (stmt) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
}
|
||||
MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql;
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -342,29 +350,34 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
|
|||
int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder));
|
||||
if (field_type == "INTEGER") {
|
||||
if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
|
||||
<< ", field value: " << std::stoll(field_value);
|
||||
return FAILED;
|
||||
}
|
||||
} else if (field_type == "NUMERIC") {
|
||||
if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
|
||||
<< ", field value: " << std::stold(field_value);
|
||||
return FAILED;
|
||||
}
|
||||
} else if (field_type == "NULL") {
|
||||
if (sqlite3_bind_null(stmt, index) != SQLITE_OK) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: NULL";
|
||||
return FAILED;
|
||||
}
|
||||
} else {
|
||||
if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << field_value;
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (sqlite3_step(stmt) != SQLITE_DONE) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt.";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -422,7 +435,12 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
|
|||
std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> full_data;
|
||||
|
||||
// current raw data page
|
||||
std::shared_ptr<Page> cur_raw_page = shard_header_.GetPage(shard_no, raw_page_id).first;
|
||||
auto ret1 = shard_header_.GetPage(shard_no, raw_page_id);
|
||||
if (ret1.second != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Get page failed";
|
||||
return {FAILED, {}};
|
||||
}
|
||||
std::shared_ptr<Page> cur_raw_page = ret1.first;
|
||||
|
||||
// related blob page
|
||||
vector<pair<int, uint64_t>> row_group_list = cur_raw_page->GetRowGroupIds();
|
||||
|
@ -430,7 +448,17 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
|
|||
// pair: row_group id, offset in raw data page
|
||||
for (pair<int, int> blob_ids : row_group_list) {
|
||||
// get blob data page according to row_group id
|
||||
std::shared_ptr<Page> cur_blob_page = shard_header_.GetPage(shard_no, blob_id_to_page_id.at(blob_ids.first)).first;
|
||||
auto iter = blob_id_to_page_id.find(blob_ids.first);
|
||||
if (iter == blob_id_to_page_id.end()) {
|
||||
MS_LOG(ERROR) << "Convert blob id failed";
|
||||
return {FAILED, {}};
|
||||
}
|
||||
auto ret2 = shard_header_.GetPage(shard_no, iter->second);
|
||||
if (ret2.second != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Get page failed";
|
||||
return {FAILED, {}};
|
||||
}
|
||||
std::shared_ptr<Page> cur_blob_page = ret2.first;
|
||||
|
||||
// offset in current raw data page
|
||||
auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second);
|
||||
|
@ -618,7 +646,12 @@ void ShardIndexGenerator::DatabaseWriter() {
|
|||
std::map<int, int> blob_id_to_page_id;
|
||||
std::vector<int> raw_page_ids;
|
||||
for (uint64_t i = 0; i < total_pages; ++i) {
|
||||
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first;
|
||||
auto ret = shard_header_.GetPage(shard_no, i);
|
||||
if (ret.second != SUCCESS) {
|
||||
write_success_ = false;
|
||||
return;
|
||||
}
|
||||
std::shared_ptr<Page> cur_page = ret.first;
|
||||
if (cur_page->GetPageType() == "RAW_DATA") {
|
||||
raw_page_ids.push_back(i);
|
||||
} else if (cur_page->GetPageType() == "BLOB_DATA") {
|
||||
|
|
|
@ -340,6 +340,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
|
|||
int shard_id, const std::vector<std::string> &columns,
|
||||
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
|
||||
for (int i = 0; i < static_cast<int>(labels.size()); ++i) {
|
||||
try {
|
||||
uint64_t group_id = std::stoull(labels[i][0]);
|
||||
uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len;
|
||||
uint64_t offset_end = std::stoull(labels[i][2]);
|
||||
|
@ -397,10 +398,20 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
|
|||
}
|
||||
(*col_val_ptr)[shard_id].emplace_back(construct_json);
|
||||
}
|
||||
} catch (std::out_of_range &e) {
|
||||
MS_LOG(ERROR) << "Out of range: " << e.what();
|
||||
return FAILED;
|
||||
} catch (std::invalid_argument &e) {
|
||||
MS_LOG(ERROR) << "Invalid argument: " << e.what();
|
||||
return FAILED;
|
||||
} catch (...) {
|
||||
MS_LOG(ERROR) << "Exception was caught while convert label to json.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace mindrecord
|
||||
|
||||
MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
|
||||
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
|
||||
|
@ -961,9 +972,13 @@ MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths
|
|||
num_samples = category_op->GetNumSamples(num_samples, num_classes);
|
||||
if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
|
||||
auto tmp = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
|
||||
if (tmp != 0) {
|
||||
if (tmp != 0 && num_samples != -1) {
|
||||
num_samples = std::min(num_samples, tmp);
|
||||
}
|
||||
if (-1 == num_samples) {
|
||||
MS_LOG(ERROR) << "Number of samples exceeds the upper limit: " << std::numeric_limits<int64_t>::max();
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
|
||||
if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
|
||||
|
|
|
@ -39,7 +39,14 @@ MSRStatus ShardCategory::Execute(ShardTaskList &tasks) { return SUCCESS; }
|
|||
int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
||||
if (dataset_size == 0) return dataset_size;
|
||||
if (dataset_size > 0 && num_classes > 0 && num_categories_ > 0 && num_elements_ > 0) {
|
||||
return std::min(num_categories_, num_classes) * num_elements_;
|
||||
num_classes = std::min(num_categories_, num_classes);
|
||||
if (num_classes == 0) {
|
||||
return 0;
|
||||
}
|
||||
if (num_elements_ > std::numeric_limits<int64_t>::max() / num_classes) {
|
||||
return -1;
|
||||
}
|
||||
return num_classes * num_elements_;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue