forked from mindspore-Ecosystem/mindspore
commit
9afdf9be9b
|
@ -31,10 +31,6 @@ Convolution1x1CPUKernel::~Convolution1x1CPUKernel() {
|
|||
free(pack_input_);
|
||||
pack_input_ = nullptr;
|
||||
}
|
||||
if (pack_output_ != nullptr) {
|
||||
free(pack_output_);
|
||||
pack_output_ = nullptr;
|
||||
}
|
||||
if (pre_trans_input_ && input_ptr_ != nullptr) {
|
||||
free(input_ptr_);
|
||||
input_ptr_ = nullptr;
|
||||
|
@ -112,13 +108,6 @@ int Convolution1x1CPUKernel::InitConv1x1Param() {
|
|||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
memset(pack_input_, 0, matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float));
|
||||
|
||||
pack_output_ = reinterpret_cast<float *>(malloc(matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float)));
|
||||
if (pack_output_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Conv1x1 Malloc pack_output_ error!";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
memset(pack_output_, 0, matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -157,7 +146,7 @@ int Convolution1x1CPUKernel::Init() {
|
|||
}
|
||||
|
||||
int Convolution1x1CPUKernel::DoConv1x1(int task_id) {
|
||||
int cur_oc = MSMIN(thread_stride_, matmul_param_->col_8_ - task_id * thread_stride_);
|
||||
int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_);
|
||||
if (cur_oc <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -165,23 +154,12 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) {
|
|||
auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id;
|
||||
|
||||
MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
|
||||
pack_output_ + task_id * thread_stride_ * matmul_param_->row_8_, bias, matmul_param_->act_type_,
|
||||
matmul_param_->deep_, matmul_param_->row_8_, cur_oc);
|
||||
output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_,
|
||||
matmul_param_->row_, cur_oc, matmul_param_->col_, true);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Convolution1x1CPUKernel::DoConv1x1Post(int task_id) {
|
||||
int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_);
|
||||
if (cur_oc <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
float *src = pack_output_ + task_id * thread_stride_ * matmul_param_->row_8_;
|
||||
float *dst = output_ptr_ + task_id * thread_stride_;
|
||||
Row8x8Major2RowMajor(src, dst, matmul_param_->row_, cur_oc, matmul_param_->col_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||
auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata);
|
||||
auto error_code = conv1x1->DoConv1x1(task_id);
|
||||
|
@ -192,12 +170,6 @@ int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int Convolution1x1Post(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||
auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata);
|
||||
conv1x1->DoConv1x1Post(task_id);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Convolution1x1CPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
|
@ -216,8 +188,6 @@ int Convolution1x1CPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "conv1x1 strassen error error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
LiteBackendParallelLaunch(Convolution1x1Post, this, thread_count_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -46,7 +46,6 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel {
|
|||
|
||||
public:
|
||||
int DoConv1x1(int task_id);
|
||||
int DoConv1x1Post(int task_id);
|
||||
|
||||
private:
|
||||
int InitConv1x1Param();
|
||||
|
@ -61,7 +60,6 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel {
|
|||
int thread_stride_ = 0;
|
||||
float *weight_ptr_ = nullptr;
|
||||
float *pack_input_ = nullptr;
|
||||
float *pack_output_ = nullptr;
|
||||
float *input_ptr_ = nullptr;
|
||||
float *output_ptr_ = nullptr;
|
||||
};
|
||||
|
|
|
@ -152,7 +152,7 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) {
|
|||
|
||||
MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
|
||||
tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_, nullptr, ActType_No,
|
||||
matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_);
|
||||
matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_, matmul_param_->col_, false);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -104,7 +104,7 @@ int FullconnectionCPUKernel::DoMatmul(int task_id) {
|
|||
MatMul(a_c8_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_,
|
||||
c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_,
|
||||
bias_ptr_ + task_id * thread_stride_ * C8NUM, fc_param_->act_type_, fc_param_->deep_, fc_param_->row_8_,
|
||||
cur_oc * 8);
|
||||
cur_oc * 8, 0, false);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -77,7 +77,7 @@ int MatmulCPUKernel::RunImpl(int task_id) {
|
|||
}
|
||||
auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_;
|
||||
auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_;
|
||||
MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8);
|
||||
MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8, 0, false);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -640,7 +640,7 @@ IndirectGemmStart:
|
|||
add x15, x15, x7
|
||||
str s30, [x15]
|
||||
add x0, x0, #4
|
||||
b WriteEnd
|
||||
b WriteEndHalf
|
||||
Write2:
|
||||
dup s17, v16.s[1]
|
||||
stp s16, s17, [x15]
|
||||
|
@ -666,7 +666,7 @@ IndirectGemmStart:
|
|||
dup s31, v30.s[1]
|
||||
stp s30, s31, [x15]
|
||||
add x0, x0, #8
|
||||
b WriteEnd
|
||||
b WriteEndHalf
|
||||
Write3:
|
||||
add x17, x15, #8
|
||||
dup s17, v16.s[1]
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
// accumulators 8x8 block
|
||||
//
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
//OptLoopMul4 RM 1x8 block
|
||||
//OptLoopMul4 RM 4x8 block
|
||||
// /--------------------------------------------\
|
||||
// |v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] |
|
||||
// |v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]|
|
||||
|
@ -46,7 +46,8 @@
|
|||
// accumulators 8x8 block
|
||||
/////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, int col)
|
||||
// void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
|
||||
// int row, int col, int stride, bool write_nhwc)
|
||||
// x0: a
|
||||
// x1: b
|
||||
// x2: c
|
||||
|
@ -55,30 +56,30 @@
|
|||
// w5: depth
|
||||
// w6: row
|
||||
// w7: col
|
||||
// w17: stride
|
||||
// w13: writeC8
|
||||
|
||||
MatmulFloatNeon64:
|
||||
sub sp, sp, #128
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
|
||||
mov w9, #0 // rm col offset
|
||||
mov w10, #0 // lm row offset
|
||||
mov w18, #32 // sizeof(float)*8
|
||||
mul w15, w5, w18 // the stride of lm/rm: sizeof(float)*8*depth
|
||||
mov x11, x3 // bias flag
|
||||
mov w18, #32 // sizeof(float) * 8
|
||||
mul w15, w5, w18 // block stride of lhs/rhs: sizeof(float) * 8 * depth
|
||||
mov x11, x3 // bias flag
|
||||
mov x18, #4
|
||||
ldr x17, [sp]
|
||||
mul x17, x17, x18
|
||||
|
||||
L1:
|
||||
cmp w9, w7
|
||||
beq End1
|
||||
mov w10, w6 // reload lhs row
|
||||
mov x12, x0 // reload lhs ptr
|
||||
mov x18, x2 // reload dst ptr
|
||||
|
||||
mov w10, #0 // reset lm row offset
|
||||
mov x12, x0 // reload lm ptr
|
||||
L2:
|
||||
cmp w10, w6
|
||||
beq End2
|
||||
|
||||
mov x16, x1 // reload rm ptr
|
||||
mov w13, w5 // reload depth
|
||||
mov x14, x3 // reload bias ptr
|
||||
mov x16, x1 // reload rhs ptr
|
||||
mov w13, w5 // reload depth
|
||||
mov x14, x3 // reload bias ptr
|
||||
dup v16.4s, wzr
|
||||
dup v17.4s, wzr
|
||||
dup v18.4s, wzr
|
||||
|
@ -96,10 +97,10 @@ L2:
|
|||
dup v30.4s, wzr
|
||||
dup v31.4s, wzr
|
||||
|
||||
OptLoopMul4:
|
||||
cmp w13, #4
|
||||
blt CommLoopMul
|
||||
|
||||
OptLoopMul4:
|
||||
ld1 {v0.4s, v1.4s}, [x12], #32
|
||||
ld1 {v8.4s, v9.4s}, [x16], #32
|
||||
fmla v16.4s, v8.4s, v0.s[0]
|
||||
|
@ -172,13 +173,14 @@ OptLoopMul4:
|
|||
fmla v29.4s, v15.4s, v7.s[2]
|
||||
fmla v30.4s, v14.4s, v7.s[3]
|
||||
fmla v31.4s, v15.4s, v7.s[3]
|
||||
subs w13, w13, #4
|
||||
b OptLoopMul4
|
||||
|
||||
sub w13, w13, #4
|
||||
cmp w13, #0
|
||||
ble Bias
|
||||
cmp w13, #4
|
||||
bge OptLoopMul4
|
||||
|
||||
CommLoopMul:
|
||||
cmp w13, #1
|
||||
blt Bias
|
||||
|
||||
ld1 {v0.4s, v1.4s}, [x12], #32
|
||||
ld1 {v2.4s, v3.4s}, [x16], #32
|
||||
fmla v16.4s, v2.4s, v0.s[0]
|
||||
|
@ -197,8 +199,9 @@ CommLoopMul:
|
|||
fmla v29.4s, v3.4s, v1.s[2]
|
||||
fmla v30.4s, v2.4s, v1.s[3]
|
||||
fmla v31.4s, v3.4s, v1.s[3]
|
||||
|
||||
subs w13, w13, #1
|
||||
b CommLoopMul
|
||||
bgt CommLoopMul
|
||||
|
||||
Bias:
|
||||
cbz x11, Activation
|
||||
|
@ -226,7 +229,8 @@ Activation:
|
|||
beq Relu6
|
||||
cmp w4, #1
|
||||
beq Relu
|
||||
b TransToOut
|
||||
b Write
|
||||
|
||||
Relu6:
|
||||
mov w8, #6
|
||||
dup v15.4s, w8
|
||||
|
@ -247,6 +251,7 @@ Relu6:
|
|||
fmin v29.4s, v29.4s, v15.4s
|
||||
fmin v30.4s, v30.4s, v15.4s
|
||||
fmin v31.4s, v31.4s, v15.4s
|
||||
|
||||
Relu:
|
||||
dup v14.4s, wzr
|
||||
fmax v16.4s, v16.4s, v14.4s
|
||||
|
@ -266,7 +271,317 @@ Relu:
|
|||
fmax v30.4s, v30.4s, v14.4s
|
||||
fmax v31.4s, v31.4s, v14.4s
|
||||
|
||||
TransToOut:
|
||||
Write:
|
||||
ldrb w13, [sp, #8]
|
||||
cbz w13, WriteC8
|
||||
cmp w7, #1
|
||||
beq Write1
|
||||
cmp w7, #2
|
||||
beq Write2
|
||||
cmp w7, #3
|
||||
beq Write3
|
||||
cmp w7, #4
|
||||
beq Write4
|
||||
cmp w7, #5
|
||||
beq Write5
|
||||
cmp w7, #6
|
||||
beq Write6
|
||||
cmp w7, #7
|
||||
beq Write7
|
||||
b Write8
|
||||
|
||||
Write1:
|
||||
str s16, [x18]
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s18, [x18]
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s20, [x18]
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s22, [x18]
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s24, [x18]
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s26, [x18]
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s28, [x18]
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
str s30, [x18]
|
||||
add x18, x18, x17
|
||||
b WriteEnd
|
||||
Write2:
|
||||
dup s17, v16.s[1]
|
||||
stp s16, s17, [x18]
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s19, v18.s[1]
|
||||
stp s18, s19, [x18]
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s21, v20.s[1]
|
||||
stp s20, s21, [x18]
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s23, v22.s[1]
|
||||
stp s22, s23, [x18]
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s25, v24.s[1]
|
||||
stp s24, s25, [x18]
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s27, v26.s[1]
|
||||
stp s26, s27, [x18]
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s29, v28.s[1]
|
||||
stp s28, s29, [x18]
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
add x18, x18, x17
|
||||
dup s31, v30.s[1]
|
||||
stp s30, s31, [x18]
|
||||
add x18, x18, x17
|
||||
b WriteEnd
|
||||
Write3:
|
||||
add x13, x18, #8
|
||||
dup s17, v16.s[1]
|
||||
stp s16, s17, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v16.s}[2], [x13], x17
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
dup s19, v18.s[1]
|
||||
stp s18, s19, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v18.s}[2], [x13], x17
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
dup s21, v20.s[1]
|
||||
stp s20, s21, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v20.s}[2], [x13], x17
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
dup s23, v22.s[1]
|
||||
stp s22, s23, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v22.s}[2], [x13], x17
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
dup s25, v24.s[1]
|
||||
stp s24, s25, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v24.s}[2], [x13], x17
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
dup s27, v26.s[1]
|
||||
stp s26, s27, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v26.s}[2], [x13], x17
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
dup s29, v28.s[1]
|
||||
stp s28, s29, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v28.s}[2], [x13], x17
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
dup s31, v30.s[1]
|
||||
stp s30, s31, [x18]
|
||||
add x18, x18, x17
|
||||
st1 {v30.s}[2], [x13]
|
||||
b WriteEnd
|
||||
Write4:
|
||||
st1 {v16.4s}, [x18], x17
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
st1 {v18.4s}, [x18], x17
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
st1 {v20.4s}, [x18], x17
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
st1 {v22.4s}, [x18], x17
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
st1 {v24.4s}, [x18], x17
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
st1 {v26.4s}, [x18], x17
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
st1 {v28.4s}, [x18], x17
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
st1 {v30.4s}, [x18], x17
|
||||
b WriteEnd
|
||||
Write5:
|
||||
add x13, x18, #16
|
||||
st1 {v16.4s}, [x18], x17
|
||||
str s17, [x13]
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v18.4s}, [x18], x17
|
||||
str s19, [x13]
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v20.4s}, [x18], x17
|
||||
str s21, [x13]
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v22.4s}, [x18], x17
|
||||
str s23, [x13]
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v24.4s}, [x18], x17
|
||||
str s25, [x13]
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v26.4s}, [x18], x17
|
||||
str s27, [x13]
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v28.4s}, [x18], x17
|
||||
str s29, [x13]
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v30.4s}, [x18], x17
|
||||
str s31, [x13]
|
||||
b WriteEnd
|
||||
Write6:
|
||||
add x13, x18, #16
|
||||
st1 {v16.4s}, [x18], x17
|
||||
dup s16, v17.s[1]
|
||||
stp s17, s16, [x13]
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v18.4s}, [x18], x17
|
||||
dup s18, v19.s[1]
|
||||
stp s19, s18, [x13]
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v20.4s}, [x18], x17
|
||||
dup s20, v21.s[1]
|
||||
stp s21, s20, [x13]
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v22.4s}, [x18], x17
|
||||
dup s22, v23.s[1]
|
||||
stp s23, s22, [x13]
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v24.4s}, [x18], x17
|
||||
dup s24, v25.s[1]
|
||||
stp s25, s24, [x13]
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v26.4s}, [x18], x17
|
||||
dup s26, v27.s[1]
|
||||
stp s27, s26, [x13]
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v28.4s}, [x18], x17
|
||||
dup s28, v29.s[1]
|
||||
stp s29, s28, [x13]
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
add x13, x13, x17
|
||||
st1 {v30.4s}, [x18], x17
|
||||
dup s30, v31.s[1]
|
||||
stp s31, s30, [x13]
|
||||
b WriteEnd
|
||||
Write7:
|
||||
add x13, x18, #16
|
||||
add x16, x18, #24
|
||||
st1 {v16.4s}, [x18], x17
|
||||
dup s16, v17.s[1]
|
||||
stp s17, s16, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v17.s}[2], [x16], x17
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
st1 {v18.4s}, [x18], x17
|
||||
dup s18, v19.s[1]
|
||||
stp s19, s18, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v19.s}[2], [x16], x17
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
st1 {v20.4s}, [x18], x17
|
||||
dup s20, v21.s[1]
|
||||
stp s21, s20, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v21.s}[2], [x16], x17
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
st1 {v22.4s}, [x18], x17
|
||||
dup s22, v23.s[1]
|
||||
stp s23, s22, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v23.s}[2], [x16], x17
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
st1 {v24.4s}, [x18], x17
|
||||
dup s24, v25.s[1]
|
||||
stp s25, s24, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v25.s}[2], [x16], x17
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
st1 {v26.4s}, [x18], x17
|
||||
dup s26, v27.s[1]
|
||||
stp s27, s26, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v27.s}[2], [x16], x17
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
st1 {v28.4s}, [x18], x17
|
||||
dup s28, v29.s[1]
|
||||
stp s29, s28, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v29.s}[2], [x16], x17
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
st1 {v30.4s}, [x18], x17
|
||||
dup s30, v31.s[1]
|
||||
stp s31, s30, [x13]
|
||||
add x13, x13, x17
|
||||
st1 {v31.s}[2], [x16], x17
|
||||
b WriteEnd
|
||||
WriteC8:
|
||||
st1 {v16.4s}, [x2], #16
|
||||
st1 {v17.4s}, [x2], #16
|
||||
st1 {v18.4s}, [x2], #16
|
||||
|
@ -283,19 +598,48 @@ TransToOut:
|
|||
st1 {v29.4s}, [x2], #16
|
||||
st1 {v30.4s}, [x2], #16
|
||||
st1 {v31.4s}, [x2], #16
|
||||
b WriteEnd
|
||||
Write8:
|
||||
st1 {v16.4s, v17.4s}, [x18], x17
|
||||
cmp w10, #1
|
||||
beq WriteEnd
|
||||
st1 {v18.4s, v19.4s}, [x18], x17
|
||||
cmp w10, #2
|
||||
beq WriteEnd
|
||||
st1 {v20.4s, v21.4s}, [x18], x17
|
||||
cmp w10, #3
|
||||
beq WriteEnd
|
||||
st1 {v22.4s, v23.4s}, [x18], x17
|
||||
cmp w10, #4
|
||||
beq WriteEnd
|
||||
st1 {v24.4s, v25.4s}, [x18], x17
|
||||
cmp w10, #5
|
||||
beq WriteEnd
|
||||
st1 {v26.4s, v27.4s}, [x18], x17
|
||||
cmp w10, #6
|
||||
beq WriteEnd
|
||||
st1 {v28.4s, v29.4s}, [x18], x17
|
||||
cmp w10, #7
|
||||
beq WriteEnd
|
||||
st1 {v30.4s, v31.4s}, [x18], x17
|
||||
|
||||
add w10, w10, #8 // lm row offset + 8
|
||||
b L2
|
||||
WriteEnd:
|
||||
subs w10, w10, #8 // lhs row - 8
|
||||
bgt L2
|
||||
|
||||
End2:
|
||||
add w9, w9, #8 // rm col offset + 8
|
||||
add x1, x1, x15 // rm ptr + stride
|
||||
add x3, x3, x18 // bias ptr + stride
|
||||
b L1
|
||||
subs w7, w7, #8 // rhs col - 8
|
||||
add x1, x1, x15 // rhs ptr + stride
|
||||
add x3, x3, #32 // bias ptr + stride
|
||||
ldrb w13, [sp, #8]
|
||||
cbz w13, NoDstStep
|
||||
add x2, x2, #32 // dst ptr + stride
|
||||
NoDstStep:
|
||||
bgt L1
|
||||
|
||||
End1:
|
||||
sub sp, sp, #128
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
ret
|
||||
#endif
|
||||
#endif
|
|
@ -221,34 +221,57 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col
|
|||
return;
|
||||
}
|
||||
|
||||
void MatMul8x8(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_,
|
||||
int col_8_) {
|
||||
/* col8-major * row8-major => col8x8-major */
|
||||
for (int row = 0; row < row_8_; row++) {
|
||||
for (int col = 0; col < col_8_; col++) {
|
||||
int r8div = row / 8, r8mod = row % 8;
|
||||
int c8div = col / 8, c8mod = col % 8;
|
||||
size_t ci = c8div * row_8_ * 8 + row * 8 + c8mod;
|
||||
float value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
|
||||
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, int stride, bool write_nhwc) {
|
||||
if (write_nhwc) {
|
||||
/* col8-major * row8-major => col-major */
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r8div = r / 8, r8mod = r % 8;
|
||||
int c8div = c / 8, c8mod = c % 8;
|
||||
size_t ci = r * stride + c;
|
||||
float value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
|
||||
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
if (bias != NULL) value += bias[c];
|
||||
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
|
||||
if (act_type != ActType_No) value = MSMAX(0.0f, value);
|
||||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
/* col8-major * row8-major => col8x8-major */
|
||||
int col_8 = UP_ROUND(col, C8NUM);
|
||||
int row_8 = UP_ROUND(row, C8NUM);
|
||||
for (int r = 0; r < row_8; r++) {
|
||||
for (int c = 0; c < col_8; c++) {
|
||||
int r8div = r / 8, r8mod = r % 8;
|
||||
int c8div = c / 8, c8mod = c % 8;
|
||||
size_t ci = c8div * row_8 * 8 + r * 8 + c8mod;
|
||||
float value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
|
||||
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
if (bias != NULL) value += bias[c];
|
||||
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
|
||||
if (act_type != ActType_No) value = MSMAX(0.0f, value);
|
||||
dst[ci] = value;
|
||||
}
|
||||
if (bias != NULL) value += bias[col];
|
||||
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
|
||||
if (act_type != ActType_No) value = MSMAX(0.0f, value);
|
||||
c[ci] = value;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_,
|
||||
int col_8_) {
|
||||
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
|
||||
int stride, bool write_nhwc) {
|
||||
#ifdef __aarch64__
|
||||
MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row_8_, col_8_);
|
||||
MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc);
|
||||
#else
|
||||
MatMul8x8(a, b, c, bias, act_type, deep, row_8_, col_8_);
|
||||
MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc);
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -26,13 +26,14 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col);
|
||||
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col,
|
||||
int stride, bool write_nhwc);
|
||||
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride);
|
||||
#ifdef __aarch64__
|
||||
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col);
|
||||
int col, size_t stride, bool write_nhwc);
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -370,26 +370,35 @@ TEST_F(TestConv1x1Fp32, Conv1x1Test2) {
|
|||
conv1x1->Run();
|
||||
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
|
||||
|
||||
/* running warm up */
|
||||
for (int i = 0; i < 0; i++) {
|
||||
conv1x1->Run();
|
||||
auto ptr = reinterpret_cast<float *>(outputs_[0]->Data());
|
||||
bool first = true;
|
||||
for (int i = 0; i < total_size; i++) {
|
||||
if (fabs(ptr[i] - correct[i]) > 0.001 && first) {
|
||||
printf("%d %f %f\n", i, ptr[i], correct[i]);
|
||||
first = false;
|
||||
}
|
||||
}
|
||||
|
||||
/* running time cost */
|
||||
int loop_count = 1;
|
||||
auto time_start = mindspore::lite::GetTimeUs();
|
||||
for (int i = 0; i < loop_count; i++) {
|
||||
conv1x1->Run();
|
||||
}
|
||||
auto time_end = mindspore::lite::GetTimeUs();
|
||||
auto cost = time_end - time_start;
|
||||
uint64_t time_avg = cost / loop_count;
|
||||
printf("1x1 average time : %f ms\n", time_avg / 1000.0f);
|
||||
|
||||
delete conv_param;
|
||||
delete conv1x1;
|
||||
for (auto t : inputs_) delete t;
|
||||
for (auto t : outputs_) delete t;
|
||||
free(correct);
|
||||
// /* running warm up */
|
||||
// for (int i = 0; i < 0; i++) {
|
||||
// conv1x1->Run();
|
||||
// }
|
||||
//
|
||||
// /* running time cost */
|
||||
// int loop_count = 1;
|
||||
// auto time_start = mindspore::lite::GetTimeUs();
|
||||
// for (int i = 0; i < loop_count; i++) {
|
||||
// conv1x1->Run();
|
||||
// }
|
||||
// auto time_end = mindspore::lite::GetTimeUs();
|
||||
// auto cost = time_end - time_start;
|
||||
// uint64_t time_avg = cost / loop_count;
|
||||
// printf("1x1 average time : %f ms\n", time_avg / 1000.0f);
|
||||
//
|
||||
// delete conv_param;
|
||||
// delete conv1x1;
|
||||
// for (auto t : inputs_) delete t;
|
||||
// for (auto t : outputs_) delete t;
|
||||
// free(correct);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue