!47517 fix reducesum bug

Merge pull request !47517 from 李林杰/0104_fix_reducesum_by_lh
This commit is contained in:
i-robot 2023-01-05 02:09:33 +00:00 committed by Gitee
commit 14bef4c082
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 47 additions and 24 deletions

View File

@ -41,6 +41,15 @@ const char *const kReduceSum = "ReduceSum";
} \
break; \
}
#define REDUCESUM_DEDUP_AXES(DTYPE, TYPE, CTX) \
case (DTYPE): { \
uint32_t result = ReduceSumDedupAxes<TYPE>(CTX); \
if (result != KERNEL_STATUS_OK) { \
KERNEL_LOG_ERROR("ReduceSum kernel deduplicate axes failed."); \
return result; \
} \
break; \
}
} // namespace
namespace aicpu {
@ -48,6 +57,16 @@ uint32_t ReduceSumCpuKernel::Compute(CpuKernelContext &ctx) {
KERNEL_HANDLE_ERROR(NormalCheck(ctx, kReduceSumInputNum, kReduceSumOutputNum), "[%s] check input and output failed.",
kReduceSum);
KERNEL_HANDLE_ERROR(ReduceSumCheck(ctx), "[%s] check params failed.", kReduceSum);
auto axes_type = ctx.Input(1)->GetDataType();
switch (axes_type) {
REDUCESUM_DEDUP_AXES(DT_INT32, int32_t, ctx)
REDUCESUM_DEDUP_AXES(DT_INT64, int64_t, ctx)
default:
KERNEL_LOG_ERROR("ReduceSum kernel axes data type not support.", DTypeStr(axes_type).c_str());
return KERNEL_STATUS_PARAM_INVALID;
}
auto input_data_type = ctx.Input(0)->GetDataType();
switch (input_data_type) {
REDUCESUM_COMPUTE_CASE(DT_FLOAT16, Eigen::half, ctx)
@ -69,7 +88,7 @@ uint32_t ReduceSumCpuKernel::Compute(CpuKernelContext &ctx) {
}
return KERNEL_STATUS_OK;
}
uint32_t ReduceSumCpuKernel::ReduceSumCheck(CpuKernelContext &ctx) const {
uint32_t ReduceSumCpuKernel::ReduceSumCheck(const CpuKernelContext &ctx) const {
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "get input failed.");
KERNEL_CHECK_NULLPTR(ctx.Input(0)->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID, "Get input tensor shape failed.");
KERNEL_CHECK_NULLPTR(ctx.Output(0)->GetData(), KERNEL_STATUS_PARAM_INVALID, "get output failed.");
@ -81,7 +100,7 @@ uint32_t ReduceSumCpuKernel::ReduceSumCheck(CpuKernelContext &ctx) const {
return KERNEL_STATUS_OK;
}
template <typename T>
uint32_t ReduceSumCpuKernel::ReduceSumCompute(CpuKernelContext &ctx) {
uint32_t ReduceSumCpuKernel::ReduceSumCompute(const CpuKernelContext &ctx) {
std::vector<int64_t> input_shape = ctx.Input(0)->GetTensorShape()->GetDimSizes();
auto input_data = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto output_data = reinterpret_cast<T *>(ctx.Output(0)->GetData());
@ -89,8 +108,8 @@ uint32_t ReduceSumCpuKernel::ReduceSumCompute(CpuKernelContext &ctx) {
output_data[0] = input_data[0];
return KERNEL_STATUS_OK;
}
auto axes_data = reinterpret_cast<int32_t *>(ctx.Input(1)->GetData());
if (axes_data == nullptr) {
if (axes_.empty()) {
int64_t data_num = ctx.Input(0)->NumElements();
auto accumulator = static_cast<T>(0);
for (int64_t i = 0; i < data_num; i++) {
@ -99,11 +118,10 @@ uint32_t ReduceSumCpuKernel::ReduceSumCompute(CpuKernelContext &ctx) {
output_data[0] = accumulator;
return KERNEL_STATUS_OK;
}
std::vector<int64_t> axes;
KERNEL_HANDLE_ERROR(ReduceSumDedupAxes(ctx, axes), "ReduceSum deduplicate failed.");
int64_t output_num = ctx.Output(0)->NumElements();
uint32_t axes_idx = 0;
KERNEL_HANDLE_ERROR(ReduceSumOneAxes<T>(input_data, input_shape, output_data, output_num, axes, axes_idx),
KERNEL_HANDLE_ERROR(ReduceSumOneAxes<T>(input_data, input_shape, output_data, output_num, axes_, axes_idx),
"Reduce sum compute failed.");
return KERNEL_STATUS_OK;
}
@ -141,7 +159,7 @@ uint32_t ReduceSumCpuKernel::ReduceSumOneAxes(const T *input_data, std::vector<i
return result;
}
template <typename T, typename T2>
uint32_t ReduceSumCpuKernel::ReduceSumCompute2(CpuKernelContext &ctx) {
uint32_t ReduceSumCpuKernel::ReduceSumCompute2(const CpuKernelContext &ctx) {
std::vector<int64_t> input_shape = ctx.Input(0)->GetTensorShape()->GetDimSizes();
auto input_data = reinterpret_cast<T *>(ctx.Input(0)->GetData());
auto output_data = reinterpret_cast<T *>(ctx.Output(0)->GetData());
@ -149,9 +167,9 @@ uint32_t ReduceSumCpuKernel::ReduceSumCompute2(CpuKernelContext &ctx) {
output_data[0] = std::complex<T2>(input_data[0].real(), input_data[0].imag());
return KERNEL_STATUS_OK;
}
auto axes_data = reinterpret_cast<int32_t *>(ctx.Input(1)->GetData());
int64_t input_num = ctx.Input(0)->NumElements();
if (axes_data == nullptr) {
if (axes_.empty()) {
auto accumulator_real = static_cast<T2>(0);
auto accumulator_imag = static_cast<T2>(0);
for (int64_t i = 0; i < input_num; i++) {
@ -161,12 +179,11 @@ uint32_t ReduceSumCpuKernel::ReduceSumCompute2(CpuKernelContext &ctx) {
output_data[0] = std::complex<T2>(accumulator_real, accumulator_imag);
return KERNEL_STATUS_OK;
}
std::vector<int64_t> axes;
KERNEL_HANDLE_ERROR(ReduceSumDedupAxes(ctx, axes), "ReduceSum deduplicate failed.");
int64_t output_num = ctx.Output(0)->NumElements();
uint32_t axes_idx = 0;
KERNEL_HANDLE_ERROR(
(ReduceSumOneAxes2<T, T2>(input_data, input_num, input_shape, output_data, output_num, axes, axes_idx)),
(ReduceSumOneAxes2<T, T2>(input_data, input_num, input_shape, output_data, output_num, axes_, axes_idx)),
"Reduce sum compute failed.");
return KERNEL_STATUS_OK;
}
@ -218,9 +235,11 @@ uint32_t ReduceSumCpuKernel::ReduceSumOneAxes2(const T *input_data, int64_t inpu
}
return result;
}
uint32_t ReduceSumCpuKernel::ReduceSumDedupAxes(CpuKernelContext &ctx, std::vector<int64_t> &axes) {
template <typename T1>
uint32_t ReduceSumCpuKernel::ReduceSumDedupAxes(CpuKernelContext &ctx) {
int32_t rank = ctx.Input(0)->GetTensorShape()->GetDims();
auto axes_data = reinterpret_cast<int32_t *>(ctx.Input(1)->GetData());
auto axes_data = reinterpret_cast<T1 *>(ctx.Input(1)->GetData());
int64_t axes_num = ctx.Input(1)->NumElements();
for (int64_t i = 0; i < axes_num; i++) {
int32_t axis = axes_data[i];
@ -229,13 +248,13 @@ uint32_t ReduceSumCpuKernel::ReduceSumDedupAxes(CpuKernelContext &ctx, std::vect
if (axis < 0) {
axis += rank;
}
axes.push_back(axis);
axes_.push_back(axis);
}
int64_t j = 1;
while (j < axes_num) {
std::vector<int64_t>::iterator iter = find(axes.begin(), axes.begin() + j, axes[j]);
if (iter != axes.begin() + j) {
axes.erase(iter);
std::vector<int64_t>::iterator iter = find(axes_.begin(), axes_.begin() + j, axes_[j]);
if (iter != axes_.begin() + j) {
axes_.erase(iter);
axes_num--;
} else {
j++;
@ -243,6 +262,7 @@ uint32_t ReduceSumCpuKernel::ReduceSumDedupAxes(CpuKernelContext &ctx, std::vect
}
return KERNEL_STATUS_OK;
}
uint32_t ReduceSumCpuKernel::ReduceSumParseAxes(std::vector<int64_t> &input_shape, std::vector<int64_t> &axes,
uint32_t &axes_idx, int64_t &inner, int64_t &outer,
int64_t &depth) const {

View File

@ -28,26 +28,29 @@ class ReduceSumCpuKernel : public CpuKernel {
uint32_t Compute(CpuKernelContext &ctx) override;
private:
uint32_t ReduceSumCheck(CpuKernelContext &ctx) const;
uint32_t ReduceSumCheck(const CpuKernelContext &ctx) const;
template <typename T>
uint32_t ReduceSumCompute(CpuKernelContext &ctx);
uint32_t ReduceSumCompute(const CpuKernelContext &ctx);
template <typename T>
uint32_t ReduceSumOneAxes(const T *input_data, std::vector<int64_t> &input_shape, T *output_data, int64_t output_num,
std::vector<int64_t> &axes, uint32_t &axes_idx);
template <typename T, typename T2>
uint32_t ReduceSumCompute2(CpuKernelContext &ctx);
uint32_t ReduceSumCompute2(const CpuKernelContext &ctx);
template <typename T, typename T2>
uint32_t ReduceSumOneAxes2(const T *input_data, int64_t input_num, std::vector<int64_t> input_shape, T *output_data,
int64_t output_num, std::vector<int64_t> &axes, uint32_t &axes_idx);
uint32_t ReduceSumDedupAxes(CpuKernelContext &ctx, std::vector<int64_t> &axes);
template <typename T1>
uint32_t ReduceSumDedupAxes(CpuKernelContext &ctx);
uint32_t ReduceSumParseAxes(std::vector<int64_t> &input_shape, std::vector<int64_t> &axes, uint32_t &axes_idx,
int64_t &inner, int64_t &outer, int64_t &depth) const;
std::vector<int64_t> axes_;
};
} // namespace aicpu
#endif