!14828 [MD] refactor pk sampler in minddataset

From: @liyong126
Reviewed-by: @liucunwei,@heleiwang
Signed-off-by: @liucunwei
This commit is contained in:
mindspore-ci-bot 2021-04-10 16:52:19 +08:00 committed by Gitee
commit f1a28e17ce
2 changed files with 155 additions and 95 deletions

View File

@ -55,6 +55,8 @@
#include "minddata/mindrecord/include/shard_shuffle.h"
#include "utils/log_adapter.h"
#define API_PUBLIC __attribute__((visibility("default")))
namespace mindspore {
namespace mindrecord {
using ROW_GROUPS =
@ -65,7 +67,7 @@ using TASK_RETURN_CONTENT =
std::pair<MSRStatus, std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>>;
const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode
class __attribute__((visibility("default"))) ShardReader {
class API_PUBLIC ShardReader {
public:
ShardReader();
@ -203,7 +205,7 @@ class __attribute__((visibility("default"))) ShardReader {
void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; }
/// \brief get all classes
MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories);
MSRStatus GetAllClasses(const std::string &category_field, std::shared_ptr<std::set<std::string>> category_ptr);
/// \brief get the size of blob data
MSRStatus GetTotalBlobSize(int64_t *total_blob_size);
@ -215,11 +217,12 @@ class __attribute__((visibility("default"))) ShardReader {
private:
/// \brief wrap up labels to json format
MSRStatus ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs,
std::vector<std::vector<std::vector<uint64_t>>> &offsets, int shard_id,
const std::vector<std::string> &columns, std::vector<std::vector<json>> &column_values);
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
int shard_id, const std::vector<std::string> &columns,
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);
/// \brief read all rows for specified columns
ROW_GROUPS ReadAllRowGroup(std::vector<std::string> &columns);
ROW_GROUPS ReadAllRowGroup(const std::vector<std::string> &columns);
/// \brief read row meta by shard_id and sample_id
ROW_GROUPS ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
@ -227,8 +230,8 @@ class __attribute__((visibility("default"))) ShardReader {
/// \brief read all rows in one shard
MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
std::vector<std::vector<std::vector<uint64_t>>> &offsets,
std::vector<std::vector<json>> &column_values);
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr);
/// \brief initialize reader
MSRStatus Init(const std::vector<std::string> &file_paths, bool load_dataset);
@ -243,8 +246,12 @@ class __attribute__((visibility("default"))) ShardReader {
std::vector<std::vector<uint64_t>> GetImageOffset(int group_id, int shard_id,
const std::pair<std::string, std::string> &criteria = {"", ""});
/// \brief get page id by category
std::pair<MSRStatus, std::vector<uint64_t>> GetPagesByCategory(int shard_id,
const std::pair<std::string, std::string> &criteria);
/// \brief execute sqlite query with prepare statement
MSRStatus QueryWithCriteria(sqlite3 *db, string &sql, string criteria, std::vector<std::vector<std::string>> &labels);
MSRStatus QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria,
std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr);
/// \brief get column values
std::pair<MSRStatus, std::vector<json>> GetLabels(int group_id, int shard_id, const std::vector<std::string> &columns,
@ -257,8 +264,7 @@ class __attribute__((visibility("default"))) ShardReader {
""});
/// \brief create category-applied task list
MSRStatus CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::shared_ptr<ShardOperator> &op);
MSRStatus CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op);
/// \brief create task list in row-reader mode
MSRStatus CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
@ -286,13 +292,15 @@ class __attribute__((visibility("default"))) ShardReader {
int shard_id, const std::vector<std::string> &columns, const std::vector<std::vector<std::string>> &label_offsets);
/// \brief get classes in one shard
void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set<std::string> &categories);
void GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
std::shared_ptr<std::set<std::string>> category_ptr);
/// \brief get number of classes
int64_t GetNumClasses(const std::string &category_field);
/// \brief get meta of header
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data);
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path,
std::shared_ptr<json> meta_data_ptr);
/// \brief extract uncompressed data based on column list
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> UnCompressBlob(const std::vector<uint8_t> &raw_blob_data);

View File

@ -51,7 +51,8 @@ ShardReader::ShardReader()
lazy_load_(false),
shard_sample_count_() {}
std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path, json &meta_data) {
std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path,
std::shared_ptr<json> meta_data_ptr) {
if (!IsLegalFile(file_path)) {
return {FAILED, {}};
}
@ -60,16 +61,16 @@ std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::s
return {FAILED, {}};
}
auto header = ret.second;
meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]},
{"version", header["version"]}, {"index_fields", header["index_fields"]},
{"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}};
*meta_data_ptr = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]},
{"version", header["version"]}, {"index_fields", header["index_fields"]},
{"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}};
return {SUCCESS, header["shard_addresses"]};
}
MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool load_dataset) {
std::string file_path = file_paths[0];
json first_meta_data = json();
auto ret = GetMeta(file_path, first_meta_data);
auto first_meta_data_ptr = std::make_shared<json>();
auto ret = GetMeta(file_path, first_meta_data_ptr);
if (ret.first != SUCCESS) {
return FAILED;
}
@ -91,12 +92,12 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
return FAILED;
}
for (const auto &file : file_paths_) {
json meta_data = json();
auto ret1 = GetMeta(file, meta_data);
auto meta_data_ptr = std::make_shared<json>();
auto ret1 = GetMeta(file, meta_data_ptr);
if (ret1.first != SUCCESS) {
return FAILED;
}
if (meta_data != first_meta_data) {
if (*meta_data_ptr != *first_meta_data_ptr) {
MS_LOG(ERROR) << "Mindrecord files meta information is different.";
return FAILED;
}
@ -140,7 +141,7 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
header_size_ = shard_header_->GetHeaderSize();
page_size_ = shard_header_->GetPageSize();
// version < 3.0
if (first_meta_data["version"] < kVersion) {
if ((*first_meta_data_ptr)["version"] < kVersion) {
shard_column_ = std::make_shared<ShardColumn>(shard_header_, false);
} else {
shard_column_ = std::make_shared<ShardColumn>(shard_header_, true);
@ -314,14 +315,14 @@ MSRStatus ShardReader::GetTotalBlobSize(int64_t *total_blob_size) {
MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels,
std::shared_ptr<std::fstream> fs,
std::vector<std::vector<std::vector<uint64_t>>> &offsets, int shard_id,
const std::vector<std::string> &columns,
std::vector<std::vector<json>> &column_values) {
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
int shard_id, const std::vector<std::string> &columns,
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
for (int i = 0; i < static_cast<int>(labels.size()); ++i) {
uint64_t group_id = std::stoull(labels[i][0]);
uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len;
uint64_t offset_end = std::stoull(labels[i][2]);
offsets[shard_id].emplace_back(
(*offset_ptr)[shard_id].emplace_back(
std::vector<uint64_t>{static_cast<uint64_t>(shard_id), group_id, offset_start, offset_end});
if (!all_in_index_) {
int raw_page_id = std::stoi(labels[i][3]);
@ -353,7 +354,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
} else {
tmp = label_json;
}
column_values[shard_id].emplace_back(tmp);
(*col_val_ptr)[shard_id].emplace_back(tmp);
} else {
json construct_json;
for (unsigned int j = 0; j < columns.size(); ++j) {
@ -373,7 +374,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
construct_json[columns[j]] = std::string(labels[i][j + 3]);
}
}
column_values[shard_id].emplace_back(construct_json);
(*col_val_ptr)[shard_id].emplace_back(construct_json);
}
}
@ -381,8 +382,8 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str
}
MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
std::vector<std::vector<std::vector<uint64_t>>> &offsets,
std::vector<std::vector<json>> &column_values) {
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr,
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) {
auto db = database_paths_[shard_id];
std::vector<std::vector<std::string>> labels;
char *errmsg = nullptr;
@ -406,10 +407,11 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
}
}
sqlite3_free(errmsg);
return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values);
return ConvertLabelToJson(labels, fs, offset_ptr, shard_id, columns, col_val_ptr);
}
MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) {
MSRStatus ShardReader::GetAllClasses(const std::string &category_field,
std::shared_ptr<std::set<std::string>> category_ptr) {
std::map<std::string, uint64_t> index_columns;
for (auto &field : GetShardHeader()->GetFields()) {
index_columns[field.second] = field.first;
@ -425,7 +427,7 @@ MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set
std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES";
std::vector<std::thread> threads = std::vector<std::thread>(shard_count_);
for (int x = 0; x < shard_count_; x++) {
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, std::ref(categories));
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, category_ptr);
}
for (int x = 0; x < shard_count_; x++) {
@ -434,8 +436,8 @@ MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set
return SUCCESS;
}
void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql,
std::set<std::string> &categories) {
void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql,
std::shared_ptr<std::set<std::string>> category_ptr) {
if (nullptr == db) {
return;
}
@ -452,20 +454,22 @@ void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string
MS_LOG(INFO) << "Get " << static_cast<int>(columns.size()) << " records from shard " << shard_id << " index.";
std::lock_guard<std::mutex> lck(shard_locker_);
for (int i = 0; i < static_cast<int>(columns.size()); ++i) {
categories.emplace(columns[i][0]);
category_ptr->emplace(columns[i][0]);
}
}
ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector<std::string> &columns) {
ROW_GROUPS ShardReader::ReadAllRowGroup(const std::vector<std::string> &columns) {
std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
std::vector<std::vector<std::vector<uint64_t>>> offsets(shard_count_, std::vector<std::vector<uint64_t>>{});
std::vector<std::vector<json>> column_values(shard_count_, std::vector<json>{});
auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
shard_count_, std::vector<std::vector<uint64_t>>{});
auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});
if (all_in_index_) {
for (unsigned int i = 0; i < columns.size(); ++i) {
fields += ',';
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]));
if (ret.first != SUCCESS) {
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values));
return std::make_tuple(FAILED, std::move(*offset_ptr), std::move(*col_val_ptr));
}
fields += ret.second;
}
@ -477,27 +481,27 @@ ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector<std::string> &columns) {
std::vector<std::thread> thread_read_db = std::vector<std::thread>(shard_count_);
for (int x = 0; x < shard_count_; x++) {
thread_read_db[x] =
std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, std::ref(offsets), std::ref(column_values));
thread_read_db[x] = std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, offset_ptr, col_val_ptr);
}
for (int x = 0; x < shard_count_; x++) {
thread_read_db[x].join();
}
return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values));
return std::make_tuple(SUCCESS, std::move(*offset_ptr), std::move(*col_val_ptr));
}
ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns,
const uint32_t &shard_id, const uint32_t &sample_id) {
std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
std::vector<std::vector<std::vector<uint64_t>>> offsets(shard_count_, std::vector<std::vector<uint64_t>>{});
std::vector<std::vector<json>> column_values(shard_count_, std::vector<json>{});
auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>(
shard_count_, std::vector<std::vector<uint64_t>>{});
auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{});
if (all_in_index_) {
for (unsigned int i = 0; i < columns.size(); ++i) {
fields += ',';
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]));
if (ret.first != SUCCESS) {
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values));
return std::make_tuple(FAILED, std::move(*offset_ptr), std::move(*col_val_ptr));
}
fields += ret.second;
}
@ -507,12 +511,12 @@ ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std::
std::string sql = "SELECT " + fields + " FROM INDEXES WHERE ROW_ID = " + std::to_string(sample_id);
if (ReadAllRowsInShard(shard_id, sql, columns, offsets, column_values) != SUCCESS) {
if (ReadAllRowsInShard(shard_id, sql, columns, offset_ptr, col_val_ptr) != SUCCESS) {
MS_LOG(ERROR) << "Read shard id: " << shard_id << ", sample id: " << sample_id << " from index failed.";
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values));
return std::make_tuple(FAILED, std::move(*offset_ptr), std::move(*col_val_ptr));
}
return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values));
return std::make_tuple(SUCCESS, std::move(*offset_ptr), std::move(*col_val_ptr));
}
ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns) {
@ -550,7 +554,10 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id,
uint64_t page_length = page->GetPageSize();
uint64_t page_offset = page_size_ * page->GetPageID() + header_size_;
std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria);
if (image_offset.empty()) {
return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::vector<std::vector<uint64_t>>(),
std::vector<json>());
}
auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria);
if (status_labels.first != SUCCESS) {
return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>());
@ -601,7 +608,7 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
db = nullptr;
return std::vector<std::vector<uint64_t>>();
} else {
MS_LOG(DEBUG) << "Get " << static_cast<int>(image_offsets.size()) << "records from index.";
MS_LOG(DEBUG) << "Get " << static_cast<int>(image_offsets.size()) << " records from index.";
}
std::vector<std::vector<uint64_t>> res;
for (int i = static_cast<int>(image_offsets.size()) - 1; i >= 0; i--) res.emplace_back(std::vector<uint64_t>{0, 0});
@ -614,6 +621,44 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int
return res;
}
std::pair<MSRStatus, std::vector<uint64_t>> ShardReader::GetPagesByCategory(
int shard_id, const std::pair<std::string, std::string> &criteria) {
auto db = database_paths_[shard_id];
std::string sql = "SELECT DISTINCT PAGE_ID_BLOB FROM INDEXES WHERE 1 = 1 ";
if (!criteria.first.empty()) {
auto schema = shard_header_->GetSchemas()[0]->GetSchema();
if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) {
sql +=
" AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second;
} else {
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" +
criteria.second + "'";
}
}
sql += ";";
std::vector<std::vector<std::string>> page_ids;
char *errmsg = nullptr;
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &page_ids, &errmsg);
if (rc != SQLITE_OK) {
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
sqlite3_free(errmsg);
sqlite3_close(db);
db = nullptr;
return std::make_pair(FAILED, std::vector<uint64_t>());
} else {
MS_LOG(DEBUG) << "Get " << page_ids.size() << "pages from index.";
}
std::vector<uint64_t> res;
for (int i = 0; i < static_cast<int>(page_ids.size()); ++i) {
res.emplace_back(std::stoull(page_ids[i][0]));
}
sqlite3_free(errmsg);
return std::make_pair(SUCCESS, res);
}
std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() {
std::vector<std::string> blob_fields;
for (auto &p : GetShardHeader()->GetSchemas()) {
@ -642,8 +687,8 @@ void ShardReader::CheckIfColumnInIndex(const std::vector<std::string> &columns)
}
}
MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, string &sql, string criteria,
std::vector<std::vector<std::string>> &labels) {
MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria,
std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr) {
sqlite3_stmt *stmt = nullptr;
if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) {
MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql;
@ -661,7 +706,7 @@ MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, string &sql, string criter
for (int i = 0; i < ncols; i++) {
tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i)));
}
labels.push_back(tmp);
labels_ptr->push_back(tmp);
rc = sqlite3_step(stmt);
}
(void)sqlite3_finalize(stmt);
@ -724,16 +769,16 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabelsFromPage(
auto db = database_paths_[shard_id];
std::string sql = "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = " +
std::to_string(page_id);
std::vector<std::vector<std::string>> label_offsets;
auto label_offset_ptr = std::make_shared<std::vector<std::vector<std::string>>>();
if (!criteria.first.empty()) {
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria";
if (QueryWithCriteria(db, sql, criteria.second, label_offsets) == FAILED) {
if (QueryWithCriteria(db, sql, criteria.second, label_offset_ptr) == FAILED) {
return {FAILED, {}};
}
} else {
sql += ";";
char *errmsg = nullptr;
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &label_offsets, &errmsg);
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, label_offset_ptr.get(), &errmsg);
if (rc != SQLITE_OK) {
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
sqlite3_free(errmsg);
@ -741,11 +786,11 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabelsFromPage(
db = nullptr;
return {FAILED, {}};
}
MS_LOG(DEBUG) << "Get " << label_offsets.size() << "records from index.";
MS_LOG(DEBUG) << "Get " << label_offset_ptr->size() << " records from index.";
sqlite3_free(errmsg);
}
// get labels from binary file
return GetLabelsFromBinaryFile(shard_id, columns, label_offsets);
return GetLabelsFromBinaryFile(shard_id, columns, *label_offset_ptr);
}
std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int shard_id,
@ -760,17 +805,17 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
fields += columns[i] + "_" + std::to_string(schema_id);
}
if (fields.empty()) fields = "*";
std::vector<std::vector<std::string>> labels;
auto labels_ptr = std::make_shared<std::vector<std::vector<std::string>>>();
std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id);
if (!criteria.first.empty()) {
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria";
if (QueryWithCriteria(db, sql, criteria.second, labels) == FAILED) {
if (QueryWithCriteria(db, sql, criteria.second, labels_ptr) == FAILED) {
return {FAILED, {}};
}
} else {
sql += ";";
char *errmsg = nullptr;
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg);
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, labels_ptr.get(), &errmsg);
if (rc != SQLITE_OK) {
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg;
sqlite3_free(errmsg);
@ -778,13 +823,13 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
db = nullptr;
return {FAILED, {}};
} else {
MS_LOG(DEBUG) << "Get " << static_cast<int>(labels.size()) << "records from index.";
MS_LOG(DEBUG) << "Get " << static_cast<int>(labels_ptr->size()) << " records from index.";
}
sqlite3_free(errmsg);
}
std::vector<json> ret;
for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{});
for (unsigned int i = 0; i < labels.size(); ++i) {
for (unsigned int i = 0; i < labels_ptr->size(); ++i) ret.emplace_back(json{});
for (unsigned int i = 0; i < labels_ptr->size(); ++i) {
json construct_json;
for (unsigned int j = 0; j < columns.size(); ++j) {
// construct json "f1": value
@ -792,15 +837,15 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int
// convert the string to base type by schema
if (schema[columns[j]]["type"] == "int32") {
construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j]);
construct_json[columns[j]] = StringToNum<int32_t>((*labels_ptr)[i][j]);
} else if (schema[columns[j]]["type"] == "int64") {
construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j]);
construct_json[columns[j]] = StringToNum<int64_t>((*labels_ptr)[i][j]);
} else if (schema[columns[j]]["type"] == "float32") {
construct_json[columns[j]] = StringToNum<float>(labels[i][j]);
construct_json[columns[j]] = StringToNum<float>((*labels_ptr)[i][j]);
} else if (schema[columns[j]]["type"] == "float64") {
construct_json[columns[j]] = StringToNum<double>(labels[i][j]);
construct_json[columns[j]] = StringToNum<double>((*labels_ptr)[i][j]);
} else {
construct_json[columns[j]] = std::string(labels[i][j]);
construct_json[columns[j]] = std::string((*labels_ptr)[i][j]);
}
}
ret[i] = construct_json;
@ -834,7 +879,7 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) {
}
std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES";
std::vector<std::thread> threads = std::vector<std::thread>(shard_count);
std::set<std::string> categories;
auto category_ptr = std::make_shared<std::set<std::string>>();
for (int x = 0; x < shard_count; x++) {
sqlite3 *db = nullptr;
int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr);
@ -843,13 +888,13 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) {
<< sqlite3_errmsg(db);
return -1;
}
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, std::ref(categories));
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, category_ptr);
}
for (int x = 0; x < shard_count; x++) {
threads[x].join();
}
return categories.size();
return category_ptr->size();
}
MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset,
@ -1008,8 +1053,7 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) {
return SUCCESS;
}
MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::shared_ptr<ShardOperator> &op) {
MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op) {
CheckIfColumnInIndex(selected_columns_);
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op);
auto categories = category_op->GetCategories();
@ -1033,42 +1077,50 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i
MS_LOG(ERROR) << "Invalid parameter, num_categories must be greater than 0, but got " << num_elements;
return FAILED;
}
std::set<std::string> categories_set;
auto ret = GetAllClasses(category_field, categories_set);
auto category_ptr = std::make_shared<std::set<std::string>>();
auto ret = GetAllClasses(category_field, category_ptr);
if (SUCCESS != ret) {
return FAILED;
}
int i = 0;
for (auto it = categories_set.begin(); it != categories_set.end() && i < num_categories; ++it) {
for (auto it = category_ptr->begin(); it != category_ptr->end() && i < num_categories; ++it) {
categories.emplace_back(category_field, *it);
i++;
}
}
// Generate task list, a task will create a batch
std::vector<ShardTask> categoryTasks(categories.size());
for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) {
int category_index = 0;
for (const auto &rg : row_group_summary) {
if (category_index >= num_elements) break;
auto shard_id = std::get<0>(rg);
auto group_id = std::get<1>(rg);
auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_);
if (SUCCESS != std::get<0>(details)) {
for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) {
auto res = GetPagesByCategory(shard_id, categories[categoryNo]);
if (SUCCESS != res.first) {
return FAILED;
}
auto offsets = std::get<4>(details);
auto number_of_rows = offsets.size();
for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) {
if (category_index < num_elements) {
categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id, std::get<4>(details)[iStart],
std::get<5>(details)[iStart]);
category_index++;
auto page_ids = res.second;
for (const auto &page_id : page_ids) {
if (category_index >= num_elements) break;
const auto &page_t = shard_header_->GetPage(shard_id, page_id);
const auto &page = page_t.first;
auto group_id = page->GetPageTypeID();
auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_);
if (SUCCESS != std::get<0>(details)) {
return FAILED;
}
auto offsets = std::get<4>(details);
auto number_of_rows = offsets.size();
for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) {
if (category_index < num_elements) {
categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id,
std::get<4>(details)[iStart], std::get<5>(details)[iStart]);
category_index++;
}
}
MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks";
}
}
MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks";
}
tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples);
if (SUCCESS != (*category_op)(tasks_)) {
@ -1189,7 +1241,7 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
}
}
} else {
if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) {
if (SUCCESS != CreateTasksByCategory(operators[category_operator])) {
return FAILED;
}
}