!20589 [MD] fix security check in master
Merge pull request !20589 from liyong126/fix_security_check
This commit is contained in:
commit
7fe8eddc42
|
@ -106,22 +106,27 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const str
|
|||
std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) {
|
||||
std::vector<std::string> field_name = StringSplit(field_path, kPoint);
|
||||
for (uint64_t i = 0; i < field_name.size(); i++) {
|
||||
if (i != field_name.size() - 1) {
|
||||
// Get type information from json schema
|
||||
schema = schema.at(field_name[i]);
|
||||
schema = schema.at("properties");
|
||||
} else {
|
||||
// standard root layer exist "properties" if type is "object"
|
||||
if (schema.find("properties") != schema.end()) {
|
||||
try {
|
||||
if (i != field_name.size() - 1) {
|
||||
// Get type information from json schema
|
||||
schema = schema.at(field_name[i]);
|
||||
schema = schema.at("properties");
|
||||
}
|
||||
schema = schema.at(field_name[i]);
|
||||
std::string field_type = schema.at("type").dump();
|
||||
if (field_type.length() <= 2) {
|
||||
return "";
|
||||
} else {
|
||||
return field_type.substr(1, field_type.length() - 2);
|
||||
// standard root layer exist "properties" if type is "object"
|
||||
if (schema.find("properties") != schema.end()) {
|
||||
schema = schema.at("properties");
|
||||
}
|
||||
schema = schema.at(field_name[i]);
|
||||
std::string field_type = schema.at("type").dump();
|
||||
if (field_type.length() <= 2) {
|
||||
return "";
|
||||
} else {
|
||||
return field_type.substr(1, field_type.length() - 2);
|
||||
}
|
||||
}
|
||||
} catch (...) {
|
||||
MS_LOG(WARNING) << "Exception occurred while get field type.";
|
||||
return "";
|
||||
}
|
||||
}
|
||||
return "";
|
||||
|
@ -330,6 +335,9 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
|
|||
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data) {
|
||||
sqlite3_stmt *stmt = nullptr;
|
||||
if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
|
||||
if (stmt) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
}
|
||||
MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql;
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -342,29 +350,34 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
|
|||
int index = sqlite3_bind_parameter_index(stmt, common::SafeCStr(place_holder));
|
||||
if (field_type == "INTEGER") {
|
||||
if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
|
||||
<< ", field value: " << std::stoll(field_value);
|
||||
return FAILED;
|
||||
}
|
||||
} else if (field_type == "NUMERIC") {
|
||||
if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index
|
||||
<< ", field value: " << std::stold(field_value);
|
||||
return FAILED;
|
||||
}
|
||||
} else if (field_type == "NULL") {
|
||||
if (sqlite3_bind_null(stmt, index) != SQLITE_OK) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: NULL";
|
||||
return FAILED;
|
||||
}
|
||||
} else {
|
||||
if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << field_value;
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (sqlite3_step(stmt) != SQLITE_DONE) {
|
||||
(void)sqlite3_finalize(stmt);
|
||||
MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt.";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -422,7 +435,12 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
|
|||
std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> full_data;
|
||||
|
||||
// current raw data page
|
||||
std::shared_ptr<Page> cur_raw_page = shard_header_.GetPage(shard_no, raw_page_id).first;
|
||||
auto ret1 = shard_header_.GetPage(shard_no, raw_page_id);
|
||||
if (ret1.second != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Get page failed";
|
||||
return {FAILED, {}};
|
||||
}
|
||||
std::shared_ptr<Page> cur_raw_page = ret1.first;
|
||||
|
||||
// related blob page
|
||||
vector<pair<int, uint64_t>> row_group_list = cur_raw_page->GetRowGroupIds();
|
||||
|
@ -430,7 +448,17 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int,
|
|||
// pair: row_group id, offset in raw data page
|
||||
for (pair<int, int> blob_ids : row_group_list) {
|
||||
// get blob data page according to row_group id
|
||||
std::shared_ptr<Page> cur_blob_page = shard_header_.GetPage(shard_no, blob_id_to_page_id.at(blob_ids.first)).first;
|
||||
auto iter = blob_id_to_page_id.find(blob_ids.first);
|
||||
if (iter == blob_id_to_page_id.end()) {
|
||||
MS_LOG(ERROR) << "Convert blob id failed";
|
||||
return {FAILED, {}};
|
||||
}
|
||||
auto ret2 = shard_header_.GetPage(shard_no, iter->second);
|
||||
if (ret2.second != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Get page failed";
|
||||
return {FAILED, {}};
|
||||
}
|
||||
std::shared_ptr<Page> cur_blob_page = ret2.first;
|
||||
|
||||
// offset in current raw data page
|
||||
auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second);
|
||||
|
@ -619,7 +647,12 @@ void ShardIndexGenerator::DatabaseWriter() {
|
|||
std::map<int, int> blob_id_to_page_id;
|
||||
std::vector<int> raw_page_ids;
|
||||
for (uint64_t i = 0; i < total_pages; ++i) {
|
||||
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first;
|
||||
auto ret = shard_header_.GetPage(shard_no, i);
|
||||
if (ret.second != SUCCESS) {
|
||||
write_success_ = false;
|
||||
return;
|
||||
}
|
||||
std::shared_ptr<Page> cur_page = ret.first;
|
||||
if (cur_page->GetPageType() == "RAW_DATA") {
|
||||
raw_page_ids.push_back(i);
|
||||
} else if (cur_page->GetPageType() == "BLOB_DATA") {
|
||||
|
|
|
@ -340,67 +340,78 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
|
|||
int shard_id, const std::vector<std::string> &columns,
|
||||
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
|
||||
for (int i = 0; i < static_cast<int>(labels.size()); ++i) {
|
||||
uint64_t group_id = std::stoull(labels[i][0]);
|
||||
uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len;
|
||||
uint64_t offset_end = std::stoull(labels[i][2]);
|
||||
(*offset_ptr)[shard_id].emplace_back(
|
||||
std::vector<uint64_t>{static_cast<uint64_t>(shard_id), group_id, offset_start, offset_end});
|
||||
if (!all_in_index_) {
|
||||
int raw_page_id = std::stoi(labels[i][3]);
|
||||
uint64_t label_start = std::stoull(labels[i][4]) + kInt64Len;
|
||||
uint64_t label_end = std::stoull(labels[i][5]);
|
||||
auto len = label_end - label_start;
|
||||
auto label_raw = std::vector<uint8_t>(len);
|
||||
auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
|
||||
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
|
||||
MS_LOG(ERROR) << "File seekg failed";
|
||||
fs->close();
|
||||
return FAILED;
|
||||
}
|
||||
try {
|
||||
uint64_t group_id = std::stoull(labels[i][0]);
|
||||
uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len;
|
||||
uint64_t offset_end = std::stoull(labels[i][2]);
|
||||
(*offset_ptr)[shard_id].emplace_back(
|
||||
std::vector<uint64_t>{static_cast<uint64_t>(shard_id), group_id, offset_start, offset_end});
|
||||
if (!all_in_index_) {
|
||||
int raw_page_id = std::stoi(labels[i][3]);
|
||||
uint64_t label_start = std::stoull(labels[i][4]) + kInt64Len;
|
||||
uint64_t label_end = std::stoull(labels[i][5]);
|
||||
auto len = label_end - label_start;
|
||||
auto label_raw = std::vector<uint8_t>(len);
|
||||
auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg);
|
||||
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
|
||||
MS_LOG(ERROR) << "File seekg failed";
|
||||
fs->close();
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
|
||||
if (!io_read.good() || io_read.fail() || io_read.bad()) {
|
||||
MS_LOG(ERROR) << "File read failed";
|
||||
fs->close();
|
||||
return FAILED;
|
||||
}
|
||||
json label_json = json::from_msgpack(label_raw);
|
||||
json tmp;
|
||||
if (!columns.empty()) {
|
||||
for (auto &col : columns) {
|
||||
if (label_json.find(col) != label_json.end()) {
|
||||
tmp[col] = label_json[col];
|
||||
auto &io_read = fs->read(reinterpret_cast<char *>(&label_raw[0]), len);
|
||||
if (!io_read.good() || io_read.fail() || io_read.bad()) {
|
||||
MS_LOG(ERROR) << "File read failed";
|
||||
fs->close();
|
||||
return FAILED;
|
||||
}
|
||||
json label_json = json::from_msgpack(label_raw);
|
||||
json tmp;
|
||||
if (!columns.empty()) {
|
||||
for (auto &col : columns) {
|
||||
if (label_json.find(col) != label_json.end()) {
|
||||
tmp[col] = label_json[col];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tmp = label_json;
|
||||
}
|
||||
(*col_val_ptr)[shard_id].emplace_back(tmp);
|
||||
} else {
|
||||
json construct_json;
|
||||
for (unsigned int j = 0; j < columns.size(); ++j) {
|
||||
// construct json "f1": value
|
||||
auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
|
||||
|
||||
// convert the string to base type by schema
|
||||
if (schema[columns[j]]["type"] == "int32") {
|
||||
construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j + 3]);
|
||||
} else if (schema[columns[j]]["type"] == "int64") {
|
||||
construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j + 3]);
|
||||
} else if (schema[columns[j]]["type"] == "float32") {
|
||||
construct_json[columns[j]] = StringToNum<float>(labels[i][j + 3]);
|
||||
} else if (schema[columns[j]]["type"] == "float64") {
|
||||
construct_json[columns[j]] = StringToNum<double>(labels[i][j + 3]);
|
||||
} else {
|
||||
construct_json[columns[j]] = std::string(labels[i][j + 3]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tmp = label_json;
|
||||
(*col_val_ptr)[shard_id].emplace_back(construct_json);
|
||||
}
|
||||
(*col_val_ptr)[shard_id].emplace_back(tmp);
|
||||
} else {
|
||||
json construct_json;
|
||||
for (unsigned int j = 0; j < columns.size(); ++j) {
|
||||
// construct json "f1": value
|
||||
auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"];
|
||||
|
||||
// convert the string to base type by schema
|
||||
if (schema[columns[j]]["type"] == "int32") {
|
||||
construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j + 3]);
|
||||
} else if (schema[columns[j]]["type"] == "int64") {
|
||||
construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j + 3]);
|
||||
} else if (schema[columns[j]]["type"] == "float32") {
|
||||
construct_json[columns[j]] = StringToNum<float>(labels[i][j + 3]);
|
||||
} else if (schema[columns[j]]["type"] == "float64") {
|
||||
construct_json[columns[j]] = StringToNum<double>(labels[i][j + 3]);
|
||||
} else {
|
||||
construct_json[columns[j]] = std::string(labels[i][j + 3]);
|
||||
}
|
||||
}
|
||||
(*col_val_ptr)[shard_id].emplace_back(construct_json);
|
||||
} catch (std::out_of_range &e) {
|
||||
MS_LOG(ERROR) << "Out of range: " << e.what();
|
||||
return FAILED;
|
||||
} catch (std::invalid_argument &e) {
|
||||
MS_LOG(ERROR) << "Invalid argument: " << e.what();
|
||||
return FAILED;
|
||||
} catch (...) {
|
||||
MS_LOG(ERROR) << "Exception was caught while convert label to json.";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace mindrecord
|
||||
|
||||
MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
|
||||
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
|
||||
|
@ -965,9 +976,13 @@ MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths
|
|||
num_samples = category_op->GetNumSamples(num_samples, num_classes);
|
||||
if (std::dynamic_pointer_cast<ShardPkSample>(op)) {
|
||||
auto tmp = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples();
|
||||
if (tmp != 0) {
|
||||
if (tmp != 0 && num_samples != -1) {
|
||||
num_samples = std::min(num_samples, tmp);
|
||||
}
|
||||
if (-1 == num_samples) {
|
||||
MS_LOG(ERROR) << "Number of samples exceeds the upper limit: " << std::numeric_limits<int64_t>::max();
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
} else if (std::dynamic_pointer_cast<ShardSample>(op)) {
|
||||
if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) {
|
||||
|
|
|
@ -39,7 +39,14 @@ MSRStatus ShardCategory::Execute(ShardTaskList &tasks) { return SUCCESS; }
|
|||
int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
|
||||
if (dataset_size == 0) return dataset_size;
|
||||
if (dataset_size > 0 && num_classes > 0 && num_categories_ > 0 && num_elements_ > 0) {
|
||||
return std::min(num_categories_, num_classes) * num_elements_;
|
||||
num_classes = std::min(num_categories_, num_classes);
|
||||
if (num_classes == 0) {
|
||||
return 0;
|
||||
}
|
||||
if (num_elements_ > std::numeric_limits<int64_t>::max() / num_classes) {
|
||||
return -1;
|
||||
}
|
||||
return num_classes * num_elements_;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue