forked from mindspore-Ecosystem/mindspore
!47517 fix reducesum bug
Merge pull request !47517 from 李林杰/0104_fix_reducesum_by_lh
This commit is contained in:
commit
14bef4c082
|
@ -41,6 +41,15 @@ const char *const kReduceSum = "ReduceSum";
|
|||
} \
|
||||
break; \
|
||||
}
|
||||
#define REDUCESUM_DEDUP_AXES(DTYPE, TYPE, CTX) \
|
||||
case (DTYPE): { \
|
||||
uint32_t result = ReduceSumDedupAxes<TYPE>(CTX); \
|
||||
if (result != KERNEL_STATUS_OK) { \
|
||||
KERNEL_LOG_ERROR("ReduceSum kernel deduplicate axes failed."); \
|
||||
return result; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace aicpu {
|
||||
|
@ -48,6 +57,16 @@ uint32_t ReduceSumCpuKernel::Compute(CpuKernelContext &ctx) {
|
|||
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kReduceSumInputNum, kReduceSumOutputNum), "[%s] check input and output failed.",
|
||||
kReduceSum);
|
||||
KERNEL_HANDLE_ERROR(ReduceSumCheck(ctx), "[%s] check params failed.", kReduceSum);
|
||||
|
||||
auto axes_type = ctx.Input(1)->GetDataType();
|
||||
switch (axes_type) {
|
||||
REDUCESUM_DEDUP_AXES(DT_INT32, int32_t, ctx)
|
||||
REDUCESUM_DEDUP_AXES(DT_INT64, int64_t, ctx)
|
||||
default:
|
||||
KERNEL_LOG_ERROR("ReduceSum kernel axes data type not support.", DTypeStr(axes_type).c_str());
|
||||
return KERNEL_STATUS_PARAM_INVALID;
|
||||
}
|
||||
|
||||
auto input_data_type = ctx.Input(0)->GetDataType();
|
||||
switch (input_data_type) {
|
||||
REDUCESUM_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
|
||||
|
@ -69,7 +88,7 @@ uint32_t ReduceSumCpuKernel::Compute(CpuKernelContext &ctx) {
|
|||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumCheck(CpuKernelContext &ctx) const {
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumCheck(const CpuKernelContext &ctx) const {
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "get input failed.");
|
||||
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID, "Get input tensor shape failed.");
|
||||
KERNEL_CHECK_NULLPTR(ctx.Output(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "get output failed.");
|
||||
|
@ -81,7 +100,7 @@ uint32_t ReduceSumCpuKernel::ReduceSumCheck(CpuKernelContext &ctx) const {
|
|||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
template <typename T>
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumCompute(CpuKernelContext &ctx) {
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumCompute(const CpuKernelContext &ctx) {
|
||||
std::vector<int64_t> input_shape = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
auto input_data = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto output_data = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
|
@ -89,8 +108,8 @@ uint32_t ReduceSumCpuKernel::ReduceSumCompute(CpuKernelContext &ctx) {
|
|||
output_data[0] = input_data[0];
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
auto axes_data = reinterpret_cast<int32_t *>(ctx.Input(1)->GetData());
|
||||
if (axes_data == nullptr) {
|
||||
|
||||
if (axes_.empty()) {
|
||||
int64_t data_num = ctx.Input(0)->NumElements();
|
||||
auto accumulator = static_cast<T>(0);
|
||||
for (int64_t i = 0; i < data_num; i++) {
|
||||
|
@ -99,11 +118,10 @@ uint32_t ReduceSumCpuKernel::ReduceSumCompute(CpuKernelContext &ctx) {
|
|||
output_data[0] = accumulator;
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
std::vector<int64_t> axes;
|
||||
KERNEL_HANDLE_ERROR(ReduceSumDedupAxes(ctx, axes), "ReduceSum deduplicate failed.");
|
||||
|
||||
int64_t output_num = ctx.Output(0)->NumElements();
|
||||
uint32_t axes_idx = 0;
|
||||
KERNEL_HANDLE_ERROR(ReduceSumOneAxes<T>(input_data, input_shape, output_data, output_num, axes, axes_idx),
|
||||
KERNEL_HANDLE_ERROR(ReduceSumOneAxes<T>(input_data, input_shape, output_data, output_num, axes_, axes_idx),
|
||||
"Reduce sum compute failed.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
@ -141,7 +159,7 @@ uint32_t ReduceSumCpuKernel::ReduceSumOneAxes(const T *input_data, std::vector<i
|
|||
return result;
|
||||
}
|
||||
template <typename T, typename T2>
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumCompute2(CpuKernelContext &ctx) {
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumCompute2(const CpuKernelContext &ctx) {
|
||||
std::vector<int64_t> input_shape = ctx.Input(0)->GetTensorShape()->GetDimSizes();
|
||||
auto input_data = reinterpret_cast<T *>(ctx.Input(0)->GetData());
|
||||
auto output_data = reinterpret_cast<T *>(ctx.Output(0)->GetData());
|
||||
|
@ -149,9 +167,9 @@ uint32_t ReduceSumCpuKernel::ReduceSumCompute2(CpuKernelContext &ctx) {
|
|||
output_data[0] = std::complex<T2>(input_data[0].real(), input_data[0].imag());
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
auto axes_data = reinterpret_cast<int32_t *>(ctx.Input(1)->GetData());
|
||||
|
||||
int64_t input_num = ctx.Input(0)->NumElements();
|
||||
if (axes_data == nullptr) {
|
||||
if (axes_.empty()) {
|
||||
auto accumulator_real = static_cast<T2>(0);
|
||||
auto accumulator_imag = static_cast<T2>(0);
|
||||
for (int64_t i = 0; i < input_num; i++) {
|
||||
|
@ -161,12 +179,11 @@ uint32_t ReduceSumCpuKernel::ReduceSumCompute2(CpuKernelContext &ctx) {
|
|||
output_data[0] = std::complex<T2>(accumulator_real, accumulator_imag);
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
std::vector<int64_t> axes;
|
||||
KERNEL_HANDLE_ERROR(ReduceSumDedupAxes(ctx, axes), "ReduceSum deduplicate failed.");
|
||||
|
||||
int64_t output_num = ctx.Output(0)->NumElements();
|
||||
uint32_t axes_idx = 0;
|
||||
KERNEL_HANDLE_ERROR(
|
||||
(ReduceSumOneAxes2<T, T2>(input_data, input_num, input_shape, output_data, output_num, axes, axes_idx)),
|
||||
(ReduceSumOneAxes2<T, T2>(input_data, input_num, input_shape, output_data, output_num, axes_, axes_idx)),
|
||||
"Reduce sum compute failed.");
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
@ -218,9 +235,11 @@ uint32_t ReduceSumCpuKernel::ReduceSumOneAxes2(const T *input_data, int64_t inpu
|
|||
}
|
||||
return result;
|
||||
}
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumDedupAxes(CpuKernelContext &ctx, std::vector<int64_t> &axes) {
|
||||
|
||||
template <typename T1>
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumDedupAxes(CpuKernelContext &ctx) {
|
||||
int32_t rank = ctx.Input(0)->GetTensorShape()->GetDims();
|
||||
auto axes_data = reinterpret_cast<int32_t *>(ctx.Input(1)->GetData());
|
||||
auto axes_data = reinterpret_cast<T1 *>(ctx.Input(1)->GetData());
|
||||
int64_t axes_num = ctx.Input(1)->NumElements();
|
||||
for (int64_t i = 0; i < axes_num; i++) {
|
||||
int32_t axis = axes_data[i];
|
||||
|
@ -229,13 +248,13 @@ uint32_t ReduceSumCpuKernel::ReduceSumDedupAxes(CpuKernelContext &ctx, std::vect
|
|||
if (axis < 0) {
|
||||
axis += rank;
|
||||
}
|
||||
axes.push_back(axis);
|
||||
axes_.push_back(axis);
|
||||
}
|
||||
int64_t j = 1;
|
||||
while (j < axes_num) {
|
||||
std::vector<int64_t>::iterator iter = find(axes.begin(), axes.begin() + j, axes[j]);
|
||||
if (iter != axes.begin() + j) {
|
||||
axes.erase(iter);
|
||||
std::vector<int64_t>::iterator iter = find(axes_.begin(), axes_.begin() + j, axes_[j]);
|
||||
if (iter != axes_.begin() + j) {
|
||||
axes_.erase(iter);
|
||||
axes_num--;
|
||||
} else {
|
||||
j++;
|
||||
|
@ -243,6 +262,7 @@ uint32_t ReduceSumCpuKernel::ReduceSumDedupAxes(CpuKernelContext &ctx, std::vect
|
|||
}
|
||||
return KERNEL_STATUS_OK;
|
||||
}
|
||||
|
||||
uint32_t ReduceSumCpuKernel::ReduceSumParseAxes(std::vector<int64_t> &input_shape, std::vector<int64_t> &axes,
|
||||
uint32_t &axes_idx, int64_t &inner, int64_t &outer,
|
||||
int64_t &depth) const {
|
||||
|
|
|
@ -28,26 +28,29 @@ class ReduceSumCpuKernel : public CpuKernel {
|
|||
uint32_t Compute(CpuKernelContext &ctx) override;
|
||||
|
||||
private:
|
||||
uint32_t ReduceSumCheck(CpuKernelContext &ctx) const;
|
||||
uint32_t ReduceSumCheck(const CpuKernelContext &ctx) const;
|
||||
|
||||
template <typename T>
|
||||
uint32_t ReduceSumCompute(CpuKernelContext &ctx);
|
||||
uint32_t ReduceSumCompute(const CpuKernelContext &ctx);
|
||||
|
||||
template <typename T>
|
||||
uint32_t ReduceSumOneAxes(const T *input_data, std::vector<int64_t> &input_shape, T *output_data, int64_t output_num,
|
||||
std::vector<int64_t> &axes, uint32_t &axes_idx);
|
||||
|
||||
template <typename T, typename T2>
|
||||
uint32_t ReduceSumCompute2(CpuKernelContext &ctx);
|
||||
uint32_t ReduceSumCompute2(const CpuKernelContext &ctx);
|
||||
|
||||
template <typename T, typename T2>
|
||||
uint32_t ReduceSumOneAxes2(const T *input_data, int64_t input_num, std::vector<int64_t> input_shape, T *output_data,
|
||||
int64_t output_num, std::vector<int64_t> &axes, uint32_t &axes_idx);
|
||||
|
||||
uint32_t ReduceSumDedupAxes(CpuKernelContext &ctx, std::vector<int64_t> &axes);
|
||||
template <typename T1>
|
||||
uint32_t ReduceSumDedupAxes(CpuKernelContext &ctx);
|
||||
|
||||
uint32_t ReduceSumParseAxes(std::vector<int64_t> &input_shape, std::vector<int64_t> &axes, uint32_t &axes_idx,
|
||||
int64_t &inner, int64_t &outer, int64_t &depth) const;
|
||||
|
||||
std::vector<int64_t> axes_;
|
||||
};
|
||||
} // namespace aicpu
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue