!39379 revert acceleration of tf CalculateNumRowsPerShard

Merge pull request !39379 from luoyang/revert_tf
This commit is contained in:
i-robot 2022-08-04 01:28:45 +00:00 committed by Gitee
commit 3cd07c93a6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 9 additions and 59 deletions

View File

@ -152,54 +152,12 @@ Status TFReaderOp::CalculateNumRowsPerShard() {
return Status::OK();
}
// Do row count in parallel
std::mutex thread_mutex;
auto rowcount_task = [this, &thread_mutex](int st, int ed) {
auto iter = filename_index_->begin();
auto end_iter = filename_index_->begin();
std::advance(iter, st);
std::advance(end_iter, ed);
for (; iter != end_iter; ++iter) {
std::vector<std::string> file(1, iter.value());
int64_t num = CountTotalRowsSectioned(file, 0, 1);
std::lock_guard<std::mutex> lock(thread_mutex);
filename_numrows_[iter.value()] = num;
num_rows_ += num;
}
};
std::vector<std::future<void>> async_tasks;
int32_t threads = GlobalContext::config_manager()->num_cpu_threads();
// constrain the workers
int32_t kThreadCount = 8;
threads = threads < kThreadCount ? threads : kThreadCount;
if (threads > filename_index_->size()) {
threads = filename_index_->size();
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
std::vector<std::string> file(1, it.value());
int64_t num = CountTotalRowsSectioned(file, 0, 1);
filename_numrows_[it.value()] = num;
num_rows_ += num;
}
CHECK_FAIL_RETURN_SYNTAX_ERROR(
threads > 0,
"Invalid threads number, TFRecordDataset should own more than 0 thread, but got " + std::to_string(threads) + ".");
int64_t chunk_size = filename_index_->size() / threads;
int64_t remainder = filename_index_->size() % threads;
int64_t begin = 0;
int64_t end = begin;
for (int i = 0; i < threads; i++) {
end += chunk_size;
if (remainder > 0) {
end++;
remainder--;
}
async_tasks.emplace_back(std::async(std::launch::async, rowcount_task, begin, end));
begin = end;
}
// Wait until all tasks have been finished
for (int i = 0; i < async_tasks.size(); i++) {
async_tasks[i].get();
}
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
if (num_rows_per_shard_ == 0) {
std::stringstream ss;

View File

@ -112,11 +112,7 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
int64_t num_rows = 0;
// First, get the number of rows in the dataset
int32_t thread_count = GlobalContext::config_manager()->num_cpu_threads();
// constrain the workers
int32_t kThreadCount = 8;
thread_count = thread_count < kThreadCount ? thread_count : kThreadCount;
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files, thread_count));
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files));
// Add the shuffle op after this op
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
@ -148,20 +144,16 @@ Status TFRecordNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
return Status::OK();
}
int64_t num_rows;
int32_t thread_count = GlobalContext::config_manager()->num_cpu_threads();
// constrain the workers
int32_t kThreadCount = 8;
thread_count = thread_count < kThreadCount ? thread_count : kThreadCount;
constexpr int64_t kThreadCount = 8;
// By default, TFRecord will do file-based sharding. But when cache is injected, it will be row-based sharding.
if (!shard_equal_rows_ && !IsCached()) {
// Data will be sharded by file
std::vector<std::string> shard_file_list;
RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list));
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, shard_file_list, thread_count, estimate));
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, shard_file_list, kThreadCount, estimate));
} else {
// Data will be sharded by row
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, dataset_files_, thread_count, estimate));
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, dataset_files_, kThreadCount, estimate));
num_rows = static_cast<int64_t>(ceil(num_rows / (num_shards_ * 1.0)));
}
*dataset_size = num_samples_ > 0 ? std::min(num_rows, num_samples_) : num_rows;