forked from mindspore-Ecosystem/mindspore
add c4 output support for fp32 matmul kernel
This commit is contained in:
parent
75af54647f
commit
99cc44d25c
|
@ -47,7 +47,7 @@
|
|||
/////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
|
||||
// int row, int col, int stride, bool write_nhwc)
|
||||
// int row, int col, size_t stride, size_t writeNhwc, size_t writeC4)
|
||||
// x0: a
|
||||
// x1: b
|
||||
// x2: c
|
||||
|
@ -57,7 +57,7 @@
|
|||
// w6: row
|
||||
// w7: col
|
||||
// w17: stride
|
||||
// w13: writeC8
|
||||
// w13: c8_nhwc_c4
|
||||
|
||||
MatmulFloatNeon64Opt:
|
||||
sub sp, sp, #128
|
||||
|
@ -209,8 +209,8 @@ Activation:
|
|||
b Write
|
||||
|
||||
Relu6:
|
||||
mov w8, #6
|
||||
dup v2.4s, w8
|
||||
mov w13, #6
|
||||
dup v2.4s, w13
|
||||
scvtf v2.4s, v2.4s
|
||||
fmin v8.4s, v8.4s, v2.4s
|
||||
fmin v9.4s, v9.4s, v2.4s
|
||||
|
@ -265,8 +265,10 @@ Relu:
|
|||
fmax v31.4s, v31.4s, v3.4s
|
||||
|
||||
Write:
|
||||
ldrb w13, [sp, #8]
|
||||
cbz w13, WriteC8
|
||||
ldr w8, [sp, #8]
|
||||
cbz w8, WriteC8
|
||||
ldr w8, [sp, #16]
|
||||
cbnz w8, WriteC4
|
||||
cmp w7, #1
|
||||
beq Write1
|
||||
cmp w7, #2
|
||||
|
@ -726,6 +728,33 @@ WriteC8:
|
|||
st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x2], #64
|
||||
st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x2], #64
|
||||
b WriteEnd
|
||||
WriteC4:
|
||||
st1 {v8.8h}, [x2], #16
|
||||
st1 {v10.8h}, [x2], #16
|
||||
st1 {v12.8h}, [x2], #16
|
||||
st1 {v14.8h}, [x2], #16
|
||||
st1 {v16.8h}, [x2], #16
|
||||
st1 {v18.8h}, [x2], #16
|
||||
st1 {v20.8h}, [x2], #16
|
||||
st1 {v22.8h}, [x2], #16
|
||||
st1 {v24.8h}, [x2], #16
|
||||
st1 {v26.8h}, [x2], #16
|
||||
st1 {v28.8h}, [x2], #16
|
||||
st1 {v30.8h}, [x2], #16
|
||||
add x18, x2, x17
|
||||
st1 {v9.8h}, [x18], #16
|
||||
st1 {v11.8h}, [x18], #16
|
||||
st1 {v13.8h}, [x18], #16
|
||||
st1 {v15.8h}, [x18], #16
|
||||
st1 {v17.8h}, [x18], #16
|
||||
st1 {v19.8h}, [x18], #16
|
||||
st1 {v21.8h}, [x18], #16
|
||||
st1 {v23.8h}, [x18], #16
|
||||
st1 {v25.8h}, [x18], #16
|
||||
st1 {v27.8h}, [x18], #16
|
||||
st1 {v29.8h}, [x18], #16
|
||||
st1 {v31.8h}, [x18], #16
|
||||
b WriteEnd
|
||||
Write8:
|
||||
st1 {v8.4s, v9.4s}, [x18], x17
|
||||
cmp w10, #1
|
||||
|
@ -770,9 +799,14 @@ End2:
|
|||
subs w7, w7, #8 // rhs col - 8
|
||||
add x1, x1, x15 // rhs ptr + stride
|
||||
add x3, x3, #32 // bias ptr + stride
|
||||
ldrb w13, [sp, #8]
|
||||
cbz w13, NoDstStep
|
||||
ldr w8, [sp, #8]
|
||||
cbz w8, NoDstStep
|
||||
ldr w8, [sp, #16]
|
||||
cbnz w8, C4DstStep
|
||||
add x2, x2, #32 // dst ptr + stride
|
||||
b NoDstStep
|
||||
C4DstStep:
|
||||
add x2, x2, x17
|
||||
NoDstStep:
|
||||
bgt L1
|
||||
|
||||
|
|
|
@ -370,8 +370,8 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac
|
|||
}
|
||||
|
||||
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, int stride, bool write_nhwc) {
|
||||
if (write_nhwc) {
|
||||
int col, size_t stride, size_t writeNhwc, size_t writeC4) {
|
||||
if (writeNhwc != 0) {
|
||||
/* col8-major * row8-major => col-major */
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
|
@ -404,10 +404,10 @@ void MatMul(const float *a, const float *b, float *c, const float *bias, ActType
|
|||
}
|
||||
|
||||
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, int stride, bool write_nhwc) {
|
||||
int col, size_t stride, size_t writeNhwc, size_t writeC4) {
|
||||
#ifdef ENABLE_ARM64
|
||||
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc);
|
||||
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, writeNhwc, writeC4);
|
||||
#else
|
||||
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc);
|
||||
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, writeNhwc, writeC4);
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ extern "C" {
|
|||
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col,
|
||||
int stride, bool write_nhwc);
|
||||
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row,
|
||||
int col, int stride, bool write_nhwc);
|
||||
int col, size_t stride, size_t writeNhwc, size_t writeC4);
|
||||
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
|
@ -38,7 +38,7 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col
|
|||
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col, size_t stride, bool write_nhwc);
|
||||
void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col, size_t stride, bool write_nhwc);
|
||||
int col, size_t stride, size_t writeNhwc, size_t writeC4);
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -145,7 +145,7 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) {
|
|||
|
||||
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
|
||||
output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_,
|
||||
matmul_param_->row_, cur_oc, matmul_param_->col_, true);
|
||||
matmul_param_->row_, cur_oc, matmul_param_->col_, 1, 0);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue