forked from mindspore-Ecosystem/mindspore
!35369 accelerate reading tfrecord file
Merge pull request !35369 from luoyang/fix-tf
This commit is contained in:
commit
c20e2e6083
|
@ -198,7 +198,7 @@
|
|||
**参数:**
|
||||
|
||||
- **estimate** (bool) - 如果 `estimate` 为 False,将返回数据集第一条数据的shape。
|
||||
否则将遍历整个数据集以获取数据集的真实shape信息,其中动态变化的维度将被标记为-1(可用于动态shape数据集场景)。
|
||||
否则将遍历整个数据集以获取数据集的真实shape信息,其中动态变化的维度将被标记为-1(可用于动态shape数据集场景),默认值:False。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
|
|
@ -167,12 +167,54 @@ Status TFReaderOp::CalculateNumRowsPerShard() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
std::vector<std::string> file(1, it.value());
|
||||
int64_t num = CountTotalRowsSectioned(file, 0, 1);
|
||||
filename_numrows_[it.value()] = num;
|
||||
num_rows_ += num;
|
||||
// Do row count in parallel
|
||||
std::mutex thread_mutex;
|
||||
auto rowcount_task = [this, &thread_mutex](int st, int ed) {
|
||||
auto iter = filename_index_->begin();
|
||||
auto end_iter = filename_index_->begin();
|
||||
std::advance(iter, st);
|
||||
std::advance(end_iter, ed);
|
||||
for (; iter != end_iter; ++iter) {
|
||||
std::vector<std::string> file(1, iter.value());
|
||||
int64_t num = CountTotalRowsSectioned(file, 0, 1);
|
||||
std::lock_guard<std::mutex> lock(thread_mutex);
|
||||
filename_numrows_[iter.value()] = num;
|
||||
num_rows_ += num;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::future<void>> async_tasks;
|
||||
int32_t threads = GlobalContext::config_manager()->num_cpu_threads();
|
||||
// constrain the workers
|
||||
int32_t kThreadCount = 8;
|
||||
threads = threads < kThreadCount ? threads : kThreadCount;
|
||||
|
||||
if (threads > filename_index_->size()) {
|
||||
threads = filename_index_->size();
|
||||
}
|
||||
CHECK_FAIL_RETURN_SYNTAX_ERROR(
|
||||
threads > 0,
|
||||
"Invalid threads number, TFRecordDataset should own more than 0 thread, but got " + std::to_string(threads) + ".");
|
||||
|
||||
int64_t chunk_size = filename_index_->size() / threads;
|
||||
int64_t remainder = filename_index_->size() % threads;
|
||||
int64_t begin = 0;
|
||||
int64_t end = begin;
|
||||
for (int i = 0; i < threads; i++) {
|
||||
end += chunk_size;
|
||||
if (remainder > 0) {
|
||||
end++;
|
||||
remainder--;
|
||||
}
|
||||
async_tasks.emplace_back(std::async(std::launch::async, rowcount_task, begin, end));
|
||||
begin = end;
|
||||
}
|
||||
|
||||
// Wait until all tasks have been finished
|
||||
for (int i = 0; i < async_tasks.size(); i++) {
|
||||
async_tasks[i].get();
|
||||
}
|
||||
|
||||
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
|
||||
if (num_rows_per_shard_ == 0) {
|
||||
std::stringstream ss;
|
||||
|
|
|
@ -115,7 +115,11 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files));
|
||||
int32_t thread_count = GlobalContext::config_manager()->num_cpu_threads();
|
||||
// constrain the workers
|
||||
int32_t kThreadCount = 8;
|
||||
thread_count = thread_count < kThreadCount ? thread_count : kThreadCount;
|
||||
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files, thread_count));
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||
|
@ -147,16 +151,20 @@ Status TFRecordNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
|
|||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows;
|
||||
constexpr int64_t kThreadCount = 8;
|
||||
int32_t thread_count = GlobalContext::config_manager()->num_cpu_threads();
|
||||
// constrain the workers
|
||||
int32_t kThreadCount = 8;
|
||||
thread_count = thread_count < kThreadCount ? thread_count : kThreadCount;
|
||||
|
||||
// By default, TFRecord will do file-based sharding. But when cache is injected, it will be row-based sharding.
|
||||
if (!shard_equal_rows_ && !IsCached()) {
|
||||
// Data will be sharded by file
|
||||
std::vector<std::string> shard_file_list;
|
||||
RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list));
|
||||
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, shard_file_list, kThreadCount, estimate));
|
||||
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, shard_file_list, thread_count, estimate));
|
||||
} else {
|
||||
// Data will be sharded by row
|
||||
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, dataset_files_, kThreadCount, estimate));
|
||||
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, dataset_files_, thread_count, estimate));
|
||||
num_rows = static_cast<int64_t>(ceil(num_rows / (num_shards_ * 1.0)));
|
||||
}
|
||||
*dataset_size = num_samples_ > 0 ? std::min(num_rows, num_samples_) : num_rows;
|
||||
|
|
|
@ -1556,7 +1556,7 @@ class Dataset:
|
|||
Args:
|
||||
estimate (bool): If `estimate` is False, will return the shapes of first data row.
|
||||
Otherwise, will iterate the whole dataset and return the estimated shapes of data row,
|
||||
where dynamic shape is marked as -1 (used in dynamic data shapes scenario).
|
||||
where dynamic shape is marked as -1 (used in dynamic data shapes scenario). Default: False.
|
||||
|
||||
Returns:
|
||||
list, list of shapes of each column.
|
||||
|
|
Loading…
Reference in New Issue