forked from mindspore-Ecosystem/mindspore
implemented multi-thread index writer for mindrecord
num threads cannot be more than num shards minor fix clang style fix address review comments
This commit is contained in:
parent
b35046f559
commit
5637f80692
|
@ -85,14 +85,14 @@ class ShardIndexGenerator {
|
||||||
/// \param sql
|
/// \param sql
|
||||||
/// \param data
|
/// \param data
|
||||||
/// \return
|
/// \return
|
||||||
MSRStatus BindParamaterExecuteSQL(
|
MSRStatus BindParameterExecuteSQL(
|
||||||
sqlite3 *db, const std::string &sql,
|
sqlite3 *db, const std::string &sql,
|
||||||
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data);
|
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data);
|
||||||
|
|
||||||
INDEX_FIELDS GenerateIndexFields(const std::vector<json> &schema_detail);
|
INDEX_FIELDS GenerateIndexFields(const std::vector<json> &schema_detail);
|
||||||
|
|
||||||
MSRStatus ExcuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db,
|
MSRStatus ExecuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db,
|
||||||
const std::vector<int> &raw_page_ids, const std::map<int, int> &blob_id_to_page_id);
|
const std::vector<int> &raw_page_ids, const std::map<int, int> &blob_id_to_page_id);
|
||||||
|
|
||||||
MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name);
|
MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name);
|
||||||
|
|
||||||
|
@ -103,12 +103,16 @@ class ShardIndexGenerator {
|
||||||
void AddIndexFieldByRawData(const std::vector<json> &schema_detail,
|
void AddIndexFieldByRawData(const std::vector<json> &schema_detail,
|
||||||
std::vector<std::tuple<std::string, std::string, std::string>> &row_data);
|
std::vector<std::tuple<std::string, std::string, std::string>> &row_data);
|
||||||
|
|
||||||
|
void DatabaseWriter(); // worker thread
|
||||||
|
|
||||||
std::string file_path_;
|
std::string file_path_;
|
||||||
bool append_;
|
bool append_;
|
||||||
ShardHeader shard_header_;
|
ShardHeader shard_header_;
|
||||||
uint64_t page_size_;
|
uint64_t page_size_;
|
||||||
uint64_t header_size_;
|
uint64_t header_size_;
|
||||||
int schema_count_;
|
int schema_count_;
|
||||||
|
std::atomic_int task_;
|
||||||
|
std::atomic_bool write_success_;
|
||||||
std::vector<std::pair<uint64_t, std::string>> fields_;
|
std::vector<std::pair<uint64_t, std::string>> fields_;
|
||||||
};
|
};
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
#include "mindrecord/include/shard_index_generator.h"
|
#include "mindrecord/include/shard_index_generator.h"
|
||||||
#include "common/utils.h"
|
#include "common/utils.h"
|
||||||
|
@ -26,7 +27,13 @@ using mindspore::MsLogLevel::INFO;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace mindrecord {
|
namespace mindrecord {
|
||||||
ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append)
|
ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append)
|
||||||
: file_path_(file_path), append_(append), page_size_(0), header_size_(0), schema_count_(0) {}
|
: file_path_(file_path),
|
||||||
|
append_(append),
|
||||||
|
page_size_(0),
|
||||||
|
header_size_(0),
|
||||||
|
schema_count_(0),
|
||||||
|
task_(0),
|
||||||
|
write_success_(true) {}
|
||||||
|
|
||||||
MSRStatus ShardIndexGenerator::Build() {
|
MSRStatus ShardIndexGenerator::Build() {
|
||||||
ShardHeader header = ShardHeader();
|
ShardHeader header = ShardHeader();
|
||||||
|
@ -284,7 +291,7 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateRawSQL(
|
||||||
return {SUCCESS, sql};
|
return {SUCCESS, sql};
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardIndexGenerator::BindParamaterExecuteSQL(
|
MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
|
||||||
sqlite3 *db, const std::string &sql,
|
sqlite3 *db, const std::string &sql,
|
||||||
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data) {
|
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data) {
|
||||||
sqlite3_stmt *stmt = nullptr;
|
sqlite3_stmt *stmt = nullptr;
|
||||||
|
@ -471,9 +478,9 @@ INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &s
|
||||||
return {SUCCESS, std::move(fields)};
|
return {SUCCESS, std::move(fields)};
|
||||||
}
|
}
|
||||||
|
|
||||||
MSRStatus ShardIndexGenerator::ExcuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db,
|
MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db,
|
||||||
const std::vector<int> &raw_page_ids,
|
const std::vector<int> &raw_page_ids,
|
||||||
const std::map<int, int> &blob_id_to_page_id) {
|
const std::map<int, int> &blob_id_to_page_id) {
|
||||||
// Add index data to database
|
// Add index data to database
|
||||||
std::string shard_address = shard_header_.get_shard_address_by_id(shard_no);
|
std::string shard_address = shard_header_.get_shard_address_by_id(shard_no);
|
||||||
if (shard_address.empty()) {
|
if (shard_address.empty()) {
|
||||||
|
@ -493,7 +500,7 @@ MSRStatus ShardIndexGenerator::ExcuteTransaction(const int &shard_no, const std:
|
||||||
if (data.first != SUCCESS) {
|
if (data.first != SUCCESS) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
if (BindParamaterExecuteSQL(db.second, sql.second, data.second) == FAILED) {
|
if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db.";
|
MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db.";
|
||||||
|
@ -514,37 +521,62 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() {
|
||||||
page_size_ = shard_header_.get_page_size();
|
page_size_ = shard_header_.get_page_size();
|
||||||
header_size_ = shard_header_.get_header_size();
|
header_size_ = shard_header_.get_header_size();
|
||||||
schema_count_ = shard_header_.get_schema_count();
|
schema_count_ = shard_header_.get_schema_count();
|
||||||
if (shard_header_.get_shard_count() <= kMaxShardCount) {
|
if (shard_header_.get_shard_count() > kMaxShardCount) {
|
||||||
// Create one database per shard
|
MS_LOG(ERROR) << "num shards: " << shard_header_.get_shard_count() << " exceeds max count:" << kMaxSchemaCount;
|
||||||
for (int shard_no = 0; shard_no < shard_header_.get_shard_count(); ++shard_no) {
|
return FAILED;
|
||||||
// Create database
|
}
|
||||||
auto db = CreateDatabase(shard_no);
|
task_ = 0; // set two atomic vars to initial value
|
||||||
if (db.first != SUCCESS || db.second == nullptr) {
|
write_success_ = true;
|
||||||
return FAILED;
|
|
||||||
}
|
// spawn half the physical threads or total number of shards whichever is smaller
|
||||||
MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully.";
|
const unsigned int num_workers =
|
||||||
|
std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.get_shard_count()));
|
||||||
// Pre-processing page information
|
|
||||||
auto total_pages = shard_header_.GetLastPageId(shard_no) + 1;
|
std::vector<std::thread> threads;
|
||||||
|
threads.reserve(num_workers);
|
||||||
std::map<int, int> blob_id_to_page_id;
|
|
||||||
std::vector<int> raw_page_ids;
|
for (size_t t = 0; t < threads.capacity(); t++) {
|
||||||
for (uint64_t i = 0; i < total_pages; ++i) {
|
threads.emplace_back(std::thread(&ShardIndexGenerator::DatabaseWriter, this));
|
||||||
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first;
|
}
|
||||||
if (cur_page->get_page_type() == "RAW_DATA") {
|
|
||||||
raw_page_ids.push_back(i);
|
for (size_t t = 0; t < threads.capacity(); t++) {
|
||||||
} else if (cur_page->get_page_type() == "BLOB_DATA") {
|
threads[t].join();
|
||||||
blob_id_to_page_id[cur_page->get_page_type_id()] = i;
|
}
|
||||||
}
|
return write_success_ ? SUCCESS : FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ExcuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) {
|
void ShardIndexGenerator::DatabaseWriter() {
|
||||||
return FAILED;
|
int shard_no = task_++;
|
||||||
}
|
while (shard_no < shard_header_.get_shard_count()) {
|
||||||
MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully.";
|
auto db = CreateDatabase(shard_no);
|
||||||
}
|
if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) {
|
||||||
|
write_success_ = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully.";
|
||||||
|
|
||||||
|
// Pre-processing page information
|
||||||
|
auto total_pages = shard_header_.GetLastPageId(shard_no) + 1;
|
||||||
|
|
||||||
|
std::map<int, int> blob_id_to_page_id;
|
||||||
|
std::vector<int> raw_page_ids;
|
||||||
|
for (uint64_t i = 0; i < total_pages; ++i) {
|
||||||
|
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first;
|
||||||
|
if (cur_page->get_page_type() == "RAW_DATA") {
|
||||||
|
raw_page_ids.push_back(i);
|
||||||
|
} else if (cur_page->get_page_type() == "BLOB_DATA") {
|
||||||
|
blob_id_to_page_id[cur_page->get_page_type_id()] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) {
|
||||||
|
write_success_ = false;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully.";
|
||||||
|
shard_no = task_++;
|
||||||
}
|
}
|
||||||
return SUCCESS;
|
|
||||||
}
|
}
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue