add matmul fp32 kernel on arm32

This commit is contained in:
lixian 2020-09-11 20:44:39 +08:00
parent 939737c017
commit 902f08be82
16 changed files with 715 additions and 31 deletions

View File

@ -0,0 +1,367 @@
#ifdef ENABLE_ARM32
.text
.align 5
.global MatmulFloatNeon32Opt
#ifndef __APPLE__
.type MatmulFloatNeon32Opt, %function
#endif
// void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
// int row, int col, size_t stride, size_t writeNhwc, size_t WriteWino)
// r0: a
// r1: b
// r2: c
// r3: bias
// r4: act_type
// r5: depth
// r6: row
// r7: col
// r8: stride
// lr: writeNhwc/writeWino
MatmulFloatNeon32Opt:
// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf
push {r0-r8, r10, r11, lr}
add sp, sp, #48
ldr r5, [sp, #4]
ldr r7, [sp, #12]
ldr r8, [sp, #16]
mov lr, #32 // sizeof(float) * 8
mul r12, r5, lr // block stride of lhs/rhs: sizeof(float) * 8 * depth
ldr lr, [sp, #24]
cmp lr, #0
beq NoWinoSteps
mov lr, #4
mul r11, r7, r8 // stride * col * sizeof(float)
mul r11, r11, lr
mov lr, #32
mul r10, r8, lr // stride * 8 * sizeof(float)
NoWinoSteps:
mov lr, #4
mul r8, r8, lr // stride * sizeof(float)
LoopCol:
ldr r6, [sp, #8] // reload lhs row
ldr r0, [sp, #-48] // reload lhs ptr
ldr r2, [sp, #-40] // reload dst ptr
LoopRow:
ldr r1, [sp, #-44] // reload rhs ptr
ldr r5, [sp, #4] // reload depth
veor q8, q8, q8
veor q9, q9, q9
veor q10, q10, q10
veor q11, q11, q11
veor q12, q12, q12
veor q13, q13, q13
veor q14, q14, q14
veor q15, q15, q15
LoopDepth:
vld1.32 {q0}, [r0]!
vld1.32 {q1, q2}, [r1]!
vmla.f32 q8, q1, d0[0]
vmla.f32 q9, q2, d0[0]
vmla.f32 q10, q1, d0[1]
vmla.f32 q11, q2, d0[1]
vmla.f32 q12, q1, d1[0]
vmla.f32 q13, q2, d1[0]
vmla.f32 q14, q1, d1[1]
vmla.f32 q15, q2, d1[1]
subs r5, r5, #1
bne LoopDepth
Bias:
cmp r3, #0
beq Activation
vld1.32 {q0}, [r3]!
vld1.32 {q1}, [r3]
sub r3, r3, #16
vadd.f32 q8, q8, q0
vadd.f32 q9, q9, q1
vadd.f32 q10, q10, q0
vadd.f32 q11, q11, q1
vadd.f32 q12, q12, q0
vadd.f32 q13, q13, q1
vadd.f32 q14, q14, q0
vadd.f32 q15, q15, q1
Activation:
ldr lr, [sp]
cmp lr, #2
beq Relu6
cmp lr, #1
beq Relu
b Write
Relu6:
vmov.i32 q2, #6
vcvt.f32.s32 q2, q2
vmin.f32 q8, q8, q2
vmin.f32 q9, q9, q2
vmin.f32 q10, q10, q2
vmin.f32 q11, q11, q2
vmin.f32 q12, q12, q2
vmin.f32 q13, q13, q2
vmin.f32 q14, q14, q2
vmin.f32 q15, q15, q2
Relu:
veor q3, q3, q3
vmax.f32 q8, q8, q3
vmax.f32 q9, q9, q3
vmax.f32 q10, q10, q3
vmax.f32 q11, q11, q3
vmax.f32 q12, q12, q3
vmax.f32 q13, q13, q3
vmax.f32 q14, q14, q3
vmax.f32 q15, q15, q3
Write:
ldr lr, [sp, #24]
cmp lr, #0
bne WriteWino
ldr lr, [sp, #20]
cmp lr, #0
beq WriteC8
cmp r7, #1
beq Write1
cmp r7, #2
beq Write2
cmp r7, #3
beq Write3
cmp r7, #4
beq Write4
cmp r7, #5
beq Write5
cmp r7, #6
beq Write6
cmp r7, #7
beq Write7
b Write8
Write1:
vst1.32 d16[0], [r2]
cmp r6, #1
beq WriteEnd
add r2, r2, r8
vst1.32 d20[0], [r2]
cmp r6, #2
beq WriteEnd
add r2, r2, r8
vst1.32 d24[0], [r2]
cmp r6, #3
beq WriteEnd
add r2, r2, r8
vst1.32 d28[0], [r2]
add r2, r2, r8
b WriteEnd
Write2:
vst1.32 d16, [r2]
cmp r6, #1
beq WriteEnd
add r2, r2, r8
vst1.32 d20, [r2]
cmp r6, #2
beq WriteEnd
add r2, r2, r8
vst1.32 d24, [r2]
cmp r6, #3
beq WriteEnd
add r2, r2, r8
vst1.32 d28, [r2]
add r2, r2, r8
b WriteEnd
Write3:
add r4, r2, #8
vst1.32 d16, [r2]
vst1.32 d17[0], [r4]
cmp r6, #1
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 d20, [r2]
vst1.32 d21[0], [r4]
cmp r6, #2
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 d24, [r2]
vst1.32 d25[0], [r4]
cmp r6, #3
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 d28, [r2]
vst1.32 d29[0], [r4]
add r2, r2, r8
b WriteEnd
Write4:
vst1.32 q8, [r2]
cmp r6, #1
beq WriteEnd
add r2, r2, r8
vst1.32 q10, [r2]
cmp r6, #2
beq WriteEnd
add r2, r2, r8
vst1.32 q12, [r2]
cmp r6, #3
beq WriteEnd
add r2, r2, r8
vst1.32 q14, [r2]
add r2, r2, r8
b WriteEnd
Write5:
add r4, r2, #16
vst1.32 q8, [r2]
vst1.32 d18[0], [r4]
cmp r6, #1
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 q10, [r2]
vst1.32 d22[0], [r4]
cmp r6, #2
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 q12, [r2]
vst1.32 d26[0], [r4]
cmp r6, #3
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 q14, [r2]
vst1.32 d30[0], [r4]
add r2, r2, r8
b WriteEnd
Write6:
add r4, r2, #16
vst1.32 q8, [r2]
vst1.32 d18, [r4]
cmp r6, #1
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 q10, [r2]
vst1.32 d22, [r4]
cmp r6, #2
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 q12, [r2]
vst1.32 d26, [r4]
cmp r6, #3
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
vst1.32 q14, [r2]
vst1.32 d30, [r4]
add r2, r2, r8
b WriteEnd
Write7:
add lr, r2, #24
add r4, r2, #16
vst1.32 q8, [r2]
vst1.32 d18, [r4]
vst1.32 d19[0], [lr]
cmp r6, #1
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
add lr, lr, r8
vst1.32 q10, [r2]
vst1.32 d22, [r4]
vst1.32 d23[0], [lr]
cmp r6, #2
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
add lr, lr, r8
vst1.32 q12, [r2]
vst1.32 d26, [r4]
vst1.32 d27[0], [lr]
cmp r6, #3
beq WriteEnd
add r2, r2, r8
add r4, r4, r8
add lr, lr, r8
vst1.32 q14, [r2]
vst1.32 d30, [r4]
vst1.32 d31[0], [lr]
add r2, r2, r8
b WriteEnd
WriteC8:
vst1.32 {q8, q9}, [r2]!
vst1.32 {q10, q11}, [r2]!
vst1.32 {q12, q13}, [r2]!
vst1.32 {q14, q15}, [r2]!
b WriteEnd
WriteWino:
vst1.32 {q8, q9}, [r2]
add r2, r2, r11
vst1.32 {q10, q11}, [r2]
add r2, r2, r11
vst1.32 {q12, q13}, [r2]
add r2, r2, r11
vst1.32 {q14, q15}, [r2]
add r2, r2, r11
b WriteEnd
Write8:
vst1.32 {q8, q9}, [r2]
cmp r6, #1
beq WriteEnd
add r2, r2, r8
vst1.32 {q10, q11}, [r2]
cmp r6, #2
beq WriteEnd
add r2, r2, r8
vst1.32 {q12, q13}, [r2]
cmp r6, #3
beq WriteEnd
add r2, r2, r8
vst1.32 {q14, q15}, [r2]
add r2, r2, r8
WriteEnd:
cmp r6, #4
ble LoopRowEnd
sub r6, r6, #4 // lhs row - 4
b LoopRow
LoopRowEnd:
ldr r1, [sp, #-44]
add r1, r1, r12 // rhs ptr + stride
str r1, [sp, #-44]
cmp r3, #0
beq NoBiasStep
add r3, r3, #32 // bias ptr + stride
NoBiasStep:
ldr lr, [sp, #24]
cmp lr, #0
bne WinoDstStep
ldr lr, [sp, #20]
cmp lr, #0
beq NoDstStep
ldr r2, [sp, #-40]
add r2, r2, #32 // dst ptr + stride
str r2, [sp, #-40]
b NoDstStep
WinoDstStep:
ldr r2, [sp, #-40]
add r2, r2, r10
str r2, [sp, #-40]
NoDstStep:
cmp r7, #8
ble LoopColEnd
sub r7, r7, #8 // rhs col - 8
b LoopCol
LoopColEnd:
sub sp, sp, #48
pop {r0-r8, r10, r11, pc}
#endif

View File

@ -112,7 +112,8 @@ void IndirectGemmFp32_8x8(float *output, const float *input, const float *weight
}
}
#endif
// #ifndef ENABLE_ARM32
#ifndef ENABLE_ARM32
void IndirectGemmFp32_8x4(float *output, const float *input, const float *weight, const float *bias, size_t step,
size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu,
size_t relu6) {
@ -155,7 +156,7 @@ void IndirectGemmFp32_8x4(float *output, const float *input, const float *weight
}
}
}
// #endif
#endif
int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); }

View File

@ -270,7 +270,12 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
int out_w_block = UP_DIV(conv_param->output_w_, out_unit);
int out_h_block = UP_DIV(conv_param->output_h_, out_unit);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, C12NUM);
#ifdef ENABLE_ARM32
int tile_num = 4;
#else
int tile_num = 12;
#endif
int output_tile_count = UP_DIV(output_count, tile_num);
int out_channel = conv_param->output_channel_;
int oc4 = UP_DIV(out_channel, C4NUM);
int oc8 = UP_DIV(out_channel, C8NUM);
@ -281,19 +286,19 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
float *tmp_out_data = buffer_list[2];
float *tmp_data = buffer_list[3];
float *col_buffer = buffer_list[4];
int trans_input_offset = C12NUM * input_unit_square * ic4 * C4NUM;
int gemm_out_offset = C12NUM * input_unit_square * oc8 * C8NUM;
int trans_input_offset = tile_num * input_unit_square * ic4 * C4NUM;
int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
int tmp_data_offset = input_unit_square * C4NUM;
int col_buffer_offset = C12NUM * ic4 * C4NUM;
int col_buffer_offset = tile_num * ic4 * C4NUM;
// step 1 : filter transform (pre-processed offline)
// step 2 : input transform (online)
for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc4 * C4NUM;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
int out_tile_index = thread_id * C12NUM;
int cal_num = output_count - thread_id * C12NUM;
cal_num = cal_num > C12NUM ? C12NUM : cal_num;
int out_tile_index = thread_id * tile_num;
int cal_num = output_count - thread_id * tile_num;
cal_num = cal_num > tile_num ? tile_num : cal_num;
WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
in_func);
@ -302,7 +307,11 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
float *dst_ptr = gemm_out + task_id * gemm_out_offset;
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
#ifdef ENABLE_ARM32
RowMajor2Col4Major(src_ptr + i * C4NUM * ic4 * C4NUM, tmp_col_ptr, C4NUM, ic4 * C4NUM);
#else
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
#endif
MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM,
cal_num, oc8 * C8NUM, input_unit_square, 2);
}
@ -460,7 +469,12 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, C12NUM);
#ifdef ENABLE_ARM32
int tile_num = 4;
#else
int tile_num = 12;
#endif
int output_tile_count = UP_DIV(output_count, tile_num);
const int input_unit_square = 4 * 4;
float *tile_buffer = buffer_list[0];
@ -468,10 +482,10 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
float *tmp_dst_buffer = buffer_list[2];
float *nc4hw4_out = buffer_list[3];
float *col_buffer = buffer_list[4];
int tile_buffer_offset = C12NUM * input_unit_square * ic4 * C4NUM;
int tile_buffer_offset = tile_num * input_unit_square * ic4 * C4NUM;
int block_unit_buffer_offset = input_unit_square * C4NUM;
int tmp_dst_buffer_offset = C12NUM * input_unit_square * oc8 * C8NUM;
int col_buffer_offset = C12NUM * ic4 * C4NUM;
int tmp_dst_buffer_offset = tile_num * input_unit_square * oc8 * C8NUM;
int col_buffer_offset = tile_num * ic4 * C4NUM;
int input_batch = conv_param->input_batch_;
for (int batch = 0; batch < input_batch; batch++) {
@ -479,8 +493,8 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
int nc4hw4_buffer_offset = batch * oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
int start_index = thread_id * C12NUM;
int real_cal_num = (output_count - start_index) < C12NUM ? (output_count - start_index) : C12NUM;
int start_index = thread_id * tile_num;
int real_cal_num = (output_count - start_index) < tile_num ? (output_count - start_index) : tile_num;
Conv3x3Fp32InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
out_w_block, conv_param);
@ -489,7 +503,11 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
float *dst_ptr = tmp_dst_buffer + task_id * tmp_dst_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
#ifdef ENABLE_ARM32
RowMajor2Col4Major(src_ptr + i * C4NUM * ic4 * C4NUM, tmp_col_ptr, C4NUM, ic4 * C4NUM);
#else
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
#endif
MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0,
ic4 * C4NUM, real_cal_num, oc8 * C8NUM, input_unit_square, 2);
}

View File

@ -40,7 +40,12 @@ int DeConvPostFp32C12x8(const float *src, float *tmp, const float *bias, float *
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
size_t output_plane = conv_param->output_w_ * conv_param->output_h_;
int oc8 = UP_ROUND(output_channel, C8NUM);
int in_plane12 = UP_ROUND(input_plane, C12NUM);
#ifdef ENABLE_ARM32
int tile_num = 4;
#else
int tile_num = 12;
#endif
int in_plane12 = UP_ROUND(input_plane, tile_num);
int src_iw_stride = C8NUM;
int src_ih_stride = conv_param->input_w_ * C8NUM;
int src_kw_stride = in_plane12 * C8NUM;

View File

@ -16,6 +16,18 @@
#include "nnacl/fp32/matmul.h"
void RowMajor2Row4Major(float *src_ptr, float *dst_ptr, int row, int col) {
for (int r = 0; r < row; r++) {
float *src = src_ptr + r * col;
for (int c = 0; c < col; c++) {
int cd8 = c / 4;
int cm8 = c % 4;
dst_ptr[cd8 * 4 * row + r * 4 + cm8] = src[c];
}
}
return;
}
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) {
for (int r = 0; r < row; r++) {
float *src = src_ptr + r * col;
@ -115,6 +127,61 @@ void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31");
#elif ENABLE_ARM32
size_t stride = col * sizeof(float);
asm volatile(
"mov r10, %[src_c]\n"
"mov r12, %[dst_c]\n"
"vld1.32 {q0}, [r10], %[stride]\n"
"vld1.32 {q3}, [r10], %[stride]\n"
"vld1.32 {q10}, [r10], %[stride]\n"
"vld1.32 {q13}, [r10], %[stride]\n"
"vtrn.32 d0, d6\n"
"vtrn.32 d1, d7\n"
"vtrn.32 d20, d26\n"
"vtrn.32 d21, d27\n"
"vld1.32 {q1}, [r10], %[stride]\n"
"vld1.32 {q8}, [r10], %[stride]\n"
"vld1.32 {q11}, [r10], %[stride]\n"
"vld1.32 {q14}, [r10], %[stride]\n"
"vswp d1, d20\n"
"vswp d7, d26\n"
"vld1.32 {q2}, [r10], %[stride]\n"
"vld1.32 {q9}, [r10], %[stride]\n"
"vld1.32 {q12}, [r10], %[stride]\n"
"vld1.32 {q15}, [r10], %[stride]\n"
"vtrn.32 d2, d16\n"
"vtrn.32 d3, d17\n"
"vtrn.32 d22, d28\n"
"vtrn.32 d23, d29\n"
"vswp d3, d22\n"
"vswp d17, d28\n"
"vtrn.32 d4, d18\n"
"vtrn.32 d5, d19\n"
"vtrn.32 d24, d30\n"
"vtrn.32 d25, d31\n"
"vswp d5, d24\n"
"vswp d19, d30\n"
"vst1.32 {q0, q1}, [r12]!\n"
"vst1.32 {q2, q3}, [r12]!\n"
"vst1.32 {q8, q9}, [r12]!\n"
"vst1.32 {q10, q11}, [r12]!\n"
"vst1.32 {q12, q13}, [r12]!\n"
"vst1.32 {q14, q15}, [r12]!\n"
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
#else
for (int tr = 0; tr < C12NUM; tr++) {
for (int tc = 0; tc < C4NUM; tc++) {
@ -242,6 +309,75 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col)
return;
}
void RowMajor2Col4Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
size_t row8 = row / C4NUM * C4NUM;
size_t col4 = col / C4NUM * C4NUM;
float *src_r = src_ptr;
float *dst_r = dst_ptr;
size_t ri = 0;
for (; ri < row8; ri += C4NUM) {
size_t ci = 0;
for (; ci < col4; ci += C4NUM) {
float *src_c = src_r + ci;
float *dst_c = dst_r + ci * C4NUM;
/* 4x4 row-major to col-major */
#ifdef ENABLE_ARM32
size_t stride = col * 4;
asm volatile(
"mov r10, %[src_c]\n"
"mov r12, %[dst_c]\n"
"vld1.32 {q0}, [r10], %[stride]\n"
"vld1.32 {q1}, [r10], %[stride]\n"
"vld1.32 {q2}, [r10], %[stride]\n"
"vld1.32 {q3}, [r10], %[stride]\n"
"vtrn.32 d0, d2\n"
"vtrn.32 d1, d3\n"
"vtrn.32 d4, d6\n"
"vtrn.32 d5, d7\n"
"vswp d1, d4\n"
"vswp d3, d6\n"
"vst1.32 {q0}, [r12]!\n"
"vst1.32 {q1}, [r12]!\n"
"vst1.32 {q2}, [r12]!\n"
"vst1.32 {q3}, [r12]!\n"
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "r10", "r12", "q0", "q1", "q2", "q3");
#else
for (int tr = 0; tr < C4NUM; tr++) {
for (int tc = 0; tc < C4NUM; tc++) {
dst_c[tc * C4NUM + tr] = src_c[tr * col + tc];
}
}
#endif
}
for (; ci < col; ci++) {
float *src_c = src_r + ci;
float *dst_c = dst_r + ci * C4NUM;
for (size_t i = 0; i < C4NUM; i++) {
dst_c[i] = src_c[i * col];
}
}
src_r += C4NUM * col;
dst_r += C4NUM * col;
}
for (; ri < row; ri++) {
for (size_t i = 0; i < col; i++) {
dst_r[i * C4NUM] = src_r[i];
}
src_r += col;
dst_r += 1;
}
return;
}
void MatrixUnPackUnit(const void *src, void *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride,
size_t data_lenth) {
size_t copy_size = col * data_lenth;
@ -418,6 +554,9 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
(int)(out_type == OutType_TileC8));
}
#elif ENABLE_ARM32
MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
(int)(out_type == OutType_TileC8));
#else
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
#endif

View File

@ -29,8 +29,10 @@ extern "C" {
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
int col, size_t stride, int out_type);
void RowMajor2Row4Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col4Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride);
@ -40,6 +42,9 @@ void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bi
void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, size_t write_nhwc, size_t write_c4);
void MatmulFloatNeon64OptRemain(const float *a, const float *b, float *c, int depth, int row, int col, size_t stride);
#elif ENABLE_ARM32
void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, size_t write_nhwc, size_t write_c4);
#endif
#ifdef __cplusplus
}

View File

@ -1223,6 +1223,78 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31");
#elif ENABLE_ARM32
size_t srcStride = channel * sizeof(float);
size_t dstStride = plane * sizeof(float);
asm volatile(
"mov r10, %[src_ptr]\n"
"mov r12, %[dst_ptr]\n"
"vld1.32 {q0, q1}, [r10], %[srcStride]\n"
"vld1.32 {q2, q3}, [r10], %[srcStride]\n"
"vtrn.32 d0, d4\n"
"vtrn.32 d1, d5\n"
"vtrn.32 d2, d6\n"
"vtrn.32 d3, d7\n"
"vld1.32 {q4, q5}, [r10], %[srcStride]\n"
"vld1.32 {q6, q7}, [r10], %[srcStride]\n"
"vtrn.32 d8, d12\n"
"vtrn.32 d9, d13\n"
"vtrn.32 d10, d14\n"
"vtrn.32 d11, d15\n"
"vld1.32 {q8, q9}, [r10], %[srcStride]\n"
"vld1.32 {q10, q11}, [r10], %[srcStride]\n"
"vswp d1, d8\n"
"vswp d3, d10\n"
"vswp d5, d12\n"
"vswp d7, d14\n"
"vtrn.32 d16, d20\n"
"vtrn.32 d17, d21\n"
"vtrn.32 d18, d22\n"
"vtrn.32 d19, d23\n"
"vld1.32 {q12, q13}, [r10], %[srcStride]\n"
"vld1.32 {q14, q15}, [r10], %[srcStride]\n"
"vtrn.32 d24, d28\n"
"vtrn.32 d25, d29\n"
"vtrn.32 d26, d30\n"
"vtrn.32 d27, d31\n"
"vswp d17, d24\n"
"vswp d19, d26\n"
"vswp d21, d28\n"
"vswp d23, d30\n"
"add r10, r12, #16\n"
"vst1.32 {q0}, [r12], %[dstStride]\n"
"vst1.32 {q8}, [r10], %[dstStride]\n"
"vst1.32 {q2}, [r12], %[dstStride]\n"
"vst1.32 {q10}, [r10], %[dstStride]\n"
"vst1.32 {q4}, [r12], %[dstStride]\n"
"vst1.32 {q12}, [r10], %[dstStride]\n"
"vst1.32 {q6}, [r12], %[dstStride]\n"
"vst1.32 {q14}, [r10], %[dstStride]\n"
"vst1.32 {q1}, [r12], %[dstStride]\n"
"vst1.32 {q9}, [r10], %[dstStride]\n"
"vst1.32 {q3}, [r12], %[dstStride]\n"
"vst1.32 {q11}, [r10], %[dstStride]\n"
"vst1.32 {q5}, [r12], %[dstStride]\n"
"vst1.32 {q13}, [r10], %[dstStride]\n"
"vst1.32 {q7}, [r12], %[dstStride]\n"
"vst1.32 {q15}, [r10], %[dstStride]\n"
:
:
[ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
: "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
"q15");
#else
for (int tr = 0; tr < C8NUM; tr++) {
for (int tc = 0; tc < C8NUM; tc++) {

View File

@ -67,8 +67,13 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
}
}
// input transform
#ifdef ENABLE_ARM32
int tile_num = 4;
#else
int tile_num = 12;
#endif
int dst_ic4_offset = dst_plane_offset + ic * C4NUM;
size_t dst_step = C12NUM * ic4 * C4NUM;
size_t dst_step = tile_num * ic4 * C4NUM;
float *trans_input_ptr = trans_input + dst_ic4_offset;
func(tmp_data, trans_input_ptr, C4NUM, dst_step);
// GeneralInputTransformUnit(tmp_data, trans_input_ptr, matrix_b, matrix_bt, C4NUM, dst_step, input_unit);
@ -331,8 +336,13 @@ void Conv3x3Fp32InputTransform(const float *input_data, float *trans_input, floa
}
// input transform
#ifdef ENABLE_ARM32
int tile_num = 4;
#else
int tile_num = 12;
#endif
int dst_ic4_offset = dst_plane_offset + ic * C4NUM;
size_t dst_step = C12NUM * ic4 * C4NUM;
size_t dst_step = tile_num * ic4 * C4NUM;
float *trans_input_ptr = trans_input + dst_ic4_offset;
Conv3x3Fp32InputUnit(tmp_data, trans_input_ptr, dst_step);
}

View File

@ -26,15 +26,13 @@ if (PLATFORM_ARM64)
set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC})
endif()
#[[
if (PLATFORM_ARM32)
# assembly
file(GLOB ASSEMBLY_SRC nnacl/assembly/arm32/*.s
nnacl/assembly/arm32/*.S
file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../../../nnacl/assembly/arm32/*.s
${CMAKE_CURRENT_SOURCE_DIR}/../../../../nnacl/assembly/arm32/*.S
)
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC})
endif()
]]
add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC} ${TRAIN_KERNEL_SRC})

View File

@ -59,6 +59,7 @@ void Convolution1x1CPUKernel::InitConv1x1MatmulParam() {
matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_;
matmul_param_->col_ = conv_param_->output_channel_;
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM);
matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM);
matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM);
matmul_param_->act_type_ = conv_param_->act_type_;
@ -120,8 +121,11 @@ void Convolution1x1CPUKernel::Pre1x1Trans(float *src_input, float *src_output) {
} else {
input_ptr_ = src_input;
}
#ifdef ENABLE_ARM32
RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
#else
RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
#endif
return;
}
@ -169,8 +173,13 @@ int Convolution1x1CPUKernel::Run() {
auto src_in = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
auto src_out = reinterpret_cast<float *>(out_tensors_[0]->MutableData());
#ifdef ENABLE_ARM32
pack_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float)));
#else
pack_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float)));
#endif
if (pack_input_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!";
return RET_MEMORY_FAILED;

View File

@ -95,7 +95,12 @@ int Convolution3x3CPUKernel::InitTmpBuffer() {
const int k_plane = 16;
MS_ASSERT(ctx_->allocator != nullptr);
size_t tile_buffer_size = thread_count_ * C12NUM * C16NUM * ic4 * C4NUM * sizeof(float);
#ifdef ENABLE_ARM32
int tile_num = 4;
#else
int tile_num = 12;
#endif
size_t tile_buffer_size = thread_count_ * tile_num * C16NUM * ic4 * C4NUM * sizeof(float);
tile_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size));
if (tile_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tile buffer failed.";
@ -109,14 +114,14 @@ int Convolution3x3CPUKernel::InitTmpBuffer() {
return RET_ERROR;
}
size_t tmp_dst_buffer_size = thread_count_ * C12NUM * k_plane * oC8 * C8NUM * sizeof(float);
size_t tmp_dst_buffer_size = thread_count_ * tile_num * k_plane * oC8 * C8NUM * sizeof(float);
tmp_dst_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tmp_dst_buffer_size));
if (tmp_dst_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed.";
return RET_ERROR;
}
size_t col_buffer_size = thread_count_ * C12NUM * C4NUM * ic4 * sizeof(float);
size_t col_buffer_size = thread_count_ * tile_num * C4NUM * ic4 * sizeof(float);
col_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(col_buffer_size));
if (col_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc col_buffer_ failed.";

View File

@ -150,9 +150,14 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
int oc4 = UP_DIV(channel_out, C4NUM);
int oc8 = UP_DIV(channel_out, C8NUM);
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
#ifdef ENABLE_ARM32
int tile_num = 4;
#else
int tile_num = 12;
#endif
MS_ASSERT(ctx_->allocator != nullptr);
size_t tile_buffer_size = thread_count_ * C12NUM * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float);
size_t tile_buffer_size = thread_count_ * tile_num * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float);
trans_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size));
if (trans_input_ == nullptr) {
MS_LOG(ERROR) << "malloc trans_input_ failed.";
@ -160,7 +165,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
}
gemm_out_ = reinterpret_cast<float *>(
ctx_->allocator->Malloc(thread_count_ * C12NUM * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float)));
ctx_->allocator->Malloc(thread_count_ * tile_num * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float)));
if (gemm_out_ == nullptr) {
MS_LOG(ERROR) << "malloc gemm_out_ failed.";
return RET_ERROR;
@ -184,7 +189,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
}
col_buffer_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(thread_count_ * C12NUM * ic4 * C4NUM * sizeof(float)));
reinterpret_cast<float *>(ctx_->allocator->Malloc(thread_count_ * tile_num * ic4 * C4NUM * sizeof(float)));
if (col_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc col_buffer_ failed.";
return RET_ERROR;

View File

@ -85,6 +85,7 @@ int DeConvolutionCPUKernel::InitParam() {
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_;
matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM);
matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM);
matmul_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * kernel_plane_;
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM));
@ -112,10 +113,17 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) {
return RET_OK;
}
#ifdef ENABLE_ARM32
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_4_;
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_4_, oc * C8NUM * kernel_plane_,
matmul_param_->col_, OutType_C8);
#else
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_12_;
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_12_, oc * C8NUM * kernel_plane_,
matmul_param_->col_, OutType_C8);
#endif
DeConvPostFp32C12x8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_,
reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id * C8NUM,
@ -159,15 +167,25 @@ int DeConvolutionCPUKernel::InitRunBuf() {
return RET_NULL_PTR;
}
#ifdef ENABLE_ARM32
tmp_buffer_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->col_8_ * sizeof(float)));
#else
tmp_buffer_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->col_8_ * sizeof(float)));
#endif
if (tmp_buffer_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc tmp_buffer_ error!";
return RET_NULL_PTR;
}
#ifdef ENABLE_ARM32
pack_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float)));
#else
pack_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float)));
#endif
if (pack_input_ == nullptr) {
MS_LOG(ERROR) << "deconv Malloc pack_input_ error!";
return RET_ERROR;

View File

@ -49,6 +49,7 @@ int FullconnectionCPUKernel::ReSize() {
fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM);
fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM);
thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_);
@ -59,11 +60,19 @@ int FullconnectionCPUKernel::ReSize() {
memcpy(bias_ptr_, in_tensors_[2]->MutableData(), fc_param_->col_ * sizeof(float));
}
#ifdef ENABLE_ARM32
a_c12_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_4_ * fc_param_->deep_ * sizeof(float)));
if (a_c12_ptr_ == nullptr) {
return RET_MEMORY_FAILED;
}
memset(a_c12_ptr_, 0, fc_param_->row_4_ * fc_param_->deep_ * sizeof(float));
#else
a_c12_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_12_ * fc_param_->deep_ * sizeof(float)));
if (a_c12_ptr_ == nullptr) {
return RET_MEMORY_FAILED;
}
memset(a_c12_ptr_, 0, fc_param_->row_12_ * fc_param_->deep_ * sizeof(float));
#endif
b_r8_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(float)));
if (b_r8_ptr_ == nullptr) {
@ -87,7 +96,11 @@ int FullconnectionCPUKernel::Init() {
}
void FullconnectionCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) {
#ifdef ENABLE_ARM32
RowMajor2Col4Major(src_ptr, a_c12_ptr_, fc_param_->row_, fc_param_->deep_);
#else
RowMajor2Col12Major(src_ptr, a_c12_ptr_, fc_param_->row_, fc_param_->deep_);
#endif
}
void FullconnectionCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) {

View File

@ -62,17 +62,27 @@ int MatmulCPUKernel::ReSize() {
params_->row_ = c_shape[c_shape.size() - 2];
params_->col_ = c_shape[c_shape.size() - 1];
params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1];
params_->row_4_ = UP_ROUND(params_->row_, C4NUM);
params_->row_12_ = UP_ROUND(params_->row_, C12NUM);
params_->col_8_ = UP_ROUND(params_->col_, 8);
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_);
#ifdef ENABLE_ARM32
a_c12_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * params_->row_4_ * params_->deep_ * sizeof(float)));
if (a_c12_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memset(a_c12_ptr_, 0, params_->row_4_ * params_->deep_ * sizeof(float));
#else
a_c12_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * params_->row_12_ * params_->deep_ * sizeof(float)));
if (a_c12_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memset(a_c12_ptr_, 0, params_->row_12_ * params_->deep_ * sizeof(float));
#endif
b_r8_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * params_->col_8_ * params_->deep_ * sizeof(float)));
if (b_r8_ptr_ == nullptr) {
@ -106,12 +116,21 @@ int MatmulCPUKernel::ReSize() {
void MatmulCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) {
for (int i = 0; i < params_->batch; i++) {
float *src = src_ptr + i * params_->deep_ * params_->row_;
#ifdef ENABLE_ARM32
float *dst = dst_ptr + i * params_->deep_ * params_->row_4_;
if (params_->a_transpose_) {
RowMajor2Row4Major(src, dst, params_->deep_, params_->row_);
} else {
RowMajor2Col4Major(src, dst, params_->row_, params_->deep_);
}
#else
float *dst = dst_ptr + i * params_->deep_ * params_->row_12_;
if (params_->a_transpose_) {
RowMajor2Row12Major(src, dst, params_->deep_, params_->row_);
} else {
RowMajor2Col12Major(src, dst, params_->row_, params_->deep_);
}
#endif
}
return;
}

View File

@ -79,7 +79,7 @@ if (PLATFORM_ARM64)
${TEST_ASSEMBLY_SRC}
)
endif()
#[[
if (PLATFORM_ARM32)
# assembly
file(GLOB TEST_ASSEMBLY_SRC
@ -91,7 +91,7 @@ if (PLATFORM_ARM32)
${TEST_ASSEMBLY_SRC}
)
endif()
]]
if (ENABLE_FP16)
file(GLOB KERNEL_OP_FP16_SRC
${LITE_DIR}/src/runtime/kernel/arm/fp16/*.cc