forked from mindspore-Ecosystem/mindspore
!20322 add fallback strategy
Merge pull request !20322 from zhujingxuan/TopK
This commit is contained in:
commit
6654305346
|
@ -50,13 +50,26 @@ void TopKCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const st
|
|||
std::vector<size_t> idx(inner_size_);
|
||||
auto base_input = i * inner_size_;
|
||||
std::iota(idx.begin(), idx.end(), base_input);
|
||||
std::nth_element(idx.begin(), idx.begin() + SizeToLong(k_num), idx.end(),
|
||||
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
|
||||
auto base_output = i * k_num;
|
||||
|
||||
if (sorted_) {
|
||||
std::stable_sort(idx.begin(), idx.begin() + SizeToLong(k_num),
|
||||
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
|
||||
constexpr float fraction = 0.5;
|
||||
const size_t threshold = inner_size_ * fraction;
|
||||
// fall back to stable_sort
|
||||
if (k_num > threshold) {
|
||||
stable_sort(idx.begin(), idx.end(),
|
||||
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
|
||||
} else {
|
||||
nth_element(idx.begin(), idx.begin() + SizeToLong(k_num), idx.end(),
|
||||
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
|
||||
stable_sort(idx.begin(), idx.begin() + SizeToLong(k_num),
|
||||
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
|
||||
}
|
||||
} else {
|
||||
nth_element(idx.begin(), idx.begin() + SizeToLong(k_num), idx.end(),
|
||||
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
|
||||
}
|
||||
|
||||
auto base_output = i * k_num;
|
||||
for (size_t j = 0; j < k_num; ++j) {
|
||||
indices[base_output + j] = SizeToInt(idx[j]) - SizeToInt(base_input);
|
||||
output[base_output + j] = input[idx[j]];
|
||||
|
|
Loading…
Reference in New Issue