forked from mindspore-Ecosystem/mindspore
impl mindrecord lazy load
This commit is contained in:
parent
05ec9352f3
commit
f97d03f695
|
@ -181,6 +181,9 @@ std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const Dis
|
||||||
/// \brief get the max hardware concurrency
|
/// \brief get the max hardware concurrency
|
||||||
/// \return max concurrency
|
/// \return max concurrency
|
||||||
uint32_t GetMaxThreadNum();
|
uint32_t GetMaxThreadNum();
|
||||||
|
|
||||||
|
/// \brief the max number of samples to enable lazy load
|
||||||
|
const uint32_t LAZY_LOAD_THRESHOLD = 5000000;
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -77,10 +77,13 @@ class __attribute__((visibility("default"))) ShardReader {
|
||||||
/// \param[in] n_consumer number of threads when reading
|
/// \param[in] n_consumer number of threads when reading
|
||||||
/// \param[in] selected_columns column list to be populated
|
/// \param[in] selected_columns column list to be populated
|
||||||
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
|
/// \param[in] operators operators applied to data, operator type is shuffle, sample or category
|
||||||
|
/// \param[in] num_padded the number of padded samples
|
||||||
|
/// \param[in] lazy_load if the mindrecord dataset is too large, enable lazy load mode to speed up initialization
|
||||||
/// \return MSRStatus the status of MSRStatus
|
/// \return MSRStatus the status of MSRStatus
|
||||||
MSRStatus Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4,
|
MSRStatus Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4,
|
||||||
const std::vector<std::string> &selected_columns = {},
|
const std::vector<std::string> &selected_columns = {},
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int num_padded = 0);
|
const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int num_padded = 0,
|
||||||
|
bool lazy_load = false);
|
||||||
|
|
||||||
/// \brief open files and initialize reader, python API
|
/// \brief open files and initialize reader, python API
|
||||||
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
|
/// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list
|
||||||
|
@ -218,6 +221,10 @@ class __attribute__((visibility("default"))) ShardReader {
|
||||||
/// \brief read all rows for specified columns
|
/// \brief read all rows for specified columns
|
||||||
ROW_GROUPS ReadAllRowGroup(std::vector<std::string> &columns);
|
ROW_GROUPS ReadAllRowGroup(std::vector<std::string> &columns);
|
||||||
|
|
||||||
|
/// \brief read row meta by shard_id and sample_id
|
||||||
|
ROW_GROUPS ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id,
|
||||||
|
const uint32_t &sample_id);
|
||||||
|
|
||||||
/// \brief read all rows in one shard
|
/// \brief read all rows in one shard
|
||||||
MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
|
MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns,
|
||||||
std::vector<std::vector<std::vector<uint64_t>>> &offsets,
|
std::vector<std::vector<std::vector<uint64_t>>> &offsets,
|
||||||
|
@ -257,6 +264,10 @@ class __attribute__((visibility("default"))) ShardReader {
|
||||||
MSRStatus CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
MSRStatus CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &operators);
|
const std::vector<std::shared_ptr<ShardOperator>> &operators);
|
||||||
|
|
||||||
|
/// \brief create task list in row-reader mode and lazy mode
|
||||||
|
MSRStatus CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||||
|
const std::vector<std::shared_ptr<ShardOperator>> &operators);
|
||||||
|
|
||||||
/// \brief crate task list
|
/// \brief crate task list
|
||||||
MSRStatus CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
MSRStatus CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &operators);
|
const std::vector<std::shared_ptr<ShardOperator>> &operators);
|
||||||
|
@ -325,6 +336,15 @@ class __attribute__((visibility("default"))) ShardReader {
|
||||||
// map of delivery
|
// map of delivery
|
||||||
std::unordered_map<int, std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>>> delivery_map_;
|
std::unordered_map<int, std::shared_ptr<std::vector<std::tuple<std::vector<uint8_t>, json>>>> delivery_map_;
|
||||||
// Delivery/Iterator mode end
|
// Delivery/Iterator mode end
|
||||||
|
|
||||||
|
// all metadata in the index is not loaded during initialization
|
||||||
|
bool lazy_load_;
|
||||||
|
|
||||||
|
// indicate shard_id : inc_count
|
||||||
|
// 0 : 15 - shard0 has 15 samples
|
||||||
|
// 1 : 41 - shard1 has 26 samples
|
||||||
|
// 2 : 58 - shard2 has 17 samples
|
||||||
|
std::vector<uint32_t> shard_sample_count_;
|
||||||
};
|
};
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -67,8 +67,13 @@ class __attribute__((visibility("default"))) ShardTask {
|
||||||
|
|
||||||
uint32_t categories;
|
uint32_t categories;
|
||||||
|
|
||||||
|
// The total sample ids which used to shuffle operation. The ids like: [0, 1, 2, 3, ..., n-1, n]
|
||||||
std::vector<int> permutation_;
|
std::vector<int> permutation_;
|
||||||
|
|
||||||
|
// The data struct is as below:
|
||||||
|
// 1. TaskType: kCommonTask / kPaddedTask
|
||||||
|
// 2. std::tuple<int, int> : shard_id, group_id(fast load) / sample_id(lazy load)
|
||||||
|
// 3. std::vector<uint64_t>, json>> : [blob_start, blob_end], scalar_variable_fields
|
||||||
std::vector<std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>> task_list_;
|
std::vector<std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json>> task_list_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,9 @@ ShardReader::ShardReader()
|
||||||
num_rows_(0),
|
num_rows_(0),
|
||||||
total_blob_size_(0),
|
total_blob_size_(0),
|
||||||
task_id_(0),
|
task_id_(0),
|
||||||
deliver_id_(0) {}
|
deliver_id_(0),
|
||||||
|
lazy_load_(false),
|
||||||
|
shard_sample_count_() {}
|
||||||
|
|
||||||
std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path, json &meta_data) {
|
std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path, json &meta_data) {
|
||||||
if (!IsLegalFile(file_path)) {
|
if (!IsLegalFile(file_path)) {
|
||||||
|
@ -148,6 +150,16 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
|
||||||
for (const auto &rg : row_group_summary) {
|
for (const auto &rg : row_group_summary) {
|
||||||
num_rows_ += std::get<3>(rg);
|
num_rows_ += std::get<3>(rg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (num_rows_ > LAZY_LOAD_THRESHOLD) {
|
||||||
|
lazy_load_ = true;
|
||||||
|
MS_LOG(WARNING) << "The number of samples is larger than " << LAZY_LOAD_THRESHOLD
|
||||||
|
<< ", enable lazy load mode. If you want to speed up data loading, "
|
||||||
|
<< "it is recommended that you save multiple samples into one record when creating mindrecord file,"
|
||||||
|
<< " so that you can enable fast loading mode, and don't forget to adjust your batch size "
|
||||||
|
<< "according to the current samples.";
|
||||||
|
}
|
||||||
|
|
||||||
auto disk_size = page_size_ * row_group_summary.size();
|
auto disk_size = page_size_ * row_group_summary.size();
|
||||||
auto compression_size = shard_header_->GetCompressionSize();
|
auto compression_size = shard_header_->GetCompressionSize();
|
||||||
total_blob_size_ = disk_size + compression_size;
|
total_blob_size_ = disk_size + compression_size;
|
||||||
|
@ -270,6 +282,7 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar
|
||||||
return row_group_summary;
|
return row_group_summary;
|
||||||
}
|
}
|
||||||
if (shard_count <= kMaxFileCount) {
|
if (shard_count <= kMaxFileCount) {
|
||||||
|
uint32_t total_count = 0;
|
||||||
for (int shard_id = 0; shard_id < shard_count; ++shard_id) {
|
for (int shard_id = 0; shard_id < shard_count; ++shard_id) {
|
||||||
// return -1 when page's size equals to 0.
|
// return -1 when page's size equals to 0.
|
||||||
auto last_page_id = shard_header_->GetLastPageId(shard_id);
|
auto last_page_id = shard_header_->GetLastPageId(shard_id);
|
||||||
|
@ -285,8 +298,10 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar
|
||||||
return std::vector<std::tuple<int, int, int, uint64_t>>();
|
return std::vector<std::tuple<int, int, int, uint64_t>>();
|
||||||
}
|
}
|
||||||
uint64_t number_of_rows = page->GetEndRowID() - start_row_id;
|
uint64_t number_of_rows = page->GetEndRowID() - start_row_id;
|
||||||
|
total_count += number_of_rows;
|
||||||
row_group_summary.emplace_back(shard_id, page->GetPageTypeID(), start_row_id, number_of_rows);
|
row_group_summary.emplace_back(shard_id, page->GetPageTypeID(), start_row_id, number_of_rows);
|
||||||
}
|
}
|
||||||
|
shard_sample_count_.push_back(total_count);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return row_group_summary;
|
return row_group_summary;
|
||||||
|
@ -472,6 +487,34 @@ ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector<std::string> &columns) {
|
||||||
return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values));
|
return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns,
|
||||||
|
const uint32_t &shard_id, const uint32_t &sample_id) {
|
||||||
|
std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END";
|
||||||
|
std::vector<std::vector<std::vector<uint64_t>>> offsets(shard_count_, std::vector<std::vector<uint64_t>>{});
|
||||||
|
std::vector<std::vector<json>> column_values(shard_count_, std::vector<json>{});
|
||||||
|
if (all_in_index_) {
|
||||||
|
for (unsigned int i = 0; i < columns.size(); ++i) {
|
||||||
|
fields += ',';
|
||||||
|
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]));
|
||||||
|
if (ret.first != SUCCESS) {
|
||||||
|
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values));
|
||||||
|
}
|
||||||
|
fields += ret.second;
|
||||||
|
}
|
||||||
|
} else { // fetch raw data from Raw page while some field is not index.
|
||||||
|
fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END ";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string sql = "SELECT " + fields + " FROM INDEXES WHERE ROW_ID = " + std::to_string(sample_id);
|
||||||
|
|
||||||
|
if (ReadAllRowsInShard(shard_id, sql, columns, offsets, column_values) != SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "Read shard id: " << shard_id << ", sample id: " << sample_id << " from index failed.";
|
||||||
|
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values));
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values));
|
||||||
|
}
|
||||||
|
|
||||||
ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns) {
|
ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns) {
|
||||||
const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id);
|
const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id);
|
||||||
if (SUCCESS != ret.first) {
|
if (SUCCESS != ret.first) {
|
||||||
|
@ -868,7 +911,10 @@ MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths
|
||||||
|
|
||||||
MSRStatus ShardReader::Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer,
|
MSRStatus ShardReader::Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer,
|
||||||
const std::vector<std::string> &selected_columns,
|
const std::vector<std::string> &selected_columns,
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &operators, int num_padded) {
|
const std::vector<std::shared_ptr<ShardOperator>> &operators, int num_padded,
|
||||||
|
bool lazy_load) {
|
||||||
|
lazy_load_ = lazy_load;
|
||||||
|
|
||||||
// Open file and set header by ShardReader
|
// Open file and set header by ShardReader
|
||||||
auto ret = Init(file_paths, load_dataset);
|
auto ret = Init(file_paths, load_dataset);
|
||||||
if (SUCCESS != ret) {
|
if (SUCCESS != ret) {
|
||||||
|
@ -1077,6 +1123,44 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, i
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MSRStatus ShardReader::CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||||
|
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
|
||||||
|
CheckIfColumnInIndex(selected_columns_);
|
||||||
|
|
||||||
|
if (shard_count_ <= kMaxFileCount) {
|
||||||
|
uint32_t sample_count = shard_sample_count_[shard_sample_count_.size() - 1];
|
||||||
|
MS_LOG(DEBUG) << "There are " << sample_count << " records in the dataset.";
|
||||||
|
|
||||||
|
// Init the tasks_ size
|
||||||
|
tasks_.ResizeTask(sample_count);
|
||||||
|
|
||||||
|
// Init the task threads, maybe use ThreadPool is better
|
||||||
|
std::vector<std::thread> init_tasks_thread(shard_count_);
|
||||||
|
|
||||||
|
for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
|
||||||
|
// the offset indicate the shard start
|
||||||
|
uint32_t current_offset = shard_id == 0 ? 0 : shard_sample_count_[shard_id - 1];
|
||||||
|
|
||||||
|
// the count indicate the number of samples in the shard
|
||||||
|
uint32_t shard_count =
|
||||||
|
shard_id == 0 ? shard_sample_count_[0] : shard_sample_count_[shard_id] - shard_sample_count_[shard_id - 1];
|
||||||
|
init_tasks_thread[shard_id] = std::thread([this, shard_id, current_offset, shard_count]() {
|
||||||
|
for (uint32_t i = current_offset; i < shard_count + current_offset; ++i) {
|
||||||
|
// here "i - current_offset" indicate the sample id in the shard
|
||||||
|
tasks_.InsertTask(i, TaskType::kCommonTask, shard_id, i - current_offset, {}, json());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) {
|
||||||
|
init_tasks_thread[shard_id].join();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary,
|
||||||
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
|
const std::vector<std::shared_ptr<ShardOperator>> &operators) {
|
||||||
int category_operator = -1;
|
int category_operator = -1;
|
||||||
|
@ -1088,9 +1172,17 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (-1 == category_operator) {
|
if (-1 == category_operator) {
|
||||||
if (SUCCESS != CreateTasksByRow(row_group_summary, operators)) {
|
if (lazy_load_ == false) {
|
||||||
return FAILED;
|
if (SUCCESS != CreateTasksByRow(row_group_summary, operators)) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (SUCCESS != CreateLazyTasksByRow(row_group_summary, operators)) {
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// need padded sample to the task
|
||||||
if (num_padded_ > 0) {
|
if (num_padded_ > 0) {
|
||||||
for (int i = 0; i < num_padded_; ++i) {
|
for (int i = 0; i < num_padded_; ++i) {
|
||||||
tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json());
|
tasks_.InsertTask(TaskType::kPaddedTask, 0, 0, {}, json());
|
||||||
|
@ -1123,6 +1215,12 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
|
||||||
std::make_pair(TaskType::kCommonTask, std::vector<std::tuple<std::vector<uint8_t>, json>>()));
|
std::make_pair(TaskType::kCommonTask, std::vector<std::tuple<std::vector<uint8_t>, json>>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t shard_id = 0;
|
||||||
|
uint32_t group_id = 0;
|
||||||
|
uint32_t blob_start = 0;
|
||||||
|
uint32_t blob_end = 0;
|
||||||
|
json var_fields;
|
||||||
|
|
||||||
// Pick up task from task list
|
// Pick up task from task list
|
||||||
auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]);
|
auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]);
|
||||||
|
|
||||||
|
@ -1133,9 +1231,33 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
|
||||||
std::make_pair(TaskType::kPaddedTask, std::vector<std::tuple<std::vector<uint8_t>, json>>()));
|
std::make_pair(TaskType::kPaddedTask, std::vector<std::tuple<std::vector<uint8_t>, json>>()));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto shard_id = std::get<0>(std::get<1>(task));
|
shard_id = std::get<0>(std::get<1>(task)); // shard id
|
||||||
auto group_id = std::get<1>(std::get<1>(task));
|
|
||||||
auto addr = std::get<2>(task);
|
if (lazy_load_ == false) {
|
||||||
|
group_id = std::get<1>(std::get<1>(task)); // group id
|
||||||
|
blob_start = std::get<2>(task)[0]; // blob start
|
||||||
|
blob_end = std::get<2>(task)[1]; // blob end
|
||||||
|
var_fields = std::get<3>(task); // scalar variable field
|
||||||
|
} else {
|
||||||
|
// get scalar variable fields by sample id
|
||||||
|
uint32_t sample_id_in_shard = std::get<1>(std::get<1>(task));
|
||||||
|
|
||||||
|
// read the meta from index
|
||||||
|
auto row_meta = ReadRowGroupByShardIDAndSampleID(selected_columns_, shard_id, sample_id_in_shard);
|
||||||
|
if (std::get<0>(row_meta) != SUCCESS) {
|
||||||
|
return std::make_pair(
|
||||||
|
FAILED, std::make_pair(TaskType::kCommonTask, std::vector<std::tuple<std::vector<uint8_t>, json>>()));
|
||||||
|
}
|
||||||
|
auto &offsets = std::get<1>(row_meta);
|
||||||
|
auto &local_columns = std::get<2>(row_meta);
|
||||||
|
|
||||||
|
group_id = offsets[shard_id][0][1]; // group_id
|
||||||
|
blob_start = offsets[shard_id][0][2]; // blob start
|
||||||
|
blob_end = offsets[shard_id][0][3]; // blob end
|
||||||
|
var_fields = local_columns[shard_id][0]; // scalar variable field
|
||||||
|
}
|
||||||
|
|
||||||
|
// read the blob from data file
|
||||||
const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id);
|
const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id);
|
||||||
if (SUCCESS != ret.first) {
|
if (SUCCESS != ret.first) {
|
||||||
return std::make_pair(FAILED,
|
return std::make_pair(FAILED,
|
||||||
|
@ -1144,8 +1266,8 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
|
||||||
const std::shared_ptr<Page> &page = ret.second;
|
const std::shared_ptr<Page> &page = ret.second;
|
||||||
|
|
||||||
// Pack image list
|
// Pack image list
|
||||||
std::vector<uint8_t> images(addr[1] - addr[0]);
|
std::vector<uint8_t> images(blob_end - blob_start);
|
||||||
auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + addr[0];
|
auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + blob_start;
|
||||||
|
|
||||||
auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg);
|
auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg);
|
||||||
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
|
if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) {
|
||||||
|
@ -1156,7 +1278,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
|
||||||
}
|
}
|
||||||
|
|
||||||
auto &io_read =
|
auto &io_read =
|
||||||
file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast<char *>(&images[0]), addr[1] - addr[0]);
|
file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast<char *>(&images[0]), blob_end - blob_start);
|
||||||
if (!io_read.good() || io_read.fail() || io_read.bad()) {
|
if (!io_read.good() || io_read.fail() || io_read.bad()) {
|
||||||
MS_LOG(ERROR) << "File read failed";
|
MS_LOG(ERROR) << "File read failed";
|
||||||
file_streams_random_[consumer_id][shard_id]->close();
|
file_streams_random_[consumer_id][shard_id]->close();
|
||||||
|
@ -1166,7 +1288,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
|
||||||
|
|
||||||
// Deliver batch data to output map
|
// Deliver batch data to output map
|
||||||
std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
|
std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
|
||||||
batch.emplace_back(std::move(images), std::move(std::get<3>(task)));
|
batch.emplace_back(std::move(images), std::move(var_fields));
|
||||||
|
|
||||||
return std::make_pair(SUCCESS, std::make_pair(TaskType::kCommonTask, std::move(batch)));
|
return std::make_pair(SUCCESS, std::make_pair(TaskType::kCommonTask, std::move(batch)));
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,6 +70,30 @@ TEST_F(TestShardReader, TestShardReaderGeneral) {
|
||||||
dataset.Close();
|
dataset.Close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestShardReader, TestShardReaderLazyLoad) {
|
||||||
|
MS_LOG(INFO) << FormatInfo("Test read imageNet");
|
||||||
|
std::string file_name = "./imagenet.shard01";
|
||||||
|
auto column_list = std::vector<std::string>{"file_name"};
|
||||||
|
|
||||||
|
ShardReader dataset;
|
||||||
|
dataset.Open({file_name}, true, 4, column_list, {}, 0, true);
|
||||||
|
dataset.Launch();
|
||||||
|
|
||||||
|
uint32_t count = 0;
|
||||||
|
while (true) {
|
||||||
|
auto x = dataset.GetNext();
|
||||||
|
if (x.empty()) break;
|
||||||
|
for (auto &j : x) {
|
||||||
|
for (auto &item : std::get<1>(j).items()) {
|
||||||
|
MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(count == 10);
|
||||||
|
dataset.Close();
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TestShardReader, TestShardReaderSample) {
|
TEST_F(TestShardReader, TestShardReaderSample) {
|
||||||
MS_LOG(INFO) << FormatInfo("Test read imageNet");
|
MS_LOG(INFO) << FormatInfo("Test read imageNet");
|
||||||
std::string file_name = "./imagenet.shard01";
|
std::string file_name = "./imagenet.shard01";
|
||||||
|
@ -91,6 +115,31 @@ TEST_F(TestShardReader, TestShardReaderSample) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dataset.Close();
|
dataset.Close();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TestShardReader, TestShardReaderLazyLoadDistributed) {
|
||||||
|
MS_LOG(INFO) << FormatInfo("Test read imageNet");
|
||||||
|
std::string file_name = "./imagenet.shard01";
|
||||||
|
auto column_list = std::vector<std::string>{"file_name"};
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<ShardOperator>> ops;
|
||||||
|
ops.push_back(std::make_shared<ShardSample>(1, 8));
|
||||||
|
ShardReader dataset;
|
||||||
|
dataset.Open({file_name}, true, 4, column_list, ops, 0, true);
|
||||||
|
dataset.Launch();
|
||||||
|
|
||||||
|
uint32_t count = 0;
|
||||||
|
while (true) {
|
||||||
|
auto x = dataset.GetNext();
|
||||||
|
if (x.empty()) break;
|
||||||
|
for (auto &j : x) {
|
||||||
|
for (auto &item : std::get<1>(j).items()) {
|
||||||
|
MS_LOG(INFO) << "key: " << item.key() << ", value: " << item.value().dump();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count++;
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(count == 2);
|
||||||
dataset.Close();
|
dataset.Close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue