forked from mindspore-Ecosystem/mindspore
!48348 修复median CPU算子当输入为空算子时错误输出的问题
Merge pull request !48348 from wangtongyu6/fix_median_bug
This commit is contained in:
commit
9950d262c0
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "plugin/device/cpu/kernel/median_cpu_kernel.h"
|
||||
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
|
||||
|
@ -57,6 +58,16 @@ int MedianCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::
|
|||
input_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
input_dim_ = input_shape_.size();
|
||||
input_num_elements_ = 1;
|
||||
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
std::vector<size_t> src_shape;
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(src_shape), LongToSize);
|
||||
size_t input_element_num = std::accumulate(src_shape.begin(), src_shape.end(), size_t(1), std::multiplies<size_t>());
|
||||
is_null_input_ = (input_element_num == 0);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "For '" << kernel_name_ << "', input tensor[0] got 'shapes[" << kIndex0 << "]' is "
|
||||
<< input_element_num;
|
||||
return KRET_OK;
|
||||
}
|
||||
if (global_median_) {
|
||||
if (axis_ != 0) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', when 'global_median' is True, the 'axis' must be 0, but got "
|
||||
|
@ -107,6 +118,9 @@ const std::vector<std::pair<KernelAttr, MedianCpuKernelMod::KernelRunFunc>> &Med
|
|||
template <typename T>
|
||||
bool MedianCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
if (global_median_ == false) {
|
||||
return MedianCompute<T>(inputs, outputs);
|
||||
} else {
|
||||
|
|
|
@ -63,6 +63,7 @@ class MedianCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<M
|
|||
size_t input_num_elements_;
|
||||
size_t output_num_elements_;
|
||||
size_t input_dim_;
|
||||
bool is_null_input_;
|
||||
template <typename T>
|
||||
bool GlobalMedianCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "plugin/device/cpu/kernel/median_grad_cpu_kernel.h"
|
||||
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
|
||||
|
@ -55,7 +56,16 @@ int MedianGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const s
|
|||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
std::vector<size_t> src_shape;
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(src_shape), LongToSize);
|
||||
size_t input_element_num = std::accumulate(src_shape.begin(), src_shape.end(), size_t(1), std::multiplies<size_t>());
|
||||
is_null_input_ = (input_element_num == 0);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "For '" << kernel_name_ << "', input tensor[0] got 'shapes[" << kIndex0 << "]' is "
|
||||
<< input_element_num;
|
||||
return KRET_OK;
|
||||
}
|
||||
input0_shape_ = inputs[kIndex0]->GetDeviceShapeAdaptively();
|
||||
input1_shape_ = inputs[kIndex1]->GetDeviceShapeAdaptively();
|
||||
input2_shape_ = inputs[kIndex2]->GetDeviceShapeAdaptively();
|
||||
|
@ -138,6 +148,9 @@ template <typename T1, typename T2>
|
|||
bool MedianGradCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
if (global_median_ == false) {
|
||||
return MedianGradCompute<T1, T2>(inputs, outputs);
|
||||
} else {
|
||||
|
|
|
@ -66,6 +66,7 @@ class MedianGradCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelp
|
|||
size_t input2_dim_;
|
||||
size_t input0_num_elements_;
|
||||
size_t input1_num_elements_;
|
||||
bool is_null_input_;
|
||||
template <typename T1, typename T2>
|
||||
bool GlobalMedianGradCompute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) const;
|
||||
template <typename T1, typename T2>
|
||||
|
|
Loading…
Reference in New Issue