forked from mindspore-Ecosystem/mindspore
!48090 reduce memory for mindrecord in lazy mode
Merge pull request !48090 from guozhijian/reduce_mindrecord_memory
This commit is contained in:
commit
232bc67d26
|
@ -32,9 +32,20 @@ namespace mindrecord {
|
|||
// The data struct is as below:
|
||||
// 1. TaskType: kCommonTask / kPaddedTask
|
||||
// 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
|
||||
// 3. std::vector<uint64_t>, json>> : [blob_start, blob_end], scalar_variable_fields
|
||||
// 3. std::vector<uint64_t> : [blob_start, blob_end]
|
||||
// 4. json : scalar_variable_fields
|
||||
using ShardTask = std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>;
|
||||
|
||||
// The data struct is as below:
|
||||
// 1. TaskType: kCommonTask / kPaddedTask
|
||||
// 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
|
||||
using TaskInfo = std::tuple<TaskType, std::tuple<int, int>>;
|
||||
|
||||
// The data struct is as below: contain the meta info
|
||||
// 3. std::vector<uint64_t> : [blob_start, blob_end]
|
||||
// 4. json : scalar_variable_fields
|
||||
using SampleMeta = std::tuple<std::vector<uint64_t>, json>;
|
||||
|
||||
class MINDRECORD_API ShardTaskList {
|
||||
public:
|
||||
ShardTaskList();
|
||||
|
@ -72,9 +83,9 @@ class MINDRECORD_API ShardTaskList {
|
|||
|
||||
int64_t SizeOfRows() const;
|
||||
|
||||
ShardTask &GetTaskByID(int64_t id);
|
||||
ShardTask GetTaskByID(int64_t id);
|
||||
|
||||
ShardTask &GetRandomTask();
|
||||
ShardTask GetRandomTask();
|
||||
|
||||
int64_t GetTaskSampleByID(int64_t id);
|
||||
|
||||
|
@ -91,7 +102,16 @@ class MINDRECORD_API ShardTaskList {
|
|||
|
||||
std::vector<int64_t> sample_ids_; // The list of actual ids that were sampled
|
||||
|
||||
std::vector<ShardTask> task_list_; // The full list of tasks
|
||||
// fast mode: [{TaskType, (shard_id, group_id(fast load))}, ...]
|
||||
// lazy mode: [{TaskType, (shard_id, sample_id(lazy load))}, ...]
|
||||
std::vector<TaskInfo> task_list_;
|
||||
|
||||
// fast mode: [{[blob_start, blob_end], json}, ...]
|
||||
// lazy mode: none
|
||||
std::vector<SampleMeta> sample_meta_list_;
|
||||
|
||||
// load type: fast mode or lazy mode
|
||||
bool lazy_load_;
|
||||
};
|
||||
|
||||
inline void ShardTaskList::AssignTask(ShardTaskList &sourceTasks, int64_t id) {
|
||||
|
@ -106,27 +126,45 @@ inline void ShardTaskList::InsertTask(TaskType task_type, int shard_id, int grou
|
|||
const std::vector<uint64_t> &offset, const json &label) {
|
||||
MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
|
||||
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
|
||||
task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label);
|
||||
task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id));
|
||||
if (lazy_load_ == false) {
|
||||
sample_meta_list_.emplace_back(offset, label);
|
||||
}
|
||||
}
|
||||
|
||||
inline void ShardTaskList::InsertTask(const int64_t &i, TaskType task_type, int shard_id, int group_id,
|
||||
const std::vector<uint64_t> &offset, const json &label) {
|
||||
MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id
|
||||
<< ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << ".";
|
||||
task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label};
|
||||
task_list_[i] = {task_type, std::make_tuple(shard_id, group_id)};
|
||||
if (lazy_load_ == false) {
|
||||
sample_meta_list_[i] = {offset, label};
|
||||
}
|
||||
}
|
||||
|
||||
inline void ShardTaskList::InsertTask(ShardTask task) {
|
||||
MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << std::get<0>(std::get<1>(task))
|
||||
<< ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump()
|
||||
<< ", size of task_list_: " << task_list_.size() << ".";
|
||||
|
||||
task_list_.push_back(std::move(task));
|
||||
task_list_.push_back({std::get<0>(task), std::get<1>(task)});
|
||||
if (lazy_load_ == false) {
|
||||
sample_meta_list_.push_back({std::get<2>(task), std::get<3>(task)});
|
||||
}
|
||||
}
|
||||
|
||||
inline void ShardTaskList::InsertTask(const int64_t &i, ShardTask task) { task_list_[i] = std::move(task); }
|
||||
inline void ShardTaskList::InsertTask(const int64_t &i, ShardTask task) {
|
||||
task_list_[i] = {std::get<0>(task), std::get<1>(task)};
|
||||
if (lazy_load_ == false) {
|
||||
sample_meta_list_[i] = {std::get<2>(task), std::get<3>(task)};
|
||||
}
|
||||
}
|
||||
|
||||
inline void ShardTaskList::ResizeTask(const int64_t &size) { task_list_.resize(size); }
|
||||
inline void ShardTaskList::ResizeTask(const int64_t &size) {
|
||||
task_list_.resize(size);
|
||||
if (lazy_load_ == false) {
|
||||
sample_meta_list_.resize(size);
|
||||
}
|
||||
}
|
||||
} // namespace mindrecord
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -110,6 +110,7 @@ Status ShardReader::Init(const std::vector<std::string> &file_paths, bool load_d
|
|||
|
||||
if (num_rows_ > LAZY_LOAD_THRESHOLD) {
|
||||
lazy_load_ = true;
|
||||
tasks_.lazy_load_ = true;
|
||||
MS_LOG(WARNING)
|
||||
<< "The number of samples is larger than " << LAZY_LOAD_THRESHOLD
|
||||
<< ", enable lazy load mode. If you want to speed up data loading, "
|
||||
|
|
|
@ -20,13 +20,15 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace mindrecord {
|
||||
ShardTaskList::ShardTaskList() : categories(1) {}
|
||||
ShardTaskList::ShardTaskList() : categories(1), lazy_load_(false) {}
|
||||
|
||||
ShardTaskList::ShardTaskList(const ShardTaskList &other)
|
||||
: categories(other.categories),
|
||||
permutation_(other.permutation_),
|
||||
sample_ids_(other.sample_ids_),
|
||||
task_list_(other.task_list_) {}
|
||||
task_list_(other.task_list_),
|
||||
sample_meta_list_(other.sample_meta_list_),
|
||||
lazy_load_(other.lazy_load_) {}
|
||||
|
||||
ShardTaskList &ShardTaskList::operator=(const ShardTaskList &other) {
|
||||
ShardTaskList tmp(other);
|
||||
|
@ -34,6 +36,8 @@ ShardTaskList &ShardTaskList::operator=(const ShardTaskList &other) {
|
|||
permutation_.swap(tmp.permutation_);
|
||||
sample_ids_.swap(tmp.sample_ids_);
|
||||
task_list_.swap(tmp.task_list_);
|
||||
sample_meta_list_.swap(tmp.sample_meta_list_);
|
||||
lazy_load_ = tmp.lazy_load_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -68,22 +72,39 @@ void ShardTaskList::TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_t
|
|||
std::swap(orig_tasks.sample_ids_, new_tasks.sample_ids_);
|
||||
}
|
||||
|
||||
void ShardTaskList::PopBack() { task_list_.pop_back(); }
|
||||
void ShardTaskList::PopBack() {
|
||||
task_list_.pop_back();
|
||||
if (lazy_load_ == false) {
|
||||
sample_meta_list_.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
int64_t ShardTaskList::Size() const { return static_cast<int64_t>(task_list_.size()); }
|
||||
|
||||
int64_t ShardTaskList::SizeOfRows() const {
|
||||
int64_t size_of_rows = 0;
|
||||
if (task_list_.size() == 0) {
|
||||
return static_cast<int64_t>(0);
|
||||
return size_of_rows;
|
||||
}
|
||||
|
||||
// 1 task is 1 page,blob index start from 2
|
||||
auto sum_num_rows = [](int64_t x, ShardTask y) { return x + std::get<2>(y)[0]; };
|
||||
int64_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows);
|
||||
return nRows;
|
||||
if (lazy_load_ == false) {
|
||||
// 1 task is 1 page,blob index start from 2
|
||||
auto sum_num_rows = [](int64_t x, SampleMeta y) { return x + std::get<0>(y)[0]; };
|
||||
size_of_rows = std::accumulate(sample_meta_list_.begin(), sample_meta_list_.end(), 0, sum_num_rows);
|
||||
} else {
|
||||
MS_LOG(WARNING) << "In lazy load mode, size of rows will be " << size_of_rows << " which is not correctly.";
|
||||
}
|
||||
return size_of_rows;
|
||||
}
|
||||
|
||||
ShardTask &ShardTaskList::GetTaskByID(int64_t id) { return task_list_[id]; }
|
||||
ShardTask ShardTaskList::GetTaskByID(int64_t id) {
|
||||
if (lazy_load_ == false) {
|
||||
return {std::get<0>(task_list_[id]), std::get<1>(task_list_[id]), std::get<0>(sample_meta_list_[id]),
|
||||
std::get<1>(sample_meta_list_[id])};
|
||||
} else {
|
||||
return {std::get<0>(task_list_[id]), std::get<1>(task_list_[id]), {}, json()};
|
||||
}
|
||||
}
|
||||
|
||||
int64_t ShardTaskList::GetTaskSampleByID(int64_t id) { return sample_ids_[id]; }
|
||||
|
||||
|
@ -93,10 +114,16 @@ int64_t ShardTaskList::GetRandomTaskID() {
|
|||
return dis(gen);
|
||||
}
|
||||
|
||||
ShardTask &ShardTaskList::GetRandomTask() {
|
||||
ShardTask ShardTaskList::GetRandomTask() {
|
||||
std::mt19937 gen = GetRandomDevice();
|
||||
std::uniform_int_distribution<> dis(0, task_list_.size() - 1);
|
||||
return task_list_[dis(gen)];
|
||||
size_t random = dis(gen);
|
||||
if (lazy_load_ == false) {
|
||||
return {std::get<0>(task_list_[random]), std::get<1>(task_list_[random]), std::get<0>(sample_meta_list_[random]),
|
||||
std::get<1>(sample_meta_list_[random])};
|
||||
} else {
|
||||
return {std::get<0>(task_list_[random]), std::get<1>(task_list_[random]), {}, json()};
|
||||
}
|
||||
}
|
||||
|
||||
ShardTaskList ShardTaskList::Combine(std::vector<ShardTaskList> &category_tasks, bool replacement, int64_t num_elements,
|
||||
|
|
Loading…
Reference in New Issue