add 1d f(2,3) support for 3x3 dw conv

This commit is contained in:
lixian 2021-03-15 09:06:07 +08:00
parent dca301eabf
commit aec6dfd513
8 changed files with 418 additions and 295 deletions

View File

@ -18,6 +18,7 @@
#include "nnacl/common_func.h"
#include "nnacl/fp32/common_func_fp32.h"
#include "nnacl/fp32/winograd_transform.h"
#include "nnacl/intrinsics/ms_simd_instructions.h"
#ifdef ENABLE_ARM64
#include <arm_neon.h>
#endif
@ -337,260 +338,373 @@ bool CheckConvDwUse3X3(const ConvParameter *conv_param) {
in_w == (conv_param->input_w_ + 2 * conv_param->pad_l_);
}
void ConvDw3x3BorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
int in_kh_step, int in_kw_step, int channel, bool relu, bool relu6) {
for (int c = 0; c < channel; c += C4NUM) {
for (int i = 0; i < C4NUM; i++) {
dst[i] = 0;
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
bool CheckConvDw1DWinograd(const ConvParameter *conv_param, int thread_num) {
return conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_w_ == 1 &&
conv_param->stride_h_ == 1 && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 &&
conv_param->pad_u_ == 1 && conv_param->pad_d_ == 1 && conv_param->pad_l_ == 1 && conv_param->pad_r_ == 1 &&
conv_param->input_channel_ == conv_param->output_channel_ &&
conv_param->output_h_ / thread_num >= 4; // better had more than 4 rows for each thread
}
void ConvDw3x3RowLeft(const float *src, float *line, int lw, int channel) {
MS_FLOAT32X4 v0, v1, v2, v3;
v0 = MS_MOVQ_F32(0.0f);
int ic = 0;
for (; ic < channel - 3; ic += 4) {
v1 = MS_LDQ_F32(src + ic);
v2 = MS_LDQ_F32(src + channel + ic);
v3 = MS_LDQ_F32(src + 2 * channel + ic);
MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
MS_STQ_F32(line + lw * ic, b0);
MS_STQ_F32(line + lw * ic + 4, b1);
MS_STQ_F32(line + lw * ic + 8, b2);
MS_STQ_F32(line + lw * ic + 12, b3);
}
if (ic < channel) {
float *remain_line = line + ic * lw;
memset(remain_line, 0, 16);
memset(remain_line + 4, 0, 16);
memset(remain_line + 8, 0, 16);
memset(remain_line + 12, 0, 16);
for (int i = 0; i < channel - ic; i++) {
float d1 = src[i + ic];
float d2 = src[i + ic + channel];
float d3 = src[i + ic + 2 * channel];
remain_line[i] = 0.0f - d2;
remain_line[i + 4] = d1 + d2;
remain_line[i + 8] = d2 - d1;
remain_line[i + 12] = d3 - d1;
}
const float *src_kh = src;
const float *weight_kh = weight;
for (int kh = 0; kh < height; kh++) {
const float *src_kw = src_kh;
const float *weight_kw = weight_kh;
for (int kw = 0; kw < width; kw++) {
for (int i = 0; i < C4NUM; i++) {
dst[i] += src_kw[c + i] * weight_kw[c + i];
}
src_kw += in_kw_step;
weight_kw += channel;
} // kernel_w loop
src_kh += in_kh_step;
weight_kh += 3 * channel;
} // kernel_h loop
for (int i = 0; i < C4NUM; i++) {
dst[i] += bias[c + i];
dst[i] = (relu) ? (MSMAX(0, dst[i])) : (dst[i]);
dst[i] = (relu6) ? (MSMIN(6, MSMAX(0, dst[i]))) : (dst[i]);
}
dst += C4NUM;
}
}
#ifndef ENABLE_ARM64
void ConvDw3x3Corner(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step,
int in_kw_step, int channel, bool relu, bool relu6) {
ConvDw3x3BorderPixel(dst, src, weight, bias, 2, 2, in_kh_step, in_kw_step, channel, relu, relu6);
}
void ConvDw3x3Vertical(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step,
int in_kw_step, int channel, bool relu, bool relu6) {
ConvDw3x3BorderPixel(dst, src, weight, bias, 2, 3, in_kh_step, in_kw_step, channel, relu, relu6);
}
void ConvDw3x3Horizontal(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step,
int in_kw_step, int channel, bool relu, bool relu6) {
ConvDw3x3BorderPixel(dst, src, weight, bias, 3, 2, in_kh_step, in_kw_step, channel, relu, relu6);
}
#endif
void ConvDw3x3Pad(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
int input_row_size = conv_param->input_w_ * conv_param->input_channel_;
int weight_row_size = conv_param->kernel_w_ * conv_param->input_channel_;
int output_row_size = conv_param->output_w_ * conv_param->output_channel_;
int in_kh_step = sliding->in_kh_step_;
int in_kw_step = sliding->in_kw_step_;
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
for (int b = 0; b < conv_param->output_batch_; b++) {
const float *input_batch =
input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
float *output_batch = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
// top
const float *input = input_batch;
const float *weight = weight_data + weight_row_size + conv_param->input_channel_;
float *output = output_batch;
ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6);
input += (conv_param->stride_w_ - 1) * conv_param->input_channel_;
weight = weight_data + weight_row_size;
output += conv_param->output_channel_;
for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) {
ConvDw3x3Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu,
relu6);
input += conv_param->stride_w_ * conv_param->input_channel_;
output += conv_param->output_channel_;
void ConvDw3x3RowMiddle(const float *src, float *line, int lw, int channel) {
MS_FLOAT32X4 v0, v1, v2, v3;
int ic = 0;
for (; ic < channel - 3; ic += 4) {
v0 = MS_LDQ_F32(src + ic);
v1 = MS_LDQ_F32(src + channel + ic);
v2 = MS_LDQ_F32(src + 2 * channel + ic);
v3 = MS_LDQ_F32(src + 3 * channel + ic);
MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
MS_STQ_F32(line + lw * ic, b0);
MS_STQ_F32(line + lw * ic + 4, b1);
MS_STQ_F32(line + lw * ic + 8, b2);
MS_STQ_F32(line + lw * ic + 12, b3);
}
if (ic < channel) {
float *remain_line = line + ic * lw;
memset(remain_line, 0, 16);
memset(remain_line + 4, 0, 16);
memset(remain_line + 8, 0, 16);
memset(remain_line + 12, 0, 16);
for (int i = 0; i < channel - ic; i++) {
float d0 = src[i + ic];
float d1 = src[i + ic + channel];
float d2 = src[i + ic + 2 * channel];
float d3 = src[i + ic + 3 * channel];
remain_line[i] = d0 - d2;
remain_line[i + 4] = d1 + d2;
remain_line[i + 8] = d2 - d1;
remain_line[i + 12] = d3 - d1;
}
ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6);
// left
input = input_batch + (conv_param->stride_h_ - 1) * input_row_size;
weight = weight_data + conv_param->input_channel_;
output = output_batch + output_row_size;
for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) {
ConvDw3x3Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu,
relu6);
input += conv_param->stride_h_ * input_row_size;
output += output_row_size;
}
// right
input = input_batch + (conv_param->input_w_ - 2) * conv_param->input_channel_ +
(conv_param->stride_h_ - 1) * input_row_size;
weight = weight_data;
output = output_batch + output_row_size + (conv_param->output_w_ - 1) * conv_param->output_channel_;
for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) {
ConvDw3x3Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu,
relu6);
input += conv_param->stride_h_ * input_row_size;
output += output_row_size;
}
// bottom
input = input_batch + (conv_param->input_h_ - 2) * input_row_size;
weight = weight_data + conv_param->input_channel_;
output = output_batch + (conv_param->output_h_ - 1) * output_row_size;
ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6);
input += conv_param->stride_w_ == 1 ? 0 : conv_param->input_channel_;
weight = weight_data;
output += conv_param->output_channel_;
for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) {
ConvDw3x3Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu,
relu6);
input += conv_param->stride_w_ * conv_param->input_channel_;
output += conv_param->output_channel_;
}
ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6);
}
}
void ConvDw3x3InitBuffer(float *buffer, const float *input, const ConvParameter *conv_param, int block_input_h,
int block_input_w) {
for (int h = 0; h < block_input_h; h++) {
const float *src = input;
for (int w = 0; w < block_input_w; w++) {
memcpy(buffer, src, 64 * sizeof(float));
src += conv_param->input_channel_;
buffer += 64;
void ConvDw3x3RowRight(const float *src, float *line, int lw, int channel) {
MS_FLOAT32X4 v0, v1, v2, v3;
int ic = 0;
v3 = MS_MOVQ_F32(0.0f);
for (; ic < channel - 3; ic += 4) {
v0 = MS_LDQ_F32(src + ic);
v1 = MS_LDQ_F32(src + channel + ic);
v2 = MS_LDQ_F32(src + 2 * channel + ic);
MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
MS_STQ_F32(line + lw * ic, b0);
MS_STQ_F32(line + lw * ic + 4, b1);
MS_STQ_F32(line + lw * ic + 8, b2);
MS_STQ_F32(line + lw * ic + 12, b3);
}
if (ic < channel) {
float *remain_line = line + ic * lw;
memset(remain_line, 0, 16);
memset(remain_line + 4, 0, 16);
memset(remain_line + 8, 0, 16);
memset(remain_line + 12, 0, 16);
for (int i = 0; i < channel - ic; i++) {
float d0 = src[i + ic];
float d1 = src[i + ic + channel];
float d2 = src[i + ic + 2 * channel];
remain_line[i] = d0 - d2;
remain_line[i + 4] = d1 + d2;
remain_line[i + 8] = d2 - d1;
remain_line[i + 12] = 0.0f - d1;
}
input += conv_param->input_w_ * conv_param->input_channel_;
}
}
void ConvDw3x3Window(float *output, const float *buffer, const float *weight, const float *bias, int col_size,
int row_size, int channel, int output_h, int output_w, int stride, bool relu, bool relu6) {
for (int w = 0; w < output_w; w++) {
for (int i = 0; i < C4NUM; i++) {
output[i] = bias[i];
void ConvDw3x3RowSingle(const float *src, float *line, int lw, int channel) {
MS_FLOAT32X4 v0, v1, v2;
int ic = 0;
v2 = MS_MOVQ_F32(0.0f);
for (; ic < channel - 3; ic += 4) {
v0 = MS_LDQ_F32(src + ic);
v1 = MS_LDQ_F32(src + channel + ic);
MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
MS_STQ_F32(line + lw * ic, v0);
MS_STQ_F32(line + lw * ic + 4, v1);
MS_STQ_F32(line + lw * ic + 8, b2);
memset(line + lw * ic + 12, 0, 16);
}
if (ic < channel) {
float *remain_line = line + ic * lw;
memset(remain_line, 0, 16);
memset(remain_line + 4, 0, 16);
memset(remain_line + 8, 0, 16);
memset(remain_line + 12, 0, 16);
for (int i = 0; i < channel - ic; i++) {
float d0 = src[i + ic];
float d1 = src[i + ic + channel];
remain_line[i] = d0;
remain_line[i + 4] = d1;
remain_line[i + 8] = 0.0f - d1;
}
const float *src_kh = buffer;
const float *weight_kh = weight;
for (int kh = 0; kh < 3; kh++) {
const float *src_kw = src_kh;
const float *weight_kw = weight_kh;
for (int kw = 0; kw < 3; kw++) {
for (int c = 0; c < C4NUM; c++) {
output[c] += src_kw[c] * weight_kw[c];
}
src_kw += col_size;
weight_kw += channel;
}
}
void ConvDw3x3InitTop(const float *src, float **lines, int width, int channel) {
float *line0 = lines[0];
float *line1 = lines[1];
float *line2 = lines[2];
int c4 = UP_ROUND(channel, C4NUM);
int lw = UP_DIV(width, C2NUM) * C4NUM;
memset(line0, 0, c4 * lw * sizeof(float));
ConvDw3x3RowLeft(src, line1, lw, channel);
ConvDw3x3RowLeft(src + width * channel, line2, lw, channel);
int ow = 2;
for (; ow < width - 2; ow += 2) {
ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
}
int remain = width - ow;
if (remain == 2) {
ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
} else if (remain == 1) {
ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
}
}
void ConvDw3x3InitRow(const float *src, float **lines, int width, int channel) {
float *line0 = lines[0];
float *line1 = lines[1];
float *line2 = lines[2];
int lw = UP_DIV(width, C2NUM) * C4NUM;
ConvDw3x3RowLeft(src - width * channel, line0, lw, channel);
ConvDw3x3RowLeft(src, line1, lw, channel);
ConvDw3x3RowLeft(src + width * channel, line2, lw, channel);
int ow = 2;
for (; ow < width - 2; ow += 2) {
ConvDw3x3RowMiddle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
}
int remain = width - ow;
if (remain == 2) {
ConvDw3x3RowRight(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
} else if (remain == 1) {
ConvDw3x3RowSingle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
}
}
void ConvDw3x3Row(const float *src, float **lines, int width, int channel) {
float *tmp = lines[0];
lines[0] = lines[1];
lines[1] = lines[2];
lines[2] = tmp;
int c4 = UP_ROUND(channel, C4NUM);
int lw = UP_DIV(width, C2NUM) * C4NUM;
memset(tmp, 0, c4 * lw * sizeof(float));
ConvDw3x3RowLeft(src, tmp, lw, channel);
int ow = 2;
for (; ow < width - 2; ow += 2) {
ConvDw3x3RowMiddle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
}
int remain = width - ow;
if (remain == 2) {
ConvDw3x3RowRight(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
} else if (remain == 1) {
ConvDw3x3RowSingle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
}
}
void ConvDw3x3Bottom(float **lines, int width, int channel) {
float *tmp = lines[0];
lines[0] = lines[1];
lines[1] = lines[2];
lines[2] = tmp;
int c4 = UP_ROUND(channel, C4NUM);
memset(tmp, 0, UP_DIV(width, C2NUM) * c4 * C4NUM * sizeof(float));
}
void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel,
bool relu, bool relu6) {
int channel = ori_channel;
float *line0 = lines[0];
float *line1 = lines[1];
float *line2 = lines[2];
for (; channel > 0; channel -= 4) {
MS_FLOAT32X4 bias = MS_LDQ_F32(bias_data);
bias_data += 4;
MS_FLOAT32X4 g00 = MS_LDQ_F32(weight);
MS_FLOAT32X4 g01 = MS_LDQ_F32(weight + 4);
MS_FLOAT32X4 g02 = MS_LDQ_F32(weight + 8);
MS_FLOAT32X4 g03 = MS_LDQ_F32(weight + 12);
MS_FLOAT32X4 g10 = MS_LDQ_F32(weight + 16);
MS_FLOAT32X4 g11 = MS_LDQ_F32(weight + 20);
MS_FLOAT32X4 g12 = MS_LDQ_F32(weight + 24);
MS_FLOAT32X4 g13 = MS_LDQ_F32(weight + 28);
MS_FLOAT32X4 g20 = MS_LDQ_F32(weight + 32);
MS_FLOAT32X4 g21 = MS_LDQ_F32(weight + 36);
MS_FLOAT32X4 g22 = MS_LDQ_F32(weight + 40);
MS_FLOAT32X4 g23 = MS_LDQ_F32(weight + 44);
weight += 48;
float *cur_dst = dst;
int ow = 0;
for (; ow < width - 1; ow += 2) {
MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00);
MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01);
MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02);
MS_FLOAT32X4 acc3 = MS_MULQ_F32(MS_LDQ_F32(line0 + 12), g03);
line0 += 16;
acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10);
acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11);
acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12);
acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line1 + 12), g13);
line1 += 16;
acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20);
acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21);
acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22);
acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line2 + 12), g23);
line2 += 16;
MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1));
MS_FLOAT32X4 res1 = MS_ADDQ_F32(acc1, MS_SUBQ_F32(acc3, acc2));
res0 = MS_ADDQ_F32(res0, bias);
res1 = MS_ADDQ_F32(res1, bias);
if (relu || relu6) {
res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f));
res1 = MS_MAXQ_F32(res1, MS_MOVQ_F32(0.0f));
}
src_kh += row_size;
weight_kh += 3 * channel;
}
for (int i = 0; i < C4NUM; i++) {
output[i] = (relu) ? (MSMAX(0, output[i])) : (output[i]);
output[i] = (relu6) ? (MSMIN(6, MSMAX(0, output[i]))) : (output[i]);
}
output += channel;
buffer += col_size * stride;
}
}
void ConvDw3x3Block(float *output, const float *buffer, const float *weight, const float *bias, int start_c, int end_c,
int col_size, int row_size, int channel, int output_h, int output_w, int stride, bool relu,
bool relu6) {
for (; start_c <= end_c - C4NUM; start_c += C4NUM) {
#ifdef ENABLE_ARM64
if (stride == 1) {
ConvDw3x3Stride1(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, relu, relu6);
} else {
ConvDw3x3Stride2(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, relu, relu6);
}
#else
ConvDw3x3Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, stride, relu, relu6);
#endif
output += C4NUM;
buffer += C4NUM;
weight += C4NUM;
bias += C4NUM;
}
}
void ConvDw3x3Row(float *output, float *buffer, const float *input, const float *weight, const float *bias,
const ConvParameter *conv_param, int start_w, int end_w, int block_output_h, int block_output_w,
int block_input_h, int block_input_w) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
const int ih_offset = 64 * block_input_w;
int w = start_w;
if (conv_param->output_channel_ > 64 || (conv_param->output_channel_ < 64 && conv_param->input_w_ > 150)) {
for (; w <= end_w - block_output_w; w += block_output_w) {
float *output_ptr = output;
const float *input_ptr = input;
const float *weight_ptr = weight;
const float *bias_ptr = bias;
int c = 0;
for (; c <= conv_param->output_channel_ - 64; c += 64) {
ConvDw3x3InitBuffer(buffer, input_ptr, conv_param, block_input_h, block_input_w);
ConvDw3x3Block(output_ptr, buffer, weight_ptr, bias_ptr, 0, 64, 64, ih_offset, conv_param->input_channel_,
block_output_h, block_output_w, conv_param->stride_h_, relu, relu6);
output_ptr += 64;
input_ptr += 64;
weight_ptr += 64;
bias_ptr += 64;
if (relu6) {
res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f));
res1 = MS_MINQ_F32(res1, MS_MOVQ_F32(6.0f));
}
// left channel
ConvDw3x3Block(output_ptr, input_ptr, weight_ptr, bias_ptr, c, conv_param->input_channel_,
conv_param->input_channel_, conv_param->input_w_ * conv_param->input_channel_,
conv_param->input_channel_, block_output_h, block_output_w, conv_param->stride_h_, relu, relu6);
output += block_output_w * conv_param->input_channel_;
input += conv_param->stride_w_ * block_output_w * conv_param->input_channel_;
if (channel >= 4) {
MS_STQ_F32(cur_dst, res0);
MS_STQ_F32(cur_dst + ori_channel, res1);
} else {
for (int i = 0; i < channel; i++) {
cur_dst[i] = res0[i];
cur_dst[ori_channel + i] = res1[i];
}
}
cur_dst += 2 * ori_channel;
}
}
// left width
int left_width = end_w - w;
if (left_width > 0) {
ConvDw3x3Block(output, input, weight, bias, 0, conv_param->input_channel_, conv_param->input_channel_,
conv_param->input_w_ * conv_param->input_channel_, conv_param->input_channel_, block_output_h,
left_width, conv_param->stride_h_, relu, relu6);
if (ow < width) {
MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00);
MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01);
MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02);
line0 += 16;
acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10);
acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11);
acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12);
line1 += 16;
acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20);
acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21);
acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22);
line2 += 16;
MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1));
res0 = MS_ADDQ_F32(res0, bias);
if (relu || relu6) {
res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f));
}
if (relu6) {
res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f));
}
if (channel >= 4) {
MS_STQ_F32(cur_dst, res0);
} else {
for (int i = 0; i < channel; i++) {
cur_dst[i] = res0[i];
}
}
}
dst += 4;
}
}
void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data,
const float *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
int task_id) {
int output_h = sliding->bottom_ - sliding->top_;
int step_oh = UP_DIV(output_h, conv_param->thread_num_);
int start_oh = step_oh * task_id + sliding->top_;
int end_oh = MSMIN(start_oh + step_oh, sliding->bottom_);
int start_ow = sliding->left_;
int end_ow = sliding->right_;
const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh) {
int units = UP_DIV(conv_param->output_w_, C2NUM);
int c4 = UP_ROUND(conv_param->input_channel_, C4NUM);
int line = conv_param->input_channel_ * conv_param->input_w_;
const int block_output_h = 1;
int block_output_w = conv_param->stride_w_ == 1 ? 30 : 14;
const int block_input_h = 3;
int block_input_w = conv_param->stride_w_ * (block_output_w - 1) + 3;
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
for (int b = 0; b < conv_param->output_batch_; b++) {
int start_ih = start_oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_iw = start_ow * conv_param->stride_w_ - conv_param->pad_l_;
const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_ +
start_ih * conv_param->input_w_ * conv_param->input_channel_ +
start_iw * conv_param->input_channel_;
float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_ +
start_oh * conv_param->output_w_ * conv_param->output_channel_ +
start_ow * conv_param->output_channel_;
for (int oh = start_oh; oh < end_oh; oh++) {
ConvDw3x3Row(dst, buffer, src, weight_data, bias_data, conv_param, start_ow, end_ow, block_output_h,
block_output_w, block_input_h, block_input_w);
src += conv_param->stride_h_ * conv_param->input_w_ * conv_param->input_channel_;
dst += conv_param->output_w_ * conv_param->output_channel_;
const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
float *line0 = buffer;
float *line1 = buffer + units * c4 * C4NUM;
float *line2 = buffer + units * c4 * C8NUM;
float *lines[3] = {line0, line1, line2};
int oh = start_oh;
if (oh == 0) {
// input trans
ConvDw3x3InitTop(src, lines, conv_param->output_w_, conv_param->input_channel_);
} else {
// input trans
ConvDw3x3InitRow(src + oh * line, lines, conv_param->output_w_, conv_param->input_channel_);
}
// dst calc and trans
ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
relu, relu6);
for (oh = start_oh + 1; oh < end_oh - 1; oh++) {
// input trans
ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_);
// dst calc and trans
ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
relu, relu6);
}
if (oh == conv_param->output_h_ - 1) {
// input trans
ConvDw3x3Bottom(lines, conv_param->output_w_, conv_param->input_channel_);
} else {
// input trans
ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_);
}
// dst calc and trans
ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
relu, relu6);
}
}
#endif
/*conv depthwise indirect buffer fp32 begin*/
bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param) {

View File

@ -47,12 +47,6 @@ void ConvDwSWFp32(float *output_data, const float *input_data, const float *weig
bool CheckConvDwUse3X3(const ConvParameter *conv_param);
void ConvDw3x3Pad(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
const ConvParameter *conv_param, const SlidingWindowParam *sliding);
void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data,
const float *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param);
void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param,
@ -74,6 +68,13 @@ void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const
size_t output_width, size_t input_stride, size_t relu, size_t relu6);
#endif
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data,
const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh);
bool CheckConvDw1DWinograd(const ConvParameter *conv_param, int thread_num);
#endif
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
int output_width, int input_stride, bool relu, bool relu6, int kernel);

View File

@ -632,3 +632,23 @@ inline void Transpose8X8Fp32Sse(const float *src_ptr, float *dst_ptr, int src_st
_mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride + C4NUM, v11_ma);
}
#endif
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel) {
// nchw to nc4hw4 with 1D F(2,3)
for (int i = 0; i < channel; i++) {
float *src_kernel = (float *)src + i * 9;
float *dst_kernel = (float *)dst + (i / 4) * 48 + i % 4;
for (int y = 0; y < 3; y++) {
float g0 = src_kernel[3 * y];
float g1 = src_kernel[3 * y + 1];
float g2 = src_kernel[3 * y + 2];
dst_kernel[16 * y] = g0;
dst_kernel[16 * y + 4] = 0.5f * (g0 + g1 + g2);
dst_kernel[16 * y + 8] = 0.5f * (g0 - g1 + g2);
dst_kernel[16 * y + 12] = g2;
}
}
}
#endif

View File

@ -44,6 +44,10 @@ void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, i
void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num,
int block_index);
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel);
#endif
// Transpose 8X8 Fp32 block data
typedef void (*Transpose8X8Fp32Func)(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride);
#ifdef ENABLE_ARM64

View File

@ -32,7 +32,6 @@
#define MS_ADDQ_EPI32 vaddq_s32
#define MS_MOVQ_F32 vmovq_n_f32
#define MS_MOVQ_EPI32 vmovq_n_s32
#define MS_DUPQ_F32 vdupq_n_f32 // It is recommended to replace with MS_MOVQ_F32.
#define MS_SUBQ_F32 vsubq_f32
#define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3)
#define MS_STQ_F32 vst1q_f32
@ -76,7 +75,6 @@ inline static float32x4_t vrecp(float32x4_t v) {
#define MS_ADD256_EPI32 _mm256_add_epi32
#define MS_MOV256_F32 _mm256_set1_ps
#define MS_MOV256_EPI32 _mm256_set1_epi32
#define MS_DUP256_F32 _mm256_load_ps1 // It is recommended to replace with MS_MOV256_F32.
#define MS_MLA256_F32(src1, src2, src3) _mm256_add_ps(src1, _mm256_mul_ps(src2, src3))
#define MS_ST256_F32 _mm256_storeu_ps
#define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2)
@ -109,7 +107,6 @@ inline static float32x4_t vrecp(float32x4_t v) {
#define MS_ADDQ_EPI32 _mm_add_epi32
#define MS_MOVQ_F32 _mm_set1_ps
#define MS_MOVQ_EPI32 _mm_set1_epi32
#define MS_DUPQ_F32 _mm_load_ps1 // It is recommended to replace with MS_MOVQ_F32.
#define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3))
#define MS_STQ_F32 _mm_storeu_ps
#define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2)

View File

@ -21,6 +21,7 @@
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h"
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h"
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h"
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
@ -354,8 +355,13 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
kernel::LiteKernel *kernel = nullptr;
if (opParameter != nullptr && opParameter->infer_flag_) {
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
if (CheckConvDw1DWinograd(conv_param, ctx->thread_num_)) {
kernel = new (std::nothrow) kernel::ConvolutionDepthwise3x3CPUKernel(opParameter, inputs, outputs, ctx);
}
#endif
#if defined(ENABLE_ARM64) || defined(ENABLE_AVX)
if (CheckConvDwUseIndirectBuffer(conv_param)) {
if (kernel == nullptr && CheckConvDwUseIndirectBuffer(conv_param)) {
kernel = new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx);
}
#endif
@ -367,7 +373,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx);
}
return kernel;
}
} // namespace mindspore::kernel
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,

View File

@ -18,8 +18,10 @@
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INFER_INVALID;
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
@ -28,10 +30,6 @@ ConvolutionDepthwise3x3CPUKernel::~ConvolutionDepthwise3x3CPUKernel() {
free(packed_weight_);
packed_weight_ = nullptr;
}
if (sliding_ != nullptr) {
delete sliding_;
sliding_ = nullptr;
}
}
int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() {
@ -39,22 +37,26 @@ int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() {
auto weight_tensor = in_tensors_[kWeightIndex];
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
int channel = weight_tensor->Batch();
int pack_weight_size = weight_tensor->Batch() * weight_tensor->Height() * weight_tensor->Width();
int c4 = UP_ROUND(channel, C4NUM);
int pack_weight_size = c4 * C12NUM;
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
}
PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), channel);
PackWeightConvDw3x3Fp32(origin_weight, packed_weight_, channel);
bias_data_ = reinterpret_cast<float *>(malloc(channel * sizeof(float)));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
bias_data_ = reinterpret_cast<float *>(malloc(c4 * sizeof(float)));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
}
memset(bias_data_, 0, channel * sizeof(float));
memset(bias_data_, 0, c4 * sizeof(float));
if (in_tensors_.size() == kInputSize2) {
auto bias_tensor = in_tensors_[kBiasIndex];
auto ori_bias = reinterpret_cast<float *>(bias_tensor->MutableData());
@ -65,11 +67,6 @@ int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() {
}
int ConvolutionDepthwise3x3CPUKernel::Init() {
sliding_ = new (std::nothrow) SlidingWindowParam;
if (sliding_ == nullptr) {
MS_LOG(ERROR) << "new sliding window param failed.";
return RET_ERROR;
}
auto ret = InitWeightBias();
if (ret != 0) {
MS_LOG(ERROR) << "Convolution depthwise 3x3 fp32 InitWeightBias failed.";
@ -83,15 +80,19 @@ int ConvolutionDepthwise3x3CPUKernel::Init() {
int ConvolutionDepthwise3x3CPUKernel::ReSize() {
ConvolutionBaseCPUKernel::Init();
InitSlidingParamConvDw(sliding_, conv_param_, conv_param_->input_channel_);
conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_);
return RET_OK;
}
int ConvolutionDepthwise3x3CPUKernel::Execute(int task_id) {
auto buffer = buffer_ + 64 * 10 * 10 * task_id;
int units = UP_DIV(conv_param_->output_w_, C2NUM); // F(2, 3) contains 2 conv units
int c4 = UP_ROUND(conv_param_->input_channel_, C4NUM);
auto buffer = buffer_ + C12NUM * c4 * units * task_id;
int step_oh = UP_DIV(conv_param_->output_h_, conv_param_->thread_num_);
int start_oh = step_oh * task_id;
int end_oh = MSMIN(start_oh + step_oh, conv_param_->output_h_);
ConvDw3x3(output_ptr_, buffer, input_ptr_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_,
sliding_, task_id);
start_oh, end_oh);
return RET_OK;
}
@ -105,25 +106,18 @@ int ConvDw3x3Run(void *cdata, int task_id) {
return RET_OK;
}
int ConvolutionDepthwise3x3CPUKernel::InitBuffer() {
int buffer_size = 64 * 10 * 10 * conv_param_->thread_num_;
buffer_ = reinterpret_cast<float *>(context_->allocator->Malloc(buffer_size * sizeof(float)));
if (buffer_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionDepthwise3x3CPUKernel::Run() {
auto ret = InitBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Depthwise int8 ReSize error!";
return ret;
int units = UP_DIV(conv_param_->output_w_, C2NUM); // F(2, 3) contains 2 conv units
int c4 = UP_ROUND(conv_param_->input_channel_, C4NUM);
int buffer_size = units * c4 * C12NUM * conv_param_->thread_num_;
buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(buffer_size * sizeof(float)));
if (buffer_ == nullptr) {
MS_LOG(ERROR) << "ConvDw3x3Run failed to allocate buffer";
return RET_MEMORY_FAILED;
}
if (IsTrain() && is_trainable()) {
PackWeight();
InitWeightBias();
}
auto input_tensor = in_tensors_.at(kInputIndex);
@ -132,32 +126,21 @@ int ConvolutionDepthwise3x3CPUKernel::Run() {
auto output_tensor = out_tensors_.at(kOutputIndex);
output_ptr_ = reinterpret_cast<float *>(output_tensor->data_c());
if (sliding_->top_ > 0 || sliding_->bottom_ < conv_param_->output_h_ || sliding_->left_ > 0 ||
sliding_->right_ < conv_param_->output_w_) {
ConvDw3x3Pad(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_, sliding_);
}
ret = ParallelLaunch(this->context_->thread_pool_, ConvDw3x3Run, this, conv_param_->thread_num_);
auto ret = ParallelLaunch(this->context_->thread_pool_, ConvDw3x3Run, this, conv_param_->thread_num_);
ctx_->allocator->Free(buffer_);
if (ret != RET_OK) {
context_->allocator->Free(buffer_);
MS_LOG(ERROR) << "ConvDw3x3Run error: error_code[" << ret << "]";
return RET_ERROR;
}
context_->allocator->Free(buffer_);
return RET_OK;
}
void ConvolutionDepthwise3x3CPUKernel::PackWeight() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(),
weight_tensor->Batch());
}
int ConvolutionDepthwise3x3CPUKernel::Eval() {
LiteKernel::Eval();
if (is_trainable()) {
PackWeight();
InitWeightBias();
}
return RET_OK;
}
} // namespace mindspore::kernel
#endif

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
@ -39,14 +40,11 @@ class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel {
int Eval() override;
private:
void PackWeight();
int InitBuffer();
SlidingWindowParam *sliding_ = nullptr;
float *packed_weight_ = nullptr;
float *input_ptr_ = nullptr;
float *output_ptr_ = nullptr;
float *buffer_ = nullptr;
};
} // namespace mindspore::kernel
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_