forked from mindspore-Ecosystem/mindspore
add 1d f(2,3) support for 3x3 dw conv
This commit is contained in:
parent
dca301eabf
commit
aec6dfd513
|
@ -18,6 +18,7 @@
|
|||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/fp32/common_func_fp32.h"
|
||||
#include "nnacl/fp32/winograd_transform.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
#ifdef ENABLE_ARM64
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
@ -337,260 +338,373 @@ bool CheckConvDwUse3X3(const ConvParameter *conv_param) {
|
|||
in_w == (conv_param->input_w_ + 2 * conv_param->pad_l_);
|
||||
}
|
||||
|
||||
void ConvDw3x3BorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
|
||||
int in_kh_step, int in_kw_step, int channel, bool relu, bool relu6) {
|
||||
for (int c = 0; c < channel; c += C4NUM) {
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
dst[i] = 0;
|
||||
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
|
||||
bool CheckConvDw1DWinograd(const ConvParameter *conv_param, int thread_num) {
|
||||
return conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_w_ == 1 &&
|
||||
conv_param->stride_h_ == 1 && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 &&
|
||||
conv_param->pad_u_ == 1 && conv_param->pad_d_ == 1 && conv_param->pad_l_ == 1 && conv_param->pad_r_ == 1 &&
|
||||
conv_param->input_channel_ == conv_param->output_channel_ &&
|
||||
conv_param->output_h_ / thread_num >= 4; // better had more than 4 rows for each thread
|
||||
}
|
||||
|
||||
void ConvDw3x3RowLeft(const float *src, float *line, int lw, int channel) {
|
||||
MS_FLOAT32X4 v0, v1, v2, v3;
|
||||
v0 = MS_MOVQ_F32(0.0f);
|
||||
int ic = 0;
|
||||
for (; ic < channel - 3; ic += 4) {
|
||||
v1 = MS_LDQ_F32(src + ic);
|
||||
v2 = MS_LDQ_F32(src + channel + ic);
|
||||
v3 = MS_LDQ_F32(src + 2 * channel + ic);
|
||||
MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
|
||||
MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
|
||||
MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
|
||||
MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
|
||||
MS_STQ_F32(line + lw * ic, b0);
|
||||
MS_STQ_F32(line + lw * ic + 4, b1);
|
||||
MS_STQ_F32(line + lw * ic + 8, b2);
|
||||
MS_STQ_F32(line + lw * ic + 12, b3);
|
||||
}
|
||||
if (ic < channel) {
|
||||
float *remain_line = line + ic * lw;
|
||||
memset(remain_line, 0, 16);
|
||||
memset(remain_line + 4, 0, 16);
|
||||
memset(remain_line + 8, 0, 16);
|
||||
memset(remain_line + 12, 0, 16);
|
||||
for (int i = 0; i < channel - ic; i++) {
|
||||
float d1 = src[i + ic];
|
||||
float d2 = src[i + ic + channel];
|
||||
float d3 = src[i + ic + 2 * channel];
|
||||
remain_line[i] = 0.0f - d2;
|
||||
remain_line[i + 4] = d1 + d2;
|
||||
remain_line[i + 8] = d2 - d1;
|
||||
remain_line[i + 12] = d3 - d1;
|
||||
}
|
||||
const float *src_kh = src;
|
||||
const float *weight_kh = weight;
|
||||
for (int kh = 0; kh < height; kh++) {
|
||||
const float *src_kw = src_kh;
|
||||
const float *weight_kw = weight_kh;
|
||||
for (int kw = 0; kw < width; kw++) {
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
dst[i] += src_kw[c + i] * weight_kw[c + i];
|
||||
}
|
||||
src_kw += in_kw_step;
|
||||
weight_kw += channel;
|
||||
} // kernel_w loop
|
||||
src_kh += in_kh_step;
|
||||
weight_kh += 3 * channel;
|
||||
} // kernel_h loop
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
dst[i] += bias[c + i];
|
||||
dst[i] = (relu) ? (MSMAX(0, dst[i])) : (dst[i]);
|
||||
dst[i] = (relu6) ? (MSMIN(6, MSMAX(0, dst[i]))) : (dst[i]);
|
||||
}
|
||||
dst += C4NUM;
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef ENABLE_ARM64
|
||||
void ConvDw3x3Corner(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step,
|
||||
int in_kw_step, int channel, bool relu, bool relu6) {
|
||||
ConvDw3x3BorderPixel(dst, src, weight, bias, 2, 2, in_kh_step, in_kw_step, channel, relu, relu6);
|
||||
}
|
||||
|
||||
void ConvDw3x3Vertical(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step,
|
||||
int in_kw_step, int channel, bool relu, bool relu6) {
|
||||
ConvDw3x3BorderPixel(dst, src, weight, bias, 2, 3, in_kh_step, in_kw_step, channel, relu, relu6);
|
||||
}
|
||||
|
||||
void ConvDw3x3Horizontal(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step,
|
||||
int in_kw_step, int channel, bool relu, bool relu6) {
|
||||
ConvDw3x3BorderPixel(dst, src, weight, bias, 3, 2, in_kh_step, in_kw_step, channel, relu, relu6);
|
||||
}
|
||||
#endif
|
||||
|
||||
void ConvDw3x3Pad(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
|
||||
const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
|
||||
int input_row_size = conv_param->input_w_ * conv_param->input_channel_;
|
||||
int weight_row_size = conv_param->kernel_w_ * conv_param->input_channel_;
|
||||
int output_row_size = conv_param->output_w_ * conv_param->output_channel_;
|
||||
int in_kh_step = sliding->in_kh_step_;
|
||||
int in_kw_step = sliding->in_kw_step_;
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
|
||||
for (int b = 0; b < conv_param->output_batch_; b++) {
|
||||
const float *input_batch =
|
||||
input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
|
||||
float *output_batch = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
|
||||
// top
|
||||
const float *input = input_batch;
|
||||
const float *weight = weight_data + weight_row_size + conv_param->input_channel_;
|
||||
float *output = output_batch;
|
||||
ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6);
|
||||
input += (conv_param->stride_w_ - 1) * conv_param->input_channel_;
|
||||
weight = weight_data + weight_row_size;
|
||||
output += conv_param->output_channel_;
|
||||
for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) {
|
||||
ConvDw3x3Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu,
|
||||
relu6);
|
||||
input += conv_param->stride_w_ * conv_param->input_channel_;
|
||||
output += conv_param->output_channel_;
|
||||
void ConvDw3x3RowMiddle(const float *src, float *line, int lw, int channel) {
|
||||
MS_FLOAT32X4 v0, v1, v2, v3;
|
||||
int ic = 0;
|
||||
for (; ic < channel - 3; ic += 4) {
|
||||
v0 = MS_LDQ_F32(src + ic);
|
||||
v1 = MS_LDQ_F32(src + channel + ic);
|
||||
v2 = MS_LDQ_F32(src + 2 * channel + ic);
|
||||
v3 = MS_LDQ_F32(src + 3 * channel + ic);
|
||||
MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
|
||||
MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
|
||||
MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
|
||||
MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
|
||||
MS_STQ_F32(line + lw * ic, b0);
|
||||
MS_STQ_F32(line + lw * ic + 4, b1);
|
||||
MS_STQ_F32(line + lw * ic + 8, b2);
|
||||
MS_STQ_F32(line + lw * ic + 12, b3);
|
||||
}
|
||||
if (ic < channel) {
|
||||
float *remain_line = line + ic * lw;
|
||||
memset(remain_line, 0, 16);
|
||||
memset(remain_line + 4, 0, 16);
|
||||
memset(remain_line + 8, 0, 16);
|
||||
memset(remain_line + 12, 0, 16);
|
||||
for (int i = 0; i < channel - ic; i++) {
|
||||
float d0 = src[i + ic];
|
||||
float d1 = src[i + ic + channel];
|
||||
float d2 = src[i + ic + 2 * channel];
|
||||
float d3 = src[i + ic + 3 * channel];
|
||||
remain_line[i] = d0 - d2;
|
||||
remain_line[i + 4] = d1 + d2;
|
||||
remain_line[i + 8] = d2 - d1;
|
||||
remain_line[i + 12] = d3 - d1;
|
||||
}
|
||||
ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6);
|
||||
|
||||
// left
|
||||
input = input_batch + (conv_param->stride_h_ - 1) * input_row_size;
|
||||
weight = weight_data + conv_param->input_channel_;
|
||||
output = output_batch + output_row_size;
|
||||
for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) {
|
||||
ConvDw3x3Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu,
|
||||
relu6);
|
||||
input += conv_param->stride_h_ * input_row_size;
|
||||
output += output_row_size;
|
||||
}
|
||||
|
||||
// right
|
||||
input = input_batch + (conv_param->input_w_ - 2) * conv_param->input_channel_ +
|
||||
(conv_param->stride_h_ - 1) * input_row_size;
|
||||
weight = weight_data;
|
||||
output = output_batch + output_row_size + (conv_param->output_w_ - 1) * conv_param->output_channel_;
|
||||
for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) {
|
||||
ConvDw3x3Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu,
|
||||
relu6);
|
||||
input += conv_param->stride_h_ * input_row_size;
|
||||
output += output_row_size;
|
||||
}
|
||||
|
||||
// bottom
|
||||
input = input_batch + (conv_param->input_h_ - 2) * input_row_size;
|
||||
weight = weight_data + conv_param->input_channel_;
|
||||
output = output_batch + (conv_param->output_h_ - 1) * output_row_size;
|
||||
ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6);
|
||||
input += conv_param->stride_w_ == 1 ? 0 : conv_param->input_channel_;
|
||||
weight = weight_data;
|
||||
output += conv_param->output_channel_;
|
||||
for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) {
|
||||
ConvDw3x3Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu,
|
||||
relu6);
|
||||
input += conv_param->stride_w_ * conv_param->input_channel_;
|
||||
output += conv_param->output_channel_;
|
||||
}
|
||||
ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6);
|
||||
}
|
||||
}
|
||||
|
||||
void ConvDw3x3InitBuffer(float *buffer, const float *input, const ConvParameter *conv_param, int block_input_h,
|
||||
int block_input_w) {
|
||||
for (int h = 0; h < block_input_h; h++) {
|
||||
const float *src = input;
|
||||
for (int w = 0; w < block_input_w; w++) {
|
||||
memcpy(buffer, src, 64 * sizeof(float));
|
||||
src += conv_param->input_channel_;
|
||||
buffer += 64;
|
||||
void ConvDw3x3RowRight(const float *src, float *line, int lw, int channel) {
|
||||
MS_FLOAT32X4 v0, v1, v2, v3;
|
||||
int ic = 0;
|
||||
v3 = MS_MOVQ_F32(0.0f);
|
||||
for (; ic < channel - 3; ic += 4) {
|
||||
v0 = MS_LDQ_F32(src + ic);
|
||||
v1 = MS_LDQ_F32(src + channel + ic);
|
||||
v2 = MS_LDQ_F32(src + 2 * channel + ic);
|
||||
MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2);
|
||||
MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2);
|
||||
MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
|
||||
MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1);
|
||||
MS_STQ_F32(line + lw * ic, b0);
|
||||
MS_STQ_F32(line + lw * ic + 4, b1);
|
||||
MS_STQ_F32(line + lw * ic + 8, b2);
|
||||
MS_STQ_F32(line + lw * ic + 12, b3);
|
||||
}
|
||||
if (ic < channel) {
|
||||
float *remain_line = line + ic * lw;
|
||||
memset(remain_line, 0, 16);
|
||||
memset(remain_line + 4, 0, 16);
|
||||
memset(remain_line + 8, 0, 16);
|
||||
memset(remain_line + 12, 0, 16);
|
||||
for (int i = 0; i < channel - ic; i++) {
|
||||
float d0 = src[i + ic];
|
||||
float d1 = src[i + ic + channel];
|
||||
float d2 = src[i + ic + 2 * channel];
|
||||
remain_line[i] = d0 - d2;
|
||||
remain_line[i + 4] = d1 + d2;
|
||||
remain_line[i + 8] = d2 - d1;
|
||||
remain_line[i + 12] = 0.0f - d1;
|
||||
}
|
||||
input += conv_param->input_w_ * conv_param->input_channel_;
|
||||
}
|
||||
}
|
||||
|
||||
void ConvDw3x3Window(float *output, const float *buffer, const float *weight, const float *bias, int col_size,
|
||||
int row_size, int channel, int output_h, int output_w, int stride, bool relu, bool relu6) {
|
||||
for (int w = 0; w < output_w; w++) {
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
output[i] = bias[i];
|
||||
void ConvDw3x3RowSingle(const float *src, float *line, int lw, int channel) {
|
||||
MS_FLOAT32X4 v0, v1, v2;
|
||||
int ic = 0;
|
||||
v2 = MS_MOVQ_F32(0.0f);
|
||||
for (; ic < channel - 3; ic += 4) {
|
||||
v0 = MS_LDQ_F32(src + ic);
|
||||
v1 = MS_LDQ_F32(src + channel + ic);
|
||||
MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1);
|
||||
MS_STQ_F32(line + lw * ic, v0);
|
||||
MS_STQ_F32(line + lw * ic + 4, v1);
|
||||
MS_STQ_F32(line + lw * ic + 8, b2);
|
||||
memset(line + lw * ic + 12, 0, 16);
|
||||
}
|
||||
if (ic < channel) {
|
||||
float *remain_line = line + ic * lw;
|
||||
memset(remain_line, 0, 16);
|
||||
memset(remain_line + 4, 0, 16);
|
||||
memset(remain_line + 8, 0, 16);
|
||||
memset(remain_line + 12, 0, 16);
|
||||
for (int i = 0; i < channel - ic; i++) {
|
||||
float d0 = src[i + ic];
|
||||
float d1 = src[i + ic + channel];
|
||||
remain_line[i] = d0;
|
||||
remain_line[i + 4] = d1;
|
||||
remain_line[i + 8] = 0.0f - d1;
|
||||
}
|
||||
const float *src_kh = buffer;
|
||||
const float *weight_kh = weight;
|
||||
for (int kh = 0; kh < 3; kh++) {
|
||||
const float *src_kw = src_kh;
|
||||
const float *weight_kw = weight_kh;
|
||||
for (int kw = 0; kw < 3; kw++) {
|
||||
for (int c = 0; c < C4NUM; c++) {
|
||||
output[c] += src_kw[c] * weight_kw[c];
|
||||
}
|
||||
src_kw += col_size;
|
||||
weight_kw += channel;
|
||||
}
|
||||
}
|
||||
|
||||
void ConvDw3x3InitTop(const float *src, float **lines, int width, int channel) {
|
||||
float *line0 = lines[0];
|
||||
float *line1 = lines[1];
|
||||
float *line2 = lines[2];
|
||||
int c4 = UP_ROUND(channel, C4NUM);
|
||||
int lw = UP_DIV(width, C2NUM) * C4NUM;
|
||||
memset(line0, 0, c4 * lw * sizeof(float));
|
||||
ConvDw3x3RowLeft(src, line1, lw, channel);
|
||||
ConvDw3x3RowLeft(src + width * channel, line2, lw, channel);
|
||||
int ow = 2;
|
||||
for (; ow < width - 2; ow += 2) {
|
||||
ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
|
||||
ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
|
||||
}
|
||||
int remain = width - ow;
|
||||
if (remain == 2) {
|
||||
ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
|
||||
ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
|
||||
} else if (remain == 1) {
|
||||
ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
|
||||
ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
|
||||
}
|
||||
}
|
||||
|
||||
void ConvDw3x3InitRow(const float *src, float **lines, int width, int channel) {
|
||||
float *line0 = lines[0];
|
||||
float *line1 = lines[1];
|
||||
float *line2 = lines[2];
|
||||
int lw = UP_DIV(width, C2NUM) * C4NUM;
|
||||
ConvDw3x3RowLeft(src - width * channel, line0, lw, channel);
|
||||
ConvDw3x3RowLeft(src, line1, lw, channel);
|
||||
ConvDw3x3RowLeft(src + width * channel, line2, lw, channel);
|
||||
int ow = 2;
|
||||
for (; ow < width - 2; ow += 2) {
|
||||
ConvDw3x3RowMiddle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
|
||||
ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
|
||||
ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
|
||||
}
|
||||
int remain = width - ow;
|
||||
if (remain == 2) {
|
||||
ConvDw3x3RowRight(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
|
||||
ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
|
||||
ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
|
||||
} else if (remain == 1) {
|
||||
ConvDw3x3RowSingle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel);
|
||||
ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel);
|
||||
ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel);
|
||||
}
|
||||
}
|
||||
|
||||
void ConvDw3x3Row(const float *src, float **lines, int width, int channel) {
|
||||
float *tmp = lines[0];
|
||||
lines[0] = lines[1];
|
||||
lines[1] = lines[2];
|
||||
lines[2] = tmp;
|
||||
int c4 = UP_ROUND(channel, C4NUM);
|
||||
int lw = UP_DIV(width, C2NUM) * C4NUM;
|
||||
memset(tmp, 0, c4 * lw * sizeof(float));
|
||||
ConvDw3x3RowLeft(src, tmp, lw, channel);
|
||||
int ow = 2;
|
||||
for (; ow < width - 2; ow += 2) {
|
||||
ConvDw3x3RowMiddle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
|
||||
}
|
||||
int remain = width - ow;
|
||||
if (remain == 2) {
|
||||
ConvDw3x3RowRight(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
|
||||
} else if (remain == 1) {
|
||||
ConvDw3x3RowSingle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel);
|
||||
}
|
||||
}
|
||||
|
||||
void ConvDw3x3Bottom(float **lines, int width, int channel) {
|
||||
float *tmp = lines[0];
|
||||
lines[0] = lines[1];
|
||||
lines[1] = lines[2];
|
||||
lines[2] = tmp;
|
||||
int c4 = UP_ROUND(channel, C4NUM);
|
||||
memset(tmp, 0, UP_DIV(width, C2NUM) * c4 * C4NUM * sizeof(float));
|
||||
}
|
||||
|
||||
void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel,
|
||||
bool relu, bool relu6) {
|
||||
int channel = ori_channel;
|
||||
float *line0 = lines[0];
|
||||
float *line1 = lines[1];
|
||||
float *line2 = lines[2];
|
||||
for (; channel > 0; channel -= 4) {
|
||||
MS_FLOAT32X4 bias = MS_LDQ_F32(bias_data);
|
||||
bias_data += 4;
|
||||
MS_FLOAT32X4 g00 = MS_LDQ_F32(weight);
|
||||
MS_FLOAT32X4 g01 = MS_LDQ_F32(weight + 4);
|
||||
MS_FLOAT32X4 g02 = MS_LDQ_F32(weight + 8);
|
||||
MS_FLOAT32X4 g03 = MS_LDQ_F32(weight + 12);
|
||||
MS_FLOAT32X4 g10 = MS_LDQ_F32(weight + 16);
|
||||
MS_FLOAT32X4 g11 = MS_LDQ_F32(weight + 20);
|
||||
MS_FLOAT32X4 g12 = MS_LDQ_F32(weight + 24);
|
||||
MS_FLOAT32X4 g13 = MS_LDQ_F32(weight + 28);
|
||||
MS_FLOAT32X4 g20 = MS_LDQ_F32(weight + 32);
|
||||
MS_FLOAT32X4 g21 = MS_LDQ_F32(weight + 36);
|
||||
MS_FLOAT32X4 g22 = MS_LDQ_F32(weight + 40);
|
||||
MS_FLOAT32X4 g23 = MS_LDQ_F32(weight + 44);
|
||||
weight += 48;
|
||||
float *cur_dst = dst;
|
||||
int ow = 0;
|
||||
for (; ow < width - 1; ow += 2) {
|
||||
MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00);
|
||||
MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01);
|
||||
MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02);
|
||||
MS_FLOAT32X4 acc3 = MS_MULQ_F32(MS_LDQ_F32(line0 + 12), g03);
|
||||
line0 += 16;
|
||||
acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10);
|
||||
acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11);
|
||||
acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12);
|
||||
acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line1 + 12), g13);
|
||||
line1 += 16;
|
||||
acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20);
|
||||
acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21);
|
||||
acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22);
|
||||
acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line2 + 12), g23);
|
||||
line2 += 16;
|
||||
MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1));
|
||||
MS_FLOAT32X4 res1 = MS_ADDQ_F32(acc1, MS_SUBQ_F32(acc3, acc2));
|
||||
res0 = MS_ADDQ_F32(res0, bias);
|
||||
res1 = MS_ADDQ_F32(res1, bias);
|
||||
if (relu || relu6) {
|
||||
res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f));
|
||||
res1 = MS_MAXQ_F32(res1, MS_MOVQ_F32(0.0f));
|
||||
}
|
||||
src_kh += row_size;
|
||||
weight_kh += 3 * channel;
|
||||
}
|
||||
for (int i = 0; i < C4NUM; i++) {
|
||||
output[i] = (relu) ? (MSMAX(0, output[i])) : (output[i]);
|
||||
output[i] = (relu6) ? (MSMIN(6, MSMAX(0, output[i]))) : (output[i]);
|
||||
}
|
||||
output += channel;
|
||||
buffer += col_size * stride;
|
||||
}
|
||||
}
|
||||
|
||||
void ConvDw3x3Block(float *output, const float *buffer, const float *weight, const float *bias, int start_c, int end_c,
|
||||
int col_size, int row_size, int channel, int output_h, int output_w, int stride, bool relu,
|
||||
bool relu6) {
|
||||
for (; start_c <= end_c - C4NUM; start_c += C4NUM) {
|
||||
#ifdef ENABLE_ARM64
|
||||
if (stride == 1) {
|
||||
ConvDw3x3Stride1(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, relu, relu6);
|
||||
} else {
|
||||
ConvDw3x3Stride2(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, relu, relu6);
|
||||
}
|
||||
#else
|
||||
ConvDw3x3Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, stride, relu, relu6);
|
||||
#endif
|
||||
output += C4NUM;
|
||||
buffer += C4NUM;
|
||||
weight += C4NUM;
|
||||
bias += C4NUM;
|
||||
}
|
||||
}
|
||||
|
||||
void ConvDw3x3Row(float *output, float *buffer, const float *input, const float *weight, const float *bias,
|
||||
const ConvParameter *conv_param, int start_w, int end_w, int block_output_h, int block_output_w,
|
||||
int block_input_h, int block_input_w) {
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
const int ih_offset = 64 * block_input_w;
|
||||
int w = start_w;
|
||||
if (conv_param->output_channel_ > 64 || (conv_param->output_channel_ < 64 && conv_param->input_w_ > 150)) {
|
||||
for (; w <= end_w - block_output_w; w += block_output_w) {
|
||||
float *output_ptr = output;
|
||||
const float *input_ptr = input;
|
||||
const float *weight_ptr = weight;
|
||||
const float *bias_ptr = bias;
|
||||
int c = 0;
|
||||
for (; c <= conv_param->output_channel_ - 64; c += 64) {
|
||||
ConvDw3x3InitBuffer(buffer, input_ptr, conv_param, block_input_h, block_input_w);
|
||||
ConvDw3x3Block(output_ptr, buffer, weight_ptr, bias_ptr, 0, 64, 64, ih_offset, conv_param->input_channel_,
|
||||
block_output_h, block_output_w, conv_param->stride_h_, relu, relu6);
|
||||
output_ptr += 64;
|
||||
input_ptr += 64;
|
||||
weight_ptr += 64;
|
||||
bias_ptr += 64;
|
||||
if (relu6) {
|
||||
res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f));
|
||||
res1 = MS_MINQ_F32(res1, MS_MOVQ_F32(6.0f));
|
||||
}
|
||||
// left channel
|
||||
ConvDw3x3Block(output_ptr, input_ptr, weight_ptr, bias_ptr, c, conv_param->input_channel_,
|
||||
conv_param->input_channel_, conv_param->input_w_ * conv_param->input_channel_,
|
||||
conv_param->input_channel_, block_output_h, block_output_w, conv_param->stride_h_, relu, relu6);
|
||||
output += block_output_w * conv_param->input_channel_;
|
||||
input += conv_param->stride_w_ * block_output_w * conv_param->input_channel_;
|
||||
if (channel >= 4) {
|
||||
MS_STQ_F32(cur_dst, res0);
|
||||
MS_STQ_F32(cur_dst + ori_channel, res1);
|
||||
} else {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
cur_dst[i] = res0[i];
|
||||
cur_dst[ori_channel + i] = res1[i];
|
||||
}
|
||||
}
|
||||
cur_dst += 2 * ori_channel;
|
||||
}
|
||||
}
|
||||
// left width
|
||||
int left_width = end_w - w;
|
||||
if (left_width > 0) {
|
||||
ConvDw3x3Block(output, input, weight, bias, 0, conv_param->input_channel_, conv_param->input_channel_,
|
||||
conv_param->input_w_ * conv_param->input_channel_, conv_param->input_channel_, block_output_h,
|
||||
left_width, conv_param->stride_h_, relu, relu6);
|
||||
if (ow < width) {
|
||||
MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00);
|
||||
MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01);
|
||||
MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02);
|
||||
line0 += 16;
|
||||
acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10);
|
||||
acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11);
|
||||
acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12);
|
||||
line1 += 16;
|
||||
acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20);
|
||||
acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21);
|
||||
acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22);
|
||||
line2 += 16;
|
||||
MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1));
|
||||
res0 = MS_ADDQ_F32(res0, bias);
|
||||
if (relu || relu6) {
|
||||
res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f));
|
||||
}
|
||||
if (relu6) {
|
||||
res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f));
|
||||
}
|
||||
if (channel >= 4) {
|
||||
MS_STQ_F32(cur_dst, res0);
|
||||
} else {
|
||||
for (int i = 0; i < channel; i++) {
|
||||
cur_dst[i] = res0[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
dst += 4;
|
||||
}
|
||||
}
|
||||
|
||||
void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data,
|
||||
const float *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
|
||||
int task_id) {
|
||||
int output_h = sliding->bottom_ - sliding->top_;
|
||||
int step_oh = UP_DIV(output_h, conv_param->thread_num_);
|
||||
int start_oh = step_oh * task_id + sliding->top_;
|
||||
int end_oh = MSMIN(start_oh + step_oh, sliding->bottom_);
|
||||
int start_ow = sliding->left_;
|
||||
int end_ow = sliding->right_;
|
||||
const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh) {
|
||||
int units = UP_DIV(conv_param->output_w_, C2NUM);
|
||||
int c4 = UP_ROUND(conv_param->input_channel_, C4NUM);
|
||||
int line = conv_param->input_channel_ * conv_param->input_w_;
|
||||
|
||||
const int block_output_h = 1;
|
||||
int block_output_w = conv_param->stride_w_ == 1 ? 30 : 14;
|
||||
const int block_input_h = 3;
|
||||
int block_input_w = conv_param->stride_w_ * (block_output_w - 1) + 3;
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
|
||||
for (int b = 0; b < conv_param->output_batch_; b++) {
|
||||
int start_ih = start_oh * conv_param->stride_h_ - conv_param->pad_u_;
|
||||
int start_iw = start_ow * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_ +
|
||||
start_ih * conv_param->input_w_ * conv_param->input_channel_ +
|
||||
start_iw * conv_param->input_channel_;
|
||||
float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_ +
|
||||
start_oh * conv_param->output_w_ * conv_param->output_channel_ +
|
||||
start_ow * conv_param->output_channel_;
|
||||
|
||||
for (int oh = start_oh; oh < end_oh; oh++) {
|
||||
ConvDw3x3Row(dst, buffer, src, weight_data, bias_data, conv_param, start_ow, end_ow, block_output_h,
|
||||
block_output_w, block_input_h, block_input_w);
|
||||
src += conv_param->stride_h_ * conv_param->input_w_ * conv_param->input_channel_;
|
||||
dst += conv_param->output_w_ * conv_param->output_channel_;
|
||||
const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
|
||||
float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
|
||||
float *line0 = buffer;
|
||||
float *line1 = buffer + units * c4 * C4NUM;
|
||||
float *line2 = buffer + units * c4 * C8NUM;
|
||||
float *lines[3] = {line0, line1, line2};
|
||||
int oh = start_oh;
|
||||
if (oh == 0) {
|
||||
// input trans
|
||||
ConvDw3x3InitTop(src, lines, conv_param->output_w_, conv_param->input_channel_);
|
||||
} else {
|
||||
// input trans
|
||||
ConvDw3x3InitRow(src + oh * line, lines, conv_param->output_w_, conv_param->input_channel_);
|
||||
}
|
||||
// dst calc and trans
|
||||
ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
|
||||
relu, relu6);
|
||||
for (oh = start_oh + 1; oh < end_oh - 1; oh++) {
|
||||
// input trans
|
||||
ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_);
|
||||
// dst calc and trans
|
||||
ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
|
||||
relu, relu6);
|
||||
}
|
||||
if (oh == conv_param->output_h_ - 1) {
|
||||
// input trans
|
||||
ConvDw3x3Bottom(lines, conv_param->output_w_, conv_param->input_channel_);
|
||||
} else {
|
||||
// input trans
|
||||
ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_);
|
||||
}
|
||||
// dst calc and trans
|
||||
ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
|
||||
relu, relu6);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
/*conv depthwise indirect buffer fp32 begin*/
|
||||
bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param) {
|
||||
|
|
|
@ -47,12 +47,6 @@ void ConvDwSWFp32(float *output_data, const float *input_data, const float *weig
|
|||
|
||||
bool CheckConvDwUse3X3(const ConvParameter *conv_param);
|
||||
|
||||
void ConvDw3x3Pad(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
|
||||
const ConvParameter *conv_param, const SlidingWindowParam *sliding);
|
||||
|
||||
void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data,
|
||||
const float *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id);
|
||||
|
||||
bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param);
|
||||
|
||||
void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param,
|
||||
|
@ -74,6 +68,13 @@ void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const
|
|||
size_t output_width, size_t input_stride, size_t relu, size_t relu6);
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
|
||||
void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data,
|
||||
const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh);
|
||||
|
||||
bool CheckConvDw1DWinograd(const ConvParameter *conv_param, int thread_num);
|
||||
#endif
|
||||
|
||||
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
|
||||
int output_width, int input_stride, bool relu, bool relu6, int kernel);
|
||||
|
||||
|
|
|
@ -632,3 +632,23 @@ inline void Transpose8X8Fp32Sse(const float *src_ptr, float *dst_ptr, int src_st
|
|||
_mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride + C4NUM, v11_ma);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
|
||||
void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel) {
|
||||
// nchw to nc4hw4 with 1D F(2,3)
|
||||
for (int i = 0; i < channel; i++) {
|
||||
float *src_kernel = (float *)src + i * 9;
|
||||
float *dst_kernel = (float *)dst + (i / 4) * 48 + i % 4;
|
||||
for (int y = 0; y < 3; y++) {
|
||||
float g0 = src_kernel[3 * y];
|
||||
float g1 = src_kernel[3 * y + 1];
|
||||
float g2 = src_kernel[3 * y + 2];
|
||||
|
||||
dst_kernel[16 * y] = g0;
|
||||
dst_kernel[16 * y + 4] = 0.5f * (g0 + g1 + g2);
|
||||
dst_kernel[16 * y + 8] = 0.5f * (g0 - g1 + g2);
|
||||
dst_kernel[16 * y + 12] = g2;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -44,6 +44,10 @@ void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, i
|
|||
void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num,
|
||||
int block_index);
|
||||
|
||||
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
|
||||
void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel);
|
||||
#endif
|
||||
|
||||
// Transpose 8X8 Fp32 block data
|
||||
typedef void (*Transpose8X8Fp32Func)(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride);
|
||||
#ifdef ENABLE_ARM64
|
||||
|
|
|
@ -32,7 +32,6 @@
|
|||
#define MS_ADDQ_EPI32 vaddq_s32
|
||||
#define MS_MOVQ_F32 vmovq_n_f32
|
||||
#define MS_MOVQ_EPI32 vmovq_n_s32
|
||||
#define MS_DUPQ_F32 vdupq_n_f32 // It is recommended to replace with MS_MOVQ_F32.
|
||||
#define MS_SUBQ_F32 vsubq_f32
|
||||
#define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3)
|
||||
#define MS_STQ_F32 vst1q_f32
|
||||
|
@ -76,7 +75,6 @@ inline static float32x4_t vrecp(float32x4_t v) {
|
|||
#define MS_ADD256_EPI32 _mm256_add_epi32
|
||||
#define MS_MOV256_F32 _mm256_set1_ps
|
||||
#define MS_MOV256_EPI32 _mm256_set1_epi32
|
||||
#define MS_DUP256_F32 _mm256_load_ps1 // It is recommended to replace with MS_MOV256_F32.
|
||||
#define MS_MLA256_F32(src1, src2, src3) _mm256_add_ps(src1, _mm256_mul_ps(src2, src3))
|
||||
#define MS_ST256_F32 _mm256_storeu_ps
|
||||
#define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2)
|
||||
|
@ -109,7 +107,6 @@ inline static float32x4_t vrecp(float32x4_t v) {
|
|||
#define MS_ADDQ_EPI32 _mm_add_epi32
|
||||
#define MS_MOVQ_F32 _mm_set1_ps
|
||||
#define MS_MOVQ_EPI32 _mm_set1_epi32
|
||||
#define MS_DUPQ_F32 _mm_load_ps1 // It is recommended to replace with MS_MOVQ_F32.
|
||||
#define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3))
|
||||
#define MS_STQ_F32 _mm_storeu_ps
|
||||
#define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2)
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h"
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h"
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h"
|
||||
#include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -354,8 +355,13 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
|
|||
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
||||
kernel::LiteKernel *kernel = nullptr;
|
||||
if (opParameter != nullptr && opParameter->infer_flag_) {
|
||||
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
|
||||
if (CheckConvDw1DWinograd(conv_param, ctx->thread_num_)) {
|
||||
kernel = new (std::nothrow) kernel::ConvolutionDepthwise3x3CPUKernel(opParameter, inputs, outputs, ctx);
|
||||
}
|
||||
#endif
|
||||
#if defined(ENABLE_ARM64) || defined(ENABLE_AVX)
|
||||
if (CheckConvDwUseIndirectBuffer(conv_param)) {
|
||||
if (kernel == nullptr && CheckConvDwUseIndirectBuffer(conv_param)) {
|
||||
kernel = new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx);
|
||||
}
|
||||
#endif
|
||||
|
@ -367,7 +373,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
|
|||
kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx);
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,
|
||||
|
|
|
@ -18,8 +18,10 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_INFER_INVALID;
|
||||
using mindspore::lite::RET_MEMORY_FAILED;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
@ -28,10 +30,6 @@ ConvolutionDepthwise3x3CPUKernel::~ConvolutionDepthwise3x3CPUKernel() {
|
|||
free(packed_weight_);
|
||||
packed_weight_ = nullptr;
|
||||
}
|
||||
if (sliding_ != nullptr) {
|
||||
delete sliding_;
|
||||
sliding_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() {
|
||||
|
@ -39,22 +37,26 @@ int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() {
|
|||
auto weight_tensor = in_tensors_[kWeightIndex];
|
||||
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
|
||||
int channel = weight_tensor->Batch();
|
||||
int pack_weight_size = weight_tensor->Batch() * weight_tensor->Height() * weight_tensor->Width();
|
||||
int c4 = UP_ROUND(channel, C4NUM);
|
||||
int pack_weight_size = c4 * C12NUM;
|
||||
|
||||
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
|
||||
if (packed_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), channel);
|
||||
PackWeightConvDw3x3Fp32(origin_weight, packed_weight_, channel);
|
||||
|
||||
bias_data_ = reinterpret_cast<float *>(malloc(channel * sizeof(float)));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
bias_data_ = reinterpret_cast<float *>(malloc(c4 * sizeof(float)));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
memset(bias_data_, 0, channel * sizeof(float));
|
||||
memset(bias_data_, 0, c4 * sizeof(float));
|
||||
if (in_tensors_.size() == kInputSize2) {
|
||||
auto bias_tensor = in_tensors_[kBiasIndex];
|
||||
auto ori_bias = reinterpret_cast<float *>(bias_tensor->MutableData());
|
||||
|
@ -65,11 +67,6 @@ int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() {
|
|||
}
|
||||
|
||||
int ConvolutionDepthwise3x3CPUKernel::Init() {
|
||||
sliding_ = new (std::nothrow) SlidingWindowParam;
|
||||
if (sliding_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new sliding window param failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = InitWeightBias();
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "Convolution depthwise 3x3 fp32 InitWeightBias failed.";
|
||||
|
@ -83,15 +80,19 @@ int ConvolutionDepthwise3x3CPUKernel::Init() {
|
|||
|
||||
int ConvolutionDepthwise3x3CPUKernel::ReSize() {
|
||||
ConvolutionBaseCPUKernel::Init();
|
||||
InitSlidingParamConvDw(sliding_, conv_param_, conv_param_->input_channel_);
|
||||
conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwise3x3CPUKernel::Execute(int task_id) {
|
||||
auto buffer = buffer_ + 64 * 10 * 10 * task_id;
|
||||
int units = UP_DIV(conv_param_->output_w_, C2NUM); // F(2, 3) contains 2 conv units
|
||||
int c4 = UP_ROUND(conv_param_->input_channel_, C4NUM);
|
||||
auto buffer = buffer_ + C12NUM * c4 * units * task_id;
|
||||
int step_oh = UP_DIV(conv_param_->output_h_, conv_param_->thread_num_);
|
||||
int start_oh = step_oh * task_id;
|
||||
int end_oh = MSMIN(start_oh + step_oh, conv_param_->output_h_);
|
||||
ConvDw3x3(output_ptr_, buffer, input_ptr_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_,
|
||||
sliding_, task_id);
|
||||
start_oh, end_oh);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -105,25 +106,18 @@ int ConvDw3x3Run(void *cdata, int task_id) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwise3x3CPUKernel::InitBuffer() {
|
||||
int buffer_size = 64 * 10 * 10 * conv_param_->thread_num_;
|
||||
buffer_ = reinterpret_cast<float *>(context_->allocator->Malloc(buffer_size * sizeof(float)));
|
||||
if (buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionDepthwise3x3CPUKernel::Run() {
|
||||
auto ret = InitBuffer();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Depthwise int8 ReSize error!";
|
||||
return ret;
|
||||
int units = UP_DIV(conv_param_->output_w_, C2NUM); // F(2, 3) contains 2 conv units
|
||||
int c4 = UP_ROUND(conv_param_->input_channel_, C4NUM);
|
||||
int buffer_size = units * c4 * C12NUM * conv_param_->thread_num_;
|
||||
buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(buffer_size * sizeof(float)));
|
||||
if (buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "ConvDw3x3Run failed to allocate buffer";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
if (IsTrain() && is_trainable()) {
|
||||
PackWeight();
|
||||
InitWeightBias();
|
||||
}
|
||||
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
|
@ -132,32 +126,21 @@ int ConvolutionDepthwise3x3CPUKernel::Run() {
|
|||
auto output_tensor = out_tensors_.at(kOutputIndex);
|
||||
output_ptr_ = reinterpret_cast<float *>(output_tensor->data_c());
|
||||
|
||||
if (sliding_->top_ > 0 || sliding_->bottom_ < conv_param_->output_h_ || sliding_->left_ > 0 ||
|
||||
sliding_->right_ < conv_param_->output_w_) {
|
||||
ConvDw3x3Pad(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast<float *>(bias_data_), conv_param_, sliding_);
|
||||
}
|
||||
ret = ParallelLaunch(this->context_->thread_pool_, ConvDw3x3Run, this, conv_param_->thread_num_);
|
||||
auto ret = ParallelLaunch(this->context_->thread_pool_, ConvDw3x3Run, this, conv_param_->thread_num_);
|
||||
ctx_->allocator->Free(buffer_);
|
||||
if (ret != RET_OK) {
|
||||
context_->allocator->Free(buffer_);
|
||||
MS_LOG(ERROR) << "ConvDw3x3Run error: error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
context_->allocator->Free(buffer_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void ConvolutionDepthwise3x3CPUKernel::PackWeight() {
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
auto origin_weight = reinterpret_cast<float *>(weight_tensor->MutableData());
|
||||
PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(),
|
||||
weight_tensor->Batch());
|
||||
}
|
||||
|
||||
int ConvolutionDepthwise3x3CPUKernel::Eval() {
|
||||
LiteKernel::Eval();
|
||||
if (is_trainable()) {
|
||||
PackWeight();
|
||||
InitWeightBias();
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
#endif
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_
|
||||
|
||||
#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/base/convolution_base.h"
|
||||
|
@ -39,14 +40,11 @@ class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel {
|
|||
int Eval() override;
|
||||
|
||||
private:
|
||||
void PackWeight();
|
||||
int InitBuffer();
|
||||
SlidingWindowParam *sliding_ = nullptr;
|
||||
float *packed_weight_ = nullptr;
|
||||
float *input_ptr_ = nullptr;
|
||||
float *output_ptr_ = nullptr;
|
||||
float *buffer_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_
|
||||
|
|
Loading…
Reference in New Issue