!4399 matmul optimize

Merge pull request !4399 from ling/fc
This commit is contained in:
mindspore-ci-bot 2020-08-15 09:21:32 +08:00 committed by Gitee
commit 9afdf9be9b
10 changed files with 460 additions and 115 deletions

View File

@ -31,10 +31,6 @@ Convolution1x1CPUKernel::~Convolution1x1CPUKernel() {
free(pack_input_);
pack_input_ = nullptr;
}
if (pack_output_ != nullptr) {
free(pack_output_);
pack_output_ = nullptr;
}
if (pre_trans_input_ && input_ptr_ != nullptr) {
free(input_ptr_);
input_ptr_ = nullptr;
@ -112,13 +108,6 @@ int Convolution1x1CPUKernel::InitConv1x1Param() {
return RET_MEMORY_FAILED;
}
memset(pack_input_, 0, matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float));
pack_output_ = reinterpret_cast<float *>(malloc(matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float)));
if (pack_output_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc pack_output_ error!";
return RET_MEMORY_FAILED;
}
memset(pack_output_, 0, matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float));
return RET_OK;
}
@ -157,7 +146,7 @@ int Convolution1x1CPUKernel::Init() {
}
int Convolution1x1CPUKernel::DoConv1x1(int task_id) {
int cur_oc = MSMIN(thread_stride_, matmul_param_->col_8_ - task_id * thread_stride_);
int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
@ -165,23 +154,12 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) {
auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id;
MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
pack_output_ + task_id * thread_stride_ * matmul_param_->row_8_, bias, matmul_param_->act_type_,
matmul_param_->deep_, matmul_param_->row_8_, cur_oc);
output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_,
matmul_param_->row_, cur_oc, matmul_param_->col_, true);
return RET_OK;
}
int Convolution1x1CPUKernel::DoConv1x1Post(int task_id) {
int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
float *src = pack_output_ + task_id * thread_stride_ * matmul_param_->row_8_;
float *dst = output_ptr_ + task_id * thread_stride_;
Row8x8Major2RowMajor(src, dst, matmul_param_->row_, cur_oc, matmul_param_->col_);
return RET_OK;
}
int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata);
auto error_code = conv1x1->DoConv1x1(task_id);
@ -192,12 +170,6 @@ int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
return RET_OK;
}
int Convolution1x1Post(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata);
conv1x1->DoConv1x1Post(task_id);
return RET_OK;
}
int Convolution1x1CPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
@ -216,8 +188,6 @@ int Convolution1x1CPUKernel::Run() {
MS_LOG(ERROR) << "conv1x1 strassen error error_code[" << error_code << "]";
return RET_ERROR;
}
LiteBackendParallelLaunch(Convolution1x1Post, this, thread_count_);
}
return RET_OK;
}

View File

@ -46,7 +46,6 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel {
public:
int DoConv1x1(int task_id);
int DoConv1x1Post(int task_id);
private:
int InitConv1x1Param();
@ -61,7 +60,6 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel {
int thread_stride_ = 0;
float *weight_ptr_ = nullptr;
float *pack_input_ = nullptr;
float *pack_output_ = nullptr;
float *input_ptr_ = nullptr;
float *output_ptr_ = nullptr;
};

View File

@ -152,7 +152,7 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) {
MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_, nullptr, ActType_No,
matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_);
matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_, matmul_param_->col_, false);
return RET_OK;
}

View File

@ -104,7 +104,7 @@ int FullconnectionCPUKernel::DoMatmul(int task_id) {
MatMul(a_c8_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_,
c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_,
bias_ptr_ + task_id * thread_stride_ * C8NUM, fc_param_->act_type_, fc_param_->deep_, fc_param_->row_8_,
cur_oc * 8);
cur_oc * 8, 0, false);
return RET_OK;
}

View File

@ -77,7 +77,7 @@ int MatmulCPUKernel::RunImpl(int task_id) {
}
auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_;
auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_;
MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8);
MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8, 0, false);
return RET_OK;
}

View File

@ -640,7 +640,7 @@ IndirectGemmStart:
add x15, x15, x7
str s30, [x15]
add x0, x0, #4
b WriteEnd
b WriteEndHalf
Write2:
dup s17, v16.s[1]
stp s16, s17, [x15]
@ -666,7 +666,7 @@ IndirectGemmStart:
dup s31, v30.s[1]
stp s30, s31, [x15]
add x0, x0, #8
b WriteEnd
b WriteEndHalf
Write3:
add x17, x15, #8
dup s17, v16.s[1]

View File

@ -27,7 +27,7 @@
// accumulators 8x8 block
//
///////////////////////////////////////////////////////////////////////////////
//OptLoopMul4 RM 1x8 block
//OptLoopMul4 RM 4x8 block
// /--------------------------------------------\
// |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] |
// |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]|
@ -46,7 +46,8 @@
// accumulators 8x8 block
/////////////////////////////////////////////////////////////////////////////////
//
// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col)
// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
// int row, int col, int stride, bool write_nhwc)
// x0: a
// x1: b
// x2: c
@ -55,30 +56,30 @@
// w5: depth
// w6: row
// w7: col
// w17: stride
// w13: writeC8
MatmulFloatNeon64:
sub sp, sp, #128
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
mov w9, #0 // rm col offset
mov w10, #0 // lm row offset
mov w18, #32 // sizeof(float)*8
mul w15, w5, w18 // the stride of lm/rm: sizeof(float)*8*depth
mov x11, x3 // bias flag
mov w18, #32 // sizeof(float) * 8
mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth
mov x11, x3 // bias flag
mov x18, #4
ldr x17, [sp]
mul x17, x17, x18
L1:
cmp w9, w7
beq End1
mov w10, w6 // reload lhs row
mov x12, x0 // reload lhs ptr
mov x18, x2 // reload dst ptr
mov w10, #0 // reset lm row offset
mov x12, x0 // reload lm ptr
L2:
cmp w10, w6
beq End2
mov x16, x1 // reload rm ptr
mov w13, w5 // reload depth
mov x14, x3 // reload bias ptr
mov x16, x1 // reload rhs ptr
mov w13, w5 // reload depth
mov x14, x3 // reload bias ptr
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
@ -96,10 +97,10 @@ L2:
dup v30.4s, wzr
dup v31.4s, wzr
OptLoopMul4:
cmp w13, #4
blt CommLoopMul
OptLoopMul4:
ld1 {v0.4s, v1.4s}, [x12], #32
ld1 {v8.4s, v9.4s}, [x16], #32
fmla v16.4s, v8.4s, v0.s[0]
@ -172,13 +173,14 @@ OptLoopMul4:
fmla v29.4s, v15.4s, v7.s[2]
fmla v30.4s, v14.4s, v7.s[3]
fmla v31.4s, v15.4s, v7.s[3]
subs w13, w13, #4
b OptLoopMul4
sub w13, w13, #4
cmp w13, #0
ble Bias
cmp w13, #4
bge OptLoopMul4
CommLoopMul:
cmp w13, #1
blt Bias
ld1 {v0.4s, v1.4s}, [x12], #32
ld1 {v2.4s, v3.4s}, [x16], #32
fmla v16.4s, v2.4s, v0.s[0]
@ -197,8 +199,9 @@ CommLoopMul:
fmla v29.4s, v3.4s, v1.s[2]
fmla v30.4s, v2.4s, v1.s[3]
fmla v31.4s, v3.4s, v1.s[3]
subs w13, w13, #1
b CommLoopMul
bgt CommLoopMul
Bias:
cbz x11, Activation
@ -226,7 +229,8 @@ Activation:
beq Relu6
cmp w4, #1
beq Relu
b TransToOut
b Write
Relu6:
mov w8, #6
dup v15.4s, w8
@ -247,6 +251,7 @@ Relu6:
fmin v29.4s, v29.4s, v15.4s
fmin v30.4s, v30.4s, v15.4s
fmin v31.4s, v31.4s, v15.4s
Relu:
dup v14.4s, wzr
fmax v16.4s, v16.4s, v14.4s
@ -266,7 +271,317 @@ Relu:
fmax v30.4s, v30.4s, v14.4s
fmax v31.4s, v31.4s, v14.4s
TransToOut:
Write:
ldrb w13, [sp, #8]
cbz w13, WriteC8
cmp w7, #1
beq Write1
cmp w7, #2
beq Write2
cmp w7, #3
beq Write3
cmp w7, #4
beq Write4
cmp w7, #5
beq Write5
cmp w7, #6
beq Write6
cmp w7, #7
beq Write7
b Write8
Write1:
str s16, [x18]
cmp w10, #1
beq WriteEnd
add x18, x18, x17
str s18, [x18]
cmp w10, #2
beq WriteEnd
add x18, x18, x17
str s20, [x18]
cmp w10, #3
beq WriteEnd
add x18, x18, x17
str s22, [x18]
cmp w10, #4
beq WriteEnd
add x18, x18, x17
str s24, [x18]
cmp w10, #5
beq WriteEnd
add x18, x18, x17
str s26, [x18]
cmp w10, #6
beq WriteEnd
add x18, x18, x17
str s28, [x18]
cmp w10, #7
beq WriteEnd
add x18, x18, x17
str s30, [x18]
add x18, x18, x17
b WriteEnd
Write2:
dup s17, v16.s[1]
stp s16, s17, [x18]
cmp w10, #1
beq WriteEnd
add x18, x18, x17
dup s19, v18.s[1]
stp s18, s19, [x18]
cmp w10, #2
beq WriteEnd
add x18, x18, x17
dup s21, v20.s[1]
stp s20, s21, [x18]
cmp w10, #3
beq WriteEnd
add x18, x18, x17
dup s23, v22.s[1]
stp s22, s23, [x18]
cmp w10, #4
beq WriteEnd
add x18, x18, x17
dup s25, v24.s[1]
stp s24, s25, [x18]
cmp w10, #5
beq WriteEnd
add x18, x18, x17
dup s27, v26.s[1]
stp s26, s27, [x18]
cmp w10, #6
beq WriteEnd
add x18, x18, x17
dup s29, v28.s[1]
stp s28, s29, [x18]
cmp w10, #7
beq WriteEnd
add x18, x18, x17
dup s31, v30.s[1]
stp s30, s31, [x18]
add x18, x18, x17
b WriteEnd
Write3:
add x13, x18, #8
dup s17, v16.s[1]
stp s16, s17, [x18]
add x18, x18, x17
st1 {v16.s}[2], [x13], x17
cmp w10, #1
beq WriteEnd
dup s19, v18.s[1]
stp s18, s19, [x18]
add x18, x18, x17
st1 {v18.s}[2], [x13], x17
cmp w10, #2
beq WriteEnd
dup s21, v20.s[1]
stp s20, s21, [x18]
add x18, x18, x17
st1 {v20.s}[2], [x13], x17
cmp w10, #3
beq WriteEnd
dup s23, v22.s[1]
stp s22, s23, [x18]
add x18, x18, x17
st1 {v22.s}[2], [x13], x17
cmp w10, #4
beq WriteEnd
dup s25, v24.s[1]
stp s24, s25, [x18]
add x18, x18, x17
st1 {v24.s}[2], [x13], x17
cmp w10, #5
beq WriteEnd
dup s27, v26.s[1]
stp s26, s27, [x18]
add x18, x18, x17
st1 {v26.s}[2], [x13], x17
cmp w10, #6
beq WriteEnd
dup s29, v28.s[1]
stp s28, s29, [x18]
add x18, x18, x17
st1 {v28.s}[2], [x13], x17
cmp w10, #7
beq WriteEnd
dup s31, v30.s[1]
stp s30, s31, [x18]
add x18, x18, x17
st1 {v30.s}[2], [x13]
b WriteEnd
Write4:
st1 {v16.4s}, [x18], x17
cmp w10, #1
beq WriteEnd
st1 {v18.4s}, [x18], x17
cmp w10, #2
beq WriteEnd
st1 {v20.4s}, [x18], x17
cmp w10, #3
beq WriteEnd
st1 {v22.4s}, [x18], x17
cmp w10, #4
beq WriteEnd
st1 {v24.4s}, [x18], x17
cmp w10, #5
beq WriteEnd
st1 {v26.4s}, [x18], x17
cmp w10, #6
beq WriteEnd
st1 {v28.4s}, [x18], x17
cmp w10, #7
beq WriteEnd
st1 {v30.4s}, [x18], x17
b WriteEnd
Write5:
add x13, x18, #16
st1 {v16.4s}, [x18], x17
str s17, [x13]
cmp w10, #1
beq WriteEnd
add x13, x13, x17
st1 {v18.4s}, [x18], x17
str s19, [x13]
cmp w10, #2
beq WriteEnd
add x13, x13, x17
st1 {v20.4s}, [x18], x17
str s21, [x13]
cmp w10, #3
beq WriteEnd
add x13, x13, x17
st1 {v22.4s}, [x18], x17
str s23, [x13]
cmp w10, #4
beq WriteEnd
add x13, x13, x17
st1 {v24.4s}, [x18], x17
str s25, [x13]
cmp w10, #5
beq WriteEnd
add x13, x13, x17
st1 {v26.4s}, [x18], x17
str s27, [x13]
cmp w10, #6
beq WriteEnd
add x13, x13, x17
st1 {v28.4s}, [x18], x17
str s29, [x13]
cmp w10, #7
beq WriteEnd
add x13, x13, x17
st1 {v30.4s}, [x18], x17
str s31, [x13]
b WriteEnd
Write6:
add x13, x18, #16
st1 {v16.4s}, [x18], x17
dup s16, v17.s[1]
stp s17, s16, [x13]
cmp w10, #1
beq WriteEnd
add x13, x13, x17
st1 {v18.4s}, [x18], x17
dup s18, v19.s[1]
stp s19, s18, [x13]
cmp w10, #2
beq WriteEnd
add x13, x13, x17
st1 {v20.4s}, [x18], x17
dup s20, v21.s[1]
stp s21, s20, [x13]
cmp w10, #3
beq WriteEnd
add x13, x13, x17
st1 {v22.4s}, [x18], x17
dup s22, v23.s[1]
stp s23, s22, [x13]
cmp w10, #4
beq WriteEnd
add x13, x13, x17
st1 {v24.4s}, [x18], x17
dup s24, v25.s[1]
stp s25, s24, [x13]
cmp w10, #5
beq WriteEnd
add x13, x13, x17
st1 {v26.4s}, [x18], x17
dup s26, v27.s[1]
stp s27, s26, [x13]
cmp w10, #6
beq WriteEnd
add x13, x13, x17
st1 {v28.4s}, [x18], x17
dup s28, v29.s[1]
stp s29, s28, [x13]
cmp w10, #7
beq WriteEnd
add x13, x13, x17
st1 {v30.4s}, [x18], x17
dup s30, v31.s[1]
stp s31, s30, [x13]
b WriteEnd
Write7:
add x13, x18, #16
add x16, x18, #24
st1 {v16.4s}, [x18], x17
dup s16, v17.s[1]
stp s17, s16, [x13]
add x13, x13, x17
st1 {v17.s}[2], [x16], x17
cmp w10, #1
beq WriteEnd
st1 {v18.4s}, [x18], x17
dup s18, v19.s[1]
stp s19, s18, [x13]
add x13, x13, x17
st1 {v19.s}[2], [x16], x17
cmp w10, #2
beq WriteEnd
st1 {v20.4s}, [x18], x17
dup s20, v21.s[1]
stp s21, s20, [x13]
add x13, x13, x17
st1 {v21.s}[2], [x16], x17
cmp w10, #3
beq WriteEnd
st1 {v22.4s}, [x18], x17
dup s22, v23.s[1]
stp s23, s22, [x13]
add x13, x13, x17
st1 {v23.s}[2], [x16], x17
cmp w10, #4
beq WriteEnd
st1 {v24.4s}, [x18], x17
dup s24, v25.s[1]
stp s25, s24, [x13]
add x13, x13, x17
st1 {v25.s}[2], [x16], x17
cmp w10, #5
beq WriteEnd
st1 {v26.4s}, [x18], x17
dup s26, v27.s[1]
stp s27, s26, [x13]
add x13, x13, x17
st1 {v27.s}[2], [x16], x17
cmp w10, #6
beq WriteEnd
st1 {v28.4s}, [x18], x17
dup s28, v29.s[1]
stp s29, s28, [x13]
add x13, x13, x17
st1 {v29.s}[2], [x16], x17
cmp w10, #7
beq WriteEnd
st1 {v30.4s}, [x18], x17
dup s30, v31.s[1]
stp s31, s30, [x13]
add x13, x13, x17
st1 {v31.s}[2], [x16], x17
b WriteEnd
WriteC8:
st1 {v16.4s}, [x2], #16
st1 {v17.4s}, [x2], #16
st1 {v18.4s}, [x2], #16
@ -283,19 +598,48 @@ TransToOut:
st1 {v29.4s}, [x2], #16
st1 {v30.4s}, [x2], #16
st1 {v31.4s}, [x2], #16
b WriteEnd
Write8:
st1 {v16.4s, v17.4s}, [x18], x17
cmp w10, #1
beq WriteEnd
st1 {v18.4s, v19.4s}, [x18], x17
cmp w10, #2
beq WriteEnd
st1 {v20.4s, v21.4s}, [x18], x17
cmp w10, #3
beq WriteEnd
st1 {v22.4s, v23.4s}, [x18], x17
cmp w10, #4
beq WriteEnd
st1 {v24.4s, v25.4s}, [x18], x17
cmp w10, #5
beq WriteEnd
st1 {v26.4s, v27.4s}, [x18], x17
cmp w10, #6
beq WriteEnd
st1 {v28.4s, v29.4s}, [x18], x17
cmp w10, #7
beq WriteEnd
st1 {v30.4s, v31.4s}, [x18], x17
add w10, w10, #8 // lm row offset + 8
b L2
WriteEnd:
subs w10, w10, #8 // lhs row - 8
bgt L2
End2:
add w9, w9, #8 // rm col offset + 8
add x1, x1, x15 // rm ptr + stride
add x3, x3, x18 // bias ptr + stride
b L1
subs w7, w7, #8 // rhs col - 8
add x1, x1, x15 // rhs ptr + stride
add x3, x3, #32 // bias ptr + stride
ldrb w13, [sp, #8]
cbz w13, NoDstStep
add x2, x2, #32 // dst ptr + stride
NoDstStep:
bgt L1
End1:
sub sp, sp, #128
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ret
#endif
#endif

View File

@ -221,34 +221,57 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col
return;
}
void MatMul8x8(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_,
int col_8_) {
/* col8-major * row8-major => col8x8-major */
for (int row = 0; row < row_8_; row++) {
for (int col = 0; col < col_8_; col++) {
int r8div = row / 8, r8mod = row % 8;
int c8div = col / 8, c8mod = col % 8;
size_t ci = c8div * row_8_ * 8 + row * 8 + c8mod;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, bool write_nhwc) {
if (write_nhwc) {
/* col8-major * row8-major => col-major */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r8div = r / 8, r8mod = r % 8;
int c8div = c / 8, c8mod = c % 8;
size_t ci = r * stride + c;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[c];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
}
} else {
/* col8-major * row8-major => col8x8-major */
int col_8 = UP_ROUND(col, C8NUM);
int row_8 = UP_ROUND(row, C8NUM);
for (int r = 0; r < row_8; r++) {
for (int c = 0; c < col_8; c++) {
int r8div = r / 8, r8mod = r % 8;
int c8div = c / 8, c8mod = c % 8;
size_t ci = c8div * row_8 * 8 + r * 8 + c8mod;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[c];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
if (bias != NULL) value += bias[col];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
c[ci] = value;
}
}
return;
}
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_,
int col_8_) {
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
int stride, bool write_nhwc) {
#ifdef __aarch64__
MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row_8_, col_8_);
MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc);
#else
MatMul8x8(a, b, c, bias, act_type, deep, row_8_, col_8_);
MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc);
#endif
}

View File

@ -26,13 +26,14 @@
#ifdef __cplusplus
extern "C" {
#endif
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col);
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col,
int stride, bool write_nhwc);
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride);
#ifdef __aarch64__
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col);
int col, size_t stride, bool write_nhwc);
#endif
#ifdef __cplusplus
}

View File

@ -370,26 +370,35 @@ TEST_F(TestConv1x1Fp32, Conv1x1Test2) {
conv1x1->Run();
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
/* running warm up */
for (int i = 0; i < 0; i++) {
conv1x1->Run();
auto ptr = reinterpret_cast<float *>(outputs_[0]->Data());
bool first = true;
for (int i = 0; i < total_size; i++) {
if (fabs(ptr[i] - correct[i]) > 0.001 && first) {
printf("%d %f %f\n", i, ptr[i], correct[i]);
first = false;
}
}
/* running time cost */
int loop_count = 1;
auto time_start = mindspore::lite::GetTimeUs();
for (int i = 0; i < loop_count; i++) {
conv1x1->Run();
}
auto time_end = mindspore::lite::GetTimeUs();
auto cost = time_end - time_start;
uint64_t time_avg = cost / loop_count;
printf("1x1 average time : %f ms\n", time_avg / 1000.0f);
delete conv_param;
delete conv1x1;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
// /* running warm up */
// for (int i = 0; i < 0; i++) {
// conv1x1->Run();
// }
//
// /* running time cost */
// int loop_count = 1;
// auto time_start = mindspore::lite::GetTimeUs();
// for (int i = 0; i < loop_count; i++) {
// conv1x1->Run();
// }
// auto time_end = mindspore::lite::GetTimeUs();
// auto cost = time_end - time_start;
// uint64_t time_avg = cost / loop_count;
// printf("1x1 average time : %f ms\n", time_avg / 1000.0f);
//
// delete conv_param;
// delete conv1x1;
// for (auto t : inputs_) delete t;
// for (auto t : outputs_) delete t;
// free(correct);
}
} // namespace mindspore