forked from mindspore-Ecosystem/mindspore
!14828 [MD] refactor pk sampler in minddataset
From: @liyong126 Reviewed-by: @liucunwei,@heleiwang Signed-off-by: @liucunwei
This commit is contained in:
commit
f1a28e17ce
|
@ -55,6 +55,8 @@
|
||||||
#include "minddata/mindrecord/include/shard_shuffle.h"
|
#include "minddata/mindrecord/include/shard_shuffle.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
#define API_PUBLIC __attribute__((visibility("default")))
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace mindrecord {
|
namespace mindrecord {
|
||||||
using ROW_GROUPS =
|
using ROW_GROUPS =
|
||||||
|
@ -65,7 +67,7 @@ using TASK_RETURN_CONTENT =
|
||||||
std::pair<MSRStatus, std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>>;
|
std::pair<MSRStatus, std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>>;
|
||||||
const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode
|
const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode
|
||||||
|
|
||||||
class __attribute__((visibility("default"))) ShardReader {
|
class API_PUBLIC ShardReader {
|
||||||
public:
|
public:
|
||||||
ShardReader();
|
ShardReader();
|
||||||
|
|
||||||
|
@ -203,7 +205,7 @@ class __attribute__((visibility("default"))) ShardReader {
|
||||||
void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; }
|
void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; }
|
||||||
|
|
||||||
/// \brief get all classes
|
/// \brief get all classes
|
||||||
MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories);
|
MSRStatus GetAllClasses(const std::string &category_field, std::shared_ptr<std::set<std::string>> category_ptr);
|
||||||
|
|
||||||
/// \brief get the size of blob data
|
/// \brief get the size of blob data
|
||||||
MSRStatus GetTotalBlobSize(int64_t *total_blob_size);
|
MSRStatus GetTotalBlobSize(int64_t *total_blob_size);
|
||||||
|
@ -215,11 +217,12 @@ class __attribute__((visibility("default"))) ShardReader {
|
||||||
private:
|
private:
|
||||||
/// \brief wrap up labels to json format
|
/// \brief wrap up labels to json format
|
||||||
MSRStatus ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs,
|
MSRStatus ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs,
|
||||||
std::vector<std::vector<std::vector<uint64_t>>> &offsets, int shard_id,
|
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
|
||||||
const std::vector<std::string> &columns, std::vector<std::vector<json>> &column_values);
|
int shard_id, const std::vector<std::string> &columns,
|
||||||
|
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);
|
||||||
|
|
||||||
/// \brief read all rows for specified columns
|
/// \brief read all rows for specified columns
|
||||||
ROW_GROUPS ReadAllRowGroup(std::vector<std::string> &columns);
|
ROW_GROUPS ReadAllRowGroup(const std::vector<std::string> &columns);
|
||||||
|
|
||||||
/// \brief read row meta by shard_id and sample_id
|
/// \brief read row meta by shard_id and sample_id
|
||||||
ROW_GROUPS ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
|
ROW_GROUPS ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
|
||||||
|
@ -227,8 +230,8 @@ class __attribute__((visibility("default"))) ShardReader {
|
||||||
|
|
||||||
/// \brief read all rows in one shard
|
/// \brief read all rows in one shard
|
||||||
MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
|
MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
|
||||||
std::vector<std::vector<std::vector<uint64_t>>> &offsets,
|
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
|
||||||
std::vector<std::vector<json>> &column_values);
|
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);
|
||||||
|
|
||||||
/// \brief initialize reader
|
/// \brief initialize reader
|
||||||
MSRStatus Init(const std::vector<std::string> &file_paths, bool load_dataset);
|
MSRStatus Init(const std::vector<std::string> &file_paths, bool load_dataset);
|
||||||
|
@ -243,8 +246,12 @@ class __attribute__((visibility("default"))) ShardReader {
|
||||||
std::vector<std::vector<uint64_t>> GetImageOffset(int group_id, int shard_id,
|
std::vector<std::vector<uint64_t>> GetImageOffset(int group_id, int shard_id,
|
||||||
const std::pair<std::string, std::string> &criteria = {"", ""});
|
const std::pair<std::string, std::string> &criteria = {"", ""});
|
||||||
|
|
||||||
|
/// \brief get page id by category
|
||||||
|
std::pair<MSRStatus, std::vector<uint64_t>> GetPagesByCategory(int shard_id,
|
||||||
|
const std::pair<std::string, std::string> &criteria);
|
||||||
/// \brief execute sqlite query with prepare statement
|
/// \brief execute sqlite query with prepare statement
|
||||||
MSRStatus QueryWithCriteria(sqlite3 *db, string &sql, string criteria, std::vector<std::vector<std::string>> &labels);
|
MSRStatus QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria,
|
||||||
|
std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr);
|
||||||
|
|
||||||
/// \brief get column values
|
/// \brief get column values
|
||||||
std::pair<MSRStatus, std::vector<json>> GetLabels(int group_id, int shard_id, const std::vector<std::string> &columns,
|
std::pair<MSRStatus, std::vector<json>> GetLabels(int group_id, int shard_id, const std::vector<std::string> &columns,
|
||||||
|
@ -257,8 +264,7 @@ class __attribute__((visibility("default"))) ShardReader {
|
||||||
""});
|
""});
|
||||||
|
|
||||||
/// \brief create category-applied task list
|
/// \brief create category-applied task list
|
||||||
MSRStatus CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
MSRStatus CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op);
|
||||||
const std::shared_ptr<ShardOperator> &op);
|
|
||||||
|
|
||||||
/// \brief create task list in row-reader mode
|
/// \brief create task list in row-reader mode
|
||||||
MSRStatus CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
MSRStatus CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||||
|
@ -286,13 +292,15 @@ class __attribute__((visibility("default"))) ShardReader {
|
||||||
int shard_id, const std::vector<std::string> &columns, const std::vector<std::vector<std::string>> &label_offsets);
|
int shard_id, const std::vector<std::string> &columns, const std::vector<std::vector<std::string>> &label_offsets);
|
||||||
|
|
||||||
/// \brief get classes in one shard
|
/// \brief get classes in one shard
|
||||||
void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set<std::string> &categories);
|
void GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
|
||||||
|
std::shared_ptr<std::set<std::string>> category_ptr);
|
||||||
|
|
||||||
/// \brief get number of classes
|
/// \brief get number of classes
|
||||||
int64_t GetNumClasses(const std::string &category_field);
|
int64_t GetNumClasses(const std::string &category_field);
|
||||||
|
|
||||||
/// \brief get meta of header
|
/// \brief get meta of header
|
||||||
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data);
|
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path,
|
||||||
|
std::shared_ptr<json> meta_data_ptr);
|
||||||
|
|
||||||
/// \brief extract uncompressed data based on column list
|
/// \brief extract uncompressed data based on column list
|
||||||
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> UnCompressBlob(const std::vector<uint8_t> &raw_blob_data);
|
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> UnCompressBlob(const std::vector<uint8_t> &raw_blob_data);
|
||||||
|
|
|
@ -51,7 +51,8 @@ ShardReader::ShardReader()
|
||||||
lazy_load_(false),
|
lazy_load_(false),
|
||||||
shard_sample_count_() {}
|
shard_sample_count_() {}
|
||||||
|
|
||||||
std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path, json &meta_data) {
|
std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path,
|
||||||
|
std::shared_ptr<json> meta_data_ptr) {
|
||||||
if (!IsLegalFile(file_path)) {
|
if (!IsLegalFile(file_path)) {
|
||||||
return {FAILED, {}};
|
return {FAILED, {}};
|
||||||
}
|
}
|
||||||
|
@ -60,16 +61,16 @@ std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::s
|
||||||
return {FAILED, {}};
|
return {FAILED, {}};
|
||||||
}
|
}
|
||||||
auto header = ret.second;
|
auto header = ret.second;
|
||||||
meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]},
|
*meta_data_ptr = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]},
|
||||||
{"version", header["version"]}, {"index_fields", header["index_fields"]},
|
{"version", header["version"]}, {"index_fields", header["index_fields"]},
|
||||||
{"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}};
|
{"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}};
|
||||||
return {SUCCESS, header["shard_addresses"]};
|
return {SUCCESS, header["shard_addresses"]};
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool load_dataset) {
|
MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool load_dataset) {
|
||||||
std::string file_path = file_paths[0];
|
std::string file_path = file_paths[0];
|
||||||
json first_meta_data = json();
|
auto first_meta_data_ptr = std::make_shared<json>();
|
||||||
auto ret = GetMeta(file_path, first_meta_data);
|
auto ret = GetMeta(file_path, first_meta_data_ptr);
|
||||||
if (ret.first != SUCCESS) {
|
if (ret.first != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
@ -91,12 +92,12 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
for (const auto &file : file_paths_) {
|
for (const auto &file : file_paths_) {
|
||||||
json meta_data = json();
|
auto meta_data_ptr = std::make_shared<json>();
|
||||||
auto ret1 = GetMeta(file, meta_data);
|
auto ret1 = GetMeta(file, meta_data_ptr);
|
||||||
if (ret1.first != SUCCESS) {
|
if (ret1.first != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
if (meta_data != first_meta_data) {
|
if (*meta_data_ptr != *first_meta_data_ptr) {
|
||||||
MS_LOG(ERROR) << "Mindrecord files meta information is different.";
|
MS_LOG(ERROR) << "Mindrecord files meta information is different.";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
@ -140,7 +141,7 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
|
||||||
header_size_ = shard_header_->GetHeaderSize();
|
header_size_ = shard_header_->GetHeaderSize();
|
||||||
page_size_ = shard_header_->GetPageSize();
|
page_size_ = shard_header_->GetPageSize();
|
||||||
// version < 3.0
|
// version < 3.0
|
||||||
if (first_meta_data["version"] < kVersion) {
|
if ((*first_meta_data_ptr)["version"] < kVersion) {
|
||||||
shard_column_ = std::make_shared<ShardColumn>(shard_header_, false);
|
shard_column_ = std::make_shared<ShardColumn>(shard_header_, false);
|
||||||
} else {
|
} else {
|
||||||
shard_column_ = std::make_shared<ShardColumn>(shard_header_, true);
|
shard_column_ = std::make_shared<ShardColumn>(shard_header_, true);
|
||||||
|
@ -314,14 +315,14 @@ MSRStatus ShardReader::GetTotalBlobSize(int64_t *total_blob_size) {
|
||||||
|
|
||||||
MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels,
|
MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels,
|
||||||
std::shared_ptr<std::fstream> fs,
|
std::shared_ptr<std::fstream> fs,
|
||||||
std::vector<std::vector<std::vector<uint64_t>>> &offsets, int shard_id,
|
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
|
||||||
const std::vector<std::string> &columns,
|
int shard_id, const std::vector<std::string> &columns,
|
||||||
std::vector<std::vector<json>> &column_values) {
|
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
|
||||||
for (int i = 0; i < static_cast<int>(labels.size()); ++i) {
|
for (int i = 0; i < static_cast<int>(labels.size()); ++i) {
|
||||||
uint64_t group_id = std::stoull(labels[i][0]);
|
uint64_t group_id = std::stoull(labels[i][0]);
|
||||||
uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len;
|
uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len;
|
||||||
uint64_t offset_end = std::stoull(labels[i][2]);
|
uint64_t offset_end = std::stoull(labels[i][2]);
|
||||||
offsets[shard_id].emplace_back(
|
(*offset_ptr)[shard_id].emplace_back(
|
||||||
std::vector<uint64_t>{static_cast<uint64_t>(shard_id), group_id, offset_start, offset_end});
|
std::vector<uint64_t>{static_cast<uint64_t>(shard_id), group_id, offset_start, offset_end});
|
||||||
if (!all_in_index_) {
|
if (!all_in_index_) {
|
||||||
int raw_page_id = std::stoi(labels[i][3]);
|
int raw_page_id = std::stoi(labels[i][3]);
|
||||||
|
@ -353,7 +354,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
|
||||||
} else {
|
} else {
|
||||||
tmp = label_json;
|
tmp = label_json;
|
||||||
}
|
}
|
||||||
column_values[shard_id].emplace_back(tmp);
|
(*col_val_ptr)[shard_id].emplace_back(tmp);
|
||||||
} else {
|
} else {
|
||||||
json construct_json;
|
json construct_json;
|
||||||
for (unsigned int j = 0; j < columns.size(); ++j) {
|
for (unsigned int j = 0; j < columns.size(); ++j) {
|
||||||
|
@ -373,7 +374,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
|
||||||
construct_json[columns[j]] = std::string(labels[i][j + 3]);
|
construct_json[columns[j]] = std::string(labels[i][j + 3]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
column_values[shard_id].emplace_back(construct_json);
|
(*col_val_ptr)[shard_id].emplace_back(construct_json);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -381,8 +382,8 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
|
MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
|
||||||
std::vector<std::vector<std::vector<uint64_t>>> &offsets,
|
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
|
||||||
std::vector<std::vector<json>> &column_values) {
|
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
|
||||||
auto db = database_paths_[shard_id];
|
auto db = database_paths_[shard_id];
|
||||||
std::vector<std::vector<std::string>> labels;
|
std::vector<std::vector<std::string>> labels;
|
||||||
char *errmsg = nullptr;
|
char *errmsg = nullptr;
|
||||||
|
@ -406,10 +407,11 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sqlite3_free(errmsg);
|
sqlite3_free(errmsg);
|
||||||
return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values);
|
return ConvertLabelToJson(labels, fs, offset_ptr, shard_id, columns, col_val_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) {
|
MSRStatus ShardReader::GetAllClasses(const std::string &category_field,
|
||||||
|
std::shared_ptr<std::set<std::string>> category_ptr) {
|
||||||
std::map<std::string, uint64_t> index_columns;
|
std::map<std::string, uint64_t> index_columns;
|
||||||
for (auto &field : GetShardHeader()->GetFields()) {
|
for (auto &field : GetShardHeader()->GetFields()) {
|
||||||
index_columns[field.second] = field.first;
|
index_columns[field.second] = field.first;
|
||||||
|
@ -425,7 +427,7 @@ MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set
|
||||||
std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES";
|
std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES";
|
||||||
std::vector<std::thread> threads = std::vector<std::thread>(shard_count_);
|
std::vector<std::thread> threads = std::vector<std::thread>(shard_count_);
|
||||||
for (int x = 0; x < shard_count_; x++) {
|
for (int x = 0; x < shard_count_; x++) {
|
||||||
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, std::ref(categories));
|
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, category_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int x = 0; x < shard_count_; x++) {
|
for (int x = 0; x < shard_count_; x++) {
|
||||||
|
@ -434,8 +436,8 @@ MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql,
|
void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
|
||||||
std::set<std::string> &categories) {
|
std::shared_ptr<std::set<std::string>> category_ptr) {
|
||||||
if (nullptr == db) {
|
if (nullptr == db) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -452,20 +454,22 @@ void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string
|
||||||
MS_LOG(INFO) << "Get " << static_cast<int>(columns.size()) << " records from shard " << shard_id << " index.";
|
MS_LOG(INFO) << "Get " << static_cast<int>(columns.size()) << " records from shard " << shard_id << " index.";
|
||||||
std::lock_guard<std::mutex> lck(shard_locker_);
|
std::lock_guard<std::mutex> lck(shard_locker_);
|
||||||
for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
|
for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
|
||||||
categories.emplace(columns[i][0]);
|
category_ptr->emplace(columns[i][0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector<std::string> &columns) {
|
ROW_GROUPS ShardReader::ReadAllRowGroup(const std::vector<std::string> &columns) {
|
||||||
std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
|
std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
|
||||||
std::vector<std::vector<std::vector<uint64_t>>> offsets(shard_count_, std::vector<std::vector<uint64_t>>{});
|
auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
|
||||||
std::vector<std::vector<json>> column_values(shard_count_, std::vector<json>{});
|
shard_count_, std::vector<std::vector<uint64_t>>{});
|
||||||
|
auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});
|
||||||
|
|
||||||
if (all_in_index_) {
|
if (all_in_index_) {
|
||||||
for (unsigned int i = 0; i < columns.size(); ++i) {
|
for (unsigned int i = 0; i < columns.size(); ++i) {
|
||||||
fields += ',';
|
fields += ',';
|
||||||
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]));
|
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]));
|
||||||
if (ret.first != SUCCESS) {
|
if (ret.first != SUCCESS) {
|
||||||
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values));
|
return std::make_tuple(FAILED, std::move(*offset_ptr), std::move(*col_val_ptr));
|
||||||
}
|
}
|
||||||
fields += ret.second;
|
fields += ret.second;
|
||||||
}
|
}
|
||||||
|
@ -477,27 +481,27 @@ ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector<std::string> &columns) {
|
||||||
|
|
||||||
std::vector<std::thread> thread_read_db = std::vector<std::thread>(shard_count_);
|
std::vector<std::thread> thread_read_db = std::vector<std::thread>(shard_count_);
|
||||||
for (int x = 0; x < shard_count_; x++) {
|
for (int x = 0; x < shard_count_; x++) {
|
||||||
thread_read_db[x] =
|
thread_read_db[x] = std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, offset_ptr, col_val_ptr);
|
||||||
std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, std::ref(offsets), std::ref(column_values));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int x = 0; x < shard_count_; x++) {
|
for (int x = 0; x < shard_count_; x++) {
|
||||||
thread_read_db[x].join();
|
thread_read_db[x].join();
|
||||||
}
|
}
|
||||||
return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values));
|
return std::make_tuple(SUCCESS, std::move(*offset_ptr), std::move(*col_val_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns,
|
ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns,
|
||||||
const uint32_t &shard_id, const uint32_t &sample_id) {
|
const uint32_t &shard_id, const uint32_t &sample_id) {
|
||||||
std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
|
std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
|
||||||
std::vector<std::vector<std::vector<uint64_t>>> offsets(shard_count_, std::vector<std::vector<uint64_t>>{});
|
auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
|
||||||
std::vector<std::vector<json>> column_values(shard_count_, std::vector<json>{});
|
shard_count_, std::vector<std::vector<uint64_t>>{});
|
||||||
|
auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});
|
||||||
if (all_in_index_) {
|
if (all_in_index_) {
|
||||||
for (unsigned int i = 0; i < columns.size(); ++i) {
|
for (unsigned int i = 0; i < columns.size(); ++i) {
|
||||||
fields += ',';
|
fields += ',';
|
||||||
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]));
|
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]));
|
||||||
if (ret.first != SUCCESS) {
|
if (ret.first != SUCCESS) {
|
||||||
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values));
|
return std::make_tuple(FAILED, std::move(*offset_ptr), std::move(*col_val_ptr));
|
||||||
}
|
}
|
||||||
fields += ret.second;
|
fields += ret.second;
|
||||||
}
|
}
|
||||||
|
@ -507,12 +511,12 @@ ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std::
|
||||||
|
|
||||||
std::string sql = "SELECT " + fields + " FROM INDEXES WHERE ROW_ID = " + std::to_string(sample_id);
|
std::string sql = "SELECT " + fields + " FROM INDEXES WHERE ROW_ID = " + std::to_string(sample_id);
|
||||||
|
|
||||||
if (ReadAllRowsInShard(shard_id, sql, columns, offsets, column_values) != SUCCESS) {
|
if (ReadAllRowsInShard(shard_id, sql, columns, offset_ptr, col_val_ptr) != SUCCESS) {
|
||||||
MS_LOG(ERROR) << "Read shard id: " << shard_id << ", sample id: " << sample_id << " from index failed.";
|
MS_LOG(ERROR) << "Read shard id: " << shard_id << ", sample id: " << sample_id << " from index failed.";
|
||||||
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values));
|
return std::make_tuple(FAILED, std::move(*offset_ptr), std::move(*col_val_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values));
|
return std::make_tuple(SUCCESS, std::move(*offset_ptr), std::move(*col_val_ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns) {
|
ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns) {
|
||||||
|
@ -550,7 +554,10 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id,
|
||||||
uint64_t page_length = page->GetPageSize();
|
uint64_t page_length = page->GetPageSize();
|
||||||
uint64_t page_offset = page_size_ * page->GetPageID() + header_size_;
|
uint64_t page_offset = page_size_ * page->GetPageID() + header_size_;
|
||||||
std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria);
|
std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria);
|
||||||
|
if (image_offset.empty()) {
|
||||||
|
return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::vector<std::vector<uint64_t>>(),
|
||||||
|
std::vector<json>());
|
||||||
|
}
|
||||||
auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria);
|
auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria);
|
||||||
if (status_labels.first != SUCCESS) {
|
if (status_labels.first != SUCCESS) {
|
||||||
return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>());
|
return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>());
|
||||||
|
@ -601,7 +608,7 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
|
||||||
db = nullptr;
|
db = nullptr;
|
||||||
return std::vector<std::vector<uint64_t>>();
|
return std::vector<std::vector<uint64_t>>();
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "Get " << static_cast<int>(image_offsets.size()) << "records from index.";
|
MS_LOG(DEBUG) << "Get " << static_cast<int>(image_offsets.size()) << " records from index.";
|
||||||
}
|
}
|
||||||
std::vector<std::vector<uint64_t>> res;
|
std::vector<std::vector<uint64_t>> res;
|
||||||
for (int i = static_cast<int>(image_offsets.size()) - 1; i >= 0; i--) res.emplace_back(std::vector<uint64_t>{0, 0});
|
for (int i = static_cast<int>(image_offsets.size()) - 1; i >= 0; i--) res.emplace_back(std::vector<uint64_t>{0, 0});
|
||||||
|
@ -614,6 +621,44 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<MSRStatus, std::vector<uint64_t>> ShardReader::GetPagesByCategory(
|
||||||
|
int shard_id, const std::pair<std::string, std::string> &criteria) {
|
||||||
|
auto db = database_paths_[shard_id];
|
||||||
|
|
||||||
|
std::string sql = "SELECT DISTINCT PAGE_ID_BLOB FROM INDEXES WHERE 1 = 1 ";
|
||||||
|
|
||||||
|
if (!criteria.first.empty()) {
|
||||||
|
auto schema = shard_header_->GetSchemas()[0]->GetSchema();
|
||||||
|
|
||||||
|
if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) {
|
||||||
|
sql +=
|
||||||
|
" AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second;
|
||||||
|
} else {
|
||||||
|
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" +
|
||||||
|
criteria.second + "'";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sql += ";";
|
||||||
|
std::vector<std::vector<std::string>> page_ids;
|
||||||
|
char *errmsg = nullptr;
|
||||||
|
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &page_ids, &errmsg);
|
||||||
|
if (rc != SQLITE_OK) {
|
||||||
|
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
|
||||||
|
sqlite3_free(errmsg);
|
||||||
|
sqlite3_close(db);
|
||||||
|
db = nullptr;
|
||||||
|
return std::make_pair(FAILED, std::vector<uint64_t>());
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "Get " << page_ids.size() << "pages from index.";
|
||||||
|
}
|
||||||
|
std::vector<uint64_t> res;
|
||||||
|
for (int i = 0; i < static_cast<int>(page_ids.size()); ++i) {
|
||||||
|
res.emplace_back(std::stoull(page_ids[i][0]));
|
||||||
|
}
|
||||||
|
sqlite3_free(errmsg);
|
||||||
|
return std::make_pair(SUCCESS, res);
|
||||||
|
}
|
||||||
|
|
||||||
std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() {
|
std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() {
|
||||||
std::vector<std::string> blob_fields;
|
std::vector<std::string> blob_fields;
|
||||||
for (auto &p : GetShardHeader()->GetSchemas()) {
|
for (auto &p : GetShardHeader()->GetSchemas()) {
|
||||||
|
@ -642,8 +687,8 @@ void ShardReader::CheckIfColumnInIndex(const std::vector<std::string> &columns)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, string &sql, string criteria,
|
MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria,
|
||||||
std::vector<std::vector<std::string>> &labels) {
|
std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr) {
|
||||||
sqlite3_stmt *stmt = nullptr;
|
sqlite3_stmt *stmt = nullptr;
|
||||||
if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
|
if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
|
||||||
MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql;
|
MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql;
|
||||||
|
@ -661,7 +706,7 @@ MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, string &sql, string criter
|
||||||
for (int i = 0; i < ncols; i++) {
|
for (int i = 0; i < ncols; i++) {
|
||||||
tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
|
tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
|
||||||
}
|
}
|
||||||
labels.push_back(tmp);
|
labels_ptr->push_back(tmp);
|
||||||
rc = sqlite3_step(stmt);
|
rc = sqlite3_step(stmt);
|
||||||
}
|
}
|
||||||
(void)sqlite3_finalize(stmt);
|
(void)sqlite3_finalize(stmt);
|
||||||
|
@ -724,16 +769,16 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabelsFromPage(
|
||||||
auto db = database_paths_[shard_id];
|
auto db = database_paths_[shard_id];
|
||||||
std::string sql = "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = " +
|
std::string sql = "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = " +
|
||||||
std::to_string(page_id);
|
std::to_string(page_id);
|
||||||
std::vector<std::vector<std::string>> label_offsets;
|
auto label_offset_ptr = std::make_shared<std::vector<std::vector<std::string>>>();
|
||||||
if (!criteria.first.empty()) {
|
if (!criteria.first.empty()) {
|
||||||
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria";
|
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria";
|
||||||
if (QueryWithCriteria(db, sql, criteria.second, label_offsets) == FAILED) {
|
if (QueryWithCriteria(db, sql, criteria.second, label_offset_ptr) == FAILED) {
|
||||||
return {FAILED, {}};
|
return {FAILED, {}};
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
sql += ";";
|
sql += ";";
|
||||||
char *errmsg = nullptr;
|
char *errmsg = nullptr;
|
||||||
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &label_offsets, &errmsg);
|
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, label_offset_ptr.get(), &errmsg);
|
||||||
if (rc != SQLITE_OK) {
|
if (rc != SQLITE_OK) {
|
||||||
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
|
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
|
||||||
sqlite3_free(errmsg);
|
sqlite3_free(errmsg);
|
||||||
|
@ -741,11 +786,11 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabelsFromPage(
|
||||||
db = nullptr;
|
db = nullptr;
|
||||||
return {FAILED, {}};
|
return {FAILED, {}};
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "Get " << label_offsets.size() << "records from index.";
|
MS_LOG(DEBUG) << "Get " << label_offset_ptr->size() << " records from index.";
|
||||||
sqlite3_free(errmsg);
|
sqlite3_free(errmsg);
|
||||||
}
|
}
|
||||||
// get labels from binary file
|
// get labels from binary file
|
||||||
return GetLabelsFromBinaryFile(shard_id, columns, label_offsets);
|
return GetLabelsFromBinaryFile(shard_id, columns, *label_offset_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int shard_id,
|
std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int shard_id,
|
||||||
|
@ -760,17 +805,17 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
|
||||||
fields += columns[i] + "_" + std::to_string(schema_id);
|
fields += columns[i] + "_" + std::to_string(schema_id);
|
||||||
}
|
}
|
||||||
if (fields.empty()) fields = "*";
|
if (fields.empty()) fields = "*";
|
||||||
std::vector<std::vector<std::string>> labels;
|
auto labels_ptr = std::make_shared<std::vector<std::vector<std::string>>>();
|
||||||
std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id);
|
std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id);
|
||||||
if (!criteria.first.empty()) {
|
if (!criteria.first.empty()) {
|
||||||
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria";
|
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria";
|
||||||
if (QueryWithCriteria(db, sql, criteria.second, labels) == FAILED) {
|
if (QueryWithCriteria(db, sql, criteria.second, labels_ptr) == FAILED) {
|
||||||
return {FAILED, {}};
|
return {FAILED, {}};
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
sql += ";";
|
sql += ";";
|
||||||
char *errmsg = nullptr;
|
char *errmsg = nullptr;
|
||||||
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg);
|
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, labels_ptr.get(), &errmsg);
|
||||||
if (rc != SQLITE_OK) {
|
if (rc != SQLITE_OK) {
|
||||||
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
|
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
|
||||||
sqlite3_free(errmsg);
|
sqlite3_free(errmsg);
|
||||||
|
@ -778,13 +823,13 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
|
||||||
db = nullptr;
|
db = nullptr;
|
||||||
return {FAILED, {}};
|
return {FAILED, {}};
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "Get " << static_cast<int>(labels.size()) << "records from index.";
|
MS_LOG(DEBUG) << "Get " << static_cast<int>(labels_ptr->size()) << " records from index.";
|
||||||
}
|
}
|
||||||
sqlite3_free(errmsg);
|
sqlite3_free(errmsg);
|
||||||
}
|
}
|
||||||
std::vector<json> ret;
|
std::vector<json> ret;
|
||||||
for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{});
|
for (unsigned int i = 0; i < labels_ptr->size(); ++i) ret.emplace_back(json{});
|
||||||
for (unsigned int i = 0; i < labels.size(); ++i) {
|
for (unsigned int i = 0; i < labels_ptr->size(); ++i) {
|
||||||
json construct_json;
|
json construct_json;
|
||||||
for (unsigned int j = 0; j < columns.size(); ++j) {
|
for (unsigned int j = 0; j < columns.size(); ++j) {
|
||||||
// construct json "f1": value
|
// construct json "f1": value
|
||||||
|
@ -792,15 +837,15 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
|
||||||
|
|
||||||
// convert the string to base type by schema
|
// convert the string to base type by schema
|
||||||
if (schema[columns[j]]["type"] == "int32") {
|
if (schema[columns[j]]["type"] == "int32") {
|
||||||
construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j]);
|
construct_json[columns[j]] = StringToNum<int32_t>((*labels_ptr)[i][j]);
|
||||||
} else if (schema[columns[j]]["type"] == "int64") {
|
} else if (schema[columns[j]]["type"] == "int64") {
|
||||||
construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j]);
|
construct_json[columns[j]] = StringToNum<int64_t>((*labels_ptr)[i][j]);
|
||||||
} else if (schema[columns[j]]["type"] == "float32") {
|
} else if (schema[columns[j]]["type"] == "float32") {
|
||||||
construct_json[columns[j]] = StringToNum<float>(labels[i][j]);
|
construct_json[columns[j]] = StringToNum<float>((*labels_ptr)[i][j]);
|
||||||
} else if (schema[columns[j]]["type"] == "float64") {
|
} else if (schema[columns[j]]["type"] == "float64") {
|
||||||
construct_json[columns[j]] = StringToNum<double>(labels[i][j]);
|
construct_json[columns[j]] = StringToNum<double>((*labels_ptr)[i][j]);
|
||||||
} else {
|
} else {
|
||||||
construct_json[columns[j]] = std::string(labels[i][j]);
|
construct_json[columns[j]] = std::string((*labels_ptr)[i][j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ret[i] = construct_json;
|
ret[i] = construct_json;
|
||||||
|
@ -834,7 +879,7 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) {
|
||||||
}
|
}
|
||||||
std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES";
|
std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES";
|
||||||
std::vector<std::thread> threads = std::vector<std::thread>(shard_count);
|
std::vector<std::thread> threads = std::vector<std::thread>(shard_count);
|
||||||
std::set<std::string> categories;
|
auto category_ptr = std::make_shared<std::set<std::string>>();
|
||||||
for (int x = 0; x < shard_count; x++) {
|
for (int x = 0; x < shard_count; x++) {
|
||||||
sqlite3 *db = nullptr;
|
sqlite3 *db = nullptr;
|
||||||
int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
|
int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
|
||||||
|
@ -843,13 +888,13 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) {
|
||||||
<< sqlite3_errmsg(db);
|
<< sqlite3_errmsg(db);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, std::ref(categories));
|
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, category_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int x = 0; x < shard_count; x++) {
|
for (int x = 0; x < shard_count; x++) {
|
||||||
threads[x].join();
|
threads[x].join();
|
||||||
}
|
}
|
||||||
return categories.size();
|
return category_ptr->size();
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
|
MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
|
||||||
|
@ -1008,8 +1053,7 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op) {
|
||||||
const std::shared_ptr<ShardOperator> &op) {
|
|
||||||
CheckIfColumnInIndex(selected_columns_);
|
CheckIfColumnInIndex(selected_columns_);
|
||||||
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
|
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
|
||||||
auto categories = category_op->GetCategories();
|
auto categories = category_op->GetCategories();
|
||||||
|
@ -1033,42 +1077,50 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
|
||||||
MS_LOG(ERROR) << "Invalid parameter, num_categories must be greater than 0, but got " << num_elements;
|
MS_LOG(ERROR) << "Invalid parameter, num_categories must be greater than 0, but got " << num_elements;
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
std::set<std::string> categories_set;
|
auto category_ptr = std::make_shared<std::set<std::string>>();
|
||||||
auto ret = GetAllClasses(category_field, categories_set);
|
auto ret = GetAllClasses(category_field, category_ptr);
|
||||||
if (SUCCESS != ret) {
|
if (SUCCESS != ret) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (auto it = categories_set.begin(); it != categories_set.end() && i < num_categories; ++it) {
|
for (auto it = category_ptr->begin(); it != category_ptr->end() && i < num_categories; ++it) {
|
||||||
categories.emplace_back(category_field, *it);
|
categories.emplace_back(category_field, *it);
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate task list, a task will create a batch
|
// Generate task list, a task will create a batch
|
||||||
std::vector<ShardTask> categoryTasks(categories.size());
|
std::vector<ShardTask> categoryTasks(categories.size());
|
||||||
for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) {
|
for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) {
|
||||||
int category_index = 0;
|
int category_index = 0;
|
||||||
for (const auto &rg : row_group_summary) {
|
for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) {
|
||||||
if (category_index >= num_elements) break;
|
auto res = GetPagesByCategory(shard_id, categories[categoryNo]);
|
||||||
auto shard_id = std::get<0>(rg);
|
if (SUCCESS != res.first) {
|
||||||
auto group_id = std::get<1>(rg);
|
|
||||||
|
|
||||||
auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_);
|
|
||||||
if (SUCCESS != std::get<0>(details)) {
|
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
auto offsets = std::get<4>(details);
|
auto page_ids = res.second;
|
||||||
|
for (const auto &page_id : page_ids) {
|
||||||
auto number_of_rows = offsets.size();
|
if (category_index >= num_elements) break;
|
||||||
for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) {
|
const auto &page_t = shard_header_->GetPage(shard_id, page_id);
|
||||||
if (category_index < num_elements) {
|
const auto &page = page_t.first;
|
||||||
categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id, std::get<4>(details)[iStart],
|
auto group_id = page->GetPageTypeID();
|
||||||
std::get<5>(details)[iStart]);
|
auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_);
|
||||||
category_index++;
|
if (SUCCESS != std::get<0>(details)) {
|
||||||
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
auto offsets = std::get<4>(details);
|
||||||
|
|
||||||
|
auto number_of_rows = offsets.size();
|
||||||
|
for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) {
|
||||||
|
if (category_index < num_elements) {
|
||||||
|
categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id,
|
||||||
|
std::get<4>(details)[iStart], std::get<5>(details)[iStart]);
|
||||||
|
category_index++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks";
|
|
||||||
}
|
}
|
||||||
tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);
|
tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);
|
||||||
if (SUCCESS != (*category_op)(tasks_)) {
|
if (SUCCESS != (*category_op)(tasks_)) {
|
||||||
|
@ -1189,7 +1241,7 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) {
|
if (SUCCESS != CreateTasksByCategory(operators[category_operator])) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue