diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_task_list.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_task_list.h index a5a32b0b3b9..a754499e21d 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_task_list.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_task_list.h @@ -32,9 +32,20 @@ namespace mindrecord { // The data struct is as below: // 1. TaskType: kCommonTask / kPaddedTask // 2. std::tuple : shard_id, group_id(fast load) / sample_id(lazy load) -// 3. std::vector, json>> : [blob_start, blob_end], scalar_variable_fields +// 3. std::vector : [blob_start, blob_end] +// 4. json : scalar_variable_fields using ShardTask = std::tuple, std::vector, json>; +// The data struct is as below: +// 1. TaskType: kCommonTask / kPaddedTask +// 2. std::tuple : shard_id, group_id(fast load) / sample_id(lazy load) +using TaskInfo = std::tuple>; + +// The data struct is as below: contain the meta info +// 3. std::vector : [blob_start, blob_end] +// 4. json : scalar_variable_fields +using SampleMeta = std::tuple, json>; + class MINDRECORD_API ShardTaskList { public: ShardTaskList(); @@ -72,9 +83,9 @@ class MINDRECORD_API ShardTaskList { int64_t SizeOfRows() const; - ShardTask &GetTaskByID(int64_t id); + ShardTask GetTaskByID(int64_t id); - ShardTask &GetRandomTask(); + ShardTask GetRandomTask(); int64_t GetTaskSampleByID(int64_t id); @@ -91,7 +102,16 @@ class MINDRECORD_API ShardTaskList { std::vector sample_ids_; // The list of actual ids that were sampled - std::vector task_list_; // The full list of tasks + // fast mode: [{TaskType, (shard_id, group_id(fast load))}, ...] + // lazy mode: [{TaskType, (shard_id, sample_id(lazy load))}, ...] + std::vector task_list_; + + // fast mode: [{[blob_start, blob_end], json}, ...] + // lazy mode: none + std::vector sample_meta_list_; + + // load type: fast mode or lazy mode + bool lazy_load_; }; inline void ShardTaskList::AssignTask(ShardTaskList &sourceTasks, int64_t id) { @@ -106,27 +126,45 @@ inline void ShardTaskList::InsertTask(TaskType task_type, int shard_id, int grou const std::vector &offset, const json &label) { MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; - task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id), offset, label); + task_list_.emplace_back(task_type, std::make_tuple(shard_id, group_id)); + if (lazy_load_ == false) { + sample_meta_list_.emplace_back(offset, label); + } } inline void ShardTaskList::InsertTask(const int64_t &i, TaskType task_type, int shard_id, int group_id, const std::vector &offset, const json &label) { MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << shard_id << ", group_id: " << group_id << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; - task_list_[i] = {task_type, std::make_tuple(shard_id, group_id), offset, label}; + task_list_[i] = {task_type, std::make_tuple(shard_id, group_id)}; + if (lazy_load_ == false) { + sample_meta_list_[i] = {offset, label}; + } } inline void ShardTaskList::InsertTask(ShardTask task) { MS_LOG(DEBUG) << "Insert task into task list, shard_id: " << std::get<0>(std::get<1>(task)) << ", group_id: " << std::get<1>(std::get<1>(task)) << ", label: " << std::get<3>(task).dump() << ", size of task_list_: " << task_list_.size() << "."; - - task_list_.push_back(std::move(task)); + task_list_.push_back({std::get<0>(task), std::get<1>(task)}); + if (lazy_load_ == false) { + sample_meta_list_.push_back({std::get<2>(task), std::get<3>(task)}); + } } -inline void ShardTaskList::InsertTask(const int64_t &i, ShardTask task) { task_list_[i] = std::move(task); } +inline void ShardTaskList::InsertTask(const int64_t &i, ShardTask task) { + task_list_[i] = {std::get<0>(task), std::get<1>(task)}; + if (lazy_load_ == false) { + sample_meta_list_[i] = {std::get<2>(task), std::get<3>(task)}; + } +} -inline void ShardTaskList::ResizeTask(const int64_t &size) { task_list_.resize(size); } +inline void ShardTaskList::ResizeTask(const int64_t &size) { + task_list_.resize(size); + if (lazy_load_ == false) { + sample_meta_list_.resize(size); + } +} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc index 5de9903aea7..787b2f7e6be 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc @@ -110,6 +110,7 @@ Status ShardReader::Init(const std::vector &file_paths, bool load_d if (num_rows_ > LAZY_LOAD_THRESHOLD) { lazy_load_ = true; + tasks_.lazy_load_ = true; MS_LOG(WARNING) << "The number of samples is larger than " << LAZY_LOAD_THRESHOLD << ", enable lazy load mode. If you want to speed up data loading, " diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc index ea776c53356..89958536968 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc @@ -20,13 +20,15 @@ namespace mindspore { namespace mindrecord { -ShardTaskList::ShardTaskList() : categories(1) {} +ShardTaskList::ShardTaskList() : categories(1), lazy_load_(false) {} ShardTaskList::ShardTaskList(const ShardTaskList &other) : categories(other.categories), permutation_(other.permutation_), sample_ids_(other.sample_ids_), - task_list_(other.task_list_) {} + task_list_(other.task_list_), + sample_meta_list_(other.sample_meta_list_), + lazy_load_(other.lazy_load_) {} ShardTaskList &ShardTaskList::operator=(const ShardTaskList &other) { ShardTaskList tmp(other); @@ -34,6 +36,8 @@ ShardTaskList &ShardTaskList::operator=(const ShardTaskList &other) { permutation_.swap(tmp.permutation_); sample_ids_.swap(tmp.sample_ids_); task_list_.swap(tmp.task_list_); + sample_meta_list_.swap(tmp.sample_meta_list_); + lazy_load_ = tmp.lazy_load_; return *this; } @@ -68,22 +72,39 @@ void ShardTaskList::TaskListSwap(ShardTaskList &orig_tasks, ShardTaskList &new_t std::swap(orig_tasks.sample_ids_, new_tasks.sample_ids_); } -void ShardTaskList::PopBack() { task_list_.pop_back(); } +void ShardTaskList::PopBack() { + task_list_.pop_back(); + if (lazy_load_ == false) { + sample_meta_list_.pop_back(); + } +} int64_t ShardTaskList::Size() const { return static_cast(task_list_.size()); } int64_t ShardTaskList::SizeOfRows() const { + int64_t size_of_rows = 0; if (task_list_.size() == 0) { - return static_cast(0); + return size_of_rows; } - // 1 task is 1 page,blob index start from 2 - auto sum_num_rows = [](int64_t x, ShardTask y) { return x + std::get<2>(y)[0]; }; - int64_t nRows = std::accumulate(task_list_.begin(), task_list_.end(), 0, sum_num_rows); - return nRows; + if (lazy_load_ == false) { + // 1 task is 1 page,blob index start from 2 + auto sum_num_rows = [](int64_t x, SampleMeta y) { return x + std::get<0>(y)[0]; }; + size_of_rows = std::accumulate(sample_meta_list_.begin(), sample_meta_list_.end(), 0, sum_num_rows); + } else { + MS_LOG(WARNING) << "In lazy load mode, size of rows will be " << size_of_rows << " which is not correctly."; + } + return size_of_rows; } -ShardTask &ShardTaskList::GetTaskByID(int64_t id) { return task_list_[id]; } +ShardTask ShardTaskList::GetTaskByID(int64_t id) { + if (lazy_load_ == false) { + return {std::get<0>(task_list_[id]), std::get<1>(task_list_[id]), std::get<0>(sample_meta_list_[id]), + std::get<1>(sample_meta_list_[id])}; + } else { + return {std::get<0>(task_list_[id]), std::get<1>(task_list_[id]), {}, json()}; + } +} int64_t ShardTaskList::GetTaskSampleByID(int64_t id) { return sample_ids_[id]; } @@ -93,10 +114,16 @@ int64_t ShardTaskList::GetRandomTaskID() { return dis(gen); } -ShardTask &ShardTaskList::GetRandomTask() { +ShardTask ShardTaskList::GetRandomTask() { std::mt19937 gen = GetRandomDevice(); std::uniform_int_distribution<> dis(0, task_list_.size() - 1); - return task_list_[dis(gen)]; + size_t random = dis(gen); + if (lazy_load_ == false) { + return {std::get<0>(task_list_[random]), std::get<1>(task_list_[random]), std::get<0>(sample_meta_list_[random]), + std::get<1>(sample_meta_list_[random])}; + } else { + return {std::get<0>(task_list_[random]), std::get<1>(task_list_[random]), {}, json()}; + } } ShardTaskList ShardTaskList::Combine(std::vector &category_tasks, bool replacement, int64_t num_elements,