!20322 add fallback strategy

Merge pull request !20322 from zhujingxuan/TopK
This commit is contained in:
i-robot 2021-07-16 07:30:48 +00:00 committed by Gitee
commit 6654305346
1 changed files with 18 additions and 5 deletions

View File

@ -50,13 +50,26 @@ void TopKCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const st
std::vector<size_t> idx(inner_size_);
auto base_input = i * inner_size_;
std::iota(idx.begin(), idx.end(), base_input);
std::nth_element(idx.begin(), idx.begin() + SizeToLong(k_num), idx.end(),
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
auto base_output = i * k_num;
if (sorted_) {
std::stable_sort(idx.begin(), idx.begin() + SizeToLong(k_num),
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
constexpr float fraction = 0.5;
const size_t threshold = inner_size_ * fraction;
// fall back to stable_sort
if (k_num > threshold) {
stable_sort(idx.begin(), idx.end(),
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
} else {
nth_element(idx.begin(), idx.begin() + SizeToLong(k_num), idx.end(),
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
stable_sort(idx.begin(), idx.begin() + SizeToLong(k_num),
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
}
} else {
nth_element(idx.begin(), idx.begin() + SizeToLong(k_num), idx.end(),
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
}
auto base_output = i * k_num;
for (size_t j = 0; j < k_num; ++j) {
indices[base_output + j] = SizeToInt(idx[j]) - SizeToInt(base_input);
output[base_output + j] = input[idx[j]];