From ad9f90dc1e23ac418e11939a976dd8515cf7a0f7 Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Wed, 14 Jul 2021 11:38:59 +0800 Subject: [PATCH] add_reducesum_fp16 --- .../cpu/nnacl/fp16/reduce_fp16.c | 40 +++++++++++++++++++ .../cpu/nnacl/fp16/reduce_fp16.h | 2 + .../runtime/kernel/arm/fp16/reduce_fp16.cc | 9 +++-- 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/reduce_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/reduce_fp16.c index 5ece339cd41..e77d040399f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/reduce_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/reduce_fp16.c @@ -61,3 +61,43 @@ int ReduceMaxFp16(int outer_size, int inner_size, int axis_size, const float16_t } return NNACL_OK; } + +int ReduceSumFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num) { + int stride = UP_DIV(outer_size, thread_num); + int start = stride * tid; + int end = MSMIN(outer_size, start + stride); + int num = end - start; +#ifdef ENABLE_NEON + int block_c8 = inner_size - inner_size % C8NUM; +#endif + + int src_stride = axis_size * inner_size; + src_data += start * src_stride; + dst_data += start * inner_size; + + for (int i = 0; i < num; i++, src_data += src_stride, dst_data += inner_size) { + int j = 0; +#ifdef ENABLE_NEON + for (; j < block_c8; j += C8NUM) { + const float16_t *inner_src = src_data + j; + float16_t *inner_dst = dst_data + j; + float16x8_t tmp = {0, 0, 0, 0, 0, 0, 0, 0}; + for (int k = 0; k < axis_size; k++) { + tmp = vaddq_f16(tmp, vld1q_f16(inner_src + k * inner_size)); + } + vst1q_f16(inner_dst, tmp); + } +#endif + for (; j < inner_size; j++) { + const float16_t *inner_src = src_data + j; + float16_t *inner_dst = dst_data + j; + float tmp = 0.0f; + for (int k = 0; k < axis_size; k++) { + tmp += inner_src[k * inner_size]; + } + *inner_dst = tmp; + } + } + return NNACL_OK; +} diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/reduce_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/reduce_fp16.h index c080882c971..f11b6751f7e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/reduce_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/reduce_fp16.h @@ -26,6 +26,8 @@ int ReduceMeanFp16(const int outer_size, const int inner_size, const int axis_si float16_t *dst_data, const int tid, const int thread_num); int ReduceMaxFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, int tid, int thread_num); +int ReduceSumFp16(int outer_size, int inner_size, int axis_size, const float16_t *src_data, float16_t *dst_data, + int tid, int thread_num); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/reduce_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/reduce_fp16.cc index d5620d72997..5af2c51d44e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/reduce_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/reduce_fp16.cc @@ -49,6 +49,9 @@ int ReduceFp16CPUKernel::Init() { case static_cast(ReduceMode_ReduceMax): reducer_ = ReduceMaxFp16; break; + case static_cast(ReduceMode_ReduceSum): + reducer_ = ReduceSumFp16; + break; default: MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << mode_; return RET_ERROR; @@ -142,11 +145,9 @@ int ReduceFp16CPUKernel::MallocTmpBuffer() { kernel::InnerKernel *CpuReduceFp16KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::Context *ctx, const kernel::KernelKey &desc) { - MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_ReduceFusion); - auto reduce_param = reinterpret_cast(opParameter); - if (reduce_param->mode_ != ReduceMode_ReduceMean && reduce_param->mode_ != ReduceMode_ReduceMax) { + if (reduce_param->mode_ != ReduceMode_ReduceMean && reduce_param->mode_ != ReduceMode_ReduceMax && + reduce_param->mode_ != ReduceMode_ReduceSum) { MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << reduce_param->mode_; return nullptr; }