forked from mindspore-Ecosystem/mindspore
implemented multi-thread index writer for mindrecord
num threads cannot be more than num shards minor fix clang style fix address review comments
This commit is contained in:
parent
b35046f559
commit
5637f80692
|
@ -85,14 +85,14 @@ class ShardIndexGenerator {
|
|||
/// \param sql
|
||||
/// \param data
|
||||
/// \return
|
||||
MSRStatus BindParamaterExecuteSQL(
|
||||
MSRStatus BindParameterExecuteSQL(
|
||||
sqlite3 *db, const std::string &sql,
|
||||
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data);
|
||||
|
||||
INDEX_FIELDS GenerateIndexFields(const std::vector<json> &schema_detail);
|
||||
|
||||
MSRStatus ExcuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db,
|
||||
const std::vector<int> &raw_page_ids, const std::map<int, int> &blob_id_to_page_id);
|
||||
MSRStatus ExecuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db,
|
||||
const std::vector<int> &raw_page_ids, const std::map<int, int> &blob_id_to_page_id);
|
||||
|
||||
MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name);
|
||||
|
||||
|
@ -103,12 +103,16 @@ class ShardIndexGenerator {
|
|||
void AddIndexFieldByRawData(const std::vector<json> &schema_detail,
|
||||
std::vector<std::tuple<std::string, std::string, std::string>> &row_data);
|
||||
|
||||
void DatabaseWriter(); // worker thread
|
||||
|
||||
std::string file_path_;
|
||||
bool append_;
|
||||
ShardHeader shard_header_;
|
||||
uint64_t page_size_;
|
||||
uint64_t header_size_;
|
||||
int schema_count_;
|
||||
std::atomic_int task_;
|
||||
std::atomic_bool write_success_;
|
||||
std::vector<std::pair<uint64_t, std::string>> fields_;
|
||||
};
|
||||
} // namespace mindrecord
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <thread>
|
||||
|
||||
#include "mindrecord/include/shard_index_generator.h"
|
||||
#include "common/utils.h"
|
||||
|
@ -26,7 +27,13 @@ using mindspore::MsLogLevel::INFO;
|
|||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append)
|
||||
: file_path_(file_path), append_(append), page_size_(0), header_size_(0), schema_count_(0) {}
|
||||
: file_path_(file_path),
|
||||
append_(append),
|
||||
page_size_(0),
|
||||
header_size_(0),
|
||||
schema_count_(0),
|
||||
task_(0),
|
||||
write_success_(true) {}
|
||||
|
||||
MSRStatus ShardIndexGenerator::Build() {
|
||||
ShardHeader header = ShardHeader();
|
||||
|
@ -284,7 +291,7 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateRawSQL(
|
|||
return {SUCCESS, sql};
|
||||
}
|
||||
|
||||
MSRStatus ShardIndexGenerator::BindParamaterExecuteSQL(
|
||||
MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
|
||||
sqlite3 *db, const std::string &sql,
|
||||
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data) {
|
||||
sqlite3_stmt *stmt = nullptr;
|
||||
|
@ -471,9 +478,9 @@ INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &s
|
|||
return {SUCCESS, std::move(fields)};
|
||||
}
|
||||
|
||||
MSRStatus ShardIndexGenerator::ExcuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db,
|
||||
const std::vector<int> &raw_page_ids,
|
||||
const std::map<int, int> &blob_id_to_page_id) {
|
||||
MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db,
|
||||
const std::vector<int> &raw_page_ids,
|
||||
const std::map<int, int> &blob_id_to_page_id) {
|
||||
// Add index data to database
|
||||
std::string shard_address = shard_header_.get_shard_address_by_id(shard_no);
|
||||
if (shard_address.empty()) {
|
||||
|
@ -493,7 +500,7 @@ MSRStatus ShardIndexGenerator::ExcuteTransaction(const int &shard_no, const std:
|
|||
if (data.first != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
if (BindParamaterExecuteSQL(db.second, sql.second, data.second) == FAILED) {
|
||||
if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) {
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db.";
|
||||
|
@ -514,37 +521,62 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() {
|
|||
page_size_ = shard_header_.get_page_size();
|
||||
header_size_ = shard_header_.get_header_size();
|
||||
schema_count_ = shard_header_.get_schema_count();
|
||||
if (shard_header_.get_shard_count() <= kMaxShardCount) {
|
||||
// Create one database per shard
|
||||
for (int shard_no = 0; shard_no < shard_header_.get_shard_count(); ++shard_no) {
|
||||
// Create database
|
||||
auto db = CreateDatabase(shard_no);
|
||||
if (db.first != SUCCESS || db.second == nullptr) {
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully.";
|
||||
|
||||
// Pre-processing page information
|
||||
auto total_pages = shard_header_.GetLastPageId(shard_no) + 1;
|
||||
|
||||
std::map<int, int> blob_id_to_page_id;
|
||||
std::vector<int> raw_page_ids;
|
||||
for (uint64_t i = 0; i < total_pages; ++i) {
|
||||
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first;
|
||||
if (cur_page->get_page_type() == "RAW_DATA") {
|
||||
raw_page_ids.push_back(i);
|
||||
} else if (cur_page->get_page_type() == "BLOB_DATA") {
|
||||
blob_id_to_page_id[cur_page->get_page_type_id()] = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (ExcuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully.";
|
||||
}
|
||||
if (shard_header_.get_shard_count() > kMaxShardCount) {
|
||||
MS_LOG(ERROR) << "num shards: " << shard_header_.get_shard_count() << " exceeds max count:" << kMaxSchemaCount;
|
||||
return FAILED;
|
||||
}
|
||||
task_ = 0; // set two atomic vars to initial value
|
||||
write_success_ = true;
|
||||
|
||||
// spawn half the physical threads or total number of shards whichever is smaller
|
||||
const unsigned int num_workers =
|
||||
std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.get_shard_count()));
|
||||
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(num_workers);
|
||||
|
||||
for (size_t t = 0; t < threads.capacity(); t++) {
|
||||
threads.emplace_back(std::thread(&ShardIndexGenerator::DatabaseWriter, this));
|
||||
}
|
||||
|
||||
for (size_t t = 0; t < threads.capacity(); t++) {
|
||||
threads[t].join();
|
||||
}
|
||||
return write_success_ ? SUCCESS : FAILED;
|
||||
}
|
||||
|
||||
void ShardIndexGenerator::DatabaseWriter() {
|
||||
int shard_no = task_++;
|
||||
while (shard_no < shard_header_.get_shard_count()) {
|
||||
auto db = CreateDatabase(shard_no);
|
||||
if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) {
|
||||
write_success_ = false;
|
||||
return;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully.";
|
||||
|
||||
// Pre-processing page information
|
||||
auto total_pages = shard_header_.GetLastPageId(shard_no) + 1;
|
||||
|
||||
std::map<int, int> blob_id_to_page_id;
|
||||
std::vector<int> raw_page_ids;
|
||||
for (uint64_t i = 0; i < total_pages; ++i) {
|
||||
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first;
|
||||
if (cur_page->get_page_type() == "RAW_DATA") {
|
||||
raw_page_ids.push_back(i);
|
||||
} else if (cur_page->get_page_type() == "BLOB_DATA") {
|
||||
blob_id_to_page_id[cur_page->get_page_type_id()] = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) {
|
||||
write_success_ = false;
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully.";
|
||||
shard_no = task_++;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue