forked from OSSInnovation/mindspore
!3930 fix fp32 kernel bugs on arm32 and ReluFp32
Merge pull request !3930 from lixian/master
This commit is contained in:
commit
6a1e6b01f7
|
@ -17,16 +17,18 @@
|
|||
IndirectGemmFp32_8x4:
|
||||
|
||||
.macro INIT_BIAS
|
||||
veor q10, q10, q10
|
||||
veor q8, q8, q8
|
||||
cmp r3, #0
|
||||
beq InitBias
|
||||
vld1.32 q10, [r3]
|
||||
vld1.32 {q8}, [r3]
|
||||
InitBias:
|
||||
vmov q11, q10
|
||||
vmov q12, q10
|
||||
vmov q13, q10
|
||||
vmov q14, q10
|
||||
vmov q15, q10
|
||||
vmov q9, q8
|
||||
vmov q10, q8
|
||||
vmov q11, q8
|
||||
vmov q12, q8
|
||||
vmov q13, q8
|
||||
vmov q14, q8
|
||||
vmov q15, q8
|
||||
.endm
|
||||
|
||||
// at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr"
|
||||
|
@ -36,7 +38,7 @@ IndirectGemmFp32_8x4:
|
|||
// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf
|
||||
push {r4-r8, r10, r11, lr}
|
||||
vpush {q4-q7}
|
||||
add sp, sp, #160
|
||||
add sp, sp, #96
|
||||
|
||||
ldr r4, [sp]
|
||||
ldr r5, [sp, #4]
|
||||
|
@ -66,8 +68,8 @@ IndirectGemmFp32_8x4:
|
|||
// load weight
|
||||
vld1.32 {q4, q5}, [r2]!
|
||||
// step for output 1-2
|
||||
vmul.f32 q8, q4, d0[0]
|
||||
vmul.f32 q9, q4, d2[0]
|
||||
vmla.f32 q8, q4, d0[0]
|
||||
vmla.f32 q9, q4, d2[0]
|
||||
vmla.f32 q8, q5, d0[1]
|
||||
vmla.f32 q9, q5, d2[1]
|
||||
vld1.32 {q6, q7}, [r2]!
|
||||
|
@ -158,31 +160,31 @@ IndirectGemmFp32_8x4:
|
|||
bne Relu
|
||||
b WriteStart
|
||||
Relu6:
|
||||
vmov.i32 q14, #6
|
||||
vcvt.f32.s32 q14, q14
|
||||
vmin.f32 q0, q0, q14
|
||||
vmin.f32 q1, q1, q14
|
||||
vmin.f32 q2, q2, q14
|
||||
vmin.f32 q3, q3, q14
|
||||
vmin.f32 q4, q4, q14
|
||||
vmin.f32 q5, q5, q14
|
||||
vmin.f32 q6, q6, q14
|
||||
vmin.f32 q7, q15, q14
|
||||
vmov.i32 q7, #6
|
||||
vcvt.f32.s32 q7, q7
|
||||
vmin.f32 q8, q8, q7
|
||||
vmin.f32 q9, q9, q7
|
||||
vmin.f32 q10, q10, q7
|
||||
vmin.f32 q11, q11, q7
|
||||
vmin.f32 q12, q12, q7
|
||||
vmin.f32 q13, q13, q7
|
||||
vmin.f32 q14, q14, q7
|
||||
vmin.f32 q15, q15, q7
|
||||
Relu:
|
||||
veor q7, q7, q7
|
||||
vmax.f32 q0, q8, q7
|
||||
vmax.f32 q1, q9, q7
|
||||
vmax.f32 q2, q10, q7
|
||||
vmax.f32 q3, q11, q7
|
||||
vmax.f32 q4, q12, q7
|
||||
vmax.f32 q5, q13, q7
|
||||
vmax.f32 q6, q14, q7
|
||||
vmax.f32 q8, q8, q7
|
||||
vmax.f32 q9, q9, q7
|
||||
vmax.f32 q10, q10, q7
|
||||
vmax.f32 q11, q11, q7
|
||||
vmax.f32 q12, q12, q7
|
||||
vmax.f32 q13, q13, q7
|
||||
vmax.f32 q14, q14, q7
|
||||
vmax.f32 q15, q15, q7
|
||||
|
||||
WriteStart:
|
||||
ldr r10, [sp, #20]
|
||||
cmp r10, #0
|
||||
bne WriteC4
|
||||
bne Write4
|
||||
cmp r6, #1
|
||||
beq Write1
|
||||
cmp r6, #2
|
||||
|
@ -191,98 +193,91 @@ IndirectGemmFp32_8x4:
|
|||
beq Write3
|
||||
b Write4
|
||||
Write1:
|
||||
vst1.32 d0[0], [r11]
|
||||
vst1.32 d16[0], [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d2[0], [r11]
|
||||
vst1.32 d18[0], [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d4[0], [r11]
|
||||
vst1.32 d20[0], [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d6[0], [r11]
|
||||
vst1.32 d22[0], [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d8[0], [r11]
|
||||
vst1.32 d24[0], [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d10[0], [r11]
|
||||
vst1.32 d26[0], [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d12[0], [r11]
|
||||
vst1.32 d28[0], [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d30[0], [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d14[0], [r11]
|
||||
add r0, r0, #4
|
||||
b WriteEnd
|
||||
Write2:
|
||||
vst1.32 d0, [r11]
|
||||
vst1.32 d16, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d2, [r11]
|
||||
vst1.32 d18, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d4, [r11]
|
||||
vst1.32 d20, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d6, [r11]
|
||||
vst1.32 d22, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d8, [r11]
|
||||
vst1.32 d24, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d10, [r11]
|
||||
vst1.32 d26, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d12, [r11]
|
||||
vst1.32 d28, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d30, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d14, [r11]
|
||||
add r0, r0, #8
|
||||
b WriteEnd
|
||||
Write3:
|
||||
add r12, r11, #8
|
||||
vst1.32 d0, [r11]
|
||||
add lr, r11, #8
|
||||
vst1.32 d16, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d1[0], [r12]
|
||||
add r12, r12, r7
|
||||
vst1.32 d2, [r11]
|
||||
vst1.32 d17[0], [lr]
|
||||
add lr, lr, r7
|
||||
vst1.32 d18, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d3[0], [r12]
|
||||
add r12, r12, r7
|
||||
vst1.32 d4, [r11]
|
||||
vst1.32 d19[0], [lr]
|
||||
add lr, lr, r7
|
||||
vst1.32 d20, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d5[0], [r12]
|
||||
add r12, r12, r7
|
||||
vst1.32 d6, [r11]
|
||||
vst1.32 d21[0], [lr]
|
||||
add lr, lr, r7
|
||||
vst1.32 d22, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d7[0], [r12]
|
||||
add r12, r12, r7
|
||||
vst1.32 d8, [r11]
|
||||
vst1.32 d23[0], [lr]
|
||||
add lr, lr, r7
|
||||
vst1.32 d24, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d9[0], [r12]
|
||||
add r12, r12, r7
|
||||
vst1.32 d10, [r11]
|
||||
vst1.32 d25[0], [lr]
|
||||
add lr, lr, r7
|
||||
vst1.32 d26, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d11[0], [r12]
|
||||
add r12, r12, r7
|
||||
vst1.32 d12, [r11]
|
||||
vst1.32 d27[0], [lr]
|
||||
add lr, lr, r7
|
||||
vst1.32 d28, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d13[0], [r12]
|
||||
add r12, r12, r7
|
||||
vst1.32 d14, [r11]
|
||||
vst1.32 d15[0], [r12]
|
||||
vst1.32 d29[0], [lr]
|
||||
add lr, lr, r7
|
||||
vst1.32 d30, [r11]
|
||||
add r11, r11, r7
|
||||
vst1.32 d31[0], [lr]
|
||||
add lr, lr, r7
|
||||
add r0, r0, #12
|
||||
b WriteEnd
|
||||
WriteC4:
|
||||
vst1.32 q0, [r11], r7
|
||||
vst1.32 q1, [r11], r7
|
||||
vst1.32 q2, [r11], r7
|
||||
vst1.32 q3, [r11], r7
|
||||
vst1.32 q4, [r11], r7
|
||||
vst1.32 q5, [r11], r7
|
||||
vst1.32 q6, [r11], r7
|
||||
vst1.32 q7, [r11]
|
||||
add r0, r0, #16
|
||||
b WriteEnd
|
||||
Write4:
|
||||
// prefetching is not prefered while writing results in spite of cache missings
|
||||
// you could try prfm pstl2vst1.32m
|
||||
// you could try pld
|
||||
// there are almost no benefits observed though
|
||||
vst1.32 q0, [r11], r7
|
||||
vst1.32 q1, [r11], r7
|
||||
vst1.32 q2, [r11], r7
|
||||
vst1.32 q3, [r11], r7
|
||||
vst1.32 q4, [r11], r7
|
||||
vst1.32 q5, [r11], r7
|
||||
vst1.32 q6, [r11], r7
|
||||
vst1.32 q7, [r11]
|
||||
vst1.32 {q8}, [r11], r7
|
||||
vst1.32 {q9}, [r11], r7
|
||||
vst1.32 {q10}, [r11], r7
|
||||
vst1.32 {q11}, [r11], r7
|
||||
vst1.32 {q12}, [r11], r7
|
||||
vst1.32 {q13}, [r11], r7
|
||||
vst1.32 {q14}, [r11], r7
|
||||
vst1.32 {q15}, [r11], r7
|
||||
add r0, r0, #16
|
||||
|
||||
WriteEnd:
|
||||
|
@ -290,14 +285,17 @@ IndirectGemmFp32_8x4:
|
|||
subs r8, r8, #1
|
||||
bne LoopKsize
|
||||
|
||||
subs r6, r6, #4
|
||||
cmp r6, #4
|
||||
ble LoopOcEnd
|
||||
sub r6, r6, #4
|
||||
cmp r3, #0
|
||||
beq NoStepFowrard
|
||||
add r3, r3, #16
|
||||
NoStepFowrard:
|
||||
bgt LoopOc
|
||||
b LoopOc
|
||||
|
||||
add sp, sp, #160
|
||||
LoopOcEnd:
|
||||
sub sp, sp, #96
|
||||
vpop {q4-q7}
|
||||
pop {r4-r8, r10, r11, pc}
|
||||
#endif
|
||||
|
|
|
@ -31,7 +31,7 @@ IndirectGemmInt8_2x4:
|
|||
// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf
|
||||
push {r4-r8, r10, r11, lr}
|
||||
vpush {q4-q7}
|
||||
add sp, sp, #160
|
||||
add sp, sp, #96
|
||||
|
||||
ldr r4, [sp]
|
||||
ldr r5, [sp, #4]
|
||||
|
@ -226,14 +226,17 @@ IndirectGemmInt8_2x4:
|
|||
subs r8, r8, #1
|
||||
bne LoopKsize
|
||||
|
||||
subs r6, r6, #4
|
||||
cmp r6, #4
|
||||
ble LoopOcEnd
|
||||
sub r6, r6, #4
|
||||
cmp r3, #0
|
||||
beq NoStepFowrard
|
||||
add r3, r3, #16
|
||||
NoStepFowrard:
|
||||
bgt LoopOc
|
||||
b LoopOc
|
||||
|
||||
add sp, sp, #160
|
||||
LoopOcEnd:
|
||||
sub sp, sp, #96
|
||||
vpop {q4-q7}
|
||||
pop {r4-r8, r10, r11, pc}
|
||||
#endif
|
||||
|
|
|
@ -159,6 +159,7 @@ void ReluFp32(float *data, int ele_num) {
|
|||
float32x4_t relu_data = vld1q_f32(data + index);
|
||||
float32x4_t zero_data = vdupq_n_f32(0);
|
||||
relu_data = vmaxq_f32(relu_data, zero_data);
|
||||
vst1q_f32(data + index, relu_data);
|
||||
#else
|
||||
data[index] = data[index] < 0 ? 0 : data[index];
|
||||
data[index + 1] = data[index + 1] < 0 ? 0 : data[index + 1];
|
||||
|
@ -181,6 +182,7 @@ void Relu6Fp32(float *data, int ele_num) {
|
|||
float32x4_t six_data = vdupq_n_f32(6);
|
||||
relu6_data = vmaxq_f32(relu6_data, zero_data);
|
||||
relu6_data = vminq_f32(relu6_data, six_data);
|
||||
vst1q_f32(data + index, relu6_data);
|
||||
#else
|
||||
data[index] = data[index] < 0 ? 0 : data[index];
|
||||
data[index] = data[index] > 6 ? 6 : data[index];
|
||||
|
|
Loading…
Reference in New Issue