forked from mindspore-Ecosystem/mindspore
!39379 revert acceleration of tf CalculateNumRowsPerShard
Merge pull request !39379 from luoyang/revert_tf
This commit is contained in:
commit
3cd07c93a6
|
@ -152,54 +152,12 @@ Status TFReaderOp::CalculateNumRowsPerShard() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Do row count in parallel
|
||||
std::mutex thread_mutex;
|
||||
auto rowcount_task = [this, &thread_mutex](int st, int ed) {
|
||||
auto iter = filename_index_->begin();
|
||||
auto end_iter = filename_index_->begin();
|
||||
std::advance(iter, st);
|
||||
std::advance(end_iter, ed);
|
||||
for (; iter != end_iter; ++iter) {
|
||||
std::vector<std::string> file(1, iter.value());
|
||||
int64_t num = CountTotalRowsSectioned(file, 0, 1);
|
||||
std::lock_guard<std::mutex> lock(thread_mutex);
|
||||
filename_numrows_[iter.value()] = num;
|
||||
num_rows_ += num;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::future<void>> async_tasks;
|
||||
int32_t threads = GlobalContext::config_manager()->num_cpu_threads();
|
||||
// constrain the workers
|
||||
int32_t kThreadCount = 8;
|
||||
threads = threads < kThreadCount ? threads : kThreadCount;
|
||||
|
||||
if (threads > filename_index_->size()) {
|
||||
threads = filename_index_->size();
|
||||
for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) {
|
||||
std::vector<std::string> file(1, it.value());
|
||||
int64_t num = CountTotalRowsSectioned(file, 0, 1);
|
||||
filename_numrows_[it.value()] = num;
|
||||
num_rows_ += num;
|
||||
}
|
||||
CHECK_FAIL_RETURN_SYNTAX_ERROR(
|
||||
threads > 0,
|
||||
"Invalid threads number, TFRecordDataset should own more than 0 thread, but got " + std::to_string(threads) + ".");
|
||||
|
||||
int64_t chunk_size = filename_index_->size() / threads;
|
||||
int64_t remainder = filename_index_->size() % threads;
|
||||
int64_t begin = 0;
|
||||
int64_t end = begin;
|
||||
for (int i = 0; i < threads; i++) {
|
||||
end += chunk_size;
|
||||
if (remainder > 0) {
|
||||
end++;
|
||||
remainder--;
|
||||
}
|
||||
async_tasks.emplace_back(std::async(std::launch::async, rowcount_task, begin, end));
|
||||
begin = end;
|
||||
}
|
||||
|
||||
// Wait until all tasks have been finished
|
||||
for (int i = 0; i < async_tasks.size(); i++) {
|
||||
async_tasks[i].get();
|
||||
}
|
||||
|
||||
num_rows_per_shard_ = static_cast<int64_t>(std::ceil(num_rows_ * 1.0 / num_devices_));
|
||||
if (num_rows_per_shard_ == 0) {
|
||||
std::stringstream ss;
|
||||
|
|
|
@ -112,11 +112,7 @@ Status TFRecordNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
|
|||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
int32_t thread_count = GlobalContext::config_manager()->num_cpu_threads();
|
||||
// constrain the workers
|
||||
int32_t kThreadCount = 8;
|
||||
thread_count = thread_count < kThreadCount ? thread_count : kThreadCount;
|
||||
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files, thread_count));
|
||||
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, sorted_dir_files));
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_IF_NOT_OK(AddShuffleOp(sorted_dir_files.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op));
|
||||
|
@ -148,20 +144,16 @@ Status TFRecordNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
|
|||
return Status::OK();
|
||||
}
|
||||
int64_t num_rows;
|
||||
int32_t thread_count = GlobalContext::config_manager()->num_cpu_threads();
|
||||
// constrain the workers
|
||||
int32_t kThreadCount = 8;
|
||||
thread_count = thread_count < kThreadCount ? thread_count : kThreadCount;
|
||||
|
||||
constexpr int64_t kThreadCount = 8;
|
||||
// By default, TFRecord will do file-based sharding. But when cache is injected, it will be row-based sharding.
|
||||
if (!shard_equal_rows_ && !IsCached()) {
|
||||
// Data will be sharded by file
|
||||
std::vector<std::string> shard_file_list;
|
||||
RETURN_IF_NOT_OK(GetShardFileList(&shard_file_list));
|
||||
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, shard_file_list, thread_count, estimate));
|
||||
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, shard_file_list, kThreadCount, estimate));
|
||||
} else {
|
||||
// Data will be sharded by row
|
||||
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, dataset_files_, thread_count, estimate));
|
||||
RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, dataset_files_, kThreadCount, estimate));
|
||||
num_rows = static_cast<int64_t>(ceil(num_rows / (num_shards_ * 1.0)));
|
||||
}
|
||||
*dataset_size = num_samples_ > 0 ? std::min(num_rows, num_samples_) : num_rows;
|
||||
|
|
Loading…
Reference in New Issue