impl mindrecord lazy load

This commit is contained in:
jonyguo 2020-12-10 18:57:22 +08:00
parent 05ec9352f3
commit f97d03f695
5 changed files with 211 additions and 12 deletions

View File

@ -181,6 +181,9 @@ std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const Dis
/// \brief get the max hardware concurrency
/// \return max concurrency
uint32_t GetMaxThreadNum();
/// \brief the max number of samples to enable lazy load
const uint32_t LAZY_LOAD_THRESHOLD = 5000000;
} // namespace mindrecord
} // namespace mindspore

View File

@ -77,10 +77,13 @@ class __attribute__((visibility("default"))) ShardReader {
/// \param[in] n_consumer number of threads when reading
/// \param[in] selected_columns column list to be populated
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
/// \param[in] num_padded the number of padded samples
/// \param[in] lazy_load if the mindrecord dataset is too large, enable lazy load mode to speed up initialization
/// \return MSRStatus the status of MSRStatus
MSRStatus Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4,
const std::vector<std::string> &selected_columns = {},
const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int num_padded = 0);
const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int num_padded = 0,
bool lazy_load = false);
/// \brief open files and initialize reader, python API
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
@ -218,6 +221,10 @@ class __attribute__((visibility("default"))) ShardReader {
/// \brief read all rows for specified columns
ROW_GROUPS ReadAllRowGroup(std::vector<std::string> &columns);
/// \brief read row meta by shard_id and sample_id
ROW_GROUPS ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
const uint32_t &sample_id);
/// \brief read all rows in one shard
MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
std::vector<std::vector<std::vector<uint64_t>>> &offsets,
@ -257,6 +264,10 @@ class __attribute__((visibility("default"))) ShardReader {
MSRStatus CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators);
/// \brief create task list in row-reader mode and lazy mode
MSRStatus CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators);
/// \brief crate task list
MSRStatus CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators);
@ -325,6 +336,15 @@ class __attribute__((visibility("default"))) ShardReader {
// map of delivery
std::unordered_map<int, std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>>> delivery_map_;
// Delivery/Iterator mode end
// all metadata in the index is not loaded during initialization
bool lazy_load_;
// indicate shard_id : inc_count
// 0 : 15 - shard0 has 15 samples
// 1 : 41 - shard1 has 26 samples
// 2 : 58 - shard2 has 17 samples
std::vector<uint32_t> shard_sample_count_;
};
} // namespace mindrecord
} // namespace mindspore

View File

@ -67,8 +67,13 @@ class __attribute__((visibility("default"))) ShardTask {
uint32_t categories;
// The total sample ids which used to shuffle operation. The ids like: [0, 1, 2, 3, ..., n-1, n]
std::vector<int> permutation_;
// The data struct is as below:
// 1. TaskType: kCommonTask / kPaddedTask
// 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
// 3. std::vector<uint64_t>, json>> : [blob_start, blob_end], scalar_variable_fields
std::vector<std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>> task_list_;
};

View File

