optimize winograd

This commit is contained in:
fuzhiye 2020-08-24 09:25:02 +08:00
parent 600263ccfe
commit d34c620dce
48 changed files with 486 additions and 432 deletions

View File

@ -47,7 +47,7 @@
/////////////////////////////////////////////////////////////////////////////////
//
// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
// int row, int col, size_t stride, size_t writeNhwc, size_t writeC4)
// int row, int col, size_t stride, size_t writeNhwc, size_t WriteWino)
// x0: a
// x1: b
// x2: c
@ -64,11 +64,20 @@ MatmulFloatNeon64Opt:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldr x9, [sp, #8]
ldr x14, [sp, #16]
mov w18, #32 // sizeof(float) * 8
mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth
mov x11, x3 // bias flag
mov x18, #4
ldr x17, [sp]
cbz x14, NoWinoSteps
mul x8, x7, x17
mov x11, #8
mul x11, x11, x17
mul x8, x8, x18
mul x11, x11, x18
NoWinoSteps:
mul x17, x17, x18
L1:
@ -79,7 +88,6 @@ L1:
L2:
mov x16, x1 // reload rhs ptr
mov w13, w5 // reload depth
mov x14, x3 // reload bias ptr
dup v8.4s, wzr
dup v9.4s, wzr
dup v10.4s, wzr
@ -173,9 +181,10 @@ LoopEnd:
fmla v31.4s, v4.4s, v2.s[3]
Bias:
cbz x11, Activation
ld1 {v0.4s}, [x14], #16
ld1 {v1.4s}, [x14], #16
cbz x3, Activation
ld1 {v0.4s}, [x3], #16
ld1 {v1.4s}, [x3]
sub x3, x3, #16
fadd v8.4s, v8.4s, v0.4s
fadd v9.4s, v9.4s, v1.4s
fadd v10.4s, v10.4s, v0.4s
@ -265,10 +274,8 @@ Relu:
fmax v31.4s, v31.4s, v3.4s
Write:
ldr w8, [sp, #8]
cbz w8, WriteC8
ldr w8, [sp, #16]
cbnz w8, WriteC4
cbnz x14, WriteWino
cbz x9, WriteC8
cmp w7, #1
beq Write1
cmp w7, #2
@ -721,39 +728,26 @@ Write7:
st1 {v31.s}[2], [x16], x17
b WriteEnd
WriteC8:
st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x2], #64
st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x2], #64
st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64
st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], #64
st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x2], #64
st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x2], #64
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64
st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64
b WriteEnd
WriteC4:
st1 {v8.8h}, [x2], #16
st1 {v10.8h}, [x2], #16
st1 {v12.8h}, [x2], #16
st1 {v14.8h}, [x2], #16
st1 {v16.8h}, [x2], #16
st1 {v18.8h}, [x2], #16
st1 {v20.8h}, [x2], #16
st1 {v22.8h}, [x2], #16
st1 {v24.8h}, [x2], #16
st1 {v26.8h}, [x2], #16
st1 {v28.8h}, [x2], #16
st1 {v30.8h}, [x2], #16
add x18, x2, x17
st1 {v9.8h}, [x18], #16
st1 {v11.8h}, [x18], #16
st1 {v13.8h}, [x18], #16
st1 {v15.8h}, [x18], #16
st1 {v17.8h}, [x18], #16
st1 {v19.8h}, [x18], #16
st1 {v21.8h}, [x18], #16
st1 {v23.8h}, [x18], #16
st1 {v25.8h}, [x18], #16
st1 {v27.8h}, [x18], #16
st1 {v29.8h}, [x18], #16
st1 {v31.8h}, [x18], #16
WriteWino:
st1 {v8.4s, v9.4s}, [x18], x8
st1 {v10.4s, v11.4s}, [x18], x8
st1 {v12.4s, v13.4s}, [x18], x8
st1 {v14.4s, v15.4s}, [x18], x8
st1 {v16.4s, v17.4s}, [x18], x8
st1 {v18.4s, v19.4s}, [x18], x8
st1 {v20.4s, v21.4s}, [x18], x8
st1 {v22.4s, v23.4s}, [x18], x8
st1 {v24.4s, v25.4s}, [x18], x8
st1 {v26.4s, v27.4s}, [x18], x8
st1 {v28.4s, v29.4s}, [x18], x8
st1 {v30.4s, v31.4s}, [x18], x8
b WriteEnd
Write8:
st1 {v8.4s, v9.4s}, [x18], x17
@ -798,15 +792,15 @@ WriteEnd:
End2:
subs w7, w7, #8 // rhs col - 8
add x1, x1, x15 // rhs ptr + stride
cbz x3, NoBiasStep
add x3, x3, #32 // bias ptr + stride
ldr w8, [sp, #8]
cbz w8, NoDstStep
ldr w8, [sp, #16]
cbnz w8, C4DstStep
NoBiasStep:
cbnz x14, WinoDstStep
cbz x9, NoDstStep
add x2, x2, #32 // dst ptr + stride
b NoDstStep
C4DstStep:
add x2, x2, x17
WinoDstStep:
add x2, x2, x11
NoDstStep:
bgt L1

View File

@ -32,8 +32,6 @@ typedef struct ConvParameter {
int stride_w_;
int dilation_h_;
int dilation_w_;
int pad_h_;
int pad_w_;
int pad_u_;
int pad_d_;
int pad_l_;
@ -51,8 +49,7 @@ typedef struct ConvParameter {
int thread_num_;
int input_unit_;
int output_unit_;
bool is_relu_;
bool is_relu6_;
ActType act_type_;
} ConvParameter;
typedef struct SlidingWindowParam {

View File

@ -53,16 +53,18 @@ void DepthwiseBorderPixelFp16(float16_t *dst, const float16_t *src, const float1
void DepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top,
int bottom, int left, int right, const ConvParameter *conv_param,
const SlidingWindowParam *sliding) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
float16_t *dst_h = dst + top * sliding->out_h_step_;
for (int oh = top; oh < bottom; oh++) {
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
const float16_t *src_h = src + ih * sliding->in_h_step_;
float16_t *dst_kernel = dst_h + left * sliding->block_channel_;
for (int ow = left; ow < right; ow++) {
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
const float16_t *src_w = src_h + iw * sliding->block_channel_;
@ -72,11 +74,10 @@ void DepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float16_t *
#ifdef ENABLE_ARM64
ConvDwFp16Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_ * sizeof(float16_t), sliding->in_kw_step_ * sizeof(float16_t),
conv_param->kernel_w_ * C8NUM * sizeof(float16_t), conv_param->is_relu_, conv_param->is_relu6_);
conv_param->kernel_w_ * C8NUM * sizeof(float16_t), relu, relu6);
#else
DepthwiseBorderPixelFp16(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C8NUM,
conv_param->is_relu_, conv_param->is_relu6_);
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C8NUM, relu, relu6);
#endif
dst_kernel += sliding->block_channel_;
} // width loop
@ -139,6 +140,8 @@ void DepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float16_t *
void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
const float16_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
int task_id) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
const float16_t *src = input_data;
float16_t *dst = output_data;
for (int b = 0; b < conv_param->output_batch_; b++) {
@ -157,8 +160,8 @@ void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const flo
conv_param->output_w_, conv_param, sliding);
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
const float16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
float16_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
#ifdef ENABLE_ARM64
@ -166,12 +169,12 @@ void ConvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const flo
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float16_t),
sliding->block_channel_ * sizeof(float16_t), sliding->in_sh_step_ * sizeof(float16_t),
sliding->in_sw_step_ * sizeof(float16_t), sliding->in_kh_step_ * sizeof(float16_t),
sliding->in_kw_step_ * sizeof(float16_t), conv_param->is_relu_, conv_param->is_relu6_);
sliding->in_kw_step_ * sizeof(float16_t), relu, relu6);
#else
DepthwiseCenterFp16(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_,
sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_,
sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_,
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_);
sliding->in_kh_step_, sliding->in_kw_step_, relu, relu6);
#endif
}
} // output C8 loop
@ -210,14 +213,14 @@ void DeconvDepthwiseBorderFp16(float16_t *dst, const float16_t *src, const float
const SlidingWindowParam *sliding) {
const float16_t *src_h = src + top * sliding->out_h_step_;
for (int ih = top; ih < bottom; ih++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
float16_t *dst_h = dst + oh * sliding->in_h_step_;
const float16_t *src_kernel = src_h + left * sliding->block_channel_;
for (int iw = left; iw < right; iw++) {
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
float16_t *dst_w = dst_h + ow * sliding->block_channel_;
@ -282,12 +285,14 @@ void DeconvDepthwiseCenterFp16(float16_t *dst, const float16_t *src, const float
void DeconvDepthwisePostFuncFp16(float16_t *dst, const float16_t *bias, int block_channel,
const ConvParameter *conv_param) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
float16_t *dst_k = dst;
for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) {
for (int c = 0; c < C8NUM; c++) {
dst_k[c] += bias[c];
dst_k[c] = (conv_param->is_relu_) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
dst_k[c] = (conv_param->is_relu6_) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
dst_k[c] = (relu) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
dst_k[c] = (relu6) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
}
dst_k += block_channel;
}
@ -315,8 +320,8 @@ void DeconvDwC8Fp16(float16_t *output_data, const float16_t *input_data, const f
conv_param->input_w_, conv_param, sliding);
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
float16_t *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
const float16_t *in_t =
src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;

View File

@ -173,16 +173,18 @@ void SWBorderPixel(float16_t *dst, const float16_t *src, const float16_t *weight
void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top,
int bottom, int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
float16_t *dst_h = dst + top * sliding->out_h_step_;
for (int oh = top; oh < bottom; oh++) {
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
const float16_t *src_h = src + ih * sliding->in_h_step_;
float16_t *dst_kernel = dst_h + left * sliding->block_channel_;
for (int ow = left; ow < right; ow++) {
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
const float16_t *src_w = src_h + iw * sliding->ic4_channel_;
@ -192,7 +194,7 @@ void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight,
SWBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_,
sliding->ic4_channel_, conv_param->is_relu_, conv_param->is_relu6_);
sliding->ic4_channel_, relu, relu6);
dst_kernel += sliding->block_channel_;
} // width loop
@ -273,6 +275,8 @@ void SWCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight,
void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, const float16_t *bias_data,
float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param,
SlidingWindowParam *slidingWindow_param) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
int oc4_res = conv_param->output_channel_ % C4NUM;
const float16_t *src = input_data;
float16_t *dst;
@ -299,8 +303,8 @@ void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, con
if (slidingWindow_param->right_ > slidingWindow_param->left_ &&
slidingWindow_param->bottom_ > slidingWindow_param->top_) {
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_l_;
const float16_t *in_t =
src_data + in_h_start * slidingWindow_param->in_h_step_ + in_w_start * slidingWindow_param->ic4_channel_;
float16_t *out_t = dst_data + slidingWindow_param->top_ * slidingWindow_param->out_h_step_ +
@ -310,7 +314,7 @@ void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, con
conv_param->kernel_w_, slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_,
slidingWindow_param->ic4_channel_, slidingWindow_param->in_sh_step_,
slidingWindow_param->in_sw_step_, slidingWindow_param->in_kh_step_,
slidingWindow_param->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_);
slidingWindow_param->in_kw_step_, relu, relu6);
}
} // output C4 loop
src += slidingWindow_param->in_step_;
@ -330,8 +334,8 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
bool relu = conv_param->is_relu_;
bool relu6 = conv_param->is_relu6_;
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
int thread_count = conv_param->thread_num_;
const int tile_n = 16;
int output_count = out_h * out_w;

View File

@ -73,8 +73,8 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
for (int ih = 0; ih < conv_param->input_h_; ih++) {
for (int iw = 0; iw < conv_param->input_w_; iw++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
@ -112,7 +112,7 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
} /*ih*/
} /*oc8*/
PostConvFuncFp16C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_,
conv_param->is_relu6_);
PostConvFuncFp16C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_,
conv_param->act_type_ == ActType_Relu, conv_param->act_type_ == ActType_Relu6);
return NNACL_OK;
}

