forked from mindspore-Ecosystem/mindspore
fp16 bug fix
This commit is contained in:
parent
679985adb8
commit
f18d7cf435
2
build.sh
2
build.sh
|
@ -607,7 +607,7 @@ build_lite()
|
|||
cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \
|
||||
-DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="clang" \
|
||||
-DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
|
||||
-DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \
|
||||
-DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DENABLE_FP16="on" \
|
||||
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \
|
||||
-DSUPPORT_GPU=${LOCAL_LITE_ENABLE_GPU} -DSUPPORT_NPU=${LOCAL_LITE_ENABLE_NPU} -DENABLE_V0=on \
|
||||
-DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \
|
||||
|
|
|
@ -68,7 +68,7 @@ add_library(nnacl_mid OBJECT ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC})
|
|||
add_dependencies(nnacl fbs_src)
|
||||
add_dependencies(nnacl_mid fbs_src)
|
||||
|
||||
########################### arm64 build optimize library ########################
|
||||
if(PLATFORM_ARM64)
|
||||
########################### arm fp16 build optimize library ########################
|
||||
if(ENABLE_FP16)
|
||||
add_subdirectory(${NNACL_DIR}/optimize)
|
||||
endif()
|
||||
|
|
|
@ -30,7 +30,7 @@ typedef struct ArgElement {
|
|||
int8_t i8_data_;
|
||||
int32_t i_data_;
|
||||
float f_data_;
|
||||
#ifdef ENABLE_ARM64
|
||||
#ifdef ENABLE_ARM
|
||||
float16_t f16_data_;
|
||||
#endif
|
||||
} data_;
|
||||
|
|
|
@ -0,0 +1,602 @@
|
|||
#ifdef ENABLE_ARM32
|
||||
#include "nnacl/assembly_global.h"
|
||||
.text
|
||||
.align 5
|
||||
.global MatMul12x8A32Fp16
|
||||
#ifndef __APPLE__
|
||||
.type MatMul12x8A32Fp16, %function
|
||||
#endif
|
||||
|
||||
// void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
|
||||
// int deep, int row, int col, int stride, bool write_mode);
|
||||
// r0: a
|
||||
// r1: b
|
||||
// r2: dst
|
||||
// r3: bias
|
||||
// #4: depth
|
||||
// #8: row
|
||||
// #12: col
|
||||
// #16: stride
|
||||
// #20: writeNhwc/writeWino
|
||||
|
||||
asm_function MatMul12x8A32Fp16
|
||||
// r13(sp) and r15(pc) can not be used!!
|
||||
// r9 r4 is tmp register
|
||||
// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf
|
||||
push {r3-r11, lr}
|
||||
vpush {q4-q7}
|
||||
add sp, sp, #104
|
||||
|
||||
ldr r5, [sp, #4]
|
||||
ldr r6, [sp, #8]
|
||||
ldr r7, [sp, #12]
|
||||
ldr r8, [sp, #16]
|
||||
ldr lr, [sp, #20]
|
||||
|
||||
mov r10, r1 // b
|
||||
mov r11, r0 // a
|
||||
mov r12, r2 // dst
|
||||
|
||||
cmp lr, #2
|
||||
bne NoWinograd
|
||||
mul r4, r8, r7 // stride * col
|
||||
add r4, r4, r4 // r4 * sizeof(float16_t)
|
||||
mov r9, #16
|
||||
mul r9, r8, r9 // stride * 8 * sizeof(float16_t)
|
||||
NoWinograd:
|
||||
add r8, r8, r8 // stride * sizeof(float16_t)
|
||||
|
||||
a .req r0
|
||||
weight .req r1
|
||||
dst .req r2
|
||||
bias .req r3
|
||||
depth .req r5
|
||||
row .req r6
|
||||
col .req r7
|
||||
stride .req r8
|
||||
b_tmp .req r10
|
||||
a_tmp .req r11
|
||||
dst_tmp .req r12
|
||||
|
||||
.macro STORE_12x8 p1
|
||||
vst1.16 {\p1}, [dst]
|
||||
add dst, dst, stride
|
||||
.endm
|
||||
|
||||
.macro STORE_12x7 p1, p2, p3
|
||||
add r4, dst, #8
|
||||
add r9, dst, #12
|
||||
vst1.16 {\p1}, [dst]
|
||||
vst1.32 {\p2}, [r4]
|
||||
vst1.16 {\p3}, [r9]
|
||||
add dst, dst, stride
|
||||
.endm
|
||||
|
||||
.macro STORE_12x6 p1, p2
|
||||
add r4, dst, #8
|
||||
vst1.16 {\p1}, [dst]
|
||||
vst1.32 {\p2}, [r4]
|
||||
add dst, dst, stride
|
||||
.endm
|
||||
|
||||
.macro STORE_12x5 p1, p2
|
||||
add r4, dst, #8
|
||||
vst1.16 {\p1}, [dst]
|
||||
vst1.16 {\p2}, [r4]
|
||||
add dst, dst, stride
|
||||
.endm
|
||||
|
||||
.macro STORE_12x4 p1
|
||||
vst1.16 {\p1}, [dst]
|
||||
add dst, dst, stride
|
||||
.endm
|
||||
|
||||
.macro STORE_12x3 p1, p2
|
||||
add r4, dst, #4
|
||||
vst1.32 {\p1}, [dst]
|
||||
vst1.16 {\p2}, [r4]
|
||||
add dst, dst, stride
|
||||
.endm
|
||||
|
||||
.macro STORE_12x2 p1
|
||||
vst1.32 {\p1}, [dst]
|
||||
add dst, dst, stride
|
||||
.endm
|
||||
|
||||
.macro STORE_12x1 p1
|
||||
vst1.16 {\p1}, [dst]
|
||||
add dst, dst, stride
|
||||
.endm
|
||||
|
||||
.macro STORE_C8 p1, p2
|
||||
vst1.16 {\p1}, [dst]
|
||||
cmp row, \p2
|
||||
add dst, dst, stride
|
||||
beq WriteEnd
|
||||
.endm
|
||||
|
||||
.macro STORE_C7 p1, p2, p3, p4
|
||||
add r4, dst, #8
|
||||
add r9, dst, #12
|
||||
vst1.16 {\p1}, [dst]
|
||||
vst1.32 {\p2}, [r4]
|
||||
vst1.16 {\p3}, [r9]
|
||||
add dst, dst, stride
|
||||
cmp row, \p4
|
||||
beq WriteEnd
|
||||
.endm
|
||||
|
||||
.macro STORE_C6 p1, p2, p3
|
||||
add r4, dst, #8
|
||||
vst1.16 {\p1}, [dst]
|
||||
vst1.32 {\p2}, [r4]
|
||||
add dst, dst, stride
|
||||
cmp row, \p3
|
||||
beq WriteEnd
|
||||
.endm
|
||||
|
||||
.macro STORE_C5 p1, p2, p3
|
||||
add r4, dst, #8
|
||||
vst1.16 {\p1}, [dst]
|
||||
vst1.16 {\p2}, [r4]
|
||||
add dst, dst, stride
|
||||
cmp row, \p3
|
||||
beq WriteEnd
|
||||
.endm
|
||||
|
||||
.macro STORE_C4 p1, p2
|
||||
vst1.16 {\p1}, [dst]
|
||||
cmp row, \p2
|
||||
add dst, dst, stride
|
||||
beq WriteEnd
|
||||
.endm
|
||||
|
||||
.macro STORE_C3 p1, p2, p3
|
||||
add r4, dst, #4
|
||||
vst1.32 {\p1}, [dst]
|
||||
vst1.16 {\p2}, [r4]
|
||||
add dst, dst, stride
|
||||
cmp row, \p3
|
||||
beq WriteEnd
|
||||
.endm
|
||||
|
||||
.macro STORE_C2 p1, p2
|
||||
vst1.32 {\p1}, [dst]
|
||||
add dst, dst, stride
|
||||
cmp row, \p2
|
||||
beq WriteEnd
|
||||
.endm
|
||||
|
||||
.macro STORE_C1 p1, p2
|
||||
vst1.16 {\p1}, [dst]
|
||||
add dst, dst, stride
|
||||
cmp row, \p2
|
||||
beq WriteEnd
|
||||
.endm
|
||||
|
||||
LoopRow12:
|
||||
ldr bias, [sp, #-40]
|
||||
LoopCol8:
|
||||
mov dst, dst_tmp
|
||||
mov a, a_tmp
|
||||
ldr depth, [sp, #4]
|
||||
veor q4, q4, q4
|
||||
veor q5, q5, q5
|
||||
veor q6, q6, q6
|
||||
veor q7, q7, q7
|
||||
veor q8, q8, q8
|
||||
veor q9, q9, q9
|
||||
veor q10, q10, q10
|
||||
veor q11, q11, q11
|
||||
veor q12, q12, q12
|
||||
veor q13, q13, q13
|
||||
veor q14, q14, q14
|
||||
veor q15, q15, q15
|
||||
LoopDepth:
|
||||
vld1.16 {q0, d2}, [a]!
|
||||
vld1.16 {q2}, [weight]!
|
||||
vmla.f16 q4, q2, d0[0]
|
||||
vmla.f16 q5, q2, d0[1]
|
||||
vmla.f16 q6, q2, d0[2]
|
||||
vmla.f16 q7, q2, d0[3]
|
||||
vmla.f16 q8, q2, d1[0]
|
||||
vmla.f16 q9, q2, d1[1]
|
||||
vmla.f16 q10, q2, d1[2]
|
||||
vmla.f16 q11, q2, d1[3]
|
||||
vmla.f16 q12, q2, d2[0]
|
||||
vmla.f16 q13, q2, d2[1]
|
||||
vmla.f16 q14, q2, d2[2]
|
||||
vmla.f16 q15, q2, d2[3]
|
||||
|
||||
subs depth, depth, #1
|
||||
bne LoopDepth
|
||||
|
||||
Bias:
|
||||
cmp bias, #0
|
||||
beq Activation
|
||||
vld1.16 {q0}, [bias]!
|
||||
vadd.f16 q4, q4, q0
|
||||
vadd.f16 q5, q5, q0
|
||||
vadd.f16 q6, q6, q0
|
||||
vadd.f16 q7, q7, q0
|
||||
vadd.f16 q8, q8, q0
|
||||
vadd.f16 q9, q9, q0
|
||||
vadd.f16 q10, q10, q0
|
||||
vadd.f16 q11, q11, q0
|
||||
vadd.f16 q12, q12, q0
|
||||
vadd.f16 q13, q13, q0
|
||||
vadd.f16 q14, q14, q0
|
||||
vadd.f16 q15, q15, q0
|
||||
|
||||
Activation:
|
||||
ldr lr, [sp]
|
||||
cmp lr, #3
|
||||
beq Relu6
|
||||
cmp lr, #1
|
||||
beq Relu
|
||||
b Write
|
||||
|
||||
Relu6:
|
||||
vmov.i16 q2, #0x4600
|
||||
vadd.f16 q4, q4, q2
|
||||
vadd.f16 q5, q5, q2
|
||||
vadd.f16 q6, q6, q2
|
||||
vadd.f16 q7, q7, q2
|
||||
vmin.f16 q8, q8, q2
|
||||
vmin.f16 q9, q9, q2
|
||||
vmin.f16 q10, q10, q2
|
||||
vmin.f16 q11, q11, q2
|
||||
vmin.f16 q12, q12, q2
|
||||
vmin.f16 q13, q13, q2
|
||||
vmin.f16 q14, q14, q2
|
||||
vmin.f16 q15, q15, q2
|
||||
|
||||
Relu:
|
||||
veor q3, q3, q3
|
||||
vmax.f16 q4, q4, q3
|
||||
vmax.f16 q5, q5, q3
|
||||
vmax.f16 q6, q6, q3
|
||||
vmax.f16 q7, q7, q3
|
||||
vmax.f16 q8, q8, q3
|
||||
vmax.f16 q9, q9, q3
|
||||
vmax.f16 q10, q10, q3
|
||||
vmax.f16 q11, q11, q3
|
||||
vmax.f16 q12, q12, q3
|
||||
vmax.f16 q13, q13, q3
|
||||
vmax.f16 q14, q14, q3
|
||||
vmax.f16 q15, q15, q3
|
||||
|
||||
Write:
|
||||
ldr lr, [sp, #20]
|
||||
cmp lr, #2
|
||||
beq WriteWinograd
|
||||
cmp row, #12
|
||||
bge Write12xCol
|
||||
b WriteRowxCol
|
||||
|
||||
WriteWinograd:
|
||||
vst1.16 {q4}, [dst]
|
||||
add dst, dst, r4
|
||||
vst1.16 {q5}, [dst]
|
||||
add dst, dst, r4
|
||||
vst1.16 {q6}, [dst]
|
||||
add dst, dst, r4
|
||||
vst1.16 {q7}, [dst]
|
||||
add dst, dst, r4
|
||||
vst1.16 {q8}, [dst]
|
||||
add dst, dst, r4
|
||||
vst1.16 {q9}, [dst]
|
||||
add dst, dst, r4
|
||||
vst1.16 {q10}, [dst]
|
||||
add dst, dst, r4
|
||||
vst1.16 {q11}, [dst]
|
||||
add dst, dst, r4
|
||||
vst1.16 {q12}, [dst]
|
||||
add dst, dst, r4
|
||||
vst1.16 {q13}, [dst]
|
||||
add dst, dst, r4
|
||||
vst1.16 {q14}, [dst]
|
||||
add dst, dst, r4
|
||||
vst1.16 {q15}, [dst]
|
||||
add dst_tmp, dst_tmp, r9
|
||||
b WriteEnd
|
||||
Write12xCol:
|
||||
cmp col, #8
|
||||
bge Write12x8
|
||||
cmp col, #1
|
||||
beq Write12x1
|
||||
cmp col, #2
|
||||
beq Write12x2
|
||||
cmp col, #3
|
||||
beq Write12x3
|
||||
cmp col, #4
|
||||
beq Write12x4
|
||||
cmp col, #5
|
||||
beq Write12x5
|
||||
cmp col, #6
|
||||
beq Write12x6
|
||||
b Write12x7
|
||||
|
||||
WriteRowxCol:
|
||||
cmp col, #8
|
||||
bge WriteRowx8
|
||||
cmp col, #1
|
||||
beq WriteRowx1
|
||||
cmp col, #2
|
||||
beq WriteRowx2
|
||||
cmp col, #3
|
||||
beq WriteRowx3
|
||||
cmp col, #4
|
||||
beq WriteRowx4
|
||||
cmp col, #5
|
||||
beq WriteRowx5
|
||||
cmp col, #6
|
||||
beq WriteRowx6
|
||||
b WriteRowx7
|
||||
|
||||
Write12x8:
|
||||
STORE_12x8 q4
|
||||
STORE_12x8 q5
|
||||
STORE_12x8 q6
|
||||
STORE_12x8 q7
|
||||
STORE_12x8 q8
|
||||
STORE_12x8 q9
|
||||
STORE_12x8 q10
|
||||
STORE_12x8 q11
|
||||
STORE_12x8 q12
|
||||
STORE_12x8 q13
|
||||
STORE_12x8 q14
|
||||
STORE_12x8 q15
|
||||
b WriteEnd
|
||||
WriteRowx8:
|
||||
STORE_C8 q4, #1
|
||||
STORE_C8 q5, #2
|
||||
STORE_C8 q6, #3
|
||||
STORE_C8 q7, #4
|
||||
STORE_C8 q8, #5
|
||||
STORE_C8 q9, #6
|
||||
STORE_C8 q10, #7
|
||||
STORE_C8 q11, #8
|
||||
STORE_C8 q12, #9
|
||||
STORE_C8 q13, #10
|
||||
STORE_C8 q14, #11
|
||||
STORE_C8 q15, #12
|
||||
b WriteEnd
|
||||
|
||||
Write12x1:
|
||||
STORE_12x1 d8[0]
|
||||
STORE_12x1 d10[0]
|
||||
STORE_12x1 d12[0]
|
||||
STORE_12x1 d14[0]
|
||||
STORE_12x1 d16[0]
|
||||
STORE_12x1 d18[0]
|
||||
STORE_12x1 d20[0]
|
||||
STORE_12x1 d22[0]
|
||||
STORE_12x1 d24[0]
|
||||
STORE_12x1 d26[0]
|
||||
STORE_12x1 d28[0]
|
||||
STORE_12x1 d30[0]
|
||||
b WriteEnd
|
||||
WriteRowx1:
|
||||
STORE_C1 d8[0], #1
|
||||
STORE_C1 d10[0], #2
|
||||
STORE_C1 d12[0], #3
|
||||
STORE_C1 d14[0], #4
|
||||
STORE_C1 d16[0], #5
|
||||
STORE_C1 d18[0], #6
|
||||
STORE_C1 d20[0], #7
|
||||
STORE_C1 d22[0], #8
|
||||
STORE_C1 d24[0], #9
|
||||
STORE_C1 d26[0], #10
|
||||
STORE_C1 d28[0], #11
|
||||
STORE_C1 d30[0], #12
|
||||
b WriteEnd
|
||||
|
||||
Write12x2:
|
||||
STORE_12x2 d8[0]
|
||||
STORE_12x2 d10[0]
|
||||
STORE_12x2 d12[0]
|
||||
STORE_12x2 d14[0]
|
||||
STORE_12x2 d16[0]
|
||||
STORE_12x2 d18[0]
|
||||
STORE_12x2 d20[0]
|
||||
STORE_12x2 d22[0]
|
||||
STORE_12x2 d24[0]
|
||||
STORE_12x2 d26[0]
|
||||
STORE_12x2 d28[0]
|
||||
STORE_12x2 d30[0]
|
||||
b WriteEnd
|
||||
WriteRowx2:
|
||||
STORE_C2 d8[0], #1
|
||||
STORE_C2 d10[0], #2
|
||||
STORE_C2 d12[0], #3
|
||||
STORE_C2 d14[0], #4
|
||||
STORE_C2 d16[0], #5
|
||||
STORE_C2 d18[0], #6
|
||||
STORE_C2 d20[0], #7
|
||||
STORE_C2 d22[0], #8
|
||||
STORE_C2 d24[0], #9
|
||||
STORE_C2 d26[0], #10
|
||||
STORE_C2 d28[0], #11
|
||||
STORE_C2 d30[0], #12
|
||||
b WriteEnd
|
||||
|
||||
Write12x3:
|
||||
STORE_12x3 d8[0], d8[2]
|
||||
STORE_12x3 d10[0], d10[2]
|
||||
STORE_12x3 d12[0], d12[2]
|
||||
STORE_12x3 d14[0], d14[2]
|
||||
STORE_12x3 d16[0], d16[2]
|
||||
STORE_12x3 d18[0], d18[2]
|
||||
STORE_12x3 d20[0], d20[2]
|
||||
STORE_12x3 d22[0], d22[2]
|
||||
STORE_12x3 d24[0], d24[2]
|
||||
STORE_12x3 d26[0], d26[2]
|
||||
STORE_12x3 d28[0], d28[2]
|
||||
STORE_12x3 d30[0], d30[2]
|
||||
b WriteEnd
|
||||
WriteRowx3:
|
||||
STORE_C3 d8[0], d8[2], #1
|
||||
STORE_C3 d10[0], d10[2], #2
|
||||
STORE_C3 d12[0], d12[2], #3
|
||||
STORE_C3 d14[0], d14[2], #4
|
||||
STORE_C3 d16[0], d16[2], #5
|
||||
STORE_C3 d18[0], d18[2], #6
|
||||
STORE_C3 d20[0], d20[2], #7
|
||||
STORE_C3 d22[0], d22[2], #8
|
||||
STORE_C3 d24[0], d24[2], #9
|
||||
STORE_C3 d26[0], d26[2], #10
|
||||
STORE_C3 d28[0], d28[2], #11
|
||||
STORE_C3 d30[0], d30[2], #12
|
||||
b WriteEnd
|
||||
|
||||
Write12x4:
|
||||
STORE_12x4 d8
|
||||
STORE_12x4 d10
|
||||
STORE_12x4 d12
|
||||
STORE_12x4 d14
|
||||
STORE_12x4 d16
|
||||
STORE_12x4 d18
|
||||
STORE_12x4 d20
|
||||
STORE_12x4 d22
|
||||
STORE_12x4 d24
|
||||
STORE_12x4 d26
|
||||
STORE_12x4 d28
|
||||
STORE_12x4 d30
|
||||
b WriteEnd
|
||||
WriteRowx4:
|
||||
STORE_C4 d8, #1
|
||||
STORE_C4 d10, #2
|
||||
STORE_C4 d12, #3
|
||||
STORE_C4 d14, #4
|
||||
STORE_C4 d16, #5
|
||||
STORE_C4 d18, #6
|
||||
STORE_C4 d20, #7
|
||||
STORE_C4 d22, #8
|
||||
STORE_C4 d24, #9
|
||||
STORE_C4 d26, #10
|
||||
STORE_C4 d28, #11
|
||||
STORE_C4 d30, #12
|
||||
b WriteEnd
|
||||
|
||||
Write12x5:
|
||||
STORE_12x5 d8, d9[0]
|
||||
STORE_12x5 d10, d11[0]
|
||||
STORE_12x5 d12, d13[0]
|
||||
STORE_12x5 d14, d15[0]
|
||||
STORE_12x5 d16, d17[0]
|
||||
STORE_12x5 d18, d19[0]
|
||||
STORE_12x5 d20, d21[0]
|
||||
STORE_12x5 d22, d23[0]
|
||||
STORE_12x5 d24, d25[0]
|
||||
STORE_12x5 d26, d27[0]
|
||||
STORE_12x5 d28, d29[0]
|
||||
STORE_12x5 d30, d31[0]
|
||||
b WriteEnd
|
||||
WriteRowx5:
|
||||
STORE_C5 d8, d9[0], #1
|
||||
STORE_C5 d10, d11[0], #2
|
||||
STORE_C5 d12, d13[0], #3
|
||||
STORE_C5 d14, d15[0], #4
|
||||
STORE_C5 d16, d17[0], #5
|
||||
STORE_C5 d18, d19[0], #6
|
||||
STORE_C5 d20, d21[0], #7
|
||||
STORE_C5 d22, d23[0], #8
|
||||
STORE_C5 d24, d25[0], #9
|
||||
STORE_C5 d26, d27[0], #10
|
||||
STORE_C5 d28, d29[0], #11
|
||||
STORE_C5 d30, d31[0], #12
|
||||
b WriteEnd
|
||||
|
||||
Write12x6:
|
||||
STORE_12x6 d8, d9[0]
|
||||
STORE_12x6 d10, d11[0]
|
||||
STORE_12x6 d12, d13[0]
|
||||
STORE_12x6 d14, d15[0]
|
||||
STORE_12x6 d16, d17[0]
|
||||
STORE_12x6 d18, d19[0]
|
||||
STORE_12x6 d20, d21[0]
|
||||
STORE_12x6 d22, d23[0]
|
||||
STORE_12x6 d24, d25[0]
|
||||
STORE_12x6 d26, d27[0]
|
||||
STORE_12x6 d28, d29[0]
|
||||
STORE_12x6 d30, d31[0]
|
||||
b WriteEnd
|
||||
WriteRowx6:
|
||||
STORE_C6 d8, d9[0], #1
|
||||
STORE_C6 d10, d11[0], #2
|
||||
STORE_C6 d12, d13[0], #3
|
||||
STORE_C6 d14, d15[0], #4
|
||||
STORE_C6 d16, d17[0], #5
|
||||
STORE_C6 d18, d19[0], #6
|
||||
STORE_C6 d20, d21[0], #7
|
||||
STORE_C6 d22, d23[0], #8
|
||||
STORE_C6 d24, d25[0], #9
|
||||
STORE_C6 d26, d27[0], #10
|
||||
STORE_C6 d28, d29[0], #11
|
||||
STORE_C6 d30, d31[0], #12
|
||||
b WriteEnd
|
||||
|
||||
Write12x7:
|
||||
STORE_12x7 d8, d9[0], d9[2]
|
||||
STORE_12x7 d10, d11[0], d11[2]
|
||||
STORE_12x7 d12, d13[0], d13[2]
|
||||
STORE_12x7 d14, d15[0], d15[2]
|
||||
STORE_12x7 d16, d17[0], d17[2]
|
||||
STORE_12x7 d18, d19[0], d19[2]
|
||||
STORE_12x7 d20, d21[0], d21[2]
|
||||
STORE_12x7 d22, d23[0], d23[2]
|
||||
STORE_12x7 d24, d25[0], d25[2]
|
||||
STORE_12x7 d26, d27[0], d27[2]
|
||||
STORE_12x7 d28, d29[0], d29[2]
|
||||
STORE_12x7 d30, d31[0], d31[2]
|
||||
b WriteEnd
|
||||
WriteRowx7:
|
||||
STORE_C7 d8, d9[0], d9[2], #1
|
||||
STORE_C7 d10, d11[0], d11[2], #2
|
||||
STORE_C7 d12, d13[0], d13[2], #3
|
||||
STORE_C7 d14, d15[0], d15[2], #4
|
||||
STORE_C7 d16, d17[0], d17[2], #5
|
||||
STORE_C7 d18, d19[0], d19[2], #6
|
||||
STORE_C7 d20, d21[0], d21[2], #7
|
||||
STORE_C7 d22, d23[0], d23[2], #8
|
||||
STORE_C7 d24, d25[0], d25[2], #9
|
||||
STORE_C7 d26, d27[0], d27[2], #10
|
||||
STORE_C7 d28, d29[0], d29[2], #11
|
||||
STORE_C7 d30, d31[0], d31[2], #12
|
||||
b WriteEnd
|
||||
|
||||
WriteEnd:
|
||||
cmp col, #8
|
||||
ble LoopColEnd
|
||||
sub col, col, #8
|
||||
ldr lr, [sp, #20]
|
||||
cmp lr, #2
|
||||
beq LoopCol8
|
||||
add dst_tmp, dst_tmp, #16
|
||||
b LoopCol8
|
||||
LoopColEnd:
|
||||
cmp row, #12
|
||||
ble LoopRowEnd
|
||||
sub row, row, #12
|
||||
mov a_tmp, a
|
||||
mov weight, b_tmp
|
||||
ldr lr, [sp, #20]
|
||||
cmp lr, #2
|
||||
beq WinogradDst
|
||||
ldr lr, [sp, #12]
|
||||
sub lr, lr, col
|
||||
add lr, lr, lr // col *= 2
|
||||
sub dst_tmp, dst, lr
|
||||
b LoopRow
|
||||
WinogradDst:
|
||||
add dst_tmp, dst, r9
|
||||
LoopRow:
|
||||
mov dst, dst_tmp
|
||||
ldr col, [sp, #12]
|
||||
b LoopRow12
|
||||
LoopRowEnd:
|
||||
sub sp, sp, #104
|
||||
vpop {q4-q7}
|
||||
pop {r3-r11, pc}
|
||||
#endif
|
|
@ -1,667 +0,0 @@
|
|||
#ifdef ENABLE_ARM64
|
||||
#include "nnacl/assembly_global.h"
|
||||
|
||||
.text
|
||||
.align 5
|
||||
|
||||
// void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias,
|
||||
// size_t step, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, size_t relu6);
|
||||
// x0: output, x1: input, x2: weight, x3: bias, x4: step, x5: ic4, x6: oc8, x7: offset,
|
||||
// x8:mode, x9: writeC4, x10:relu, x11: relu6
|
||||
// compute 8 channel for 16 outputs
|
||||
asm_function IndirectGemmFp16_16x8
|
||||
|
||||
.macro INIT_BIAS
|
||||
dup v16.4s, wzr
|
||||
cbz x3, InitBias
|
||||
ld1 {v16.8h}, [x3]
|
||||
InitBias:
|
||||
mov v17.16b, v16.16b
|
||||
mov v18.16b, v16.16b
|
||||
mov v19.16b, v16.16b
|
||||
mov v20.16b, v16.16b
|
||||
mov v21.16b, v16.16b
|
||||
mov v22.16b, v16.16b
|
||||
mov v23.16b, v16.16b
|
||||
mov v24.16b, v16.16b
|
||||
mov v25.16b, v16.16b
|
||||
mov v26.16b, v16.16b
|
||||
mov v27.16b, v16.16b
|
||||
mov v28.16b, v16.16b
|
||||
mov v29.16b, v16.16b
|
||||
mov v30.16b, v16.16b
|
||||
mov v31.16b, v16.16b
|
||||
.endm
|
||||
|
||||
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
|
||||
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
|
||||
// x19 ~ r29 should be also preserved
|
||||
// whereas our coding style do not permit such amount of parameters
|
||||
sub sp, sp, #144
|
||||
// performance between storing 4 registers at the same time and separately storing them on in-order cores
|
||||
// is not tested yet
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
stp x19, x20, [sp], #16
|
||||
|
||||
ldr x8, [sp, #0]
|
||||
ldr x9, [sp, #8]
|
||||
ldr x10, [sp, #16]
|
||||
ldr x11, [sp, #24]
|
||||
|
||||
cbnz x8, IndirectGemmStart
|
||||
// step is one for common convolution, where ic8 should multiply by kernel size
|
||||
// step is (a+b-1) for F(a,b) in winograd
|
||||
mul x5, x4, x5
|
||||
mov x4, #1
|
||||
|
||||
IndirectGemmStart:
|
||||
|
||||
LoopOc:
|
||||
|
||||
mov x14, x4
|
||||
mov x12, x1
|
||||
|
||||
LoopKsize:
|
||||
|
||||
mov x15, x0
|
||||
INIT_BIAS
|
||||
// load input for output 1-8
|
||||
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64
|
||||
// load weight
|
||||
ld1 {v8.8h, v9.8h}, [x2], #32
|
||||
// first 2 steps for output 1 and 3
|
||||
fmla v16.8h, v8.8h, v0.h[0]
|
||||
fmla v18.8h, v8.8h, v1.h[0]
|
||||
fmla v16.8h, v9.8h, v0.h[1]
|
||||
fmla v18.8h, v9.8h, v1.h[1]
|
||||
// load weight
|
||||
ld1 {v10.8h, v11.8h}, [x2], #32
|
||||
// first 2 steps for output 2 and 4
|
||||
fmla v17.8h, v8.8h, v0.h[4]
|
||||
fmla v19.8h, v8.8h, v1.h[4]
|
||||
fmla v17.8h, v9.8h, v0.h[5]
|
||||
fmla v19.8h, v9.8h, v1.h[5]
|
||||
// load input for output 9-16
|
||||
// input cache should be refreshed after loading
|
||||
// ATTENTION: advance is preferred, but advancing too much may lead to invalid prefetching
|
||||
ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64
|
||||
// last 2 steps for output 1 and 3
|
||||
fmla v16.8h, v10.8h, v0.h[2]
|
||||
fmla v18.8h, v10.8h, v1.h[2]
|
||||
fmla v16.8h, v11.8h, v0.h[3]
|
||||
fmla v18.8h, v11.8h, v1.h[3]
|
||||
|
||||
// check if ic4=1
|
||||
subs x13, x5, #1
|
||||
beq LoopIcEnd
|
||||
|
||||
LoopIc:
|
||||
// last 2 steps for output 2 and 4
|
||||
fmla v17.8h, v10.8h, v0.h[6]
|
||||
fmla v19.8h, v10.8h, v1.h[6]
|
||||
fmla v17.8h, v11.8h, v0.h[7]
|
||||
fmla v19.8h, v11.8h, v1.h[7]
|
||||
// steps for output 5-8
|
||||
fmla v20.8h, v8.8h, v2.h[0]
|
||||
fmla v22.8h, v8.8h, v3.h[0]
|
||||
fmla v20.8h, v9.8h, v2.h[1]
|
||||
fmla v22.8h, v9.8h, v3.h[1]
|
||||
fmla v21.8h, v8.8h, v2.h[4]
|
||||
fmla v23.8h, v8.8h, v3.h[4]
|
||||
fmla v21.8h, v9.8h, v2.h[5]
|
||||
fmla v23.8h, v9.8h, v3.h[5]
|
||||
fmla v20.8h, v10.8h, v2.h[2]
|
||||
fmla v22.8h, v10.8h, v3.h[2]
|
||||
fmla v20.8h, v11.8h, v2.h[3]
|
||||
fmla v22.8h, v11.8h, v3.h[3]
|
||||
fmla v21.8h, v10.8h, v2.h[6]
|
||||
fmla v23.8h, v10.8h, v3.h[6]
|
||||
fmla v21.8h, v11.8h, v2.h[7]
|
||||
fmla v23.8h, v11.8h, v3.h[7]
|
||||
// load input for output 1-8
|
||||
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64
|
||||
// steps for output 9-12
|
||||
fmla v24.8h, v8.8h, v4.h[0]
|
||||
fmla v26.8h, v8.8h, v5.h[0]
|
||||
fmla v24.8h, v9.8h, v4.h[1]
|
||||
fmla v26.8h, v9.8h, v5.h[1]
|
||||
fmla v25.8h, v8.8h, v4.h[4]
|
||||
fmla v27.8h, v8.8h, v5.h[4]
|
||||
fmla v25.8h, v9.8h, v4.h[5]
|
||||
fmla v27.8h, v9.8h, v5.h[5]
|
||||
fmla v24.8h, v10.8h, v4.h[2]
|
||||
fmla v26.8h, v10.8h, v5.h[2]
|
||||
fmla v24.8h, v11.8h, v4.h[3]
|
||||
fmla v26.8h, v11.8h, v5.h[3]
|
||||
fmla v25.8h, v10.8h, v4.h[6]
|
||||
fmla v27.8h, v10.8h, v5.h[6]
|
||||
fmla v25.8h, v11.8h, v4.h[7]
|
||||
fmla v27.8h, v11.8h, v5.h[7]
|
||||
// steps for output 13-16
|
||||
fmla v28.8h, v8.8h, v6.h[0]
|
||||
fmla v30.8h, v8.8h, v7.h[0]
|
||||
fmla v28.8h, v9.8h, v6.h[1]
|
||||
fmla v30.8h, v9.8h, v7.h[1]
|
||||
fmla v29.8h, v8.8h, v6.h[4]
|
||||
fmla v31.8h, v8.8h, v7.h[4]
|
||||
fmla v29.8h, v9.8h, v6.h[5]
|
||||
fmla v31.8h, v9.8h, v7.h[5]
|
||||
// load weight
|
||||
ld1 {v8.8h, v9.8h}, [x2], #32
|
||||
fmla v28.8h, v10.8h, v6.h[2]
|
||||
fmla v30.8h, v10.8h, v7.h[2]
|
||||
fmla v28.8h, v11.8h, v6.h[3]
|
||||
fmla v30.8h, v11.8h, v7.h[3]
|
||||
fmla v29.8h, v10.8h, v6.h[6]
|
||||
fmla v31.8h, v10.8h, v7.h[6]
|
||||
fmla v29.8h, v11.8h, v6.h[7]
|
||||
fmla v31.8h, v11.8h, v7.h[7]
|
||||
// load weight
|
||||
ld1 {v10.8h, v11.8h}, [x2], #32
|
||||
// first 2 steps for output 1-4
|
||||
fmla v16.8h, v8.8h, v0.h[0]
|
||||
fmla v18.8h, v8.8h, v1.h[0]
|
||||
fmla v16.8h, v9.8h, v0.h[1]
|
||||
fmla v18.8h, v9.8h, v1.h[1]
|
||||
fmla v17.8h, v8.8h, v0.h[4]
|
||||
fmla v19.8h, v8.8h, v1.h[4]
|
||||
fmla v17.8h, v9.8h, v0.h[5]
|
||||
fmla v19.8h, v9.8h, v1.h[5]
|
||||
// load input for output 9-16
|
||||
ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64
|
||||
// last 2 steps for output 1 and 3
|
||||
fmla v16.8h, v10.8h, v0.h[2]
|
||||
fmla v18.8h, v10.8h, v1.h[2]
|
||||
fmla v16.8h, v11.8h, v0.h[3]
|
||||
fmla v18.8h, v11.8h, v1.h[3]
|
||||
|
||||
subs x13, x13, #1
|
||||
bne LoopIc
|
||||
|
||||
LoopIcEnd:
|
||||
fmla v17.8h, v10.8h, v0.h[6]
|
||||
fmla v19.8h, v10.8h, v1.h[6]
|
||||
fmla v17.8h, v11.8h, v0.h[7]
|
||||
fmla v19.8h, v11.8h, v1.h[7]
|
||||
// steps for output 5-8
|
||||
fmla v20.8h, v8.8h, v2.h[0]
|
||||
fmla v22.8h, v8.8h, v3.h[0]
|
||||
fmla v20.8h, v9.8h, v2.h[1]
|
||||
fmla v22.8h, v9.8h, v3.h[1]
|
||||
fmla v21.8h, v8.8h, v2.h[4]
|
||||
fmla v23.8h, v8.8h, v3.h[4]
|
||||
fmla v21.8h, v9.8h, v2.h[5]
|
||||
fmla v23.8h, v9.8h, v3.h[5]
|
||||
fmla v20.8h, v10.8h, v2.h[2]
|
||||
fmla v22.8h, v10.8h, v3.h[2]
|
||||
fmla v20.8h, v11.8h, v2.h[3]
|
||||
fmla v22.8h, v11.8h, v3.h[3]
|
||||
fmla v21.8h, v10.8h, v2.h[6]
|
||||
fmla v23.8h, v10.8h, v3.h[6]
|
||||
fmla v21.8h, v11.8h, v2.h[7]
|
||||
fmla v23.8h, v11.8h, v3.h[7]
|
||||
// steps for output 9-12
|
||||
fmla v24.8h, v8.8h, v4.h[0]
|
||||
fmla v26.8h, v8.8h, v5.h[0]
|
||||
fmla v24.8h, v9.8h, v4.h[1]
|
||||
fmla v26.8h, v9.8h, v5.h[1]
|
||||
fmla v25.8h, v8.8h, v4.h[4]
|
||||
fmla v27.8h, v8.8h, v5.h[4]
|
||||
fmla v25.8h, v9.8h, v4.h[5]
|
||||
fmla v27.8h, v9.8h, v5.h[5]
|
||||
fmla v24.8h, v10.8h, v4.h[2]
|
||||
fmla v26.8h, v10.8h, v5.h[2]
|
||||
fmla v24.8h, v11.8h, v4.h[3]
|
||||
fmla v26.8h, v11.8h, v5.h[3]
|
||||
fmla v25.8h, v10.8h, v4.h[6]
|
||||
fmla v27.8h, v10.8h, v5.h[6]
|
||||
fmla v25.8h, v11.8h, v4.h[7]
|
||||
fmla v27.8h, v11.8h, v5.h[7]
|
||||
// steps for output 13-16
|
||||
fmla v28.8h, v8.8h, v6.h[0]
|
||||
fmla v30.8h, v8.8h, v7.h[0]
|
||||
fmla v28.8h, v9.8h, v6.h[1]
|
||||
fmla v30.8h, v9.8h, v7.h[1]
|
||||
fmla v29.8h, v8.8h, v6.h[4]
|
||||
fmla v31.8h, v8.8h, v7.h[4]
|
||||
fmla v29.8h, v9.8h, v6.h[5]
|
||||
fmla v31.8h, v9.8h, v7.h[5]
|
||||
fmla v28.8h, v10.8h, v6.h[2]
|
||||
fmla v30.8h, v10.8h, v7.h[2]
|
||||
fmla v28.8h, v11.8h, v6.h[3]
|
||||
fmla v30.8h, v11.8h, v7.h[3]
|
||||
fmla v29.8h, v10.8h, v6.h[6]
|
||||
fmla v31.8h, v10.8h, v7.h[6]
|
||||
fmla v29.8h, v11.8h, v6.h[7]
|
||||
fmla v31.8h, v11.8h, v7.h[7]
|
||||
|
||||
cbnz x11, Relu6
|
||||
cbnz x10, Relu
|
||||
b WriteStart
|
||||
Relu6:
|
||||
movi v9.8h, #0x46, lsl #8
|
||||
fmin v16.8h, v16.8h, v9.8h
|
||||
fmin v17.8h, v17.8h, v9.8h
|
||||
fmin v18.8h, v18.8h, v9.8h
|
||||
fmin v19.8h, v19.8h, v9.8h
|
||||
fmin v20.8h, v20.8h, v9.8h
|
||||
fmin v21.8h, v21.8h, v9.8h
|
||||
fmin v22.8h, v22.8h, v9.8h
|
||||
fmin v23.8h, v23.8h, v9.8h
|
||||
fmin v24.8h, v24.8h, v9.8h
|
||||
fmin v25.8h, v25.8h, v9.8h
|
||||
fmin v26.8h, v26.8h, v9.8h
|
||||
fmin v27.8h, v27.8h, v9.8h
|
||||
fmin v28.8h, v28.8h, v9.8h
|
||||
fmin v29.8h, v29.8h, v9.8h
|
||||
fmin v30.8h, v30.8h, v9.8h
|
||||
fmin v31.8h, v31.8h, v9.8h
|
||||
Relu:
|
||||
dup v8.4s, wzr
|
||||
fmax v16.8h, v16.8h, v8.8h
|
||||
fmax v17.8h, v17.8h, v8.8h
|
||||
fmax v18.8h, v18.8h, v8.8h
|
||||
fmax v19.8h, v19.8h, v8.8h
|
||||
fmax v20.8h, v20.8h, v8.8h
|
||||
fmax v21.8h, v21.8h, v8.8h
|
||||
fmax v22.8h, v22.8h, v8.8h
|
||||
fmax v23.8h, v23.8h, v8.8h
|
||||
fmax v24.8h, v24.8h, v8.8h
|
||||
fmax v25.8h, v25.8h, v8.8h
|
||||
fmax v26.8h, v26.8h, v8.8h
|
||||
fmax v27.8h, v27.8h, v8.8h
|
||||
fmax v28.8h, v28.8h, v8.8h
|
||||
fmax v29.8h, v29.8h, v8.8h
|
||||
fmax v30.8h, v30.8h, v8.8h
|
||||
fmax v31.8h, v31.8h, v8.8h
|
||||
|
||||
WriteStart:
|
||||
cbnz x9, Write8
|
||||
cmp x6, #1
|
||||
beq Write1
|
||||
cmp x6, #2
|
||||
beq Write2
|
||||
cmp x6, #3
|
||||
beq Write3
|
||||
cmp x6, #4
|
||||
beq Write4
|
||||
cmp x6, #5
|
||||
beq Write5
|
||||
cmp x6, #6
|
||||
beq Write6
|
||||
cmp x6, #7
|
||||
beq Write7
|
||||
b Write8
|
||||
// prefetching is not preferred while writing results in spite of cache missing
|
||||
// you could try prfm pstl2strm
|
||||
// there are almost no benefits observed though
|
||||
Write1:
|
||||
str h16, [x15]
|
||||
add x15, x15, x7
|
||||
str h17, [x15]
|
||||
add x15, x15, x7
|
||||
str h18, [x15]
|
||||
add x15, x15, x7
|
||||
str h19, [x15]
|
||||
add x15, x15, x7
|
||||
str h20, [x15]
|
||||
add x15, x15, x7
|
||||
str h21, [x15]
|
||||
add x15, x15, x7
|
||||
str h22, [x15]
|
||||
add x15, x15, x7
|
||||
str h23, [x15]
|
||||
add x15, x15, x7
|
||||
str h24, [x15]
|
||||
add x15, x15, x7
|
||||
str h25, [x15]
|
||||
add x15, x15, x7
|
||||
str h26, [x15]
|
||||
add x15, x15, x7
|
||||
str h27, [x15]
|
||||
add x15, x15, x7
|
||||
str h28, [x15]
|
||||
add x15, x15, x7
|
||||
str h29, [x15]
|
||||
add x15, x15, x7
|
||||
str h30, [x15]
|
||||
add x15, x15, x7
|
||||
str h31, [x15]
|
||||
add x0, x0, #2
|
||||
b WriteEnd
|
||||
Write2:
|
||||
add x17, x15, #2
|
||||
st1 {v16.h}[0], [x15], x7
|
||||
st1 {v16.h}[1], [x17], x7
|
||||
st1 {v17.h}[0], [x15], x7
|
||||
st1 {v17.h}[1], [x17], x7
|
||||
st1 {v18.h}[0], [x15], x7
|
||||
st1 {v18.h}[1], [x17], x7
|
||||
st1 {v19.h}[0], [x15], x7
|
||||
st1 {v19.h}[1], [x17], x7
|
||||
st1 {v20.h}[0], [x15], x7
|
||||
st1 {v20.h}[1], [x17], x7
|
||||
st1 {v21.h}[0], [x15], x7
|
||||
st1 {v21.h}[1], [x17], x7
|
||||
st1 {v22.h}[0], [x15], x7
|
||||
st1 {v22.h}[1], [x17], x7
|
||||
st1 {v23.h}[0], [x15], x7
|
||||
st1 {v23.h}[1], [x17], x7
|
||||
st1 {v24.h}[0], [x15], x7
|
||||
st1 {v24.h}[1], [x17], x7
|
||||
st1 {v25.h}[0], [x15], x7
|
||||
st1 {v25.h}[1], [x17], x7
|
||||
st1 {v26.h}[0], [x15], x7
|
||||
st1 {v26.h}[1], [x17], x7
|
||||
st1 {v27.h}[0], [x15], x7
|
||||
st1 {v27.h}[1], [x17], x7
|
||||
st1 {v28.h}[0], [x15], x7
|
||||
st1 {v28.h}[1], [x17], x7
|
||||
st1 {v29.h}[0], [x15], x7
|
||||
st1 {v29.h}[1], [x17], x7
|
||||
st1 {v30.h}[0], [x15], x7
|
||||
st1 {v30.h}[1], [x17], x7
|
||||
st1 {v31.h}[0], [x15]
|
||||
st1 {v31.h}[1], [x17]
|
||||
add x0, x0, #4
|
||||
b WriteEnd
|
||||
Write3:
|
||||
add x17, x15, #4
|
||||
add x16, x15, #2
|
||||
st1 {v16.h}[0], [x15], x7
|
||||
st1 {v16.h}[1], [x16], x7
|
||||
st1 {v16.h}[2], [x17], x7
|
||||
st1 {v17.h}[0], [x15], x7
|
||||
st1 {v17.h}[1], [x16], x7
|
||||
st1 {v17.h}[2], [x17], x7
|
||||
st1 {v18.h}[0], [x15], x7
|
||||
st1 {v18.h}[1], [x16], x7
|
||||
st1 {v18.h}[2], [x17], x7
|
||||
st1 {v19.h}[0], [x15], x7
|
||||
st1 {v19.h}[1], [x16], x7
|
||||
st1 {v19.h}[2], [x17], x7
|
||||
st1 {v20.h}[0], [x15], x7
|
||||
st1 {v20.h}[1], [x16], x7
|
||||
st1 {v20.h}[2], [x17], x7
|
||||
st1 {v21.h}[0], [x15], x7
|
||||
st1 {v21.h}[1], [x16], x7
|
||||
st1 {v21.h}[2], [x17], x7
|
||||
st1 {v22.h}[0], [x15], x7
|
||||
st1 {v22.h}[1], [x16], x7
|
||||
st1 {v22.h}[2], [x17], x7
|
||||
st1 {v23.h}[0], [x15], x7
|
||||
st1 {v23.h}[1], [x16], x7
|
||||
st1 {v23.h}[2], [x17], x7
|
||||
st1 {v24.h}[0], [x15], x7
|
||||
st1 {v24.h}[1], [x16], x7
|
||||
st1 {v24.h}[2], [x17], x7
|
||||
st1 {v25.h}[0], [x15], x7
|
||||
st1 {v25.h}[1], [x16], x7
|
||||
st1 {v25.h}[2], [x17], x7
|
||||
st1 {v26.h}[0], [x15], x7
|
||||
st1 {v26.h}[1], [x16], x7
|
||||
st1 {v26.h}[2], [x17], x7
|
||||
st1 {v27.h}[0], [x15], x7
|
||||
st1 {v27.h}[1], [x16], x7
|
||||
st1 {v27.h}[2], [x17], x7
|
||||
st1 {v28.h}[0], [x15], x7
|
||||
st1 {v28.h}[1], [x16], x7
|
||||
st1 {v28.h}[2], [x17], x7
|
||||
st1 {v29.h}[0], [x15], x7
|
||||
st1 {v29.h}[1], [x16], x7
|
||||
st1 {v29.h}[2], [x17], x7
|
||||
st1 {v30.h}[0], [x15], x7
|
||||
st1 {v30.h}[1], [x16], x7
|
||||
st1 {v30.h}[2], [x17], x7
|
||||
st1 {v31.h}[0], [x15]
|
||||
st1 {v31.h}[1], [x16]
|
||||
st1 {v31.h}[2], [x17]
|
||||
add x0, x0, #6
|
||||
b WriteEnd
|
||||
Write4:
|
||||
st1 {v16.4h}, [x15], x7
|
||||
st1 {v17.4h}, [x15], x7
|
||||
st1 {v18.4h}, [x15], x7
|
||||
st1 {v19.4h}, [x15], x7
|
||||
st1 {v20.4h}, [x15], x7
|
||||
st1 {v21.4h}, [x15], x7
|
||||
st1 {v22.4h}, [x15], x7
|
||||
st1 {v23.4h}, [x15], x7
|
||||
st1 {v24.4h}, [x15], x7
|
||||
st1 {v25.4h}, [x15], x7
|
||||
st1 {v26.4h}, [x15], x7
|
||||
st1 {v27.4h}, [x15], x7
|
||||
st1 {v28.4h}, [x15], x7
|
||||
st1 {v29.4h}, [x15], x7
|
||||
st1 {v30.4h}, [x15], x7
|
||||
st1 {v31.4h}, [x15]
|
||||
add x0, x0, #8
|
||||
b WriteEnd
|
||||
Write5:
|
||||
add x17, x15, #8
|
||||
st1 {v16.4h}, [x15], x7
|
||||
st1 {v16.h}[4], [x17], x7
|
||||
st1 {v17.4h}, [x15], x7
|
||||
st1 {v17.h}[4], [x17], x7
|
||||
st1 {v18.4h}, [x15], x7
|
||||
st1 {v18.h}[4], [x17], x7
|
||||
st1 {v19.4h}, [x15], x7
|
||||
st1 {v19.h}[4], [x17], x7
|
||||
st1 {v20.4h}, [x15], x7
|
||||
st1 {v20.h}[4], [x17], x7
|
||||
st1 {v21.4h}, [x15], x7
|
||||
st1 {v21.h}[4], [x17], x7
|
||||
st1 {v22.4h}, [x15], x7
|
||||
st1 {v22.h}[4], [x17], x7
|
||||
st1 {v23.4h}, [x15], x7
|
||||
st1 {v23.h}[4], [x17], x7
|
||||
st1 {v24.4h}, [x15], x7
|
||||
st1 {v24.h}[4], [x17], x7
|
||||
st1 {v25.4h}, [x15], x7
|
||||
st1 {v25.h}[4], [x17], x7
|
||||
st1 {v26.4h}, [x15], x7
|
||||
st1 {v26.h}[4], [x17], x7
|
||||
st1 {v27.4h}, [x15], x7
|
||||
st1 {v27.h}[4], [x17], x7
|
||||
st1 {v28.4h}, [x15], x7
|
||||
st1 {v28.h}[4], [x17], x7
|
||||
st1 {v29.4h}, [x15], x7
|
||||
st1 {v29.h}[4], [x17], x7
|
||||
st1 {v30.4h}, [x15], x7
|
||||
st1 {v30.h}[4], [x17], x7
|
||||
st1 {v31.4h}, [x15]
|
||||
st1 {v31.h}[4], [x17]
|
||||
add x0, x0, #10
|
||||
b WriteEnd
|
||||
Write6:
|
||||
add x17, x15, #8
|
||||
add x16, x15, #10
|
||||
st1 {v16.4h}, [x15], x7
|
||||
ins v0.s[0], v16.s[2]
|
||||
st1 {v0.h}[0], [x17], x7
|
||||
st1 {v0.h}[1], [x16], x7
|
||||
st1 {v17.4h}, [x15], x7
|
||||
ins v1.s[0], v17.s[2]
|
||||
st1 {v1.h}[0], [x17], x7
|
||||
st1 {v1.h}[1], [x16], x7
|
||||
st1 {v18.4h}, [x15], x7
|
||||
ins v2.s[0], v18.s[2]
|
||||
st1 {v2.h}[0], [x17], x7
|
||||
st1 {v2.h}[1], [x16], x7
|
||||
st1 {v19.4h}, [x15], x7
|
||||
ins v3.s[0], v19.s[2]
|
||||
st1 {v3.h}[0], [x17], x7
|
||||
st1 {v3.h}[1], [x16], x7
|
||||
st1 {v20.4h}, [x15], x7
|
||||
ins v4.s[0], v20.s[2]
|
||||
st1 {v4.h}[0], [x17], x7
|
||||
st1 {v4.h}[1], [x16], x7
|
||||
st1 {v21.4h}, [x15], x7
|
||||
ins v5.s[0], v21.s[2]
|
||||
st1 {v5.h}[0], [x17], x7
|
||||
st1 {v5.h}[1], [x16], x7
|
||||
st1 {v22.4h}, [x15], x7
|
||||
ins v6.s[0], v22.s[2]
|
||||
st1 {v6.h}[0], [x17], x7
|
||||
st1 {v6.h}[1], [x16], x7
|
||||
st1 {v23.4h}, [x15], x7
|
||||
ins v7.s[0], v23.s[2]
|
||||
st1 {v7.h}[0], [x17], x7
|
||||
st1 {v7.h}[1], [x16], x7
|
||||
st1 {v24.4h}, [x15], x7
|
||||
ins v8.s[0], v24.s[2]
|
||||
st1 {v8.h}[0], [x17], x7
|
||||
st1 {v8.h}[1], [x16], x7
|
||||
st1 {v25.4h}, [x15], x7
|
||||
ins v9.s[0], v25.s[2]
|
||||
st1 {v9.h}[0], [x17], x7
|
||||
st1 {v9.h}[1], [x16], x7
|
||||
st1 {v26.4h}, [x15], x7
|
||||
ins v10.s[0], v26.s[2]
|
||||
st1 {v10.h}[0], [x17], x7
|
||||
st1 {v10.h}[1], [x16], x7
|
||||
st1 {v27.4h}, [x15], x7
|
||||
ins v11.s[0], v27.s[2]
|
||||
st1 {v11.h}[0], [x17], x7
|
||||
st1 {v11.h}[1], [x16], x7
|
||||
st1 {v28.4h}, [x15], x7
|
||||
ins v12.s[0], v28.s[2]
|
||||
st1 {v12.h}[0], [x17], x7
|
||||
st1 {v12.h}[1], [x16], x7
|
||||
st1 {v29.4h}, [x15], x7
|
||||
ins v13.s[0], v29.s[2]
|
||||
st1 {v13.h}[0], [x17], x7
|
||||
st1 {v13.h}[1], [x16], x7
|
||||
st1 {v30.4h}, [x15], x7
|
||||
ins v14.s[0], v30.s[2]
|
||||
st1 {v14.h}[0], [x17], x7
|
||||
st1 {v14.h}[1], [x16], x7
|
||||
st1 {v31.4h}, [x15]
|
||||
ins v15.s[0], v31.s[2]
|
||||
st1 {v14.h}[0], [x17]
|
||||
st1 {v14.h}[1], [x16]
|
||||
add x0, x0, #12
|
||||
b WriteEnd
|
||||
Write7:
|
||||
add x17, x15, #8
|
||||
add x19, x15, #10
|
||||
add x16, x15, #12
|
||||
st1 {v16.4h}, [x15], x7
|
||||
ins v0.s[0], v16.s[2]
|
||||
st1 {v0.h}[0], [x17], x7
|
||||
st1 {v0.h}[1], [x19], x7
|
||||
st1 {v16.h}[6], [x16], x7
|
||||
st1 {v17.4h}, [x15], x7
|
||||
ins v1.s[0], v17.s[2]
|
||||
st1 {v1.h}[0], [x17], x7
|
||||
st1 {v1.h}[1], [x19], x7
|
||||
st1 {v17.h}[6], [x16], x7
|
||||
st1 {v18.4h}, [x15], x7
|
||||
ins v2.s[0], v18.s[2]
|
||||
st1 {v2.h}[0], [x17], x7
|
||||
st1 {v2.h}[1], [x19], x7
|
||||
st1 {v18.h}[6], [x16], x7
|
||||
st1 {v19.4h}, [x15], x7
|
||||
ins v3.s[0], v19.s[2]
|
||||
st1 {v3.h}[0], [x17], x7
|
||||
st1 {v3.h}[1], [x19], x7
|
||||
st1 {v19.h}[6], [x16], x7
|
||||
st1 {v20.4h}, [x15], x7
|
||||
ins v4.s[0], v20.s[2]
|
||||
st1 {v4.h}[0], [x17], x7
|
||||
st1 {v4.h}[1], [x19], x7
|
||||
st1 {v20.h}[6], [x16], x7
|
||||
st1 {v21.4h}, [x15], x7
|
||||
ins v5.s[0], v21.s[2]
|
||||
st1 {v5.h}[0], [x17], x7
|
||||
st1 {v5.h}[1], [x19], x7
|
||||
st1 {v21.h}[6], [x16], x7
|
||||
st1 {v22.4h}, [x15], x7
|
||||
ins v6.s[0], v22.s[2]
|
||||
st1 {v6.h}[0], [x17], x7
|
||||
st1 {v6.h}[1], [x19], x7
|
||||
st1 {v22.h}[6], [x16], x7
|
||||
st1 {v23.4h}, [x15], x7
|
||||
ins v7.s[0], v23.s[2]
|
||||
st1 {v7.h}[0], [x17], x7
|
||||
st1 {v7.h}[1], [x19], x7
|
||||
st1 {v23.h}[6], [x16], x7
|
||||
st1 {v24.4h}, [x15], x7
|
||||
ins v8.s[0], v24.s[2]
|
||||
st1 {v8.h}[0], [x17], x7
|
||||
st1 {v8.h}[1], [x19], x7
|
||||
st1 {v24.h}[6], [x16], x7
|
||||
st1 {v25.4h}, [x15], x7
|
||||
ins v9.s[0], v25.s[2]
|
||||
st1 {v9.h}[0], [x17], x7
|
||||
st1 {v9.h}[1], [x19], x7
|
||||
st1 {v25.h}[6], [x16], x7
|
||||
st1 {v26.4h}, [x15], x7
|
||||
ins v10.s[0], v26.s[2]
|
||||
st1 {v10.h}[0], [x17], x7
|
||||
st1 {v10.h}[1], [x19], x7
|
||||
st1 {v26.h}[6], [x16], x7
|
||||
st1 {v27.4h}, [x15], x7
|
||||
ins v11.s[0], v27.s[2]
|
||||
st1 {v11.h}[0], [x17], x7
|
||||
st1 {v11.h}[1], [x19], x7
|
||||
st1 {v27.h}[6], [x16], x7
|
||||
st1 {v28.4h}, [x15], x7
|
||||
ins v12.s[0], v28.s[2]
|
||||
st1 {v12.h}[0], [x17], x7
|
||||
st1 {v12.h}[1], [x19], x7
|
||||
st1 {v28.h}[6], [x16], x7
|
||||
st1 {v29.4h}, [x15], x7
|
||||
ins v13.s[0], v29.s[2]
|
||||
st1 {v13.h}[0], [x17], x7
|
||||
st1 {v13.h}[1], [x19], x7
|
||||
st1 {v29.h}[6], [x16], x7
|
||||
st1 {v30.4h}, [x15], x7
|
||||
ins v14.s[0], v30.s[2]
|
||||
st1 {v14.h}[0], [x17], x7
|
||||
st1 {v14.h}[1], [x19], x7
|
||||
st1 {v30.h}[6], [x16], x7
|
||||
st1 {v31.4h}, [x15]
|
||||
ins v15.s[0], v31.s[2]
|
||||
st1 {v15.h}[0], [x17]
|
||||
st1 {v15.h}[1], [x19]
|
||||
st1 {v31.h}[6], [x16]
|
||||
add x0, x0, #14
|
||||
b WriteEnd
|
||||
Write8:
|
||||
st1 {v16.8h}, [x15], x7
|
||||
st1 {v17.8h}, [x15], x7
|
||||
st1 {v18.8h}, [x15], x7
|
||||
st1 {v19.8h}, [x15], x7
|
||||
st1 {v20.8h}, [x15], x7
|
||||
st1 {v21.8h}, [x15], x7
|
||||
st1 {v22.8h}, [x15], x7
|
||||
st1 {v23.8h}, [x15], x7
|
||||
st1 {v24.8h}, [x15], x7
|
||||
st1 {v25.8h}, [x15], x7
|
||||
st1 {v26.8h}, [x15], x7
|
||||
st1 {v27.8h}, [x15], x7
|
||||
st1 {v28.8h}, [x15], x7
|
||||
st1 {v29.8h}, [x15], x7
|
||||
st1 {v30.8h}, [x15], x7
|
||||
st1 {v31.8h}, [x15]
|
||||
add x0, x0, #16
|
||||
|
||||
WriteEnd:
|
||||
subs x14, x14, #1
|
||||
bne LoopKsize
|
||||
|
||||
subs x6, x6, #8
|
||||
cbz x3, NoStepForward
|
||||
add x3, x3, #16
|
||||
NoStepForward:
|
||||
bgt LoopOc
|
||||
|
||||
sub sp, sp, #144
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
|
||||
ldp x19, x20, [sp], #16
|
||||
ret
|
||||
#endif
|
||||
|
|
@ -29,7 +29,7 @@ int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
|||
}
|
||||
#endif
|
||||
for (; offset < ele_num; offset++) {
|
||||
dst[offset] = src[offset] < 0 ? 0 : src[offset];
|
||||
dst[offset] = src[offset] < 0.0f ? 0.0f : src[offset];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
@ -47,14 +47,24 @@ int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) {
|
|||
}
|
||||
#endif
|
||||
for (; offset < ele_num; offset++) {
|
||||
dst[offset] = data[offset] < 0 ? 0 : data[offset];
|
||||
dst[offset] = dst[offset] > 6 ? 6 : dst[offset];
|
||||
dst[offset] = data[offset] < 0.0f ? 0.0f : data[offset];
|
||||
dst[offset] = dst[offset] > 6.0f ? 6.0f : dst[offset];
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha) {
|
||||
for (int i = 0; i < ele_num; ++i) {
|
||||
int i = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
int ele_c8 = UP_ROUND(ele_num, C8NUM);
|
||||
for (; i < ele_c8; i += C8NUM) {
|
||||
float16x8_t src_tmp = vld1q_f16(src + i);
|
||||
float16x8_t mul_tmp = vmulq_n_f16(src_tmp, alpha);
|
||||
float16x8_t mask = vcgtq_f16(src_tmp, vdupq_n_f16(0.0f));
|
||||
vst1q_f16(dst + i, vbslq_f32(mask, src_tmp, mul_tmp));
|
||||
}
|
||||
#endif
|
||||
for (; i < ele_num; ++i) {
|
||||
dst[i] = src[i] > (float16_t)0.0f ? src[i] : (src[i] * alpha);
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
@ -62,12 +72,12 @@ int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha
|
|||
|
||||
int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
||||
int i = 0;
|
||||
#ifdef ENABLE_ARM64
|
||||
#ifdef ENABLE_NEON
|
||||
int count = (ele_num / C4NUM) * C4NUM;
|
||||
for (; i < count; i += C4NUM) {
|
||||
float32x4_t tmp;
|
||||
simd_exp(vnegq_f32(vcvt_f32_f16(vld1_f16(src + i))), (float *)&tmp);
|
||||
vst1_f16(dst + i, vcvt_f16_f32(vdivq_f32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), tmp))));
|
||||
vst1_f16(dst + i, vcvt_f16_f32(MS_DIVQ_F32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), tmp))));
|
||||
}
|
||||
#endif
|
||||
for (; i < ele_num; ++i) {
|
||||
|
@ -79,9 +89,9 @@ int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
|||
}
|
||||
|
||||
float16_t TanhOptFp16(float16_t src) {
|
||||
if (src > 5.0) {
|
||||
if (src > 5.0f) {
|
||||
return 1.0f;
|
||||
} else if (src < -5.0) {
|
||||
} else if (src < -5.0f) {
|
||||
return -1.0f;
|
||||
} else {
|
||||
float square = src * src;
|
||||
|
@ -93,7 +103,7 @@ float16_t TanhOptFp16(float16_t src) {
|
|||
|
||||
int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
||||
int i = 0;
|
||||
#ifdef ENABLE_ARM64
|
||||
#ifdef ENABLE_NEON
|
||||
static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f},
|
||||
{17325.0f, 17325.0f, 17325.0f, 17325.0f},
|
||||
{135135.0f, 135135.0f, 135135.0f, 135135.0f},
|
||||
|
@ -112,7 +122,7 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
|||
float32x4_t b = vaddq_f32(
|
||||
vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square),
|
||||
paramv[2]);
|
||||
vst1_f16(dst + i, vcvt_f16_f32(vminq_f32(vmaxq_f32(vdivq_f32(a, b), neg_one), pos_one)));
|
||||
vst1_f16(dst + i, vcvt_f16_f32(vminq_f32(vmaxq_f32(MS_DIVQ_F32(a, b), neg_one), pos_one)));
|
||||
}
|
||||
#endif
|
||||
for (; i < ele_num; ++i) {
|
||||
|
@ -130,7 +140,7 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
|||
int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
||||
for (int i = 0; i < ele_num; ++i) {
|
||||
float16_t in = src[i];
|
||||
float16_t relu6 = MSMIN(MSMAX(in + 3, 0), 6);
|
||||
float16_t relu6 = MSMIN(MSMAX(in + 3.0f, 0.0f), 6.0f);
|
||||
dst[i] = in * relu6 / (float16_t)6.0f;
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
@ -181,26 +191,26 @@ int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate)
|
|||
for (; i < C8; i += C8NUM) {
|
||||
float16x8_t in = vld1q_f16(src + i);
|
||||
float16x8_t res =
|
||||
0.5 * in * (1.0 + MS_TANHX8_F16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * in * in) * in));
|
||||
0.5f * in * (1.0f + MS_TANHX8_F16(((float16_t)0.79788456080287f + (float16_t)0.035677408136f * in * in) * in));
|
||||
vst1q_f16(dst + i, res);
|
||||
}
|
||||
#endif
|
||||
for (; i < length; i++) {
|
||||
dst[i] =
|
||||
0.5 * src[i] *
|
||||
(1.0 + TanhOptFp16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * src[i] * src[i]) * src[i]));
|
||||
0.5f * src[i] *
|
||||
(1.0f + TanhOptFp16(((float16_t)0.79788456080287f + (float16_t)0.035677408136f * src[i] * src[i]) * src[i]));
|
||||
}
|
||||
} else {
|
||||
#ifdef ENABLE_NEON
|
||||
int C8 = UP_ROUND(length, C8NUM);
|
||||
for (; i < C8; i += C8NUM) {
|
||||
float16x8_t in = vld1q_f16(src + i);
|
||||
const float16x8_t res = 0.5 * in * (1.0 + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f));
|
||||
const float16x8_t res = 0.5f * in * (1.0f + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f));
|
||||
vst1q_f16(dst + i, res);
|
||||
}
|
||||
#endif
|
||||
for (; i < length; i++) {
|
||||
dst[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f));
|
||||
dst[i] = 0.5f * src[i] * (1.0f + erff(src[i] / 1.4142135623730951f));
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
|
|
|
@ -16,11 +16,9 @@
|
|||
#ifndef MINDSPORE_NNACL_FP16_ACTIVATION_FP16_H_
|
||||
#define MINDSPORE_NNACL_FP16_ACTIVATION_FP16_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
|
||||
#include "nnacl/int8/fixed_point.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -569,7 +569,7 @@ int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t *
|
|||
for (; index <= element_size - 8; index += C8NUM) {
|
||||
float16x8_t vin0 = vld1q_f16(input0 + index);
|
||||
float16x8_t vin1 = vld1q_f16(input1 + index);
|
||||
float16x8_t vout = vdivq_f16(vin0, vin1);
|
||||
float16x8_t vout = MS_DIVQ_F16(vin0, vin1);
|
||||
vst1q_f16(output + index, vout);
|
||||
}
|
||||
#endif
|
||||
|
@ -591,7 +591,7 @@ int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_
|
|||
#ifdef ENABLE_NEON
|
||||
for (; index <= element_size - 8; index += C8NUM) {
|
||||
float16x8_t vin1 = vld1q_f16(input1 + index);
|
||||
float16x8_t vout = vdivq_f16(vin0_opt, vin1);
|
||||
float16x8_t vout = MS_DIVQ_F16(vin0_opt, vin1);
|
||||
vst1q_f16(output + index, vout);
|
||||
}
|
||||
#endif
|
||||
|
@ -606,7 +606,7 @@ int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_
|
|||
#ifdef ENABLE_NEON
|
||||
for (; index <= element_size - 8; index += C8NUM) {
|
||||
float16x8_t vin0 = vld1q_f16(input0 + index);
|
||||
float16x8_t vout = vdivq_f16(vin0, vin1_opt);
|
||||
float16x8_t vout = MS_DIVQ_F16(vin0, vin1_opt);
|
||||
vst1q_f16(output + index, vout);
|
||||
}
|
||||
#endif
|
||||
|
@ -624,7 +624,7 @@ int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16
|
|||
for (; index <= element_size - 8; index += C8NUM) {
|
||||
float16x8_t vin0 = vld1q_f16(input0 + index);
|
||||
float16x8_t vin1 = vld1q_f16(input1 + index);
|
||||
float16x8_t vout = vdivq_f16(vin0, vin1);
|
||||
float16x8_t vout = MS_DIVQ_F16(vin0, vin1);
|
||||
vout = vmaxq_f16(vout, zeros);
|
||||
vst1q_f16(output + index, vout);
|
||||
}
|
||||
|
@ -652,7 +652,7 @@ int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, floa
|
|||
#ifdef ENABLE_NEON
|
||||
for (; index <= element_size - 8; index += C8NUM) {
|
||||
float16x8_t vin1 = vld1q_f16(input1 + index);
|
||||
float16x8_t vout = vmaxq_f16(vdivq_f16(vin0_opt, vin1), zeros);
|
||||
float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros);
|
||||
vst1q_f16(output + index, vout);
|
||||
}
|
||||
#endif
|
||||
|
@ -670,7 +670,7 @@ int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, floa
|
|||
#ifdef ENABLE_NEON
|
||||
for (; index <= element_size - 8; index += C8NUM) {
|
||||
float16x8_t vin0 = vld1q_f16(input0 + index);
|
||||
float16x8_t vout = vmaxq_f16(vdivq_f16(vin0, vin1_opt), zeros);
|
||||
float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros);
|
||||
vst1q_f16(output + index, vout);
|
||||
}
|
||||
#endif
|
||||
|
@ -689,7 +689,7 @@ int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float1
|
|||
for (; index <= element_size - 8; index += C8NUM) {
|
||||
float16x8_t vin0 = vld1q_f16(input0 + index);
|
||||
float16x8_t vin1 = vld1q_f16(input1 + index);
|
||||
float16x8_t vout = vdivq_f16(vin0, vin1);
|
||||
float16x8_t vout = MS_DIVQ_F16(vin0, vin1);
|
||||
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
|
||||
vst1q_f16(output + index, vout);
|
||||
}
|
||||
|
@ -716,7 +716,7 @@ int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, flo
|
|||
#ifdef ENABLE_NEON
|
||||
for (; index <= element_size - 8; index += C8NUM) {
|
||||
float16x8_t vin1 = vld1q_f16(input1 + index);
|
||||
float16x8_t vout = vminq_f16(vmaxq_f16(vdivq_f16(vin0_opt, vin1), zeros), bounds);
|
||||
float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros), bounds);
|
||||
vst1q_f16(output + index, vout);
|
||||
}
|
||||
#endif
|
||||
|
@ -733,7 +733,7 @@ int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, flo
|
|||
#ifdef ENABLE_NEON
|
||||
for (; index <= element_size - 8; index += C8NUM) {
|
||||
float16x8_t vin0 = vld1q_f16(input0 + index);
|
||||
float16x8_t vout = vminq_f16(vmaxq_f16(vdivq_f16(vin0, vin1_opt), zeros), bounds);
|
||||
float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros), bounds);
|
||||
vst1q_f16(output + index, vout);
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -16,10 +16,8 @@
|
|||
#ifndef MINDSPORE_NNACL_FP16_ARITHMETIC_FP16_H_
|
||||
#define MINDSPORE_NNACL_FP16_ARITHMETIC_FP16_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
|
||||
#include "nnacl/base/arithmetic_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ int ElementLogicalNotFp16(float16_t *input, float16_t *output, int element_size)
|
|||
|
||||
int ElementRoundFp16(float16_t *input, float16_t *output, int element_size) {
|
||||
for (int i = 0; i < element_size; i++) {
|
||||
output[i] = round(input[i]);
|
||||
output[i] = roundf(input[i]);
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ int ElementFloorFp16(float16_t *input, float16_t *output, int element_size) {
|
|||
|
||||
int ElementCeilFp16(float16_t *input, float16_t *output, int number) {
|
||||
for (int i = 0; i < number; ++i) {
|
||||
output[i] = ceil(input[i]);
|
||||
output[i] = ceilf(input[i]);
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -16,10 +16,8 @@
|
|||
#ifndef MINDSPORE_NNACL_FP16_ARITHMETIC_SELF_FP16_H_
|
||||
#define MINDSPORE_NNACL_FP16_ARITHMETIC_SELF_FP16_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -26,7 +26,7 @@ void BatchNormFp16(const float16_t *input, const void *mean, const void *varianc
|
|||
|
||||
for (int i = 0; i < cur_unit; i++) {
|
||||
for (int c = 0; c < param->channel_; c++) {
|
||||
float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_);
|
||||
float16_t variance_sqrt = sqrtf(((const float16_t *)variance)[c] + param->epsilon_);
|
||||
if (variance_sqrt != 0) {
|
||||
output[cur_offset + c] = (input[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset
|
|||
|
||||
for (int i = 0; i < cur_unit; i++) {
|
||||
for (int c = 0; c < param->channel_; c++) {
|
||||
float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_);
|
||||
float16_t variance_sqrt = sqrtf(((const float16_t *)variance)[c] + param->epsilon_);
|
||||
if (variance_sqrt != 0) {
|
||||
float16_t norm_val =
|
||||
(((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
|
||||
|
|
|
@ -13,12 +13,9 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_BATCHNORM_FP16_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_BATCHNORM_FP16_H_
|
||||
#ifndef MINDSPORE_NNACL_FP16_BATCHNORM_FP16_H_
|
||||
#define MINDSPORE_NNACL_FP16_BATCHNORM_FP16_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/batchnorm_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -34,4 +31,4 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset
|
|||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_BATCHNORM_FP16_H_
|
||||
#endif // MINDSPORE_NNACL_FP16_BATCHNORM_FP16_H_
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
#ifndef MINDSPORE_NNACL_CAST_FP16_H_
|
||||
#define MINDSPORE_NNACL_CAST_FP16_H_
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -56,3 +56,17 @@ void PostConvFuncFp16C4(const float16_t *c4_out, float16_t *nhwc_out, const floa
|
|||
PostFuncBiasReluC4Fp16(nhwc_out, c4_out, bias, oc4div, oc4mod, plane, stride_size, act_type);
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM82_A32
|
||||
void PostFuncBiasReluC4Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc4div, size_t oc4mod,
|
||||
size_t plane_size, size_t plane_stride, size_t relu_type) {
|
||||
// TODO(fun): function
|
||||
return;
|
||||
}
|
||||
|
||||
void PostFuncBiasReluC8Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc8div, size_t oc8mod,
|
||||
size_t plane_size, size_t stride, size_t relu_type) {
|
||||
// TODO(fun): function
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
#ifndef MINDSPORE_NNACL_FP16_COMMON_FUNC_FP16_H_
|
||||
#define MINDSPORE_NNACL_FP16_COMMON_FUNC_FP16_H_
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -16,9 +16,6 @@
|
|||
#ifndef MINDSPORE_NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_
|
||||
#define MINDSPORE_NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/constant_of_shape_parameter.h"
|
||||
|
@ -27,7 +24,7 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
#ifdef ENABLE_NEON
|
||||
#ifdef ENABLE_FP16
|
||||
inline int ConstantOfShapeFp16(float16_t *output, int start, int end, float16_t value) {
|
||||
for (int i = start; i < end; i++) {
|
||||
output[i] = value;
|
||||
|
|
|
@ -18,6 +18,18 @@
|
|||
#include <string.h>
|
||||
#include "nnacl/fp16/activation_fp16.h"
|
||||
|
||||
#ifdef ENABLE_ARM82_A32
|
||||
void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *weight_ptr, size_t num_pixels,
|
||||
size_t output_channel, size_t input_step) {
|
||||
for (int i = 0; i < num_pixels; i++) {
|
||||
for (int c = 0; c < output_channel; c++) {
|
||||
*output_ptr++ += weight_ptr[c] * input_ptr[c];
|
||||
}
|
||||
input_ptr += input_step;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
|
||||
const float16_t *bias_data, const ConvParameter *conv_param, int task_id) {
|
||||
int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
|
||||
|
@ -57,7 +69,6 @@ void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float
|
|||
|
||||
const float16_t *src_kw = src_kh + iw_origin * conv_param->input_channel_;
|
||||
int num_pixels = out_w_end - out_w_start;
|
||||
|
||||
ConvDwFp16Row(dst_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step);
|
||||
weight_kh += conv_param->output_channel_;
|
||||
}
|
||||
|
|
|
@ -23,9 +23,9 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
#ifdef ENABLE_ARM64
|
||||
void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *filter_ptr, size_t num_pixels,
|
||||
size_t input_channel, size_t input_step);
|
||||
#ifdef ENABLE_ARM64
|
||||
void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias,
|
||||
size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu,
|
||||
size_t relu6);
|
||||
|
|
|
@ -31,7 +31,7 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh
|
|||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#ifndef ENABLE_NEON
|
||||
#ifndef ENABLE_ARM64
|
||||
void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
|
||||
size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC8, size_t relu,
|
||||
size_t relu6) {
|
||||
|
@ -124,7 +124,11 @@ void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *we
|
|||
// fp16 convolution common (im2col+gemm)
|
||||
void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data,
|
||||
float16_t *col_major_input, float16_t *output_data, int task_id, ConvParameter *conv_param) {
|
||||
#ifdef ENABLE_ARM64
|
||||
const int tile_n = 16;
|
||||
#else
|
||||
const int tile_n = 12;
|
||||
#endif
|
||||
int out_channel = conv_param->output_channel_;
|
||||
int output_count = conv_param->output_h_ * conv_param->output_w_;
|
||||
int output_tile_count = UP_DIV(output_count, tile_n);
|
||||
|
@ -144,7 +148,11 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
|
|||
Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);
|
||||
|
||||
int out_offset = thread_id * tile_n * out_channel + out_batch_offset;
|
||||
#ifdef ENABLE_ARM64
|
||||
RowMajor2Col16MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep);
|
||||
#else
|
||||
RowMajor2Col12MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep);
|
||||
#endif
|
||||
MatMulFp16(col_major_gemm_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep,
|
||||
real_cal_num, out_channel, out_channel, OutType_Nhwc);
|
||||
}
|
||||
|
@ -155,7 +163,11 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
|
|||
void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data,
|
||||
float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param,
|
||||
InputTransFp16Func in_func, OutputTransFp16Func out_func) {
|
||||
#ifdef ENABLE_ARM64
|
||||
const int tile_num = 16;
|
||||
#else
|
||||
const int tile_num = 12;
|
||||
#endif
|
||||
int in_channel = conv_param->input_channel_;
|
||||
int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_);
|
||||
int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_);
|
||||
|
@ -194,7 +206,11 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
|
|||
float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset;
|
||||
float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
|
||||
for (int i = 0; i < input_unit_square; ++i) {
|
||||
#ifdef ENABLE_ARM64
|
||||
RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
|
||||
#else
|
||||
RowMajor2Col12MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
|
||||
#endif
|
||||
MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel,
|
||||
cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8);
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
typedef float16_t *TmpBufferAddressFp16;
|
||||
typedef float16_t *MatricesFp16;
|
||||
|
||||
#ifndef ENABLE_NEON
|
||||
#ifndef ENABLE_ARM64
|
||||
void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
|
||||
size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu,
|
||||
size_t relu6);
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#ifndef MINDSPORE_NNACL_FP16_CROP_FP16_H_
|
||||
#define MINDSPORE_NNACL_FP16_CROP_FP16_H_
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/crop_parameter.h"
|
||||
|
||||
|
|
|
@ -53,11 +53,7 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
|
|||
int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride;
|
||||
float16_t *tmp_dst = dst_ptr + dst_index;
|
||||
const float16_t *tmp_src = src_ptr + src_index;
|
||||
#ifdef DEBUG_CODE
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_dst[i] += tmp_src[i];
|
||||
}
|
||||
#else
|
||||
#ifdef ENABLE_ARM64
|
||||
asm volatile(
|
||||
"mov x0, %[tmp_src] \n"
|
||||
"mov x1, %[tmp_dst] \n"
|
||||
|
@ -72,6 +68,10 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
|
|||
:
|
||||
: [ tmp_src ] "r"(tmp_src), [ tmp_dst ] "r"(tmp_dst)
|
||||
: "x0", "x1", "v0", "v1");
|
||||
#else
|
||||
for (int i = 0; i < C8NUM; i++) {
|
||||
tmp_dst[i] += tmp_src[i];
|
||||
}
|
||||
#endif
|
||||
} /*kw*/
|
||||
} /*kh*/
|
||||
|
|
|
@ -47,6 +47,7 @@ void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride,
|
|||
size_t cuont8 = count / C8NUM * C8NUM;
|
||||
int i = 0;
|
||||
for (; i < cuont8; i += C8NUM) {
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t src_step = src_stride * sizeof(float16_t);
|
||||
size_t dst_step = dst_stride * sizeof(float16_t);
|
||||
asm volatile(
|
||||
|
@ -93,7 +94,9 @@ void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride,
|
|||
:
|
||||
: [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step)
|
||||
: "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
|
||||
|
||||
#else
|
||||
// TODO(fun): arm32
|
||||
#endif
|
||||
src_ptr += C8NUM * src_stride;
|
||||
dst_ptr += C8NUM * dst_stride;
|
||||
}
|
||||
|
@ -373,3 +376,23 @@ void DeconvWgPostFp16(float16_t *tile_out, float16_t *nc4hw4_output, ConvParamet
|
|||
}
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM82_A32
|
||||
void WinogradTransLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k,
|
||||
size_t length) {
|
||||
// TODO(fun): function
|
||||
return;
|
||||
}
|
||||
|
||||
void WinogradTransRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k,
|
||||
size_t length) {
|
||||
// TODO(fun): function
|
||||
return;
|
||||
}
|
||||
|
||||
void TiledC4MatmulFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t ic4, size_t cal_num,
|
||||
size_t oc4) {
|
||||
// TODO(fun): function
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_NNACL_FP16_EXP_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "nnacl/fp16/instance_norm_fp16.h"
|
||||
#include <math.h>
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
|
||||
|
||||
int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data,
|
||||
const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) {
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#define MINDSPORE_NNACL_FP16_INSTANCE_NORM_H_
|
||||
|
||||
#include "nnacl/instance_norm_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "nnacl/fp16/arithmetic_fp16.h"
|
||||
#include "nnacl/fp16/matmul_fp16.h"
|
||||
#include "nnacl/fp16/cast_fp16.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
|
||||
|
||||
void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align) {
|
||||
for (int i = 0; i < batch; i++) {
|
||||
|
@ -121,7 +122,7 @@ int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float1
|
|||
for (; index <= element_size - 8; index += 8) {
|
||||
float16x8_t vin0 = vld1q_f16(input0 + index);
|
||||
float16x8_t vout = vld1q_f16(output + index);
|
||||
vout = vfmaq_n_f16(vout, vin0, input1);
|
||||
vout = MS_FMAQ_N_F16(vout, vin0, input1);
|
||||
vst1q_f16(output + index, vout);
|
||||
}
|
||||
for (; index < element_size; index++) {
|
||||
|
|
|
@ -226,24 +226,43 @@ void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row,
|
|||
return;
|
||||
}
|
||||
|
||||
void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
|
||||
int deep, int row, int col, int stride, bool write_nhwc) {
|
||||
if (write_nhwc) {
|
||||
/* col16-major * row8-major => col-major */
|
||||
void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
|
||||
int deep, int row, int col, int stride, int write_mode) {
|
||||
if (write_mode == OutType_Nhwc) {
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r16div = r / 16, r16mod = r % 16;
|
||||
int c8div = c / 8, c8mod = c % 8;
|
||||
size_t ci = r * stride + c;
|
||||
float16_t value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r16div * deep * 16 + d * 16 + r16mod;
|
||||
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
ADD_BIAS(value, bias, c)
|
||||
DO_RELU(value, act_type)
|
||||
DO_RELU6(value, act_type)
|
||||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
} else if (write_mode == OutType_C8) {
|
||||
int col_8 = UP_ROUND(col, C8NUM);
|
||||
int row_16 = UP_ROUND(row, C16NUM);
|
||||
for (int r = 0; r < row_16; r++) {
|
||||
for (int c = 0; c < col_8; c++) {
|
||||
int r16div = r / C16NUM, r16mod = r % C16NUM;
|
||||
int c8div = c / C8NUM, c8mod = c % C8NUM;
|
||||
size_t ci = r * stride + c;
|
||||
float value = 0;
|
||||
size_t ci = (c8div * C8NUM * row_16 + r * C8NUM + c8mod);
|
||||
float16_t value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod;
|
||||
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
if (bias != NULL) value += bias[c];
|
||||
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
|
||||
if (act_type != ActType_No) value = MSMAX(0.0f, value);
|
||||
ADD_BIAS(value, bias, c)
|
||||
DO_RELU(value, act_type)
|
||||
DO_RELU6(value, act_type)
|
||||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
|
@ -254,37 +273,119 @@ void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const fl
|
|||
for (int j = 0; j < col; ++j) {
|
||||
int c8div = j / 8, c8mod = j % 8;
|
||||
size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
|
||||
float value = 0;
|
||||
float16_t value = 0;
|
||||
for (int d = 0; d < deep; ++d) {
|
||||
size_t ai = src_r_offset + d * C16NUM;
|
||||
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
if (bias != NULL) value += bias[j];
|
||||
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
|
||||
if (act_type != ActType_No) value = MSMAX(0.0f, value);
|
||||
ADD_BIAS(value, bias, j)
|
||||
DO_RELU(value, act_type)
|
||||
DO_RELU6(value, act_type)
|
||||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
|
||||
int deep, int row, int col, int stride, int write_mode) {
|
||||
if (write_mode == OutType_Nhwc) {
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r12div = r / 12, r12mod = r % 12;
|
||||
int c8div = c / 8, c8mod = c % 8;
|
||||
size_t ci = r * stride + c;
|
||||
float16_t value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r12div * deep * 12 + d * 12 + r12mod;
|
||||
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
ADD_BIAS(value, bias, c)
|
||||
DO_RELU(value, act_type)
|
||||
DO_RELU6(value, act_type)
|
||||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
} else if (write_mode == OutType_C8) {
|
||||
int col_8 = UP_ROUND(col, C8NUM);
|
||||
int row_12 = UP_ROUND(row, C12NUM);
|
||||
for (int r = 0; r < row_12; r++) {
|
||||
for (int c = 0; c < col_8; c++) {
|
||||
int r12div = r / C12NUM, r12mod = r % C12NUM;
|
||||
int c8div = c / C8NUM, c8mod = c % C8NUM;
|
||||
size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod);
|
||||
float16_t value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod;
|
||||
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
ADD_BIAS(value, bias, c)
|
||||
DO_RELU(value, act_type)
|
||||
DO_RELU6(value, act_type)
|
||||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < row; ++i) {
|
||||
int src_r_offset = i;
|
||||
int dst_r_offset = i * col * stride;
|
||||
for (int j = 0; j < col; ++j) {
|
||||
int c8div = j / 8, c8mod = j % 8;
|
||||
size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
|
||||
float16_t value = 0;
|
||||
for (int d = 0; d < deep; ++d) {
|
||||
size_t ai = src_r_offset + d * C12NUM;
|
||||
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
ADD_BIAS(value, bias, j)
|
||||
DO_RELU(value, act_type)
|
||||
DO_RELU6(value, act_type)
|
||||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
|
||||
int depth, int row, int col, int stride, int out_type) {
|
||||
if (out_type == OutType_C8) {
|
||||
#ifdef ENABLE_ARM64
|
||||
MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, false);
|
||||
#else
|
||||
MatMul12x8A32Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type);
|
||||
#endif
|
||||
} else {
|
||||
#ifdef ENABLE_ARM64
|
||||
MatmulFp16Neon64Opt(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type);
|
||||
#else
|
||||
MatMul12x8A32Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type);
|
||||
#endif
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM82_A32
|
||||
void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
int depth, int col) {
|
||||
// TODO(fun): function
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
|
||||
int depth, int col) {
|
||||
#ifdef ENABLE_ARM64
|
||||
MatVecMulFp16Neon64(a, b, c, bias, (int)act_type, depth, col);
|
||||
#else
|
||||
MatVecMulA32Fp16(a, b, c, bias, (int)act_type, depth, col);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) {
|
||||
size_t stride = col * 2;
|
||||
asm volatile(
|
||||
|
@ -392,6 +493,7 @@ static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_
|
|||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
|
||||
"v31");
|
||||
}
|
||||
#endif
|
||||
|
||||
void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
|
||||
size_t row_up_16 = UP_ROUND(row, C16NUM);
|
||||
|
@ -442,6 +544,54 @@ void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si
|
|||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
|
||||
size_t row_up_12 = UP_ROUND(row, C12NUM);
|
||||
size_t row12 = row / C12NUM * C12NUM;
|
||||
size_t col8 = col / C8NUM * C8NUM;
|
||||
const float16_t *src_r = src_ptr;
|
||||
float16_t *dst_r = dst_ptr;
|
||||
size_t ri = 0;
|
||||
// transpose 12x8
|
||||
for (; ri < row12; ri += C12NUM) {
|
||||
size_t ci = 0;
|
||||
for (; ci < col8; ci += C8NUM) {
|
||||
const float16_t *src_c = src_r + ci;
|
||||
float16_t *dst_c = dst_r + ci * C12NUM;
|
||||
#ifdef ENABLE_ARM82_A32
|
||||
Transpose12x8A32Fp16(src_c, dst_c, col * sizeof(float16_t), 24);
|
||||
#else
|
||||
for (int tr = 0; tr < C12NUM; tr++) {
|
||||
for (int tc = 0; tc < C8NUM; tc++) {
|
||||
dst_c[tc * C12NUM + tr] = src_c[tr * col + tc];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (; ci < col; ci++) {
|
||||
const float16_t *src_c = src_r + ci;
|
||||
float16_t *dst_c = dst_r + ci * C12NUM;
|
||||
for (size_t i = 0; i < C12NUM; i++) {
|
||||
dst_c[i] = src_c[i * col];
|
||||
}
|
||||
}
|
||||
src_r += C12NUM * col;
|
||||
dst_r += C12NUM * col;
|
||||
}
|
||||
for (; ri < row; ri++) {
|
||||
for (size_t i = 0; i < col; ++i) {
|
||||
dst_r[i * C12NUM] = src_r[i];
|
||||
}
|
||||
src_r += col;
|
||||
dst_r += 1;
|
||||
}
|
||||
for (; ri < row_up_12; ri++) {
|
||||
for (size_t i = 0; i < col; i++) {
|
||||
dst_r[i * C12NUM] = 0;
|
||||
}
|
||||
dst_r += 1;
|
||||
}
|
||||
}
|
||||
|
||||
void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) {
|
||||
if (is_fp32_src) {
|
||||
const float *fp32_src = (const float *)src;
|
||||
|
|
|
@ -19,18 +19,51 @@
|
|||
|
||||
#include <float.h>
|
||||
#include <string.h>
|
||||
#ifdef ENABLE_NEON
|
||||
#ifdef ENABLE_ARM64
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/matmul_parameter.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
|
||||
#include "nnacl/fp16/pack_fp16.h"
|
||||
|
||||
#define ADD_BIAS(value, bias, c) \
|
||||
if (bias != NULL) value = value + bias[c];
|
||||
|
||||
#define DO_RELU(value, act_type) \
|
||||
if (act_type == ActType_Relu) value = MSMAX(0.0f, value);
|
||||
|
||||
#define DO_RELU6(value, act_type) \
|
||||
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); \
|
||||
if (act_type == ActType_Relu6) value = MSMAX(0.0f, value);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
|
||||
int deep, int row, int col, int stride, bool write_nhwc);
|
||||
void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
|
||||
int deep, int row, int col, int stride, int write_mode);
|
||||
|
||||
void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
|
||||
int deep, int row, int col, int stride, int write_mode);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc);
|
||||
|
||||
void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc);
|
||||
|
||||
void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
int depth, int col);
|
||||
|
||||
#elif ENABLE_ARM82_A32
|
||||
void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
|
||||
int deep, int row, int col, int stride, int write_mode);
|
||||
|
||||
void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
int depth, int col);
|
||||
#endif
|
||||
|
||||
void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
|
||||
int depth, int row, int col, int stride, int out_type);
|
||||
|
@ -42,14 +75,7 @@ void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row,
|
|||
|
||||
void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col);
|
||||
|
||||
void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc);
|
||||
|
||||
void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc);
|
||||
|
||||
void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
|
||||
int depth, int col);
|
||||
void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col);
|
||||
|
||||
void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src);
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matri
|
|||
for (int i = 0; i < m; ++i) {
|
||||
for (int j = 0; j < n; ++j) {
|
||||
for (int y = 0; y < in_channel; ++y) {
|
||||
float16_t tmp = 0;
|
||||
float tmp = 0;
|
||||
for (int z = 0; z < k; ++z) {
|
||||
tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n];
|
||||
}
|
||||
|
|
|
@ -160,129 +160,35 @@ void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int
|
|||
}
|
||||
|
||||
void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int channel) {
|
||||
int hw16 = plane / C16NUM * C16NUM;
|
||||
#ifdef ENABLE_ARM64
|
||||
// Transpose16x8 in arm64
|
||||
const int hw_tile = C16NUM;
|
||||
#else
|
||||
// Transpose8x8 in others
|
||||
const int hw_tile = C8NUM;
|
||||
#endif
|
||||
int hw_align = plane / hw_tile * hw_tile;
|
||||
int c8 = channel / C8NUM * C8NUM;
|
||||
int batch = plane * channel;
|
||||
for (int n = 0; n < batches; n++) {
|
||||
const float16_t *src_batch = (const float16_t *)src + n * batch;
|
||||
float16_t *dst_batch = (float16_t *)dst + n * batch;
|
||||
int hw = 0;
|
||||
for (; hw < hw16; hw += C16NUM) {
|
||||
for (; hw < hw_align; hw += hw_tile) {
|
||||
int c = 0;
|
||||
for (; c < c8; c += C8NUM) {
|
||||
const float16_t *src_ptr = src_batch + hw * channel + c;
|
||||
float16_t *dst_ptr = dst_batch + c * plane + hw;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t srcStride = channel * sizeof(float16_t);
|
||||
size_t dstStride = plane * sizeof(float16_t);
|
||||
asm volatile(
|
||||
"mov x10, %[src_ptr]\n"
|
||||
"mov x11, %[dst_ptr]\n"
|
||||
|
||||
"ld1 {v0.8h}, [x10], %[srcStride]\n"
|
||||
"ld1 {v1.8h}, [x10], %[srcStride]\n"
|
||||
"ld1 {v2.8h}, [x10], %[srcStride]\n"
|
||||
"ld1 {v3.8h}, [x10], %[srcStride]\n"
|
||||
"ld1 {v4.8h}, [x10], %[srcStride]\n"
|
||||
"ld1 {v5.8h}, [x10], %[srcStride]\n"
|
||||
"ld1 {v6.8h}, [x10], %[srcStride]\n"
|
||||
"ld1 {v7.8h}, [x10], %[srcStride]\n"
|
||||
|
||||
"zip1 v16.8h, v0.8h, v1.8h\n"
|
||||
"zip1 v17.8h, v2.8h, v3.8h\n"
|
||||
"zip1 v18.8h, v4.8h, v5.8h\n"
|
||||
"zip1 v19.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"ld1 {v8.8h}, [x10], %[srcStride]\n"
|
||||
"ld1 {v9.8h}, [x10], %[srcStride]\n"
|
||||
"ld1 {v10.8h}, [x10], %[srcStride]\n"
|
||||
"ld1 {v11.8h}, [x10], %[srcStride]\n"
|
||||
"ld1 {v12.8h}, [x10], %[srcStride]\n"
|
||||
"ld1 {v13.8h}, [x10], %[srcStride]\n"
|
||||
"ld1 {v14.8h}, [x10], %[srcStride]\n"
|
||||
"ld1 {v15.8h}, [x10], %[srcStride]\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v24.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v25.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v26.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v27.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"zip1 v16.8h, v8.8h, v9.8h\n"
|
||||
"zip1 v17.8h, v10.8h, v11.8h\n"
|
||||
"zip1 v18.8h, v12.8h, v13.8h\n"
|
||||
"zip1 v19.8h, v14.8h, v15.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v28.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v29.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v30.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v31.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"add x10, x11, #16\n"
|
||||
"st1 {v24.8h}, [x11], %[dstStride]\n"
|
||||
"st1 {v28.8h}, [x10], %[dstStride]\n"
|
||||
"st1 {v26.8h}, [x11], %[dstStride]\n"
|
||||
"st1 {v30.8h}, [x10], %[dstStride]\n"
|
||||
"st1 {v25.8h}, [x11], %[dstStride]\n"
|
||||
"st1 {v29.8h}, [x10], %[dstStride]\n"
|
||||
"st1 {v27.8h}, [x11], %[dstStride]\n"
|
||||
"st1 {v31.8h}, [x10], %[dstStride]\n"
|
||||
|
||||
"zip2 v16.8h, v0.8h, v1.8h\n"
|
||||
"zip2 v17.8h, v2.8h, v3.8h\n"
|
||||
"zip2 v18.8h, v4.8h, v5.8h\n"
|
||||
"zip2 v19.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v24.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v25.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v26.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v27.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"zip2 v16.8h, v8.8h, v9.8h\n"
|
||||
"zip2 v17.8h, v10.8h, v11.8h\n"
|
||||
"zip2 v18.8h, v12.8h, v13.8h\n"
|
||||
"zip2 v19.8h, v14.8h, v15.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v28.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v29.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v30.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v31.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"st1 {v24.8h}, [x11], %[dstStride]\n"
|
||||
"st1 {v28.8h}, [x10], %[dstStride]\n"
|
||||
"st1 {v26.8h}, [x11], %[dstStride]\n"
|
||||
"st1 {v30.8h}, [x10], %[dstStride]\n"
|
||||
"st1 {v25.8h}, [x11], %[dstStride]\n"
|
||||
"st1 {v29.8h}, [x10], %[dstStride]\n"
|
||||
"st1 {v27.8h}, [x11], %[dstStride]\n"
|
||||
"st1 {v31.8h}, [x10], %[dstStride]\n"
|
||||
:
|
||||
:
|
||||
[ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
|
||||
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
|
||||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
|
||||
"v30", "v31");
|
||||
size_t src_stride = channel * sizeof(float16_t);
|
||||
size_t dst_stride = plane * sizeof(float16_t);
|
||||
Transpose16x8ARM64Fp16(src_ptr, dst_ptr, src_stride, dst_stride);
|
||||
#elif defined(ENABLE_ARM82_A32)
|
||||
size_t src_stride = channel * sizeof(float16_t);
|
||||
size_t dst_stride = plane * sizeof(float16_t);
|
||||
Transpose8x8A32Fp16(src_ptr, dst_ptr, src_stride, dst_stride);
|
||||
#else
|
||||
for (int tr = 0; tr < C16NUM; tr++) {
|
||||
for (int tr = 0; tr < hw_tile; tr++) {
|
||||
for (int tc = 0; tc < C8NUM; tc++) {
|
||||
dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc];
|
||||
}
|
||||
|
@ -292,7 +198,7 @@ void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int
|
|||
for (; c < channel; c++) {
|
||||
const float16_t *src_ptr = src_batch + hw * channel + c;
|
||||
float16_t *dst_ptr = dst_batch + c * plane + hw;
|
||||
for (size_t i = 0; i < C16NUM; i++) {
|
||||
for (size_t i = 0; i < hw_tile; i++) {
|
||||
dst_ptr[i] = src_ptr[i * channel];
|
||||
}
|
||||
}
|
||||
|
@ -305,7 +211,6 @@ void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int
|
|||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
|
@ -565,3 +470,246 @@ void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, i
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM82_A32
|
||||
inline void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride) {
|
||||
asm volatile(
|
||||
"mov r10, %[src]\n"
|
||||
"mov r12, %[dst]\n"
|
||||
"vld1.16 {q0}, [r10], %[src_stride]\n"
|
||||
"vld1.16 {q2}, [r10], %[src_stride]\n"
|
||||
"vld1.16 {q4}, [r10], %[src_stride]\n"
|
||||
"vld1.16 {q6}, [r10], %[src_stride]\n"
|
||||
|
||||
"vtrn.16 d0, d4\n"
|
||||
"vtrn.16 d1, d5\n"
|
||||
"vtrn.16 d8, d12\n"
|
||||
"vtrn.16 d9, d13\n"
|
||||
|
||||
"vld1.16 {q8}, [r10], %[src_stride]\n"
|
||||
"vld1.16 {q10}, [r10], %[src_stride]\n"
|
||||
"vld1.16 {q12}, [r10], %[src_stride]\n"
|
||||
"vld1.16 {q14}, [r10], %[src_stride]\n"
|
||||
|
||||
"vtrn.32 d0, d8\n"
|
||||
"vtrn.32 d4, d12\n"
|
||||
"vtrn.32 d1, d9\n"
|
||||
"vtrn.32 d5, d13\n"
|
||||
|
||||
"vtrn.16 d16, d20\n"
|
||||
"vtrn.16 d17, d21\n"
|
||||
"vtrn.16 d24, d28\n"
|
||||
"vtrn.16 d25, d29\n"
|
||||
|
||||
"vtrn.32 d16, d24\n"
|
||||
"vtrn.32 d20, d28\n"
|
||||
"vtrn.32 d17, d25\n"
|
||||
"vtrn.32 d21, d29\n"
|
||||
|
||||
"vswp d1, d16\n"
|
||||
"vswp d5, d20\n"
|
||||
"vswp d9, d24\n"
|
||||
"vswp d13, d28\n"
|
||||
|
||||
"vst1.16 {q0}, [r12], %[dst_stride]\n"
|
||||
"vst1.16 {q2}, [r12], %[dst_stride]\n"
|
||||
"vst1.16 {q4}, [r12], %[dst_stride]\n"
|
||||
"vst1.16 {q6}, [r12], %[dst_stride]\n"
|
||||
|
||||
"vst1.16 {q8}, [r12], %[dst_stride]\n"
|
||||
"vst1.16 {q10}, [r12], %[dst_stride]\n"
|
||||
"vst1.16 {q12}, [r12], %[dst_stride]\n"
|
||||
"vst1.16 {q14}, [r12], %[dst_stride]\n"
|
||||
|
||||
:
|
||||
: [ dst ] "r"(dst), [ src ] "r"(src), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride)
|
||||
: "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
|
||||
"q15");
|
||||
}
|
||||
|
||||
inline void Transpose12x8A32Fp16(const float16_t *src_c, float16_t *dst_c, size_t src_stride, size_t dst_stride) {
|
||||
asm volatile(
|
||||
"mov r10, %[src_c]\n"
|
||||
"mov r12, %[dst_c]\n"
|
||||
|
||||
"vld1.16 {q0}, [r10], %[src_stride]\n"
|
||||
"vld1.16 {q2}, [r10], %[src_stride]\n"
|
||||
"vld1.16 {q4}, [r10], %[src_stride]\n"
|
||||
"vld1.16 {q6}, [r10], %[src_stride]\n"
|
||||
|
||||
"vtrn.16 d0, d4\n"
|
||||
"vtrn.16 d1, d5\n"
|
||||
"vtrn.16 d8, d12\n"
|
||||
"vtrn.16 d9, d13\n"
|
||||
|
||||
"vld1.16 {q8}, [r10], %[src_stride]\n"
|
||||
"vld1.16 {q10}, [r10], %[src_stride]\n"
|
||||
"vld1.16 {q12}, [r10], %[src_stride]\n"
|
||||
"vld1.16 {q14}, [r10], %[src_stride]\n"
|
||||
|
||||
"vtrn.32 d0, d8\n"
|
||||
"vtrn.32 d4, d12\n"
|
||||
"vtrn.32 d1, d9\n"
|
||||
"vtrn.32 d5, d13\n"
|
||||
|
||||
"vtrn.16 d16, d20\n"
|
||||
"vtrn.16 d17, d21\n"
|
||||
"vtrn.16 d24, d28\n"
|
||||
"vtrn.16 d25, d29\n"
|
||||
|
||||
"vld1.16 {q1}, [r10], %[src_stride]\n"
|
||||
"vld1.16 {q3}, [r10], %[src_stride]\n"
|
||||
"vld1.16 {q5}, [r10], %[src_stride]\n"
|
||||
"vld1.16 {q7}, [r10], %[src_stride]\n"
|
||||
|
||||
"vtrn.32 d16, d24\n"
|
||||
"vtrn.32 d20, d28\n"
|
||||
"vtrn.32 d17, d25\n"
|
||||
"vtrn.32 d21, d29\n"
|
||||
|
||||
"vswp d1, d16\n"
|
||||
"vswp d5, d20\n"
|
||||
"vswp d9, d24\n"
|
||||
"vswp d13, d28\n"
|
||||
|
||||
"vtrn.16 d2, d6\n"
|
||||
"vtrn.16 d3, d7\n"
|
||||
"vtrn.16 d10, d14\n"
|
||||
"vtrn.16 d11, d15\n"
|
||||
|
||||
"vtrn.32 d2, d10\n"
|
||||
"vtrn.32 d6, d14\n"
|
||||
"vtrn.32 d3, d11\n"
|
||||
"vtrn.32 d7, d15\n"
|
||||
|
||||
"vst1.16 {q0, d2}, [r12], %[dst_stride]\n"
|
||||
"vst1.16 {q2, d6}, [r12], %[dst_stride]\n"
|
||||
"vst1.16 {q4, d10}, [r12], %[dst_stride]\n"
|
||||
"vst1.16 {q6, d14}, [r12], %[dst_stride]\n"
|
||||
|
||||
"vswp d3, d18\n"
|
||||
"vswp d7, d22\n"
|
||||
"vswp d11, d26\n"
|
||||
"vswp d15, d30\n"
|
||||
|
||||
"vst1.16 {q8, d18}, [r12], %[dst_stride]\n"
|
||||
"vst1.16 {q10, d22}, [r12], %[dst_stride]\n"
|
||||
"vst1.16 {q12, d26}, [r12], %[dst_stride]\n"
|
||||
"vst1.16 {q14, d30}, [r12], %[dst_stride]\n"
|
||||
|
||||
:
|
||||
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride)
|
||||
: "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
|
||||
"q15");
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
inline void Transpose16x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) {
|
||||
asm volatile(
|
||||
"mov x10, %[src_ptr]\n"
|
||||
"mov x11, %[dst_ptr]\n"
|
||||
|
||||
"ld1 {v0.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v1.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v2.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v3.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v4.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v5.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v6.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v7.8h}, [x10], %[src_stride]\n"
|
||||
|
||||
"zip1 v16.8h, v0.8h, v1.8h\n"
|
||||
"zip1 v17.8h, v2.8h, v3.8h\n"
|
||||
"zip1 v18.8h, v4.8h, v5.8h\n"
|
||||
"zip1 v19.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"ld1 {v8.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v9.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v10.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v11.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v12.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v13.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v14.8h}, [x10], %[src_stride]\n"
|
||||
"ld1 {v15.8h}, [x10], %[src_stride]\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v24.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v25.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v26.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v27.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"zip1 v16.8h, v8.8h, v9.8h\n"
|
||||
"zip1 v17.8h, v10.8h, v11.8h\n"
|
||||
"zip1 v18.8h, v12.8h, v13.8h\n"
|
||||
"zip1 v19.8h, v14.8h, v15.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v28.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v29.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v30.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v31.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"add x10, x11, #16\n"
|
||||
"st1 {v24.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v28.8h}, [x10], %[dst_stride]\n"
|
||||
"st1 {v26.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v30.8h}, [x10], %[dst_stride]\n"
|
||||
"st1 {v25.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v29.8h}, [x10], %[dst_stride]\n"
|
||||
"st1 {v27.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v31.8h}, [x10], %[dst_stride]\n"
|
||||
|
||||
"zip2 v16.8h, v0.8h, v1.8h\n"
|
||||
"zip2 v17.8h, v2.8h, v3.8h\n"
|
||||
"zip2 v18.8h, v4.8h, v5.8h\n"
|
||||
"zip2 v19.8h, v6.8h, v7.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v24.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v25.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v26.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v27.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"zip2 v16.8h, v8.8h, v9.8h\n"
|
||||
"zip2 v17.8h, v10.8h, v11.8h\n"
|
||||
"zip2 v18.8h, v12.8h, v13.8h\n"
|
||||
"zip2 v19.8h, v14.8h, v15.8h\n"
|
||||
|
||||
"trn1 v20.4s, v16.4s, v17.4s\n"
|
||||
"trn2 v21.4s, v16.4s, v17.4s\n"
|
||||
"trn1 v22.4s, v18.4s, v19.4s\n"
|
||||
"trn2 v23.4s, v18.4s, v19.4s\n"
|
||||
|
||||
"trn1 v28.2d, v20.2d, v22.2d\n"
|
||||
"trn2 v29.2d, v20.2d, v22.2d\n"
|
||||
"trn1 v30.2d, v21.2d, v23.2d\n"
|
||||
"trn2 v31.2d, v21.2d, v23.2d\n"
|
||||
|
||||
"st1 {v24.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v28.8h}, [x10], %[dst_stride]\n"
|
||||
"st1 {v26.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v30.8h}, [x10], %[dst_stride]\n"
|
||||
"st1 {v25.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v29.8h}, [x10], %[dst_stride]\n"
|
||||
"st1 {v27.8h}, [x11], %[dst_stride]\n"
|
||||
"st1 {v31.8h}, [x10], %[dst_stride]\n"
|
||||
:
|
||||
: [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride)
|
||||
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
|
||||
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
|
||||
"v31");
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -17,11 +17,9 @@
|
|||
#ifndef MINDSPORE_NNACL_FP16_PACK_FP16_H_
|
||||
#define MINDSPORE_NNACL_FP16_PACK_FP16_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
@ -72,6 +70,17 @@ void PackNHWCFp16ToC8HWN8Fp16(float16_t *src, float16_t *dst, int batch, int pla
|
|||
void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel);
|
||||
|
||||
void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, int channel);
|
||||
|
||||
#ifdef ENABLE_ARM82_A32
|
||||
void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride);
|
||||
|
||||
void Transpose12x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride);
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void Transpose16x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride);
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -16,9 +16,6 @@
|
|||
#ifndef MINDSPORE_NNACL_FP16_PAD_FP16_H_
|
||||
#define MINDSPORE_NNACL_FP16_PAD_FP16_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/fp32/pad_fp32.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -18,10 +18,8 @@
|
|||
#define MINDSPORE_NNACL_FP16_POOLING_FP16_H_
|
||||
|
||||
#include <math.h>
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/pooling_parameter.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "nnacl/fp16/power_fp16.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
#if defined(ENABLE_NEON)
|
||||
#if defined(ENABLE_ARM64)
|
||||
float16x8_t OptimizedPowerSimdFp16(float16x8_t x, const void *exponent) {
|
||||
int tmp = (int)(*(float16_t *)exponent);
|
||||
int exp = abs(tmp);
|
||||
|
@ -53,23 +53,23 @@ float16_t OptimizedPowerScalarFp16(float16_t x, const void *exponent) {
|
|||
void PowerBroadCastFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale,
|
||||
float shift) {
|
||||
PowerScalarFunFp16 PowerScalarFunFp16_ = NULL;
|
||||
#if defined(ENABLE_NEON)
|
||||
#if defined(ENABLE_ARM64)
|
||||
PowerSimdFunFp16 PowerSimdFunFp16_ = NULL;
|
||||
#endif
|
||||
|
||||
if (CheckInteger(*exponent)) {
|
||||
#if defined(ENABLE_NEON)
|
||||
#if defined(ENABLE_ARM64)
|
||||
PowerSimdFunFp16_ = OptimizedPowerSimdFp16;
|
||||
#endif
|
||||
PowerScalarFunFp16_ = OptimizedPowerScalarFp16;
|
||||
} else {
|
||||
#if defined(ENABLE_NEON)
|
||||
#if defined(ENABLE_ARM64)
|
||||
PowerSimdFunFp16_ = StdPowerSimdFp16;
|
||||
#endif
|
||||
PowerScalarFunFp16_ = StdPowerScalarFp16;
|
||||
}
|
||||
int i = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
#ifdef ENABLE_ARM64
|
||||
int len_c8 = UP_ROUND(len, C8NUM);
|
||||
float16x8_t scale_8 = vmovq_n_f16(scale);
|
||||
float16x8_t shift_8 = vmovq_n_f16(shift);
|
||||
|
@ -87,7 +87,7 @@ void PowerSingleFp16(const float16_t *input, const float16_t *exponent, float16_
|
|||
float shift) {
|
||||
int i = 0;
|
||||
PowerScalarFunFp16 PowerScalarFunFp16_ = NULL;
|
||||
#ifdef ENABLE_NEON
|
||||
#ifdef ENABLE_ARM64
|
||||
int len_c8 = UP_ROUND(len, C8NUM);
|
||||
float16x8_t scale_8 = vmovq_n_f16(scale);
|
||||
float16x8_t shift_8 = vmovq_n_f16(shift);
|
||||
|
|
|
@ -19,9 +19,10 @@
|
|||
|
||||
#include <math.h>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
|
||||
#include "nnacl/power_parameter.h"
|
||||
|
||||
#if defined(ENABLE_NEON)
|
||||
#if defined(ENABLE_ARM64)
|
||||
typedef float16x8_t (*PowerSimdFunFp16)(float16x8_t x, const void *exponent);
|
||||
#endif
|
||||
typedef float16_t (*PowerScalarFunFp16)(float16_t x, const void *exponent);
|
||||
|
@ -36,7 +37,7 @@ static inline float16_t StdPowerScalarFp16(float16_t x, const void *exponent) {
|
|||
return powf(x, *(float16_t *)exponent);
|
||||
}
|
||||
|
||||
#if defined(ENABLE_NEON)
|
||||
#if defined(ENABLE_ARM64)
|
||||
static inline float16x8_t StdPowerSimdFp16(float16x8_t x, const void *exponent) {
|
||||
float16x8_t result;
|
||||
result[0] = powf(x[0], *(float16_t *)exponent);
|
||||
|
|
|
@ -18,10 +18,7 @@
|
|||
#define MINDSPORE_NNACL_FP16_QUANTDTYPECAST_FP16_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -19,9 +19,6 @@
|
|||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/reduce_parameter.h"
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
|
|
@ -18,10 +18,9 @@
|
|||
#define MINDSPORE_NNACL_SCALE_FP16_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
|
||||
#include "nnacl/scale.h"
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
|
|
@ -40,7 +40,7 @@ void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channe
|
|||
}
|
||||
}
|
||||
int k = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
#ifdef ENABLE_ARM64
|
||||
int count2 = (channel / C8NUM) * C8NUM;
|
||||
for (; k < count2; k += C8NUM) {
|
||||
float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + k);
|
||||
|
@ -58,9 +58,9 @@ void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channe
|
|||
void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel) {
|
||||
int cur_batch_offset = 0;
|
||||
for (int i = 0; i < batch; i++, cur_batch_offset += channel) {
|
||||
float16_t sum = 0;
|
||||
float16_t sum = 0.0f;
|
||||
int j = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
#ifdef ENABLE_ARM64
|
||||
float16x8_t sum8 = vdupq_n_f16(0);
|
||||
int count = (channel / C8NUM) * C8NUM;
|
||||
for (; j < count; j += C8NUM) {
|
||||
|
@ -72,7 +72,7 @@ void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel)
|
|||
sum += src[cur_batch_offset + j];
|
||||
}
|
||||
int k = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
#ifdef ENABLE_ARM64
|
||||
const float16_t div = 1.0f / sum;
|
||||
for (; k < count; k += C8NUM) {
|
||||
vst1q_f16(dst + cur_batch_offset + k, vmulq_n_f16(vld1q_f16(src + cur_batch_offset + k), div));
|
||||
|
@ -117,7 +117,7 @@ void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *s
|
|||
}
|
||||
for (int j = 0; j < input_shape[axis]; j++) {
|
||||
int axis_offset = inner_offset + j * inner_size;
|
||||
output_ptr[axis_offset] = exp(input_ptr[axis_offset] - max_data);
|
||||
output_ptr[axis_offset] = expf(input_ptr[axis_offset] - max_data);
|
||||
sum_data[k + sum_outter_offset] += output_ptr[axis_offset];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,10 +18,9 @@
|
|||
#define MINDSPORE_NNACL_FP16_SOFTMAX_FP16_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
|
||||
#include "nnacl/softmax_parameter.h"
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
|
|
@ -18,10 +18,8 @@
|
|||
#define MINDSPORE_NNACL_FP16_TRANSPOSE_FP16_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
|
||||
#include "nnacl/transpose.h"
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -16,562 +16,15 @@
|
|||
|
||||
#include "nnacl/fp16/winograd_transform_fp16.h"
|
||||
|
||||
// for fp16 convolution 3x3 filter/input/output transform F(4,3)
|
||||
void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step) {
|
||||
float16x8_t d00 = vld1q_f16(tmp_data);
|
||||
float16x8_t d01 = vld1q_f16(tmp_data + 8);
|
||||
float16x8_t d02 = vld1q_f16(tmp_data + 2 * 8);
|
||||
float16x8_t d03 = vld1q_f16(tmp_data + 3 * 8);
|
||||
float16x8_t d04 = vld1q_f16(tmp_data + 4 * 8);
|
||||
float16x8_t d05 = vld1q_f16(tmp_data + 5 * 8);
|
||||
|
||||
float16x8_t d10 = vld1q_f16(tmp_data + 6 * 8);
|
||||
float16x8_t d11 = vld1q_f16(tmp_data + 7 * 8);
|
||||
float16x8_t d12 = vld1q_f16(tmp_data + 8 * 8);
|
||||
float16x8_t d13 = vld1q_f16(tmp_data + 9 * 8);
|
||||
float16x8_t d14 = vld1q_f16(tmp_data + 10 * 8);
|
||||
float16x8_t d15 = vld1q_f16(tmp_data + 11 * 8);
|
||||
|
||||
float16x8_t d20 = vld1q_f16(tmp_data + 12 * 8);
|
||||
float16x8_t d21 = vld1q_f16(tmp_data + 13 * 8);
|
||||
float16x8_t d22 = vld1q_f16(tmp_data + 14 * 8);
|
||||
float16x8_t d23 = vld1q_f16(tmp_data + 15 * 8);
|
||||
float16x8_t d24 = vld1q_f16(tmp_data + 16 * 8);
|
||||
float16x8_t d25 = vld1q_f16(tmp_data + 17 * 8);
|
||||
|
||||
float16x8_t d30 = vld1q_f16(tmp_data + 18 * 8);
|
||||
float16x8_t d31 = vld1q_f16(tmp_data + 19 * 8);
|
||||
float16x8_t d32 = vld1q_f16(tmp_data + 20 * 8);
|
||||
float16x8_t d33 = vld1q_f16(tmp_data + 21 * 8);
|
||||
float16x8_t d34 = vld1q_f16(tmp_data + 22 * 8);
|
||||
float16x8_t d35 = vld1q_f16(tmp_data + 23 * 8);
|
||||
|
||||
float16x8_t d40 = vld1q_f16(tmp_data + 24 * 8);
|
||||
float16x8_t d41 = vld1q_f16(tmp_data + 25 * 8);
|
||||
float16x8_t d42 = vld1q_f16(tmp_data + 26 * 8);
|
||||
float16x8_t d43 = vld1q_f16(tmp_data + 27 * 8);
|
||||
float16x8_t d44 = vld1q_f16(tmp_data + 28 * 8);
|
||||
float16x8_t d45 = vld1q_f16(tmp_data + 29 * 8);
|
||||
|
||||
float16x8_t d50 = vld1q_f16(tmp_data + 30 * 8);
|
||||
float16x8_t d51 = vld1q_f16(tmp_data + 31 * 8);
|
||||
float16x8_t d52 = vld1q_f16(tmp_data + 32 * 8);
|
||||
float16x8_t d53 = vld1q_f16(tmp_data + 33 * 8);
|
||||
float16x8_t d54 = vld1q_f16(tmp_data + 34 * 8);
|
||||
float16x8_t d55 = vld1q_f16(tmp_data + 35 * 8);
|
||||
|
||||
float16x8_t t00 = vaddq_f16(vsubq_f16(vmulq_n_f16(d00, 4), vmulq_n_f16(d20, 5)), d40);
|
||||
float16x8_t t01 = vaddq_f16(vsubq_f16(vmulq_n_f16(d01, 4), vmulq_n_f16(d21, 5)), d41);
|
||||
float16x8_t t02 = vaddq_f16(vsubq_f16(vmulq_n_f16(d02, 4), vmulq_n_f16(d22, 5)), d42);
|
||||
float16x8_t t03 = vaddq_f16(vsubq_f16(vmulq_n_f16(d03, 4), vmulq_n_f16(d23, 5)), d43);
|
||||
float16x8_t t04 = vaddq_f16(vsubq_f16(vmulq_n_f16(d04, 4), vmulq_n_f16(d24, 5)), d44);
|
||||
float16x8_t t05 = vaddq_f16(vsubq_f16(vmulq_n_f16(d05, 4), vmulq_n_f16(d25, 5)), d45);
|
||||
|
||||
float16x8_t t10 = vaddq_f16(vaddq_f16(d30, d40), vmulq_n_f16(vaddq_f16(d10, d20), -4));
|
||||
float16x8_t t11 = vaddq_f16(vaddq_f16(d31, d41), vmulq_n_f16(vaddq_f16(d11, d21), -4));
|
||||
float16x8_t t12 = vaddq_f16(vaddq_f16(d32, d42), vmulq_n_f16(vaddq_f16(d12, d22), -4));
|
||||
float16x8_t t13 = vaddq_f16(vaddq_f16(d33, d43), vmulq_n_f16(vaddq_f16(d13, d23), -4));
|
||||
float16x8_t t14 = vaddq_f16(vaddq_f16(d34, d44), vmulq_n_f16(vaddq_f16(d14, d24), -4));
|
||||
float16x8_t t15 = vaddq_f16(vaddq_f16(d35, d45), vmulq_n_f16(vaddq_f16(d15, d25), -4));
|
||||
|
||||
float16x8_t t20 = vaddq_f16(vsubq_f16(d40, d30), vmulq_n_f16(vsubq_f16(d10, d20), 4));
|
||||
float16x8_t t21 = vaddq_f16(vsubq_f16(d41, d31), vmulq_n_f16(vsubq_f16(d11, d21), 4));
|
||||
float16x8_t t22 = vaddq_f16(vsubq_f16(d42, d32), vmulq_n_f16(vsubq_f16(d12, d22), 4));
|
||||
float16x8_t t23 = vaddq_f16(vsubq_f16(d43, d33), vmulq_n_f16(vsubq_f16(d13, d23), 4));
|
||||
float16x8_t t24 = vaddq_f16(vsubq_f16(d44, d34), vmulq_n_f16(vsubq_f16(d14, d24), 4));
|
||||
float16x8_t t25 = vaddq_f16(vsubq_f16(d45, d35), vmulq_n_f16(vsubq_f16(d15, d25), 4));
|
||||
|
||||
float16x8_t t30 = vaddq_f16(vsubq_f16(d40, d20), vmulq_n_f16(vsubq_f16(d30, d10), 2));
|
||||
float16x8_t t31 = vaddq_f16(vsubq_f16(d41, d21), vmulq_n_f16(vsubq_f16(d31, d11), 2));
|
||||
float16x8_t t32 = vaddq_f16(vsubq_f16(d42, d22), vmulq_n_f16(vsubq_f16(d32, d12), 2));
|
||||
float16x8_t t33 = vaddq_f16(vsubq_f16(d43, d23), vmulq_n_f16(vsubq_f16(d33, d13), 2));
|
||||
float16x8_t t34 = vaddq_f16(vsubq_f16(d44, d24), vmulq_n_f16(vsubq_f16(d34, d14), 2));
|
||||
float16x8_t t35 = vaddq_f16(vsubq_f16(d45, d25), vmulq_n_f16(vsubq_f16(d35, d15), 2));
|
||||
|
||||
float16x8_t t40 = vaddq_f16(vsubq_f16(d40, d20), vmulq_n_f16(vsubq_f16(d10, d30), 2));
|
||||
float16x8_t t41 = vaddq_f16(vsubq_f16(d41, d21), vmulq_n_f16(vsubq_f16(d11, d31), 2));
|
||||
float16x8_t t42 = vaddq_f16(vsubq_f16(d42, d22), vmulq_n_f16(vsubq_f16(d12, d32), 2));
|
||||
float16x8_t t43 = vaddq_f16(vsubq_f16(d43, d23), vmulq_n_f16(vsubq_f16(d13, d33), 2));
|
||||
float16x8_t t44 = vaddq_f16(vsubq_f16(d44, d24), vmulq_n_f16(vsubq_f16(d14, d34), 2));
|
||||
float16x8_t t45 = vaddq_f16(vsubq_f16(d45, d25), vmulq_n_f16(vsubq_f16(d15, d35), 2));
|
||||
|
||||
float16x8_t t50 = vaddq_f16(vsubq_f16(vmulq_n_f16(d10, 4), vmulq_n_f16(d30, 5)), d50);
|
||||
float16x8_t t51 = vaddq_f16(vsubq_f16(vmulq_n_f16(d11, 4), vmulq_n_f16(d31, 5)), d51);
|
||||
float16x8_t t52 = vaddq_f16(vsubq_f16(vmulq_n_f16(d12, 4), vmulq_n_f16(d32, 5)), d52);
|
||||
float16x8_t t53 = vaddq_f16(vsubq_f16(vmulq_n_f16(d13, 4), vmulq_n_f16(d33, 5)), d53);
|
||||
float16x8_t t54 = vaddq_f16(vsubq_f16(vmulq_n_f16(d14, 4), vmulq_n_f16(d34, 5)), d54);
|
||||
float16x8_t t55 = vaddq_f16(vsubq_f16(vmulq_n_f16(d15, 4), vmulq_n_f16(d35, 5)), d55);
|
||||
|
||||
float16x8_t m00 = vaddq_f16(vsubq_f16(vmulq_n_f16(t00, 4), vmulq_n_f16(t02, 5)), t04);
|
||||
float16x8_t m01 = vaddq_f16(vaddq_f16(t03, t04), vmulq_n_f16(vaddq_f16(t01, t02), -4));
|
||||
float16x8_t m02 = vaddq_f16(vsubq_f16(t04, t03), vmulq_n_f16(vsubq_f16(t01, t02), 4));
|
||||
float16x8_t m03 = vaddq_f16(vsubq_f16(t04, t02), vmulq_n_f16(vsubq_f16(t03, t01), 2));
|
||||
float16x8_t m04 = vaddq_f16(vsubq_f16(t04, t02), vmulq_n_f16(vsubq_f16(t01, t03), 2));
|
||||
float16x8_t m05 = vaddq_f16(vsubq_f16(vmulq_n_f16(t01, 4), vmulq_n_f16(t03, 5)), t05);
|
||||
|
||||
float16x8_t m10 = vaddq_f16(vsubq_f16(vmulq_n_f16(t10, 4), vmulq_n_f16(t12, 5)), t14);
|
||||
float16x8_t m11 = vaddq_f16(vaddq_f16(t13, t14), vmulq_n_f16(vaddq_f16(t11, t12), -4));
|
||||
float16x8_t m12 = vaddq_f16(vsubq_f16(t14, t13), vmulq_n_f16(vsubq_f16(t11, t12), 4));
|
||||
float16x8_t m13 = vaddq_f16(vsubq_f16(t14, t12), vmulq_n_f16(vsubq_f16(t13, t11), 2));
|
||||
float16x8_t m14 = vaddq_f16(vsubq_f16(t14, t12), vmulq_n_f16(vsubq_f16(t11, t13), 2));
|
||||
float16x8_t m15 = vaddq_f16(vsubq_f16(vmulq_n_f16(t11, 4), vmulq_n_f16(t13, 5)), t15);
|
||||
|
||||
float16x8_t m20 = vaddq_f16(vsubq_f16(vmulq_n_f16(t20, 4), vmulq_n_f16(t22, 5)), t24);
|
||||
float16x8_t m21 = vaddq_f16(vaddq_f16(t23, t24), vmulq_n_f16(vaddq_f16(t21, t22), -4));
|
||||
float16x8_t m22 = vaddq_f16(vsubq_f16(t24, t23), vmulq_n_f16(vsubq_f16(t21, t22), 4));
|
||||
float16x8_t m23 = vaddq_f16(vsubq_f16(t24, t22), vmulq_n_f16(vsubq_f16(t23, t21), 2));
|
||||
float16x8_t m24 = vaddq_f16(vsubq_f16(t24, t22), vmulq_n_f16(vsubq_f16(t21, t23), 2));
|
||||
float16x8_t m25 = vaddq_f16(vsubq_f16(vmulq_n_f16(t21, 4), vmulq_n_f16(t23, 5)), t25);
|
||||
|
||||
float16x8_t m30 = vaddq_f16(vsubq_f16(vmulq_n_f16(t30, 4), vmulq_n_f16(t32, 5)), t34);
|
||||
float16x8_t m31 = vaddq_f16(vaddq_f16(t33, t34), vmulq_n_f16(vaddq_f16(t31, t32), -4));
|
||||
float16x8_t m32 = vaddq_f16(vsubq_f16(t34, t33), vmulq_n_f16(vsubq_f16(t31, t32), 4));
|
||||
float16x8_t m33 = vaddq_f16(vsubq_f16(t34, t32), vmulq_n_f16(vsubq_f16(t33, t31), 2));
|
||||
float16x8_t m34 = vaddq_f16(vsubq_f16(t34, t32), vmulq_n_f16(vsubq_f16(t31, t33), 2));
|
||||
float16x8_t m35 = vaddq_f16(vsubq_f16(vmulq_n_f16(t31, 4), vmulq_n_f16(t33, 5)), t35);
|
||||
|
||||
float16x8_t m40 = vaddq_f16(vsubq_f16(vmulq_n_f16(t40, 4), vmulq_n_f16(t42, 5)), t44);
|
||||
float16x8_t m41 = vaddq_f16(vaddq_f16(t43, t44), vmulq_n_f16(vaddq_f16(t41, t42), -4));
|
||||
float16x8_t m42 = vaddq_f16(vsubq_f16(t44, t43), vmulq_n_f16(vsubq_f16(t41, t42), 4));
|
||||
float16x8_t m43 = vaddq_f16(vsubq_f16(t44, t42), vmulq_n_f16(vsubq_f16(t43, t41), 2));
|
||||
float16x8_t m44 = vaddq_f16(vsubq_f16(t44, t42), vmulq_n_f16(vsubq_f16(t41, t43), 2));
|
||||
float16x8_t m45 = vaddq_f16(vsubq_f16(vmulq_n_f16(t41, 4), vmulq_n_f16(t43, 5)), t45);
|
||||
|
||||
float16x8_t m50 = vaddq_f16(vsubq_f16(vmulq_n_f16(t50, 4), vmulq_n_f16(t52, 5)), t54);
|
||||
float16x8_t m51 = vaddq_f16(vaddq_f16(t53, t54), vmulq_n_f16(vaddq_f16(t51, t52), -4));
|
||||
float16x8_t m52 = vaddq_f16(vsubq_f16(t54, t53), vmulq_n_f16(vsubq_f16(t51, t52), 4));
|
||||
float16x8_t m53 = vaddq_f16(vsubq_f16(t54, t52), vmulq_n_f16(vsubq_f16(t53, t51), 2));
|
||||
float16x8_t m54 = vaddq_f16(vsubq_f16(t54, t52), vmulq_n_f16(vsubq_f16(t51, t53), 2));
|
||||
float16x8_t m55 = vaddq_f16(vsubq_f16(vmulq_n_f16(t51, 4), vmulq_n_f16(t53, 5)), t55);
|
||||
|
||||
vst1_f16(trans_input_data, vget_low_f16(m00));
|
||||
vst1_f16(trans_input_data + 64, vget_high_f16(m00));
|
||||
vst1_f16(trans_input_data + step, vget_low_f16(m01));
|
||||
vst1_f16(trans_input_data + step + 64, vget_high_f16(m01));
|
||||
vst1_f16(trans_input_data + 2 * step, vget_low_f16(m02));
|
||||
vst1_f16(trans_input_data + 2 * step + 64, vget_high_f16(m02));
|
||||
vst1_f16(trans_input_data + 3 * step, vget_low_f16(m03));
|
||||
vst1_f16(trans_input_data + 3 * step + 64, vget_high_f16(m03));
|
||||
vst1_f16(trans_input_data + 4 * step, vget_low_f16(m04));
|
||||
vst1_f16(trans_input_data + 4 * step + 64, vget_high_f16(m04));
|
||||
vst1_f16(trans_input_data + 5 * step, vget_low_f16(m05));
|
||||
vst1_f16(trans_input_data + 5 * step + 64, vget_high_f16(m05));
|
||||
|
||||
vst1_f16(trans_input_data + 6 * step, vget_low_f16(m10));
|
||||
vst1_f16(trans_input_data + 6 * step + 64, vget_high_f16(m10));
|
||||
vst1_f16(trans_input_data + 7 * step, vget_low_f16(m11));
|
||||
vst1_f16(trans_input_data + 7 * step + 64, vget_high_f16(m11));
|
||||
vst1_f16(trans_input_data + 8 * step, vget_low_f16(m12));
|
||||
vst1_f16(trans_input_data + 8 * step + 64, vget_high_f16(m12));
|
||||
vst1_f16(trans_input_data + 9 * step, vget_low_f16(m13));
|
||||
vst1_f16(trans_input_data + 9 * step + 64, vget_high_f16(m13));
|
||||
vst1_f16(trans_input_data + 10 * step, vget_low_f16(m14));
|
||||
vst1_f16(trans_input_data + 10 * step + 64, vget_high_f16(m14));
|
||||
vst1_f16(trans_input_data + 11 * step, vget_low_f16(m15));
|
||||
vst1_f16(trans_input_data + 11 * step + 64, vget_high_f16(m15));
|
||||
|
||||
vst1_f16(trans_input_data + 12 * step, vget_low_f16(m20));
|
||||
vst1_f16(trans_input_data + 12 * step + 64, vget_high_f16(m20));
|
||||
vst1_f16(trans_input_data + 13 * step, vget_low_f16(m21));
|
||||
vst1_f16(trans_input_data + 13 * step + 64, vget_high_f16(m21));
|
||||
vst1_f16(trans_input_data + 14 * step, vget_low_f16(m22));
|
||||
vst1_f16(trans_input_data + 14 * step + 64, vget_high_f16(m22));
|
||||
vst1_f16(trans_input_data + 15 * step, vget_low_f16(m23));
|
||||
vst1_f16(trans_input_data + 15 * step + 64, vget_high_f16(m23));
|
||||
vst1_f16(trans_input_data + 16 * step, vget_low_f16(m24));
|
||||
vst1_f16(trans_input_data + 16 * step + 64, vget_high_f16(m24));
|
||||
vst1_f16(trans_input_data + 17 * step, vget_low_f16(m25));
|
||||
vst1_f16(trans_input_data + 17 * step + 64, vget_high_f16(m25));
|
||||
|
||||
vst1_f16(trans_input_data + 18 * step, vget_low_f16(m30));
|
||||
vst1_f16(trans_input_data + 18 * step + 64, vget_high_f16(m30));
|
||||
vst1_f16(trans_input_data + 19 * step, vget_low_f16(m31));
|
||||
vst1_f16(trans_input_data + 19 * step + 64, vget_high_f16(m31));
|
||||
vst1_f16(trans_input_data + 20 * step, vget_low_f16(m32));
|
||||
vst1_f16(trans_input_data + 20 * step + 64, vget_high_f16(m32));
|
||||
vst1_f16(trans_input_data + 21 * step, vget_low_f16(m33));
|
||||
vst1_f16(trans_input_data + 21 * step + 64, vget_high_f16(m33));
|
||||
vst1_f16(trans_input_data + 22 * step, vget_low_f16(m34));
|
||||
vst1_f16(trans_input_data + 22 * step + 64, vget_high_f16(m34));
|
||||
vst1_f16(trans_input_data + 23 * step, vget_low_f16(m35));
|
||||
vst1_f16(trans_input_data + 23 * step + 64, vget_high_f16(m35));
|
||||
|
||||
vst1_f16(trans_input_data + 24 * step, vget_low_f16(m40));
|
||||
vst1_f16(trans_input_data + 24 * step + 64, vget_high_f16(m40));
|
||||
vst1_f16(trans_input_data + 25 * step, vget_low_f16(m41));
|
||||
vst1_f16(trans_input_data + 25 * step + 64, vget_high_f16(m41));
|
||||
vst1_f16(trans_input_data + 26 * step, vget_low_f16(m42));
|
||||
vst1_f16(trans_input_data + 26 * step + 64, vget_high_f16(m42));
|
||||
vst1_f16(trans_input_data + 27 * step, vget_low_f16(m43));
|
||||
vst1_f16(trans_input_data + 27 * step + 64, vget_high_f16(m43));
|
||||
vst1_f16(trans_input_data + 28 * step, vget_low_f16(m44));
|
||||
vst1_f16(trans_input_data + 28 * step + 64, vget_high_f16(m44));
|
||||
vst1_f16(trans_input_data + 29 * step, vget_low_f16(m45));
|
||||
vst1_f16(trans_input_data + 29 * step + 64, vget_high_f16(m45));
|
||||
|
||||
vst1_f16(trans_input_data + 30 * step, vget_low_f16(m50));
|
||||
vst1_f16(trans_input_data + 30 * step + 64, vget_high_f16(m50));
|
||||
vst1_f16(trans_input_data + 31 * step, vget_low_f16(m51));
|
||||
vst1_f16(trans_input_data + 31 * step + 64, vget_high_f16(m51));
|
||||
vst1_f16(trans_input_data + 32 * step, vget_low_f16(m52));
|
||||
vst1_f16(trans_input_data + 32 * step + 64, vget_high_f16(m52));
|
||||
vst1_f16(trans_input_data + 33 * step, vget_low_f16(m53));
|
||||
vst1_f16(trans_input_data + 33 * step + 64, vget_high_f16(m53));
|
||||
vst1_f16(trans_input_data + 34 * step, vget_low_f16(m54));
|
||||
vst1_f16(trans_input_data + 34 * step + 64, vget_high_f16(m54));
|
||||
vst1_f16(trans_input_data + 35 * step, vget_low_f16(m55));
|
||||
vst1_f16(trans_input_data + 35 * step + 64, vget_high_f16(m55));
|
||||
}
|
||||
|
||||
void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data,
|
||||
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) {
|
||||
// input data format : nhwc
|
||||
const int output_unit = 4;
|
||||
int input_channel = conv_param->input_channel_;
|
||||
int input_width = conv_param->input_w_;
|
||||
int input_height = conv_param->input_h_;
|
||||
int pad_w = conv_param->pad_l_;
|
||||
int pad_h = conv_param->pad_u_;
|
||||
int ic8 = UP_DIV(input_channel, C8NUM);
|
||||
if (out_w_block == 0) {
|
||||
return;
|
||||
}
|
||||
for (int cal_id = 0; cal_id < real_cal_num; cal_id++) {
|
||||
int x_id = start_index + cal_id;
|
||||
int origin_x = (x_id % out_w_block) * output_unit - pad_w;
|
||||
int origin_y = (x_id / out_w_block) * output_unit - pad_h;
|
||||
int real_x_start = origin_x > 0 ? 0 : -origin_x;
|
||||
int real_x_end = (origin_x + 6) < input_width ? 6 : (input_width - origin_x);
|
||||
int real_y_start = origin_y > 0 ? 0 : -origin_y;
|
||||
int real_y_end = (origin_y + 6) < input_height ? 6 : (input_height - origin_y);
|
||||
|
||||
int src_plane_offset = ic8 * C8NUM * (origin_y * input_width + origin_x);
|
||||
int dst_plane_offset = cal_id * C4NUM;
|
||||
for (int ic = 0; ic < ic8; ic++) {
|
||||
// clear tmp buffer
|
||||
memset(tmp_data, 0, 6 * 6 * C8NUM * sizeof(float16_t));
|
||||
|
||||
// get real input block with padding
|
||||
int src_ic4_offset = src_plane_offset + ic * C8NUM;
|
||||
for (int interval = real_y_start; interval < real_y_end; interval++) {
|
||||
int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * ic8 * C8NUM;
|
||||
int dst_y_offset = interval * 6 * C8NUM + real_x_start * C8NUM;
|
||||
for (int j = 0; j < (real_x_end - real_x_start); j++) {
|
||||
int src_x_offset = src_y_offset + j * ic8 * C8NUM;
|
||||
int dst_x_offset = dst_y_offset + j * C8NUM;
|
||||
float16_t *src_addr = (float16_t *)(input_data) + src_x_offset;
|
||||
float16_t *dst_addr = tmp_data + dst_x_offset;
|
||||
vst1q_f16(dst_addr, vld1q_f16(src_addr));
|
||||
}
|
||||
}
|
||||
|
||||
// input transform
|
||||
int dst_ic4_offset = dst_plane_offset + ic * 16 * C8NUM;
|
||||
size_t dst_step = ic8 * C8NUM * 16;
|
||||
float16_t *trans_input_ptr = trans_input + dst_ic4_offset;
|
||||
Conv3x3Fp16InputUnit(tmp_data, trans_input_ptr, dst_step);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC4, int output_channel,
|
||||
int kernel_plane) {
|
||||
int dst_step = iC4 * C4NUM * 8;
|
||||
for (int o = 0; o < output_channel; o++) {
|
||||
int oc8_block_num = o / C8NUM;
|
||||
int oc8_block_rem = o % C8NUM;
|
||||
int src_oc_offset = o * iC4 * C4NUM * kernel_plane;
|
||||
int dst_oc_offset = oc8_block_num * C8NUM * iC4 * C4NUM * 36 + oc8_block_rem;
|
||||
for (int i = 0; i < iC4; i++) {
|
||||
const float16_t *src_ic4_ptr = weight_data + src_oc_offset + i * kernel_plane * C4NUM;
|
||||
float16_t *dst_ic4_ptr = trans_weight + dst_oc_offset + i * 8 * C4NUM;
|
||||
float16x4_t g00 = vld1_f16(src_ic4_ptr);
|
||||
float16x4_t g01 = vld1_f16(src_ic4_ptr + 4);
|
||||
float16x4_t g02 = vld1_f16(src_ic4_ptr + 2 * 4);
|
||||
float16x4_t g10 = vld1_f16(src_ic4_ptr + 3 * 4);
|
||||
float16x4_t g11 = vld1_f16(src_ic4_ptr + 4 * 4);
|
||||
float16x4_t g12 = vld1_f16(src_ic4_ptr + 5 * 4);
|
||||
float16x4_t g20 = vld1_f16(src_ic4_ptr + 6 * 4);
|
||||
float16x4_t g21 = vld1_f16(src_ic4_ptr + 7 * 4);
|
||||
float16x4_t g22 = vld1_f16(src_ic4_ptr + 8 * 4);
|
||||
|
||||
float16x4_t dst00 = vmul_n_f16(g00, 0.25);
|
||||
float16x4_t dst01 = vmul_n_f16(g01, 0.25);
|
||||
float16x4_t dst02 = vmul_n_f16(g02, 0.25);
|
||||
|
||||
float16x4_t dst10 = vmul_n_f16(vadd_f16(g00, vadd_f16(g10, g20)), -0.1666666666667);
|
||||
float16x4_t dst11 = vmul_n_f16(vadd_f16(g01, vadd_f16(g11, g21)), -0.1666666666667);
|
||||
float16x4_t dst12 = vmul_n_f16(vadd_f16(g02, vadd_f16(g12, g22)), -0.1666666666667);
|
||||
|
||||
float16x4_t dst20 = vmul_n_f16(vsub_f16(vadd_f16(g00, g20), g10), -0.1666666666667);
|
||||
float16x4_t dst21 = vmul_n_f16(vsub_f16(vadd_f16(g01, g21), g11), -0.1666666666667);
|
||||
float16x4_t dst22 = vmul_n_f16(vsub_f16(vadd_f16(g02, g22), g12), -0.1666666666667);
|
||||
|
||||
float16x4_t dst30 = vadd_f16(vmul_n_f16(g10, 0.08333333333333),
|
||||
vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667)));
|
||||
float16x4_t dst31 = vadd_f16(vmul_n_f16(g11, 0.08333333333333),
|
||||
vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667)));
|
||||
float16x4_t dst32 = vadd_f16(vmul_n_f16(g12, 0.08333333333333),
|
||||
vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667)));
|
||||
|
||||
float16x4_t dst40 = vsub_f16(vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667)),
|
||||
vmul_n_f16(g10, 0.08333333333333));
|
||||
float16x4_t dst41 = vsub_f16(vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667)),
|
||||
vmul_n_f16(g11, 0.08333333333333));
|
||||
float16x4_t dst42 = vsub_f16(vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667)),
|
||||
vmul_n_f16(g12, 0.08333333333333));
|
||||
|
||||
float16x4_t dst50 = g20;
|
||||
float16x4_t dst51 = g21;
|
||||
float16x4_t dst52 = g22;
|
||||
|
||||
float16x4_t m00 = vmul_n_f16(dst00, 0.25);
|
||||
float16x4_t m01 = vmul_n_f16(vadd_f16(dst00, vadd_f16(dst01, dst02)), -0.1666666666667);
|
||||
float16x4_t m02 = vmul_n_f16(vsub_f16(vadd_f16(dst00, dst02), dst01), -0.1666666666667);
|
||||
float16x4_t m03 = vadd_f16(vmul_n_f16(dst01, 0.08333333333333),
|
||||
vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667)));
|
||||
float16x4_t m04 = vsub_f16(vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667)),
|
||||
vmul_n_f16(dst01, 0.08333333333333));
|
||||
float16x4_t m05 = dst02;
|
||||
|
||||
float16x4_t m10 = vmul_n_f16(dst10, 0.25);
|
||||
float16x4_t m11 = vmul_n_f16(vadd_f16(dst10, vadd_f16(dst11, dst12)), -0.1666666666667);
|
||||
float16x4_t m12 = vmul_n_f16(vsub_f16(vadd_f16(dst10, dst12), dst11), -0.1666666666667);
|
||||
float16x4_t m13 = vadd_f16(vmul_n_f16(dst11, 0.08333333333333),
|
||||
vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667)));
|
||||
float16x4_t m14 = vsub_f16(vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667)),
|
||||
vmul_n_f16(dst11, 0.08333333333333));
|
||||
float16x4_t m15 = dst12;
|
||||
|
||||
float16x4_t m20 = vmul_n_f16(dst20, 0.25);
|
||||
float16x4_t m21 = vmul_n_f16(vadd_f16(dst20, vadd_f16(dst21, dst22)), -0.1666666666667);
|
||||
float16x4_t m22 = vmul_n_f16(vsub_f16(vadd_f16(dst20, dst22), dst21), -0.1666666666667);
|
||||
float16x4_t m23 = vadd_f16(vmul_n_f16(dst21, 0.08333333333333),
|
||||
vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667)));
|
||||
float16x4_t m24 = vsub_f16(vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667)),
|
||||
vmul_n_f16(dst21, 0.08333333333333));
|
||||
float16x4_t m25 = dst22;
|
||||
|
||||
float16x4_t m30 = vmul_n_f16(dst30, 0.25);
|
||||
float16x4_t m31 = vmul_n_f16(vadd_f16(dst30, vadd_f16(dst31, dst32)), -0.1666666666667);
|
||||
float16x4_t m32 = vmul_n_f16(vsub_f16(vadd_f16(dst30, dst32), dst31), -0.1666666666667);
|
||||
float16x4_t m33 = vadd_f16(vmul_n_f16(dst31, 0.08333333333333),
|
||||
vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667)));
|
||||
float16x4_t m34 = vsub_f16(vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667)),
|
||||
vmul_n_f16(dst31, 0.08333333333333));
|
||||
float16x4_t m35 = dst32;
|
||||
|
||||
float16x4_t m40 = vmul_n_f16(dst40, 0.25);
|
||||
float16x4_t m41 = vmul_n_f16(vadd_f16(dst40, vadd_f16(dst41, dst42)), -0.1666666666667);
|
||||
float16x4_t m42 = vmul_n_f16(vsub_f16(vadd_f16(dst40, dst42), dst41), -0.1666666666667);
|
||||
float16x4_t m43 = vadd_f16(vmul_n_f16(dst41, 0.08333333333333),
|
||||
vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667)));
|
||||
float16x4_t m44 = vsub_f16(vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667)),
|
||||
vmul_n_f16(dst41, 0.08333333333333));
|
||||
float16x4_t m45 = dst42;
|
||||
|
||||
float16x4_t m50 = vmul_n_f16(dst50, 0.25);
|
||||
float16x4_t m51 = vmul_n_f16(vadd_f16(dst50, vadd_f16(dst51, dst52)), -0.1666666666667);
|
||||
float16x4_t m52 = vmul_n_f16(vsub_f16(vadd_f16(dst50, dst52), dst51), -0.1666666666667);
|
||||
float16x4_t m53 = vadd_f16(vmul_n_f16(dst51, 0.08333333333333),
|
||||
vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667)));
|
||||
float16x4_t m54 = vsub_f16(vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667)),
|
||||
vmul_n_f16(dst51, 0.08333333333333));
|
||||
float16x4_t m55 = dst52;
|
||||
|
||||
for (int j = 0; j < 4; j++) {
|
||||
dst_ic4_ptr[j * 8] = m00[j];
|
||||
dst_ic4_ptr[j * 8 + dst_step] = m01[j];
|
||||
dst_ic4_ptr[j * 8 + 2 * dst_step] = m02[j];
|
||||
dst_ic4_ptr[j * 8 + 3 * dst_step] = m03[j];
|
||||
dst_ic4_ptr[j * 8 + 4 * dst_step] = m04[j];
|
||||
dst_ic4_ptr[j * 8 + 5 * dst_step] = m05[j];
|
||||
dst_ic4_ptr[j * 8 + 6 * dst_step] = m10[j];
|
||||
dst_ic4_ptr[j * 8 + 7 * dst_step] = m11[j];
|
||||
dst_ic4_ptr[j * 8 + 8 * dst_step] = m12[j];
|
||||
dst_ic4_ptr[j * 8 + 9 * dst_step] = m13[j];
|
||||
dst_ic4_ptr[j * 8 + 10 * dst_step] = m14[j];
|
||||
dst_ic4_ptr[j * 8 + 11 * dst_step] = m15[j];
|
||||
dst_ic4_ptr[j * 8 + 12 * dst_step] = m20[j];
|
||||
dst_ic4_ptr[j * 8 + 13 * dst_step] = m21[j];
|
||||
dst_ic4_ptr[j * 8 + 14 * dst_step] = m22[j];
|
||||
dst_ic4_ptr[j * 8 + 15 * dst_step] = m23[j];
|
||||
dst_ic4_ptr[j * 8 + 16 * dst_step] = m24[j];
|
||||
dst_ic4_ptr[j * 8 + 17 * dst_step] = m25[j];
|
||||
dst_ic4_ptr[j * 8 + 18 * dst_step] = m30[j];
|
||||
dst_ic4_ptr[j * 8 + 19 * dst_step] = m31[j];
|
||||
dst_ic4_ptr[j * 8 + 20 * dst_step] = m32[j];
|
||||
dst_ic4_ptr[j * 8 + 21 * dst_step] = m33[j];
|
||||
dst_ic4_ptr[j * 8 + 22 * dst_step] = m34[j];
|
||||
dst_ic4_ptr[j * 8 + 23 * dst_step] = m35[j];
|
||||
dst_ic4_ptr[j * 8 + 24 * dst_step] = m40[j];
|
||||
dst_ic4_ptr[j * 8 + 25 * dst_step] = m41[j];
|
||||
dst_ic4_ptr[j * 8 + 26 * dst_step] = m42[j];
|
||||
dst_ic4_ptr[j * 8 + 27 * dst_step] = m43[j];
|
||||
dst_ic4_ptr[j * 8 + 28 * dst_step] = m44[j];
|
||||
dst_ic4_ptr[j * 8 + 29 * dst_step] = m45[j];
|
||||
dst_ic4_ptr[j * 8 + 30 * dst_step] = m50[j];
|
||||
dst_ic4_ptr[j * 8 + 31 * dst_step] = m51[j];
|
||||
dst_ic4_ptr[j * 8 + 32 * dst_step] = m52[j];
|
||||
dst_ic4_ptr[j * 8 + 33 * dst_step] = m53[j];
|
||||
dst_ic4_ptr[j * 8 + 34 * dst_step] = m54[j];
|
||||
dst_ic4_ptr[j * 8 + 35 * dst_step] = m55[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data,
|
||||
int output_w) {
|
||||
float16x8_t s00 = vld1q_f16(gemm_out);
|
||||
float16x8_t s01 = vld1q_f16(gemm_out + 8);
|
||||
float16x8_t s02 = vld1q_f16(gemm_out + 16);
|
||||
float16x8_t s03 = vld1q_f16(gemm_out + 24);
|
||||
float16x8_t s04 = vld1q_f16(gemm_out + 32);
|
||||
float16x8_t s05 = vld1q_f16(gemm_out + 40);
|
||||
|
||||
float16x8_t s10 = vld1q_f16(gemm_out + 48);
|
||||
float16x8_t s11 = vld1q_f16(gemm_out + 56);
|
||||
float16x8_t s12 = vld1q_f16(gemm_out + 64);
|
||||
float16x8_t s13 = vld1q_f16(gemm_out + 72);
|
||||
float16x8_t s14 = vld1q_f16(gemm_out + 80);
|
||||
float16x8_t s15 = vld1q_f16(gemm_out + 88);
|
||||
|
||||
float16x8_t s20 = vld1q_f16(gemm_out + 96);
|
||||
float16x8_t s21 = vld1q_f16(gemm_out + 104);
|
||||
float16x8_t s22 = vld1q_f16(gemm_out + 112);
|
||||
float16x8_t s23 = vld1q_f16(gemm_out + 120);
|
||||
float16x8_t s24 = vld1q_f16(gemm_out + 128);
|
||||
float16x8_t s25 = vld1q_f16(gemm_out + 136);
|
||||
|
||||
float16x8_t s30 = vld1q_f16(gemm_out + 144);
|
||||
float16x8_t s31 = vld1q_f16(gemm_out + 152);
|
||||
float16x8_t s32 = vld1q_f16(gemm_out + 160);
|
||||
float16x8_t s33 = vld1q_f16(gemm_out + 168);
|
||||
float16x8_t s34 = vld1q_f16(gemm_out + 176);
|
||||
float16x8_t s35 = vld1q_f16(gemm_out + 184);
|
||||
|
||||
float16x8_t s40 = vld1q_f16(gemm_out + 192);
|
||||
float16x8_t s41 = vld1q_f16(gemm_out + 200);
|
||||
float16x8_t s42 = vld1q_f16(gemm_out + 208);
|
||||
float16x8_t s43 = vld1q_f16(gemm_out + 216);
|
||||
float16x8_t s44 = vld1q_f16(gemm_out + 224);
|
||||
float16x8_t s45 = vld1q_f16(gemm_out + 232);
|
||||
|
||||
float16x8_t s50 = vld1q_f16(gemm_out + 240);
|
||||
float16x8_t s51 = vld1q_f16(gemm_out + 248);
|
||||
float16x8_t s52 = vld1q_f16(gemm_out + 256);
|
||||
float16x8_t s53 = vld1q_f16(gemm_out + 264);
|
||||
float16x8_t s54 = vld1q_f16(gemm_out + 272);
|
||||
float16x8_t s55 = vld1q_f16(gemm_out + 280);
|
||||
|
||||
float16x8_t t00 = vaddq_f16(vaddq_f16(vaddq_f16(s00, s10), vaddq_f16(s20, s30)), s40);
|
||||
float16x8_t t01 = vaddq_f16(vaddq_f16(vaddq_f16(s01, s11), vaddq_f16(s21, s31)), s41);
|
||||
float16x8_t t02 = vaddq_f16(vaddq_f16(vaddq_f16(s02, s12), vaddq_f16(s22, s32)), s42);
|
||||
float16x8_t t03 = vaddq_f16(vaddq_f16(vaddq_f16(s03, s13), vaddq_f16(s23, s33)), s43);
|
||||
float16x8_t t04 = vaddq_f16(vaddq_f16(vaddq_f16(s04, s14), vaddq_f16(s24, s34)), s44);
|
||||
float16x8_t t05 = vaddq_f16(vaddq_f16(vaddq_f16(s05, s15), vaddq_f16(s25, s35)), s45);
|
||||
|
||||
float16x8_t t10 = vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 2));
|
||||
float16x8_t t11 = vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 2));
|
||||
float16x8_t t12 = vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 2));
|
||||
float16x8_t t13 = vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 2));
|
||||
float16x8_t t14 = vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 2));
|
||||
float16x8_t t15 = vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 2));
|
||||
|
||||
float16x8_t t20 = vaddq_f16(vaddq_f16(s10, s20), vmulq_n_f16(vaddq_f16(s30, s40), 4));
|
||||
float16x8_t t21 = vaddq_f16(vaddq_f16(s11, s21), vmulq_n_f16(vaddq_f16(s31, s41), 4));
|
||||
float16x8_t t22 = vaddq_f16(vaddq_f16(s12, s22), vmulq_n_f16(vaddq_f16(s32, s42), 4));
|
||||
float16x8_t t23 = vaddq_f16(vaddq_f16(s13, s23), vmulq_n_f16(vaddq_f16(s33, s43), 4));
|
||||
float16x8_t t24 = vaddq_f16(vaddq_f16(s14, s24), vmulq_n_f16(vaddq_f16(s34, s44), 4));
|
||||
float16x8_t t25 = vaddq_f16(vaddq_f16(s15, s25), vmulq_n_f16(vaddq_f16(s35, s45), 4));
|
||||
|
||||
float16x8_t t30 = vaddq_f16(vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 8)), s50);
|
||||
float16x8_t t31 = vaddq_f16(vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 8)), s51);
|
||||
float16x8_t t32 = vaddq_f16(vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 8)), s52);
|
||||
float16x8_t t33 = vaddq_f16(vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 8)), s53);
|
||||
float16x8_t t34 = vaddq_f16(vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 8)), s54);
|
||||
float16x8_t t35 = vaddq_f16(vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 8)), s55);
|
||||
|
||||
float16x8_t bias_ptr = vld1q_f16(bias_data);
|
||||
float16x8_t d00 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t00, t01), vaddq_f16(t02, t03)), t04), bias_ptr);
|
||||
float16x8_t d01 = vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 2)), bias_ptr);
|
||||
float16x8_t d02 = vaddq_f16(vaddq_f16(vaddq_f16(t01, t02), vmulq_n_f16(vaddq_f16(t03, t04), 4)), bias_ptr);
|
||||
float16x8_t d03 =
|
||||
vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 8)), t05), bias_ptr);
|
||||
|
||||
float16x8_t d10 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t10, t11), vaddq_f16(t12, t13)), t14), bias_ptr);
|
||||
float16x8_t d11 = vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 2)), bias_ptr);
|
||||
float16x8_t d12 = vaddq_f16(vaddq_f16(vaddq_f16(t11, t12), vmulq_n_f16(vaddq_f16(t13, t14), 4)), bias_ptr);
|
||||
float16x8_t d13 =
|
||||
vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 8)), t15), bias_ptr);
|
||||
|
||||
float16x8_t d20 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t20, t21), vaddq_f16(t22, t23)), t24), bias_ptr);
|
||||
float16x8_t d21 = vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 2)), bias_ptr);
|
||||
float16x8_t d22 = vaddq_f16(vaddq_f16(vaddq_f16(t21, t22), vmulq_n_f16(vaddq_f16(t23, t24), 4)), bias_ptr);
|
||||
float16x8_t d23 =
|
||||
vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 8)), t25), bias_ptr);
|
||||
|
||||
float16x8_t d30 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t30, t31), vaddq_f16(t32, t33)), t34), bias_ptr);
|
||||
float16x8_t d31 = vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 2)), bias_ptr);
|
||||
float16x8_t d32 = vaddq_f16(vaddq_f16(vaddq_f16(t31, t32), vmulq_n_f16(vaddq_f16(t33, t34), 4)), bias_ptr);
|
||||
float16x8_t d33 =
|
||||
vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 8)), t35), bias_ptr);
|
||||
|
||||
vst1q_f16(output_data, d00);
|
||||
vst1q_f16(output_data + 8, d01);
|
||||
vst1q_f16(output_data + 16, d02);
|
||||
vst1q_f16(output_data + 24, d03);
|
||||
|
||||
vst1q_f16(output_data + output_w * 8, d10);
|
||||
vst1q_f16(output_data + output_w * 8 + 8, d11);
|
||||
vst1q_f16(output_data + output_w * 8 + 16, d12);
|
||||
vst1q_f16(output_data + output_w * 8 + 24, d13);
|
||||
|
||||
vst1q_f16(output_data + 2 * output_w * 8, d20);
|
||||
vst1q_f16(output_data + 2 * output_w * 8 + 8, d21);
|
||||
vst1q_f16(output_data + 2 * output_w * 8 + 16, d22);
|
||||
vst1q_f16(output_data + 2 * output_w * 8 + 24, d23);
|
||||
|
||||
vst1q_f16(output_data + 3 * output_w * 8, d30);
|
||||
vst1q_f16(output_data + 3 * output_w * 8 + 8, d31);
|
||||
vst1q_f16(output_data + 3 * output_w * 8 + 16, d32);
|
||||
vst1q_f16(output_data + 3 * output_w * 8 + 24, d33);
|
||||
}
|
||||
|
||||
void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data,
|
||||
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) {
|
||||
int output_channel = conv_param->output_channel_;
|
||||
int output_h = conv_param->output_h_;
|
||||
int out_h_block = UP_DIV(output_h, C4NUM);
|
||||
int oc8 = UP_DIV(output_channel, C8NUM);
|
||||
if (out_w_block == 0) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int out_w_index = (start_index + i) % out_w_block;
|
||||
int out_h_index = (start_index + i) / out_w_block;
|
||||
int src_tile_offset = i * oc8 * C8NUM * 36;
|
||||
int dst_tile_offset = C8NUM * (out_w_index * C4NUM + out_h_index * C4NUM * out_w_block * C4NUM);
|
||||
|
||||
for (int j = 0; j < oc8; j++) {
|
||||
int src_oc8_offset = src_tile_offset + j * 36 * C8NUM;
|
||||
int dst_oc8_offset = dst_tile_offset + j * C8NUM * out_h_block * out_w_block * C4NUM * C4NUM;
|
||||
const float16_t *src_ptr = gemm_out + src_oc8_offset;
|
||||
const float16_t *bias_ptr = bias_data + j * C8NUM;
|
||||
float16_t *dst_ptr = out_data + dst_oc8_offset;
|
||||
|
||||
// output transform
|
||||
Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, out_w_block * C4NUM);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fp16 common winograd
|
||||
void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num,
|
||||
int out_tile_index, int out_w_block_num, ConvParameter *conv_param,
|
||||
InputTransFp16Func func) {
|
||||
#ifdef ENABLE_ARM64
|
||||
const int tile_num = 16;
|
||||
#else
|
||||
const int tile_num = 12;
|
||||
#endif
|
||||
int input_unit = conv_param->input_unit_;
|
||||
int output_unit = conv_param->output_unit_;
|
||||
int in_channel = conv_param->input_channel_;
|
||||
|
|
|
@ -23,25 +23,11 @@
|
|||
#include "nnacl/fp16/cast_fp16.h"
|
||||
#include "nnacl/fp16/conv_fp16.h"
|
||||
#include "nnacl/fp16/matrix_fp16.h"
|
||||
#include "nnacl/fp16/pack_fp16.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// for fp16 convolution 3x3 filter/input/output transform
|
||||
void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step);
|
||||
|
||||
void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data,
|
||||
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param);
|
||||
|
||||
void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC8, int output_channel,
|
||||
int kernel_plane);
|
||||
|
||||
void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, int output_w);
|
||||
|
||||
void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data,
|
||||
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param);
|
||||
|
||||
// fp16 common winograd
|
||||
void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num,
|
||||
int out_tile_index, int out_w_block_num, ConvParameter *conv_param,
|
||||
|
|
|
@ -17,9 +17,9 @@
|
|||
#ifndef MINDSPORE_NNACL_FP16_WINOGRAD_UTILS_H_
|
||||
#define MINDSPORE_NNACL_FP16_WINOGRAD_UTILS_H_
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
|
||||
|
||||
#define MAX_LEN 256
|
||||
|
||||
|
|
|
@ -17,9 +17,11 @@
|
|||
#ifndef MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
|
||||
#define MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
|
||||
#include <math.h>
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_SSE) || defined(ENABLE_AVX)
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
|
@ -46,7 +48,7 @@
|
|||
#ifdef ENABLE_ARM64
|
||||
#define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2)
|
||||
#else
|
||||
inline static float32x4_t vrecp(float32x4_t v) {
|
||||
static inline float32x4_t vrecp(float32x4_t v) {
|
||||
float32x4_t r = vrecpeq_f32(v);
|
||||
r = vmulq_f32(vrecpsq_f32(v, r), r);
|
||||
r = vmulq_f32(vrecpsq_f32(v, r), r);
|
||||
|
@ -205,25 +207,4 @@ static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
|
|||
return dst;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
static inline float16x8_t MS_TANHX8_F16(float16x8_t src) {
|
||||
float32x4_t src_low = vcvt_f32_f16(vget_low_f16(src));
|
||||
float32x4_t src_high = vcvt_f32_f16(vget_high_f16(src));
|
||||
return vcombine_f16(vcvt_f16_f32(MS_TANHX4_F32(src_low)), vcvt_f16_f32(MS_TANHX4_F32(src_high)));
|
||||
}
|
||||
|
||||
static inline float16x8_t MS_ERFX8_F16(float16x8_t src) {
|
||||
float16x8_t dst;
|
||||
dst[0] = erff(src[0]);
|
||||
dst[1] = erff(src[1]);
|
||||
dst[2] = erff(src[2]);
|
||||
dst[3] = erff(src[3]);
|
||||
dst[4] = erff(src[4]);
|
||||
dst[5] = erff(src[5]);
|
||||
dst[6] = erff(src[6]);
|
||||
dst[7] = erff(src[7]);
|
||||
return dst;
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_
|
||||
#define MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_
|
||||
#include <math.h>
|
||||
#include "nnacl/intrinsics/ms_simd_instructions.h"
|
||||
|
||||
#if defined(ENABLE_ARM82_A32)
|
||||
static inline float16x8_t divq_f16(float16x8_t in1, float16x8_t in2) {
|
||||
float16x8_t dst;
|
||||
asm volatile(
|
||||
"vrecpe.f16 q14, %3\n"
|
||||
"vrecps.f16 q15, %3, q14\n"
|
||||
"vmul.f16 q14, q15, q14\n"
|
||||
"vrecps.f16 q15, %3, q14\n"
|
||||
"vmul.f16 q14, q15, q14\n"
|
||||
"vmul.f16 %0, %2, q14\n"
|
||||
: "=w"(dst)
|
||||
: "0"(dst), "w"(in1), "w"(in2)
|
||||
: "q14", "q15");
|
||||
return dst;
|
||||
}
|
||||
|
||||
static inline float16x4_t div_f16(float16x4_t in1, float16x4_t in2) {
|
||||
float16x4_t dst;
|
||||
asm volatile(
|
||||
"vrecpe.f16 d14, %3\n"
|
||||
"vrecps.f16 d16, %3, d14\n"
|
||||
"vmul.f16 d14, d16, d14\n"
|
||||
"vrecps.f16 d16, %3, d14\n"
|
||||
"vmul.f16 d14, d16, d14\n"
|
||||
"vmul.f16 %0, %2, d14\n"
|
||||
: "=w"(dst)
|
||||
: "0"(dst), "w"(in1), "w"(in2)
|
||||
: "d14", "d16");
|
||||
return dst;
|
||||
}
|
||||
|
||||
static inline float vaddvq_f32(float32x4_t in) { // is not support in arm82 aarch32
|
||||
return in[0] + in[1] + in[2] + in[3];
|
||||
}
|
||||
|
||||
static inline float32x4_t cvt_f32_f16(float16x4_t in) {
|
||||
float32x4_t dst;
|
||||
asm volatile("vcvt.f32.f16 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :);
|
||||
return dst;
|
||||
}
|
||||
|
||||
static inline float16x4_t cvt_f16_f32(float32x4_t in) {
|
||||
float16x4_t dst;
|
||||
asm volatile("vcvt.f16.f32 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :);
|
||||
return dst;
|
||||
}
|
||||
|
||||
#define MS_CVT_F32_F16(src) cvt_f32_f16(src)
|
||||
#define MS_CVT_F16_F32(src) cvt_f16_f32(src)
|
||||
#define MS_DIV_F16(src1, src2) div_f16(src1, src2)
|
||||
#define MS_DIVQ_F16(src1, src2) divq_f16(src1, src2)
|
||||
#define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_f16(src1, src2, vdupq_n_f16(src3))
|
||||
#else
|
||||
#define MS_CVT_F32_F16(src) vcvt_f32_f16(src)
|
||||
#define MS_CVT_F16_F32(src) vcvt_f16_f32(src)
|
||||
#define MS_DIV_F16(src1, src2) vdiv_f16(src1, src2)
|
||||
#define MS_DIVQ_F16(src1, src2) vdivq_f16(src1, src2)
|
||||
#define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_n_f16(src1, src2, src3)
|
||||
#endif
|
||||
|
||||
static inline float16x8_t MS_TANHX8_F16(float16x8_t src) {
|
||||
float32x4_t src_low = MS_CVT_F32_F16(vget_low_f16(src));
|
||||
float32x4_t src_high = MS_CVT_F32_F16(vget_high_f16(src));
|
||||
return vcombine_f16(MS_CVT_F16_F32(MS_TANHX4_F32(src_low)), MS_CVT_F16_F32(MS_TANHX4_F32(src_high)));
|
||||
}
|
||||
|
||||
static inline float16x8_t MS_ERFX8_F16(float16x8_t src) {
|
||||
float16x8_t dst;
|
||||
dst[0] = erff(src[0]);
|
||||
dst[1] = erff(src[1]);
|
||||
dst[2] = erff(src[2]);
|
||||
dst[3] = erff(src[3]);
|
||||
dst[4] = erff(src[4]);
|
||||
dst[5] = erff(src[5]);
|
||||
dst[6] = erff(src[6]);
|
||||
dst[7] = erff(src[7]);
|
||||
return dst;
|
||||
}
|
||||
#endif // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_
|
|
@ -4,11 +4,15 @@ set(NNACL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
|||
include_directories(NNACL_DIR)
|
||||
|
||||
########################### optimized files ###########################
|
||||
file(GLOB SDOT_SRC ${NNACL_DIR}/assembly/opt/*.S)
|
||||
file(GLOB FP16_C_SRC ${NNACL_DIR}/fp16/*.c)
|
||||
file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/fp16/*.S)
|
||||
if(PLATFORM_ARM32)
|
||||
file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/arm82_aarch32_fp16/*.S)
|
||||
else()
|
||||
file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/fp16/*.S)
|
||||
file(GLOB SDOT_SRC ${NNACL_DIR}/assembly/opt/*.S)
|
||||
set_property(SOURCE ${SDOT_SRC} PROPERTY LANGUAGE C)
|
||||
endif()
|
||||
|
||||
set_property(SOURCE ${SDOT_SRC} PROPERTY LANGUAGE C)
|
||||
set_property(SOURCE ${FP16_C_SRC} PROPERTY LANGUAGE C)
|
||||
set_property(SOURCE ${FP16_NEON_SRC} PROPERTY LANGUAGE C)
|
||||
|
||||
|
@ -17,7 +21,6 @@ if(APPLE)
|
|||
set_source_files_properties(${FP16_NEON_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp")
|
||||
endif()
|
||||
########################### share library build ########################
|
||||
list(APPEND SDOT_FILES ${SDOT_SRC})
|
||||
list(APPEND FP16_FILES ${FP16_C_SRC})
|
||||
list(APPEND FP16_FILES ${FP16_NEON_SRC})
|
||||
|
||||
|
@ -27,13 +30,20 @@ if(SUPPORT_TRAIN)
|
|||
endif()
|
||||
|
||||
string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16")
|
||||
|
||||
add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES})
|
||||
add_dependencies(nnacl_optimize_mid fbs_src)
|
||||
if(NOT PLATFORM_ARM32)
|
||||
list(APPEND SDOT_FILES ${SDOT_SRC})
|
||||
add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES})
|
||||
add_dependencies(nnacl_optimize_mid fbs_src)
|
||||
endif()
|
||||
|
||||
if(ENABLE_FP16)
|
||||
add_library(nnacl_fp16_mid OBJECT ${FP16_FILES})
|
||||
if(PLATFORM_ARM32)
|
||||
target_compile_options(nnacl_fp16_mid PRIVATE -march=armv8.2-a+fp16 -mfpu=neon-fp-armv8 -mfloat-abi=softfp)
|
||||
else()
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16")
|
||||
endif()
|
||||
add_dependencies(nnacl_fp16_mid fbs_src)
|
||||
endif()
|
|
@ -5,6 +5,12 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_L
|
|||
message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0")
|
||||
endif()
|
||||
|
||||
if(PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
|
||||
set(ENABLE_FP16 "off")
|
||||
message(WARNING "If you want to build fp16 in arm82_a32, \
|
||||
your Clang version:[${CMAKE_CXX_COMPILER_VERSION}] must not be less than 9.0 and please use android nkd r21e!")
|
||||
endif()
|
||||
|
||||
option(MS_VERSION_MAJOR "major version" 0)
|
||||
option(MS_VERSION_MINOR "minor version" 7)
|
||||
option(MS_VERSION_REVISION "revision version" 0)
|
||||
|
@ -15,6 +21,7 @@ option(PLATFORM_ARM64 "if build device for arm64" off)
|
|||
option(PLATFORM_ARM32 "if build device for arm32" off)
|
||||
option(ENABLE_CONVERTER "if build converter" on)
|
||||
option(ENABLE_FP16 "if build fp16 ops" off)
|
||||
option(ENABLE_ARM82_A32 "if build fp16 on platform_arm32" off)
|
||||
option(ENABLE_TOOLS "if build tools" on)
|
||||
option(BUILD_TESTCASES "if build testcase" on)
|
||||
option(SUPPORT_GPU "if support gpu" off)
|
||||
|
@ -177,6 +184,9 @@ if(ENABLE_NEON)
|
|||
endif()
|
||||
if(ENABLE_FP16)
|
||||
add_compile_definitions(ENABLE_FP16)
|
||||
if(PLATFORM_ARM32)
|
||||
add_compile_definitions(ENABLE_ARM82_A32)
|
||||
endif()
|
||||
endif()
|
||||
if(SUPPORT_GPU STREQUAL opencl)
|
||||
add_definitions(-DGPU_OPENCL)
|
||||
|
|
|
@ -3,6 +3,9 @@ if(ENABLE_V0)
|
|||
add_definitions(-DENABLE_V0)
|
||||
endif()
|
||||
include_directories(${CCSRC_DIR}/backend/kernel_compiler/cpu)
|
||||
set(LITE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
||||
include_directories(${LITE_DIR}/nnacl/)
|
||||
include_directories(${LITE_DIR}/nnacl/optimize)
|
||||
|
||||
if(PLATFORM_ARM32 OR PLATFORM_ARM64)
|
||||
#for performance
|
||||
|
@ -209,9 +212,11 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Release")
|
|||
endif()
|
||||
|
||||
########################## build optimize and float16 library #################################
|
||||
if(PLATFORM_ARM64)
|
||||
target_link_libraries(mindspore-lite cpu_opt_kernel_mid nnacl_optimize_mid)
|
||||
target_link_libraries(mindspore-lite_static cpu_opt_kernel_mid nnacl_optimize_mid)
|
||||
if(PLATFORM_ARM)
|
||||
if(PLATFORM_ARM64)
|
||||
target_link_libraries(mindspore-lite cpu_opt_kernel_mid nnacl_optimize_mid)
|
||||
target_link_libraries(mindspore-lite_static cpu_opt_kernel_mid nnacl_optimize_mid)
|
||||
endif()
|
||||
if(ENABLE_FP16)
|
||||
target_link_libraries(mindspore-lite cpu_fp16_kernel_mid nnacl_fp16_mid)
|
||||
target_link_libraries(mindspore-lite_static cpu_fp16_kernel_mid nnacl_fp16_mid)
|
||||
|
@ -247,8 +252,10 @@ if(DEFINED ARCHS)
|
|||
target_link_libraries(mindspore_lite mindrt_mid)
|
||||
endif()
|
||||
|
||||
if(PLATFORM_ARM64)
|
||||
target_link_libraries(mindspore_lite cpu_opt_kernel_mid nnacl_optimize_mid)
|
||||
if(PLATFORM_ARM)
|
||||
if(PLATFORM_ARM64)
|
||||
target_link_libraries(mindspore_lite cpu_opt_kernel_mid nnacl_optimize_mid)
|
||||
endif()
|
||||
if(ENABLE_FP16)
|
||||
target_link_libraries(mindspore_lite cpu_fp16_kernel_mid nnacl_fp16_mid)
|
||||
endif()
|
||||
|
|
|
@ -155,7 +155,11 @@ bool IsSupportSDot() {
|
|||
|
||||
bool IsSupportFloat16() {
|
||||
bool status = false;
|
||||
#ifdef ENABLE_ARM64
|
||||
#ifdef ENABLE_ARM32
|
||||
status = true;
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_ARM64)
|
||||
#if defined(__ANDROID__)
|
||||
int hwcap_type = 16;
|
||||
uint32_t hwcap = getHwCap(hwcap_type);
|
||||
|
|
|
@ -44,7 +44,7 @@ uint64_t GetTimeUs();
|
|||
bool IsSupportSDot();
|
||||
|
||||
bool IsSupportFloat16();
|
||||
#if defined(__arm__) || defined(__aarch64__)
|
||||
#if defined(__arm__)
|
||||
uint32_t getHwCap(int hwcap_type);
|
||||
#endif
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include "src/common/version_manager.h"
|
||||
#include "nnacl/pooling_parameter.h"
|
||||
#include "src/ios_reg_kernels.h"
|
||||
#ifdef ENABLE_ARM64
|
||||
#if defined(ENABLE_FP16) && defined(ENABLE_ARM)
|
||||
#if defined(__ANDROID__)
|
||||
#include <asm/hwcap.h>
|
||||
#endif
|
||||
|
@ -55,6 +55,8 @@ int KernelRegistry::Init() {
|
|||
} else {
|
||||
MS_LOG(INFO) << "The current device NOT supports Sdot.";
|
||||
}
|
||||
#endif
|
||||
#ifdef ENABLE_FP16
|
||||
if (mindspore::lite::IsSupportFloat16()) {
|
||||
MS_LOG(INFO) << "The current device supports float16.";
|
||||
} else {
|
||||
|
|
|
@ -17,7 +17,7 @@ endif()
|
|||
add_library(cpu_kernel_mid OBJECT ${KERNEL_SRC})
|
||||
add_dependencies(cpu_kernel_mid fbs_src)
|
||||
|
||||
if(PLATFORM_ARM64)
|
||||
if(PLATFORM_ARM)
|
||||
if(ENABLE_FP16)
|
||||
file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc)
|
||||
if(SUPPORT_TRAIN)
|
||||
|
|
|
@ -52,7 +52,7 @@ int ConstantOfShapeCPUKernel::DoExecute(int task_id) {
|
|||
ConstantOfShapeInt32(reinterpret_cast<int32_t *>(output_ptr_), start, start + current_stride,
|
||||
param_->value_.int32_value_);
|
||||
break;
|
||||
#ifdef ENABLE_NEON
|
||||
#ifdef ENABLE_FP16
|
||||
case kNumberTypeFloat16:
|
||||
ConstantOfShapeFp16(reinterpret_cast<float16_t *>(output_ptr_), start, start + current_stride,
|
||||
param_->value_.f32_value_);
|
||||
|
|
|
@ -31,8 +31,8 @@ int Convolution1x1FP16CPUKernel::InitMatmulParam() {
|
|||
matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_;
|
||||
matmul_param_->col_ = conv_param_->output_channel_;
|
||||
matmul_param_->deep_ = conv_param_->input_channel_;
|
||||
matmul_param_->row_16_ = UP_ROUND(matmul_param_->row_, C16NUM);
|
||||
matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM);
|
||||
matmul_param_->row_align_ = UP_ROUND(matmul_param_->row_, row_tile_);
|
||||
matmul_param_->col_align_ = UP_ROUND(matmul_param_->col_, col_tile_);
|
||||
matmul_param_->act_type_ = conv_param_->act_type_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -54,14 +54,14 @@ int Convolution1x1FP16CPUKernel::InitConv1x1Param() {
|
|||
pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 ||
|
||||
conv_param_->stride_w_ != 1);
|
||||
|
||||
if ((matmul_param_->row_ > (C16NUM * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) {
|
||||
if ((matmul_param_->row_ > (row_tile_ * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) {
|
||||
multi_thread_by_hw_ = true;
|
||||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, C16NUM));
|
||||
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, C16NUM), thread_count_) * C16NUM;
|
||||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, row_tile_));
|
||||
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, row_tile_), thread_count_) * row_tile_;
|
||||
} else {
|
||||
multi_thread_by_hw_ = false;
|
||||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM));
|
||||
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM;
|
||||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, col_tile_));
|
||||
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, col_tile_), thread_count_) * col_tile_;
|
||||
}
|
||||
|
||||
if (pre_trans_input_) {
|
||||
|
@ -81,8 +81,8 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() {
|
|||
auto output_channel = weight_tensor->Batch();
|
||||
|
||||
if (in_tensors_.size() == 3) {
|
||||
size_t size = UP_ROUND(output_channel, C8NUM) * sizeof(float16_t);
|
||||
size_t weight_size = output_channel * sizeof(float16_t);
|
||||
size_t size = UP_ROUND(output_channel, col_tile_) * sizeof(float16_t);
|
||||
size_t bias_size = output_channel * sizeof(float16_t);
|
||||
bias_data_ = malloc(size);
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!";
|
||||
|
@ -94,11 +94,11 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() {
|
|||
Float32ToFloat16(reinterpret_cast<float *>(origin_bias_), reinterpret_cast<float16_t *>(bias_data_),
|
||||
output_channel);
|
||||
}
|
||||
memset(reinterpret_cast<char *>(bias_data_) + weight_size, 0, size - weight_size);
|
||||
memset(reinterpret_cast<char *>(bias_data_) + bias_size, 0, size - bias_size);
|
||||
}
|
||||
|
||||
size_t size = input_channel * UP_ROUND(output_channel, C8NUM) * sizeof(float16_t);
|
||||
size_t down_size = input_channel * DOWN_DIV(output_channel, C8NUM) * C8NUM * sizeof(float16_t);
|
||||
size_t size = input_channel * UP_ROUND(output_channel, col_tile_) * sizeof(float16_t);
|
||||
size_t down_size = input_channel * DOWN_DIV(output_channel, col_tile_) * col_tile_ * sizeof(float16_t);
|
||||
weight_ptr_ = reinterpret_cast<float16_t *>(malloc(size));
|
||||
if (weight_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!";
|
||||
|
@ -111,6 +111,12 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() {
|
|||
}
|
||||
|
||||
int Convolution1x1FP16CPUKernel::Init() {
|
||||
col_tile_ = C8NUM;
|
||||
#ifdef ENABLE_ARM64
|
||||
row_tile_ = C16NUM;
|
||||
#else
|
||||
row_tile_ = C12NUM;
|
||||
#endif
|
||||
matmul_param_ = new (std::nothrow) MatMulParameter();
|
||||
if (matmul_param_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Init matmul_param_ failed.";
|
||||
|
@ -177,8 +183,11 @@ int Convolution1x1FP16CPUKernel::RunHw(int task_id) {
|
|||
|
||||
float16_t *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_;
|
||||
float16_t *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_;
|
||||
#ifdef ENABLE_ARM64
|
||||
RowMajor2Col16MajorFp16Opt(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
|
||||
|
||||
#else
|
||||
RowMajor2Col12MajorFp16Opt(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
|
||||
#endif
|
||||
float16_t *thread_output_ptr = output_ptr_ + task_id * thread_stride_ * matmul_param_->col_;
|
||||
MatMulFp16(thread_pack_input, weight_ptr_, thread_output_ptr, reinterpret_cast<float16_t *>(bias_data_),
|
||||
matmul_param_->act_type_, matmul_param_->deep_, cur_hw_, matmul_param_->col_, matmul_param_->col_,
|
||||
|
@ -211,7 +220,7 @@ int Convolution1x1FP16CPUKernel::Run() {
|
|||
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
|
||||
|
||||
pack_input_ = reinterpret_cast<float16_t *>(
|
||||
ctx_->allocator->Malloc(matmul_param_->row_16_ * matmul_param_->deep_ * sizeof(float16_t)));
|
||||
ctx_->allocator->Malloc(matmul_param_->row_align_ * matmul_param_->deep_ * sizeof(float16_t)));
|
||||
if (pack_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!";
|
||||
return RET_MEMORY_FAILED;
|
||||
|
@ -231,7 +240,11 @@ int Convolution1x1FP16CPUKernel::Run() {
|
|||
if (multi_thread_by_hw_) {
|
||||
ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Fp16RunHw, this, thread_count_);
|
||||
} else {
|
||||
#ifdef ENABLE_ARM64
|
||||
RowMajor2Col16MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
|
||||
#else
|
||||
RowMajor2Col12MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
|
||||
#endif
|
||||
ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Fp16RunOc, this, thread_count_);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
|
|
|
@ -62,6 +62,8 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
|
|||
float16_t *pack_input_ = nullptr;
|
||||
float16_t *output_ptr_ = nullptr;
|
||||
MatMulParameter *matmul_param_ = nullptr;
|
||||
int col_tile_;
|
||||
int row_tile_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
|
|||
int out_channel = filter_tensor->Batch();
|
||||
conv_param_->input_channel_ = in_channel;
|
||||
conv_param_->output_channel_ = out_channel;
|
||||
int oc8 = UP_ROUND(out_channel, C8NUM);
|
||||
int oc8 = UP_ROUND(out_channel, col_tile_);
|
||||
int kernel_plane = filter_tensor->Height() * filter_tensor->Width();
|
||||
int pack_weight_size = oc8 * in_channel * kernel_plane;
|
||||
|
||||
|
@ -68,9 +68,8 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
|
|||
}
|
||||
|
||||
int ConvolutionFP16CPUKernel::InitTmpBuffer() {
|
||||
const int cal_num = 16;
|
||||
int unit_size =
|
||||
conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * cal_num * thread_count_;
|
||||
conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * row_tile_ * thread_count_;
|
||||
|
||||
packed_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(unit_size * sizeof(float16_t)));
|
||||
if (packed_input_ == nullptr) {
|
||||
|
@ -87,6 +86,12 @@ int ConvolutionFP16CPUKernel::InitTmpBuffer() {
|
|||
}
|
||||
|
||||
int ConvolutionFP16CPUKernel::Init() {
|
||||
#ifdef ENABLE_ARM64
|
||||
row_tile_ = C16NUM;
|
||||
#else
|
||||
row_tile_ = C12NUM;
|
||||
#endif
|
||||
col_tile_ = C8NUM;
|
||||
auto ret = InitWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init weight bias failed.";
|
||||
|
@ -98,7 +103,7 @@ int ConvolutionFP16CPUKernel::Init() {
|
|||
void ConvolutionFP16CPUKernel::AdjustNumberOfThread() {
|
||||
auto out_tensor = out_tensors_.front();
|
||||
int out_plane = out_tensor->Height() * out_tensor->Width();
|
||||
thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(out_plane, C16NUM));
|
||||
thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(out_plane, row_tile_));
|
||||
conv_param_->thread_num_ = thread_count_;
|
||||
}
|
||||
|
||||
|
|
|
@ -62,6 +62,8 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
|
|||
float16_t *packed_input_ = nullptr;
|
||||
float16_t *packed_weight_ = nullptr;
|
||||
float16_t *col_major_input_ = nullptr;
|
||||
int col_tile_;
|
||||
int row_tile_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -38,13 +38,10 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
|
|||
int out_channel = filter_tensor->Batch();
|
||||
conv_param_->input_channel_ = in_channel;
|
||||
conv_param_->output_channel_ = out_channel;
|
||||
|
||||
const int oc_block = C8NUM;
|
||||
int oc_block_num = UP_DIV(out_channel, C8NUM);
|
||||
|
||||
int oc_block_num = UP_DIV(out_channel, col_tile_);
|
||||
// init weight
|
||||
// set data
|
||||
auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float16_t);
|
||||
auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * col_tile_ * sizeof(float16_t);
|
||||
trans_weight_ = reinterpret_cast<float16_t *>(malloc(trans_matrix_data_size));
|
||||
if (trans_weight_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc trans_weight_ failed.";
|
||||
|
@ -73,7 +70,7 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
|
|||
MS_LOG(ERROR) << "get execute filter failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = WinogradFilterTransformFp16(execute_weight_, matrix_g, matrix_gt, oc_block);
|
||||
ret = WinogradFilterTransformFp16(execute_weight_, matrix_g, matrix_gt, col_tile_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "winograd filter transform failed.";
|
||||
return ret;
|
||||
|
@ -85,12 +82,12 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
|
|||
}
|
||||
|
||||
// init bias
|
||||
bias_data_ = malloc(oc_block_num * oc_block * sizeof(float16_t));
|
||||
bias_data_ = malloc(oc_block_num * col_tile_ * sizeof(float16_t));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_data_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float16_t));
|
||||
memset(bias_data_, 0, oc_block_num * col_tile_ * sizeof(float16_t));
|
||||
|
||||
if (in_tensors_.size() == kInputSize2) {
|
||||
if (origin_bias_data_type_ == kNumberTypeFloat16) {
|
||||
|
@ -105,11 +102,9 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
|
|||
}
|
||||
|
||||
int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
|
||||
const int cal_num = 16;
|
||||
int channel_out = conv_param_->output_channel_;
|
||||
|
||||
size_t tile_buffer_size =
|
||||
thread_count_ * cal_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t);
|
||||
thread_count_ * row_tile_ * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t);
|
||||
trans_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(tile_buffer_size));
|
||||
if (trans_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc trans_input_ failed.";
|
||||
|
@ -117,7 +112,7 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
|
|||
}
|
||||
|
||||
gemm_out_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(
|
||||
thread_count_ * cal_num * input_unit_ * input_unit_ * UP_ROUND(channel_out, C8NUM) * sizeof(float16_t)));
|
||||
thread_count_ * row_tile_ * input_unit_ * input_unit_ * UP_ROUND(channel_out, C8NUM) * sizeof(float16_t)));
|
||||
if (gemm_out_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc gemm_out_ failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -131,7 +126,7 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
|
|||
}
|
||||
|
||||
col_buffer_ = reinterpret_cast<float16_t *>(
|
||||
ctx_->allocator->Malloc(thread_count_ * cal_num * conv_param_->input_channel_ * sizeof(float16_t)));
|
||||
ctx_->allocator->Malloc(thread_count_ * row_tile_ * conv_param_->input_channel_ * sizeof(float16_t)));
|
||||
if (col_buffer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc col_buffer_ failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -159,6 +154,12 @@ int ConvolutionWinogradFP16CPUKernel::ConfigInputOutput() {
|
|||
}
|
||||
|
||||
int ConvolutionWinogradFP16CPUKernel::Init() {
|
||||
col_tile_ = C8NUM;
|
||||
#ifdef ENABLE_ARM64
|
||||
row_tile_ = C16NUM;
|
||||
#else
|
||||
row_tile_ = C12NUM;
|
||||
#endif
|
||||
kernel_unit_ = conv_param_->kernel_h_;
|
||||
input_unit_ = output_unit_ + kernel_unit_ - 1;
|
||||
conv_param_->input_unit_ = input_unit_;
|
||||
|
|
|
@ -86,6 +86,8 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
|
|||
TmpBufferAddressFp16 tmp_buffer_address_list_[4];
|
||||
InputTransFp16Func in_func_;
|
||||
OutputTransFp16Func out_func_;
|
||||
int col_tile_;
|
||||
int row_tile_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -13,26 +13,20 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
#ifdef ENABLE_ARM
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/fp16/cast_fp16.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
#ifdef ENABLE_ARM64
|
||||
extern void Float32ToFloat16(const float *input, float16_t *output, int number);
|
||||
extern void Float16ToFloat32(const float16_t *input, float *output, int number);
|
||||
|
||||
inline void Float32ToFloat16_fp16_handler(const void *input, void *output, int number) {
|
||||
static inline void Float32ToFloat16_fp16_handler(const void *input, void *output, int number) {
|
||||
Float32ToFloat16(reinterpret_cast<const float *>(input), reinterpret_cast<float16_t *>(output), number);
|
||||
}
|
||||
|
||||
inline void Float16ToFloat32_fp16_handler(const void *input, void *output, int number) {
|
||||
static inline void Float16ToFloat32_fp16_handler(const void *input, void *output, int number) {
|
||||
Float16ToFloat32(reinterpret_cast<const float16_t *>(input), reinterpret_cast<float *>(output), number);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -53,8 +53,8 @@ int PoolingFp16CPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int PoolingFp16CPUKernel::RunImpl(int task_id) {
|
||||
float16_t minf = -FLT_MAX;
|
||||
float16_t maxf = FLT_MAX;
|
||||
float16_t minf = -FLT16_MAX;
|
||||
float16_t maxf = FLT16_MAX;
|
||||
if (pooling_param_->act_type_ == ActType_Relu) {
|
||||
minf = 0.f;
|
||||
} else if (pooling_param_->act_type_ == ActType_Relu6) {
|
||||
|
|
|
@ -45,7 +45,7 @@
|
|||
#include "src/runtime/agent/npu/optimizer/npu_fusion_pass.h"
|
||||
#include "src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h"
|
||||
#endif
|
||||
#if defined(ENABLE_ARM64) && defined(ENABLE_FP16)
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
#include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
|
||||
#endif
|
||||
|
||||
|
@ -230,7 +230,7 @@ int CopyConstTensor(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origi
|
|||
auto origin_data = tensor->data_c();
|
||||
MS_ASSERT(origin_data != nullptr);
|
||||
if (tensor->data_type() == kNumberTypeFloat32 && dst_data_type == kNumberTypeFloat16) {
|
||||
#if defined(ENABLE_ARM64) && defined(ENABLE_FP16)
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
auto restore_tensor = Tensor::CopyTensor(*tensor, false);
|
||||
restore_tensor->set_data(origin_data);
|
||||
restore_tensor->set_own_data(tensor->own_data());
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "src/sub_graph_kernel.h"
|
||||
#include "src/tensor.h"
|
||||
#include "src/tensorlist.h"
|
||||
#if defined(ENABLE_ARM64) && defined(ENABLE_FP16)
|
||||
#ifdef ENABLE_FP16
|
||||
#include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
|
||||
#endif
|
||||
#include "src/common/version_manager.h"
|
||||
|
@ -283,7 +283,7 @@ int CpuFp16SubGraph::Float16TensorToFloat32Tensor(lite::Tensor *tensor) {
|
|||
}
|
||||
|
||||
int CpuFp16SubGraph::PreProcess() {
|
||||
#ifdef ENABLE_ARM64
|
||||
#ifdef ENABLE_FP16
|
||||
if (!mindspore::lite::IsSupportFloat16()) {
|
||||
MS_LOG(ERROR) << "Unsupported fp16 in this devices";
|
||||
return RET_ERROR;
|
||||
|
@ -347,7 +347,7 @@ int CpuFp16SubGraph::PreProcess() {
|
|||
}
|
||||
|
||||
int CpuFp16SubGraph::PostProcess() {
|
||||
#ifdef ENABLE_ARM64
|
||||
#ifdef ENABLE_FP16
|
||||
if (!mindspore::lite::IsSupportFloat16()) {
|
||||
MS_LOG(ERROR) << "Unsupported fp16 in this devices";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -377,8 +377,11 @@ add_dependencies(lite-test fbs_src)
|
|||
|
||||
target_link_libraries(lite-test dl mindspore::gtest)
|
||||
|
||||
if(PLATFORM_ARM64 AND ENABLE_FP16)
|
||||
target_link_libraries(lite-test nnacl_fp16_mid nnacl_optimize_mid)
|
||||
if(PLATFORM_ARM AND ENABLE_FP16)
|
||||
target_link_libraries(lite-test nnacl_fp16_mid)
|
||||
if(PLATFORM_ARM64)
|
||||
target_link_libraries(lite-test nnacl_optimize_mid)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(PLATFORM_ARM)
|
||||
|
|
Loading…
Reference in New Issue