forked from mindspore-Ecosystem/mindspore
enhance: use multi thread to load metadata in mindrecord
This commit is contained in:
parent
ecc9f00c3c
commit
cfc826feed
|
@ -21,6 +21,7 @@
|
|||
#include <iostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/mindrecord/include/common/shard_utils.h"
|
||||
|
||||
|
@ -38,10 +39,16 @@ class ShardTask {
|
|||
|
||||
void MakePerm();
|
||||
|
||||
void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
|
||||
const json &label);
|
||||
inline void InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
|
||||
const json &label);
|
||||
|
||||
void InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task);
|
||||
inline void InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
|
||||
const std::vector<uint64_t> &offset, const json &label);
|
||||
|
||||
inline void InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task);
|
||||
|
||||
inline void InsertTask(const uint32_t &i,
|
||||
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task);
|
||||
|
||||
void PopBack();
|
||||
|
||||
|
@ -56,12 +63,41 @@ class ShardTask {
|
|||
static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements,
|
||||
int64_t num_samples);
|
||||
|
||||
inline void ResizeTask(const uint32_t &size);
|
||||
|
||||
uint32_t categories;
|
||||
|
||||
std::vector<int> permutation_;
|
||||
|
||||
std::vector<std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>> task_list_;
|
||||
};
|
||||
|
||||
inline void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
|
||||
const json &label) {
|
||||
MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id
|
||||
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
|
||||
task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label);
|
||||
}
|
||||
|
||||
inline void ShardTask::InsertTask(const uint32_t &i, TaskType task_type, int shard_id, int group_id,
|
||||
const std::vector<uint64_t> &offset, const json &label) {
|
||||
task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label};
|
||||
}
|
||||
|
||||
inline void ShardTask::InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) {
|
||||
MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task))
|
||||
<< ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump()
|
||||
<< ", size of task_list_: " << task_list_.size() << ".";
|
||||
|
||||
task_list_.push_back(std::move(task));
|
||||
}
|
||||
|
||||
inline void ShardTask::InsertTask(const uint32_t &i,
|
||||
std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) {
|
||||
task_list_[i] = std::move(task);
|
||||
}
|
||||
|
||||
inline void ShardTask::ResizeTask(const uint32_t &size) { task_list_.resize(size); }
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -14,6 +14,9 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <thread>
|
||||
|
||||
#include "minddata/mindrecord/include/shard_distributed_sample.h"
|
||||
#include "minddata/mindrecord/include/shard_reader.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
@ -1036,15 +1039,37 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, i
|
|||
if (std::get<0>(ret) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
auto offsets = std::get<1>(ret);
|
||||
auto local_columns = std::get<2>(ret);
|
||||
auto &offsets = std::get<1>(ret);
|
||||
auto &local_columns = std::get<2>(ret);
|
||||
if (shard_count_ <= kMaxFileCount) {
|
||||
int sample_count = 0;
|
||||
for (int shard_id = 0; shard_id < shard_count_; shard_id++) {
|
||||
for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) {
|
||||
tasks_.InsertTask(TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1],
|
||||
std::vector<uint64_t>{offsets[shard_id][i][2], offsets[shard_id][i][3]},
|
||||
local_columns[shard_id][i]);
|
||||
}
|
||||
sample_count += offsets[shard_id].size();
|
||||
}
|
||||
MS_LOG(DEBUG) << "There are " << sample_count << " records in the dataset.";
|
||||
|
||||
// Init the tasks_ size
|
||||
tasks_.ResizeTask(sample_count);
|
||||
|
||||
// Init the task threads, maybe use ThreadPool is better
|
||||
std::vector<std::thread> init_tasks_thread(shard_count_);
|
||||
|
||||
uint32_t current_offset = 0;
|
||||
for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
|
||||
init_tasks_thread[shard_id] = std::thread([this, &offsets, &local_columns, shard_id, current_offset]() {
|
||||
auto offset = current_offset;
|
||||
for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) {
|
||||
tasks_.InsertTask(offset, TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1],
|
||||
std::vector<uint64_t>{offsets[shard_id][i][2], offsets[shard_id][i][3]},
|
||||
local_columns[shard_id][i]);
|
||||
offset++;
|
||||
}
|
||||
});
|
||||
current_offset += offsets[shard_id].size();
|
||||
}
|
||||
|
||||
for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
|
||||
init_tasks_thread[shard_id].join();
|
||||
}
|
||||
} else {
|
||||
return FAILED;
|
||||
|
|
|
@ -44,21 +44,6 @@ void ShardTask::MakePerm() {
|
|||
}
|
||||
}
|
||||
|
||||
void ShardTask::InsertTask(TaskType task_type, int shard_id, int group_id, const std::vector<uint64_t> &offset,
|
||||
const json &label) {
|
||||
MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id
|
||||
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
|
||||
task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label);
|
||||
}
|
||||
|
||||
void ShardTask::InsertTask(std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> task) {
|
||||
MS_LOG(DEBUG) << "Into insert task, shard_id: " << std::get<0>(std::get<1>(task))
|
||||
<< ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump()
|
||||
<< ", size of task_list_: " << task_list_.size() << ".";
|
||||
|
||||
task_list_.push_back(std::move(task));
|
||||
}
|
||||
|
||||
void ShardTask::PopBack() { task_list_.pop_back(); }
|
||||
|
||||
uint32_t ShardTask::Size() const { return static_cast<uint32_t>(task_list_.size()); }
|
||||
|
|
Loading…
Reference in New Issue