!18497 depthwise fp16 3x3 winograd

Merge pull request !18497 from zhaozhenlong/lite/opt/fp16_depthwise_3x3
This commit is contained in:
i-robot 2021-06-23 01:28:11 +00:00 committed by Gitee
commit 4982424ccf
13 changed files with 631 additions and 20 deletions

View File

@ -30,6 +30,371 @@ void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const floa
}
#endif
#ifdef ENABLE_ARM
static void ConvDw3x3RowLeftFp16(const float16_t *src, float16_t *line, int lw, int channel) {
MS_FLOAT16X8 v0, v1, v2, v3;
v0 = MS_MOVQ_F16((float16_t)0.0);
int ic = 0;
for (; ic < channel - 7; ic += 8) {
v1 = MS_LDQ_F16(src + ic);
v2 = MS_LDQ_F16(src + channel + ic);
v3 = MS_LDQ_F16(src + 2 * channel + ic);
MS_FLOAT16X8 b0 = MS_SUBQ_F16(v0, v2);
MS_FLOAT16X8 b1 = MS_ADDQ_F16(v1, v2);
MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1);
MS_FLOAT16X8 b3 = MS_SUBQ_F16(v3, v1);
MS_STQ_F16(line + lw * ic, b0);
MS_STQ_F16(line + lw * ic + 8, b1);
MS_STQ_F16(line + lw * ic + 16, b2);
MS_STQ_F16(line + lw * ic + 24, b3);
}
if (ic < channel) {
float16_t *remain_line = line + ic * lw;
memset(remain_line, 0, 16);
memset(remain_line + 8, 0, 16);
memset(remain_line + 16, 0, 16);
memset(remain_line + 24, 0, 16);
for (int i = 0; i < channel - ic; i++) {
float16_t d1 = src[i + ic];
float16_t d2 = src[i + ic + channel];
float16_t d3 = src[i + ic + 2 * channel];
remain_line[i] = (float16_t)0.0 - d2;
remain_line[i + 8] = d1 + d2;
remain_line[i + 16] = d2 - d1;
remain_line[i + 24] = d3 - d1;
}
}
}
static void ConvDw3x3RowMiddleFp16(const float16_t *src, float16_t *line, int lw, int channel) {
MS_FLOAT16X8 v0, v1, v2, v3;
int ic = 0;
for (; ic < channel - 7; ic += 8) {
v0 = MS_LDQ_F16(src + ic);
v1 = MS_LDQ_F16(src + channel + ic);
v2 = MS_LDQ_F16(src + 2 * channel + ic);
v3 = MS_LDQ_F16(src + 3 * channel + ic);
MS_FLOAT16X8 b0 = MS_SUBQ_F16(v0, v2);
MS_FLOAT16X8 b1 = MS_ADDQ_F16(v1, v2);
MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1);
MS_FLOAT16X8 b3 = MS_SUBQ_F16(v3, v1);
MS_STQ_F16(line + lw * ic, b0);
MS_STQ_F16(line + lw * ic + 8, b1);
MS_STQ_F16(line + lw * ic + 16, b2);
MS_STQ_F16(line + lw * ic + 24, b3);
}
if (ic < channel) {
float16_t *remain_line = line + ic * lw;
memset(remain_line, 0, 16);
memset(remain_line + 8, 0, 16);
memset(remain_line + 16, 0, 16);
memset(remain_line + 24, 0, 16);
for (int i = 0; i < channel - ic; i++) {
float16_t d0 = src[i + ic];
float16_t d1 = src[i + ic + channel];
float16_t d2 = src[i + ic + 2 * channel];
float16_t d3 = src[i + ic + 3 * channel];
remain_line[i] = d0 - d2;
remain_line[i + 8] = d1 + d2;
remain_line[i + 16] = d2 - d1;
remain_line[i + 24] = d3 - d1;
}
}
}
static void ConvDw3x3RowRightFp16(const float16_t *src, float16_t *line, int lw, int channel) {
MS_FLOAT16X8 v0, v1, v2, v3;
int ic = 0;
v3 = MS_MOVQ_F16((float16_t)0.0);
for (; ic < channel - 7; ic += 8) {
v0 = MS_LDQ_F16(src + ic);
v1 = MS_LDQ_F16(src + channel + ic);
v2 = MS_LDQ_F16(src + 2 * channel + ic);
MS_FLOAT16X8 b0 = MS_SUBQ_F16(v0, v2);
MS_FLOAT16X8 b1 = MS_ADDQ_F16(v1, v2);
MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1);
MS_FLOAT16X8 b3 = MS_SUBQ_F16(v3, v1);
MS_STQ_F16(line + lw * ic, b0);
MS_STQ_F16(line + lw * ic + 8, b1);
MS_STQ_F16(line + lw * ic + 16, b2);
MS_STQ_F16(line + lw * ic + 24, b3);
}
if (ic < channel) {
float16_t *remain_line = line + ic * lw;
memset(remain_line, 0, 16);
memset(remain_line + 8, 0, 16);
memset(remain_line + 16, 0, 16);
memset(remain_line + 24, 0, 16);
for (int i = 0; i < channel - ic; i++) {
float16_t d0 = src[i + ic];
float16_t d1 = src[i + ic + channel];
float16_t d2 = src[i + ic + 2 * channel];
remain_line[i] = d0 - d2;
remain_line[i + 8] = d1 + d2;
remain_line[i + 16] = d2 - d1;
remain_line[i + 24] = (float16_t)0.0 - d1;
}
}
}
static void ConvDw3x3RowSingleFp16(const float16_t *src, float16_t *line, int lw, int channel) {
MS_FLOAT16X8 v0, v1, v2;
int ic = 0;
v2 = MS_MOVQ_F16((float16_t)0.0);
for (; ic < channel - 7; ic += 8) {
v0 = MS_LDQ_F16(src + ic);
v1 = MS_LDQ_F16(src + channel + ic);
MS_FLOAT16X8 b2 = MS_SUBQ_F16(v2, v1);
MS_STQ_F16(line + lw * ic, v0);
MS_STQ_F16(line + lw * ic + 8, v1);
MS_STQ_F16(line + lw * ic + 16, b2);
memset(line + lw * ic + 24, 0, 16);
}
if (ic < channel) {
float16_t *remain_line = line + ic * lw;
memset(remain_line, 0, 16);
memset(remain_line + 8, 0, 16);
memset(remain_line + 16, 0, 16);
memset(remain_line + 24, 0, 16);
for (int i = 0; i < channel - ic; i++) {
float16_t d0 = src[i + ic];
float16_t d1 = src[i + ic + channel];
remain_line[i] = d0;
remain_line[i + 8] = d1;
remain_line[i + 16] = (float16_t)0.0 - d1;
}
}
}
static void ConvDw3x3InitTopFp16(const float16_t *src, float16_t **lines, int width, int channel) {
float16_t *line0 = lines[0];
float16_t *line1 = lines[1];
float16_t *line2 = lines[2];
int c8 = UP_ROUND(channel, C8NUM);
int lw = UP_DIV(width, C2NUM) * C4NUM;
memset(line0, 0, c8 * lw * sizeof(float16_t));
ConvDw3x3RowLeftFp16(src, line1, lw, channel);
ConvDw3x3RowLeftFp16(src + width * channel, line2, lw, channel);
int ow = 2;
for (; ow < width - 2; ow += 2) {
ConvDw3x3RowMiddleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel);
ConvDw3x3RowMiddleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel);
}
int remain = width - ow;
if (remain == 2) {
ConvDw3x3RowRightFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel);
ConvDw3x3RowRightFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel);
} else if (remain == 1) {
ConvDw3x3RowSingleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel);
ConvDw3x3RowSingleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel);
}
}
static void ConvDw3x3InitRowFp16(const float16_t *src, float16_t **lines, int width, int channel) {
float16_t *line0 = lines[0];
float16_t *line1 = lines[1];
float16_t *line2 = lines[2];
int lw = UP_DIV(width, C2NUM) * C4NUM;
ConvDw3x3RowLeftFp16(src - width * channel, line0, lw, channel);
ConvDw3x3RowLeftFp16(src, line1, lw, channel);
ConvDw3x3RowLeftFp16(src + width * channel, line2, lw, channel);
int ow = 2;
for (; ow < width - 2; ow += 2) {
ConvDw3x3RowMiddleFp16(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 8, lw, channel);
ConvDw3x3RowMiddleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel);
ConvDw3x3RowMiddleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel);
}
int remain = width - ow;
if (remain == 2) {
ConvDw3x3RowRightFp16(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 8, lw, channel);
ConvDw3x3RowRightFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel);
ConvDw3x3RowRightFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel);
} else if (remain == 1) {
ConvDw3x3RowSingleFp16(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 8, lw, channel);
ConvDw3x3RowSingleFp16(src + (ow - 1) * channel, line1 + 2 * ow * 8, lw, channel);
ConvDw3x3RowSingleFp16(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 8, lw, channel);
}
}
static void ConvDw3x3RowFp16(const float16_t *src, float16_t **lines, int width, int channel) {
float16_t *tmp = lines[0];
lines[0] = lines[1];
lines[1] = lines[2];
lines[2] = tmp;
int c8 = UP_ROUND(channel, C8NUM);
int lw = UP_DIV(width, C2NUM) * C4NUM;
memset(tmp, 0, c8 * lw * sizeof(float16_t));
ConvDw3x3RowLeftFp16(src, tmp, lw, channel);
int ow = 2;
for (; ow < width - 2; ow += 2) {
ConvDw3x3RowMiddleFp16(src + (ow - 1) * channel, tmp + 2 * ow * 8, lw, channel);
}
int remain = width - ow;
if (remain == 2) {
ConvDw3x3RowRightFp16(src + (ow - 1) * channel, tmp + 2 * ow * 8, lw, channel);
} else if (remain == 1) {
ConvDw3x3RowSingleFp16(src + (ow - 1) * channel, tmp + 2 * ow * 8, lw, channel);
}
}
static void ConvDw3x3BottomFp16(float16_t **lines, int width, int channel) {
float16_t *tmp = lines[0];
lines[0] = lines[1];
lines[1] = lines[2];
lines[2] = tmp;
int c8 = UP_ROUND(channel, C8NUM);
memset(tmp, 0, UP_DIV(width, C2NUM) * c8 * C4NUM * sizeof(float16_t));
}
void ConvDw3x3LineFp16(float16_t *dst, float16_t **lines, const float16_t *weight, const float16_t *bias_data,
int width, int ori_channel, bool relu, bool relu6) {
int channel = ori_channel;
float16_t *line0 = lines[0];
float16_t *line1 = lines[1];
float16_t *line2 = lines[2];
for (; channel > 0; channel -= 8) {
MS_FLOAT16X8 bias = MS_LDQ_F16(bias_data);
bias_data += 8;
MS_FLOAT16X8 g00 = MS_LDQ_F16(weight);
MS_FLOAT16X8 g01 = MS_LDQ_F16(weight + 8);
MS_FLOAT16X8 g02 = MS_LDQ_F16(weight + 16);
MS_FLOAT16X8 g03 = MS_LDQ_F16(weight + 24);
MS_FLOAT16X8 g10 = MS_LDQ_F16(weight + 32);
MS_FLOAT16X8 g11 = MS_LDQ_F16(weight + 40);
MS_FLOAT16X8 g12 = MS_LDQ_F16(weight + 48);
MS_FLOAT16X8 g13 = MS_LDQ_F16(weight + 56);
MS_FLOAT16X8 g20 = MS_LDQ_F16(weight + 64);
MS_FLOAT16X8 g21 = MS_LDQ_F16(weight + 72);
MS_FLOAT16X8 g22 = MS_LDQ_F16(weight + 80);
MS_FLOAT16X8 g23 = MS_LDQ_F16(weight + 88);
weight += 96;
float16_t *cur_dst = dst;
int ow = 0;
for (; ow < width - 1; ow += 2) {
MS_FLOAT16X8 acc0 = MS_MULQ_F16(MS_LDQ_F16(line0), g00);
MS_FLOAT16X8 acc1 = MS_MULQ_F16(MS_LDQ_F16(line0 + 8), g01);
MS_FLOAT16X8 acc2 = MS_MULQ_F16(MS_LDQ_F16(line0 + 16), g02);
MS_FLOAT16X8 acc3 = MS_MULQ_F16(MS_LDQ_F16(line0 + 24), g03);
line0 += 32;
acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line1), g10);
acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line1 + 8), g11);
acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line1 + 16), g12);
acc3 = MS_FMAQ_F16(acc3, MS_LDQ_F16(line1 + 24), g13);
line1 += 32;
acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line2), g20);
acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line2 + 8), g21);
acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line2 + 16), g22);
acc3 = MS_FMAQ_F16(acc3, MS_LDQ_F16(line2 + 24), g23);
line2 += 32;
MS_FLOAT16X8 res0 = MS_ADDQ_F16(acc0, MS_ADDQ_F16(acc2, acc1));
MS_FLOAT16X8 res1 = MS_ADDQ_F16(acc1, MS_SUBQ_F16(acc3, acc2));
res0 = MS_ADDQ_F16(res0, bias);
res1 = MS_ADDQ_F16(res1, bias);
if (relu || relu6) {
res0 = MS_MAXQ_F16(res0, MS_MOVQ_F16((float16_t)0.0));
res1 = MS_MAXQ_F16(res1, MS_MOVQ_F16((float16_t)0.0));
}
if (relu6) {
res0 = MS_MINQ_F16(res0, MS_MOVQ_F16((float16_t)6.0));
res1 = MS_MINQ_F16(res1, MS_MOVQ_F16((float16_t)6.0));
}
if (channel >= 8) {
MS_STQ_F16(cur_dst, res0);
MS_STQ_F16(cur_dst + ori_channel, res1);
} else {
for (int i = 0; i < channel; i++) {
cur_dst[i] = res0[i];
cur_dst[ori_channel + i] = res1[i];
}
}
cur_dst += 2 * ori_channel;
}
if (ow < width) {
MS_FLOAT16X8 acc0 = MS_MULQ_F16(MS_LDQ_F16(line0), g00);
MS_FLOAT16X8 acc1 = MS_MULQ_F16(MS_LDQ_F16(line0 + 8), g01);
MS_FLOAT16X8 acc2 = MS_MULQ_F16(MS_LDQ_F16(line0 + 16), g02);
line0 += 32;
acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line1), g10);
acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line1 + 8), g11);
acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line1 + 16), g12);
line1 += 32;
acc0 = MS_FMAQ_F16(acc0, MS_LDQ_F16(line2), g20);
acc1 = MS_FMAQ_F16(acc1, MS_LDQ_F16(line2 + 8), g21);
acc2 = MS_FMAQ_F16(acc2, MS_LDQ_F16(line2 + 16), g22);
line2 += 32;
MS_FLOAT16X8 res0 = MS_ADDQ_F16(acc0, MS_ADDQ_F16(acc2, acc1));
res0 = MS_ADDQ_F16(res0, bias);
if (relu || relu6) {
res0 = MS_MAXQ_F16(res0, MS_MOVQ_F16((float16_t)0.0));
}
if (relu6) {
res0 = MS_MINQ_F16(res0, MS_MOVQ_F16((float16_t)6.0));
}
if (channel >= 8) {
MS_STQ_F16(cur_dst, res0);
} else {
for (int i = 0; i < channel; i++) {
cur_dst[i] = res0[i];
}
}
}
dst += 8;
}
}
void ConvDw3x3Fp16(float16_t *output_data, float16_t *buffer, const float16_t *input_data, const float16_t *weight_data,
const float16_t *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh) {
int units = UP_DIV(conv_param->output_w_, C2NUM);
int c8 = UP_ROUND(conv_param->input_channel_, C8NUM);
int line = conv_param->input_channel_ * conv_param->input_w_;
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
for (int b = 0; b < conv_param->output_batch_; b++) {
const float16_t *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
float16_t *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
float16_t *line0 = buffer;
float16_t *line1 = buffer + units * c8 * C4NUM;
float16_t *line2 = buffer + units * c8 * C4NUM * 2;
float16_t *lines[3] = {line0, line1, line2};
int oh = start_oh;
if (oh == 0) {
// input trans
ConvDw3x3InitTopFp16(src, lines, conv_param->output_w_, conv_param->input_channel_);
} else {
// input trans
ConvDw3x3InitRowFp16(src + oh * line, lines, conv_param->output_w_, conv_param->input_channel_);
}
// dst calc and trans
ConvDw3x3LineFp16(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
relu, relu6);
for (oh = start_oh + 1; oh < end_oh - 1; oh++) {
// input trans
ConvDw3x3RowFp16(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_);
// dst calc and trans
ConvDw3x3LineFp16(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_,
conv_param->input_channel_, relu, relu6);
}
if (oh == conv_param->output_h_ - 1) {
// input trans
ConvDw3x3BottomFp16(lines, conv_param->output_w_, conv_param->input_channel_);
} else {
// input trans
ConvDw3x3RowFp16(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_);
}
// dst calc and trans
ConvDw3x3LineFp16(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_,
relu, relu6);
}
}
#endif
void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
const float16_t *bias_data, const ConvParameter *conv_param, int task_id) {
int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);

