diff --git a/mindspore/lite/nnacl/assembly/arm64/MatmulFp32OptRemain.S b/mindspore/lite/nnacl/assembly/arm64/MatmulFp32OptRemain.S new file mode 100644 index 00000000000..dd42b492453 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm64/MatmulFp32OptRemain.S @@ -0,0 +1,144 @@ +#ifdef __aarch64__ + .text + .align 5 + .global MatmulFloatNeon64OptRemain +#ifndef __APPLE__ + .type MatmulFloatNeon64OptRemain, %function +#endif + +// void MatmulFloatNeon64(const float *a, const float *b, float *c, int depth +// int row, int col, size_t stride) +// x0: a +// x1: b +// x2: c +// x3: depth +// x4: row +// x5: col +// x6: stride +// only for winograd +MatmulFloatNeon64OptRemain: + mov x18, #32 // sizeof(float) * 8 + mul x9, x3, x18 // block stride of lhs/rhs: sizeof(float) * 8 * depth + mov x18, #4 + mul x8, x5, x6 + mov x11, #8 + mul x11, x11, x6 + mul x8, x8, x18 + mul x11, x11, x18 + + cmp x4, #4 + ble LoopH4 + + LoopH8: + mov x10, x4 // reload lhs row + mov x12, x0 // reload lhs ptr + mov x18, x2 // reload dst ptr + + LoopW8: + mov x16, x1 // reload rhs ptr + mov x13, x3 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + LoopD8: + ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48 + ld1 {v3.4s, v4.4s}, [x16], #32 + fmla v16.4s, v3.4s, v0.s[0] + fmla v18.4s, v3.4s, v0.s[1] + fmla v20.4s, v3.4s, v0.s[2] + fmla v22.4s, v3.4s, v0.s[3] + fmla v17.4s, v4.4s, v0.s[0] + fmla v19.4s, v4.4s, v0.s[1] + fmla v21.4s, v4.4s, v0.s[2] + fmla v23.4s, v4.4s, v0.s[3] + fmla v24.4s, v3.4s, v1.s[0] + fmla v26.4s, v3.4s, v1.s[1] + fmla v28.4s, v3.4s, v1.s[2] + fmla v30.4s, v3.4s, v1.s[3] + fmla v25.4s, v4.4s, v1.s[0] + fmla v27.4s, v4.4s, v1.s[1] + fmla v29.4s, v4.4s, v1.s[2] + fmla v31.4s, v4.4s, v1.s[3] + + subs w13, w13, #1 + bgt LoopD8 + + st1 {v16.4s, v17.4s}, [x18], x8 + st1 {v18.4s, v19.4s}, [x18], x8 + st1 {v20.4s, v21.4s}, [x18], x8 + st1 {v22.4s, v23.4s}, [x18], x8 + st1 {v24.4s, v25.4s}, [x18], x8 + st1 {v26.4s, v27.4s}, [x18], x8 + st1 {v28.4s, v29.4s}, [x18], x8 + st1 {v30.4s, v31.4s}, [x18], x8 + + subs x10, x10, #8 // lhs row - 8 + bgt LoopW8 + + subs x5, x5, #8 // rhs col - 8 + add x1, x1, x9 // rhs ptr + stride + add x2, x2, x11 + bgt LoopH8 + + ret + + LoopH4: + mov x10, x4 // reload lhs row + mov x12, x0 // reload lhs ptr + mov x18, x2 // reload dst ptr + + LoopW4: + mov x16, x1 // reload rhs ptr + mov x13, x3 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + + LoopD4: + ld1 {v0.4s, v1.4s, v2.4s}, [x12], #48 + ld1 {v3.4s, v4.4s}, [x16], #32 + fmla v16.4s, v3.4s, v0.s[0] + fmla v18.4s, v3.4s, v0.s[1] + fmla v20.4s, v3.4s, v0.s[2] + fmla v22.4s, v3.4s, v0.s[3] + fmla v17.4s, v4.4s, v0.s[0] + fmla v19.4s, v4.4s, v0.s[1] + fmla v21.4s, v4.4s, v0.s[2] + fmla v23.4s, v4.4s, v0.s[3] + + subs x13, x13, #1 + bgt LoopD4 + + st1 {v16.4s, v17.4s}, [x18], x8 + st1 {v18.4s, v19.4s}, [x18], x8 + st1 {v20.4s, v21.4s}, [x18], x8 + st1 {v22.4s, v23.4s}, [x18], x8 + + subs x10, x10, #4 // lhs row - 4 + bgt LoopW4 + + subs x5, x5, #8 // rhs col - 8 + add x1, x1, x9 // rhs ptr + stride + add x2, x2, x11 + bgt LoopH4 + ret +#endif diff --git a/mindspore/lite/nnacl/fp32/conv.c b/mindspore/lite/nnacl/fp32/conv.c index ec5fdad0bb0..504ab0670f0 100644 --- a/mindspore/lite/nnacl/fp32/conv.c +++ b/mindspore/lite/nnacl/fp32/conv.c @@ -303,7 +303,7 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ for (int i = 0; i < input_unit_square; ++i) { RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM); MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM, - C12NUM, oc8 * C8NUM, input_unit_square, 2); + cal_num, oc8 * C8NUM, input_unit_square, 2); } // step 4 : output transform @@ -489,9 +489,8 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat for (int i = 0; i < input_unit_square; ++i) { RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM); MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, - ic4 * C4NUM, C12NUM, oc8 * C8NUM, input_unit_square, 2); + ic4 * C4NUM, real_cal_num, oc8 * C8NUM, input_unit_square, 2); } - Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset, bias_data, start_index, real_cal_num, out_w_block, conv_param); } diff --git a/mindspore/lite/nnacl/fp32/matmul.c b/mindspore/lite/nnacl/fp32/matmul.c index dd5d2979144..c44daeb2dfe 100644 --- a/mindspore/lite/nnacl/fp32/matmul.c +++ b/mindspore/lite/nnacl/fp32/matmul.c @@ -386,7 +386,7 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; float value = 0; for (int d = 0; d < deep; ++d) { - size_t ai = src_r_offset + d * row; + size_t ai = src_r_offset + d * C12NUM; size_t bi = c8div * deep * 8 + d * 8 + c8mod; value = value + a[ai] * b[bi]; } @@ -403,8 +403,12 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, size_t stride, int out_type) { #ifdef ENABLE_ARM64 - MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc), - (int)(out_type == OutType_TileC8)); + if (out_type == 2 && row <= 8) { + MatmulFloatNeon64OptRemain(a, b, c, deep, row, col, stride); + } else { + MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc), + (int)(out_type == OutType_TileC8)); + } #else MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type); #endif diff --git a/mindspore/lite/nnacl/fp32/matmul.h b/mindspore/lite/nnacl/fp32/matmul.h index 9a2dc7de14a..61759b60be9 100644 --- a/mindspore/lite/nnacl/fp32/matmul.h +++ b/mindspore/lite/nnacl/fp32/matmul.h @@ -39,6 +39,7 @@ void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bi int col, size_t stride, bool write_nhwc); void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col, size_t stride, size_t write_nhwc, size_t write_c4); +void MatmulFloatNeon64OptRemain(const float *a, const float *b, float *c, int depth, int row, int col, size_t stride); #endif #ifdef __cplusplus }