From 2931920fea3978896a083e0d8687c2066613d627 Mon Sep 17 00:00:00 2001 From: songhonglei413 Date: Tue, 25 Aug 2020 10:16:49 +0800 Subject: [PATCH] op_pooling+relu --- mindspore/lite/nnacl/fp32/pooling.c | 447 ++++++++++++++++++ mindspore/lite/nnacl/fp32/pooling.h | 8 + mindspore/lite/nnacl/pooling_parameter.h | 1 + mindspore/lite/schema/ops.fbs | 1 + mindspore/lite/src/ops/pooling.cc | 5 + mindspore/lite/src/ops/pooling.h | 4 +- mindspore/lite/src/populate_parameter.cc | 8 + .../src/runtime/kernel/arm/fp32/pooling.cc | 22 +- 8 files changed, 493 insertions(+), 3 deletions(-) diff --git a/mindspore/lite/nnacl/fp32/pooling.c b/mindspore/lite/nnacl/fp32/pooling.c index 543ccfb132..e533eec5e0 100644 --- a/mindspore/lite/nnacl/fp32/pooling.c +++ b/mindspore/lite/nnacl/fp32/pooling.c @@ -171,6 +171,7 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo tmp_max2 = fmax(tmp_max2, *(input_ptr + in_offset + 1)); tmp_max3 = fmax(tmp_max3, *(input_ptr + in_offset + 2)); tmp_max4 = fmax(tmp_max4, *(input_ptr + in_offset + 3)); + #endif } } // win_w loop @@ -206,3 +207,449 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo } // out_plane loop } // out_batch loop } + +void AvgPoolingRelu(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int c4 = UP_DIV(channel, C4NUM); + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + int thread_num = pooling_param->thread_num_; +#ifdef ENABLE_NEON + float32x4_t zeros = vdupq_n_f32(0); +#endif + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c4 - 1; j++) { + int in_channel_offset = in_batch_offset + j * C4NUM; + int out_channel_offset = out_plane_offset + j * C4NUM; +#ifdef ENABLE_NEON + float32x4_t tmp_avg = vdupq_n_f32(0); +#else + float tmp_avg1 = 0; + float tmp_avg2 = 0; + float tmp_avg3 = 0; + float tmp_avg4 = 0; +#endif + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_avg = vaddq_f32(tmp_avg, vld1q_f32(input_ptr + in_offset)); +#else + tmp_avg1 += *(input_ptr + in_offset); + tmp_avg2 += *(input_ptr + in_offset + 1); + tmp_avg3 += *(input_ptr + in_offset + 2); + tmp_avg4 += *(input_ptr + in_offset + 3); +#endif + ++real_count; + } + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + tmp_avg = vmaxq_f32(tmp_avg, zeros); + vst1q_f32(output_ptr + out_channel_offset, tmp_avg / vdupq_n_f32(real_count)); +#else + tmp_avg1 = fmax(tmp_avg1, 0); + tmp_avg2 = fmax(tmp_avg2, 0); + tmp_avg3 = fmax(tmp_avg3, 0); + tmp_avg4 = fmax(tmp_avg4, 0); + + *(output_ptr + out_channel_offset) = tmp_avg1 / (float)real_count; + *(output_ptr + out_channel_offset + 1) = tmp_avg2 / (float)real_count; + *(output_ptr + out_channel_offset + 2) = tmp_avg3 / (float)real_count; + *(output_ptr + out_channel_offset + 3) = tmp_avg4 / (float)real_count; +#endif + } // ic4-1 loop + int channel_s = (c4 - 1) * C4NUM; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + float tmp_avg = 0; + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += *(input_ptr + in_offset); + ++real_count; + } + } // win_w loop + } // win_h loop + tmp_avg = fmax(tmp_avg, 0); + *(output_ptr + out_channel_offset) = tmp_avg / (float)real_count; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + } // out_batch loop +} + +void MaxPoolingRelu(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + int thread_num = pooling_param->thread_num_; + int c4 = UP_DIV(channel, C4NUM); + // input channel is equal to output channel + +#ifdef ENABLE_NEON + float32x4_t zeros = vdupq_n_f32(0); +#endif + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c4 - 1; j++) { + int in_channel_offset = in_batch_offset + j * C4NUM; + int out_channel_offset = out_plane_offset + j * C4NUM; +#ifdef ENABLE_NEON + float32x4_t tmp_max = vdupq_n_f32(-FLT_MAX); +#else + float tmp_max1 = -FLT_MAX; + float tmp_max2 = -FLT_MAX; + float tmp_max3 = -FLT_MAX; + float tmp_max4 = -FLT_MAX; +#endif + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmaxq_f32(tmp_max, vld1q_f32(input_ptr + in_offset)); +#else + tmp_max1 = fmax(tmp_max1, *(input_ptr + in_offset)); + tmp_max2 = fmax(tmp_max2, *(input_ptr + in_offset + 1)); + tmp_max3 = fmax(tmp_max3, *(input_ptr + in_offset + 2)); + tmp_max4 = fmax(tmp_max4, *(input_ptr + in_offset + 3)); + +#endif + } + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + tmp_max = vmaxq_f32(tmp_max, zeros); + vst1q_f32(output_ptr + out_channel_offset, tmp_max); +#else + // relu: + tmp_max1 = fmax(tmp_max1, 0); + tmp_max2 = fmax(tmp_max2, 0); + tmp_max3 = fmax(tmp_max3, 0); + tmp_max4 = fmax(tmp_max4, 0); + + *(output_ptr + out_channel_offset) = tmp_max1; + *(output_ptr + out_channel_offset + 1) = tmp_max2; + *(output_ptr + out_channel_offset + 2) = tmp_max3; + *(output_ptr + out_channel_offset + 3) = tmp_max4; +#endif + } // ic4-1 loop + int channel_s = (c4 - 1) * C4NUM; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + float tmp_max = -FLT_MAX; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = fmax(tmp_max, *(input_ptr + in_offset)); + tmp_max = fmax(tmp_max, 0); + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = tmp_max; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + } // out_batch loop +} + +void AvgPoolingRelu6(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int c4 = UP_DIV(channel, C4NUM); + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + int thread_num = pooling_param->thread_num_; + // input channel is equal to output channel + +#ifdef ENABLE_NEON + float32x4_t zeros = vdupq_n_f32(0); + float32x4_t bounds = vdupq_n_f32(6); +#endif + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c4 - 1; j++) { + int in_channel_offset = in_batch_offset + j * C4NUM; + int out_channel_offset = out_plane_offset + j * C4NUM; +#ifdef ENABLE_NEON + float32x4_t tmp_avg = vdupq_n_f32(0); +#else + float tmp_avg1 = 0; + float tmp_avg2 = 0; + float tmp_avg3 = 0; + float tmp_avg4 = 0; +#endif + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_avg = vaddq_f32(tmp_avg, vld1q_f32(input_ptr + in_offset)); +#else + tmp_avg1 += *(input_ptr + in_offset); + tmp_avg2 += *(input_ptr + in_offset + 1); + tmp_avg3 += *(input_ptr + in_offset + 2); + tmp_avg4 += *(input_ptr + in_offset + 3); +#endif + ++real_count; + } + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + float32x4_t count = vdupq_n_f32(real_count); + tmp_avg = vdivq_f32(tmp_avg, count); + tmp_avg = vmaxq_f32(tmp_avg, zeros); + tmp_avg = vminq_f32(tmp_avg, bounds); + vst1q_f32(output_ptr + out_channel_offset, tmp_avg); +#else + tmp_avg1 /= (float)real_count; + tmp_avg2 /= (float)real_count; + tmp_avg3 /= (float)real_count; + tmp_avg4 /= (float)real_count; + tmp_avg1 = fmax(tmp_avg1, 0); + tmp_avg2 = fmax(tmp_avg2, 0); + tmp_avg3 = fmax(tmp_avg3, 0); + tmp_avg4 = fmax(tmp_avg4, 0); + tmp_avg1 = fmin(tmp_avg1, 6); + tmp_avg2 = fmin(tmp_avg2, 6); + tmp_avg3 = fmin(tmp_avg3, 6); + tmp_avg4 = fmin(tmp_avg4, 6); + + *(output_ptr + out_channel_offset) = tmp_avg1; + *(output_ptr + out_channel_offset + 1) = tmp_avg2; + *(output_ptr + out_channel_offset + 2) = tmp_avg3; + *(output_ptr + out_channel_offset + 3) = tmp_avg4; +#endif + } // ic4-1 loop + int channel_s = (c4 - 1) * C4NUM; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + float tmp_avg = 0; + int real_count = 0; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_avg += *(input_ptr + in_offset); + ++real_count; + } + } // win_w loop + } // win_h loop + tmp_avg /= (float)real_count; + tmp_avg = fmax(tmp_avg, 0); + tmp_avg = fmin(tmp_avg, 6); + *(output_ptr + out_channel_offset) = tmp_avg; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + } // out_batch loop +} + +void MaxPoolingRelu6(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + int thread_num = pooling_param->thread_num_; + int c4 = UP_DIV(channel, C4NUM); + // input channel is equal to output channel + +#ifdef ENABLE_NEON + float32x4_t zeros = vdupq_n_f32(0); + float32x4_t bounds = vdupq_n_f32(6); +#endif + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c4 - 1; j++) { + int in_channel_offset = in_batch_offset + j * C4NUM; + int out_channel_offset = out_plane_offset + j * C4NUM; +#ifdef ENABLE_NEON + float32x4_t tmp_max = vdupq_n_f32(-FLT_MAX); +#else + float tmp_max1 = -FLT_MAX; + float tmp_max2 = -FLT_MAX; + float tmp_max3 = -FLT_MAX; + float tmp_max4 = -FLT_MAX; +#endif + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmaxq_f32(tmp_max, vld1q_f32(input_ptr + in_offset)); +#else + tmp_max1 = fmax(tmp_max1, *(input_ptr + in_offset)); + tmp_max2 = fmax(tmp_max2, *(input_ptr + in_offset + 1)); + tmp_max3 = fmax(tmp_max3, *(input_ptr + in_offset + 2)); + tmp_max4 = fmax(tmp_max4, *(input_ptr + in_offset + 3)); + +#endif + } + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + tmp_max = vmaxq_f32(tmp_max, zeros); + tmp_max = vminq_f32(tmp_max, bounds); + vst1q_f32(output_ptr + out_channel_offset, tmp_max); +#else + tmp_max1 = fmax(tmp_max1, 0); + tmp_max2 = fmax(tmp_max2, 0); + tmp_max3 = fmax(tmp_max3, 0); + tmp_max4 = fmax(tmp_max4, 0); + tmp_max1 = fmin(tmp_max1, 6); + tmp_max2 = fmin(tmp_max2, 6); + tmp_max3 = fmin(tmp_max3, 6); + tmp_max4 = fmin(tmp_max4, 6); + + *(output_ptr + out_channel_offset) = tmp_max1; + *(output_ptr + out_channel_offset + 1) = tmp_max2; + *(output_ptr + out_channel_offset + 2) = tmp_max3; + *(output_ptr + out_channel_offset + 3) = tmp_max4; +#endif + } // ic4-1 loop + int channel_s = (c4 - 1) * C4NUM; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + float tmp_max = -FLT_MAX; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = fmax(tmp_max, *(input_ptr + in_offset)); + tmp_max = fmax(tmp_max, 0); + tmp_max = fmin(tmp_max, 6); + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = tmp_max; + } // channel_res loop + } // real_cal_num loop + } // out_plane loop + } // out_batch loop +} diff --git a/mindspore/lite/nnacl/fp32/pooling.h b/mindspore/lite/nnacl/fp32/pooling.h index c8c90fca64..ae62f97390 100644 --- a/mindspore/lite/nnacl/fp32/pooling.h +++ b/mindspore/lite/nnacl/fp32/pooling.h @@ -30,6 +30,14 @@ extern "C" { void AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); + +void AvgPoolingRelu(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); + +void MaxPoolingRelu(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); + +void AvgPoolingRelu6(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); + +void MaxPoolingRelu6(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/pooling_parameter.h b/mindspore/lite/nnacl/pooling_parameter.h index 42205af006..644d3d4c47 100644 --- a/mindspore/lite/nnacl/pooling_parameter.h +++ b/mindspore/lite/nnacl/pooling_parameter.h @@ -44,6 +44,7 @@ typedef struct PoolingParameter { int stride_w_; int stride_h_; int thread_num_; + ActType act_type_; } PoolingParameter; #endif // MINDSPORE_LITE_NNACL_POOLING_PARAMETER_H_ diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 78c6140e43..6977278622 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -290,6 +290,7 @@ table Pooling { padLeft: int; padRight: int; roundMode: RoundMode; + activationType: ActivationType = 0; } table DepthwiseConv2D { diff --git a/mindspore/lite/src/ops/pooling.cc b/mindspore/lite/src/ops/pooling.cc index ffdddfece2..80a3d728af 100644 --- a/mindspore/lite/src/ops/pooling.cc +++ b/mindspore/lite/src/ops/pooling.cc @@ -36,6 +36,7 @@ int Pooling::GetPadDown() const { return this->primitive_->value.AsPooling()->pa int Pooling::GetPadLeft() const { return this->primitive_->value.AsPooling()->padLeft; } int Pooling::GetPadRight() const { return this->primitive_->value.AsPooling()->padRight; } int Pooling::GetRoundMode() const { return this->primitive_->value.AsPooling()->roundMode; } +int Pooling::GetActivationType() const { return this->primitive_->value.AsPooling()->activationType; } void Pooling::SetFormat(int format) { this->primitive_->value.AsPooling()->format = (schema::Format)format; } void Pooling::SetPoolingMode(int pooling_mode) { @@ -54,6 +55,9 @@ void Pooling::SetPadRight(int pad_right) { this->primitive_->value.AsPooling()-> void Pooling::SetRoundMode(int round_mode) { this->primitive_->value.AsPooling()->roundMode = (schema::RoundMode)round_mode; } +void Pooling::SetActivationType(int activation_type) { + this->primitive_->value.AsPooling()->activationType = (schema::ActivationType)activation_type; +} int Pooling::UnPackAttr(const Primitive &prim, const std::vector &inputs) { if (this->primitive_ == nullptr) { @@ -130,6 +134,7 @@ int Pooling::GetPadDown() const { return this->primitive_->value_as_Pooling()->p int Pooling::GetPadLeft() const { return this->primitive_->value_as_Pooling()->padLeft(); } int Pooling::GetPadRight() const { return this->primitive_->value_as_Pooling()->padRight(); } int Pooling::GetRoundMode() const { return this->primitive_->value_as_Pooling()->roundMode(); } +int Pooling::GetActivationType() const { return this->primitive_->value_as_Pooling()->activationType(); } #endif diff --git a/mindspore/lite/src/ops/pooling.h b/mindspore/lite/src/ops/pooling.h index 6b7d3c10fa..b6aed86409 100644 --- a/mindspore/lite/src/ops/pooling.h +++ b/mindspore/lite/src/ops/pooling.h @@ -44,6 +44,7 @@ class Pooling : public PrimitiveC { void SetPadLeft(int pad_left); void SetPadRight(int pad_right); void SetRoundMode(int round_mode); + void SetActivationType(int activation_type); #else explicit Pooling(schema::Primitive *primitive) : PrimitiveC(primitive) {} #endif @@ -61,6 +62,7 @@ class Pooling : public PrimitiveC { int GetPadLeft() const; int GetPadRight() const; int GetRoundMode() const; + int GetActivationType() const; int PadUp() const; int PadDown() const; @@ -74,7 +76,7 @@ class Pooling : public PrimitiveC { int pad_d_ = 0; int pad_l_ = 0; int pad_r_ = 0; -}; +}; // namespace lite } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 9106496e82..c035c087da 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -314,6 +314,14 @@ OpParameter *PopulatePoolingParameter(const mindspore::lite::PrimitiveC *primiti pooling_param->round_ceil_ = false; break; } + + if (pooling_primitive->GetActivationType() == schema::ActivationType_RELU) { + pooling_param->act_type_ = ActType_Relu; + } else if (pooling_primitive->GetActivationType() == schema::ActivationType_RELU6) { + pooling_param->act_type_ = ActType_Relu6; + } else { + pooling_param->act_type_ = ActType_No; + } return reinterpret_cast(pooling_param); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc index 8ea453593f..61009096d5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc @@ -53,9 +53,27 @@ int PoolingCPUKernel::RunImpl(int task_id) { auto input_ptr = reinterpret_cast(in_tensors_.at(kInputIndex)->Data()); auto output_ptr = reinterpret_cast(out_tensors_.at(kOutputIndex)->Data()); if (pooling_param_->max_pooling_) { - MaxPooling(input_ptr, output_ptr, pooling_param_, task_id); + switch (pooling_param_->act_type_) { + case ActType_Relu: + MaxPoolingRelu(input_ptr, output_ptr, pooling_param_, task_id); + break; + case ActType_Relu6: + MaxPoolingRelu6(input_ptr, output_ptr, pooling_param_, task_id); + break; + default: + MaxPooling(input_ptr, output_ptr, pooling_param_, task_id); + } } else { - AvgPooling(input_ptr, output_ptr, pooling_param_, task_id); + switch (pooling_param_->act_type_) { + case ActType_Relu: + AvgPoolingRelu(input_ptr, output_ptr, pooling_param_, task_id); + break; + case ActType_Relu6: + AvgPoolingRelu6(input_ptr, output_ptr, pooling_param_, task_id); + break; + default: + AvgPooling(input_ptr, output_ptr, pooling_param_, task_id); + } } return RET_OK; }