implemented multi-thread index writer for mindrecord

num threads cannot be more than num shards

minor fix

clang style fix

address review comments
This commit is contained in:
Zirui Wu 2020-04-01 11:24:25 -04:00
parent b35046f559
commit 5637f80692
2 changed files with 75 additions and 39 deletions

View File

@ -85,14 +85,14 @@ class ShardIndexGenerator {
/// \param sql /// \param sql
/// \param data /// \param data
/// \return /// \return
MSRStatus BindParamaterExecuteSQL( MSRStatus BindParameterExecuteSQL(
sqlite3 *db, const std::string &sql, sqlite3 *db, const std::string &sql,
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data); const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data);
INDEX_FIELDS GenerateIndexFields(const std::vector<json> &schema_detail); INDEX_FIELDS GenerateIndexFields(const std::vector<json> &schema_detail);
MSRStatus ExcuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db, MSRStatus ExecuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db,
const std::vector<int> &raw_page_ids, const std::map<int, int> &blob_id_to_page_id); const std::vector<int> &raw_page_ids, const std::map<int, int> &blob_id_to_page_id);
MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name); MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name);
@ -103,12 +103,16 @@ class ShardIndexGenerator {
void AddIndexFieldByRawData(const std::vector<json> &schema_detail, void AddIndexFieldByRawData(const std::vector<json> &schema_detail,
std::vector<std::tuple<std::string, std::string, std::string>> &row_data); std::vector<std::tuple<std::string, std::string, std::string>> &row_data);
void DatabaseWriter(); // worker thread
std::string file_path_; std::string file_path_;
bool append_; bool append_;
ShardHeader shard_header_; ShardHeader shard_header_;
uint64_t page_size_; uint64_t page_size_;
uint64_t header_size_; uint64_t header_size_;
int schema_count_; int schema_count_;
std::atomic_int task_;
std::atomic_bool write_success_;
std::vector<std::pair<uint64_t, std::string>> fields_; std::vector<std::pair<uint64_t, std::string>> fields_;
}; };
} // namespace mindrecord } // namespace mindrecord

View File

@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <thread>
#include "mindrecord/include/shard_index_generator.h" #include "mindrecord/include/shard_index_generator.h"
#include "common/utils.h" #include "common/utils.h"
@ -26,7 +27,13 @@ using mindspore::MsLogLevel::INFO;
namespace mindspore { namespace mindspore {
namespace mindrecord { namespace mindrecord {
ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append) ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool append)
: file_path_(file_path), append_(append), page_size_(0), header_size_(0), schema_count_(0) {} : file_path_(file_path),
append_(append),
page_size_(0),
header_size_(0),
schema_count_(0),
task_(0),
write_success_(true) {}
MSRStatus ShardIndexGenerator::Build() { MSRStatus ShardIndexGenerator::Build() {
ShardHeader header = ShardHeader(); ShardHeader header = ShardHeader();
@ -284,7 +291,7 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateRawSQL(
return {SUCCESS, sql}; return {SUCCESS, sql};
} }
MSRStatus ShardIndexGenerator::BindParamaterExecuteSQL( MSRStatus ShardIndexGenerator::BindParameterExecuteSQL(
sqlite3 *db, const std::string &sql, sqlite3 *db, const std::string &sql,
const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data) { const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data) {
sqlite3_stmt *stmt = nullptr; sqlite3_stmt *stmt = nullptr;
@ -471,9 +478,9 @@ INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &s
return {SUCCESS, std::move(fields)}; return {SUCCESS, std::move(fields)};
} }
MSRStatus ShardIndexGenerator::ExcuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db, MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std::pair<MSRStatus, sqlite3 *> &db,
const std::vector<int> &raw_page_ids, const std::vector<int> &raw_page_ids,
const std::map<int, int> &blob_id_to_page_id) { const std::map<int, int> &blob_id_to_page_id) {
// Add index data to database // Add index data to database
std::string shard_address = shard_header_.get_shard_address_by_id(shard_no); std::string shard_address = shard_header_.get_shard_address_by_id(shard_no);
if (shard_address.empty()) { if (shard_address.empty()) {
@ -493,7 +500,7 @@ MSRStatus ShardIndexGenerator::ExcuteTransaction(const int &shard_no, const std:
if (data.first != SUCCESS) { if (data.first != SUCCESS) {
return FAILED; return FAILED;
} }
if (BindParamaterExecuteSQL(db.second, sql.second, data.second) == FAILED) { if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) {
return FAILED; return FAILED;
} }
MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db.";
@ -514,37 +521,62 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() {
page_size_ = shard_header_.get_page_size(); page_size_ = shard_header_.get_page_size();
header_size_ = shard_header_.get_header_size(); header_size_ = shard_header_.get_header_size();
schema_count_ = shard_header_.get_schema_count(); schema_count_ = shard_header_.get_schema_count();
if (shard_header_.get_shard_count() <= kMaxShardCount) { if (shard_header_.get_shard_count() > kMaxShardCount) {
// Create one database per shard MS_LOG(ERROR) << "num shards: " << shard_header_.get_shard_count() << " exceeds max count:" << kMaxSchemaCount;
for (int shard_no = 0; shard_no < shard_header_.get_shard_count(); ++shard_no) { return FAILED;
// Create database }
auto db = CreateDatabase(shard_no); task_ = 0; // set two atomic vars to initial value
if (db.first != SUCCESS || db.second == nullptr) { write_success_ = true;
return FAILED;
} // spawn half the physical threads or total number of shards whichever is smaller
MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully."; const unsigned int num_workers =
std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.get_shard_count()));
// Pre-processing page information
auto total_pages = shard_header_.GetLastPageId(shard_no) + 1; std::vector<std::thread> threads;
threads.reserve(num_workers);
std::map<int, int> blob_id_to_page_id;
std::vector<int> raw_page_ids; for (size_t t = 0; t < threads.capacity(); t++) {
for (uint64_t i = 0; i < total_pages; ++i) { threads.emplace_back(std::thread(&ShardIndexGenerator::DatabaseWriter, this));
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first; }
if (cur_page->get_page_type() == "RAW_DATA") {
raw_page_ids.push_back(i); for (size_t t = 0; t < threads.capacity(); t++) {
} else if (cur_page->get_page_type() == "BLOB_DATA") { threads[t].join();
blob_id_to_page_id[cur_page->get_page_type_id()] = i; }
} return write_success_ ? SUCCESS : FAILED;
} }
if (ExcuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) { void ShardIndexGenerator::DatabaseWriter() {
return FAILED; int shard_no = task_++;
} while (shard_no < shard_header_.get_shard_count()) {
MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully."; auto db = CreateDatabase(shard_no);
} if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) {
write_success_ = false;
return;
}
MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully.";
// Pre-processing page information
auto total_pages = shard_header_.GetLastPageId(shard_no) + 1;
std::map<int, int> blob_id_to_page_id;
std::vector<int> raw_page_ids;
for (uint64_t i = 0; i < total_pages; ++i) {
std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first;
if (cur_page->get_page_type() == "RAW_DATA") {
raw_page_ids.push_back(i);
} else if (cur_page->get_page_type() == "BLOB_DATA") {
blob_id_to_page_id[cur_page->get_page_type_id()] = i;
}
}
if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) {
write_success_ = false;
return;
}
MS_LOG(INFO) << "Generate index db for shard: " << shard_no << " successfully.";
shard_no = task_++;
} }
return SUCCESS;
} }
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore