forked from mindspore-Ecosystem/mindspore
avgpooling optimization from nc4hwc -> nhwc
This commit is contained in:
parent
7295eca19d
commit
0baf9d53b6
|
@ -140,6 +140,243 @@ int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int AvgPoolingFromNC4HW4ToNHWCLessC(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param,
|
||||
int task_id, float minf, float maxf) {
|
||||
int in_w = pooling_param->input_w_, in_h = pooling_param->input_h_;
|
||||
int win_w = pooling_param->window_w_, win_h = pooling_param->window_h_;
|
||||
int output_w = pooling_param->output_w_, output_h = pooling_param->output_h_;
|
||||
int channel = pooling_param->input_channel_;
|
||||
|
||||
int out_plane = output_w * output_h;
|
||||
int in_plane = in_w * in_h;
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(output_w);
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
const int c_tile = C8NUM;
|
||||
const int once_calc_num = 2;
|
||||
#elif defined(ENABLE_NEON) || defined(ENABLE_SSE)
|
||||
const int c_tile = C4NUM;
|
||||
const int once_calc_num = 1;
|
||||
#else
|
||||
const int c_tile = 1;
|
||||
const int once_calc_num = 1;
|
||||
#endif
|
||||
|
||||
const int c_xtile = once_calc_num * c_tile;
|
||||
|
||||
int cur_c = (channel / c_xtile) * c_xtile;
|
||||
int last_c_size = channel - cur_c;
|
||||
|
||||
int less_out_plane = out_plane * last_c_size;
|
||||
int calc_tile = UP_DIV(less_out_plane, pooling_param->thread_num_);
|
||||
|
||||
int index_begin = task_id * calc_tile;
|
||||
int index_end = (index_begin + calc_tile) < less_out_plane ? (index_begin + calc_tile) : less_out_plane;
|
||||
|
||||
int c_start = (index_begin / out_plane) + cur_c;
|
||||
int index_less = index_begin % out_plane;
|
||||
int h_start = index_less / output_h;
|
||||
int w_start = index_less % output_h;
|
||||
|
||||
int c_end = (index_end / out_plane) + cur_c;
|
||||
index_less = index_end % out_plane;
|
||||
int h_end = index_less / output_h;
|
||||
int w_end = index_less % output_h;
|
||||
|
||||
int c = c_start;
|
||||
int h = h_start;
|
||||
int w = w_start;
|
||||
|
||||
int in_w_cx_line = in_w * last_c_size;
|
||||
const float *src_c_ptr = src_b_ptr + c * in_plane;
|
||||
for (; c < channel; c += c_xtile) {
|
||||
for (; h < output_h; h++) {
|
||||
int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0);
|
||||
int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h);
|
||||
|
||||
for (; w < output_w; w++) {
|
||||
MS_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK);
|
||||
float tmp_avg = 0.0;
|
||||
|
||||
int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0);
|
||||
int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w);
|
||||
|
||||
int real_count = (cur_index_in_w_end - cur_index_in_w_start) * (cur_index_in_h_end - cur_index_in_h_start);
|
||||
MS_CHECK_TRUE_RET(real_count != 0, NNACL_ERR);
|
||||
|
||||
for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) {
|
||||
const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line;
|
||||
for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) {
|
||||
const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * last_c_size + (c - cur_c);
|
||||
tmp_avg += cur_input_index[0];
|
||||
}
|
||||
}
|
||||
|
||||
float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c;
|
||||
tmp_avg = tmp_avg / (float)real_count;
|
||||
tmp_avg = fminf(tmp_avg, maxf);
|
||||
dst_c_ptr[0] = tmp_avg;
|
||||
}
|
||||
w = 0;
|
||||
}
|
||||
h = 0;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int AvgPoolingFromNC4HW4ToNHWCBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param,
|
||||
int task_id, float minf, float maxf) {
|
||||
int in_w = pooling_param->input_w_, in_h = pooling_param->input_h_;
|
||||
int win_w = pooling_param->window_w_, win_h = pooling_param->window_h_;
|
||||
int output_w = pooling_param->output_w_, output_h = pooling_param->output_h_;
|
||||
int channel = pooling_param->input_channel_;
|
||||
|
||||
int out_plane = output_w * output_h;
|
||||
int in_plane = in_w * in_h;
|
||||
NNACL_CHECK_ZERO_RETURN_ERR(output_w);
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
const int c_tile = C8NUM;
|
||||
const int once_calc_num = 2;
|
||||
MS_FLOAT32X8 min_value_8 = MS_MOV256_F32(minf);
|
||||
MS_FLOAT32X8 max_value_8 = MS_MOV256_F32(maxf);
|
||||
#elif defined(ENABLE_NEON) || defined(ENABLE_SSE)
|
||||
const int c_tile = C4NUM;
|
||||
const int once_calc_num = 1;
|
||||
MS_FLOAT32X4 min_value = MS_MOVQ_F32(minf);
|
||||
MS_FLOAT32X4 max_value = MS_MOVQ_F32(maxf);
|
||||
#else
|
||||
const int c_tile = 1;
|
||||
const int once_calc_num = 1;
|
||||
#endif
|
||||
|
||||
int in_w_cx_line = in_w * c_tile;
|
||||
const int c_xtile = once_calc_num * c_tile;
|
||||
int c_tile_num = channel / c_xtile;
|
||||
int all_out_plane = out_plane * c_tile_num;
|
||||
int calc_tile = UP_DIV(all_out_plane, pooling_param->thread_num_);
|
||||
|
||||
int index_begin = task_id * calc_tile;
|
||||
int index_end = (index_begin + calc_tile) < all_out_plane ? (index_begin + calc_tile) : all_out_plane;
|
||||
|
||||
int c_start = (index_begin / out_plane) * c_xtile;
|
||||
int index_less = index_begin % out_plane;
|
||||
int h_start = index_less / output_h;
|
||||
int w_start = index_less % output_h;
|
||||
|
||||
int c_end = (index_end / out_plane) * c_xtile;
|
||||
index_less = index_end % out_plane;
|
||||
int h_end = index_less / output_h;
|
||||
int w_end = index_less % output_h;
|
||||
|
||||
int c = c_start;
|
||||
int h = h_start;
|
||||
int w = w_start;
|
||||
for (; c < channel; c += c_xtile) {
|
||||
const float *src_c_ptr = src_b_ptr + c * in_plane;
|
||||
for (; h < output_h; h++) {
|
||||
int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0);
|
||||
int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h);
|
||||
|
||||
for (; w < output_w; w++) {
|
||||
MS_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK);
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
MS_FLOAT32X8 tmp_avg = MS_MOV256_F32(0);
|
||||
MS_FLOAT32X8 tmp_avg2 = MS_MOV256_F32(0);
|
||||
#elif defined(ENABLE_NEON) || defined(ENABLE_SSE)
|
||||
MS_FLOAT32X4 tmp_avg = MS_MOVQ_F32(0);
|
||||
#else
|
||||
float tmp_avg = 0;
|
||||
#endif
|
||||
|
||||
int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0);
|
||||
int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w);
|
||||
|
||||
int real_count = (cur_index_in_w_end - cur_index_in_w_start) * (cur_index_in_h_end - cur_index_in_h_start);
|
||||
MS_CHECK_TRUE_RET(real_count != 0, NNACL_ERR);
|
||||
|
||||
for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) {
|
||||
const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line;
|
||||
for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) {
|
||||
const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * c_tile;
|
||||
#ifdef ENABLE_AVX
|
||||
tmp_avg = MS_ADD256_F32(tmp_avg, MS_LD256_F32(cur_input_index));
|
||||
#elif defined(ENABLE_NEON) || defined(ENABLE_SSE)
|
||||
tmp_avg = MS_ADDQ_F32(tmp_avg, MS_LDQ_F32(cur_input_index));
|
||||
#else
|
||||
tmp_avg += cur_input_index[0];
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
const float *src_c2_ptr_h_line = src_c_ptr_h_line + c_tile * in_plane;
|
||||
for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) {
|
||||
const float *cur_input_index = src_c2_ptr_h_line + cur_index_in_w * c_tile;
|
||||
|
||||
tmp_avg2 = MS_ADD256_F32(tmp_avg2, MS_LD256_F32(cur_input_index));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c;
|
||||
#ifdef ENABLE_AVX
|
||||
float *dst_c2_ptr = dst_c_ptr + c_tile;
|
||||
|
||||
tmp_avg = MS_DIV256_F32(tmp_avg, MS_MOV256_F32(real_count));
|
||||
tmp_avg = MS_MAX256_F32(tmp_avg, min_value_8);
|
||||
tmp_avg = MS_MIN256_F32(tmp_avg, max_value_8);
|
||||
MS_ST256_F32(dst_c_ptr, tmp_avg);
|
||||
|
||||
tmp_avg2 = MS_DIV256_F32(tmp_avg2, MS_MOV256_F32(real_count));
|
||||
tmp_avg2 = MS_MAX256_F32(tmp_avg2, min_value_8);
|
||||
tmp_avg2 = MS_MIN256_F32(tmp_avg2, max_value_8);
|
||||
MS_ST256_F32(dst_c2_ptr, tmp_avg2);
|
||||
#elif defined(ENABLE_NEON) || defined(ENABLE_SSE)
|
||||
tmp_avg = MS_DIVQ_F32(tmp_avg, MS_MOVQ_F32(real_count));
|
||||
tmp_avg = MS_MAXQ_F32(tmp_avg, min_value);
|
||||
tmp_avg = MS_MINQ_F32(tmp_avg, max_value);
|
||||
MS_STQ_F32(dst_c_ptr, tmp_avg);
|
||||
#else
|
||||
tmp_avg = tmp_avg / (float)real_count;
|
||||
tmp_avg = fmaxf(tmp_avg, minf);
|
||||
tmp_avg = fminf(tmp_avg, maxf);
|
||||
dst_c_ptr[0] = tmp_avg;
|
||||
#endif
|
||||
}
|
||||
w = 0;
|
||||
}
|
||||
h = 0;
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int AvgPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param,
|
||||
int task_id, float minf, float maxf) {
|
||||
int in_w = pooling_param->input_w_;
|
||||
int in_h = pooling_param->input_h_;
|
||||
int output_w = pooling_param->output_w_;
|
||||
int output_h = pooling_param->output_h_;
|
||||
int channel = pooling_param->input_channel_;
|
||||
int output_batch = pooling_param->output_batch_;
|
||||
|
||||
for (int batch = 0; batch < output_batch; batch++) {
|
||||
const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel;
|
||||
float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel;
|
||||
int ret = AvgPoolingFromNC4HW4ToNHWCBatch(src_b_ptr, dst_b_ptr, pooling_param, task_id, minf, maxf);
|
||||
if (ret != NNACL_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = AvgPoolingFromNC4HW4ToNHWCLessC(src_b_ptr, dst_b_ptr, pooling_param, task_id, minf, maxf);
|
||||
if (ret != NNACL_OK) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int MaxPoolingBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param, int task_id,
|
||||
float minf, float maxf) {
|
||||
int in_w = pooling_param->input_w_, in_h = pooling_param->input_h_;
|
||||
|
@ -249,7 +486,7 @@ int MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParameter
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int MaxPoolingFormNC4HW4ToNHWCLessC(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param,
|
||||
int MaxPoolingFromNC4HW4ToNHWCLessC(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param,
|
||||
int task_id, float minf, float maxf) {
|
||||
int in_w = pooling_param->input_w_, in_h = pooling_param->input_h_;
|
||||
int win_w = pooling_param->window_w_, win_h = pooling_param->window_h_;
|
||||
|
@ -276,36 +513,61 @@ int MaxPoolingFormNC4HW4ToNHWCLessC(const float *src_b_ptr, float *dst_b_ptr, co
|
|||
int cur_c = (channel / c_xtile) * c_xtile;
|
||||
int last_c_size = channel - cur_c;
|
||||
|
||||
int calc_tile = UP_DIV(out_plane, pooling_param->thread_num_);
|
||||
int less_out_plane = out_plane * last_c_size;
|
||||
int calc_tile = UP_DIV(less_out_plane, pooling_param->thread_num_);
|
||||
|
||||
int index_begin = task_id * calc_tile;
|
||||
int index_end = (index_begin + calc_tile) < out_plane ? (index_begin + calc_tile) : out_plane;
|
||||
int index_end = (index_begin + calc_tile) < less_out_plane ? (index_begin + calc_tile) : less_out_plane;
|
||||
|
||||
for (int c = cur_c; c < channel; c++) {
|
||||
for (int index = index_begin; index < index_end; index++) {
|
||||
const float *src_c_ptr = src_b_ptr + c * in_plane;
|
||||
int h = index / output_h;
|
||||
int w = index % output_h;
|
||||
int c_start = (index_begin / out_plane) + cur_c;
|
||||
int index_less = index_begin % out_plane;
|
||||
int h_start = index_less / output_h;
|
||||
int w_start = index_less % output_h;
|
||||
|
||||
float tmp_max = -FLT_MAX;
|
||||
for (int kh = 0; kh < win_h; kh++) {
|
||||
for (int kw = 0; kw < win_w; kw++) {
|
||||
const float *src_win_ptr = src_c_ptr + (kh + win_h * h) * in_w * last_c_size + (kw + win_w * w) * last_c_size;
|
||||
tmp_max = fmaxf(tmp_max, src_win_ptr[0]);
|
||||
int c_end = (index_end / out_plane) + cur_c;
|
||||
index_less = index_end % out_plane;
|
||||
int h_end = index_less / output_h;
|
||||
int w_end = index_less % output_h;
|
||||
|
||||
int c = c_start;
|
||||
int h = h_start;
|
||||
int w = w_start;
|
||||
|
||||
int in_w_cx_line = in_w * last_c_size;
|
||||
const float *src_c_ptr = src_b_ptr + cur_c * in_plane;
|
||||
for (; c < channel; c++) {
|
||||
for (; h < output_h; h++) {
|
||||
int cur_index_in_h_start = MSMAX(h * pooling_param->stride_h_ - pooling_param->pad_d_, 0);
|
||||
int cur_index_in_h_end = MSMIN(cur_index_in_h_start + win_h, in_h);
|
||||
|
||||
for (; w < output_w; w++) {
|
||||
MS_CHECK_TRUE_RET((c < c_end || h < h_end || w < w_end), NNACL_OK);
|
||||
float tmp_max = -FLT_MAX;
|
||||
|
||||
int cur_index_in_w_start = MSMAX(w * pooling_param->stride_w_ - pooling_param->pad_l_, 0);
|
||||
int cur_index_in_w_end = MSMIN(cur_index_in_w_start + win_w, in_w);
|
||||
|
||||
for (int cur_index_in_h = cur_index_in_h_start; cur_index_in_h < cur_index_in_h_end; cur_index_in_h++) {
|
||||
const float *src_c_ptr_h_line = src_c_ptr + cur_index_in_h * in_w_cx_line;
|
||||
for (int cur_index_in_w = cur_index_in_w_start; cur_index_in_w < cur_index_in_w_end; cur_index_in_w++) {
|
||||
const float *cur_input_index = src_c_ptr_h_line + cur_index_in_w * last_c_size + (c - cur_c);
|
||||
tmp_max = fmaxf(tmp_max, cur_input_index[0]);
|
||||
}
|
||||
}
|
||||
|
||||
float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c;
|
||||
tmp_max = fmaxf(tmp_max, minf);
|
||||
tmp_max = fminf(tmp_max, maxf);
|
||||
dst_c_ptr[0] = tmp_max;
|
||||
}
|
||||
|
||||
float *dst_c_ptr = dst_b_ptr + h * output_w * channel + w * channel + c;
|
||||
tmp_max = fmaxf(tmp_max, minf);
|
||||
tmp_max = fminf(tmp_max, maxf);
|
||||
dst_c_ptr[0] = tmp_max;
|
||||
w = 0;
|
||||
}
|
||||
h = 0;
|
||||
}
|
||||
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int MaxPoolingFormNC4HW4ToNHWCBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param,
|
||||
int MaxPoolingFromNC4HW4ToNHWCBatch(const float *src_b_ptr, float *dst_b_ptr, const PoolingParameter *pooling_param,
|
||||
int task_id, float minf, float maxf) {
|
||||
int in_w = pooling_param->input_w_, in_h = pooling_param->input_h_;
|
||||
int win_w = pooling_param->window_w_, win_h = pooling_param->window_h_;
|
||||
|
@ -426,7 +688,7 @@ int MaxPoolingFormNC4HW4ToNHWCBatch(const float *src_b_ptr, float *dst_b_ptr, co
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int MaxPoolingFormNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param,
|
||||
int MaxPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param,
|
||||
int task_id, float minf, float maxf) {
|
||||
int in_w = pooling_param->input_w_;
|
||||
int in_h = pooling_param->input_h_;
|
||||
|
@ -438,12 +700,12 @@ int MaxPoolingFormNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const
|
|||
for (int batch = 0; batch < output_batch; batch++) {
|
||||
const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel;
|
||||
float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel;
|
||||
int ret = MaxPoolingFormNC4HW4ToNHWCBatch(src_b_ptr, dst_b_ptr, pooling_param, task_id, minf, maxf);
|
||||
int ret = MaxPoolingFromNC4HW4ToNHWCBatch(src_b_ptr, dst_b_ptr, pooling_param, task_id, minf, maxf);
|
||||
if (ret != NNACL_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = MaxPoolingFormNC4HW4ToNHWCLessC(src_b_ptr, dst_b_ptr, pooling_param, task_id, minf, maxf);
|
||||
ret = MaxPoolingFromNC4HW4ToNHWCLessC(src_b_ptr, dst_b_ptr, pooling_param, task_id, minf, maxf);
|
||||
if (ret != NNACL_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -29,9 +29,11 @@ extern "C" {
|
|||
#endif
|
||||
int AvgPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, int task_id,
|
||||
float minf, float maxf);
|
||||
int AvgPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param,
|
||||
int task_id, float minf, float maxf);
|
||||
int MaxPooling(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param, int task_id,
|
||||
float minf, float maxf);
|
||||
int MaxPoolingFormNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param,
|
||||
int MaxPoolingFromNC4HW4ToNHWC(const float *input_ptr, float *output_ptr, const PoolingParameter *pooling_param,
|
||||
int task_id, float minf, float maxf);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -67,10 +67,9 @@ int PoolingCPUKernel::RunImpl(int task_id) const {
|
|||
|
||||
if (in_tensors_[0]->format() == NC4HW4) {
|
||||
if (pooling_param_->pool_mode_ == PoolMode_MaxPool) {
|
||||
ret = MaxPoolingFormNC4HW4ToNHWC(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
|
||||
ret = MaxPoolingFromNC4HW4ToNHWC(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
|
||||
} else {
|
||||
// ret = AvgPoolingFormNC4HW4ToNHWC(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
|
||||
MS_LOG(ERROR) << "Do not support NC4HW4 AvgPooling input format.";
|
||||
ret = AvgPoolingFromNC4HW4ToNHWC(input_ptr, output_ptr, pooling_param_, task_id, minf, maxf);
|
||||
}
|
||||
} else if (in_tensors_[0]->format() == NHWC) {
|
||||
if (pooling_param_->pool_mode_ == PoolMode_MaxPool) {
|
||||
|
|
Loading…
Reference in New Issue