View File

@ -21,14 +21,14 @@
void Conv1x1InputPackFp16(const float16_t *src, float16_t *dst, ConvParameter *conv_param) {
/* support nhwc */
for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) {
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_h_;
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_u_;
if (src_h < 0 || src_h >= conv_param->input_h_) {
continue;
}
const float16_t *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_;
float16_t *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_;
for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) {
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_w_;
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_l_;
if (src_w < 0 || src_w >= conv_param->input_w_) {
continue;
}
@ -46,8 +46,8 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
int in_channel = conv_param->input_channel_;

View File

@ -230,8 +230,8 @@ void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_inp
int input_channel = conv_param->input_channel_;
int input_width = conv_param->input_w_;
int input_height = conv_param->input_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_l_;
int pad_h = conv_param->pad_u_;
int ic8 = UP_DIV(input_channel, C8NUM);
if (out_w_block == 0) {
return;
@ -576,8 +576,8 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
int output_unit = conv_param->output_unit_;
int in_channel = conv_param->input_channel_;
int ic8 = UP_DIV(in_channel, C8NUM);
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int input_h = conv_param->input_h_;
int input_w = conv_param->input_w_;
if (out_w_block_num == 0) {

View File

@ -18,6 +18,7 @@
#include <string.h>
#include "nnacl/fp32/common_func.h"
#include "nnacl/winograd_transform.h"
#include "nnacl/fp32/matmul.h"
void SWBorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
int in_kh_step, int in_kw_step, int kernel_h, int kernel_w, int ic4, bool is_relu, bool is_relu6) {
@ -57,16 +58,18 @@ void SWBorderPixel(float *dst, const float *src, const float *weight, const floa
void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
int ic4 = sliding->ic4_channel_ / C4NUM;
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
float *dst_h = dst + top * sliding->out_h_step_;
for (int oh = top; oh < bottom; oh++) {
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
const float *src_h = src + ih * sliding->in_h_step_;
float *dst_kernel = dst_h + left * sliding->block_channel_;
for (int ow = left; ow < right; ow++) {
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
const float *src_w = src_h + iw * sliding->ic4_channel_;
@ -75,8 +78,8 @@ void SWBorder(float *dst, const float *src, const float *weight, const float *bi
const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * sliding->ic4_channel_;
SWBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_, ic4,
conv_param->is_relu_, conv_param->is_relu6_);
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_, ic4, relu,
relu6);
dst_kernel += sliding->block_channel_;
} // width loop
@ -144,6 +147,8 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
float *output_data, int task_id, ConvParameter *conv_param, SlidingWindowParam *slidingWindow_param) {
int ic4 = slidingWindow_param->ic4_channel_ / C4NUM;
int oc4_res = conv_param->output_channel_ % C4NUM;
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
const float *src = input_data;
float *dst = NULL;
if (oc4_res == 0) {
@ -169,28 +174,26 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
if (slidingWindow_param->right_ > slidingWindow_param->left_ &&
slidingWindow_param->bottom_ > slidingWindow_param->top_) {
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_l_;
const float *in_t =
src_data + in_h_start * slidingWindow_param->in_h_step_ + in_w_start * slidingWindow_param->ic4_channel_;
float *out_t = dst_data + slidingWindow_param->top_ * slidingWindow_param->out_h_step_ +
slidingWindow_param->left_ * slidingWindow_param->block_channel_;
#ifdef ENABLE_ARM64
ConvSwFp32Center(out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_,
conv_param->kernel_w_, slidingWindow_param->out_h_step_ * sizeof(float),
slidingWindow_param->block_channel_ * sizeof(float), ic4,
slidingWindow_param->in_sh_step_ * sizeof(float),
slidingWindow_param->in_sw_step_ * sizeof(float),
slidingWindow_param->in_kh_step_ * sizeof(float),
slidingWindow_param->in_kw_step_ * sizeof(float),
conv_param->is_relu_, conv_param->is_relu6_);
ConvSwFp32Center(
out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_, conv_param->kernel_w_,
slidingWindow_param->out_h_step_ * sizeof(float), slidingWindow_param->block_channel_ * sizeof(float), ic4,
slidingWindow_param->in_sh_step_ * sizeof(float), slidingWindow_param->in_sw_step_ * sizeof(float),
slidingWindow_param->in_kh_step_ * sizeof(float), slidingWindow_param->in_kw_step_ * sizeof(float), relu,
relu6);
#else
SWCenter(out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_,
conv_param->kernel_w_, slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_, ic4,
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_, conv_param->kernel_w_,
slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_, ic4,
slidingWindow_param->in_sh_step_, slidingWindow_param->in_sw_step_, slidingWindow_param->in_kh_step_,
slidingWindow_param->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_);
slidingWindow_param->in_kw_step_, relu, relu6);
#endif
}
} // output C4 loop
@ -219,6 +222,8 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
int kernel_plane = kernel_h * kernel_w;
int unit_size = kernel_plane * ic4 * C4NUM;
int packed_input_size = output_tile_count * TILE_NUM * unit_size;
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
// we accumulate 4 channels per time for input blocks
int conv_depth = kernel_h * kernel_w;
@ -240,11 +245,11 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
if (real_cal_num == TILE_NUM) {
float *gemm_output = output_data + out_offset;
gemm_func(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0, 0,
conv_param->is_relu_, conv_param->is_relu6_);
relu, relu6);
} else {
// res part
gemm_func(tmp_out_block, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0,
0, conv_param->is_relu_, conv_param->is_relu6_);
0, relu, relu6);
memcpy(output_data + out_offset, tmp_out_block, real_cal_num * out_channel * sizeof(float));
}
}
@ -264,34 +269,42 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
int out_w_block = UP_DIV(conv_param->output_w_, out_unit);
int out_h_block = UP_DIV(conv_param->output_h_, out_unit);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, TILE_NUM);
int output_tile_count = UP_DIV(output_count, C12NUM);
int out_channel = conv_param->output_channel_;
int oc4 = UP_DIV(out_channel, C4NUM);
int oc8 = UP_DIV(out_channel, C8NUM);
int input_unit_square = input_unit * input_unit;
size_t output_offset = oc4 * C4NUM * input_unit_square * sizeof(float);
float *trans_input = buffer_list[0];
float *gemm_out = buffer_list[1];
float *tmp_out_data = buffer_list[2];
float *tmp_data = buffer_list[3];
int trans_input_offset = TILE_NUM * input_unit_square * ic4 * C4NUM;
int gemm_out_offset = TILE_NUM * input_unit_square * oc4 * C4NUM;
float *col_buffer = buffer_list[4];
int trans_input_offset = C12NUM * input_unit_square * ic4 * C4NUM;
int gemm_out_offset = C12NUM * input_unit_square * oc8 * C8NUM;
int tmp_data_offset = input_unit_square * C4NUM;
int col_buffer_offset = C12NUM * ic4 * C4NUM;
// step 1 : filter transform (pre-processed offline)
// step 2 : input transform (online)
for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc4 * C4NUM;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
int out_tile_index = thread_id * TILE_NUM;
int cal_num = output_count - thread_id * TILE_NUM;
cal_num = cal_num > TILE_NUM ? TILE_NUM : cal_num;
int out_tile_index = thread_id * C12NUM;
int cal_num = output_count - thread_id * C12NUM;
cal_num = cal_num > C12NUM ? C12NUM : cal_num;
WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
input_trans_func);
// step 3 : gemm
gemm_func(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset, trans_weight, NULL,
input_unit_square, ic4, oc4 * C4NUM, output_offset, 1, 1, 0, 0);
float *src_ptr = trans_input + task_id * trans_input_offset;
float *dst_ptr = gemm_out + task_id * gemm_out_offset;
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM,
C12NUM, oc8 * C8NUM, input_unit_square, 2);
}
// step 4 : output transform
WinogradOutputTransform(gemm_out + task_id * gemm_out_offset, tmp_out_data + tmp_out_batch_offset, bias_data,
@ -442,18 +455,21 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
int output_channel = conv_param->output_channel_;
int oc4 = UP_DIV(output_channel, C4NUM);
int oc8 = UP_DIV(output_channel, C8NUM);
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, TILE_NUM);
int output_tile_count = UP_DIV(output_count, C12NUM);
const int input_unit_square = 4 * 4;
float *tile_buffer = buffer_list[0];
float *block_unit_buffer = buffer_list[1];
float *tmp_dst_buffer = buffer_list[2];
float *nc4hw4_out = buffer_list[3];
int tile_buffer_offset = TILE_NUM * input_unit_square * ic4 * C4NUM;
float *col_buffer = buffer_list[4];
int tile_buffer_offset = C12NUM * input_unit_square * ic4 * C4NUM;
int block_unit_buffer_offset = input_unit_square * C4NUM;
int tmp_dst_buffer_offset = TILE_NUM * input_unit_square * oc4 * C4NUM;
int tmp_dst_buffer_offset = C12NUM * input_unit_square * oc8 * C8NUM;
int col_buffer_offset = C12NUM * ic4 * C4NUM;
int input_batch = conv_param->input_batch_;
for (int batch = 0; batch < input_batch; batch++) {
@ -461,15 +477,20 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
int nc4hw4_buffer_offset = batch * oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
int start_index = thread_id * TILE_NUM;
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
int start_index = thread_id * C12NUM;
int real_cal_num = (output_count - start_index) < C12NUM ? (output_count - start_index) : C12NUM;
Conv3x3Fp32InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
out_w_block, conv_param);
gemm_func(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset,
transed_weight, NULL, input_unit_square, ic4, oc4 * C4NUM,
oc4 * C4NUM * input_unit_square * sizeof(float), 1, 1, 0, 0);
float *src_ptr = tile_buffer + task_id * tile_buffer_offset;
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
float *dst_ptr = tmp_dst_buffer + task_id * tmp_dst_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0,
ic4 * C4NUM, C12NUM, oc8 * C8NUM, input_unit_square, 2);
}
Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset,
bias_data, start_index, real_cal_num, out_w_block, conv_param);

View File

