!35369 accelerate reading tfrecord file

Merge pull request !35369 from luoyang/fix-tf
This commit is contained in:
i-robot 2022-06-18 02:39:29 +00:00 committed by Gitee
commit c20e2e6083
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 61 additions and 11 deletions

View File

@ -198,7 +198,7 @@
**参数:**
- **estimate** (bool) - 如果 `estimate` 为 False将返回数据集第一条数据的shape。
否则将遍历整个数据集以获取数据集的真实shape信息其中动态变化的维度将被标记为-1可用于动态shape数据集场景
否则将遍历整个数据集以获取数据集的真实shape信息其中动态变化的维度将被标记为-1可用于动态shape数据集场景默认值False
**返回:**

View File

@ -167,12 +167,54 @@ Status TFReaderOp::CalculateNumRowsPerShard() {
return Status::OK();
}
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
std::vector<std::string> file(1, it.value());
int64_t num = CountTotalRowsSectioned(file, 0, 1);
filename_numrows_[it.value()] = num;
num_rows_ += num;
// Do row count in parallel
std::mutex thread_mutex;
auto rowcount_task = [this, &thread_mutex](int st, int ed) {
auto iter = filename_index_->begin();
auto end_iter = filename_index_->begin();
std::advance(iter, st);
std::advance(end_iter, ed);
for (; iter != end_iter; ++iter) {
std::vector<std::string> file(1, iter.value());
int64_t num = CountTotalRowsSectioned(file, 0, 1);
std::lock_guard<std::mutex> lock(thread_mutex);
filename_numrows_[iter.value()] = num;
num_rows_ += num;
}
};
std::vector<std::future<void>> async_tasks;
int32_t threads = GlobalContext::config_manager()->num_cpu_threads();
// constrain the workers
int32_t kThreadCount = 8;
threads = threads < kThreadCount ? threads : kThreadCount;
if (threads > filename_index_->size()) {
threads = filename_index_->size();
}
CHECK_FAIL_RETURN_SYNTAX_ERROR(
threads > 0,
"Invalid threads number, TFRecordDataset should own more than 0 thread, but got " + std::to_string(threads) + ".");
int64_t chunk_size = filename_index_->size() / threads;
int64_t remainder = filename_index_->size() % threads;
int64_t begin = 0;
int64_t end = begin;
for (int i = 0; i < threads; i++) {
end += chunk_size;
if (remainder > 0) {
end++;
remainder--;
}
async_tasks.emplace_back(std::async(std::launch::async, rowcount_task, begin, end));
begin = end;
}
// Wait until all tasks have been finished
for (int i = 0; i < async_tasks.size(); i++) {
async_tasks[i].get();
}
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
if (num_rows_per_shard_ == 0) {
std::stringstream ss;

View File

@ -115,7 +115,11 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
int64_t num_rows = 0;
// First, get the number of rows in the dataset
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files));
int32_t thread_count = GlobalContext::config_manager()->num_cpu_threads();
// constrain the workers
int32_t kThreadCount = 8;
thread_count = thread_count < kThreadCount ? thread_count : kThreadCount;
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files, thread_count));
// Add the shuffle op after this op
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
@ -147,16 +151,20 @@ Status TFRecordNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
return Status::OK();
}
int64_t num_rows;
constexpr int64_t kThreadCount = 8;
int32_t thread_count = GlobalContext::config_manager()->num_cpu_threads();
// constrain the workers
int32_t kThreadCount = 8;
thread_count = thread_count < kThreadCount ? thread_count : kThreadCount;
// By default, TFRecord will do file-based sharding. But when cache is injected, it will be row-based sharding.
if (!shard_equal_rows_ && !IsCached()) {
// Data will be sharded by file
std::vector<std::string> shard_file_list;
RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list));
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, shard_file_list, kThreadCount, estimate));
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, shard_file_list, thread_count, estimate));
} else {
// Data will be sharded by row
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, dataset_files_, kThreadCount, estimate));
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, dataset_files_, thread_count, estimate));
num_rows = static_cast<int64_t>(ceil(num_rows / (num_shards_ * 1.0)));
}
*dataset_size = num_samples_ > 0 ? std::min(num_rows, num_samples_) : num_rows;

View File

@ -1556,7 +1556,7 @@ class Dataset:
Args:
estimate (bool): If `estimate` is False, will return the shapes of first data row.
Otherwise, will iterate the whole dataset and return the estimated shapes of data row,
where dynamic shape is marked as -1 (used in dynamic data shapes scenario).
where dynamic shape is marked as -1 (used in dynamic data shapes scenario). Default: False.
Returns:
list, list of shapes of each column.