fix kernel choose logic about sort_gpu_kernel

This commit is contained in:
yeyunpeng2020 2023-02-08 16:54:29 +08:00
parent a7f81dcf82
commit 88dab2f3f3
1 changed files with 12 additions and 11 deletions

View File

@ -53,17 +53,6 @@ class SortGpuKernelMod : public NativeGpuKernelMod {
input_shape_ = inputs[0]->GetShapeVector();
use_fast_ = input_shape_[axis_] <= sort_dim_thres_;
if (use_fast_) {
return fast_sort_kernel_->Resize(base_operator, inputs, outputs, inputsOnHost);
} else {
if (!old_kernel_support_) {
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
MS_LOG(ERROR) << "Only support input datatype in [float16, float32] for sort kernel, but got "
<< kernel_attr.GetInputAttr(0).dtype << " in KernelAttr.";
return KRET_RESIZE_FAILED;
}
}
auto kernel_name = base_operator->GetPrim()->name();
is_null_input_ = CHECK_SHAPE_NULL(input_shape_, kernel_name, "input");
if (is_null_input_) {
@ -95,6 +84,18 @@ class SortGpuKernelMod : public NativeGpuKernelMod {
return KRET_RESIZE_FAILED;
}
use_fast_ = input_shape_[axis_] > 0 && input_shape_[axis_] <= sort_dim_thres_;
if (use_fast_) {
return fast_sort_kernel_->Resize(base_operator, inputs, outputs, inputsOnHost);
} else {
if (!old_kernel_support_) {
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
MS_LOG(ERROR) << "Only support input datatype in [float16, float32] for sort kernel, but got "
<< kernel_attr.GetInputAttr(0).dtype << " in KernelAttr.";
return KRET_RESIZE_FAILED;
}
}
perm_.resize(input_rank_);
std::iota(perm_.begin(), perm_.end(), 0);
std::swap(perm_[input_rank_ - 1], perm_[axis_]);