@ -38,13 +38,15 @@ void ConvDw(float *output_data, const float *input_data, const float *weight_dat
int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
int h_start = h_step * task_id;
int h_end = MSMIN(h_start + h_step, conv_param->output_h_);
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
for (int b = 0; b < conv_param->output_batch_; b++) {
const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_;
float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_;
for (int oh = h_start; oh < h_end; oh++) {
float *dst_data = dst + oh * conv_param->output_w_ * conv_param->output_channel_;
int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_h_;
int ih_origin = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih_origin, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih_origin, conv_param->dilation_h_));
@ -60,13 +62,13 @@ void ConvDw(float *output_data, const float *input_data, const float *weight_dat
int in_sw_step = conv_param->stride_w_ * conv_param->input_channel_;
for (int kw = 0; kw < conv_param->kernel_w_; kw++) {
int out_w_start = MSMAX(
0, (conv_param->pad_w_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_w_ -
0, (conv_param->pad_l_ - conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) / conv_param->stride_w_);
int out_w_end = MSMIN(conv_param->output_w_, (conv_param->input_w_ + conv_param->pad_l_ -
conv_param->dilation_w_ * kw + conv_param->stride_w_ - 1) /
conv_param->stride_w_);
float *dst_w = dst_data + out_w_start * conv_param->output_channel_;
int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_w_ + conv_param->dilation_w_ * kw;
int iw_origin = (out_w_start * conv_param->stride_w_) - conv_param->pad_l_ + conv_param->dilation_w_ * kw;
const float *src_kw = src_kh + iw_origin * conv_param->input_channel_;
int num_pixels = out_w_end - out_w_start;
@ -75,10 +77,10 @@ void ConvDw(float *output_data, const float *input_data, const float *weight_dat
weight_kh += conv_param->output_channel_;
}
}
if (conv_param->is_relu_) {
if (relu) {
ReluFp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
}
if (conv_param->is_relu6_) {
if (relu6) {
Relu6Fp32(dst_data, dst_data, conv_param->output_w_ * conv_param->output_channel_);
}
}
@ -91,16 +93,16 @@ void InitSlidingParam(SlidingWindowParam *sliding, const ConvParameter *conv_par
int top = 0;
int bottom = conv_param->output_h_;
for (; left * conv_param->stride_w_ < conv_param->pad_w_; left++) {
for (; left * conv_param->stride_w_ < conv_param->pad_l_; left++) {
}
for (; (right - 1) * conv_param->stride_w_ - conv_param->pad_w_ + conv_param->kernel_w_ * conv_param->dilation_w_ >
for (; (right - 1) * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_ * conv_param->dilation_w_ >
conv_param->input_w_ &&
right > left;
right--) {
}
for (; top * conv_param->stride_h_ < conv_param->pad_h_; top++) {
for (; top * conv_param->stride_h_ < conv_param->pad_u_; top++) {
}
for (; (bottom - 1) * conv_param->stride_h_ - conv_param->pad_h_ + conv_param->kernel_h_ * conv_param->dilation_h_ >
for (; (bottom - 1) * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_ * conv_param->dilation_h_ >
conv_param->input_h_ &&
bottom > top;
bottom--) {
@ -181,16 +183,18 @@ void DepthwiseBorderPixel(float *dst, const float *src, const float *weight, con
void DepthwiseBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom,
int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
float *dst_h = dst + top * sliding->out_h_step_;
for (int oh = top; oh < bottom; oh++) {
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
const float *src_h = src + ih * sliding->in_h_step_;
float *dst_kernel = dst_h + left * sliding->block_channel_;
for (int ow = left; ow < right; ow++) {
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
const float *src_w = src_h + iw * sliding->block_channel_;
@ -201,11 +205,10 @@ void DepthwiseBorder(float *dst, const float *src, const float *weight, const fl
#ifdef ENABLE_ARM64
ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float),
conv_param->kernel_w_ * C4NUM * sizeof(float), conv_param->is_relu_, conv_param->is_relu6_);
conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6);
#else
DepthwiseBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM,
conv_param->is_relu_, conv_param->is_relu6_);
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_ * C4NUM, relu, relu6);
#endif
dst_kernel += sliding->block_channel_;
} // width loop
@ -259,6 +262,8 @@ void DepthwiseCenter(float *dst, const float *src, const float *weight, const fl
// conv depthwise fp32: sliding window
void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data,
const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
const float *src = input_data;
float *dst = output_data;
for (int b = 0; b < conv_param->output_batch_; b++) {
@ -277,8 +282,8 @@ void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weig
conv_param->output_w_, conv_param, sliding);
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
#ifdef ENABLE_ARM64
@ -286,12 +291,12 @@ void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weig
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float),
sliding->in_kw_step_ * sizeof(float), conv_param->is_relu_, conv_param->is_relu6_);
sliding->in_kw_step_ * sizeof(float), relu, relu6);
#else
DepthwiseCenter(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_,
conv_param->is_relu_, conv_param->is_relu6_);
sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, relu,
relu6);
#endif
}
} // output C4 loop
@ -454,11 +459,11 @@ void ConvDw3x3Fp32InputTrans(const float *input_data, float *trans_input, float
memset(trans_input, 0, out_h_block * out_h_block * 16 * C4NUM * sizeof(float));
for (int oh = 0; oh < out_h_block; oh++) {
int ih = oh * 2 - conv_param->pad_h_;
int ih = oh * 2 - conv_param->pad_u_;
int real_h_start = ih > 0 ? 0 : -ih;
int real_h_end = (ih + input_unit) < conv_param->input_h_ ? input_unit : (conv_param->input_h_ - ih);
for (int ow = 0; ow < out_w_block; ow++) {
int iw = ow * 2 - conv_param->pad_w_;
int iw = ow * 2 - conv_param->pad_l_;
int real_w_start = iw > 0 ? 0 : -iw;
int real_w_end = (iw + input_unit) < conv_param->input_w_ ? input_unit : (conv_param->input_w_ - iw);
@ -642,6 +647,8 @@ void ConvDw3x3Fp32OutputUnit(float *src_buf, float *dst_output, const float *bia
void ConvDw3x3Fp32OutputTrans(float *trans_buffer, float *output_data, const float *bias, int out_h_block,
int out_w_block, const ConvParameter *conv_param) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
int oc4 = UP_DIV(conv_param->output_channel_, C4NUM);
bool h_in_range = true;
for (int oh = 0; oh < out_h_block; oh++) {
@ -661,8 +668,8 @@ void ConvDw3x3Fp32OutputTrans(float *trans_buffer, float *output_data, const flo
float *buf_ow = buf_oh + ow * 16 * C4NUM;
float *output_ow = output_oh + real_ow * oc4 * C4NUM;
ConvDw3x3Fp32OutputUnit(buf_ow, output_ow, bias, oc4 * C4NUM, conv_param->output_w_, h_in_range, w_in_range,
conv_param->is_relu_, conv_param->is_relu6_);
ConvDw3x3Fp32OutputUnit(buf_ow, output_ow, bias, oc4 * C4NUM, conv_param->output_w_, h_in_range, w_in_range, relu,
relu6);
}
}
}
@ -727,14 +734,14 @@ void DeconvDepthwiseBorder(float *dst, const float *src, const float *weight, in
const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
const float *src_h = src + top * sliding->out_h_step_;
for (int ih = top; ih < bottom; ih++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
float *dst_h = dst + oh * sliding->in_h_step_;
const float *src_kernel = src_h + left * sliding->block_channel_;
for (int iw = left; iw < right; iw++) {
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
float *dst_w = dst_h + ow * sliding->block_channel_;
@ -790,12 +797,14 @@ void DeconvDepthwiseCenter(float *dst, const float *src, const float *weight, in
#endif
void DeconvDepthwisePostFunc(float *dst, const float *bias, int block_channel, const ConvParameter *conv_param) {
bool relu = conv_param->act_type_ == ActType_Relu;
bool relu6 = conv_param->act_type_ == ActType_Relu6;
float *dst_k = dst;
for (int k = 0; k < conv_param->output_h_ * conv_param->output_w_; k++) {
for (int c = 0; c < C4NUM; c++) {
dst_k[c] += bias[c];
dst_k[c] = (conv_param->is_relu_) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
dst_k[c] = (conv_param->is_relu6_) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
dst_k[c] = (relu) ? (MSMAX(0, dst_k[c])) : (dst_k[c]);
dst_k[c] = (relu6) ? (MSMIN(6, MSMAX(0, dst_k[c]))) : (dst_k[c]);
}
dst_k += block_channel;
}
@ -821,8 +830,8 @@ void DeconvDwC4Fp32(float *output_data, const float *input_data, const float *we
conv_param->input_w_, conv_param, sliding);
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;

View File

@ -57,8 +57,8 @@ int DeConvPostFp32C12x8(const float *src, float *tmp, const float *bias, float *
for (int ih = 0; ih < conv_param->input_h_; ih++) {
for (int iw = 0; iw < conv_param->input_w_; iw++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
@ -97,7 +97,7 @@ int DeConvPostFp32C12x8(const float *src, float *tmp, const float *bias, float *
} /*ih*/
} /*oc8*/
PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_,
conv_param->is_relu6_);
PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_,
conv_param->act_type_ == ActType_Relu, conv_param->act_type_ == ActType_Relu6);
return NNACL_OK;
}

View File

@ -356,7 +356,7 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
dst[ci] = value;
}
}
} else {
} else if (out_type == OutType_C8) {
/* col8-major * row8-major => col12x8-major */
int col_8 = UP_ROUND(col, C8NUM);
int row_12 = UP_ROUND(row, C12NUM);
@ -364,9 +364,7 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
for (int c = 0; c < col_8; c++) {
int r12div = r / C12NUM, r12mod = r % C12NUM;
int c8div = c / C8NUM, c8mod = c % C8NUM;
int c4div = c / C4NUM, c4mod = c % C4NUM;
size_t ci = (out_type == OutType_C4) ? (c4div * C4NUM * row_12 + r * C4NUM + c4mod)
: (c8div * C8NUM * row_12 + r * C8NUM + c8mod);
size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod);
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod;
@ -379,6 +377,25 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
dst[ci] = value;
}
}
} else {
for (int i = 0; i < row; ++i) {
int src_r_offset = i;
int dst_r_offset = i * col * stride;
for (int j = 0; j < col; ++j) {
int c8div = j / 8, c8mod = j % 8;
size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
float value = 0;
for (int d = 0; d < deep; ++d) {
size_t ai = src_r_offset + d * row;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[j];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
}
}
return;
}
@ -387,7 +404,7 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
int col, size_t stride, int out_type) {
#ifdef ENABLE_ARM64
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
(int)(out_type == OutType_C4));
(int)(out_type == OutType_TileC8));
#else
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
#endif

View File

@ -20,9 +20,9 @@
static int is_a_ge_zero_and_a_lt_b(int a, int b) { return (unsigned)(a) < (unsigned)(b); }
void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param) {
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_w_;
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_l_;
// const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_;
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_h_;
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_u_;
// const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_;
const int stride_h = conv_param->stride_h_;
@ -72,9 +72,9 @@ void im2col_hwc(const float *in_data, float *data_col, ConvParameter *conv_param
// output matrix is (kernel_h*kernel_w*channels)X(output_h*output_w)
void im2row_hwc(const float *in_data, float *data_row, ConvParameter *conv_param) {
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_w_;
const int pad_left = /*conv_param->pad_l_*/ conv_param->pad_l_;
// const int pad_right = /*conv_param->pad_r_*/conv_param->pad_w_;
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_h_;
const int pad_up = /*conv_param->pad_u_*/ conv_param->pad_u_;
// const int pad_down = /*conv_param->pad_d/*/conv_param->pad_h_;
const int stride_h = conv_param->stride_h_;

View File

@ -68,14 +68,14 @@ void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight,
bool per_channel) {
int8_t *dst_h = dst + top * sliding->out_h_step_;
for (int oh = top; oh < bottom; oh++) {
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
const int16_t *src_h = src + ih * sliding->in_h_step_;
int8_t *dst_kernel = dst_h + left * sliding->block_channel_;
for (int ow = left; ow < right; ow++) {
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
const int16_t *src_w = src_h + iw * sliding->block_channel_;
@ -186,8 +186,8 @@ void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *w
per_channel);
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
const int16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
#ifdef ENABLE_ARM64
@ -241,14 +241,14 @@ void DeconvDepthwiseBorderInt8(int32_t *dst, const int16_t *src, const int16_t *
int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
const int16_t *src_h = src + top * sliding->out_h_step_;
for (int ih = top; ih < bottom; ih++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
int start_kh = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
int32_t *dst_h = dst + oh * sliding->in_h_step_;
const int16_t *src_kernel = src_h + left * sliding->block_channel_;
for (int iw = left; iw < right; iw++) {
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
int start_kw = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
int32_t *dst_w = dst_h + ow * C4NUM;
@ -341,8 +341,8 @@ void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *in
conv_param->input_w_, conv_param, sliding);
if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) {
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_;
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_;
int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_;
int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
int32_t *out_t = output_buffer + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
const int16_t *in_t =
src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;

View File

@ -33,8 +33,8 @@ int DeConvPostInt8C8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8
for (int ih = 0; ih < conv_param->input_h_; ih++) {
for (int iw = 0; iw < conv_param->input_w_; iw++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
@ -88,8 +88,8 @@ int DeConvPostInt8C4(const int32_t *src, const int32_t *bias, int32_t *tmp, int8
for (int ih = 0; ih < conv_param->input_h_; ih++) {
for (int iw = 0; iw < conv_param->input_w_; iw++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int oh = ih * conv_param->stride_h_ - conv_param->pad_u_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_l_;
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));

View File

@ -29,9 +29,7 @@ typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst,
typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col);
typedef enum ActType { ActType_No, ActType_Relu, ActType_Relu6 } ActType;
typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_C4 = 2 } OutType;
typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 } OutType;
typedef struct MatMulParameter {
OpParameter op_parameter_;

View File

@ -25,7 +25,6 @@
#define C8NUM 8
#define C12NUM 12
#define C16NUM 16
#define BLOCK 4
#define TILE_NUM 8
#define MSMIN(x, y) ((x) < (y) ? (x) : (y))
@ -62,4 +61,6 @@ typedef struct OpParameter {
int thread_num_;
} OpParameter;
typedef enum ActType { ActType_No, ActType_Relu, ActType_Relu6 } ActType;
#endif // MINDSPORE_LITE_NNACL_OP_BASE_H_

View File

@ -158,14 +158,14 @@ void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_pa
char *src = (char *)src_ptr;
char *dst = (char *)dst_ptr;
for (int dst_h = 0; dst_h < conv_param->output_h_; dst_h++) {
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_h_;
int src_h = dst_h * conv_param->stride_h_ - conv_param->pad_u_;
if (src_h < 0 || src_h >= conv_param->input_h_) {
continue;
}
const char *src_h_ptr = src + src_h * conv_param->input_w_ * conv_param->input_channel_ * data_size;
char *dst_h_ptr = dst + dst_h * conv_param->output_w_ * conv_param->input_channel_ * data_size;
for (int dst_w = 0; dst_w < conv_param->output_w_; dst_w++) {
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_w_;
int src_w = dst_w * conv_param->stride_w_ - conv_param->pad_l_;
if (src_w < 0 || src_w >= conv_param->input_w_) {
continue;
}
@ -296,8 +296,8 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
int in_channel = conv_param->input_channel_;
@ -348,8 +348,8 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
int in_channel = conv_param->input_channel_;
@ -419,8 +419,8 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
int in_channel = conv_param->input_channel_;

View File

@ -24,8 +24,8 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
int output_unit = conv_param->output_unit_;
int in_channel = conv_param->input_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int input_h = conv_param->input_h_;
int input_w = conv_param->input_w_;
if (out_w_block_num == 0) {
@ -42,7 +42,7 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s);
int src_plane_offset = ic4 * C4NUM * (src_y_s * input_w + src_x_s);
int dst_plane_offset = c * C4NUM;
int dst_plane_offset = c * C4NUM * ic4;
for (int ic = 0; ic < ic4; ic++) {
// clear tmp buffer
memset(tmp_data, 0, input_unit * input_unit * C4NUM * sizeof(float));
@ -67,8 +67,8 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
}
}
// input transform
int dst_ic4_offset = dst_plane_offset + ic * TILE_NUM * C4NUM;
size_t dst_step = ic4 * C4NUM * TILE_NUM;
int dst_ic4_offset = dst_plane_offset + ic * C4NUM;
size_t dst_step = C12NUM * ic4 * C4NUM;
float *trans_input_ptr = trans_input + dst_ic4_offset;
input_trans_func(tmp_data, trans_input_ptr, C4NUM, dst_step);
}
@ -86,6 +86,7 @@ void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const f
int output_h_unit_block = UP_DIV(output_h, output_unit);
int output_channel = conv_param->output_channel_;
int oc4 = UP_DIV(output_channel, C4NUM);
int oc8 = UP_DIV(output_channel, C8NUM);
int input_unit = conv_param->input_unit_;
if (output_unit_num == 0) {
return;
@ -93,17 +94,19 @@ void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const f
for (int i = 0; i < cal_num; i++) {
int dst_x_s = out_tile_index % output_unit_num;
int dst_y_s = out_tile_index / output_unit_num;
int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit;
int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit;
int dst_tile_offset = C4NUM * output_unit * (dst_x_s + dst_y_s * output_w_unit_block * output_unit);
for (int j = 0; j < oc4; j++) {
int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM;
int c8_block = j / 2;
int c8_res = j % 2;
int src_oc4_offset = src_tile_offset + c8_block * input_unit * input_unit * C8NUM + c8_res * C4NUM;
int dst_oc4_offset =
dst_tile_offset + j * C4NUM * output_h_unit_block * output_w_unit_block * output_unit * output_unit;
const float *src_ptr = gemm_out + src_oc4_offset;
const float *bias_ptr = bias_data + j * C4NUM;
float *dst_ptr = tmp_out_data + dst_oc4_offset;
output_trans_func(src_ptr, dst_ptr, bias_ptr, C4NUM, output_w_unit_block * output_unit);
output_trans_func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w_unit_block * output_unit);
}
out_tile_index++;
}
@ -283,8 +286,8 @@ void Conv3x3Fp32InputTransform(const float *input_data, float *trans_input, floa
int input_channel = conv_param->input_channel_;
int input_width = conv_param->input_w_;
int input_height = conv_param->input_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_l_;
int pad_h = conv_param->pad_u_;
int ic4 = UP_DIV(input_channel, C4NUM);
const int input_unit = 4;
if (out_w_block == 0) {
@ -300,7 +303,7 @@ void Conv3x3Fp32InputTransform(const float *input_data, float *trans_input, floa
int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y);
int src_plane_offset = ic4 * C4NUM * (origin_y * input_width + origin_x);
int dst_plane_offset = cal_id * C4NUM;
int dst_plane_offset = cal_id * C4NUM * ic4;
for (int ic = 0; ic < ic4; ic++) {
// clear tmp buffer
memset(tmp_data, 0, input_unit * input_unit * C4NUM * sizeof(float));
@ -326,8 +329,8 @@ void Conv3x3Fp32InputTransform(const float *input_data, float *trans_input, floa
}
// input transform
int dst_ic4_offset = dst_plane_offset + ic * TILE_NUM * C4NUM;
size_t dst_step = ic4 * C4NUM * TILE_NUM;
int dst_ic4_offset = dst_plane_offset + ic * C4NUM;
size_t dst_step = C12NUM * ic4 * C4NUM;
float *trans_input_ptr = trans_input + dst_ic4_offset;
Conv3x3Fp32InputUnit(tmp_data, trans_input_ptr, dst_step);
}
@ -336,8 +339,8 @@ void Conv3x3Fp32InputTransform(const float *input_data, float *trans_input, floa
void Conv3x3Fp32FilterTransform(float *weight_data, float *trans_weight, int iC4, int output_channel, int kernel_plane,
int oc_block) {
const int input_unit = 4;
int dst_step = iC4 * C4NUM * oc_block;
int oc_plane_block = UP_DIV(output_channel, oc_block);
int dst_step = iC4 * C4NUM * oc_block * oc_plane_block;
if (oc_block == 0) {
return;
}
@ -345,7 +348,7 @@ void Conv3x3Fp32FilterTransform(float *weight_data, float *trans_weight, int iC4
int oc_block_num = o / oc_block;
int oc_block_rem = o % oc_block;
int src_oc_offset = o * iC4 * C4NUM * kernel_plane;
int dst_oc_offset = oc_block_num * oc_block * iC4 * C4NUM * input_unit * input_unit + oc_block_rem;
int dst_oc_offset = oc_block_num * oc_block * iC4 * C4NUM + oc_block_rem;
for (int i = 0; i < iC4; i++) {
float *src_ic4_ptr = weight_data + src_oc_offset + i * kernel_plane * C4NUM;
float *dst_ic4_ptr = trans_weight + dst_oc_offset + i * oc_block * C4NUM;
@ -559,24 +562,24 @@ void Conv3x3Fp32OutputUnit(const float *gemm_out, const float *bias_data, float
float32x4_t bias_ptr = vld1q_f32(bias_data);
float32x4_t s00 = vld1q_f32(gemm_out);
float32x4_t s01 = vld1q_f32(gemm_out + 4);
float32x4_t s02 = vld1q_f32(gemm_out + 8);
float32x4_t s03 = vld1q_f32(gemm_out + 12);
float32x4_t s01 = vld1q_f32(gemm_out + 8);
float32x4_t s02 = vld1q_f32(gemm_out + 16);
float32x4_t s03 = vld1q_f32(gemm_out + 24);
float32x4_t s10 = vld1q_f32(gemm_out + 16);
float32x4_t s11 = vld1q_f32(gemm_out + 20);
float32x4_t s12 = vld1q_f32(gemm_out + 24);
float32x4_t s13 = vld1q_f32(gemm_out + 28);
float32x4_t s10 = vld1q_f32(gemm_out + 32);
float32x4_t s11 = vld1q_f32(gemm_out + 40);
float32x4_t s12 = vld1q_f32(gemm_out + 48);
float32x4_t s13 = vld1q_f32(gemm_out + 56);
float32x4_t s20 = vld1q_f32(gemm_out + 32);
float32x4_t s21 = vld1q_f32(gemm_out + 36);
float32x4_t s22 = vld1q_f32(gemm_out + 40);
float32x4_t s23 = vld1q_f32(gemm_out + 44);
float32x4_t s20 = vld1q_f32(gemm_out + 64);
float32x4_t s21 = vld1q_f32(gemm_out + 72);
float32x4_t s22 = vld1q_f32(gemm_out + 80);
float32x4_t s23 = vld1q_f32(gemm_out + 88);
float32x4_t s30 = vld1q_f32(gemm_out + 48);
float32x4_t s31 = vld1q_f32(gemm_out + 52);
float32x4_t s32 = vld1q_f32(gemm_out + 56);
float32x4_t s33 = vld1q_f32(gemm_out + 60);
float32x4_t s30 = vld1q_f32(gemm_out + 96);
float32x4_t s31 = vld1q_f32(gemm_out + 104);
float32x4_t s32 = vld1q_f32(gemm_out + 112);
float32x4_t s33 = vld1q_f32(gemm_out + 120);
float32x4_t t00 = vaddq_f32(vaddq_f32(s00, s10), s20);
float32x4_t t01 = vaddq_f32(vaddq_f32(s01, s11), s21);
@ -609,24 +612,24 @@ void Conv3x3Fp32OutputUnit(const float *gemm_out, const float *bias_data, float
const float *bias_ptr = bias_data + i;
float s00 = local_ptr[0];
float s01 = (local_ptr + 4)[0];
float s02 = (local_ptr + 8)[0];
float s03 = (local_ptr + 12)[0];
float s01 = (local_ptr + 8)[0];
float s02 = (local_ptr + 16)[0];
float s03 = (local_ptr + 24)[0];
float s10 = (local_ptr + 16)[0];
float s11 = (local_ptr + 20)[0];
float s12 = (local_ptr + 24)[0];
float s13 = (local_ptr + 28)[0];
float s10 = (local_ptr + 32)[0];
float s11 = (local_ptr + 40)[0];
float s12 = (local_ptr + 48)[0];
float s13 = (local_ptr + 56)[0];
float s20 = (local_ptr + 32)[0];
float s21 = (local_ptr + 36)[0];
float s22 = (local_ptr + 40)[0];
float s23 = (local_ptr + 44)[0];
float s20 = (local_ptr + 64)[0];
float s21 = (local_ptr + 72)[0];
float s22 = (local_ptr + 80)[0];
float s23 = (local_ptr + 88)[0];
float s30 = (local_ptr + 48)[0];
float s31 = (local_ptr + 52)[0];
float s32 = (local_ptr + 56)[0];
float s33 = (local_ptr + 60)[0];
float s30 = (local_ptr + 96)[0];
float s31 = (local_ptr + 104)[0];
float s32 = (local_ptr + 112)[0];
float s33 = (local_ptr + 120)[0];
float t00 = s00 + s10 + s20;
float t01 = s01 + s11 + s21;
@ -663,6 +666,7 @@ void Conv3x3Fp32OutputTransform(const float *gemm_out, float *out_data, const fl
int output_w = conv_param->output_w_;
int output_h = conv_param->output_h_;
int oc4 = UP_DIV(output_channel, C4NUM);
int oc8 = UP_DIV(output_channel, C8NUM);
const int input_unit = 4;
if (out_w_block == 0) {
return;
@ -670,11 +674,13 @@ void Conv3x3Fp32OutputTransform(const float *gemm_out, float *out_data, const fl
for (int i = 0; i < real_cal_num; i++) {
int out_w_index = (start_index + i) % out_w_block;
int out_h_index = (start_index + i) / out_w_block;
int src_tile_offset = i * oc4 * C4NUM * input_unit * input_unit;
int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit;
int dst_tile_offset = C4NUM * (out_w_index * OUPUT_UNIT + out_h_index * OUPUT_UNIT * output_w);
for (int j = 0; j < oc4; j++) {
int src_oc4_offset = src_tile_offset + j * input_unit * input_unit * C4NUM;
int c8_block = j / 2;
int c8_res = j % 2;
int src_oc4_offset = src_tile_offset + c8_block * input_unit * input_unit * C8NUM + c8_res * C4NUM;
int dst_oc4_offset = dst_tile_offset + j * C4NUM * output_h * output_w;
const float *src_ptr = gemm_out + src_oc4_offset;
const float *bias_ptr = bias_data + j * C4NUM;
@ -864,8 +870,8 @@ void Conv3x3Uint8InputTransform(const int16_t *input_data, int16_t *trans_input,
int input_channel = conv_param->input_channel_;
int input_width = conv_param->input_w_;
int input_height = conv_param->input_h_;
int pad_w = conv_param->pad_w_;
int pad_h = conv_param->pad_h_;
int pad_w = conv_param->pad_l_;
int pad_h = conv_param->pad_u_;
ConvQuantArg quant_arg = conv_param->conv_quant_arg_;
int input_zp = quant_arg.input_quant_args_[0].zp_;
const int ic8 = UP_DIV(input_channel, C8NUM);
@ -1221,9 +1227,9 @@ void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, i
int32x4_t ls;
int32x4_t rs;
if ((conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
out_multiplier = vld1q_s32(quant_multiplier);
ls = vld1q_s32(left_shift);
rs = vld1q_s32(right_shift);
out_multiplier = vld1q_s32(quant_multiplier + oc_start);
ls = vld1q_s32(left_shift + oc_start);
rs = vld1q_s32(right_shift + oc_start);
} else {
out_multiplier = vdupq_n_s32(quant_multiplier[0]);
ls = vdupq_n_s32(left_shift[0]);

View File

@ -4649,43 +4649,41 @@ void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float
// Utilize cost model to compute performance gain.
// If the gain is greater than got from Im2col, winograd algorithm will be chosen.
int SelectOutputUnit(ConvParameter *conv_param) {
int input_batch = conv_param->input_batch_;
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int in_channel = conv_param->input_channel_;
int out_h = conv_param->output_h_;
int in_c = conv_param->input_channel_;
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
int out_plane = out_h * out_w;
int out_h = conv_param->output_h_;
int out_c = conv_param->output_channel_;
int unit2 = UP_DIV(out_w * out_h, C12NUM * conv_param->op_parameter_.thread_num_);
int max_out_unit = (int)(sqrtf((float)unit2));
max_out_unit = max_out_unit < MAX_UNIT ? MAX_UNIT : max_out_unit;
max_out_unit = max_out_unit > MIN_UNIT ? max_out_unit : MIN_UNIT;
int max_unit = sqrt((float)(out_plane));
max_unit = max_unit > MIN_UNIT ? max_unit : MIN_UNIT;
max_unit = max_unit < MAX_UNIT ? max_unit : MAX_UNIT;
int output_unit = 1;
float ratio = 0.0f;
// cost of conventional convolution multiplications
float ori_cost = out_plane * out_channel * in_channel * kernel_h * kernel_w;
int unit = 0;
float max_rate = 0.0f;
float common_cost = (float)out_h * out_w * in_c * out_c * kernel_h * kernel_w;
for (int u = MIN_UNIT; u < max_unit; u++) {
int input_unit = u + kernel_h - 1;
if (input_unit != 4 && input_unit != 8) {
for (int i = MIN_UNIT; i <= max_out_unit; ++i) {
int input_unit = i + kernel_w - 1;
OutputTransformUnitFunc output_trans_func = GetOutputTransFunc(input_unit, i);
if (output_trans_func == NULL) {
continue;
}
// don't count filter transform cost, because it can be processed once offline.
const float input_trans_unit_cost = 2 * input_unit * input_unit * input_unit * in_channel;
float gemm_unit_cost = input_unit * input_unit * in_channel * out_channel;
float output_trans_unit_cost = input_unit * u * (u + input_unit) * out_channel;
// equation (23) in papar
float winograd_cost = (input_trans_unit_cost + gemm_unit_cost + output_trans_unit_cost) *
(UP_DIV(out_w, u) * (UP_DIV(out_h, u))) * input_batch;
float reduce_rate = ori_cost / winograd_cost;
if (reduce_rate > ratio && reduce_rate > 1) {
ratio = reduce_rate;
output_unit = u;
float penalty = ((float)input_unit * input_unit) / ((float)kernel_h * kernel_w) * 0.12f;
float wino_cost = ((2 + out_c) * (float)input_unit * input_unit * in_c + ((float)input_unit + i) * i * out_c) *
UP_DIV(out_w, i) * UP_DIV(out_h, i);
float reduce_rate = common_cost / wino_cost - penalty;
if (reduce_rate > max_rate) {
max_rate = reduce_rate;
unit = i;
}
}
if (max_rate < 1.0f) {
return 1;
}
// If output_unit is 1, then it is conventional convolution
return output_unit;
return unit;
}
InputTransformUnitFunc GetInputTransFunc(int input_unit) {
@ -4719,17 +4717,6 @@ void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *con
*output_unit = SelectOutputUnit(conv_param);
if (*output_unit > 1) {
*use_winograd = true;
int input_unit = conv_param->kernel_h_ + *output_unit - 1;
input_trans_func = GetInputTransFunc(input_unit);
if (input_trans_func == NULL) {
*use_winograd = false;
}
output_trans_func = GetOutputTransFunc(input_unit, *output_unit);
if (output_trans_func == NULL) {
*use_winograd = false;
}
} else {
*use_winograd = false;
}
} else {
*use_winograd = false;

View File

@ -376,10 +376,18 @@ void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output
*output_h = std::ceil(static_cast<float>(input_h) / static_cast<float>(stride_h));
auto pad_h_all = ((*output_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - input_h);
auto pad_w_all = ((*output_w - 1) * stride_w + (kernel_w - 1) * dilate_w + 1 - input_w);
pad_u_ = pad_h_all / 2;
pad_d_ = pad_h_all - pad_u_;
pad_l_ = pad_w_all / 2;
pad_r_ = pad_w_all - pad_l_;
if (pad_h_all < 0) {
pad_u_ = pad_d_ = 0;
} else {
pad_u_ = pad_h_all / 2;
pad_d_ = pad_h_all - pad_u_;
}
if (pad_w_all < 0) {
pad_l_ = pad_r_ = 0;
} else {
pad_l_ = pad_w_all / 2;
pad_r_ = pad_w_all - pad_l_;
}
} else {
*output_w = std::ceil((static_cast<float>(input_w) + pad_l_ + pad_r_ -
(static_cast<float>(kernel_w) - 1) * static_cast<float>(dilate_w)) /

View File

@ -126,14 +126,12 @@ int DeConv2D::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vecto
output->set_shape(out_shape);
if (pad_mode == schema::PadMode_SAME) {
pad_h_ = ((input_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - output_h) / 2;
pad_w_ = ((input_w - 1) * stride_w + (kernel_w - 1) * dilate_w + 1 - output_w) / 2;
pad_u_ = ((input_h - 1) * stride_h + (kernel_h - 1) * dilate_h + 1 - output_h) / 2;
pad_l_ = ((input_w - 1) * stride_w + (kernel_w - 1) * dilate_w + 1 - output_w) / 2;
} else if (pad_mode == schema::PadMode_VALID) {
pad_h_ = 0;
pad_w_ = 0;
pad_u_ = 0;
pad_l_ = 0;
} else if (pad_mode == schema::PadMode_CAFFE) {
pad_h_ = pad_u_;
pad_w_ = pad_l_;
} else {
MS_LOG(ERROR) << "unsupported pad mode for deconv";
}

View File

@ -74,16 +74,12 @@ class DeConv2D : public PrimitiveC {
int PadDown() const { return this->pad_d_; }
int PadLeft() const { return this->pad_l_; }
int PadRight() const { return this->pad_r_; }
int PadH() const { return this->pad_h_; }
int PadW() const { return this->pad_w_; }
protected:
int pad_u_ = 0;
int pad_d_ = 0;
int pad_l_ = 0;
int pad_r_ = 0;
int pad_h_ = 0;
int pad_w_ = 0;
};
} // namespace lite
} // namespace mindspore

View File

@ -170,10 +170,18 @@ int Pooling::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
output_h = std::ceil(static_cast<float>(input_h) / static_cast<float>(GetStrideH()));
auto pad_h_all = ((output_h - 1) * GetStrideH() + (window_h - 1) + 1 - input_h);
auto pad_w_all = ((output_w - 1) * GetStrideW() + (window_w - 1) + 1 - input_w);
pad_u_ = pad_h_all / 2;
pad_d_ = pad_h_all - pad_u_;
pad_l_ = pad_w_all / 2;
pad_r_ = pad_w_all - pad_l_;
if (pad_h_all < 0) {
pad_u_ = pad_d_ = 0;
} else {
pad_u_ = pad_h_all / 2;
pad_d_ = pad_h_all - pad_u_;
}
if (pad_w_all < 0) {
pad_l_ = pad_r_ = 0;
} else {
pad_l_ = pad_w_all / 2;
pad_r_ = pad_w_all - pad_l_;
}
} else {
auto round_mode = (schema::RoundMode)GetRoundMode();
if (round_mode == schema::RoundMode_FLOOR) {

View File

@ -376,8 +376,6 @@ OpParameter *PopulateConvParameter(const mindspore::lite::PrimitiveC *primitive)
conv_param->pad_d_ = conv2d_lite_primitive->PadDown();
conv_param->pad_l_ = conv2d_lite_primitive->PadLeft();
conv_param->pad_r_ = conv2d_lite_primitive->PadRight();
conv_param->pad_h_ = conv2d_lite_primitive->PadUp();
conv_param->pad_w_ = conv2d_lite_primitive->PadLeft();
conv_param->dilation_h_ = conv_primitive->GetDilateH();
conv_param->dilation_w_ = conv_primitive->GetDilateW();
conv_param->input_channel_ = conv_primitive->GetChannelIn();
@ -386,16 +384,13 @@ OpParameter *PopulateConvParameter(const mindspore::lite::PrimitiveC *primitive)
auto act_type = conv_primitive->GetActivationType();
switch (act_type) {
case schema::ActivationType_RELU:
conv_param->is_relu_ = true;
conv_param->is_relu6_ = false;
conv_param->act_type_ = ActType_Relu;
break;
case schema::ActivationType_RELU6:
conv_param->is_relu_ = false;
conv_param->is_relu6_ = true;
conv_param->act_type_ = ActType_Relu6;
break;
default:
conv_param->is_relu_ = false;
conv_param->is_relu6_ = false;
conv_param->act_type_ = ActType_No;
break;
}
return reinterpret_cast<OpParameter *>(conv_param);
@ -422,23 +417,18 @@ OpParameter *PopulateConvDwParameter(const mindspore::lite::PrimitiveC *primitiv
conv_param->pad_d_ = convdw_lite_primitive->PadDown();
conv_param->pad_l_ = convdw_lite_primitive->PadLeft();
conv_param->pad_r_ = convdw_lite_primitive->PadRight();
conv_param->pad_h_ = convdw_lite_primitive->PadUp();
conv_param->pad_w_ = convdw_lite_primitive->PadLeft();
conv_param->dilation_h_ = conv_primitive->GetDilateH();
conv_param->dilation_w_ = conv_primitive->GetDilateW();
auto act_type = conv_primitive->GetActivationType();
switch (act_type) {
case schema::ActivationType_RELU:
conv_param->is_relu_ = true;
conv_param->is_relu6_ = false;
conv_param->act_type_ = ActType_Relu;
break;
case schema::ActivationType_RELU6:
conv_param->is_relu_ = false;
conv_param->is_relu6_ = true;
conv_param->act_type_ = ActType_Relu6;
break;
default:
conv_param->is_relu_ = false;
conv_param->is_relu6_ = false;
conv_param->act_type_ = ActType_No;
break;
}
return reinterpret_cast<OpParameter *>(conv_param);
@ -464,23 +454,18 @@ OpParameter *PopulateDeconvDwParameter(const mindspore::lite::PrimitiveC *primit
conv_param->pad_d_ = deconvdw_lite_primitive->PadDown();
conv_param->pad_l_ = deconvdw_lite_primitive->PadLeft();
conv_param->pad_r_ = deconvdw_lite_primitive->PadRight();
conv_param->pad_h_ = deconvdw_lite_primitive->PadUp();
conv_param->pad_w_ = deconvdw_lite_primitive->PadLeft();
conv_param->dilation_h_ = conv_primitive->GetDilateH();
conv_param->dilation_w_ = conv_primitive->GetDilateW();
auto act_type = conv_primitive->GetActivationType();
switch (act_type) {
case schema::ActivationType_RELU:
conv_param->is_relu_ = true;
conv_param->is_relu6_ = false;
conv_param->act_type_ = ActType_Relu;
break;
case schema::ActivationType_RELU6:
conv_param->is_relu_ = false;
conv_param->is_relu6_ = true;
conv_param->act_type_ = ActType_Relu6;
break;
default:
conv_param->is_relu_ = false;
conv_param->is_relu6_ = false;
conv_param->act_type_ = ActType_No;
break;
}
return reinterpret_cast<OpParameter *>(conv_param);
@ -506,23 +491,18 @@ OpParameter *PopulateDeconvParameter(const mindspore::lite::PrimitiveC *primitiv
conv_param->pad_d_ = deconv_lite_primitive->PadDown();
conv_param->pad_l_ = deconv_lite_primitive->PadLeft();
conv_param->pad_r_ = deconv_lite_primitive->PadRight();
conv_param->pad_h_ = deconv_lite_primitive->PadH();
conv_param->pad_w_ = deconv_lite_primitive->PadW();
conv_param->dilation_h_ = conv_primitive->GetDilateH();
conv_param->dilation_w_ = conv_primitive->GetDilateW();
auto act_type = conv_primitive->GetActivationType();
switch (act_type) {
case schema::ActivationType_RELU:
conv_param->is_relu_ = true;
conv_param->is_relu6_ = false;
conv_param->act_type_ = ActType_Relu;
break;
case schema::ActivationType_RELU6:
conv_param->is_relu_ = false;
conv_param->is_relu6_ = true;
conv_param->act_type_ = ActType_Relu6;
break;
default:
conv_param->is_relu_ = false;
conv_param->is_relu6_ = false;
conv_param->act_type_ = ActType_No;
break;
}
return reinterpret_cast<OpParameter *>(conv_param);

View File

@ -322,10 +322,12 @@ int ConvolutionBaseCPUKernel::SetQuantParam() {
return ret;
}
// now only consider per tensor for output
CalculateActivationRangeQuantized(
conv_param_->is_relu_, conv_param_->is_relu6_, conv_param_->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param_->conv_quant_arg_.output_quant_args_[0].scale_, &conv_param_->conv_quant_arg_.out_act_min_[0],
&conv_param_->conv_quant_arg_.out_act_max_[0]);
bool relu = conv_param_->act_type_ == ActType_Relu;
bool relu6 = conv_param_->act_type_ == ActType_Relu6;
CalculateActivationRangeQuantized(relu, relu6, conv_param_->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param_->conv_quant_arg_.output_quant_args_[0].scale_,
&conv_param_->conv_quant_arg_.out_act_min_[0],
&conv_param_->conv_quant_arg_.out_act_max_[0]);
return RET_OK;
}

View File

@ -38,8 +38,7 @@ int Convolution1x1FP16CPUKernel::InitMatmulParam() {
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->row_16_ = UP_ROUND(matmul_param_->row_, C16NUM);
matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM);
matmul_param_->act_type_ = (conv_param_->is_relu6_) ? ActType_Relu6 : ActType_No;
matmul_param_->act_type_ = (conv_param_->is_relu_) ? ActType_Relu : matmul_param_->act_type_;
matmul_param_->act_type_ = conv_param_->act_type_;
return RET_OK;
}
@ -57,7 +56,7 @@ Convolution1x1FP16CPUKernel::~Convolution1x1FP16CPUKernel() {
}
int Convolution1x1FP16CPUKernel::InitConv1x1Param() {
pre_trans_input_ = (conv_param_->pad_h_ != 0 || conv_param_->pad_w_ != 0 || conv_param_->stride_h_ != 1 ||
pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 ||
conv_param_->stride_w_ != 1);
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM));

View File

@ -237,8 +237,8 @@ int Convolution3x3FP16CPUKernel::Run() {
}
// get real output
bool relu = conv_param_->is_relu_;
bool relu6 = conv_param_->is_relu6_;
bool relu = conv_param_->act_type_ == ActType_Relu;
bool relu6 = conv_param_->act_type_ == ActType_Relu6;
if (relu) {
UnPack3x3ReluOutputFp16(tmp_out_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_);

View File

@ -391,10 +391,10 @@ int ConvolutionWinogradFP16CPUKernel::Run() {
}
// get real output
if (conv_param_->is_relu_) {
if (conv_param_->act_type_ == ActType_Relu) {
UnPackWinogradReluOutputFp16(tmp_out_data_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
} else if (conv_param_->is_relu6_) {
} else if (conv_param_->act_type_ == ActType_Relu6) {
UnPackWinogradRelu6OutputFp16(tmp_out_data_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
} else {

View File

@ -232,34 +232,31 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
conv_param->input_h_ = inputs.front()->Height();
conv_param->input_w_ = inputs.front()->Width();
conv_param->input_channel_ = inputs.front()->Channel();
conv_param->output_h_ = outputs.front()->Height();
conv_param->output_w_ = outputs.front()->Width();
conv_param->output_channel_ = outputs.front()->Channel();
conv_param->op_parameter_.thread_num_ = ctx->thread_num_;
bool use_winograd = false;
bool use_sw = false;
int out_unit;
InputTransformUnitFunc input_trans_func = nullptr;
OutputTransformUnitFunc output_trans_func = nullptr;
if (primitive != nullptr && primitive->GetInferFlag()) {
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
use_sw = CheckIfUseSlideWindow(conv_param);
}
kernel::LiteKernel *kernel;
if (kernel_h == 1 && kernel_w == 1) {
kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
} else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
} else if (use_winograd) {
kernel =
new (std::nothrow) kernel::ConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, ctx, primitive, out_unit);
} else if (use_sw) {
kernel = new (std::nothrow) kernel::ConvolutionSWCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
if (kernel_h == 3 && kernel_w == 3 && out_unit == 2) {
kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
} else {
kernel = new (std::nothrow)
kernel::ConvolutionWinogradCPUKernel(op_parameter, inputs, outputs, ctx, primitive, out_unit);
}
} else {
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
}

View File

@ -65,8 +65,7 @@ void Convolution1x1CPUKernel::InitConv1x1MatmulParam() {
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM);
matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM);
matmul_param_->act_type_ = (conv_param_->is_relu6_) ? ActType_Relu6 : ActType_No;
matmul_param_->act_type_ = (conv_param_->is_relu_) ? ActType_Relu : matmul_param_->act_type_;
matmul_param_->act_type_ = conv_param_->act_type_;
return;
}
@ -98,7 +97,7 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() {
}
int Convolution1x1CPUKernel::InitConv1x1Param() {
pre_trans_input_ = (conv_param_->pad_h_ != 0 || conv_param_->pad_w_ != 0 || conv_param_->stride_h_ != 1 ||
pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 ||
conv_param_->stride_w_ != 1);
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM));

View File

@ -94,7 +94,9 @@ int Convolution3x3CPUKernel::InitWeightBias() {
}
int Convolution3x3CPUKernel::InitTmpBuffer() {
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
int oC4 = UP_DIV(conv_param_->output_channel_, C4NUM);
int oC8 = UP_DIV(conv_param_->output_channel_, C8NUM);
const int k_plane = 16;
MS_ASSERT(ctx_->allocator != nullptr);
@ -105,13 +107,20 @@ int Convolution3x3CPUKernel::InitTmpBuffer() {
return RET_ERROR;
}
size_t tmp_dst_buffer_size = thread_count_ * TILE_NUM * k_plane * oC4 * C4NUM * sizeof(float);
size_t tmp_dst_buffer_size = thread_count_ * C12NUM * k_plane * oC8 * C8NUM * sizeof(float);
tmp_dst_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tmp_dst_buffer_size));
if (tmp_dst_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed.";
return RET_ERROR;
}
size_t col_buffer_size = thread_count_ * C12NUM * C4NUM * ic4 * sizeof(float);
col_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(col_buffer_size));
if (col_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc col_buffer_ failed.";
return RET_ERROR;
}
size_t nc4hw4_out_size =
oC4 * C4NUM * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * sizeof(float);
nc4hw4_out_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(nc4hw4_out_size));
@ -124,6 +133,7 @@ int Convolution3x3CPUKernel::InitTmpBuffer() {
tmp_buffer_address_list_[1] = block_unit_buffer_;
tmp_buffer_address_list_[2] = tmp_dst_buffer_;
tmp_buffer_address_list_[3] = nc4hw4_out_;
tmp_buffer_address_list_[4] = col_buffer_;
return RET_OK;
}
@ -182,7 +192,7 @@ int Convolution3x3CPUKernel::ReSize() {
}
memset(nhwc4_input_, 0, nhwc4_input_size);
size_t tile_buffer_size = thread_count_ * TILE_NUM * C16NUM * iC4 * C4NUM * sizeof(float);
size_t tile_buffer_size = thread_count_ * C12NUM * C16NUM * iC4 * C4NUM * sizeof(float);
tile_buffer_ = reinterpret_cast<float *>(malloc(tile_buffer_size));
if (tile_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tile buffer failed.";
@ -237,8 +247,8 @@ int Convolution3x3CPUKernel::Run() {
return RET_ERROR;
}
auto is_relu = conv_param_->is_relu_;
auto is_relu6 = conv_param_->is_relu6_;
auto is_relu = conv_param_->act_type_ == ActType_Relu;
auto is_relu6 = conv_param_->act_type_ == ActType_Relu6;
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
if (is_relu) {
PackNC4HW4ToNHWCReluFp32(nc4hw4_out_, output_addr, conv_param_->output_batch_,

View File

@ -60,14 +60,19 @@ class Convolution3x3CPUKernel : public ConvolutionBaseCPUKernel {
ctx_->allocator->Free(nc4hw4_out_);
nc4hw4_out_ = nullptr;
}
if (col_buffer_ != nullptr) {
ctx_->allocator->Free(col_buffer_);
col_buffer_ = nullptr;
}
}
float *transformed_filter_addr_ = nullptr;
float *tile_buffer_ = nullptr;
float *block_unit_buffer_ = nullptr;
float *tmp_dst_buffer_ = nullptr;
float *col_buffer_ = nullptr;
float *nc4hw4_out_ = nullptr;
TmpBufferAddress tmp_buffer_address_list_[4];
TmpBufferAddress tmp_buffer_address_list_[5];
GEMM_FUNC_FP32 gemm_func_ = nullptr;
};
void ProcessFilter(float *origin_weight, float *dst_weight, ConvParameter *conv_param, int oc_block, int oc_block_num);

View File

@ -76,7 +76,7 @@ int WinogradFilterTransform(const float *weight_data, Matrix *trans_weight, int
int out_c_block = i / oc_block;
int out_c_res = i % oc_block;
int input_oz_offset = i * kernel_unit * kernel_unit * channel_in;
int output_oz_offset = out_c_block * strides[1] * input_unit * input_unit + out_c_res;
int output_oz_offset = out_c_block * strides[1] + out_c_res;
for (int j = 0; j < channel_in; j++) {
int ic4_block = j / C4NUM;
int ic4_res = j % C4NUM;
@ -93,7 +93,7 @@ int WinogradFilterTransform(const float *weight_data, Matrix *trans_weight, int
MatrixMultiply(tmp_data, matrix_gt_data, trans_out_data, input_unit, kernel_unit, input_unit, row);
for (int z = 0; z < input_unit_square; z++) {
int output_xy_offset = output_iz_offset + z * strides[1];
int output_xy_offset = output_iz_offset + z * strides[0];
*(trans_weight_data + output_xy_offset) = trans_out_data[z];
}
}
@ -151,7 +151,7 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
int ConvolutionWinogradCPUKernel::MallocFilterMatrix(int oc_block, int oc_block_num) {
int channel_in = conv_param_->input_channel_;
int ic4 = UP_DIV(channel_in, BLOCK);
int ic4 = UP_DIV(channel_in, C4NUM);
// set data
auto trans_matrix_data_size = input_unit_ * input_unit_ * ic4 * C4NUM * oc_block_num * oc_block * sizeof(float);
@ -196,10 +196,12 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
int output_h = conv_param_->output_h_;
int output_w = conv_param_->output_w_;
int oc4 = UP_DIV(channel_out, C4NUM);
int oc8 = UP_DIV(channel_out, C8NUM);
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
MS_ASSERT(ctx_->allocator != nullptr);
gemm_out_ = reinterpret_cast<float *>(
ctx_->allocator->Malloc(thread_count_ * TILE_NUM * input_unit_ * input_unit_ * oc4 * C4NUM * sizeof(float)));
ctx_->allocator->Malloc(thread_count_ * C12NUM * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float)));
if (gemm_out_ == nullptr) {
MS_LOG(ERROR) << "malloc gemm_out_ failed.";
return RET_ERROR;
@ -222,10 +224,18 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
return RET_ERROR;
}
col_buffer_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(thread_count_ * C12NUM * ic4 * C4NUM * sizeof(float)));
if (col_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc col_buffer_ failed.";
return RET_ERROR;
}
tmp_buffer_address_list_[0] = trans_input_;
tmp_buffer_address_list_[1] = gemm_out_;
tmp_buffer_address_list_[2] = tmp_out_data_;
tmp_buffer_address_list_[3] = tmp_data_;
tmp_buffer_address_list_[4] = col_buffer_;
return RET_OK;
}
@ -306,7 +316,7 @@ int ConvolutionWinogradCPUKernel::ReSize() {
}
memset(nhwc4_input_, 0, nhwc4_input_size);
size_t tile_buffer_size = thread_count_ * TILE_NUM * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float);
size_t tile_buffer_size = thread_count_ * C12NUM * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float);
trans_input_ = reinterpret_cast<float *>(malloc(tile_buffer_size));
if (trans_input_ == nullptr) {
MS_LOG(ERROR) << "malloc trans_input_ failed.";
@ -370,10 +380,10 @@ int ConvolutionWinogradCPUKernel::Run() {
// get real output
auto out_tensor = out_tensors_.front();
auto out_data = reinterpret_cast<float *>(out_tensor->Data());
if (conv_param_->is_relu_) {
if (conv_param_->act_type_ == ActType_Relu) {
UnPackWinogradReluOutput(tmp_out_data_, out_data, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
} else if (conv_param_->is_relu6_) {
} else if (conv_param_->act_type_ == ActType_Relu6) {
UnPackWinogradRelu6Output(tmp_out_data_, out_data, conv_param_->output_batch_, conv_param_->output_h_,
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
} else {

View File

@ -66,6 +66,10 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
ctx_->allocator->Free(tmp_out_data_);
tmp_out_data_ = nullptr;
}
if (col_buffer_ != nullptr) {
ctx_->allocator->Free(col_buffer_);
col_buffer_ = nullptr;
}
}
int kernel_unit_;
int input_unit_;
@ -74,6 +78,7 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
float *trans_input_ = nullptr;
float *gemm_out_ = nullptr;
float *tmp_out_data_ = nullptr;
float *col_buffer_ = nullptr;
Matrix *trans_weight_ = nullptr;
InputTransformUnitFunc input_trans_func_;
OutputTransformUnitFunc output_trans_func_;

View File

@ -146,7 +146,7 @@ int Convolution1x1Int8CPUKernel::Init() {
}
int Convolution1x1Int8CPUKernel::InitParam() {
pre_trans_input_ = (conv_param_->pad_h_ != 0 || conv_param_->pad_w_ != 0 || conv_param_->stride_h_ != 1 ||
pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 ||
conv_param_->stride_w_ != 1);
matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_;

View File

@ -36,7 +36,7 @@ int Conv2dTransposeOpenCLKernel::Init() {
MS_LOG(ERROR) << "only support kh=kw=2 and stride_h=stride_w=2.";
return RET_ERROR;
}
if (param->pad_h_ != 0 || param->pad_w_ != 0) {
if (param->pad_u_ != 0 || param->pad_l_ != 0) {
MS_LOG(ERROR) << "only support pad =0.";
return RET_ERROR;
}
@ -170,7 +170,7 @@ int Conv2dTransposeOpenCLKernel::Run() {
int co = out_tensors_[0]->Channel();
int kh = param->kernel_h_;
int kw = param->kernel_w_;
int pad = param->pad_h_;
int pad = param->pad_u_;
int oh = out_tensors_[0]->Height();
int ow = out_tensors_[0]->Width();
int h = in_tensors_[0]->Height();

View File

@ -382,9 +382,9 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolution() {
" }\n\n";
code += " FLT4 out0_c4_bias = out0_c4 + bias[co_slice];\n";
if (param->is_relu_) {
if (param->act_type_ == ActType_Relu) {
code += " out0_c4_bias = max(out0_c4_bias, (FLT4)(0.0f));\n";
} else if (param->is_relu6_) {
} else if (param->act_type_ == ActType_Relu6) {
code += " out0_c4_bias = clamp(out0_c4_bias, (FLT4)(0.0f), (FLT4)(6.0f));\n";
}
@ -609,9 +609,9 @@ std::string ConvolutionOpenCLKernel::CodeGenWinograd36To4x4() {
" acc += bias[slice];\n";
auto param = reinterpret_cast<ConvParameter *>(op_parameter_);
if (param->is_relu_) {
if (param->act_type_ == ActType_Relu) {
code += " acc = max(acc, (float4)(0.0f));\n";
} else if (param->is_relu6_) {
} else if (param->act_type_ == ActType_Relu6) {
code += " acc = clamp(acc, (float4)(0.0f), (float4)(6.0f));\n";
}

View File

@ -163,7 +163,7 @@ int DepthwiseConv2dOpenCLKernel::Run() {
float relu_clip1 = 6.0;
cl_int2 kernel_size = {parameter->kernel_h_, parameter->kernel_w_};
cl_int2 stride = {parameter->stride_h_, parameter->stride_w_};
cl_int2 padding = {-parameter->pad_h_, -parameter->pad_w_};
cl_int2 padding = {-parameter->pad_u_, -parameter->pad_l_};
cl_int2 dilation = {parameter->dilation_h_, parameter->dilation_w_};
cl_int4 src_size = {in_tensors_[0]->Width(), in_tensors_[0]->Height(), (cl_int)CI4, in_tensors_[0]->Batch()};
cl_int4 dst_size = {(cl_int)out_tensors_[0]->Width(), (cl_int)out_tensors_[0]->Height(), (cl_int)CO4,

View File

@ -47,8 +47,8 @@ void InitConvParamPack(ConvParameter *conv_param) {
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
conv_param->pad_h_ = 1;
conv_param->pad_w_ = 1;
conv_param->pad_u_ = 1;
conv_param->pad_l_ = 1;
}
TEST_F(TestPack, PackInputFp32) {

View File

@ -50,8 +50,8 @@ void InitConvParamGroup1Fp16(ConvParameter *conv_param) {
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
conv_param->pad_h_ = 1;
conv_param->pad_w_ = 1;
conv_param->pad_u_ = 1;
conv_param->pad_l_ = 1;
conv_param->thread_num_ = 1;
}
@ -75,8 +75,8 @@ void InitConvParamGroup2Fp16(ConvParameter *conv_param) {
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
conv_param->pad_h_ = 1;
conv_param->pad_w_ = 1;
conv_param->pad_u_ = 1;
conv_param->pad_l_ = 1;
conv_param->thread_num_ = 1;
}

View File

@ -50,7 +50,7 @@ TEST_F(TestConv1x1Fp32, Input1x1PrePack1) {
conv_param->output_h_ = 4;
conv_param->output_w_ = 5;
conv_param->stride_h_ = conv_param->stride_w_ = 4;
conv_param->pad_h_ = conv_param->pad_w_ = 2;
conv_param->pad_u_ = conv_param->pad_l_ = 2;
float out[20] = {0};
Conv1x1InputPack(in, out, conv_param, sizeof(float));
@ -91,7 +91,7 @@ TEST_F(TestConv1x1Fp32, Input1x1PrePack2) {
conv_param->output_h_ = 7;
conv_param->output_w_ = 4;
conv_param->stride_h_ = conv_param->stride_w_ = 3;
conv_param->pad_h_ = conv_param->pad_w_ = 0;
conv_param->pad_u_ = conv_param->pad_l_ = 0;
float out[28] = {0};
Conv1x1InputPack(in, out, conv_param, sizeof(float));
@ -105,7 +105,7 @@ TEST_F(TestConv1x1Fp32, Input1x1PrePack3) {
conv_param->input_h_ = conv_param->input_w_ = 3;
conv_param->output_h_ = conv_param->output_w_ = 3;
conv_param->stride_h_ = conv_param->stride_w_ = 2;
conv_param->pad_h_ = conv_param->pad_w_ = 1;
conv_param->pad_u_ = conv_param->pad_l_ = 1;
float in[] = {1.6767339, 12.25904, 19.018835, 3.0790641, -9.252135, -8.685675, 3.6115494, 3.2282279, 17.025112,
-5.052577, 12.750252, 12.701241, -8.9477215, -9.080522, 19.03931, -6.501229, -4.122992, 9.540845};
@ -124,7 +124,7 @@ TEST_F(TestConv1x1Fp32, Input1x1PrePack4) {
conv_param->input_h_ = conv_param->input_w_ = 3;
conv_param->output_h_ = conv_param->output_w_ = 3;
conv_param->stride_h_ = conv_param->stride_w_ = 2;
conv_param->pad_h_ = conv_param->pad_w_ = 1;
conv_param->pad_u_ = conv_param->pad_l_ = 1;
float in[] = {4.1795, 13.142, -3.593, 16.505, 19.899, 8.5562, 19.969, -6.235, -2.380, -9.027, 9.5542,
18.974, 23.622, 8.3608, 47.325, -14.36, 15.370, 4.3049, -0.784, 37.925, -0.081, 6.1298,
0.6721, -1.517, 37.998, 13.719, 11.029, 1.7127, -1.770, 41.903, 9.0560, 14.988, 3.1866,
@ -281,8 +281,8 @@ int Conv1x1TestInit1(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<l
conv_param->kernel_h_ = conv_param->kernel_w_ = 1;
conv_param->stride_h_ = conv_param->stride_w_ = 2;
conv_param->dilation_h_ = conv_param->dilation_w_ = 1;
conv_param->pad_h_ = conv_param->pad_w_ = 1;
conv_param->is_relu_ = conv_param->is_relu6_ = false;
conv_param->pad_u_ = conv_param->pad_l_ = 1;
conv_param->act_type_ = ActType_No;
return out_t->ElementsNum();
}
@ -348,9 +348,8 @@ int Conv1x1TestInit2(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<l
conv_param->kernel_h_ = conv_param->kernel_w_ = 1;
conv_param->stride_h_ = conv_param->stride_w_ = 1;
conv_param->dilation_h_ = conv_param->dilation_w_ = 1;
conv_param->pad_h_ = conv_param->pad_w_ = 0;
conv_param->is_relu_ = false;
conv_param->is_relu6_ = false;
conv_param->pad_u_ = conv_param->pad_l_ = 0;
conv_param->act_type_ = ActType_No;
return out_t->ElementsNum();
}

View File

@ -47,8 +47,8 @@ void InitConvDwParam(ConvParameter *conv_param) {
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
conv_param->pad_h_ = 1;
conv_param->pad_w_ = 1;
conv_param->pad_u_ = 1;
conv_param->pad_l_ = 1;
}
void InitConvDwCreator(std::vector<lite::tensor::Tensor *> *inputs, std::vector<lite::tensor::Tensor *> *outputs,

View File

@ -468,7 +468,7 @@ int DeConvTestInit1(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<li
conv_param->kernel_h_ = conv_param->kernel_w_ = 3;
conv_param->stride_h_ = conv_param->stride_w_ = 2;
conv_param->dilation_h_ = conv_param->dilation_w_ = 1;
conv_param->pad_h_ = conv_param->pad_w_ = 1;
conv_param->pad_u_ = conv_param->pad_l_ = 1;
return out_t->ElementsNum();
}
@ -537,7 +537,7 @@ int DeConvTestInit2(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<li
conv_param->kernel_h_ = conv_param->kernel_w_ = 3;
conv_param->stride_h_ = conv_param->stride_w_ = 2;
conv_param->dilation_h_ = conv_param->dilation_w_ = 1;
conv_param->pad_h_ = conv_param->pad_w_ = 1;
conv_param->pad_u_ = conv_param->pad_l_ = 1;
return out_t->ElementsNum();
}
@ -616,7 +616,7 @@ int DeConvTestInit3(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<li
conv_param->kernel_h_ = conv_param->kernel_w_ = 2;
conv_param->stride_h_ = conv_param->stride_w_ = 3;
conv_param->dilation_h_ = conv_param->dilation_w_ = 2;
conv_param->pad_h_ = conv_param->pad_w_ = 0;
conv_param->pad_u_ = conv_param->pad_l_ = 0;
return out_t->ElementsNum();
}
@ -685,8 +685,8 @@ int DeConvTestInit4(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<li
conv_param->kernel_h_ = conv_param->kernel_w_ = 3;
conv_param->stride_h_ = conv_param->stride_w_ = 1;
conv_param->dilation_h_ = conv_param->dilation_w_ = 1;
conv_param->pad_h_ = conv_param->pad_w_ = 0;
conv_param->is_relu_ = conv_param->is_relu6_ = false;
conv_param->pad_u_ = conv_param->pad_l_ = 0;
conv_param->act_type_ = ActType_No;
return out_t->ElementsNum();
}

View File

@ -52,12 +52,11 @@ void InitConvParamGroup1FP32(ConvParameter *conv_param) {
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
conv_param->pad_h_ = 1;
conv_param->pad_w_ = 1;
conv_param->pad_u_ = 1;
conv_param->pad_l_ = 1;
conv_param->group_ = 1;
conv_param->is_relu_ = false;
conv_param->is_relu6_ = false;
conv_param->act_type_ = ActType_No;
conv_param->thread_num_ = 1;
}

View File

@ -34,7 +34,7 @@ TEST_F(TestConv1x1Int8, Input1x1PrePack1) {
conv_param->input_h_ = conv_param->input_w_ = 3;
conv_param->output_h_ = conv_param->output_w_ = 3;
conv_param->stride_h_ = conv_param->stride_w_ = 2;
conv_param->pad_h_ = conv_param->pad_w_ = 1;
conv_param->pad_u_ = conv_param->pad_l_ = 1;
int8_t in[] = {4, 13, -3, 16, 19, 8, 19, -6, -2, -9, 9, 18, 23, 8, 47, -14, 15, 4,
-0, 37, -0, 6, 0, -1, 37, 13, 11, 1, -1, 41, 9, 14, 3, 0, 8, 9,
14, -14, -8, -8, -8, 7, 19, 17, 13, 3, 9, 18, -1, -0, 18, 0, 4, -2};
@ -61,7 +61,7 @@ TEST_F(TestConv1x1Int8, Input1x1PrePack2) {
conv_param->output_h_ = 4;
conv_param->output_w_ = 5;
conv_param->stride_h_ = conv_param->stride_w_ = 4;
conv_param->pad_h_ = conv_param->pad_w_ = 2;
conv_param->pad_u_ = conv_param->pad_l_ = 2;
int8_t out[20] = {0};
Conv1x1InputPack(in, out, conv_param, sizeof(int8_t));
@ -111,8 +111,8 @@ int Conv1x1Int8TestInit1_perchannel(std::vector<lite::tensor::Tensor *> *inputs_
conv_param->kernel_h_ = conv_param->kernel_w_ = 1;
conv_param->stride_h_ = conv_param->stride_w_ = 1;
conv_param->dilation_h_ = conv_param->dilation_w_ = 1;
conv_param->pad_h_ = conv_param->pad_w_ = 0;
conv_param->is_relu_ = conv_param->is_relu6_ = false;
conv_param->pad_u_ = conv_param->pad_l_ = 0;
conv_param->act_type_ = ActType_No;
return out_t->ElementsNum();
}
@ -178,8 +178,8 @@ int Conv1x1Int8TestInit1(std::vector<lite::tensor::Tensor *> *inputs_, std::vect
conv_param->kernel_h_ = conv_param->kernel_w_ = 1;
conv_param->stride_h_ = conv_param->stride_w_ = 1;
conv_param->dilation_h_ = conv_param->dilation_w_ = 1;
conv_param->pad_h_ = conv_param->pad_w_ = 0;
conv_param->is_relu_ = conv_param->is_relu6_ = false;
conv_param->pad_u_ = conv_param->pad_l_ = 0;
conv_param->act_type_ = ActType_No;
return out_t->ElementsNum();
}
@ -253,8 +253,8 @@ int Conv1x1Int8TestInit2(std::vector<lite::tensor::Tensor *> *inputs_, std::vect
conv_param->kernel_h_ = conv_param->kernel_w_ = 1;
conv_param->stride_h_ = conv_param->stride_w_ = 1;
conv_param->dilation_h_ = conv_param->dilation_w_ = 1;
conv_param->pad_h_ = conv_param->pad_w_ = 0;
conv_param->is_relu_ = conv_param->is_relu6_ = false;
conv_param->pad_u_ = conv_param->pad_l_ = 0;
conv_param->act_type_ = ActType_No;
return out_t->ElementsNum();
}

View File

@ -343,7 +343,7 @@ int DeConvInt8TestInit1(std::vector<lite::tensor::Tensor *> *inputs_, std::vecto
PackNCHWToNHWCInt8(co_nchw, *correct, out_t->Batch(), out_t->Width() * out_t->Height(), out_t->Channel());
conv_param->kernel_h_ = conv_param->kernel_w_ = 3;
conv_param->pad_h_ = conv_param->pad_w_ = 1;
conv_param->pad_u_ = conv_param->pad_l_ = 1;
conv_param->stride_h_ = conv_param->stride_w_ = 2;
conv_param->dilation_h_ = conv_param->dilation_w_ = 1;
return out_t->ElementsNum();

View File

@ -119,8 +119,8 @@ void RunTestCase(const std::vector<int> shape, const std::vector<std::string> fi
opParameter->kernel_w_ = kw;
opParameter->stride_h_ = 2;
opParameter->stride_w_ = 2;
opParameter->pad_h_ = pad;
opParameter->pad_w_ = pad;
opParameter->pad_u_ = pad;
opParameter->pad_l_ = pad;
opParameter->input_channel_ = ci;
opParameter->output_channel_ = co;
auto op_kernel_ptr = std::make_unique<kernel::Conv2dTransposeOpenCLKernel>(

View File

@ -169,8 +169,8 @@ TEST_F(TestConvolutionDwOpenCL, NoPadNC4HW4Fp32) {
conv_param->stride_w_ = 1;
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
conv_param->pad_h_ = 0;
conv_param->pad_w_ = 0;
conv_param->pad_u_ = 0;
conv_param->pad_l_ = 0;
}
// nhwc
@ -214,8 +214,8 @@ TEST_F(TestConvolutionDwOpenCL, PadNC4HW4Fp32) {
conv_param->stride_w_ = 1;
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
conv_param->pad_h_ = 1;
conv_param->pad_w_ = 1;
conv_param->pad_u_ = 1;
conv_param->pad_l_ = 1;
}
// nhwc
@ -286,8 +286,8 @@ TEST_F(TestConvolutionDwOpenCL, NoPadNHWC4Fp32) {
conv_param->stride_w_ = 1;
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
conv_param->pad_h_ = 0;
conv_param->pad_w_ = 0;
conv_param->pad_u_ = 0;
conv_param->pad_l_ = 0;
}
// nhwc
@ -331,8 +331,8 @@ TEST_F(TestConvolutionDwOpenCL, PadNHWC4Fp32) {
conv_param->stride_w_ = 1;
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
conv_param->pad_h_ = 1;
conv_param->pad_w_ = 1;
conv_param->pad_u_ = 1;
conv_param->pad_l_ = 1;
}
// nhwc
@ -405,8 +405,8 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwNoPadFp32) {
conv_param->stride_w_ = 1;
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
conv_param->pad_h_ = 0;
conv_param->pad_w_ = 0;
conv_param->pad_u_ = 0;
conv_param->pad_l_ = 0;
}
// nhwc
@ -529,8 +529,8 @@ TEST_F(TestConvolutionDwOpenCL, ConvDwPadFp32) {
conv_param->stride_w_ = 1;
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
conv_param->pad_h_ = 1;
conv_param->pad_w_ = 1;
conv_param->pad_u_ = 1;
conv_param->pad_l_ = 1;
}
// nhwc
@ -724,8 +724,8 @@ TEST_F(TestConvolutionDwOpenCL, ProfilingMobilenetv2) {
conv_param->kernel_w_ = filter_shape[i][2];
conv_param->stride_h_ = conv_param->output_h_ / conv_param->input_h_;
conv_param->stride_w_ = conv_param->output_w_ / conv_param->input_w_;
conv_param->pad_h_ = (conv_param->kernel_h_ - 1) / 2;
conv_param->pad_w_ = (conv_param->kernel_w_ - 1) / 2;
conv_param->pad_u_ = (conv_param->kernel_h_ - 1) / 2;
conv_param->pad_l_ = (conv_param->kernel_w_ - 1) / 2;
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
}
@ -774,8 +774,8 @@ TEST_F(TestConvolutionDwOpenCL, Buffer2Image) {
conv_param->kernel_w_ = filter_shape[2];
conv_param->stride_h_ = conv_param->output_h_ / conv_param->input_h_;
conv_param->stride_w_ = conv_param->output_w_ / conv_param->input_w_;
conv_param->pad_h_ = (conv_param->kernel_h_ - 1) / 2;
conv_param->pad_w_ = (conv_param->kernel_w_ - 1) / 2;
conv_param->pad_u_ = (conv_param->kernel_h_ - 1) / 2;
conv_param->pad_l_ = (conv_param->kernel_w_ - 1) / 2;
conv_param->dilation_h_ = 1;
conv_param->dilation_w_ = 1;
}