!48348 修复median CPU算子当输入为空算子时错误输出的问题

Merge pull request !48348 from wangtongyu6/fix_median_bug
This commit is contained in:
i-robot 2023-02-09 10:19:39 +00:00 committed by Gitee
commit 9950d262c0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 30 additions and 1 deletions

View File

@ -16,6 +16,7 @@
#include "plugin/device/cpu/kernel/median_cpu_kernel.h"
#include <functional>
#include <algorithm>
#include <type_traits>
@ -57,6 +58,16 @@ int MedianCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::
input_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
input_dim_ = input_shape_.size();
input_num_elements_ = 1;
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
std::vector<size_t> src_shape;
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(src_shape), LongToSize);
size_t input_element_num = std::accumulate(src_shape.begin(), src_shape.end(), size_t(1), std::multiplies<size_t>());
is_null_input_ = (input_element_num == 0);
if (is_null_input_) {
MS_LOG(WARNING) << "For '" << kernel_name_ << "', input tensor[0] got 'shapes[" << kIndex0 << "]' is "
<< input_element_num;
return KRET_OK;
}
if (global_median_) {
if (axis_ != 0) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', when 'global_median' is True, the 'axis' must be 0, but got "
@ -107,6 +118,9 @@ const std::vector<std::pair<KernelAttr, MedianCpuKernelMod::KernelRunFunc>> &Med
template <typename T>
bool MedianCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
if (is_null_input_) {
return true;
}
if (global_median_ == false) {
return MedianCompute<T>(inputs, outputs);
} else {

View File

@ -63,6 +63,7 @@ class MedianCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<M
size_t input_num_elements_;
size_t output_num_elements_;
size_t input_dim_;
bool is_null_input_;
template <typename T>
bool GlobalMedianCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>

View File

@ -16,6 +16,7 @@
#include "plugin/device/cpu/kernel/median_grad_cpu_kernel.h"
#include <functional>
#include <algorithm>
#include <type_traits>
@ -55,7 +56,16 @@ int MedianGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
std::vector<size_t> src_shape;
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(src_shape), LongToSize);
size_t input_element_num = std::accumulate(src_shape.begin(), src_shape.end(), size_t(1), std::multiplies<size_t>());
is_null_input_ = (input_element_num == 0);
if (is_null_input_) {
MS_LOG(WARNING) << "For '" << kernel_name_ << "', input tensor[0] got 'shapes[" << kIndex0 << "]' is "
<< input_element_num;
return KRET_OK;
}
input0_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
input1_shape_ = inputs[kIndex1]->GetDeviceShapeAdaptively();
input2_shape_ = inputs[kIndex2]->GetDeviceShapeAdaptively();
@ -138,6 +148,9 @@ template <typename T1, typename T2>
bool MedianGradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
if (is_null_input_) {
return true;
}
if (global_median_ == false) {
return MedianGradCompute<T1, T2>(inputs, outputs);
} else {

View File

@ -66,6 +66,7 @@ class MedianGradCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelp
size_t input2_dim_;
size_t input0_num_elements_;
size_t input1_num_elements_;
bool is_null_input_;
template <typename T1, typename T2>
bool GlobalMedianGradCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) const;
template <typename T1, typename T2>