forked from mindspore-Ecosystem/mindspore
fix kernel choose logic about sort_gpu_kernel
This commit is contained in:
parent
a7f81dcf82
commit
88dab2f3f3
|
@ -53,17 +53,6 @@ class SortGpuKernelMod : public NativeGpuKernelMod {
|
|||
|
||||
input_shape_ = inputs[0]->GetShapeVector();
|
||||
|
||||
use_fast_ = input_shape_[axis_] <= sort_dim_thres_;
|
||||
if (use_fast_) {
|
||||
return fast_sort_kernel_->Resize(base_operator, inputs, outputs, inputsOnHost);
|
||||
} else {
|
||||
if (!old_kernel_support_) {
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
MS_LOG(ERROR) << "Only support input datatype in [float16, float32] for sort kernel, but got "
|
||||
<< kernel_attr.GetInputAttr(0).dtype << " in KernelAttr.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
}
|
||||
auto kernel_name = base_operator->GetPrim()->name();
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shape_, kernel_name, "input");
|
||||
if (is_null_input_) {
|
||||
|
@ -95,6 +84,18 @@ class SortGpuKernelMod : public NativeGpuKernelMod {
|
|||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
||||
use_fast_ = input_shape_[axis_] > 0 && input_shape_[axis_] <= sort_dim_thres_;
|
||||
if (use_fast_) {
|
||||
return fast_sort_kernel_->Resize(base_operator, inputs, outputs, inputsOnHost);
|
||||
} else {
|
||||
if (!old_kernel_support_) {
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
MS_LOG(ERROR) << "Only support input datatype in [float16, float32] for sort kernel, but got "
|
||||
<< kernel_attr.GetInputAttr(0).dtype << " in KernelAttr.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
perm_.resize(input_rank_);
|
||||
std::iota(perm_.begin(), perm_.end(), 0);
|
||||
std::swap(perm_[input_rank_ - 1], perm_[axis_]);
|
||||
|
|
Loading…
Reference in New Issue