diff --git a/mindspore/lite/nnacl/fp16/pooling_fp16.c b/mindspore/lite/nnacl/fp16/pooling_fp16.c index 5dad6783f00..288c17d2a03 100644 --- a/mindspore/lite/nnacl/fp16/pooling_fp16.c +++ b/mindspore/lite/nnacl/fp16/pooling_fp16.c @@ -23,6 +23,7 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa int pad_h = pooling_param->pad_u_; int win_w = pooling_param->window_w_; int win_h = pooling_param->window_h_; + int window = win_w * win_h; int channel = pooling_param->input_channel_; int c8 = channel / C8NUM; int c8_res = channel % C8NUM; @@ -54,7 +55,7 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa int real_win_h_start = MSMAX(0, -in_h_index); int real_win_h_end = MSMIN(win_h, in_h - in_h_index); - int resl_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_start = MSMAX(0, -in_w_index); int real_win_w_end = MSMIN(win_w, in_w - in_w_index); for (int j = 0; j < c8; j++) { @@ -67,7 +68,7 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa #endif int real_count = 0; for (int h = real_win_h_start; h < real_win_h_end; h++) { - for (int w = resl_win_w_start; w < real_win_w_end; w++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; #ifdef ENABLE_NEON tmp_avg = vaddq_f16(tmp_avg, vld1q_f16(input_ptr + in_offset)); @@ -79,6 +80,9 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa ++real_count; } // win_w loop } // win_h loop + if (pooling_param->avg_mode_ == 1) { + real_count = window; + } #ifdef ENABLE_NEON vst1q_f16(output_ptr + out_channel_offset, tmp_avg / vdupq_n_f16(real_count)); #else @@ -99,7 +103,7 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa #endif int real_count = 0; for (int h = real_win_h_start; h < real_win_h_end; h++) { - for (int w = resl_win_w_start; w < real_win_w_end; w++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; #ifdef ENABLE_NEON tmp_avg = vadd_f16(tmp_avg, vld1_f16(input_ptr + in_offset)); @@ -111,6 +115,9 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa ++real_count; } // win_w loop } // win_h loop + if (pooling_param->avg_mode_ == 1) { + real_count = window; + } #ifdef ENABLE_NEON vst1_f16(output_ptr + out_channel_offset, tmp_avg / vdup_n_f16(real_count)); #else @@ -127,7 +134,7 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa float16_t tmp_avg = 0; int real_count = 0; for (int h = real_win_h_start; h < real_win_h_end; h++) { - for (int w = resl_win_w_start; w < real_win_w_end; w++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; tmp_avg += *(input_ptr + in_offset); ++real_count; @@ -139,6 +146,7 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa } // out_plane loop } // out_batch loop } + void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingParameter *pooling_param, int task_id) { int stride_w = pooling_param->stride_w_; int stride_h = pooling_param->stride_h_; @@ -176,7 +184,7 @@ void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa int real_win_h_start = MSMAX(0, -in_h_index); int real_win_h_end = MSMIN(win_h, in_h - in_h_index); - int resl_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_start = MSMAX(0, -in_w_index); int real_win_w_end = MSMIN(win_w, in_w - in_w_index); for (int j = 0; j < c8; j++) { @@ -188,7 +196,7 @@ void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa float16_t tmp_max[8]{-FLT_MAX}; #endif for (int h = real_win_h_start; h < real_win_h_end; h++) { - for (int w = resl_win_w_start; w < real_win_w_end; w++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; #ifdef ENABLE_NEON tmp_max = vmaxq_f16(tmp_max, vld1q_f16(input_ptr + in_offset)); @@ -218,7 +226,7 @@ void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa float16_t tmp_max[4]{-FLT_MAX}; #endif for (int h = real_win_h_start; h < real_win_h_end; h++) { - for (int w = resl_win_w_start; w < real_win_w_end; w++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; #ifdef ENABLE_NEON tmp_max = vmax_f16(tmp_max, vld1_f16(input_ptr + in_offset)); @@ -244,7 +252,7 @@ void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa int out_channel_offset = out_plane_offset + k; float16_t tmp_max = -FLT_MAX; for (int h = real_win_h_start; h < real_win_h_end; h++) { - for (int w = resl_win_w_start; w < real_win_w_end; w++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; tmp_max = fmax(tmp_max, *(input_ptr + in_offset)); } // win_w loop diff --git a/mindspore/lite/nnacl/fp32/pooling.c b/mindspore/lite/nnacl/fp32/pooling.c index 5ce7f18797b..a2b8eda43b4 100644 --- a/mindspore/lite/nnacl/fp32/pooling.c +++ b/mindspore/lite/nnacl/fp32/pooling.c @@ -26,7 +26,7 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo int win_w = pooling_param->window_w_; int win_h = pooling_param->window_h_; int channel = pooling_param->input_channel_; - int c4 = UP_DIV(channel, C4NUM); /* oc && ic */ + int c4 = channel / C4NUM; /* oc && ic */ int in_w = pooling_param->input_w_; int in_h = pooling_param->input_h_; int output_w = pooling_param->output_w_; @@ -35,6 +35,7 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo int out_plane = output_w * output_h; int out_tile_count = UP_DIV(out_plane, TILE_NUM); int thread_num = pooling_param->thread_num_; + int window = win_w * win_h; #ifdef ENABLE_NEON float32x4_t min_value = vdupq_n_f32(minf); @@ -59,10 +60,10 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo int real_win_h_start = MSMAX(0, -in_h_index); int real_win_h_end = MSMIN(win_h, in_h - in_h_index); - int resl_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_start = MSMAX(0, -in_w_index); int real_win_w_end = MSMIN(win_w, in_w - in_w_index); - for (int ci = 0; ci < c4 - 1; ci++) { + for (int ci = 0; ci < c4; ci++) { const float *src_c_ptr = src_plane_ptr + ci * C4NUM; float *dst_c_ptr = dst_plane_ptr + ci * C4NUM; #ifdef ENABLE_NEON @@ -75,7 +76,7 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo #endif int real_count = 0; for (int h = real_win_h_start; h < real_win_h_end; h++) { - for (int w = resl_win_w_start; w < real_win_w_end; w++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; #ifdef ENABLE_NEON tmp_avg = vaddq_f32(tmp_avg, vld1q_f32(src_win_ptr)); @@ -88,6 +89,9 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo ++real_count; } // win_w loop } // win_h loop + if (pooling_param->avg_mode_ == 1) { + real_count = window; + } #ifdef ENABLE_NEON tmp_avg = tmp_avg / vdupq_n_f32(real_count); tmp_avg = vmaxq_f32(tmp_avg, min_value); @@ -112,19 +116,22 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo dst_c_ptr[3] = tmp_avg4; #endif } // ic4-1 loop - int channel_s = (c4 - 1) * C4NUM; + int channel_s = c4 * C4NUM; for (int ci = channel_s; ci < channel; ci++) { const float *src_c_ptr = src_plane_ptr + ci; float *dst_c_ptr = dst_plane_ptr + ci; float tmp_avg = 0; int real_count = 0; for (int h = real_win_h_start; h < real_win_h_end; h++) { - for (int w = resl_win_w_start; w < real_win_w_end; w++) { + for (int w = real_win_w_start; w < real_win_w_end; w++) { const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel; tmp_avg += src_win_ptr[0]; ++real_count; } // win_w loop } // win_h loop + if (pooling_param->avg_mode_ == 1) { + real_count = window; + } tmp_avg = tmp_avg / (float)real_count; tmp_avg = fmax(tmp_avg, minf); tmp_avg = fmin(tmp_avg, maxf); @@ -152,7 +159,7 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo int out_plane = output_w * output_h; int out_tile_count = UP_DIV(out_plane, TILE_NUM); int thread_num = pooling_param->thread_num_; - int c4 = UP_DIV(channel, C4NUM); /* oc && ic */ + int c4 = channel / C4NUM; /* oc && ic */ #ifdef ENABLE_NEON float32x4_t min_value = vdupq_n_f32(minf); @@ -177,10 +184,10 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo int real_win_h_start = MSMAX(0, -in_h_index); int real_win_h_end = MSMIN(win_h, in_h - in_h_index); - int resl_win_w_start = MSMAX(0, -in_w_index); + int real_win_w_start = MSMAX(0, -in_w_index); int real_win_w_end = MSMIN(win_w, in_w - in_w_index); - for (int ci = 0; ci < c4 - 1; ci++) { + for (int ci = 0; ci < c4; ci++) { const float *src_c_ptr = src_plane_ptr + ci * C4NUM; float *dst_c_ptr = dst_plane_ptr + ci * C4NUM; #ifdef ENABLE_NEON @@ -193,7 +200,7 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo #endif for (int kh = real_win_h_start; kh < real_win_h_end; kh++) { - for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) { + for (int kw = real_win_w_start; kw < real_win_w_end; kw++) { const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel; #ifdef ENABLE_NEON tmp_max = vmaxq_f32(tmp_max, vld1q_f32(src_win_ptr)); @@ -224,14 +231,14 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo dst_c_ptr[3] = tmp_max4; #endif } // ic4-1 loop - int channel_s = (c4 - 1) * C4NUM; + int channel_s = c4 * C4NUM; for (int ci = channel_s; ci < channel; ci++) { float *dst_c_ptr = dst_plane_ptr + ci; const float *src_c_ptr = src_plane_ptr + ci; float tmp_max = -FLT_MAX; for (int kh = real_win_h_start; kh < real_win_h_end; kh++) { - for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) { + for (int kw = real_win_w_start; kw < real_win_w_end; kw++) { const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel; tmp_max = fmax(tmp_max, src_win_ptr[0]); } // win_w loop diff --git a/mindspore/lite/nnacl/pooling_parameter.h b/mindspore/lite/nnacl/pooling_parameter.h index e3d7a239db3..b0839c92ec8 100644 --- a/mindspore/lite/nnacl/pooling_parameter.h +++ b/mindspore/lite/nnacl/pooling_parameter.h @@ -46,6 +46,7 @@ typedef struct PoolingParameter { int stride_w_; int stride_h_; int thread_num_; + int avg_mode_; bool global_; bool quantize_; } PoolingParameter; diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 1e4c1de0099..ccaaf1222eb 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -302,6 +302,7 @@ table Pooling { padRight: int; roundMode: RoundMode; activationType: ActivationType = 0; + avgMode: int = 0; } table DepthwiseConv2D { diff --git a/mindspore/lite/src/ops/pooling.cc b/mindspore/lite/src/ops/pooling.cc index 2e39deb4339..5cdefe8dc6b 100644 --- a/mindspore/lite/src/ops/pooling.cc +++ b/mindspore/lite/src/ops/pooling.cc @@ -37,6 +37,7 @@ int Pooling::GetPadLeft() const { return this->primitive_->value.AsPooling()->pa int Pooling::GetPadRight() const { return this->primitive_->value.AsPooling()->padRight; } int Pooling::GetRoundMode() const { return this->primitive_->value.AsPooling()->roundMode; } int Pooling::GetActivationType() const { return this->primitive_->value.AsPooling()->activationType; } +int Pooling::GetAvgMode() const { return this->primitive_->value.AsPooling()->avgMode; } void Pooling::SetFormat(int format) { this->primitive_->value.AsPooling()->format = (schema::Format)format; } void Pooling::SetPoolingMode(int pooling_mode) { @@ -58,6 +59,7 @@ void Pooling::SetRoundMode(int round_mode) { void Pooling::SetActivationType(int activation_type) { this->primitive_->value.AsPooling()->activationType = (schema::ActivationType)activation_type; } +void Pooling::SetAvgMode(int avg_mode) { this->primitive_->value.AsPooling()->avgMode = avg_mode; } int Pooling::UnPackAttr(const Primitive &prim, const std::vector &inputs) { if (this->primitive_ == nullptr) { @@ -93,6 +95,8 @@ int Pooling::UnPackAttr(const Primitive &prim, const std::vector &in attr->format = schema::Format::Format_NUM_OF_FORMAT; } + attr->avgMode = 1; + auto pad_mode = GetValue(prim.GetAttr("padding")); if (pad_mode == "VALID") { attr->padMode = schema::PadMode_VALID; @@ -135,6 +139,7 @@ int Pooling::GetPadLeft() const { return this->primitive_->value_as_Pooling()->p int Pooling::GetPadRight() const { return this->primitive_->value_as_Pooling()->padRight(); } int Pooling::GetRoundMode() const { return this->primitive_->value_as_Pooling()->roundMode(); } int Pooling::GetActivationType() const { return this->primitive_->value_as_Pooling()->activationType(); } +int Pooling::GetAvgMode() const { return this->primitive_->value_as_Pooling()->avgMode(); } int Pooling::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); @@ -144,10 +149,10 @@ int Pooling::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers MS_LOG(ERROR) << "value_as_Pooling return nullptr"; return RET_ERROR; } - auto val_offset = - schema::CreatePooling(*fbb, attr->format(), attr->poolingMode(), attr->global(), attr->windowW(), attr->windowH(), - attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(), - attr->padLeft(), attr->padRight(), attr->roundMode(), attr->activationType()); + auto val_offset = schema::CreatePooling(*fbb, attr->format(), attr->poolingMode(), attr->global(), attr->windowW(), + attr->windowH(), attr->strideW(), attr->strideH(), attr->padMode(), + attr->padUp(), attr->padDown(), attr->padLeft(), attr->padRight(), + attr->roundMode(), attr->activationType(), attr->avgMode()); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Pooling, val_offset.o); fbb->Finish(prim_offset); return RET_OK; diff --git a/mindspore/lite/src/ops/pooling.h b/mindspore/lite/src/ops/pooling.h index 8dd73c6c9d8..c9f6840dfeb 100644 --- a/mindspore/lite/src/ops/pooling.h +++ b/mindspore/lite/src/ops/pooling.h @@ -45,6 +45,7 @@ class Pooling : public PrimitiveC { void SetPadRight(int pad_right); void SetRoundMode(int round_mode); void SetActivationType(int activation_type); + void SetAvgMode(int avg_mode); int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else Pooling() = default; @@ -66,6 +67,7 @@ class Pooling : public PrimitiveC { int GetPadRight() const; int GetRoundMode() const; int GetActivationType() const; + int GetAvgMode() const; int PadUp() const; int PadDown() const; diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 53af8819fc7..b7ff742dc8f 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -293,6 +293,7 @@ OpParameter *PopulatePoolingParameter(const mindspore::lite::PrimitiveC *primiti pooling_param->pad_r_ = pooling_lite_primitive->PadRight(); pooling_param->stride_w_ = pooling_primitive->GetStrideW(); pooling_param->stride_h_ = pooling_primitive->GetStrideH(); + pooling_param->avg_mode_ = pooling_primitive->GetAvgMode(); auto is_global = pooling_primitive->GetGlobal(); pooling_param->global_ = is_global;