!5182 optimization for winograd matmul

Merge pull request !5182 from lixian/master
This commit is contained in:
mindspore-ci-bot 2020-08-25 22:49:07 +08:00 committed by Gitee
commit bd12a37d4f
4 changed files with 154 additions and 6 deletions

View File

@ -0,0 +1,144 @@
#ifdef __aarch64__
.text
.align 5
.global MatmulFloatNeon64OptRemain
#ifndef __APPLE__
.type MatmulFloatNeon64OptRemain, %function
#endif
// void MatmulFloatNeon64(const float *a, const float *b, float *c, int depth
// int row, int col, size_t stride)
// x0: a
// x1: b
// x2: c
// x3: depth
// x4: row
// x5: col
// x6: stride
// only for winograd
MatmulFloatNeon64OptRemain:
mov x18, #32 // sizeof(float) * 8
mul x9, x3, x18 // block stride of lhs/rhs: sizeof(float) * 8 * depth
mov x18, #4
mul x8, x5, x6
mov x11, #8
mul x11, x11, x6
mul x8, x8, x18
mul x11, x11, x18
cmp x4, #4
ble LoopH4
LoopH8:
mov x10, x4 // reload lhs row
mov x12, x0 // reload lhs ptr
mov x18, x2 // reload dst ptr
LoopW8:
mov x16, x1 // reload rhs ptr
mov x13, x3 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
dup v24.4s, wzr
dup v25.4s, wzr
dup v26.4s, wzr
dup v27.4s, wzr
dup v28.4s, wzr
dup v29.4s, wzr
dup v30.4s, wzr
dup v31.4s, wzr
LoopD8:
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
ld1 {v3.4s, v4.4s}, [x16], #32
fmla v16.4s, v3.4s, v0.s[0]
fmla v18.4s, v3.4s, v0.s[1]
fmla v20.4s, v3.4s, v0.s[2]
fmla v22.4s, v3.4s, v0.s[3]
fmla v17.4s, v4.4s, v0.s[0]
fmla v19.4s, v4.4s, v0.s[1]
fmla v21.4s, v4.4s, v0.s[2]
fmla v23.4s, v4.4s, v0.s[3]
fmla v24.4s, v3.4s, v1.s[0]
fmla v26.4s, v3.4s, v1.s[1]
fmla v28.4s, v3.4s, v1.s[2]
fmla v30.4s, v3.4s, v1.s[3]
fmla v25.4s, v4.4s, v1.s[0]
fmla v27.4s, v4.4s, v1.s[1]
fmla v29.4s, v4.4s, v1.s[2]
fmla v31.4s, v4.4s, v1.s[3]
subs w13, w13, #1
bgt LoopD8
st1 {v16.4s, v17.4s}, [x18], x8
st1 {v18.4s, v19.4s}, [x18], x8
st1 {v20.4s, v21.4s}, [x18], x8
st1 {v22.4s, v23.4s}, [x18], x8
st1 {v24.4s, v25.4s}, [x18], x8
st1 {v26.4s, v27.4s}, [x18], x8
st1 {v28.4s, v29.4s}, [x18], x8
st1 {v30.4s, v31.4s}, [x18], x8
subs x10, x10, #8 // lhs row - 8
bgt LoopW8
subs x5, x5, #8 // rhs col - 8
add x1, x1, x9 // rhs ptr + stride
add x2, x2, x11
bgt LoopH8
ret
LoopH4:
mov x10, x4 // reload lhs row
mov x12, x0 // reload lhs ptr
mov x18, x2 // reload dst ptr
LoopW4:
mov x16, x1 // reload rhs ptr
mov x13, x3 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
LoopD4:
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
ld1 {v3.4s, v4.4s}, [x16], #32
fmla v16.4s, v3.4s, v0.s[0]
fmla v18.4s, v3.4s, v0.s[1]
fmla v20.4s, v3.4s, v0.s[2]
fmla v22.4s, v3.4s, v0.s[3]
fmla v17.4s, v4.4s, v0.s[0]
fmla v19.4s, v4.4s, v0.s[1]
fmla v21.4s, v4.4s, v0.s[2]
fmla v23.4s, v4.4s, v0.s[3]
subs x13, x13, #1
bgt LoopD4
st1 {v16.4s, v17.4s}, [x18], x8
st1 {v18.4s, v19.4s}, [x18], x8
st1 {v20.4s, v21.4s}, [x18], x8
st1 {v22.4s, v23.4s}, [x18], x8
subs x10, x10, #4 // lhs row - 4
bgt LoopW4
subs x5, x5, #8 // rhs col - 8
add x1, x1, x9 // rhs ptr + stride
add x2, x2, x11
bgt LoopH4
ret
#endif

View File

@ -303,7 +303,7 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
for (int i = 0; i < input_unit_square; ++i) { for (int i = 0; i < input_unit_square; ++i) {
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM); RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM, MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM,
C12NUM, oc8 * C8NUM, input_unit_square, 2); cal_num, oc8 * C8NUM, input_unit_square, 2);
} }
// step 4 : output transform // step 4 : output transform
@ -489,9 +489,8 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
for (int i = 0; i < input_unit_square; ++i) { for (int i = 0; i < input_unit_square; ++i) {
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM); RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0,
ic4 * C4NUM, C12NUM, oc8 * C8NUM, input_unit_square, 2); ic4 * C4NUM, real_cal_num, oc8 * C8NUM, input_unit_square, 2);
} }
Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset, Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset,
bias_data, start_index, real_cal_num, out_w_block, conv_param); bias_data, start_index, real_cal_num, out_w_block, conv_param);
} }

View File

@ -386,7 +386,7 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
float value = 0; float value = 0;
for (int d = 0; d < deep; ++d) { for (int d = 0; d < deep; ++d) {
size_t ai = src_r_offset + d * row; size_t ai = src_r_offset + d * C12NUM;
size_t bi = c8div * deep * 8 + d * 8 + c8mod; size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi]; value = value + a[ai] * b[bi];
} }
@ -403,8 +403,12 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
int col, size_t stride, int out_type) { int col, size_t stride, int out_type) {
#ifdef ENABLE_ARM64 #ifdef ENABLE_ARM64
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc), if (out_type == 2 && row <= 8) {
(int)(out_type == OutType_TileC8)); MatmulFloatNeon64OptRemain(a, b, c, deep, row, col, stride);
} else {
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
(int)(out_type == OutType_TileC8));
}
#else #else
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type); MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
#endif #endif

View File

@ -39,6 +39,7 @@ void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bi
int col, size_t stride, bool write_nhwc); int col, size_t stride, bool write_nhwc);
void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, size_t write_nhwc, size_t write_c4); int col, size_t stride, size_t write_nhwc, size_t write_c4);
void MatmulFloatNeon64OptRemain(const float *a, const float *b, float *c, int depth, int row, int col, size_t stride);
#endif #endif
#ifdef __cplusplus #ifdef __cplusplus
} }