@ -47,7 +47,9 @@ ShardReader::ShardReader()
num_rows_(0),
total_blob_size_(0),
task_id_(0),
deliver_id_(0) {}
deliver_id_(0),
lazy_load_(false),
shard_sample_count_() {}
std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path, json &meta_data) {
if (!IsLegalFile(file_path)) {
@ -148,6 +150,16 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
for (const auto &rg : row_group_summary) {
num_rows_ += std::get<3>(rg);
}
if (num_rows_ > LAZY_LOAD_THRESHOLD) {
lazy_load_ = true;
MS_LOG(WARNING) << "The number of samples is larger than " << LAZY_LOAD_THRESHOLD
<< ", enable lazy load mode. If you want to speed up data loading, "
<< "it is recommended that you save multiple samples into one record when creating mindrecord file,"
<< " so that you can enable fast loading mode, and don't forget to adjust your batch size "
<< "according to the current samples.";
}
auto disk_size = page_size_ * row_group_summary.size();
auto compression_size = shard_header_->GetCompressionSize();
total_blob_size_ = disk_size + compression_size;
@ -270,6 +282,7 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar
return row_group_summary;
}
if (shard_count <= kMaxFileCount) {
uint32_t total_count = 0;
for (int shard_id = 0; shard_id < shard_count; ++shard_id) {
// return -1 when page's size equals to 0.
auto last_page_id = shard_header_->GetLastPageId(shard_id);
@ -285,8 +298,10 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar
return std::vector<std::tuple<int, int, int, uint64_t>>();
}
uint64_t number_of_rows = page->GetEndRowID() - start_row_id;
total_count += number_of_rows;
row_group_summary.emplace_back(shard_id, page->GetPageTypeID(), start_row_id, number_of_rows);
}
shard_sample_count_.push_back(total_count);
}
}
return row_group_summary;
@ -472,6 +487,34 @@ ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector<std::string> &columns) {
return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values));
}
ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns,
const uint32_t &shard_id, const uint32_t &sample_id) {
std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
std::vector<std::vector<std::vector<uint64_t>>> offsets(shard_count_, std::vector<std::vector<uint64_t>>{});
std::vector<std::vector<json>> column_values(shard_count_, std::vector<json>{});
if (all_in_index_) {
for (unsigned int i = 0; i < columns.size(); ++i) {
fields += ',';
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]));
if (ret.first != SUCCESS) {
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values));
}
fields += ret.second;
}
} else { // fetch raw data from Raw page while some field is not index.
fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END ";
}
std::string sql = "SELECT " + fields + " FROM INDEXES WHERE ROW_ID = " + std::to_string(sample_id);
if (ReadAllRowsInShard(shard_id, sql, columns, offsets, column_values) != SUCCESS) {
MS_LOG(ERROR) << "Read shard id: " << shard_id << ", sample id: " << sample_id << " from index failed.";
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values));
}
return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values));
}
ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns) {
const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id);
if (SUCCESS != ret.first) {
@ -868,7 +911,10 @@ MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths
MSRStatus ShardReader::Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer,
const std::vector<std::string> &selected_columns,
const std::vector<std::shared_ptr<ShardOperator>> &operators, int num_padded) {
const std::vector<std::shared_ptr<ShardOperator>> &operators, int num_padded,
bool lazy_load) {
lazy_load_ = lazy_load;
// Open file and set header by ShardReader
auto ret = Init(file_paths, load_dataset);
if (SUCCESS != ret) {
@ -1077,6 +1123,44 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, i
return SUCCESS;
}
MSRStatus ShardReader::CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
CheckIfColumnInIndex(selected_columns_);
if (shard_count_ <= kMaxFileCount) {
uint32_t sample_count = shard_sample_count_[shard_sample_count_.size() - 1];
MS_LOG(DEBUG) << "There are " << sample_count << " records in the dataset.";
// Init the tasks_ size
tasks_.ResizeTask(sample_count);
// Init the task threads, maybe use ThreadPool is better
std::vector<std::thread> init_tasks_thread(shard_count_);
for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
// the offset indicate the shard start
uint32_t current_offset = shard_id == 0 ? 0 : shard_sample_count_[shard_id - 1];
// the count indicate the number of samples in the shard
uint32_t shard_count =
shard_id == 0 ? shard_sample_count_[0] : shard_sample_count_[shard_id] - shard_sample_count_[shard_id - 1];
init_tasks_thread[shard_id] = std::thread([this, shard_id, current_offset, shard_count]() {
for (uint32_t i = current_offset; i < shard_count + current_offset; ++i) {
// here "i - current_offset" indicate the sample id in the shard
tasks_.InsertTask(i, TaskType::kCommonTask, shard_id, i - current_offset, {}, json());
}
});
}
for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
init_tasks_thread[shard_id].join();
}
} else {
return FAILED;
}
return SUCCESS;
}
MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
int category_operator = -1;
@ -1088,9 +1172,17 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
}
}
if (-1 == category_operator) {
if (SUCCESS != CreateTasksByRow(row_group_summary, operators)) {
return FAILED;
if (lazy_load_ == false) {
if (SUCCESS != CreateTasksByRow(row_group_summary, operators)) {
return FAILED;
}
} else {
if (SUCCESS != CreateLazyTasksByRow(row_group_summary, operators)) {
return FAILED;
}
}
// need padded sample to the task
if (num_padded_ > 0) {
for (int i = 0; i < num_padded_; ++i) {
tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json());
@ -1123,6 +1215,12 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
std::make_pair(TaskType::kCommonTask, std::vector<std::tuple<std::vector<uint8_t>, json>>()));
}
uint32_t shard_id = 0;
uint32_t group_id = 0;
uint32_t blob_start = 0;
uint32_t blob_end = 0;
json var_fields;
// Pick up task from task list
auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]);
@ -1133,9 +1231,33 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
std::make_pair(TaskType::kPaddedTask, std::vector<std::tuple<std::vector<uint8_t>, json>>()));
}
auto shard_id = std::get<0>(std::get<1>(task));
auto group_id = std::get<1>(std::get<1>(task));
auto addr = std::get<2>(task);
shard_id = std::get<0>(std::get<1>(task)); // shard id
if (lazy_load_ == false) {
group_id = std::get<1>(std::get<1>(task)); // group id
blob_start = std::get<2>(task)[0]; // blob start
blob_end = std::get<2>(task)[1]; // blob end
var_fields = std::get<3>(task); // scalar variable field
} else {
// get scalar variable fields by sample id
uint32_t sample_id_in_shard = std::get<1>(std::get<1>(task));
// read the meta from index
auto row_meta = ReadRowGroupByShardIDAndSampleID(selected_columns_, shard_id, sample_id_in_shard);
if (std::get<0>(row_meta) != SUCCESS) {
return std::make_pair(
FAILED, std::make_pair(TaskType::kCommonTask, std::vector<std::tuple<std::vector<uint8_t>, json>>()));
}
auto &offsets = std::get<1>(row_meta);
auto &local_columns = std::get<2>(row_meta);
group_id = offsets[shard_id][0][1]; // group_id
blob_start = offsets[shard_id][0][2]; // blob start
blob_end = offsets[shard_id][0][3]; // blob end
var_fields = local_columns[shard_id][0]; // scalar variable field
}
// read the blob from data file
const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id);
if (SUCCESS != ret.first) {
return std::make_pair(FAILED,
@ -1144,8 +1266,8 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
const std::shared_ptr<Page> &page = ret.second;
// Pack image list
std::vector<uint8_t> images(addr[1] - addr[0]);
auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + addr[0];
std::vector<uint8_t> images(blob_end - blob_start);
auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + blob_start;
auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg);
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
@ -1156,7 +1278,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
}
auto &io_read =
file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast<char *>(&images[0]), addr[1] - addr[0]);
file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast<char *>(&images[0]), blob_end - blob_start);
if (!io_read.good() || io_read.fail() || io_read.bad()) {
MS_LOG(ERROR) << "File read failed";
file_streams_random_[consumer_id][shard_id]->close();
@ -1166,7 +1288,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
// Deliver batch data to output map
std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
batch.emplace_back(std::move(images), std::move(std::get<3>(task)));
batch.emplace_back(std::move(images), std::move(var_fields));
return std::make_pair(SUCCESS, std::make_pair(TaskType::kCommonTask, std::move(batch)));
}

