forked from mindspore-Ecosystem/mindspore
!5182 optimization for winograd matmul
Merge pull request !5182 from lixian/master
This commit is contained in:
commit
bd12a37d4f
|
@ -0,0 +1,144 @@
|
||||||
|
#ifdef __aarch64__
|
||||||
|
.text
|
||||||
|
.align 5
|
||||||
|
.global MatmulFloatNeon64OptRemain
|
||||||
|
#ifndef __APPLE__
|
||||||
|
.type MatmulFloatNeon64OptRemain, %function
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// void MatmulFloatNeon64(const float *a, const float *b, float *c, int depth
|
||||||
|
// int row, int col, size_t stride)
|
||||||
|
// x0: a
|
||||||
|
// x1: b
|
||||||
|
// x2: c
|
||||||
|
// x3: depth
|
||||||
|
// x4: row
|
||||||
|
// x5: col
|
||||||
|
// x6: stride
|
||||||
|
// only for winograd
|
||||||
|
MatmulFloatNeon64OptRemain:
|
||||||
|
mov x18, #32 // sizeof(float) * 8
|
||||||
|
mul x9, x3, x18 // block stride of lhs/rhs: sizeof(float) * 8 * depth
|
||||||
|
mov x18, #4
|
||||||
|
mul x8, x5, x6
|
||||||
|
mov x11, #8
|
||||||
|
mul x11, x11, x6
|
||||||
|
mul x8, x8, x18
|
||||||
|
mul x11, x11, x18
|
||||||
|
|
||||||
|
cmp x4, #4
|
||||||
|
ble LoopH4
|
||||||
|
|
||||||
|
LoopH8:
|
||||||
|
mov x10, x4 // reload lhs row
|
||||||
|
mov x12, x0 // reload lhs ptr
|
||||||
|
mov x18, x2 // reload dst ptr
|
||||||
|
|
||||||
|
LoopW8:
|
||||||
|
mov x16, x1 // reload rhs ptr
|
||||||
|
mov x13, x3 // reload depth
|
||||||
|
dup v16.4s, wzr
|
||||||
|
dup v17.4s, wzr
|
||||||
|
dup v18.4s, wzr
|
||||||
|
dup v19.4s, wzr
|
||||||
|
dup v20.4s, wzr
|
||||||
|
dup v21.4s, wzr
|
||||||
|
dup v22.4s, wzr
|
||||||
|
dup v23.4s, wzr
|
||||||
|
dup v24.4s, wzr
|
||||||
|
dup v25.4s, wzr
|
||||||
|
dup v26.4s, wzr
|
||||||
|
dup v27.4s, wzr
|
||||||
|
dup v28.4s, wzr
|
||||||
|
dup v29.4s, wzr
|
||||||
|
dup v30.4s, wzr
|
||||||
|
dup v31.4s, wzr
|
||||||
|
|
||||||
|
LoopD8:
|
||||||
|
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
|
||||||
|
ld1 {v3.4s, v4.4s}, [x16], #32
|
||||||
|
fmla v16.4s, v3.4s, v0.s[0]
|
||||||
|
fmla v18.4s, v3.4s, v0.s[1]
|
||||||
|
fmla v20.4s, v3.4s, v0.s[2]
|
||||||
|
fmla v22.4s, v3.4s, v0.s[3]
|
||||||
|
fmla v17.4s, v4.4s, v0.s[0]
|
||||||
|
fmla v19.4s, v4.4s, v0.s[1]
|
||||||
|
fmla v21.4s, v4.4s, v0.s[2]
|
||||||
|
fmla v23.4s, v4.4s, v0.s[3]
|
||||||
|
fmla v24.4s, v3.4s, v1.s[0]
|
||||||
|
fmla v26.4s, v3.4s, v1.s[1]
|
||||||
|
fmla v28.4s, v3.4s, v1.s[2]
|
||||||
|
fmla v30.4s, v3.4s, v1.s[3]
|
||||||
|
fmla v25.4s, v4.4s, v1.s[0]
|
||||||
|
fmla v27.4s, v4.4s, v1.s[1]
|
||||||
|
fmla v29.4s, v4.4s, v1.s[2]
|
||||||
|
fmla v31.4s, v4.4s, v1.s[3]
|
||||||
|
|
||||||
|
subs w13, w13, #1
|
||||||
|
bgt LoopD8
|
||||||
|
|
||||||
|
st1 {v16.4s, v17.4s}, [x18], x8
|
||||||
|
st1 {v18.4s, v19.4s}, [x18], x8
|
||||||
|
st1 {v20.4s, v21.4s}, [x18], x8
|
||||||
|
st1 {v22.4s, v23.4s}, [x18], x8
|
||||||
|
st1 {v24.4s, v25.4s}, [x18], x8
|
||||||
|
st1 {v26.4s, v27.4s}, [x18], x8
|
||||||
|
st1 {v28.4s, v29.4s}, [x18], x8
|
||||||
|
st1 {v30.4s, v31.4s}, [x18], x8
|
||||||
|
|
||||||
|
subs x10, x10, #8 // lhs row - 8
|
||||||
|
bgt LoopW8
|
||||||
|
|
||||||
|
subs x5, x5, #8 // rhs col - 8
|
||||||
|
add x1, x1, x9 // rhs ptr + stride
|
||||||
|
add x2, x2, x11
|
||||||
|
bgt LoopH8
|
||||||
|
|
||||||
|
ret
|
||||||
|
|
||||||
|
LoopH4:
|
||||||
|
mov x10, x4 // reload lhs row
|
||||||
|
mov x12, x0 // reload lhs ptr
|
||||||
|
mov x18, x2 // reload dst ptr
|
||||||
|
|
||||||
|
LoopW4:
|
||||||
|
mov x16, x1 // reload rhs ptr
|
||||||
|
mov x13, x3 // reload depth
|
||||||
|
dup v16.4s, wzr
|
||||||
|
dup v17.4s, wzr
|
||||||
|
dup v18.4s, wzr
|
||||||
|
dup v19.4s, wzr
|
||||||
|
dup v20.4s, wzr
|
||||||
|
dup v21.4s, wzr
|
||||||
|
dup v22.4s, wzr
|
||||||
|
dup v23.4s, wzr
|
||||||
|
|
||||||
|
LoopD4:
|
||||||
|
ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48
|
||||||
|
ld1 {v3.4s, v4.4s}, [x16], #32
|
||||||
|
fmla v16.4s, v3.4s, v0.s[0]
|
||||||
|
fmla v18.4s, v3.4s, v0.s[1]
|
||||||
|
fmla v20.4s, v3.4s, v0.s[2]
|
||||||
|
fmla v22.4s, v3.4s, v0.s[3]
|
||||||
|
fmla v17.4s, v4.4s, v0.s[0]
|
||||||
|
fmla v19.4s, v4.4s, v0.s[1]
|
||||||
|
fmla v21.4s, v4.4s, v0.s[2]
|
||||||
|
fmla v23.4s, v4.4s, v0.s[3]
|
||||||
|
|
||||||
|
subs x13, x13, #1
|
||||||
|
bgt LoopD4
|
||||||
|
|
||||||
|
st1 {v16.4s, v17.4s}, [x18], x8
|
||||||
|
st1 {v18.4s, v19.4s}, [x18], x8
|
||||||
|
st1 {v20.4s, v21.4s}, [x18], x8
|
||||||
|
st1 {v22.4s, v23.4s}, [x18], x8
|
||||||
|
|
||||||
|
subs x10, x10, #4 // lhs row - 4
|
||||||
|
bgt LoopW4
|
||||||
|
|
||||||
|
subs x5, x5, #8 // rhs col - 8
|
||||||
|
add x1, x1, x9 // rhs ptr + stride
|
||||||
|
add x2, x2, x11
|
||||||
|
bgt LoopH4
|
||||||
|
ret
|
||||||
|
#endif
|
|
@ -303,7 +303,7 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
||||||
for (int i = 0; i < input_unit_square; ++i) {
|
for (int i = 0; i < input_unit_square; ++i) {
|
||||||
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
|
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
|
||||||
MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM,
|
MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM,
|
||||||
C12NUM, oc8 * C8NUM, input_unit_square, 2);
|
cal_num, oc8 * C8NUM, input_unit_square, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
// step 4 : output transform
|
// step 4 : output transform
|
||||||
|
@ -489,9 +489,8 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
|
||||||
for (int i = 0; i < input_unit_square; ++i) {
|
for (int i = 0; i < input_unit_square; ++i) {
|
||||||
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
|
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
|
||||||
MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0,
|
MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0,
|
||||||
ic4 * C4NUM, C12NUM, oc8 * C8NUM, input_unit_square, 2);
|
ic4 * C4NUM, real_cal_num, oc8 * C8NUM, input_unit_square, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset,
|
Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset,
|
||||||
bias_data, start_index, real_cal_num, out_w_block, conv_param);
|
bias_data, start_index, real_cal_num, out_w_block, conv_param);
|
||||||
}
|
}
|
||||||
|
|
|
@ -386,7 +386,7 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
|
||||||
size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
|
size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
|
||||||
float value = 0;
|
float value = 0;
|
||||||
for (int d = 0; d < deep; ++d) {
|
for (int d = 0; d < deep; ++d) {
|
||||||
size_t ai = src_r_offset + d * row;
|
size_t ai = src_r_offset + d * C12NUM;
|
||||||
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
|
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
|
||||||
value = value + a[ai] * b[bi];
|
value = value + a[ai] * b[bi];
|
||||||
}
|
}
|
||||||
|
@ -403,8 +403,12 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
|
||||||
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
|
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
|
||||||
int col, size_t stride, int out_type) {
|
int col, size_t stride, int out_type) {
|
||||||
#ifdef ENABLE_ARM64
|
#ifdef ENABLE_ARM64
|
||||||
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
|
if (out_type == 2 && row <= 8) {
|
||||||
(int)(out_type == OutType_TileC8));
|
MatmulFloatNeon64OptRemain(a, b, c, deep, row, col, stride);
|
||||||
|
} else {
|
||||||
|
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
|
||||||
|
(int)(out_type == OutType_TileC8));
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
|
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -39,6 +39,7 @@ void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bi
|
||||||
int col, size_t stride, bool write_nhwc);
|
int col, size_t stride, bool write_nhwc);
|
||||||
void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||||
int col, size_t stride, size_t write_nhwc, size_t write_c4);
|
int col, size_t stride, size_t write_nhwc, size_t write_c4);
|
||||||
|
void MatmulFloatNeon64OptRemain(const float *a, const float *b, float *c, int depth, int row, int col, size_t stride);
|
||||||
#endif
|
#endif
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue