diff --git a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h index 0b1ab4e3b01..d8ac5654711 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h @@ -181,6 +181,9 @@ std::pair GetDiskSize(const std::string &str_dir, const Dis /// \brief get the max hardware concurrency /// \return max concurrency uint32_t GetMaxThreadNum(); + +/// \brief the max number of samples to enable lazy load +const uint32_t LAZY_LOAD_THRESHOLD = 5000000; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h index f48445121e9..eb891c228f3 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h @@ -77,10 +77,13 @@ class __attribute__((visibility("default"))) ShardReader { /// \param[in] n_consumer number of threads when reading /// \param[in] selected_columns column list to be populated /// \param[in] operators operators applied to data, operator type is shuffle, sample or category + /// \param[in] num_padded the number of padded samples + /// \param[in] lazy_load if the mindrecord dataset is too large, enable lazy load mode to speed up initialization /// \return MSRStatus the status of MSRStatus MSRStatus Open(const std::vector &file_paths, bool load_dataset, int n_consumer = 4, const std::vector &selected_columns = {}, - const std::vector> &operators = {}, const int num_padded = 0); + const std::vector> &operators = {}, const int num_padded = 0, + bool lazy_load = false); /// \brief open files and initialize reader, python API /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list @@ -218,6 +221,10 @@ class __attribute__((visibility("default"))) ShardReader { /// \brief read all rows for specified columns ROW_GROUPS ReadAllRowGroup(std::vector &columns); + /// \brief read row meta by shard_id and sample_id + ROW_GROUPS ReadRowGroupByShardIDAndSampleID(const std::vector &columns, const uint32_t &shard_id, + const uint32_t &sample_id); + /// \brief read all rows in one shard MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector &columns, std::vector>> &offsets, @@ -257,6 +264,10 @@ class __attribute__((visibility("default"))) ShardReader { MSRStatus CreateTasksByRow(const std::vector> &row_group_summary, const std::vector> &operators); + /// \brief create task list in row-reader mode and lazy mode + MSRStatus CreateLazyTasksByRow(const std::vector> &row_group_summary, + const std::vector> &operators); + /// \brief crate task list MSRStatus CreateTasks(const std::vector> &row_group_summary, const std::vector> &operators); @@ -325,6 +336,15 @@ class __attribute__((visibility("default"))) ShardReader { // map of delivery std::unordered_map, json>>>> delivery_map_; // Delivery/Iterator mode end + + // all metadata in the index is not loaded during initialization + bool lazy_load_; + + // indicate shard_id : inc_count + // 0 : 15 - shard0 has 15 samples + // 1 : 41 - shard1 has 26 samples + // 2 : 58 - shard2 has 17 samples + std::vector shard_sample_count_; }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h index f1e59f538c3..395eda3a3de 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_task.h @@ -67,8 +67,13 @@ class __attribute__((visibility("default"))) ShardTask { uint32_t categories; + // The total sample ids which used to shuffle operation. The ids like: [0, 1, 2, 3, ..., n-1, n] std::vector permutation_; + // The data struct is as below: + // 1. TaskType: kCommonTask / kPaddedTask + // 2. std::tuple : shard_id, group_id(fast load) / sample_id(lazy load) + // 3. std::vector, json>> : [blob_start, blob_end], scalar_variable_fields std::vector, std::vector, json>> task_list_; }; diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc index e9162bfebaf..4c4840a66d9 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc @@ -47,7 +47,9 @@ ShardReader::ShardReader() num_rows_(0), total_blob_size_(0), task_id_(0), - deliver_id_(0) {} + deliver_id_(0), + lazy_load_(false), + shard_sample_count_() {} std::pair> ShardReader::GetMeta(const std::string &file_path, json &meta_data) { if (!IsLegalFile(file_path)) { @@ -148,6 +150,16 @@ MSRStatus ShardReader::Init(const std::vector &file_paths, bool loa for (const auto &rg : row_group_summary) { num_rows_ += std::get<3>(rg); } + + if (num_rows_ > LAZY_LOAD_THRESHOLD) { + lazy_load_ = true; + MS_LOG(WARNING) << "The number of samples is larger than " << LAZY_LOAD_THRESHOLD + << ", enable lazy load mode. If you want to speed up data loading, " + << "it is recommended that you save multiple samples into one record when creating mindrecord file," + << " so that you can enable fast loading mode, and don't forget to adjust your batch size " + << "according to the current samples."; + } + auto disk_size = page_size_ * row_group_summary.size(); auto compression_size = shard_header_->GetCompressionSize(); total_blob_size_ = disk_size + compression_size; @@ -270,6 +282,7 @@ std::vector> ShardReader::ReadRowGroupSummar return row_group_summary; } if (shard_count <= kMaxFileCount) { + uint32_t total_count = 0; for (int shard_id = 0; shard_id < shard_count; ++shard_id) { // return -1 when page's size equals to 0. auto last_page_id = shard_header_->GetLastPageId(shard_id); @@ -285,8 +298,10 @@ std::vector> ShardReader::ReadRowGroupSummar return std::vector>(); } uint64_t number_of_rows = page->GetEndRowID() - start_row_id; + total_count += number_of_rows; row_group_summary.emplace_back(shard_id, page->GetPageTypeID(), start_row_id, number_of_rows); } + shard_sample_count_.push_back(total_count); } } return row_group_summary; @@ -472,6 +487,34 @@ ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector &columns) { return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values)); } +ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector &columns, + const uint32_t &shard_id, const uint32_t &sample_id) { + std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; + std::vector>> offsets(shard_count_, std::vector>{}); + std::vector> column_values(shard_count_, std::vector{}); + if (all_in_index_) { + for (unsigned int i = 0; i < columns.size(); ++i) { + fields += ','; + auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i])); + if (ret.first != SUCCESS) { + return std::make_tuple(FAILED, std::move(offsets), std::move(column_values)); + } + fields += ret.second; + } + } else { // fetch raw data from Raw page while some field is not index. + fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END "; + } + + std::string sql = "SELECT " + fields + " FROM INDEXES WHERE ROW_ID = " + std::to_string(sample_id); + + if (ReadAllRowsInShard(shard_id, sql, columns, offsets, column_values) != SUCCESS) { + MS_LOG(ERROR) << "Read shard id: " << shard_id << ", sample id: " << sample_id << " from index failed."; + return std::make_tuple(FAILED, std::move(offsets), std::move(column_values)); + } + + return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values)); +} + ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector &columns) { const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); if (SUCCESS != ret.first) { @@ -868,7 +911,10 @@ MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths MSRStatus ShardReader::Open(const std::vector &file_paths, bool load_dataset, int n_consumer, const std::vector &selected_columns, - const std::vector> &operators, int num_padded) { + const std::vector> &operators, int num_padded, + bool lazy_load) { + lazy_load_ = lazy_load; + // Open file and set header by ShardReader auto ret = Init(file_paths, load_dataset); if (SUCCESS != ret) { @@ -1077,6 +1123,44 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector> &row_group_summary, + const std::vector> &operators) { + CheckIfColumnInIndex(selected_columns_); + + if (shard_count_ <= kMaxFileCount) { + uint32_t sample_count = shard_sample_count_[shard_sample_count_.size() - 1]; + MS_LOG(DEBUG) << "There are " << sample_count << " records in the dataset."; + + // Init the tasks_ size + tasks_.ResizeTask(sample_count); + + // Init the task threads, maybe use ThreadPool is better + std::vector init_tasks_thread(shard_count_); + + for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) { + // the offset indicate the shard start + uint32_t current_offset = shard_id == 0 ? 0 : shard_sample_count_[shard_id - 1]; + + // the count indicate the number of samples in the shard + uint32_t shard_count = + shard_id == 0 ? shard_sample_count_[0] : shard_sample_count_[shard_id] - shard_sample_count_[shard_id - 1]; + init_tasks_thread[shard_id] = std::thread([this, shard_id, current_offset, shard_count]() { + for (uint32_t i = current_offset; i < shard_count + current_offset; ++i) { + // here "i - current_offset" indicate the sample id in the shard + tasks_.InsertTask(i, TaskType::kCommonTask, shard_id, i - current_offset, {}, json()); + } + }); + } + + for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) { + init_tasks_thread[shard_id].join(); + } + } else { + return FAILED; + } + return SUCCESS; +} + MSRStatus ShardReader::CreateTasks(const std::vector> &row_group_summary, const std::vector> &operators) { int category_operator = -1; @@ -1088,9 +1172,17 @@ MSRStatus ShardReader::CreateTasks(const std::vector 0) { for (int i = 0; i < num_padded_; ++i) { tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json()); @@ -1123,6 +1215,12 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ std::make_pair(TaskType::kCommonTask, std::vector, json>>())); } + uint32_t shard_id = 0; + uint32_t group_id = 0; + uint32_t blob_start = 0; + uint32_t blob_end = 0; + json var_fields; + // Pick up task from task list auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); @@ -1133,9 +1231,33 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ std::make_pair(TaskType::kPaddedTask, std::vector, json>>())); } - auto shard_id = std::get<0>(std::get<1>(task)); - auto group_id = std::get<1>(std::get<1>(task)); - auto addr = std::get<2>(task); + shard_id = std::get<0>(std::get<1>(task)); // shard id + + if (lazy_load_ == false) { + group_id = std::get<1>(std::get<1>(task)); // group id + blob_start = std::get<2>(task)[0]; // blob start + blob_end = std::get<2>(task)[1]; // blob end + var_fields = std::get<3>(task); // scalar variable field + } else { + // get scalar variable fields by sample id + uint32_t sample_id_in_shard = std::get<1>(std::get<1>(task)); + + // read the meta from index + auto row_meta = ReadRowGroupByShardIDAndSampleID(selected_columns_, shard_id, sample_id_in_shard); + if (std::get<0>(row_meta) != SUCCESS) { + return std::make_pair( + FAILED, std::make_pair(TaskType::kCommonTask, std::vector, json>>())); + } + auto &offsets = std::get<1>(row_meta); + auto &local_columns = std::get<2>(row_meta); + + group_id = offsets[shard_id][0][1]; // group_id + blob_start = offsets[shard_id][0][2]; // blob start + blob_end = offsets[shard_id][0][3]; // blob end + var_fields = local_columns[shard_id][0]; // scalar variable field + } + + // read the blob from data file const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); if (SUCCESS != ret.first) { return std::make_pair(FAILED, @@ -1144,8 +1266,8 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ const std::shared_ptr &page = ret.second; // Pack image list - std::vector images(addr[1] - addr[0]); - auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + addr[0]; + std::vector images(blob_end - blob_start); + auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + blob_start; auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg); if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { @@ -1156,7 +1278,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ } auto &io_read = - file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast(&images[0]), addr[1] - addr[0]); + file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast(&images[0]), blob_end - blob_start); if (!io_read.good() || io_read.fail() || io_read.bad()) { MS_LOG(ERROR) << "File read failed"; file_streams_random_[consumer_id][shard_id]->close(); @@ -1166,7 +1288,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ // Deliver batch data to output map std::vector, json>> batch; - batch.emplace_back(std::move(images), std::move(std::get<3>(task))); + batch.emplace_back(std::move(images), std::move(var_fields)); return std::make_pair(SUCCESS, std::make_pair(TaskType::kCommonTask, std::move(batch))); } diff --git a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc index 7b56f5e18f9..a7102ee9188 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc @@ -70,6 +70,30 @@ TEST_F(TestShardReader, TestShardReaderGeneral) { dataset.Close(); } +TEST_F(TestShardReader, TestShardReaderLazyLoad) { + MS_LOG(INFO) << FormatInfo("Test read imageNet"); + std::string file_name = "./imagenet.shard01"; + auto column_list = std::vector{"file_name"}; + + ShardReader dataset; + dataset.Open({file_name}, true, 4, column_list, {}, 0, true); + dataset.Launch(); + + uint32_t count = 0; + while (true) { + auto x = dataset.GetNext(); + if (x.empty()) break; + for (auto &j : x) { + for (auto &item : std::get<1>(j).items()) { + MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); + } + } + count++; + } + ASSERT_TRUE(count == 10); + dataset.Close(); +} + TEST_F(TestShardReader, TestShardReaderSample) { MS_LOG(INFO) << FormatInfo("Test read imageNet"); std::string file_name = "./imagenet.shard01"; @@ -91,6 +115,31 @@ TEST_F(TestShardReader, TestShardReaderSample) { } } dataset.Close(); +} + +TEST_F(TestShardReader, TestShardReaderLazyLoadDistributed) { + MS_LOG(INFO) << FormatInfo("Test read imageNet"); + std::string file_name = "./imagenet.shard01"; + auto column_list = std::vector{"file_name"}; + + std::vector> ops; + ops.push_back(std::make_shared(1, 8)); + ShardReader dataset; + dataset.Open({file_name}, true, 4, column_list, ops, 0, true); + dataset.Launch(); + + uint32_t count = 0; + while (true) { + auto x = dataset.GetNext(); + if (x.empty()) break; + for (auto &j : x) { + for (auto &item : std::get<1>(j).items()) { + MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump(); + } + } + count++; + } + ASSERT_TRUE(count == 2); dataset.Close(); }