View File

@ -70,6 +70,30 @@ TEST_F(TestShardReader, TestShardReaderGeneral) {
dataset.Close();
}
TEST_F(TestShardReader, TestShardReaderLazyLoad) {
MS_LOG(INFO) << FormatInfo("Test read imageNet");
std::string file_name = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"file_name"};
ShardReader dataset;
dataset.Open({file_name}, true, 4, column_list, {}, 0, true);
dataset.Launch();
uint32_t count = 0;
while (true) {
auto x = dataset.GetNext();
if (x.empty()) break;
for (auto &j : x) {
for (auto &item : std::get<1>(j).items()) {
MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump();
}
}
count++;
}
ASSERT_TRUE(count == 10);
dataset.Close();
}
TEST_F(TestShardReader, TestShardReaderSample) {
MS_LOG(INFO) << FormatInfo("Test read imageNet");
std::string file_name = "./imagenet.shard01";
@ -91,6 +115,31 @@ TEST_F(TestShardReader, TestShardReaderSample) {
}
}
dataset.Close();
}
TEST_F(TestShardReader, TestShardReaderLazyLoadDistributed) {
MS_LOG(INFO) << FormatInfo("Test read imageNet");
std::string file_name = "./imagenet.shard01";
auto column_list = std::vector<std::string>{"file_name"};
std::vector<std::shared_ptr<ShardOperator>> ops;
ops.push_back(std::make_shared<ShardSample>(1, 8));
ShardReader dataset;
dataset.Open({file_name}, true, 4, column_list, ops, 0, true);
dataset.Launch();
uint32_t count = 0;
while (true) {
auto x = dataset.GetNext();
if (x.empty()) break;
for (auto &j : x) {
for (auto &item : std::get<1>(j).items()) {
MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump();
}
}
count++;
}
ASSERT_TRUE(count == 2);
dataset.Close();
}