forked from OSSInnovation/mindspore
!6480 [MS][LITE]solve avgMode problem
Merge pull request !6480 from fuzhiye/wino
This commit is contained in:
commit
7d2da469a9
|
@ -23,6 +23,7 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa
|
|||
int pad_h = pooling_param->pad_u_;
|
||||
int win_w = pooling_param->window_w_;
|
||||
int win_h = pooling_param->window_h_;
|
||||
int window = win_w * win_h;
|
||||
int channel = pooling_param->input_channel_;
|
||||
int c8 = channel / C8NUM;
|
||||
int c8_res = channel % C8NUM;
|
||||
|
@ -54,7 +55,7 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa
|
|||
|
||||
int real_win_h_start = MSMAX(0, -in_h_index);
|
||||
int real_win_h_end = MSMIN(win_h, in_h - in_h_index);
|
||||
int resl_win_w_start = MSMAX(0, -in_w_index);
|
||||
int real_win_w_start = MSMAX(0, -in_w_index);
|
||||
int real_win_w_end = MSMIN(win_w, in_w - in_w_index);
|
||||
|
||||
for (int j = 0; j < c8; j++) {
|
||||
|
@ -67,7 +68,7 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa
|
|||
#endif
|
||||
int real_count = 0;
|
||||
for (int h = real_win_h_start; h < real_win_h_end; h++) {
|
||||
for (int w = resl_win_w_start; w < real_win_w_end; w++) {
|
||||
for (int w = real_win_w_start; w < real_win_w_end; w++) {
|
||||
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_avg = vaddq_f16(tmp_avg, vld1q_f16(input_ptr + in_offset));
|
||||
|
@ -79,6 +80,9 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa
|
|||
++real_count;
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
if (pooling_param->avg_mode_ == 1) {
|
||||
real_count = window;
|
||||
}
|
||||
#ifdef ENABLE_NEON
|
||||
vst1q_f16(output_ptr + out_channel_offset, tmp_avg / vdupq_n_f16(real_count));
|
||||
#else
|
||||
|
@ -99,7 +103,7 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa
|
|||
#endif
|
||||
int real_count = 0;
|
||||
for (int h = real_win_h_start; h < real_win_h_end; h++) {
|
||||
for (int w = resl_win_w_start; w < real_win_w_end; w++) {
|
||||
for (int w = real_win_w_start; w < real_win_w_end; w++) {
|
||||
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_avg = vadd_f16(tmp_avg, vld1_f16(input_ptr + in_offset));
|
||||
|
@ -111,6 +115,9 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa
|
|||
++real_count;
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
if (pooling_param->avg_mode_ == 1) {
|
||||
real_count = window;
|
||||
}
|
||||
#ifdef ENABLE_NEON
|
||||
vst1_f16(output_ptr + out_channel_offset, tmp_avg / vdup_n_f16(real_count));
|
||||
#else
|
||||
|
@ -127,7 +134,7 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa
|
|||
float16_t tmp_avg = 0;
|
||||
int real_count = 0;
|
||||
for (int h = real_win_h_start; h < real_win_h_end; h++) {
|
||||
for (int w = resl_win_w_start; w < real_win_w_end; w++) {
|
||||
for (int w = real_win_w_start; w < real_win_w_end; w++) {
|
||||
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
tmp_avg += *(input_ptr + in_offset);
|
||||
++real_count;
|
||||
|
@ -139,6 +146,7 @@ void AvgPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa
|
|||
} // out_plane loop
|
||||
} // out_batch loop
|
||||
}
|
||||
|
||||
void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingParameter *pooling_param, int task_id) {
|
||||
int stride_w = pooling_param->stride_w_;
|
||||
int stride_h = pooling_param->stride_h_;
|
||||
|
@ -176,7 +184,7 @@ void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa
|
|||
|
||||
int real_win_h_start = MSMAX(0, -in_h_index);
|
||||
int real_win_h_end = MSMIN(win_h, in_h - in_h_index);
|
||||
int resl_win_w_start = MSMAX(0, -in_w_index);
|
||||
int real_win_w_start = MSMAX(0, -in_w_index);
|
||||
int real_win_w_end = MSMIN(win_w, in_w - in_w_index);
|
||||
|
||||
for (int j = 0; j < c8; j++) {
|
||||
|
@ -188,7 +196,7 @@ void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa
|
|||
float16_t tmp_max[8]{-FLT_MAX};
|
||||
#endif
|
||||
for (int h = real_win_h_start; h < real_win_h_end; h++) {
|
||||
for (int w = resl_win_w_start; w < real_win_w_end; w++) {
|
||||
for (int w = real_win_w_start; w < real_win_w_end; w++) {
|
||||
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_max = vmaxq_f16(tmp_max, vld1q_f16(input_ptr + in_offset));
|
||||
|
@ -218,7 +226,7 @@ void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa
|
|||
float16_t tmp_max[4]{-FLT_MAX};
|
||||
#endif
|
||||
for (int h = real_win_h_start; h < real_win_h_end; h++) {
|
||||
for (int w = resl_win_w_start; w < real_win_w_end; w++) {
|
||||
for (int w = real_win_w_start; w < real_win_w_end; w++) {
|
||||
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_max = vmax_f16(tmp_max, vld1_f16(input_ptr + in_offset));
|
||||
|
@ -244,7 +252,7 @@ void MaxPoolingFp16(const float16_t *input_ptr, float16_t *output_ptr, PoolingPa
|
|||
int out_channel_offset = out_plane_offset + k;
|
||||
float16_t tmp_max = -FLT_MAX;
|
||||
for (int h = real_win_h_start; h < real_win_h_end; h++) {
|
||||
for (int w = resl_win_w_start; w < real_win_w_end; w++) {
|
||||
for (int w = real_win_w_start; w < real_win_w_end; w++) {
|
||||
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
tmp_max = fmax(tmp_max, *(input_ptr + in_offset));
|
||||
} // win_w loop
|
||||
|
|
|
@ -26,7 +26,7 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
int win_w = pooling_param->window_w_;
|
||||
int win_h = pooling_param->window_h_;
|
||||
int channel = pooling_param->input_channel_;
|
||||
int c4 = UP_DIV(channel, C4NUM); /* oc && ic */
|
||||
int c4 = channel / C4NUM; /* oc && ic */
|
||||
int in_w = pooling_param->input_w_;
|
||||
int in_h = pooling_param->input_h_;
|
||||
int output_w = pooling_param->output_w_;
|
||||
|
@ -35,6 +35,7 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
int out_plane = output_w * output_h;
|
||||
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
||||
int thread_num = pooling_param->thread_num_;
|
||||
int window = win_w * win_h;
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t min_value = vdupq_n_f32(minf);
|
||||
|
@ -59,10 +60,10 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
|
||||
int real_win_h_start = MSMAX(0, -in_h_index);
|
||||
int real_win_h_end = MSMIN(win_h, in_h - in_h_index);
|
||||
int resl_win_w_start = MSMAX(0, -in_w_index);
|
||||
int real_win_w_start = MSMAX(0, -in_w_index);
|
||||
int real_win_w_end = MSMIN(win_w, in_w - in_w_index);
|
||||
|
||||
for (int ci = 0; ci < c4 - 1; ci++) {
|
||||
for (int ci = 0; ci < c4; ci++) {
|
||||
const float *src_c_ptr = src_plane_ptr + ci * C4NUM;
|
||||
float *dst_c_ptr = dst_plane_ptr + ci * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
|
@ -75,7 +76,7 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
#endif
|
||||
int real_count = 0;
|
||||
for (int h = real_win_h_start; h < real_win_h_end; h++) {
|
||||
for (int w = resl_win_w_start; w < real_win_w_end; w++) {
|
||||
for (int w = real_win_w_start; w < real_win_w_end; w++) {
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_avg = vaddq_f32(tmp_avg, vld1q_f32(src_win_ptr));
|
||||
|
@ -88,6 +89,9 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
++real_count;
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
if (pooling_param->avg_mode_ == 1) {
|
||||
real_count = window;
|
||||
}
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_avg = tmp_avg / vdupq_n_f32(real_count);
|
||||
tmp_avg = vmaxq_f32(tmp_avg, min_value);
|
||||
|
@ -112,19 +116,22 @@ void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
dst_c_ptr[3] = tmp_avg4;
|
||||
#endif
|
||||
} // ic4-1 loop
|
||||
int channel_s = (c4 - 1) * C4NUM;
|
||||
int channel_s = c4 * C4NUM;
|
||||
for (int ci = channel_s; ci < channel; ci++) {
|
||||
const float *src_c_ptr = src_plane_ptr + ci;
|
||||
float *dst_c_ptr = dst_plane_ptr + ci;
|
||||
float tmp_avg = 0;
|
||||
int real_count = 0;
|
||||
for (int h = real_win_h_start; h < real_win_h_end; h++) {
|
||||
for (int w = resl_win_w_start; w < real_win_w_end; w++) {
|
||||
for (int w = real_win_w_start; w < real_win_w_end; w++) {
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
||||
tmp_avg += src_win_ptr[0];
|
||||
++real_count;
|
||||
} // win_w loop
|
||||
} // win_h loop
|
||||
if (pooling_param->avg_mode_ == 1) {
|
||||
real_count = window;
|
||||
}
|
||||
tmp_avg = tmp_avg / (float)real_count;
|
||||
tmp_avg = fmax(tmp_avg, minf);
|
||||
tmp_avg = fmin(tmp_avg, maxf);
|
||||
|
@ -152,7 +159,7 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
int out_plane = output_w * output_h;
|
||||
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
||||
int thread_num = pooling_param->thread_num_;
|
||||
int c4 = UP_DIV(channel, C4NUM); /* oc && ic */
|
||||
int c4 = channel / C4NUM; /* oc && ic */
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t min_value = vdupq_n_f32(minf);
|
||||
|
@ -177,10 +184,10 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
|
||||
int real_win_h_start = MSMAX(0, -in_h_index);
|
||||
int real_win_h_end = MSMIN(win_h, in_h - in_h_index);
|
||||
int resl_win_w_start = MSMAX(0, -in_w_index);
|
||||
int real_win_w_start = MSMAX(0, -in_w_index);
|
||||
int real_win_w_end = MSMIN(win_w, in_w - in_w_index);
|
||||
|
||||
for (int ci = 0; ci < c4 - 1; ci++) {
|
||||
for (int ci = 0; ci < c4; ci++) {
|
||||
const float *src_c_ptr = src_plane_ptr + ci * C4NUM;
|
||||
float *dst_c_ptr = dst_plane_ptr + ci * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
|
@ -193,7 +200,7 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
#endif
|
||||
|
||||
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
|
||||
for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) {
|
||||
for (int kw = real_win_w_start; kw < real_win_w_end; kw++) {
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
|
||||
#ifdef ENABLE_NEON
|
||||
tmp_max = vmaxq_f32(tmp_max, vld1q_f32(src_win_ptr));
|
||||
|
@ -224,14 +231,14 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
|
|||
dst_c_ptr[3] = tmp_max4;
|
||||
#endif
|
||||
} // ic4-1 loop
|
||||
int channel_s = (c4 - 1) * C4NUM;
|
||||
int channel_s = c4 * C4NUM;
|
||||
for (int ci = channel_s; ci < channel; ci++) {
|
||||
float *dst_c_ptr = dst_plane_ptr + ci;
|
||||
const float *src_c_ptr = src_plane_ptr + ci;
|
||||
float tmp_max = -FLT_MAX;
|
||||
|
||||
for (int kh = real_win_h_start; kh < real_win_h_end; kh++) {
|
||||
for (int kw = resl_win_w_start; kw < real_win_w_end; kw++) {
|
||||
for (int kw = real_win_w_start; kw < real_win_w_end; kw++) {
|
||||
const float *src_win_ptr = src_c_ptr + ((in_h_index + kh) * in_w + in_w_index + kw) * channel;
|
||||
tmp_max = fmax(tmp_max, src_win_ptr[0]);
|
||||
} // win_w loop
|
||||
|
|
|
@ -46,6 +46,7 @@ typedef struct PoolingParameter {
|
|||
int stride_w_;
|
||||
int stride_h_;
|
||||
int thread_num_;
|
||||
int avg_mode_;
|
||||
bool global_;
|
||||
bool quantize_;
|
||||
} PoolingParameter;
|
||||
|
|
|
@ -302,6 +302,7 @@ table Pooling {
|
|||
padRight: int;
|
||||
roundMode: RoundMode;
|
||||
activationType: ActivationType = 0;
|
||||
avgMode: int = 0;
|
||||
}
|
||||
|
||||
table DepthwiseConv2D {
|
||||
|
|
|
@ -37,6 +37,7 @@ int Pooling::GetPadLeft() const { return this->primitive_->value.AsPooling()->pa
|
|||
int Pooling::GetPadRight() const { return this->primitive_->value.AsPooling()->padRight; }
|
||||
int Pooling::GetRoundMode() const { return this->primitive_->value.AsPooling()->roundMode; }
|
||||
int Pooling::GetActivationType() const { return this->primitive_->value.AsPooling()->activationType; }
|
||||
int Pooling::GetAvgMode() const { return this->primitive_->value.AsPooling()->avgMode; }
|
||||
|
||||
void Pooling::SetFormat(int format) { this->primitive_->value.AsPooling()->format = (schema::Format)format; }
|
||||
void Pooling::SetPoolingMode(int pooling_mode) {
|
||||
|
@ -58,6 +59,7 @@ void Pooling::SetRoundMode(int round_mode) {
|
|||
void Pooling::SetActivationType(int activation_type) {
|
||||
this->primitive_->value.AsPooling()->activationType = (schema::ActivationType)activation_type;
|
||||
}
|
||||
void Pooling::SetAvgMode(int avg_mode) { this->primitive_->value.AsPooling()->avgMode = avg_mode; }
|
||||
|
||||
int Pooling::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
|
@ -93,6 +95,8 @@ int Pooling::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &in
|
|||
attr->format = schema::Format::Format_NUM_OF_FORMAT;
|
||||
}
|
||||
|
||||
attr->avgMode = 1;
|
||||
|
||||
auto pad_mode = GetValue<std::string>(prim.GetAttr("padding"));
|
||||
if (pad_mode == "VALID") {
|
||||
attr->padMode = schema::PadMode_VALID;
|
||||
|
@ -135,6 +139,7 @@ int Pooling::GetPadLeft() const { return this->primitive_->value_as_Pooling()->p
|
|||
int Pooling::GetPadRight() const { return this->primitive_->value_as_Pooling()->padRight(); }
|
||||
int Pooling::GetRoundMode() const { return this->primitive_->value_as_Pooling()->roundMode(); }
|
||||
int Pooling::GetActivationType() const { return this->primitive_->value_as_Pooling()->activationType(); }
|
||||
int Pooling::GetAvgMode() const { return this->primitive_->value_as_Pooling()->avgMode(); }
|
||||
|
||||
int Pooling::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
|
@ -144,10 +149,10 @@ int Pooling::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
|
|||
MS_LOG(ERROR) << "value_as_Pooling return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset =
|
||||
schema::CreatePooling(*fbb, attr->format(), attr->poolingMode(), attr->global(), attr->windowW(), attr->windowH(),
|
||||
attr->strideW(), attr->strideH(), attr->padMode(), attr->padUp(), attr->padDown(),
|
||||
attr->padLeft(), attr->padRight(), attr->roundMode(), attr->activationType());
|
||||
auto val_offset = schema::CreatePooling(*fbb, attr->format(), attr->poolingMode(), attr->global(), attr->windowW(),
|
||||
attr->windowH(), attr->strideW(), attr->strideH(), attr->padMode(),
|
||||
attr->padUp(), attr->padDown(), attr->padLeft(), attr->padRight(),
|
||||
attr->roundMode(), attr->activationType(), attr->avgMode());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Pooling, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
|
|
|
@ -45,6 +45,7 @@ class Pooling : public PrimitiveC {
|
|||
void SetPadRight(int pad_right);
|
||||
void SetRoundMode(int round_mode);
|
||||
void SetActivationType(int activation_type);
|
||||
void SetAvgMode(int avg_mode);
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
Pooling() = default;
|
||||
|
@ -66,6 +67,7 @@ class Pooling : public PrimitiveC {
|
|||
int GetPadRight() const;
|
||||
int GetRoundMode() const;
|
||||
int GetActivationType() const;
|
||||
int GetAvgMode() const;
|
||||
|
||||
int PadUp() const;
|
||||
int PadDown() const;
|
||||
|
|
|
@ -293,6 +293,7 @@ OpParameter *PopulatePoolingParameter(const mindspore::lite::PrimitiveC *primiti
|
|||
pooling_param->pad_r_ = pooling_lite_primitive->PadRight();
|
||||
pooling_param->stride_w_ = pooling_primitive->GetStrideW();
|
||||
pooling_param->stride_h_ = pooling_primitive->GetStrideH();
|
||||
pooling_param->avg_mode_ = pooling_primitive->GetAvgMode();
|
||||
|
||||
auto is_global = pooling_primitive->GetGlobal();
|
||||
pooling_param->global_ = is_global;
|
||||
|
|
Loading…
Reference in New Issue