!48090 reduce memory for mindrecord in lazy mode

Merge pull request !48090 from guozhijian/reduce_mindrecord_memory
This commit is contained in:
i-robot 2023-01-30 03:24:40 +00:00 committed by Gitee
commit 232bc67d26
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 87 additions and 21 deletions

View File

@ -32,9 +32,20 @@ namespace mindrecord {
// The data struct is as below:
// 1. TaskType: kCommonTask / kPaddedTask
// 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
// 3. std::vector<uint64_t>, json>> : [blob_start, blob_end], scalar_variable_fields
// 3. std::vector<uint64_t> : [blob_start, blob_end]
// 4. json : scalar_variable_fields
using ShardTask = std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>;
// The data struct is as below:
// 1. TaskType: kCommonTask / kPaddedTask
// 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
using TaskInfo = std::tuple<TaskType, std::tuple<int, int>>;
// The data struct is as below: contain the meta info
// 3. std::vector<uint64_t> : [blob_start, blob_end]
// 4. json : scalar_variable_fields
using SampleMeta = std::tuple<std::vector<uint64_t>, json>;
class MINDRECORD_API ShardTaskList {
public:
ShardTaskList();
@ -72,9 +83,9 @@ class MINDRECORD_API ShardTaskList {
int64_t SizeOfRows() const;
ShardTask &GetTaskByID(int64_t id);
ShardTask GetTaskByID(int64_t id);
ShardTask &GetRandomTask();
ShardTask GetRandomTask();
int64_t GetTaskSampleByID(int64_t id);
@ -91,7 +102,16 @@ class MINDRECORD_API ShardTaskList {
std::vector<int64_t> sample_ids_; // The list of actual ids that were sampled
std::vector<ShardTask> task_list_; // The full list of tasks
// fast mode: [{TaskType, (shard_id, group_id(fast load))}, ...]
// lazy mode: [{TaskType, (shard_id, sample_id(lazy load))}, ...]
std::vector<TaskInfo> task_list_;
// fast mode: [{[blob_start, blob_end], json}, ...]
// lazy mode: none
std::vector<SampleMeta> sample_meta_list_;
// load type: fast mode or lazy mode
bool lazy_load_;
};
inline void ShardTaskList::AssignTask(ShardTaskList &sourceTasks, int64_t id) {
@ -106,27 +126,45 @@ inline void ShardTaskList::InsertTask(TaskType task_type, int shard_id, int grou
const std::vector<uint64_t> &offset, const json &label) {
MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label);
task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id));
if (lazy_load_ == false) {
sample_meta_list_.emplace_back(offset, label);
}
}
inline void ShardTaskList::InsertTask(const int64_t &i, TaskType task_type, int shard_id, int group_id,
const std::vector<uint64_t> &offset, const json &label) {
MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label};
task_list_[i] = {task_type, std::make_tuple(shard_id, group_id)};
if (lazy_load_ == false) {
sample_meta_list_[i] = {offset, label};
}
}
inline void ShardTaskList::InsertTask(ShardTask task) {
MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << std::get<0>(std::get<1>(task))
<< ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump()
<< ", size of task_list_: " << task_list_.size() << ".";
task_list_.push_back(std::move(task));
task_list_.push_back({std::get<0>(task), std::get<1>(task)});
if (lazy_load_ == false) {
sample_meta_list_.push_back({std::get<2>(task), std::get<3>(task)});
}
}
inline void ShardTaskList::InsertTask(const int64_t &i, ShardTask task) { task_list_[i] = std::move(task); }
inline void ShardTaskList::InsertTask(const int64_t &i, ShardTask task) {
task_list_[i] = {std::get<0>(task), std::get<1>(task)};
if (lazy_load_ == false) {
sample_meta_list_[i] = {std::get<2>(task), std::get<3>(task)};
}
}
inline void ShardTaskList::ResizeTask(const int64_t &size) { task_list_.resize(size); }
inline void ShardTaskList::ResizeTask(const int64_t &size) {
task_list_.resize(size);
if (lazy_load_ == false) {
sample_meta_list_.resize(size);
}
}
} // namespace mindrecord
} // namespace mindspore

View File

@ -110,6 +110,7 @@ Status ShardReader::Init(const std::vector<std::string> &file_paths, bool load_d
if (num_rows_ > LAZY_LOAD_THRESHOLD) {
lazy_load_ = true;
tasks_.lazy_load_ = true;
MS_LOG(WARNING)
<< "The number of samples is larger than " << LAZY_LOAD_THRESHOLD
<< ", enable lazy load mode. If you want to speed up data loading, "

View File

@ -20,13 +20,15 @@
namespace mindspore {
namespace mindrecord {
ShardTaskList::ShardTaskList() : categories(1) {}
ShardTaskList::ShardTaskList() : categories(1), lazy_load_(false) {}
ShardTaskList::ShardTaskList(const ShardTaskList &other)
: categories(other.categories),
permutation_(other.permutation_),
sample_ids_(other.sample_ids_),
task_list_(other.task_list_) {}
task_list_(other.task_list_),
sample_meta_list_(other.sample_meta_list_),
lazy_load_(other.lazy_load_) {}
ShardTaskList &ShardTaskList::operator=(const ShardTaskList &other) {
ShardTaskList tmp(other);
@ -34,6 +36,8 @@ ShardTaskList &ShardTaskList::operator=(const ShardTaskList &other) {
permutation_.swap(tmp.permutation_);
sample_ids_.swap(tmp.sample_ids_);
task_list_.swap(tmp.task_list_);
sample_meta_list_.swap(tmp.sample_meta_list_);
lazy_load_ = tmp.lazy_load_;
return *this;
}
@ -68,22 +72,39 @@ void ShardTaskList::TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_t
std::swap(orig_tasks.sample_ids_, new_tasks.sample_ids_);
}
void ShardTaskList::PopBack() { task_list_.pop_back(); }
void ShardTaskList::PopBack() {
task_list_.pop_back();
if (lazy_load_ == false) {
sample_meta_list_.pop_back();
}
}
int64_t ShardTaskList::Size() const { return static_cast<int64_t>(task_list_.size()); }
int64_t ShardTaskList::SizeOfRows() const {
int64_t size_of_rows = 0;
if (task_list_.size() == 0) {
return static_cast<int64_t>(0);
return size_of_rows;
}
// 1 task is 1 page,blob index start from 2
auto sum_num_rows = [](int64_t x, ShardTask y) { return x + std::get<2>(y)[0]; };
int64_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows);
return nRows;
if (lazy_load_ == false) {
// 1 task is 1 page,blob index start from 2
auto sum_num_rows = [](int64_t x, SampleMeta y) { return x + std::get<0>(y)[0]; };
size_of_rows = std::accumulate(sample_meta_list_.begin(), sample_meta_list_.end(), 0, sum_num_rows);
} else {
MS_LOG(WARNING) << "In lazy load mode, size of rows will be " << size_of_rows << " which is not correctly.";
}
return size_of_rows;
}
ShardTask &ShardTaskList::GetTaskByID(int64_t id) { return task_list_[id]; }
ShardTask ShardTaskList::GetTaskByID(int64_t id) {
if (lazy_load_ == false) {
return {std::get<0>(task_list_[id]), std::get<1>(task_list_[id]), std::get<0>(sample_meta_list_[id]),
std::get<1>(sample_meta_list_[id])};
} else {
return {std::get<0>(task_list_[id]), std::get<1>(task_list_[id]), {}, json()};
}
}
int64_t ShardTaskList::GetTaskSampleByID(int64_t id) { return sample_ids_[id]; }
@ -93,10 +114,16 @@ int64_t ShardTaskList::GetRandomTaskID() {
return dis(gen);
}
ShardTask &ShardTaskList::GetRandomTask() {
ShardTask ShardTaskList::GetRandomTask() {
std::mt19937 gen = GetRandomDevice();
std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
return task_list_[dis(gen)];
size_t random = dis(gen);
if (lazy_load_ == false) {
return {std::get<0>(task_list_[random]), std::get<1>(task_list_[random]), std::get<0>(sample_meta_list_[random]),
std::get<1>(sample_meta_list_[random])};
} else {
return {std::get<0>(task_list_[random]), std::get<1>(task_list_[random]), {}, json()};
}
}
ShardTaskList ShardTaskList::Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements,