!3930 fix fp32 kernel bugs on arm32 and ReluFp32

Merge pull request !3930 from lixian/master
This commit is contained in:
mindspore-ci-bot 2020-08-04 16:13:25 +08:00 committed by Gitee
commit 6a1e6b01f7
3 changed files with 98 additions and 95 deletions

View File

@ -17,16 +17,18 @@
IndirectGemmFp32_8x4:
.macro INIT_BIAS
veor q10, q10, q10
veor q8, q8, q8
cmp r3, #0
beq InitBias
vld1.32 q10, [r3]
vld1.32 {q8}, [r3]
InitBias:
vmov q11, q10
vmov q12, q10
vmov q13, q10
vmov q14, q10
vmov q15, q10
vmov q9, q8
vmov q10, q8
vmov q11, q8
vmov q12, q8
vmov q13, q8
vmov q14, q8
vmov q15, q8
.endm
// at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr"
@ -36,7 +38,7 @@ IndirectGemmFp32_8x4:
// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf
push {r4-r8, r10, r11, lr}
vpush {q4-q7}
add sp, sp, #160
add sp, sp, #96
ldr r4, [sp]
ldr r5, [sp, #4]
@ -66,8 +68,8 @@ IndirectGemmFp32_8x4:
// load weight
vld1.32 {q4, q5}, [r2]!
// step for output 1-2
vmul.f32 q8, q4, d0[0]
vmul.f32 q9, q4, d2[0]
vmla.f32 q8, q4, d0[0]
vmla.f32 q9, q4, d2[0]
vmla.f32 q8, q5, d0[1]
vmla.f32 q9, q5, d2[1]
vld1.32 {q6, q7}, [r2]!
@ -158,31 +160,31 @@ IndirectGemmFp32_8x4:
bne Relu
b WriteStart
Relu6:
vmov.i32 q14, #6
vcvt.f32.s32 q14, q14
vmin.f32 q0, q0, q14
vmin.f32 q1, q1, q14
vmin.f32 q2, q2, q14
vmin.f32 q3, q3, q14
vmin.f32 q4, q4, q14
vmin.f32 q5, q5, q14
vmin.f32 q6, q6, q14
vmin.f32 q7, q15, q14
vmov.i32 q7, #6
vcvt.f32.s32 q7, q7
vmin.f32 q8, q8, q7
vmin.f32 q9, q9, q7
vmin.f32 q10, q10, q7
vmin.f32 q11, q11, q7
vmin.f32 q12, q12, q7
vmin.f32 q13, q13, q7
vmin.f32 q14, q14, q7
vmin.f32 q15, q15, q7
Relu:
veor q7, q7, q7
vmax.f32 q0, q8, q7
vmax.f32 q1, q9, q7
vmax.f32 q2, q10, q7
vmax.f32 q3, q11, q7
vmax.f32 q4, q12, q7
vmax.f32 q5, q13, q7
vmax.f32 q6, q14, q7
vmax.f32 q8, q8, q7
vmax.f32 q9, q9, q7
vmax.f32 q10, q10, q7
vmax.f32 q11, q11, q7
vmax.f32 q12, q12, q7
vmax.f32 q13, q13, q7
vmax.f32 q14, q14, q7
vmax.f32 q15, q15, q7
WriteStart:
ldr r10, [sp, #20]
cmp r10, #0
bne WriteC4
bne Write4
cmp r6, #1
beq Write1
cmp r6, #2
@ -191,98 +193,91 @@ IndirectGemmFp32_8x4:
beq Write3
b Write4
Write1:
vst1.32 d0[0], [r11]
vst1.32 d16[0], [r11]
add r11, r11, r7
vst1.32 d2[0], [r11]
vst1.32 d18[0], [r11]
add r11, r11, r7
vst1.32 d4[0], [r11]
vst1.32 d20[0], [r11]
add r11, r11, r7
vst1.32 d6[0], [r11]
vst1.32 d22[0], [r11]
add r11, r11, r7
vst1.32 d8[0], [r11]
vst1.32 d24[0], [r11]
add r11, r11, r7
vst1.32 d10[0], [r11]
vst1.32 d26[0], [r11]
add r11, r11, r7
vst1.32 d12[0], [r11]
vst1.32 d28[0], [r11]
add r11, r11, r7
vst1.32 d30[0], [r11]
add r11, r11, r7
vst1.32 d14[0], [r11]
add r0, r0, #4
b WriteEnd
Write2:
vst1.32 d0, [r11]
vst1.32 d16, [r11]
add r11, r11, r7
vst1.32 d2, [r11]
vst1.32 d18, [r11]
add r11, r11, r7
vst1.32 d4, [r11]
vst1.32 d20, [r11]
add r11, r11, r7
vst1.32 d6, [r11]
vst1.32 d22, [r11]
add r11, r11, r7
vst1.32 d8, [r11]
vst1.32 d24, [r11]
add r11, r11, r7
vst1.32 d10, [r11]
vst1.32 d26, [r11]
add r11, r11, r7
vst1.32 d12, [r11]
vst1.32 d28, [r11]
add r11, r11, r7
vst1.32 d30, [r11]
add r11, r11, r7
vst1.32 d14, [r11]
add r0, r0, #8
b WriteEnd
Write3:
add r12, r11, #8
vst1.32 d0, [r11]
add lr, r11, #8
vst1.32 d16, [r11]
add r11, r11, r7
vst1.32 d1[0], [r12]
add r12, r12, r7
vst1.32 d2, [r11]
vst1.32 d17[0], [lr]
add lr, lr, r7
vst1.32 d18, [r11]
add r11, r11, r7
vst1.32 d3[0], [r12]
add r12, r12, r7
vst1.32 d4, [r11]
vst1.32 d19[0], [lr]
add lr, lr, r7
vst1.32 d20, [r11]
add r11, r11, r7
vst1.32 d5[0], [r12]
add r12, r12, r7
vst1.32 d6, [r11]
vst1.32 d21[0], [lr]
add lr, lr, r7
vst1.32 d22, [r11]
add r11, r11, r7
vst1.32 d7[0], [r12]
add r12, r12, r7
vst1.32 d8, [r11]
vst1.32 d23[0], [lr]
add lr, lr, r7
vst1.32 d24, [r11]
add r11, r11, r7
vst1.32 d9[0], [r12]
add r12, r12, r7
vst1.32 d10, [r11]
vst1.32 d25[0], [lr]
add lr, lr, r7
vst1.32 d26, [r11]
add r11, r11, r7
vst1.32 d11[0], [r12]
add r12, r12, r7
vst1.32 d12, [r11]
vst1.32 d27[0], [lr]
add lr, lr, r7
vst1.32 d28, [r11]
add r11, r11, r7
vst1.32 d13[0], [r12]
add r12, r12, r7
vst1.32 d14, [r11]
vst1.32 d15[0], [r12]
vst1.32 d29[0], [lr]
add lr, lr, r7
vst1.32 d30, [r11]
add r11, r11, r7
vst1.32 d31[0], [lr]
add lr, lr, r7
add r0, r0, #12
b WriteEnd
WriteC4:
vst1.32 q0, [r11], r7
vst1.32 q1, [r11], r7
vst1.32 q2, [r11], r7
vst1.32 q3, [r11], r7
vst1.32 q4, [r11], r7
vst1.32 q5, [r11], r7
vst1.32 q6, [r11], r7
vst1.32 q7, [r11]
add r0, r0, #16
b WriteEnd
Write4:
// prefetching is not prefered while writing results in spite of cache missings
// you could try prfm pstl2vst1.32m
// you could try pld
// there are almost no benefits observed though
vst1.32 q0, [r11], r7
vst1.32 q1, [r11], r7
vst1.32 q2, [r11], r7
vst1.32 q3, [r11], r7
vst1.32 q4, [r11], r7
vst1.32 q5, [r11], r7
vst1.32 q6, [r11], r7
vst1.32 q7, [r11]
vst1.32 {q8}, [r11], r7
vst1.32 {q9}, [r11], r7
vst1.32 {q10}, [r11], r7
vst1.32 {q11}, [r11], r7
vst1.32 {q12}, [r11], r7
vst1.32 {q13}, [r11], r7
vst1.32 {q14}, [r11], r7
vst1.32 {q15}, [r11], r7
add r0, r0, #16
WriteEnd:
@ -290,14 +285,17 @@ IndirectGemmFp32_8x4:
subs r8, r8, #1
bne LoopKsize
subs r6, r6, #4
cmp r6, #4
ble LoopOcEnd
sub r6, r6, #4
cmp r3, #0
beq NoStepFowrard
add r3, r3, #16
NoStepFowrard:
bgt LoopOc
b LoopOc
add sp, sp, #160
LoopOcEnd:
sub sp, sp, #96
vpop {q4-q7}
pop {r4-r8, r10, r11, pc}
#endif

View File

@ -31,7 +31,7 @@ IndirectGemmInt8_2x4:
// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf
push {r4-r8, r10, r11, lr}
vpush {q4-q7}
add sp, sp, #160
add sp, sp, #96
ldr r4, [sp]
ldr r5, [sp, #4]
@ -226,14 +226,17 @@ IndirectGemmInt8_2x4:
subs r8, r8, #1
bne LoopKsize
subs r6, r6, #4
cmp r6, #4
ble LoopOcEnd
sub r6, r6, #4
cmp r3, #0
beq NoStepFowrard
add r3, r3, #16
NoStepFowrard:
bgt LoopOc
b LoopOc
add sp, sp, #160
LoopOcEnd:
sub sp, sp, #96
vpop {q4-q7}
pop {r4-r8, r10, r11, pc}
#endif

View File

@ -159,6 +159,7 @@ void ReluFp32(float *data, int ele_num) {
float32x4_t relu_data = vld1q_f32(data + index);
float32x4_t zero_data = vdupq_n_f32(0);
relu_data = vmaxq_f32(relu_data, zero_data);
vst1q_f32(data + index, relu_data);
#else
data[index] = data[index] < 0 ? 0 : data[index];
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
@ -181,6 +182,7 @@ void Relu6Fp32(float *data, int ele_num) {
float32x4_t six_data = vdupq_n_f32(6);
relu6_data = vmaxq_f32(relu6_data, zero_data);
relu6_data = vminq_f32(relu6_data, six_data);
vst1q_f32(data + index, relu6_data);
#else
data[index] = data[index] < 0 ? 0 : data[index];
data[index] = data[index] > 6 ? 6 : data[index];