View File

@ -50,6 +50,14 @@ void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const flo
void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
int task_id);
#ifdef ENABLE_ARM
void ConvDw3x3LineFp16(float16_t *dst, float16_t **lines, const float16_t *weight, const float16_t *bias_data,
int width, int ori_channel, bool relu, bool relu6);
void ConvDw3x3Fp16(float16_t *output_data, float16_t *buffer, const float16_t *input_data, const float16_t *weight_data,
const float16_t *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh);
#endif
#ifdef __cplusplus
}
#endif

View File

@ -17,6 +17,26 @@
#include "nnacl/fp16/pack_fp16.h"
#include <string.h>
#ifdef ENABLE_ARM
void PackWeightConvDw3x3Fp16(const void *src, void *dst, int channel) {
// nchw to nc8hw8 with 1D F(2,3)
for (int i = 0; i < channel; i++) {
float16_t *src_kernel = (float16_t *)src + i * 9;
float16_t *dst_kernel = (float16_t *)dst + (i / 8) * 96 + i % 8;
for (int y = 0; y < 3; y++) {
float16_t g0 = src_kernel[3 * y];
float16_t g1 = src_kernel[3 * y + 1];
float16_t g2 = src_kernel[3 * y + 2];
dst_kernel[32 * y] = g0;
dst_kernel[32 * y + 8] = (float16_t)0.5 * (g0 + g1 + g2);
dst_kernel[32 * y + 16] = (float16_t)0.5 * (g0 - g1 + g2);
dst_kernel[32 * y + 24] = g2;
}
}
}
#endif
void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float16_t *packed_input, int real_cal_num,
int block_index) {
// input format : nhwc

View File

@ -81,6 +81,10 @@ void Transpose12x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_strid
void Transpose16x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride);
#endif
#ifdef ENABLE_ARM
void PackWeightConvDw3x3Fp16(const void *src, void *dst, int channel);
#endif
#ifdef __cplusplus
}
#endif

View File

@ -99,6 +99,9 @@ static inline float16x4_t ms_vcvt_f16_f32(float32x4_t in) {
#define MS_MAXQ_F16 vmaxq_f16
#define MS_LDQ_F16 vld1q_f16
#define MS_ADDQ_F16 vaddq_f16
#define MS_SUBQ_F16 vsubq_f16
#define MS_MULQ_F16 vmulq_f16
#define MS_FMAQ_F16 vfmaq_f16
static inline float16x8_t MS_TANHX8_F16(float16x8_t src) {
float32x4_t src_low = MS_CVT_F32_F16(vget_low_f16(src));

View File

@ -22,6 +22,7 @@
#include "src/runtime/kernel/arm/fp16/group_convolution_fp16.h"
#include "src/runtime/kernel/arm/fp16/convolution_depthwise_fp16.h"
#include "src/runtime/kernel/arm/fp16/convolution_depthwise_slidewindow_fp16.h"
#include "src/runtime/kernel/arm/fp16/convolution_depthwise_3x3_fp16.h"
#include "src/runtime/kernel/arm/base/group_convolution_creator.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
@ -117,6 +118,17 @@ kernel::InnerKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *
MS_ASSERT(opParameter != nullptr);
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
kernel::InnerKernel *kernel = nullptr;
#if defined(ENABLE_ARM)
if (CheckConvDw1DWinograd(conv_param, ctx->thread_num_)) {
kernel = new (std::nothrow) kernel::ConvolutionDepthwise3x3Fp16CPUKernel(opParameter, inputs, outputs, ctx);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
free(opParameter);
return nullptr;
}
return kernel;
}
#endif
if (conv_param->input_channel_ < 32) {
kernel = new (std::nothrow) kernel::ConvolutionDepthwiseSWFp16CPUKernel(opParameter, inputs, outputs, ctx);
} else {

View File

@ -0,0 +1,150 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef ENABLE_ARM
#include "src/runtime/kernel/arm/fp16/convolution_depthwise_3x3_fp16.h"
#include "include/errorcode.h"
#include "nnacl/fp16/pack_fp16.h"
#include "nnacl/fp16/conv_depthwise_fp16.h"
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INFER_INVALID;
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
ConvolutionDepthwise3x3Fp16CPUKernel::~ConvolutionDepthwise3x3Fp16CPUKernel() {
if (packed_weight_ != nullptr) {
free(packed_weight_);
packed_weight_ = nullptr;
}
}
int ConvolutionDepthwise3x3Fp16CPUKernel::InitWeightBias() {
// init weight: k, h, w, c; k == group == output_channel, c == 1
auto weight_tensor = in_tensors_[kWeightIndex];
auto origin_weight = reinterpret_cast<float16_t *>(weight_tensor->MutableData());
int channel = weight_tensor->Batch();
int c8 = UP_ROUND(channel, C8NUM);
int pack_weight_size = c8 * C12NUM;
if (packed_weight_ == nullptr) {
packed_weight_ = reinterpret_cast<float16_t *>(malloc(pack_weight_size * sizeof(float16_t)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
}
PackWeightConvDw3x3Fp16(origin_weight, packed_weight_, channel);
if (bias_data_ == nullptr) {
bias_data_ = reinterpret_cast<float16_t *>(malloc(c8 * sizeof(float16_t)));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return RET_ERROR;
}
}
memset(bias_data_, 0, c8 * sizeof(float16_t));
if (in_tensors_.size() == kInputSize2) {
auto bias_tensor = in_tensors_[kBiasIndex];
auto ori_bias = reinterpret_cast<float16_t *>(bias_tensor->MutableData());
memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(float16_t));
}
return RET_OK;
}
int ConvolutionDepthwise3x3Fp16CPUKernel::Init() {
auto ret = InitWeightBias();
if (ret != 0) {
MS_LOG(ERROR) << "Convolution depthwise 3x3 fp16 InitWeightBias failed.";
return RET_ERROR;
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int ConvolutionDepthwise3x3Fp16CPUKernel::ReSize() {
ConvolutionBaseCPUKernel::Init();
conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_);
return RET_OK;
}
int ConvolutionDepthwise3x3Fp16CPUKernel::Execute(int task_id) {
int units = UP_DIV(conv_param_->output_w_, C2NUM); // F(2, 3) contains 2 conv units
int c8 = UP_ROUND(conv_param_->input_channel_, C8NUM);
auto buffer = buffer_ + C12NUM * c8 * units * task_id;
int step_oh = UP_DIV(conv_param_->output_h_, conv_param_->thread_num_);
int start_oh = step_oh * task_id;
int end_oh = MSMIN(start_oh + step_oh, conv_param_->output_h_);
ConvDw3x3Fp16(output_ptr_, buffer, input_ptr_, packed_weight_, reinterpret_cast<float16_t *>(bias_data_), conv_param_,
start_oh, end_oh);
return RET_OK;
}
int ConvDw3x3Fp16Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
auto conv_dw = reinterpret_cast<ConvolutionDepthwise3x3Fp16CPUKernel *>(cdata);
auto ret = conv_dw->Execute(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionDepthwise3x3Run error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionDepthwise3x3Fp16CPUKernel::Run() {
if (IsTrainable() && (IsTrain() || IsRepack())) {
auto ret = InitWeightBias();
if (ret != 0) {
MS_LOG(ERROR) << "Convolution depthwise fp16 repack weight failure";
return RET_ERROR;
}
is_repack_ = false;
}
int units = UP_DIV(conv_param_->output_w_, C2NUM); // F(2, 3) contains 2 conv units
int c8 = UP_ROUND(conv_param_->input_channel_, C8NUM);
int buffer_size = units * c8 * C12NUM * conv_param_->thread_num_;
buffer_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(buffer_size * sizeof(float16_t)));
if (buffer_ == nullptr) {
MS_LOG(ERROR) << "ConvDw3x3Fp16Run failed to allocate buffer";
return RET_MEMORY_FAILED;
}
auto input_tensor = in_tensors_.at(kInputIndex);
input_ptr_ = reinterpret_cast<float16_t *>(input_tensor->data_c());
auto output_tensor = out_tensors_.at(kOutputIndex);
output_ptr_ = reinterpret_cast<float16_t *>(output_tensor->data_c());
auto ret = ParallelLaunch(this->context_, ConvDw3x3Fp16Run, this, conv_param_->thread_num_);
ctx_->allocator->Free(buffer_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvDw3x3Run error: error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionDepthwise3x3Fp16CPUKernel::Eval() {
if (IsTrainable()) {
is_repack_ = true;
}
return InnerKernel::Eval();
}
} // namespace mindspore::kernel
#endif

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP16_H_
#ifdef ENABLE_ARM
#include <vector>
#include "src/inner_kernel.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "nnacl/fp32/conv_depthwise_fp32.h"
namespace mindspore::kernel {
class ConvolutionDepthwise3x3Fp16CPUKernel : public ConvolutionBaseCPUKernel {
public:
ConvolutionDepthwise3x3Fp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx) {}
~ConvolutionDepthwise3x3Fp16CPUKernel() override;
int Init() override;
int ReSize() override;
int Run() override;
int InitWeightBias();
int Execute(int task_id);
int Eval() override;
private:
float16_t *packed_weight_ = nullptr;
float16_t *input_ptr_ = nullptr;
float16_t *output_ptr_ = nullptr;
float16_t *buffer_ = nullptr;
};
} // namespace mindspore::kernel
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP16_H_

View File

@ -118,7 +118,6 @@ int ConvolutionDelegateCPUKernel::ReSize() {
MS_LOG(ERROR) << "Selecting execute kernel failed for conv_kernel, got a nullptr.";
return RET_ERROR;
}
// conv_kernel_->set_name(this->name_);
}
FreeCopiedData();
return conv_kernel_->ReSize();

View File

@ -37,7 +37,7 @@ ml_ocr_sfz_add_final_0325 0.1
ml_hardware_pose 2
ml_bank_recog 0.1
2012_ATLANTA_10class_20190131_v4.0 12
mnet 9
mnet 12
recognition 10
ml_face_landmark 1
model_hebing_3branch 40
@ -52,7 +52,7 @@ hiai_video_seg 1
hiai_semantic_seg 3
hiai_human_seg 28
hiai_face_recognition_1 10
hiai_cpu_face_detect 4
hiai_cpu_face_detect 4.5
hiai_cpu_face_attr 12
hiai_face_attr1 12
# mtk_detect-mbv1-shortcut-400-400_nopostprocess_simplified: precision is 5%
@ -76,9 +76,9 @@ ml_video_edit_hairSeg_have_imageProcessLayer_interpTo145 0.5
hdc_age_medium 6
hdc_contour_pose_128 0.5
hdc_emotion 0.5
hdc_fivembnet 0.5
hdc_fivembnet 1
hdc_isface 0.5
hdc_mobilenetface 7.5
hdc_mobilenetface 8.5
hdc_retinaface 14
hdc_resnet 7
ml_video_edit_detect 2.5

View File

@ -55,7 +55,7 @@ hdc_resnet_1w_class.onnx 6
gts_text_detection.onnx;1;1,224,224,3 10
hdc_Face_Emotion_MTI_Aesthetic.onnx 144
ml_video_edit_imitate_filter.onnx 120
ml_facedetector.onnx 3
ml_facedetector.onnx 6
ml_ei_facedetection.onnx 2
#ml_video_edit_art_generate.onnx #mul operator overflows, not suitable for fp16
#ml_voice_detect.onnx #conv operator overflows, not suitable for fp16
@ -79,7 +79,7 @@ mtk_face_recognition_v2.onnx 2.5
ml_2012_ocr_detection_tmp.onnx 0.5
Harmony_Voiceprint_resnet18.onnx;1;1,150,40,1 4.5
bloom_hongmo_detection_tmp.onnx 0.5
Q_face_recognition.onnx 2
Q_face_recognition.onnx 3
ml_video_edit_enhance_update_tmp.onnx 0.5
Q888_face_recognition.onnx 3.5
Q888_iris_detect.onnx 0.5

View File

@ -6,10 +6,10 @@
ml_vision_guide_detection1.pb 0.5
ml_vision_guide_detection3.pb 0.5
ml_video_edit_generate_filter.pb 2
ml_ocr_jk.pb 0.7
ml_ocr_jk.pb 0.8
# The accumulated error causes the threshold to be exceeded
ml_ocr_latin.pb 12
scan_hms_angle.pb 2.5
scan_hms_angle.pb 7
scan_hms_detect.pb 2.5
ml_face_openclose.pb;1;1,32,32,3 0.5
ml_object_detect.pb;1;1,288,288,3 2
@ -23,7 +23,7 @@ hiai_PoseEstimation_Pcm.pb 0.5
hiai_cn_recognize_modify_padv2.pb;1;1,32,512,1 27
hiai_model_normalize_object_scene_ps_20200519.pb;1;1,224,224,3 17.1
# The output of mtk_model_ckpt.pb has small value
mtk_model_ckpt.pb 19
mtk_model_ckpt.pb 19.5
mtk_age_gender.pb 0.5
# The Difference of output node divided by 0 results in cumulative deviation
mtk_model_normalize_object_scene_ps_20200519.pb;1;1,224,224,3 10

View File

@ -52,10 +52,10 @@ ml_ei_landmark.tflite 3
mnist.tflite 4
mobilenet.tflite 0.1
resnet.tflite 120
scan_hms_angle1.tflite 4
scan_hms_angle1.tflite 6
scan_hms_detect.tflite 12
hiai_latin_ocr.tflite 45
hiai_latin_ocr_1.tflite 13
hiai_latin_ocr_1.tflite 14.5
ml_ocr_jk.tflite 2
nasnet_mobile.tflite 3
nasnet_large.tflite 3
@ -113,14 +113,14 @@ mnasnet_1.0_224_1_metadata_1.tflite 6
mnasnet_1.0_96_1_metadata_1.tflite 6
lite-model_on_device_vision_classifier_popular_us_products_V1_1.tflite 16
lite-model_on_device_vision_classifier_popular_wine_V1_1.tflite 80
posenet_mobilenet_float_075_1_default_1.tflite 45
posenet_mobilenet_float_075_1_default_1.tflite 49
deeplabv3_1_default_1.tflite 6
lite-model_deeplabv3-mobilenetv2_dm05-float16_1_default_1.tflite 13
lite-model_deeplabv3-mobilenetv2-float16_1_default_1.tflite 60
lite-model_east-text-detector_fp16_1.tflite 60
lite-model_cartoongan_fp16_1.tflite 3
lite-model_arbitrary-image-stylization-inceptionv3_fp16_predict_1.tflite 6
gts_detect_5k_tf115.tflite 6
gts_detect_5k_tf115.tflite 9.5
mtk_isface.tflite 0.2
mtk_landmark.tflite 0.1
mtk_new_detect.tflite 3
@ -143,7 +143,7 @@ ml_vision_guide_detection3_pb2tflite.tflite 0.5
ml_vision_guide_detection1_pb2tflite.tflite 0.5
ml_pic_shopping_pb2tflite.tflite 95
ml_ocr_jk_pb2tflite.tflite 0.5
ml_ocr_latin_pb2tflite.tflite 11
ml_ocr_latin_pb2tflite.tflite 11.5
scan_hms_angle_pb2tflite.tflite 2.5
scan_hms_detect_pb2tflite.tflite 1.5
ml_location.tflite 0.5
@ -154,19 +154,19 @@ ml_object_detect_pb2tflite.tflite 1.5
lite-model_on_device_vision_classifier_landmarks_classifier_africa_V1_1.tflite 10
lite-model_on_device_vision_classifier_landmarks_classifier_north_america_V1_1.tflite 19
lite-model_on_device_vision_classifier_landmarks_classifier_asia_V1_1.tflite 25
lite-model_on_device_vision_classifier_landmarks_classifier_oceania_antarctica_V1_1.tflite 10
lite-model_on_device_vision_classifier_landmarks_classifier_oceania_antarctica_V1_1.tflite 11
lite-model_on_device_vision_classifier_landmarks_classifier_europe_V1_1.tflite 32
lite-model_on_device_vision_classifier_landmarks_classifier_south_america_V1_1.tflite 14
ml_ei_landmark_pb2tflite.tflite 2
unet_mbv2_05_104pts.tflite 15
hiai_AADB_HADB_MBV2_model_f16.tflite 2.5
unet_mbv2_05_104pts.tflite 17
hiai_AADB_HADB_MBV2_model_f16.tflite 3.5
hiai_AADB_HADB_MBV2_model_fp32.tflite 4.5
mtk_age_gender_fp16.tflite 26
hiai_detect_curve_model_float32.tflite 9
Q_language_model_hrmini_Q4_b4_17w.tflite 3.5
lite-model_aiy_vision_classifier_food_V1_1.tflite 42
lite-model_aiy_vision_classifier_food_V1_1.tflite 47.5
lite-model_disease-classification_1.tflite 70
lite-model_models_mushroom-identification_v1_1.tflite 4.5
lite-model_models_mushroom-identification_v1_1.tflite 5
smartreply_1_default_1.tflite 0.5
text_classification.tflite 0.5
Q_AADB_HADB_MBV2_model.tflite 5