forked from mindspore-Ecosystem/mindspore
!14828 [MD] refactor pk sampler in minddataset
From: @liyong126 Reviewed-by: @liucunwei,@heleiwang Signed-off-by: @liucunwei
This commit is contained in:
commit
f1a28e17ce
|
@ -55,6 +55,8 @@
|
|||
#include "minddata/mindrecord/include/shard_shuffle.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
#define API_PUBLIC __attribute__((visibility("default")))
|
||||
|
||||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
using ROW_GROUPS =
|
||||
|
@ -65,7 +67,7 @@ using TASK_RETURN_CONTENT =
|
|||
std::pair<MSRStatus, std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>>;
|
||||
const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode
|
||||
|
||||
class __attribute__((visibility("default"))) ShardReader {
|
||||
class API_PUBLIC ShardReader {
|
||||
public:
|
||||
ShardReader();
|
||||
|
||||
|
@ -203,7 +205,7 @@ class __attribute__((visibility("default"))) ShardReader {
|
|||
void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; }
|
||||
|
||||
/// \brief get all classes
|
||||
MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories);
|
||||
MSRStatus GetAllClasses(const std::string &category_field, std::shared_ptr<std::set<std::string>> category_ptr);
|
||||
|
||||
/// \brief get the size of blob data
|
||||
MSRStatus GetTotalBlobSize(int64_t *total_blob_size);
|
||||
|
@ -215,11 +217,12 @@ class __attribute__((visibility("default"))) ShardReader {
|
|||
private:
|
||||
/// \brief wrap up labels to json format
|
||||
MSRStatus ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs,
|
||||
std::vector<std::vector<std::vector<uint64_t>>> &offsets, int shard_id,
|
||||
const std::vector<std::string> &columns, std::vector<std::vector<json>> &column_values);
|
||||
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
|
||||
int shard_id, const std::vector<std::string> &columns,
|
||||
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);
|
||||
|
||||
/// \brief read all rows for specified columns
|
||||
ROW_GROUPS ReadAllRowGroup(std::vector<std::string> &columns);
|
||||
ROW_GROUPS ReadAllRowGroup(const std::vector<std::string> &columns);
|
||||
|
||||
/// \brief read row meta by shard_id and sample_id
|
||||
ROW_GROUPS ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
|
||||
|
@ -227,8 +230,8 @@ class __attribute__((visibility("default"))) ShardReader {
|
|||
|
||||
/// \brief read all rows in one shard
|
||||
MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
|
||||
std::vector<std::vector<std::vector<uint64_t>>> &offsets,
|
||||
std::vector<std::vector<json>> &column_values);
|
||||
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
|
||||
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);
|
||||
|
||||
/// \brief initialize reader
|
||||
MSRStatus Init(const std::vector<std::string> &file_paths, bool load_dataset);
|
||||
|
@ -243,8 +246,12 @@ class __attribute__((visibility("default"))) ShardReader {
|
|||
std::vector<std::vector<uint64_t>> GetImageOffset(int group_id, int shard_id,
|
||||
const std::pair<std::string, std::string> &criteria = {"", ""});
|
||||
|
||||
/// \brief get page id by category
|
||||
std::pair<MSRStatus, std::vector<uint64_t>> GetPagesByCategory(int shard_id,
|
||||
const std::pair<std::string, std::string> &criteria);
|
||||
/// \brief execute sqlite query with prepare statement
|
||||
MSRStatus QueryWithCriteria(sqlite3 *db, string &sql, string criteria, std::vector<std::vector<std::string>> &labels);
|
||||
MSRStatus QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria,
|
||||
std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr);
|
||||
|
||||
/// \brief get column values
|
||||
std::pair<MSRStatus, std::vector<json>> GetLabels(int group_id, int shard_id, const std::vector<std::string> &columns,
|
||||
|
@ -257,8 +264,7 @@ class __attribute__((visibility("default"))) ShardReader {
|
|||
""});
|
||||
|
||||
/// \brief create category-applied task list
|
||||
MSRStatus CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||
const std::shared_ptr<ShardOperator> &op);
|
||||
MSRStatus CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op);
|
||||
|
||||
/// \brief create task list in row-reader mode
|
||||
MSRStatus CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||
|
@ -286,13 +292,15 @@ class __attribute__((visibility("default"))) ShardReader {
|
|||
int shard_id, const std::vector<std::string> &columns, const std::vector<std::vector<std::string>> &label_offsets);
|
||||
|
||||
/// \brief get classes in one shard
|
||||
void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set<std::string> &categories);
|
||||
void GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
|
||||
std::shared_ptr<std::set<std::string>> category_ptr);
|
||||
|
||||
/// \brief get number of classes
|
||||
int64_t GetNumClasses(const std::string &category_field);
|
||||
|
||||
/// \brief get meta of header
|
||||
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data);
|
||||
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path,
|
||||
std::shared_ptr<json> meta_data_ptr);
|
||||
|
||||
/// \brief extract uncompressed data based on column list
|
||||
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> UnCompressBlob(const std::vector<uint8_t> &raw_blob_data);
|
||||
|
|
|
@ -51,7 +51,8 @@ ShardReader::ShardReader()
|
|||
lazy_load_(false),
|
||||
shard_sample_count_() {}
|
||||
|
||||
std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path, json &meta_data) {
|
||||
std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path,
|
||||
std::shared_ptr<json> meta_data_ptr) {
|
||||
if (!IsLegalFile(file_path)) {
|
||||
return {FAILED, {}};
|
||||
}
|
||||
|
@ -60,16 +61,16 @@ std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::s
|
|||
return {FAILED, {}};
|
||||
}
|
||||
auto header = ret.second;
|
||||
meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]},
|
||||
{"version", header["version"]}, {"index_fields", header["index_fields"]},
|
||||
{"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}};
|
||||
*meta_data_ptr = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]},
|
||||
{"version", header["version"]}, {"index_fields", header["index_fields"]},
|
||||
{"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}};
|
||||
return {SUCCESS, header["shard_addresses"]};
|
||||
}
|
||||
|
||||
MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool load_dataset) {
|
||||
std::string file_path = file_paths[0];
|
||||
json first_meta_data = json();
|
||||
auto ret = GetMeta(file_path, first_meta_data);
|
||||
auto first_meta_data_ptr = std::make_shared<json>();
|
||||
auto ret = GetMeta(file_path, first_meta_data_ptr);
|
||||
if (ret.first != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -91,12 +92,12 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
|
|||
return FAILED;
|
||||
}
|
||||
for (const auto &file : file_paths_) {
|
||||
json meta_data = json();
|
||||
auto ret1 = GetMeta(file, meta_data);
|
||||
auto meta_data_ptr = std::make_shared<json>();
|
||||
auto ret1 = GetMeta(file, meta_data_ptr);
|
||||
if (ret1.first != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
if (meta_data != first_meta_data) {
|
||||
if (*meta_data_ptr != *first_meta_data_ptr) {
|
||||
MS_LOG(ERROR) << "Mindrecord files meta information is different.";
|
||||
return FAILED;
|
||||
}
|
||||
|
@ -140,7 +141,7 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
|
|||
header_size_ = shard_header_->GetHeaderSize();
|
||||
page_size_ = shard_header_->GetPageSize();
|
||||
// version < 3.0
|
||||
if (first_meta_data["version"] < kVersion) {
|
||||
if ((*first_meta_data_ptr)["version"] < kVersion) {
|
||||
shard_column_ = std::make_shared<ShardColumn>(shard_header_, false);
|
||||
} else {
|
||||
shard_column_ = std::make_shared<ShardColumn>(shard_header_, true);
|
||||
|
@ -314,14 +315,14 @@ MSRStatus ShardReader::GetTotalBlobSize(int64_t *total_blob_size) {
|
|||
|
||||
MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels,
|
||||
std::shared_ptr<std::fstream> fs,
|
||||
std::vector<std::vector<std::vector<uint64_t>>> &offsets, int shard_id,
|
||||
const std::vector<std::string> &columns,
|
||||
std::vector<std::vector<json>> &column_values) {
|
||||
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
|
||||
int shard_id, const std::vector<std::string> &columns,
|
||||
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
|
||||
for (int i = 0; i < static_cast<int>(labels.size()); ++i) {
|
||||
uint64_t group_id = std::stoull(labels[i][0]);
|
||||
uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len;
|
||||
uint64_t offset_end = std::stoull(labels[i][2]);
|
||||
offsets[shard_id].emplace_back(
|
||||
(*offset_ptr)[shard_id].emplace_back(
|
||||
std::vector<uint64_t>{static_cast<uint64_t>(shard_id), group_id, offset_start, offset_end});
|
||||
if (!all_in_index_) {
|
||||
int raw_page_id = std::stoi(labels[i][3]);
|
||||
|
@ -353,7 +354,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
|
|||
} else {
|
||||
tmp = label_json;
|
||||
}
|
||||
column_values[shard_id].emplace_back(tmp);
|
||||
(*col_val_ptr)[shard_id].emplace_back(tmp);
|
||||
} else {
|
||||
json construct_json;
|
||||
for (unsigned int j = 0; j < columns.size(); ++j) {
|
||||
|
@ -373,7 +374,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
|
|||
construct_json[columns[j]] = std::string(labels[i][j + 3]);
|
||||
}
|
||||
}
|
||||
column_values[shard_id].emplace_back(construct_json);
|
||||
(*col_val_ptr)[shard_id].emplace_back(construct_json);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -381,8 +382,8 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
|
|||
}
|
||||
|
||||
MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
|
||||
std::vector<std::vector<std::vector<uint64_t>>> &offsets,
|
||||
std::vector<std::vector<json>> &column_values) {
|
||||
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
|
||||
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
|
||||
auto db = database_paths_[shard_id];
|
||||
std::vector<std::vector<std::string>> labels;
|
||||
char *errmsg = nullptr;
|
||||
|
@ -406,10 +407,11 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
|
|||
}
|
||||
}
|
||||
sqlite3_free(errmsg);
|
||||
return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values);
|
||||
return ConvertLabelToJson(labels, fs, offset_ptr, shard_id, columns, col_val_ptr);
|
||||
}
|
||||
|
||||
MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) {
|
||||
MSRStatus ShardReader::GetAllClasses(const std::string &category_field,
|
||||
std::shared_ptr<std::set<std::string>> category_ptr) {
|
||||
std::map<std::string, uint64_t> index_columns;
|
||||
for (auto &field : GetShardHeader()->GetFields()) {
|
||||
index_columns[field.second] = field.first;
|
||||
|
@ -425,7 +427,7 @@ MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set
|
|||
std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES";
|
||||
std::vector<std::thread> threads = std::vector<std::thread>(shard_count_);
|
||||
for (int x = 0; x < shard_count_; x++) {
|
||||
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, std::ref(categories));
|
||||
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, category_ptr);
|
||||
}
|
||||
|
||||
for (int x = 0; x < shard_count_; x++) {
|
||||
|
@ -434,8 +436,8 @@ MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql,
|
||||
std::set<std::string> &categories) {
|
||||
void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
|
||||
std::shared_ptr<std::set<std::string>> category_ptr) {
|
||||
if (nullptr == db) {
|
||||
return;
|
||||
}
|
||||
|
@ -452,20 +454,22 @@ void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string
|
|||
MS_LOG(INFO) << "Get " << static_cast<int>(columns.size()) << " records from shard " << shard_id << " index.";
|
||||
std::lock_guard<std::mutex> lck(shard_locker_);
|
||||
for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
|
||||
categories.emplace(columns[i][0]);
|
||||
category_ptr->emplace(columns[i][0]);
|
||||
}
|
||||
}
|
||||
|
||||
ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector<std::string> &columns) {
|
||||
ROW_GROUPS ShardReader::ReadAllRowGroup(const std::vector<std::string> &columns) {
|
||||
std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
|
||||
std::vector<std::vector<std::vector<uint64_t>>> offsets(shard_count_, std::vector<std::vector<uint64_t>>{});
|
||||
std::vector<std::vector<json>> column_values(shard_count_, std::vector<json>{});
|
||||
auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
|
||||
shard_count_, std::vector<std::vector<uint64_t>>{});
|
||||
auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});
|
||||
|
||||
if (all_in_index_) {
|
||||
for (unsigned int i = 0; i < columns.size(); ++i) {
|
||||
fields += ',';
|
||||
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]));
|
||||
if (ret.first != SUCCESS) {
|
||||
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values));
|
||||
return std::make_tuple(FAILED, std::move(*offset_ptr), std::move(*col_val_ptr));
|
||||
}
|
||||
fields += ret.second;
|
||||
}
|
||||
|
@ -477,27 +481,27 @@ ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector<std::string> &columns) {
|
|||
|
||||
std::vector<std::thread> thread_read_db = std::vector<std::thread>(shard_count_);
|
||||
for (int x = 0; x < shard_count_; x++) {
|
||||
thread_read_db[x] =
|
||||
std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, std::ref(offsets), std::ref(column_values));
|
||||
thread_read_db[x] = std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, offset_ptr, col_val_ptr);
|
||||
}
|
||||
|
||||
for (int x = 0; x < shard_count_; x++) {
|
||||
thread_read_db[x].join();
|
||||
}
|
||||
return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values));
|
||||
return std::make_tuple(SUCCESS, std::move(*offset_ptr), std::move(*col_val_ptr));
|
||||
}
|
||||
|
||||
ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns,
|
||||
const uint32_t &shard_id, const uint32_t &sample_id) {
|
||||
std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
|
||||
std::vector<std::vector<std::vector<uint64_t>>> offsets(shard_count_, std::vector<std::vector<uint64_t>>{});
|
||||
std::vector<std::vector<json>> column_values(shard_count_, std::vector<json>{});
|
||||
auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
|
||||
shard_count_, std::vector<std::vector<uint64_t>>{});
|
||||
auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});
|
||||
if (all_in_index_) {
|
||||
for (unsigned int i = 0; i < columns.size(); ++i) {
|
||||
fields += ',';
|
||||
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]));
|
||||
if (ret.first != SUCCESS) {
|
||||
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values));
|
||||
return std::make_tuple(FAILED, std::move(*offset_ptr), std::move(*col_val_ptr));
|
||||
}
|
||||
fields += ret.second;
|
||||
}
|
||||
|
@ -507,12 +511,12 @@ ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std::
|
|||
|
||||
std::string sql = "SELECT " + fields + " FROM INDEXES WHERE ROW_ID = " + std::to_string(sample_id);
|
||||
|
||||
if (ReadAllRowsInShard(shard_id, sql, columns, offsets, column_values) != SUCCESS) {
|
||||
if (ReadAllRowsInShard(shard_id, sql, columns, offset_ptr, col_val_ptr) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "Read shard id: " << shard_id << ", sample id: " << sample_id << " from index failed.";
|
||||
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values));
|
||||
return std::make_tuple(FAILED, std::move(*offset_ptr), std::move(*col_val_ptr));
|
||||
}
|
||||
|
||||
return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values));
|
||||
return std::make_tuple(SUCCESS, std::move(*offset_ptr), std::move(*col_val_ptr));
|
||||
}
|
||||
|
||||
ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns) {
|
||||
|
@ -550,7 +554,10 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id,
|
|||
uint64_t page_length = page->GetPageSize();
|
||||
uint64_t page_offset = page_size_ * page->GetPageID() + header_size_;
|
||||
std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria);
|
||||
|
||||
if (image_offset.empty()) {
|
||||
return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::vector<std::vector<uint64_t>>(),
|
||||
std::vector<json>());
|
||||
}
|
||||
auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria);
|
||||
if (status_labels.first != SUCCESS) {
|
||||
return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>());
|
||||
|
@ -601,7 +608,7 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
|
|||
db = nullptr;
|
||||
return std::vector<std::vector<uint64_t>>();
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get " << static_cast<int>(image_offsets.size()) << "records from index.";
|
||||
MS_LOG(DEBUG) << "Get " << static_cast<int>(image_offsets.size()) << " records from index.";
|
||||
}
|
||||
std::vector<std::vector<uint64_t>> res;
|
||||
for (int i = static_cast<int>(image_offsets.size()) - 1; i >= 0; i--) res.emplace_back(std::vector<uint64_t>{0, 0});
|
||||
|
@ -614,6 +621,44 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
|
|||
return res;
|
||||
}
|
||||
|
||||
std::pair<MSRStatus, std::vector<uint64_t>> ShardReader::GetPagesByCategory(
|
||||
int shard_id, const std::pair<std::string, std::string> &criteria) {
|
||||
auto db = database_paths_[shard_id];
|
||||
|
||||
std::string sql = "SELECT DISTINCT PAGE_ID_BLOB FROM INDEXES WHERE 1 = 1 ";
|
||||
|
||||
if (!criteria.first.empty()) {
|
||||
auto schema = shard_header_->GetSchemas()[0]->GetSchema();
|
||||
|
||||
if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) {
|
||||
sql +=
|
||||
" AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second;
|
||||
} else {
|
||||
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" +
|
||||
criteria.second + "'";
|
||||
}
|
||||
}
|
||||
sql += ";";
|
||||
std::vector<std::vector<std::string>> page_ids;
|
||||
char *errmsg = nullptr;
|
||||
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &page_ids, &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
sqlite3_close(db);
|
||||
db = nullptr;
|
||||
return std::make_pair(FAILED, std::vector<uint64_t>());
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get " << page_ids.size() << "pages from index.";
|
||||
}
|
||||
std::vector<uint64_t> res;
|
||||
for (int i = 0; i < static_cast<int>(page_ids.size()); ++i) {
|
||||
res.emplace_back(std::stoull(page_ids[i][0]));
|
||||
}
|
||||
sqlite3_free(errmsg);
|
||||
return std::make_pair(SUCCESS, res);
|
||||
}
|
||||
|
||||
std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() {
|
||||
std::vector<std::string> blob_fields;
|
||||
for (auto &p : GetShardHeader()->GetSchemas()) {
|
||||
|
@ -642,8 +687,8 @@ void ShardReader::CheckIfColumnInIndex(const std::vector<std::string> &columns)
|
|||
}
|
||||
}
|
||||
|
||||
MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, string &sql, string criteria,
|
||||
std::vector<std::vector<std::string>> &labels) {
|
||||
MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria,
|
||||
std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr) {
|
||||
sqlite3_stmt *stmt = nullptr;
|
||||
if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
|
||||
MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql;
|
||||
|
@ -661,7 +706,7 @@ MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, string &sql, string criter
|
|||
for (int i = 0; i < ncols; i++) {
|
||||
tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
|
||||
}
|
||||
labels.push_back(tmp);
|
||||
labels_ptr->push_back(tmp);
|
||||
rc = sqlite3_step(stmt);
|
||||
}
|
||||
(void)sqlite3_finalize(stmt);
|
||||
|
@ -724,16 +769,16 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabelsFromPage(
|
|||
auto db = database_paths_[shard_id];
|
||||
std::string sql = "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = " +
|
||||
std::to_string(page_id);
|
||||
std::vector<std::vector<std::string>> label_offsets;
|
||||
auto label_offset_ptr = std::make_shared<std::vector<std::vector<std::string>>>();
|
||||
if (!criteria.first.empty()) {
|
||||
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria";
|
||||
if (QueryWithCriteria(db, sql, criteria.second, label_offsets) == FAILED) {
|
||||
if (QueryWithCriteria(db, sql, criteria.second, label_offset_ptr) == FAILED) {
|
||||
return {FAILED, {}};
|
||||
}
|
||||
} else {
|
||||
sql += ";";
|
||||
char *errmsg = nullptr;
|
||||
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &label_offsets, &errmsg);
|
||||
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, label_offset_ptr.get(), &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
|
@ -741,11 +786,11 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabelsFromPage(
|
|||
db = nullptr;
|
||||
return {FAILED, {}};
|
||||
}
|
||||
MS_LOG(DEBUG) << "Get " << label_offsets.size() << "records from index.";
|
||||
MS_LOG(DEBUG) << "Get " << label_offset_ptr->size() << " records from index.";
|
||||
sqlite3_free(errmsg);
|
||||
}
|
||||
// get labels from binary file
|
||||
return GetLabelsFromBinaryFile(shard_id, columns, label_offsets);
|
||||
return GetLabelsFromBinaryFile(shard_id, columns, *label_offset_ptr);
|
||||
}
|
||||
|
||||
std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int shard_id,
|
||||
|
@ -760,17 +805,17 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
|
|||
fields += columns[i] + "_" + std::to_string(schema_id);
|
||||
}
|
||||
if (fields.empty()) fields = "*";
|
||||
std::vector<std::vector<std::string>> labels;
|
||||
auto labels_ptr = std::make_shared<std::vector<std::vector<std::string>>>();
|
||||
std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id);
|
||||
if (!criteria.first.empty()) {
|
||||
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria";
|
||||
if (QueryWithCriteria(db, sql, criteria.second, labels) == FAILED) {
|
||||
if (QueryWithCriteria(db, sql, criteria.second, labels_ptr) == FAILED) {
|
||||
return {FAILED, {}};
|
||||
}
|
||||
} else {
|
||||
sql += ";";
|
||||
char *errmsg = nullptr;
|
||||
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg);
|
||||
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, labels_ptr.get(), &errmsg);
|
||||
if (rc != SQLITE_OK) {
|
||||
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
|
||||
sqlite3_free(errmsg);
|
||||
|
@ -778,13 +823,13 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
|
|||
db = nullptr;
|
||||
return {FAILED, {}};
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Get " << static_cast<int>(labels.size()) << "records from index.";
|
||||
MS_LOG(DEBUG) << "Get " << static_cast<int>(labels_ptr->size()) << " records from index.";
|
||||
}
|
||||
sqlite3_free(errmsg);
|
||||
}
|
||||
std::vector<json> ret;
|
||||
for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{});
|
||||
for (unsigned int i = 0; i < labels.size(); ++i) {
|
||||
for (unsigned int i = 0; i < labels_ptr->size(); ++i) ret.emplace_back(json{});
|
||||
for (unsigned int i = 0; i < labels_ptr->size(); ++i) {
|
||||
json construct_json;
|
||||
for (unsigned int j = 0; j < columns.size(); ++j) {
|
||||
// construct json "f1": value
|
||||
|
@ -792,15 +837,15 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
|
|||
|
||||
// convert the string to base type by schema
|
||||
if (schema[columns[j]]["type"] == "int32") {
|
||||
construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j]);
|
||||
construct_json[columns[j]] = StringToNum<int32_t>((*labels_ptr)[i][j]);
|
||||
} else if (schema[columns[j]]["type"] == "int64") {
|
||||
construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j]);
|
||||
construct_json[columns[j]] = StringToNum<int64_t>((*labels_ptr)[i][j]);
|
||||
} else if (schema[columns[j]]["type"] == "float32") {
|
||||
construct_json[columns[j]] = StringToNum<float>(labels[i][j]);
|
||||
construct_json[columns[j]] = StringToNum<float>((*labels_ptr)[i][j]);
|
||||
} else if (schema[columns[j]]["type"] == "float64") {
|
||||
construct_json[columns[j]] = StringToNum<double>(labels[i][j]);
|
||||
construct_json[columns[j]] = StringToNum<double>((*labels_ptr)[i][j]);
|
||||
} else {
|
||||
construct_json[columns[j]] = std::string(labels[i][j]);
|
||||
construct_json[columns[j]] = std::string((*labels_ptr)[i][j]);
|
||||
}
|
||||
}
|
||||
ret[i] = construct_json;
|
||||
|
@ -834,7 +879,7 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) {
|
|||
}
|
||||
std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES";
|
||||
std::vector<std::thread> threads = std::vector<std::thread>(shard_count);
|
||||
std::set<std::string> categories;
|
||||
auto category_ptr = std::make_shared<std::set<std::string>>();
|
||||
for (int x = 0; x < shard_count; x++) {
|
||||
sqlite3 *db = nullptr;
|
||||
int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
|
||||
|
@ -843,13 +888,13 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) {
|
|||
<< sqlite3_errmsg(db);
|
||||
return -1;
|
||||
}
|
||||
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, std::ref(categories));
|
||||
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, category_ptr);
|
||||
}
|
||||
|
||||
for (int x = 0; x < shard_count; x++) {
|
||||
threads[x].join();
|
||||
}
|
||||
return categories.size();
|
||||
return category_ptr->size();
|
||||
}
|
||||
|
||||
MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
|
||||
|
@ -1008,8 +1053,7 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||
const std::shared_ptr<ShardOperator> &op) {
|
||||
MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op) {
|
||||
CheckIfColumnInIndex(selected_columns_);
|
||||
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
|
||||
auto categories = category_op->GetCategories();
|
||||
|
@ -1033,42 +1077,50 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
|
|||
MS_LOG(ERROR) << "Invalid parameter, num_categories must be greater than 0, but got " << num_elements;
|
||||
return FAILED;
|
||||
}
|
||||
std::set<std::string> categories_set;
|
||||
auto ret = GetAllClasses(category_field, categories_set);
|
||||
auto category_ptr = std::make_shared<std::set<std::string>>();
|
||||
auto ret = GetAllClasses(category_field, category_ptr);
|
||||
if (SUCCESS != ret) {
|
||||
return FAILED;
|
||||
}
|
||||
int i = 0;
|
||||
for (auto it = categories_set.begin(); it != categories_set.end() && i < num_categories; ++it) {
|
||||
for (auto it = category_ptr->begin(); it != category_ptr->end() && i < num_categories; ++it) {
|
||||
categories.emplace_back(category_field, *it);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
// Generate task list, a task will create a batch
|
||||
std::vector<ShardTask> categoryTasks(categories.size());
|
||||
for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) {
|
||||
int category_index = 0;
|
||||
for (const auto &rg : row_group_summary) {
|
||||
if (category_index >= num_elements) break;
|
||||
auto shard_id = std::get<0>(rg);
|
||||
auto group_id = std::get<1>(rg);
|
||||
|
||||
auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_);
|
||||
if (SUCCESS != std::get<0>(details)) {
|
||||
for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) {
|
||||
auto res = GetPagesByCategory(shard_id, categories[categoryNo]);
|
||||
if (SUCCESS != res.first) {
|
||||
return FAILED;
|
||||
}
|
||||
auto offsets = std::get<4>(details);
|
||||
|
||||
auto number_of_rows = offsets.size();
|
||||
for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) {
|
||||
if (category_index < num_elements) {
|
||||
categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id, std::get<4>(details)[iStart],
|
||||
std::get<5>(details)[iStart]);
|
||||
category_index++;
|
||||
auto page_ids = res.second;
|
||||
for (const auto &page_id : page_ids) {
|
||||
if (category_index >= num_elements) break;
|
||||
const auto &page_t = shard_header_->GetPage(shard_id, page_id);
|
||||
const auto &page = page_t.first;
|
||||
auto group_id = page->GetPageTypeID();
|
||||
auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_);
|
||||
if (SUCCESS != std::get<0>(details)) {
|
||||
return FAILED;
|
||||
}
|
||||
auto offsets = std::get<4>(details);
|
||||
|
||||
auto number_of_rows = offsets.size();
|
||||
for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) {
|
||||
if (category_index < num_elements) {
|
||||
categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id,
|
||||
std::get<4>(details)[iStart], std::get<5>(details)[iStart]);
|
||||
category_index++;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks";
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks";
|
||||
}
|
||||
tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);
|
||||
if (SUCCESS != (*category_op)(tasks_)) {
|
||||
|
@ -1189,7 +1241,7 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
|
|||
}
|
||||
}
|
||||
} else {
|
||||
if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) {
|
||||
if (SUCCESS != CreateTasksByCategory(operators[category_operator])) {
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue