fp16 bug fix

This commit is contained in:
lzk 2021-04-12 23:10:40 -07:00
parent 679985adb8
commit f18d7cf435
67 changed files with 1457 additions and 1568 deletions

View File

@ -607,7 +607,7 @@ build_lite()
cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \
-DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="clang" \
-DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \
-DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \
-DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DENABLE_FP16="on" \
-DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \
-DSUPPORT_GPU=${LOCAL_LITE_ENABLE_GPU} -DSUPPORT_NPU=${LOCAL_LITE_ENABLE_NPU} -DENABLE_V0=on \
-DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \

View File

@ -68,7 +68,7 @@ add_library(nnacl_mid OBJECT ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC})
add_dependencies(nnacl fbs_src)
add_dependencies(nnacl_mid fbs_src)
########################### arm64 build optimize library ########################
if(PLATFORM_ARM64)
########################### arm fp16 build optimize library ########################
if(ENABLE_FP16)
add_subdirectory(${NNACL_DIR}/optimize)
endif()

View File

@ -30,7 +30,7 @@ typedef struct ArgElement {
int8_t i8_data_;
int32_t i_data_;
float f_data_;
#ifdef ENABLE_ARM64
#ifdef ENABLE_ARM
float16_t f16_data_;
#endif
} data_;

View File

@ -0,0 +1,602 @@
#ifdef ENABLE_ARM32
#include "nnacl/assembly_global.h"
.text
.align 5
.global MatMul12x8A32Fp16
#ifndef __APPLE__
.type MatMul12x8A32Fp16, %function
#endif
// void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
// int deep, int row, int col, int stride, bool write_mode);
// r0: a
// r1: b
// r2: dst
// r3: bias
// #4: depth
// #8: row
// #12: col
// #16: stride
// #20: writeNhwc/writeWino
asm_function MatMul12x8A32Fp16
// r13(sp) and r15(pc) can not be used!!
// r9 r4 is tmp register
// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf
push {r3-r11, lr}
vpush {q4-q7}
add sp, sp, #104
ldr r5, [sp, #4]
ldr r6, [sp, #8]
ldr r7, [sp, #12]
ldr r8, [sp, #16]
ldr lr, [sp, #20]
mov r10, r1 // b
mov r11, r0 // a
mov r12, r2 // dst
cmp lr, #2
bne NoWinograd
mul r4, r8, r7 // stride * col
add r4, r4, r4 // r4 * sizeof(float16_t)
mov r9, #16
mul r9, r8, r9 // stride * 8 * sizeof(float16_t)
NoWinograd:
add r8, r8, r8 // stride * sizeof(float16_t)
a .req r0
weight .req r1
dst .req r2
bias .req r3
depth .req r5
row .req r6
col .req r7
stride .req r8
b_tmp .req r10
a_tmp .req r11
dst_tmp .req r12
.macro STORE_12x8 p1
vst1.16 {\p1}, [dst]
add dst, dst, stride
.endm
.macro STORE_12x7 p1, p2, p3
add r4, dst, #8
add r9, dst, #12
vst1.16 {\p1}, [dst]
vst1.32 {\p2}, [r4]
vst1.16 {\p3}, [r9]
add dst, dst, stride
.endm
.macro STORE_12x6 p1, p2
add r4, dst, #8
vst1.16 {\p1}, [dst]
vst1.32 {\p2}, [r4]
add dst, dst, stride
.endm
.macro STORE_12x5 p1, p2
add r4, dst, #8
vst1.16 {\p1}, [dst]
vst1.16 {\p2}, [r4]
add dst, dst, stride
.endm
.macro STORE_12x4 p1
vst1.16 {\p1}, [dst]
add dst, dst, stride
.endm
.macro STORE_12x3 p1, p2
add r4, dst, #4
vst1.32 {\p1}, [dst]
vst1.16 {\p2}, [r4]
add dst, dst, stride
.endm
.macro STORE_12x2 p1
vst1.32 {\p1}, [dst]
add dst, dst, stride
.endm
.macro STORE_12x1 p1
vst1.16 {\p1}, [dst]
add dst, dst, stride
.endm
.macro STORE_C8 p1, p2
vst1.16 {\p1}, [dst]
cmp row, \p2
add dst, dst, stride
beq WriteEnd
.endm
.macro STORE_C7 p1, p2, p3, p4
add r4, dst, #8
add r9, dst, #12
vst1.16 {\p1}, [dst]
vst1.32 {\p2}, [r4]
vst1.16 {\p3}, [r9]
add dst, dst, stride
cmp row, \p4
beq WriteEnd
.endm
.macro STORE_C6 p1, p2, p3
add r4, dst, #8
vst1.16 {\p1}, [dst]
vst1.32 {\p2}, [r4]
add dst, dst, stride
cmp row, \p3
beq WriteEnd
.endm
.macro STORE_C5 p1, p2, p3
add r4, dst, #8
vst1.16 {\p1}, [dst]
vst1.16 {\p2}, [r4]
add dst, dst, stride
cmp row, \p3
beq WriteEnd
.endm
.macro STORE_C4 p1, p2
vst1.16 {\p1}, [dst]
cmp row, \p2
add dst, dst, stride
beq WriteEnd
.endm
.macro STORE_C3 p1, p2, p3
add r4, dst, #4
vst1.32 {\p1}, [dst]
vst1.16 {\p2}, [r4]
add dst, dst, stride
cmp row, \p3
beq WriteEnd
.endm
.macro STORE_C2 p1, p2
vst1.32 {\p1}, [dst]
add dst, dst, stride
cmp row, \p2
beq WriteEnd
.endm
.macro STORE_C1 p1, p2
vst1.16 {\p1}, [dst]
add dst, dst, stride
cmp row, \p2
beq WriteEnd
.endm
LoopRow12:
ldr bias, [sp, #-40]
LoopCol8:
mov dst, dst_tmp
mov a, a_tmp
ldr depth, [sp, #4]
veor q4, q4, q4
veor q5, q5, q5
veor q6, q6, q6
veor q7, q7, q7
veor q8, q8, q8
veor q9, q9, q9
veor q10, q10, q10
veor q11, q11, q11
veor q12, q12, q12
veor q13, q13, q13
veor q14, q14, q14
veor q15, q15, q15
LoopDepth:
vld1.16 {q0, d2}, [a]!
vld1.16 {q2}, [weight]!
vmla.f16 q4, q2, d0[0]
vmla.f16 q5, q2, d0[1]
vmla.f16 q6, q2, d0[2]
vmla.f16 q7, q2, d0[3]
vmla.f16 q8, q2, d1[0]
vmla.f16 q9, q2, d1[1]
vmla.f16 q10, q2, d1[2]
vmla.f16 q11, q2, d1[3]
vmla.f16 q12, q2, d2[0]
vmla.f16 q13, q2, d2[1]
vmla.f16 q14, q2, d2[2]
vmla.f16 q15, q2, d2[3]
subs depth, depth, #1
bne LoopDepth
Bias:
cmp bias, #0
beq Activation
vld1.16 {q0}, [bias]!
vadd.f16 q4, q4, q0
vadd.f16 q5, q5, q0
vadd.f16 q6, q6, q0
vadd.f16 q7, q7, q0
vadd.f16 q8, q8, q0
vadd.f16 q9, q9, q0
vadd.f16 q10, q10, q0
vadd.f16 q11, q11, q0
vadd.f16 q12, q12, q0
vadd.f16 q13, q13, q0
vadd.f16 q14, q14, q0
vadd.f16 q15, q15, q0
Activation:
ldr lr, [sp]
cmp lr, #3
beq Relu6
cmp lr, #1
beq Relu
b Write
Relu6:
vmov.i16 q2, #0x4600
vadd.f16 q4, q4, q2
vadd.f16 q5, q5, q2
vadd.f16 q6, q6, q2
vadd.f16 q7, q7, q2
vmin.f16 q8, q8, q2
vmin.f16 q9, q9, q2
vmin.f16 q10, q10, q2
vmin.f16 q11, q11, q2
vmin.f16 q12, q12, q2
vmin.f16 q13, q13, q2
vmin.f16 q14, q14, q2
vmin.f16 q15, q15, q2
Relu:
veor q3, q3, q3
vmax.f16 q4, q4, q3
vmax.f16 q5, q5, q3
vmax.f16 q6, q6, q3
vmax.f16 q7, q7, q3
vmax.f16 q8, q8, q3
vmax.f16 q9, q9, q3
vmax.f16 q10, q10, q3
vmax.f16 q11, q11, q3
vmax.f16 q12, q12, q3
vmax.f16 q13, q13, q3
vmax.f16 q14, q14, q3
vmax.f16 q15, q15, q3
Write:
ldr lr, [sp, #20]
cmp lr, #2
beq WriteWinograd
cmp row, #12
bge Write12xCol
b WriteRowxCol
WriteWinograd:
vst1.16 {q4}, [dst]
add dst, dst, r4
vst1.16 {q5}, [dst]
add dst, dst, r4
vst1.16 {q6}, [dst]
add dst, dst, r4
vst1.16 {q7}, [dst]
add dst, dst, r4
vst1.16 {q8}, [dst]
add dst, dst, r4
vst1.16 {q9}, [dst]
add dst, dst, r4
vst1.16 {q10}, [dst]
add dst, dst, r4
vst1.16 {q11}, [dst]
add dst, dst, r4
vst1.16 {q12}, [dst]
add dst, dst, r4
vst1.16 {q13}, [dst]
add dst, dst, r4
vst1.16 {q14}, [dst]
add dst, dst, r4
vst1.16 {q15}, [dst]
add dst_tmp, dst_tmp, r9
b WriteEnd
Write12xCol:
cmp col, #8
bge Write12x8
cmp col, #1
beq Write12x1
cmp col, #2
beq Write12x2
cmp col, #3
beq Write12x3
cmp col, #4
beq Write12x4
cmp col, #5
beq Write12x5
cmp col, #6
beq Write12x6
b Write12x7
WriteRowxCol:
cmp col, #8
bge WriteRowx8
cmp col, #1
beq WriteRowx1
cmp col, #2
beq WriteRowx2
cmp col, #3
beq WriteRowx3
cmp col, #4
beq WriteRowx4
cmp col, #5
beq WriteRowx5
cmp col, #6
beq WriteRowx6
b WriteRowx7
Write12x8:
STORE_12x8 q4
STORE_12x8 q5
STORE_12x8 q6
STORE_12x8 q7
STORE_12x8 q8
STORE_12x8 q9
STORE_12x8 q10
STORE_12x8 q11
STORE_12x8 q12
STORE_12x8 q13
STORE_12x8 q14
STORE_12x8 q15
b WriteEnd
WriteRowx8:
STORE_C8 q4, #1
STORE_C8 q5, #2
STORE_C8 q6, #3
STORE_C8 q7, #4
STORE_C8 q8, #5
STORE_C8 q9, #6
STORE_C8 q10, #7
STORE_C8 q11, #8
STORE_C8 q12, #9
STORE_C8 q13, #10
STORE_C8 q14, #11
STORE_C8 q15, #12
b WriteEnd
Write12x1:
STORE_12x1 d8[0]
STORE_12x1 d10[0]
STORE_12x1 d12[0]
STORE_12x1 d14[0]
STORE_12x1 d16[0]
STORE_12x1 d18[0]
STORE_12x1 d20[0]
STORE_12x1 d22[0]
STORE_12x1 d24[0]
STORE_12x1 d26[0]
STORE_12x1 d28[0]
STORE_12x1 d30[0]
b WriteEnd
WriteRowx1:
STORE_C1 d8[0], #1
STORE_C1 d10[0], #2
STORE_C1 d12[0], #3
STORE_C1 d14[0], #4
STORE_C1 d16[0], #5
STORE_C1 d18[0], #6
STORE_C1 d20[0], #7
STORE_C1 d22[0], #8
STORE_C1 d24[0], #9
STORE_C1 d26[0], #10
STORE_C1 d28[0], #11
STORE_C1 d30[0], #12
b WriteEnd
Write12x2:
STORE_12x2 d8[0]
STORE_12x2 d10[0]
STORE_12x2 d12[0]
STORE_12x2 d14[0]
STORE_12x2 d16[0]
STORE_12x2 d18[0]
STORE_12x2 d20[0]
STORE_12x2 d22[0]
STORE_12x2 d24[0]
STORE_12x2 d26[0]
STORE_12x2 d28[0]
STORE_12x2 d30[0]
b WriteEnd
WriteRowx2:
STORE_C2 d8[0], #1
STORE_C2 d10[0], #2
STORE_C2 d12[0], #3
STORE_C2 d14[0], #4
STORE_C2 d16[0], #5
STORE_C2 d18[0], #6
STORE_C2 d20[0], #7
STORE_C2 d22[0], #8
STORE_C2 d24[0], #9
STORE_C2 d26[0], #10
STORE_C2 d28[0], #11
STORE_C2 d30[0], #12
b WriteEnd
Write12x3:
STORE_12x3 d8[0], d8[2]
STORE_12x3 d10[0], d10[2]
STORE_12x3 d12[0], d12[2]
STORE_12x3 d14[0], d14[2]
STORE_12x3 d16[0], d16[2]
STORE_12x3 d18[0], d18[2]
STORE_12x3 d20[0], d20[2]
STORE_12x3 d22[0], d22[2]
STORE_12x3 d24[0], d24[2]
STORE_12x3 d26[0], d26[2]
STORE_12x3 d28[0], d28[2]
STORE_12x3 d30[0], d30[2]
b WriteEnd
WriteRowx3:
STORE_C3 d8[0], d8[2], #1
STORE_C3 d10[0], d10[2], #2
STORE_C3 d12[0], d12[2], #3
STORE_C3 d14[0], d14[2], #4
STORE_C3 d16[0], d16[2], #5
STORE_C3 d18[0], d18[2], #6
STORE_C3 d20[0], d20[2], #7
STORE_C3 d22[0], d22[2], #8
STORE_C3 d24[0], d24[2], #9
STORE_C3 d26[0], d26[2], #10
STORE_C3 d28[0], d28[2], #11
STORE_C3 d30[0], d30[2], #12
b WriteEnd
Write12x4:
STORE_12x4 d8
STORE_12x4 d10
STORE_12x4 d12
STORE_12x4 d14
STORE_12x4 d16
STORE_12x4 d18
STORE_12x4 d20
STORE_12x4 d22
STORE_12x4 d24
STORE_12x4 d26
STORE_12x4 d28
STORE_12x4 d30
b WriteEnd
WriteRowx4:
STORE_C4 d8, #1
STORE_C4 d10, #2
STORE_C4 d12, #3
STORE_C4 d14, #4
STORE_C4 d16, #5
STORE_C4 d18, #6
STORE_C4 d20, #7
STORE_C4 d22, #8
STORE_C4 d24, #9
STORE_C4 d26, #10
STORE_C4 d28, #11
STORE_C4 d30, #12
b WriteEnd
Write12x5:
STORE_12x5 d8, d9[0]
STORE_12x5 d10, d11[0]
STORE_12x5 d12, d13[0]
STORE_12x5 d14, d15[0]
STORE_12x5 d16, d17[0]
STORE_12x5 d18, d19[0]
STORE_12x5 d20, d21[0]
STORE_12x5 d22, d23[0]
STORE_12x5 d24, d25[0]
STORE_12x5 d26, d27[0]
STORE_12x5 d28, d29[0]
STORE_12x5 d30, d31[0]
b WriteEnd
WriteRowx5:
STORE_C5 d8, d9[0], #1
STORE_C5 d10, d11[0], #2
STORE_C5 d12, d13[0], #3
STORE_C5 d14, d15[0], #4
STORE_C5 d16, d17[0], #5
STORE_C5 d18, d19[0], #6
STORE_C5 d20, d21[0], #7
STORE_C5 d22, d23[0], #8
STORE_C5 d24, d25[0], #9
STORE_C5 d26, d27[0], #10
STORE_C5 d28, d29[0], #11
STORE_C5 d30, d31[0], #12
b WriteEnd
Write12x6:
STORE_12x6 d8, d9[0]
STORE_12x6 d10, d11[0]
STORE_12x6 d12, d13[0]
STORE_12x6 d14, d15[0]
STORE_12x6 d16, d17[0]
STORE_12x6 d18, d19[0]
STORE_12x6 d20, d21[0]
STORE_12x6 d22, d23[0]
STORE_12x6 d24, d25[0]
STORE_12x6 d26, d27[0]
STORE_12x6 d28, d29[0]
STORE_12x6 d30, d31[0]
b WriteEnd
WriteRowx6:
STORE_C6 d8, d9[0], #1
STORE_C6 d10, d11[0], #2
STORE_C6 d12, d13[0], #3
STORE_C6 d14, d15[0], #4
STORE_C6 d16, d17[0], #5
STORE_C6 d18, d19[0], #6
STORE_C6 d20, d21[0], #7
STORE_C6 d22, d23[0], #8
STORE_C6 d24, d25[0], #9
STORE_C6 d26, d27[0], #10
STORE_C6 d28, d29[0], #11
STORE_C6 d30, d31[0], #12
b WriteEnd
Write12x7:
STORE_12x7 d8, d9[0], d9[2]
STORE_12x7 d10, d11[0], d11[2]
STORE_12x7 d12, d13[0], d13[2]
STORE_12x7 d14, d15[0], d15[2]
STORE_12x7 d16, d17[0], d17[2]
STORE_12x7 d18, d19[0], d19[2]
STORE_12x7 d20, d21[0], d21[2]
STORE_12x7 d22, d23[0], d23[2]
STORE_12x7 d24, d25[0], d25[2]
STORE_12x7 d26, d27[0], d27[2]
STORE_12x7 d28, d29[0], d29[2]
STORE_12x7 d30, d31[0], d31[2]
b WriteEnd
WriteRowx7:
STORE_C7 d8, d9[0], d9[2], #1
STORE_C7 d10, d11[0], d11[2], #2
STORE_C7 d12, d13[0], d13[2], #3
STORE_C7 d14, d15[0], d15[2], #4
STORE_C7 d16, d17[0], d17[2], #5
STORE_C7 d18, d19[0], d19[2], #6
STORE_C7 d20, d21[0], d21[2], #7
STORE_C7 d22, d23[0], d23[2], #8
STORE_C7 d24, d25[0], d25[2], #9
STORE_C7 d26, d27[0], d27[2], #10
STORE_C7 d28, d29[0], d29[2], #11
STORE_C7 d30, d31[0], d31[2], #12
b WriteEnd
WriteEnd:
cmp col, #8
ble LoopColEnd
sub col, col, #8
ldr lr, [sp, #20]
cmp lr, #2
beq LoopCol8
add dst_tmp, dst_tmp, #16
b LoopCol8
LoopColEnd:
cmp row, #12
ble LoopRowEnd
sub row, row, #12
mov a_tmp, a
mov weight, b_tmp
ldr lr, [sp, #20]
cmp lr, #2
beq WinogradDst
ldr lr, [sp, #12]
sub lr, lr, col
add lr, lr, lr // col *= 2
sub dst_tmp, dst, lr
b LoopRow
WinogradDst:
add dst_tmp, dst, r9
LoopRow:
mov dst, dst_tmp
ldr col, [sp, #12]
b LoopRow12
LoopRowEnd:
sub sp, sp, #104
vpop {q4-q7}
pop {r3-r11, pc}
#endif

View File

@ -1,667 +0,0 @@
#ifdef ENABLE_ARM64
#include "nnacl/assembly_global.h"
.text
.align 5
// void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias,
// size_t step, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, size_t relu6);
// x0: output, x1: input, x2: weight, x3: bias, x4: step, x5: ic4, x6: oc8, x7: offset,
// x8:mode, x9: writeC4, x10:relu, x11: relu6
// compute 8 channel for 16 outputs
asm_function IndirectGemmFp16_16x8
.macro INIT_BIAS
dup v16.4s, wzr
cbz x3, InitBias
ld1 {v16.8h}, [x3]
InitBias:
mov v17.16b, v16.16b
mov v18.16b, v16.16b
mov v19.16b, v16.16b
mov v20.16b, v16.16b
mov v21.16b, v16.16b
mov v22.16b, v16.16b
mov v23.16b, v16.16b
mov v24.16b, v16.16b
mov v25.16b, v16.16b
mov v26.16b, v16.16b
mov v27.16b, v16.16b
mov v28.16b, v16.16b
mov v29.16b, v16.16b
mov v30.16b, v16.16b
mov v31.16b, v16.16b
.endm
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// x19 ~ r29 should be also preserved
// whereas our coding style do not permit such amount of parameters
sub sp, sp, #144
// performance between storing 4 registers at the same time and separately storing them on in-order cores
// is not tested yet
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
ldr x8, [sp, #0]
ldr x9, [sp, #8]
ldr x10, [sp, #16]
ldr x11, [sp, #24]
cbnz x8, IndirectGemmStart
// step is one for common convolution, where ic8 should multiply by kernel size
// step is (a+b-1) for F(a,b) in winograd
mul x5, x4, x5
mov x4, #1
IndirectGemmStart:
LoopOc:
mov x14, x4
mov x12, x1
LoopKsize:
mov x15, x0
INIT_BIAS
// load input for output 1-8
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64
// load weight
ld1 {v8.8h, v9.8h}, [x2], #32
// first 2 steps for output 1 and 3
fmla v16.8h, v8.8h, v0.h[0]
fmla v18.8h, v8.8h, v1.h[0]
fmla v16.8h, v9.8h, v0.h[1]
fmla v18.8h, v9.8h, v1.h[1]
// load weight
ld1 {v10.8h, v11.8h}, [x2], #32
// first 2 steps for output 2 and 4
fmla v17.8h, v8.8h, v0.h[4]
fmla v19.8h, v8.8h, v1.h[4]
fmla v17.8h, v9.8h, v0.h[5]
fmla v19.8h, v9.8h, v1.h[5]
// load input for output 9-16
// input cache should be refreshed after loading
// ATTENTION: advance is preferred, but advancing too much may lead to invalid prefetching
ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64
// last 2 steps for output 1 and 3
fmla v16.8h, v10.8h, v0.h[2]
fmla v18.8h, v10.8h, v1.h[2]
fmla v16.8h, v11.8h, v0.h[3]
fmla v18.8h, v11.8h, v1.h[3]
// check if ic4=1
subs x13, x5, #1
beq LoopIcEnd
LoopIc:
// last 2 steps for output 2 and 4
fmla v17.8h, v10.8h, v0.h[6]
fmla v19.8h, v10.8h, v1.h[6]
fmla v17.8h, v11.8h, v0.h[7]
fmla v19.8h, v11.8h, v1.h[7]
// steps for output 5-8
fmla v20.8h, v8.8h, v2.h[0]
fmla v22.8h, v8.8h, v3.h[0]
fmla v20.8h, v9.8h, v2.h[1]
fmla v22.8h, v9.8h, v3.h[1]
fmla v21.8h, v8.8h, v2.h[4]
fmla v23.8h, v8.8h, v3.h[4]
fmla v21.8h, v9.8h, v2.h[5]
fmla v23.8h, v9.8h, v3.h[5]
fmla v20.8h, v10.8h, v2.h[2]
fmla v22.8h, v10.8h, v3.h[2]
fmla v20.8h, v11.8h, v2.h[3]
fmla v22.8h, v11.8h, v3.h[3]
fmla v21.8h, v10.8h, v2.h[6]
fmla v23.8h, v10.8h, v3.h[6]
fmla v21.8h, v11.8h, v2.h[7]
fmla v23.8h, v11.8h, v3.h[7]
// load input for output 1-8
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64
// steps for output 9-12
fmla v24.8h, v8.8h, v4.h[0]
fmla v26.8h, v8.8h, v5.h[0]
fmla v24.8h, v9.8h, v4.h[1]
fmla v26.8h, v9.8h, v5.h[1]
fmla v25.8h, v8.8h, v4.h[4]
fmla v27.8h, v8.8h, v5.h[4]
fmla v25.8h, v9.8h, v4.h[5]
fmla v27.8h, v9.8h, v5.h[5]
fmla v24.8h, v10.8h, v4.h[2]
fmla v26.8h, v10.8h, v5.h[2]
fmla v24.8h, v11.8h, v4.h[3]
fmla v26.8h, v11.8h, v5.h[3]
fmla v25.8h, v10.8h, v4.h[6]
fmla v27.8h, v10.8h, v5.h[6]
fmla v25.8h, v11.8h, v4.h[7]
fmla v27.8h, v11.8h, v5.h[7]
// steps for output 13-16
fmla v28.8h, v8.8h, v6.h[0]
fmla v30.8h, v8.8h, v7.h[0]
fmla v28.8h, v9.8h, v6.h[1]
fmla v30.8h, v9.8h, v7.h[1]
fmla v29.8h, v8.8h, v6.h[4]
fmla v31.8h, v8.8h, v7.h[4]
fmla v29.8h, v9.8h, v6.h[5]
fmla v31.8h, v9.8h, v7.h[5]
// load weight
ld1 {v8.8h, v9.8h}, [x2], #32
fmla v28.8h, v10.8h, v6.h[2]
fmla v30.8h, v10.8h, v7.h[2]
fmla v28.8h, v11.8h, v6.h[3]
fmla v30.8h, v11.8h, v7.h[3]
fmla v29.8h, v10.8h, v6.h[6]
fmla v31.8h, v10.8h, v7.h[6]
fmla v29.8h, v11.8h, v6.h[7]
fmla v31.8h, v11.8h, v7.h[7]
// load weight
ld1 {v10.8h, v11.8h}, [x2], #32
// first 2 steps for output 1-4
fmla v16.8h, v8.8h, v0.h[0]
fmla v18.8h, v8.8h, v1.h[0]
fmla v16.8h, v9.8h, v0.h[1]
fmla v18.8h, v9.8h, v1.h[1]
fmla v17.8h, v8.8h, v0.h[4]
fmla v19.8h, v8.8h, v1.h[4]
fmla v17.8h, v9.8h, v0.h[5]
fmla v19.8h, v9.8h, v1.h[5]
// load input for output 9-16
ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64
// last 2 steps for output 1 and 3
fmla v16.8h, v10.8h, v0.h[2]
fmla v18.8h, v10.8h, v1.h[2]
fmla v16.8h, v11.8h, v0.h[3]
fmla v18.8h, v11.8h, v1.h[3]
subs x13, x13, #1
bne LoopIc
LoopIcEnd:
fmla v17.8h, v10.8h, v0.h[6]
fmla v19.8h, v10.8h, v1.h[6]
fmla v17.8h, v11.8h, v0.h[7]
fmla v19.8h, v11.8h, v1.h[7]
// steps for output 5-8
fmla v20.8h, v8.8h, v2.h[0]
fmla v22.8h, v8.8h, v3.h[0]
fmla v20.8h, v9.8h, v2.h[1]
fmla v22.8h, v9.8h, v3.h[1]
fmla v21.8h, v8.8h, v2.h[4]
fmla v23.8h, v8.8h, v3.h[4]
fmla v21.8h, v9.8h, v2.h[5]
fmla v23.8h, v9.8h, v3.h[5]
fmla v20.8h, v10.8h, v2.h[2]
fmla v22.8h, v10.8h, v3.h[2]
fmla v20.8h, v11.8h, v2.h[3]
fmla v22.8h, v11.8h, v3.h[3]
fmla v21.8h, v10.8h, v2.h[6]
fmla v23.8h, v10.8h, v3.h[6]
fmla v21.8h, v11.8h, v2.h[7]
fmla v23.8h, v11.8h, v3.h[7]
// steps for output 9-12
fmla v24.8h, v8.8h, v4.h[0]
fmla v26.8h, v8.8h, v5.h[0]
fmla v24.8h, v9.8h, v4.h[1]
fmla v26.8h, v9.8h, v5.h[1]
fmla v25.8h, v8.8h, v4.h[4]
fmla v27.8h, v8.8h, v5.h[4]
fmla v25.8h, v9.8h, v4.h[5]
fmla v27.8h, v9.8h, v5.h[5]
fmla v24.8h, v10.8h, v4.h[2]
fmla v26.8h, v10.8h, v5.h[2]
fmla v24.8h, v11.8h, v4.h[3]
fmla v26.8h, v11.8h, v5.h[3]
fmla v25.8h, v10.8h, v4.h[6]
fmla v27.8h, v10.8h, v5.h[6]
fmla v25.8h, v11.8h, v4.h[7]
fmla v27.8h, v11.8h, v5.h[7]
// steps for output 13-16
fmla v28.8h, v8.8h, v6.h[0]
fmla v30.8h, v8.8h, v7.h[0]
fmla v28.8h, v9.8h, v6.h[1]
fmla v30.8h, v9.8h, v7.h[1]
fmla v29.8h, v8.8h, v6.h[4]
fmla v31.8h, v8.8h, v7.h[4]
fmla v29.8h, v9.8h, v6.h[5]
fmla v31.8h, v9.8h, v7.h[5]
fmla v28.8h, v10.8h, v6.h[2]
fmla v30.8h, v10.8h, v7.h[2]
fmla v28.8h, v11.8h, v6.h[3]
fmla v30.8h, v11.8h, v7.h[3]
fmla v29.8h, v10.8h, v6.h[6]
fmla v31.8h, v10.8h, v7.h[6]
fmla v29.8h, v11.8h, v6.h[7]
fmla v31.8h, v11.8h, v7.h[7]
cbnz x11, Relu6
cbnz x10, Relu
b WriteStart
Relu6:
movi v9.8h, #0x46, lsl #8
fmin v16.8h, v16.8h, v9.8h
fmin v17.8h, v17.8h, v9.8h
fmin v18.8h, v18.8h, v9.8h
fmin v19.8h, v19.8h, v9.8h
fmin v20.8h, v20.8h, v9.8h
fmin v21.8h, v21.8h, v9.8h
fmin v22.8h, v22.8h, v9.8h
fmin v23.8h, v23.8h, v9.8h
fmin v24.8h, v24.8h, v9.8h
fmin v25.8h, v25.8h, v9.8h
fmin v26.8h, v26.8h, v9.8h
fmin v27.8h, v27.8h, v9.8h
fmin v28.8h, v28.8h, v9.8h
fmin v29.8h, v29.8h, v9.8h
fmin v30.8h, v30.8h, v9.8h
fmin v31.8h, v31.8h, v9.8h
Relu:
dup v8.4s, wzr
fmax v16.8h, v16.8h, v8.8h
fmax v17.8h, v17.8h, v8.8h
fmax v18.8h, v18.8h, v8.8h
fmax v19.8h, v19.8h, v8.8h
fmax v20.8h, v20.8h, v8.8h
fmax v21.8h, v21.8h, v8.8h
fmax v22.8h, v22.8h, v8.8h
fmax v23.8h, v23.8h, v8.8h
fmax v24.8h, v24.8h, v8.8h
fmax v25.8h, v25.8h, v8.8h
fmax v26.8h, v26.8h, v8.8h
fmax v27.8h, v27.8h, v8.8h
fmax v28.8h, v28.8h, v8.8h
fmax v29.8h, v29.8h, v8.8h
fmax v30.8h, v30.8h, v8.8h
fmax v31.8h, v31.8h, v8.8h
WriteStart:
cbnz x9, Write8
cmp x6, #1
beq Write1
cmp x6, #2
beq Write2
cmp x6, #3
beq Write3
cmp x6, #4
beq Write4
cmp x6, #5
beq Write5
cmp x6, #6
beq Write6
cmp x6, #7
beq Write7
b Write8
// prefetching is not preferred while writing results in spite of cache missing
// you could try prfm pstl2strm
// there are almost no benefits observed though
Write1:
str h16, [x15]
add x15, x15, x7
str h17, [x15]
add x15, x15, x7
str h18, [x15]
add x15, x15, x7
str h19, [x15]
add x15, x15, x7
str h20, [x15]
add x15, x15, x7
str h21, [x15]
add x15, x15, x7
str h22, [x15]
add x15, x15, x7
str h23, [x15]
add x15, x15, x7
str h24, [x15]
add x15, x15, x7
str h25, [x15]
add x15, x15, x7
str h26, [x15]
add x15, x15, x7
str h27, [x15]
add x15, x15, x7
str h28, [x15]
add x15, x15, x7
str h29, [x15]
add x15, x15, x7
str h30, [x15]
add x15, x15, x7
str h31, [x15]
add x0, x0, #2
b WriteEnd
Write2:
add x17, x15, #2
st1 {v16.h}[0], [x15], x7
st1 {v16.h}[1], [x17], x7
st1 {v17.h}[0], [x15], x7
st1 {v17.h}[1], [x17], x7
st1 {v18.h}[0], [x15], x7
st1 {v18.h}[1], [x17], x7
st1 {v19.h}[0], [x15], x7
st1 {v19.h}[1], [x17], x7
st1 {v20.h}[0], [x15], x7
st1 {v20.h}[1], [x17], x7
st1 {v21.h}[0], [x15], x7
st1 {v21.h}[1], [x17], x7
st1 {v22.h}[0], [x15], x7
st1 {v22.h}[1], [x17], x7
st1 {v23.h}[0], [x15], x7
st1 {v23.h}[1], [x17], x7
st1 {v24.h}[0], [x15], x7
st1 {v24.h}[1], [x17], x7
st1 {v25.h}[0], [x15], x7
st1 {v25.h}[1], [x17], x7
st1 {v26.h}[0], [x15], x7
st1 {v26.h}[1], [x17], x7
st1 {v27.h}[0], [x15], x7
st1 {v27.h}[1], [x17], x7
st1 {v28.h}[0], [x15], x7
st1 {v28.h}[1], [x17], x7
st1 {v29.h}[0], [x15], x7
st1 {v29.h}[1], [x17], x7
st1 {v30.h}[0], [x15], x7
st1 {v30.h}[1], [x17], x7
st1 {v31.h}[0], [x15]
st1 {v31.h}[1], [x17]
add x0, x0, #4
b WriteEnd
Write3:
add x17, x15, #4
add x16, x15, #2
st1 {v16.h}[0], [x15], x7
st1 {v16.h}[1], [x16], x7
st1 {v16.h}[2], [x17], x7
st1 {v17.h}[0], [x15], x7
st1 {v17.h}[1], [x16], x7
st1 {v17.h}[2], [x17], x7
st1 {v18.h}[0], [x15], x7
st1 {v18.h}[1], [x16], x7
st1 {v18.h}[2], [x17], x7
st1 {v19.h}[0], [x15], x7
st1 {v19.h}[1], [x16], x7
st1 {v19.h}[2], [x17], x7
st1 {v20.h}[0], [x15], x7
st1 {v20.h}[1], [x16], x7
st1 {v20.h}[2], [x17], x7
st1 {v21.h}[0], [x15], x7
st1 {v21.h}[1], [x16], x7
st1 {v21.h}[2], [x17], x7
st1 {v22.h}[0], [x15], x7
st1 {v22.h}[1], [x16], x7
st1 {v22.h}[2], [x17], x7
st1 {v23.h}[0], [x15], x7
st1 {v23.h}[1], [x16], x7
st1 {v23.h}[2], [x17], x7
st1 {v24.h}[0], [x15], x7
st1 {v24.h}[1], [x16], x7
st1 {v24.h}[2], [x17], x7
st1 {v25.h}[0], [x15], x7
st1 {v25.h}[1], [x16], x7
st1 {v25.h}[2], [x17], x7
st1 {v26.h}[0], [x15], x7
st1 {v26.h}[1], [x16], x7
st1 {v26.h}[2], [x17], x7
st1 {v27.h}[0], [x15], x7
st1 {v27.h}[1], [x16], x7
st1 {v27.h}[2], [x17], x7
st1 {v28.h}[0], [x15], x7
st1 {v28.h}[1], [x16], x7
st1 {v28.h}[2], [x17], x7
st1 {v29.h}[0], [x15], x7
st1 {v29.h}[1], [x16], x7
st1 {v29.h}[2], [x17], x7
st1 {v30.h}[0], [x15], x7
st1 {v30.h}[1], [x16], x7
st1 {v30.h}[2], [x17], x7
st1 {v31.h}[0], [x15]
st1 {v31.h}[1], [x16]
st1 {v31.h}[2], [x17]
add x0, x0, #6
b WriteEnd
Write4:
st1 {v16.4h}, [x15], x7
st1 {v17.4h}, [x15], x7
st1 {v18.4h}, [x15], x7
st1 {v19.4h}, [x15], x7
st1 {v20.4h}, [x15], x7
st1 {v21.4h}, [x15], x7
st1 {v22.4h}, [x15], x7
st1 {v23.4h}, [x15], x7
st1 {v24.4h}, [x15], x7
st1 {v25.4h}, [x15], x7
st1 {v26.4h}, [x15], x7
st1 {v27.4h}, [x15], x7
st1 {v28.4h}, [x15], x7
st1 {v29.4h}, [x15], x7
st1 {v30.4h}, [x15], x7
st1 {v31.4h}, [x15]
add x0, x0, #8
b WriteEnd
Write5:
add x17, x15, #8
st1 {v16.4h}, [x15], x7
st1 {v16.h}[4], [x17], x7
st1 {v17.4h}, [x15], x7
st1 {v17.h}[4], [x17], x7
st1 {v18.4h}, [x15], x7
st1 {v18.h}[4], [x17], x7
st1 {v19.4h}, [x15], x7
st1 {v19.h}[4], [x17], x7
st1 {v20.4h}, [x15], x7
st1 {v20.h}[4], [x17], x7
st1 {v21.4h}, [x15], x7
st1 {v21.h}[4], [x17], x7
st1 {v22.4h}, [x15], x7
st1 {v22.h}[4], [x17], x7
st1 {v23.4h}, [x15], x7
st1 {v23.h}[4], [x17], x7
st1 {v24.4h}, [x15], x7
st1 {v24.h}[4], [x17], x7
st1 {v25.4h}, [x15], x7
st1 {v25.h}[4], [x17], x7
st1 {v26.4h}, [x15], x7
st1 {v26.h}[4], [x17], x7
st1 {v27.4h}, [x15], x7
st1 {v27.h}[4], [x17], x7
st1 {v28.4h}, [x15], x7
st1 {v28.h}[4], [x17], x7
st1 {v29.4h}, [x15], x7
st1 {v29.h}[4], [x17], x7
st1 {v30.4h}, [x15], x7
st1 {v30.h}[4], [x17], x7
st1 {v31.4h}, [x15]
st1 {v31.h}[4], [x17]
add x0, x0, #10
b WriteEnd
Write6:
add x17, x15, #8
add x16, x15, #10
st1 {v16.4h}, [x15], x7
ins v0.s[0], v16.s[2]
st1 {v0.h}[0], [x17], x7
st1 {v0.h}[1], [x16], x7
st1 {v17.4h}, [x15], x7
ins v1.s[0], v17.s[2]
st1 {v1.h}[0], [x17], x7
st1 {v1.h}[1], [x16], x7
st1 {v18.4h}, [x15], x7
ins v2.s[0], v18.s[2]
st1 {v2.h}[0], [x17], x7
st1 {v2.h}[1], [x16], x7
st1 {v19.4h}, [x15], x7
ins v3.s[0], v19.s[2]
st1 {v3.h}[0], [x17], x7
st1 {v3.h}[1], [x16], x7
st1 {v20.4h}, [x15], x7
ins v4.s[0], v20.s[2]
st1 {v4.h}[0], [x17], x7
st1 {v4.h}[1], [x16], x7
st1 {v21.4h}, [x15], x7
ins v5.s[0], v21.s[2]
st1 {v5.h}[0], [x17], x7
st1 {v5.h}[1], [x16], x7
st1 {v22.4h}, [x15], x7
ins v6.s[0], v22.s[2]
st1 {v6.h}[0], [x17], x7
st1 {v6.h}[1], [x16], x7
st1 {v23.4h}, [x15], x7
ins v7.s[0], v23.s[2]
st1 {v7.h}[0], [x17], x7
st1 {v7.h}[1], [x16], x7
st1 {v24.4h}, [x15], x7
ins v8.s[0], v24.s[2]
st1 {v8.h}[0], [x17], x7
st1 {v8.h}[1], [x16], x7
st1 {v25.4h}, [x15], x7
ins v9.s[0], v25.s[2]
st1 {v9.h}[0], [x17], x7
st1 {v9.h}[1], [x16], x7
st1 {v26.4h}, [x15], x7
ins v10.s[0], v26.s[2]
st1 {v10.h}[0], [x17], x7
st1 {v10.h}[1], [x16], x7
st1 {v27.4h}, [x15], x7
ins v11.s[0], v27.s[2]
st1 {v11.h}[0], [x17], x7
st1 {v11.h}[1], [x16], x7
st1 {v28.4h}, [x15], x7
ins v12.s[0], v28.s[2]
st1 {v12.h}[0], [x17], x7
st1 {v12.h}[1], [x16], x7
st1 {v29.4h}, [x15], x7
ins v13.s[0], v29.s[2]
st1 {v13.h}[0], [x17], x7
st1 {v13.h}[1], [x16], x7
st1 {v30.4h}, [x15], x7
ins v14.s[0], v30.s[2]
st1 {v14.h}[0], [x17], x7
st1 {v14.h}[1], [x16], x7
st1 {v31.4h}, [x15]
ins v15.s[0], v31.s[2]
st1 {v14.h}[0], [x17]
st1 {v14.h}[1], [x16]
add x0, x0, #12
b WriteEnd
Write7:
add x17, x15, #8
add x19, x15, #10
add x16, x15, #12
st1 {v16.4h}, [x15], x7
ins v0.s[0], v16.s[2]
st1 {v0.h}[0], [x17], x7
st1 {v0.h}[1], [x19], x7
st1 {v16.h}[6], [x16], x7
st1 {v17.4h}, [x15], x7
ins v1.s[0], v17.s[2]
st1 {v1.h}[0], [x17], x7
st1 {v1.h}[1], [x19], x7
st1 {v17.h}[6], [x16], x7
st1 {v18.4h}, [x15], x7
ins v2.s[0], v18.s[2]
st1 {v2.h}[0], [x17], x7
st1 {v2.h}[1], [x19], x7
st1 {v18.h}[6], [x16], x7
st1 {v19.4h}, [x15], x7
ins v3.s[0], v19.s[2]
st1 {v3.h}[0], [x17], x7
st1 {v3.h}[1], [x19], x7
st1 {v19.h}[6], [x16], x7
st1 {v20.4h}, [x15], x7
ins v4.s[0], v20.s[2]
st1 {v4.h}[0], [x17], x7
st1 {v4.h}[1], [x19], x7
st1 {v20.h}[6], [x16], x7
st1 {v21.4h}, [x15], x7
ins v5.s[0], v21.s[2]
st1 {v5.h}[0], [x17], x7
st1 {v5.h}[1], [x19], x7
st1 {v21.h}[6], [x16], x7
st1 {v22.4h}, [x15], x7
ins v6.s[0], v22.s[2]
st1 {v6.h}[0], [x17], x7
st1 {v6.h}[1], [x19], x7
st1 {v22.h}[6], [x16], x7
st1 {v23.4h}, [x15], x7
ins v7.s[0], v23.s[2]
st1 {v7.h}[0], [x17], x7
st1 {v7.h}[1], [x19], x7
st1 {v23.h}[6], [x16], x7
st1 {v24.4h}, [x15], x7
ins v8.s[0], v24.s[2]
st1 {v8.h}[0], [x17], x7
st1 {v8.h}[1], [x19], x7
st1 {v24.h}[6], [x16], x7
st1 {v25.4h}, [x15], x7
ins v9.s[0], v25.s[2]
st1 {v9.h}[0], [x17], x7
st1 {v9.h}[1], [x19], x7
st1 {v25.h}[6], [x16], x7
st1 {v26.4h}, [x15], x7
ins v10.s[0], v26.s[2]
st1 {v10.h}[0], [x17], x7
st1 {v10.h}[1], [x19], x7
st1 {v26.h}[6], [x16], x7
st1 {v27.4h}, [x15], x7
ins v11.s[0], v27.s[2]
st1 {v11.h}[0], [x17], x7
st1 {v11.h}[1], [x19], x7
st1 {v27.h}[6], [x16], x7
st1 {v28.4h}, [x15], x7
ins v12.s[0], v28.s[2]
st1 {v12.h}[0], [x17], x7
st1 {v12.h}[1], [x19], x7
st1 {v28.h}[6], [x16], x7
st1 {v29.4h}, [x15], x7
ins v13.s[0], v29.s[2]
st1 {v13.h}[0], [x17], x7
st1 {v13.h}[1], [x19], x7
st1 {v29.h}[6], [x16], x7
st1 {v30.4h}, [x15], x7
ins v14.s[0], v30.s[2]
st1 {v14.h}[0], [x17], x7
st1 {v14.h}[1], [x19], x7
st1 {v30.h}[6], [x16], x7
st1 {v31.4h}, [x15]
ins v15.s[0], v31.s[2]
st1 {v15.h}[0], [x17]
st1 {v15.h}[1], [x19]
st1 {v31.h}[6], [x16]
add x0, x0, #14
b WriteEnd
Write8:
st1 {v16.8h}, [x15], x7
st1 {v17.8h}, [x15], x7
st1 {v18.8h}, [x15], x7
st1 {v19.8h}, [x15], x7
st1 {v20.8h}, [x15], x7
st1 {v21.8h}, [x15], x7
st1 {v22.8h}, [x15], x7
st1 {v23.8h}, [x15], x7
st1 {v24.8h}, [x15], x7
st1 {v25.8h}, [x15], x7
st1 {v26.8h}, [x15], x7
st1 {v27.8h}, [x15], x7
st1 {v28.8h}, [x15], x7
st1 {v29.8h}, [x15], x7
st1 {v30.8h}, [x15], x7
st1 {v31.8h}, [x15]
add x0, x0, #16
WriteEnd:
subs x14, x14, #1
bne LoopKsize
subs x6, x6, #8
cbz x3, NoStepForward
add x3, x3, #16
NoStepForward:
bgt LoopOc
sub sp, sp, #144
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ret
#endif

View File

@ -29,7 +29,7 @@ int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) {
}
#endif
for (; offset < ele_num; offset++) {
dst[offset] = src[offset] < 0 ? 0 : src[offset];
dst[offset] = src[offset] < 0.0f ? 0.0f : src[offset];
}
return NNACL_OK;
}
@ -47,14 +47,24 @@ int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) {
}
#endif
for (; offset < ele_num; offset++) {
dst[offset] = data[offset] < 0 ? 0 : data[offset];
dst[offset] = dst[offset] > 6 ? 6 : dst[offset];
dst[offset] = data[offset] < 0.0f ? 0.0f : data[offset];
dst[offset] = dst[offset] > 6.0f ? 6.0f : dst[offset];
}
return NNACL_OK;
}
int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha) {
for (int i = 0; i < ele_num; ++i) {
int i = 0;
#ifdef ENABLE_NEON
int ele_c8 = UP_ROUND(ele_num, C8NUM);
for (; i < ele_c8; i += C8NUM) {
float16x8_t src_tmp = vld1q_f16(src + i);
float16x8_t mul_tmp = vmulq_n_f16(src_tmp, alpha);
float16x8_t mask = vcgtq_f16(src_tmp, vdupq_n_f16(0.0f));
vst1q_f16(dst + i, vbslq_f32(mask, src_tmp, mul_tmp));
}
#endif
for (; i < ele_num; ++i) {
dst[i] = src[i] > (float16_t)0.0f ? src[i] : (src[i] * alpha);
}
return NNACL_OK;
@ -62,12 +72,12 @@ int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha
int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) {
int i = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
int count = (ele_num / C4NUM) * C4NUM;
for (; i < count; i += C4NUM) {
float32x4_t tmp;
simd_exp(vnegq_f32(vcvt_f32_f16(vld1_f16(src + i))), (float *)&tmp);
vst1_f16(dst + i, vcvt_f16_f32(vdivq_f32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), tmp))));
vst1_f16(dst + i, vcvt_f16_f32(MS_DIVQ_F32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), tmp))));
}
#endif
for (; i < ele_num; ++i) {
@ -79,9 +89,9 @@ int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) {
}
float16_t TanhOptFp16(float16_t src) {
if (src > 5.0) {
if (src > 5.0f) {
return 1.0f;
} else if (src < -5.0) {
} else if (src < -5.0f) {
return -1.0f;
} else {
float square = src * src;
@ -93,7 +103,7 @@ float16_t TanhOptFp16(float16_t src) {
int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) {
int i = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_NEON
static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f},
{17325.0f, 17325.0f, 17325.0f, 17325.0f},
{135135.0f, 135135.0f, 135135.0f, 135135.0f},
@ -112,7 +122,7 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) {
float32x4_t b = vaddq_f32(
vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square),
paramv[2]);
vst1_f16(dst + i, vcvt_f16_f32(vminq_f32(vmaxq_f32(vdivq_f32(a, b), neg_one), pos_one)));
vst1_f16(dst + i, vcvt_f16_f32(vminq_f32(vmaxq_f32(MS_DIVQ_F32(a, b), neg_one), pos_one)));
}
#endif
for (; i < ele_num; ++i) {
@ -130,7 +140,7 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) {
int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num) {
for (int i = 0; i < ele_num; ++i) {
float16_t in = src[i];
float16_t relu6 = MSMIN(MSMAX(in + 3, 0), 6);
float16_t relu6 = MSMIN(MSMAX(in + 3.0f, 0.0f), 6.0f);
dst[i] = in * relu6 / (float16_t)6.0f;
}
return NNACL_OK;
@ -181,26 +191,26 @@ int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate)
for (; i < C8; i += C8NUM) {
float16x8_t in = vld1q_f16(src + i);
float16x8_t res =
0.5 * in * (1.0 + MS_TANHX8_F16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * in * in) * in));
0.5f * in * (1.0f + MS_TANHX8_F16(((float16_t)0.79788456080287f + (float16_t)0.035677408136f * in * in) * in));
vst1q_f16(dst + i, res);
}
#endif
for (; i < length; i++) {
dst[i] =
0.5 * src[i] *
(1.0 + TanhOptFp16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * src[i] * src[i]) * src[i]));
0.5f * src[i] *
(1.0f + TanhOptFp16(((float16_t)0.79788456080287f + (float16_t)0.035677408136f * src[i] * src[i]) * src[i]));
}
} else {
#ifdef ENABLE_NEON
int C8 = UP_ROUND(length, C8NUM);
for (; i < C8; i += C8NUM) {
float16x8_t in = vld1q_f16(src + i);
const float16x8_t res = 0.5 * in * (1.0 + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f));
const float16x8_t res = 0.5f * in * (1.0f + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f));
vst1q_f16(dst + i, res);
}
#endif
for (; i < length; i++) {
dst[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f));
dst[i] = 0.5f * src[i] * (1.0f + erff(src[i] / 1.4142135623730951f));
}
}
return NNACL_OK;

View File

@ -16,11 +16,9 @@
#ifndef MINDSPORE_NNACL_FP16_ACTIVATION_FP16_H_
#define MINDSPORE_NNACL_FP16_ACTIVATION_FP16_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
#include "nnacl/int8/fixed_point.h"
#ifdef __cplusplus

View File

@ -569,7 +569,7 @@ int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t *
for (; index <= element_size - 8; index += C8NUM) {
float16x8_t vin0 = vld1q_f16(input0 + index);
float16x8_t vin1 = vld1q_f16(input1 + index);
float16x8_t vout = vdivq_f16(vin0, vin1);
float16x8_t vout = MS_DIVQ_F16(vin0, vin1);
vst1q_f16(output + index, vout);
}
#endif
@ -591,7 +591,7 @@ int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
float16x8_t vin1 = vld1q_f16(input1 + index);
float16x8_t vout = vdivq_f16(vin0_opt, vin1);
float16x8_t vout = MS_DIVQ_F16(vin0_opt, vin1);
vst1q_f16(output + index, vout);
}
#endif
@ -606,7 +606,7 @@ int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
float16x8_t vin0 = vld1q_f16(input0 + index);
float16x8_t vout = vdivq_f16(vin0, vin1_opt);
float16x8_t vout = MS_DIVQ_F16(vin0, vin1_opt);
vst1q_f16(output + index, vout);
}
#endif
@ -624,7 +624,7 @@ int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16
for (; index <= element_size - 8; index += C8NUM) {
float16x8_t vin0 = vld1q_f16(input0 + index);
float16x8_t vin1 = vld1q_f16(input1 + index);
float16x8_t vout = vdivq_f16(vin0, vin1);
float16x8_t vout = MS_DIVQ_F16(vin0, vin1);
vout = vmaxq_f16(vout, zeros);
vst1q_f16(output + index, vout);
}
@ -652,7 +652,7 @@ int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, floa
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
float16x8_t vin1 = vld1q_f16(input1 + index);
float16x8_t vout = vmaxq_f16(vdivq_f16(vin0_opt, vin1), zeros);
float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros);
vst1q_f16(output + index, vout);
}
#endif
@ -670,7 +670,7 @@ int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, floa
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
float16x8_t vin0 = vld1q_f16(input0 + index);
float16x8_t vout = vmaxq_f16(vdivq_f16(vin0, vin1_opt), zeros);
float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros);
vst1q_f16(output + index, vout);
}
#endif
@ -689,7 +689,7 @@ int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float1
for (; index <= element_size - 8; index += C8NUM) {
float16x8_t vin0 = vld1q_f16(input0 + index);
float16x8_t vin1 = vld1q_f16(input1 + index);
float16x8_t vout = vdivq_f16(vin0, vin1);
float16x8_t vout = MS_DIVQ_F16(vin0, vin1);
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
vst1q_f16(output + index, vout);
}
@ -716,7 +716,7 @@ int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, flo
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
float16x8_t vin1 = vld1q_f16(input1 + index);
float16x8_t vout = vminq_f16(vmaxq_f16(vdivq_f16(vin0_opt, vin1), zeros), bounds);
float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros), bounds);
vst1q_f16(output + index, vout);
}
#endif
@ -733,7 +733,7 @@ int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, flo
#ifdef ENABLE_NEON
for (; index <= element_size - 8; index += C8NUM) {
float16x8_t vin0 = vld1q_f16(input0 + index);
float16x8_t vout = vminq_f16(vmaxq_f16(vdivq_f16(vin0, vin1_opt), zeros), bounds);
float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros), bounds);
vst1q_f16(output + index, vout);
}
#endif

View File

@ -16,10 +16,8 @@
#ifndef MINDSPORE_NNACL_FP16_ARITHMETIC_FP16_H_
#define MINDSPORE_NNACL_FP16_ARITHMETIC_FP16_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/op_base.h"
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
#include "nnacl/base/arithmetic_base.h"
#include "nnacl/errorcode.h"

View File

@ -80,7 +80,7 @@ int ElementLogicalNotFp16(float16_t *input, float16_t *output, int element_size)
int ElementRoundFp16(float16_t *input, float16_t *output, int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = round(input[i]);
output[i] = roundf(input[i]);
}
return NNACL_OK;
}
@ -94,7 +94,7 @@ int ElementFloorFp16(float16_t *input, float16_t *output, int element_size) {
int ElementCeilFp16(float16_t *input, float16_t *output, int number) {
for (int i = 0; i < number; ++i) {
output[i] = ceil(input[i]);
output[i] = ceilf(input[i]);
}
return NNACL_OK;
}

View File

@ -16,10 +16,8 @@
#ifndef MINDSPORE_NNACL_FP16_ARITHMETIC_SELF_FP16_H_
#define MINDSPORE_NNACL_FP16_ARITHMETIC_SELF_FP16_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/op_base.h"
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
#include "nnacl/errorcode.h"
#ifdef __cplusplus

View File

@ -26,7 +26,7 @@ void BatchNormFp16(const float16_t *input, const void *mean, const void *varianc
for (int i = 0; i < cur_unit; i++) {
for (int c = 0; c < param->channel_; c++) {
float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_);
float16_t variance_sqrt = sqrtf(((const float16_t *)variance)[c] + param->epsilon_);
if (variance_sqrt != 0) {
output[cur_offset + c] = (input[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;
}
@ -44,7 +44,7 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset
for (int i = 0; i < cur_unit; i++) {
for (int c = 0; c < param->channel_; c++) {
float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_);
float16_t variance_sqrt = sqrtf(((const float16_t *)variance)[c] + param->epsilon_);
if (variance_sqrt != 0) {
float16_t norm_val =
(((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt;

View File

@ -13,12 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_BATCHNORM_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_BATCHNORM_FP16_H_
#ifndef MINDSPORE_NNACL_FP16_BATCHNORM_FP16_H_
#define MINDSPORE_NNACL_FP16_BATCHNORM_FP16_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/batchnorm_parameter.h"
#ifdef __cplusplus
@ -34,4 +31,4 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset
}
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_BATCHNORM_FP16_H_
#endif // MINDSPORE_NNACL_FP16_BATCHNORM_FP16_H_

View File

@ -16,7 +16,6 @@
#ifndef MINDSPORE_NNACL_CAST_FP16_H_
#define MINDSPORE_NNACL_CAST_FP16_H_
#include <arm_neon.h>
#include "nnacl/op_base.h"
#ifdef __cplusplus

View File

@ -56,3 +56,17 @@ void PostConvFuncFp16C4(const float16_t *c4_out, float16_t *nhwc_out, const floa
PostFuncBiasReluC4Fp16(nhwc_out, c4_out, bias, oc4div, oc4mod, plane, stride_size, act_type);
return;
}
#ifdef ENABLE_ARM82_A32
void PostFuncBiasReluC4Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc4div, size_t oc4mod,
size_t plane_size, size_t plane_stride, size_t relu_type) {
// TODO(fun): function
return;
}
void PostFuncBiasReluC8Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc8div, size_t oc8mod,
size_t plane_size, size_t stride, size_t relu_type) {
// TODO(fun): function
return;
}
#endif

View File

@ -16,7 +16,6 @@
#ifndef MINDSPORE_NNACL_FP16_COMMON_FUNC_FP16_H_
#define MINDSPORE_NNACL_FP16_COMMON_FUNC_FP16_H_
#include <arm_neon.h>
#include "nnacl/op_base.h"
#ifdef __cplusplus

View File

@ -16,9 +16,6 @@
#ifndef MINDSPORE_NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_
#define MINDSPORE_NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
#include "nnacl/constant_of_shape_parameter.h"
@ -27,7 +24,7 @@
extern "C" {
#endif
#ifdef __cplusplus
#ifdef ENABLE_NEON
#ifdef ENABLE_FP16
inline int ConstantOfShapeFp16(float16_t *output, int start, int end, float16_t value) {
for (int i = start; i < end; i++) {
output[i] = value;

View File

@ -18,6 +18,18 @@
#include <string.h>
#include "nnacl/fp16/activation_fp16.h"
#ifdef ENABLE_ARM82_A32
void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *weight_ptr, size_t num_pixels,
size_t output_channel, size_t input_step) {
for (int i = 0; i < num_pixels; i++) {
for (int c = 0; c < output_channel; c++) {
*output_ptr++ += weight_ptr[c] * input_ptr[c];
}
input_ptr += input_step;
}
}
#endif
void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data,
const float16_t *bias_data, const ConvParameter *conv_param, int task_id) {
int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
@ -57,7 +69,6 @@ void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float
const float16_t *src_kw = src_kh + iw_origin * conv_param->input_channel_;
int num_pixels = out_w_end - out_w_start;
ConvDwFp16Row(dst_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step);
weight_kh += conv_param->output_channel_;
}

View File

@ -23,9 +23,9 @@
#ifdef __cplusplus
extern "C" {
#endif
#ifdef ENABLE_ARM64
void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *filter_ptr, size_t num_pixels,
size_t input_channel, size_t input_step);
#ifdef ENABLE_ARM64
void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias,
size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu,
size_t relu6);

View File

@ -31,7 +31,7 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh
#ifdef __cplusplus
}
#endif
#ifndef ENABLE_NEON
#ifndef ENABLE_ARM64
void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC8, size_t relu,
size_t relu6) {
@ -124,7 +124,11 @@ void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *we
// fp16 convolution common (im2col+gemm)
void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data,
float16_t *col_major_input, float16_t *output_data, int task_id, ConvParameter *conv_param) {
#ifdef ENABLE_ARM64
const int tile_n = 16;
#else
const int tile_n = 12;
#endif
int out_channel = conv_param->output_channel_;
int output_count = conv_param->output_h_ * conv_param->output_w_;
int output_tile_count = UP_DIV(output_count, tile_n);
@ -144,7 +148,11 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);
int out_offset = thread_id * tile_n * out_channel + out_batch_offset;
#ifdef ENABLE_ARM64
RowMajor2Col16MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep);
#else
RowMajor2Col12MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep);
#endif
MatMulFp16(col_major_gemm_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep,
real_cal_num, out_channel, out_channel, OutType_Nhwc);
}
@ -155,7 +163,11 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_
void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data,
float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param,
InputTransFp16Func in_func, OutputTransFp16Func out_func) {
#ifdef ENABLE_ARM64
const int tile_num = 16;
#else
const int tile_num = 12;
#endif
int in_channel = conv_param->input_channel_;
int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_);
int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_);
@ -194,7 +206,11 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset;
float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
#ifdef ENABLE_ARM64
RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
#else
RowMajor2Col12MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
#endif
MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel,
cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8);
}

View File

@ -24,7 +24,7 @@
typedef float16_t *TmpBufferAddressFp16;
typedef float16_t *MatricesFp16;
#ifndef ENABLE_NEON
#ifndef ENABLE_ARM64
void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step,
size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu,
size_t relu6);

View File

@ -17,7 +17,6 @@
#ifndef MINDSPORE_NNACL_FP16_CROP_FP16_H_
#define MINDSPORE_NNACL_FP16_CROP_FP16_H_
#include <arm_neon.h>
#include "nnacl/op_base.h"
#include "nnacl/crop_parameter.h"

View File

@ -53,11 +53,7 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride;
float16_t *tmp_dst = dst_ptr + dst_index;
const float16_t *tmp_src = src_ptr + src_index;
#ifdef DEBUG_CODE
for (int i = 0; i < C8NUM; i++) {
tmp_dst[i] += tmp_src[i];
}
#else
#ifdef ENABLE_ARM64
asm volatile(
"mov x0, %[tmp_src] \n"
"mov x1, %[tmp_dst] \n"
@ -72,6 +68,10 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias,
:
: [ tmp_src ] "r"(tmp_src), [ tmp_dst ] "r"(tmp_dst)
: "x0", "x1", "v0", "v1");
#else
for (int i = 0; i < C8NUM; i++) {
tmp_dst[i] += tmp_src[i];
}
#endif
} /*kw*/
} /*kh*/

View File

@ -47,6 +47,7 @@ void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride,
size_t cuont8 = count / C8NUM * C8NUM;
int i = 0;
for (; i < cuont8; i += C8NUM) {
#ifdef ENABLE_ARM64
size_t src_step = src_stride * sizeof(float16_t);
size_t dst_step = dst_stride * sizeof(float16_t);
asm volatile(
@ -93,7 +94,9 @@ void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride,
:
: [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step)
: "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7");
#else
// TODO(fun): arm32
#endif
src_ptr += C8NUM * src_stride;
dst_ptr += C8NUM * dst_stride;
}
@ -373,3 +376,23 @@ void DeconvWgPostFp16(float16_t *tile_out, float16_t *nc4hw4_output, ConvParamet
}
return;
}
#ifdef ENABLE_ARM82_A32
void WinogradTransLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k,
size_t length) {
// TODO(fun): function
return;
}
void WinogradTransRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k,
size_t length) {
// TODO(fun): function
return;
}
void TiledC4MatmulFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t ic4, size_t cal_num,
size_t oc4) {
// TODO(fun): function
return;
}
#endif

View File

@ -18,6 +18,7 @@
#define MINDSPORE_NNACL_FP16_EXP_H_
#include "nnacl/op_base.h"
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
#ifdef __cplusplus
extern "C" {

View File

@ -16,6 +16,7 @@
#include "nnacl/fp16/instance_norm_fp16.h"
#include <math.h>
#include "nnacl/errorcode.h"
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data,
const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) {

View File

@ -17,7 +17,6 @@
#define MINDSPORE_NNACL_FP16_INSTANCE_NORM_H_
#include "nnacl/instance_norm_parameter.h"
#ifdef __cplusplus
extern "C" {
#endif

View File

@ -21,6 +21,7 @@
#include "nnacl/fp16/arithmetic_fp16.h"
#include "nnacl/fp16/matmul_fp16.h"
#include "nnacl/fp16/cast_fp16.h"
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align) {
for (int i = 0; i < batch; i++) {
@ -121,7 +122,7 @@ int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float1
for (; index <= element_size - 8; index += 8) {
float16x8_t vin0 = vld1q_f16(input0 + index);
float16x8_t vout = vld1q_f16(output + index);
vout = vfmaq_n_f16(vout, vin0, input1);
vout = MS_FMAQ_N_F16(vout, vin0, input1);
vst1q_f16(output + index, vout);
}
for (; index < element_size; index++) {

View File

@ -226,24 +226,43 @@ void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row,
return;
}
void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, int stride, bool write_nhwc) {
if (write_nhwc) {
/* col16-major * row8-major => col-major */
void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, int stride, int write_mode) {
if (write_mode == OutType_Nhwc) {
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r16div = r / 16, r16mod = r % 16;
int c8div = c / 8, c8mod = c % 8;
size_t ci = r * stride + c;
float16_t value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r16div * deep * 16 + d * 16 + r16mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
ADD_BIAS(value, bias, c)
DO_RELU(value, act_type)
DO_RELU6(value, act_type)
dst[ci] = value;
}
}
} else if (write_mode == OutType_C8) {
int col_8 = UP_ROUND(col, C8NUM);
int row_16 = UP_ROUND(row, C16NUM);
for (int r = 0; r < row_16; r++) {
for (int c = 0; c < col_8; c++) {
int r16div = r / C16NUM, r16mod = r % C16NUM;
int c8div = c / C8NUM, c8mod = c % C8NUM;
size_t ci = r * stride + c;
float value = 0;
size_t ci = (c8div * C8NUM * row_16 + r * C8NUM + c8mod);
float16_t value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod;
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[c];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
ADD_BIAS(value, bias, c)
DO_RELU(value, act_type)
DO_RELU6(value, act_type)
dst[ci] = value;
}
}
@ -254,37 +273,119 @@ void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const fl
for (int j = 0; j < col; ++j) {
int c8div = j / 8, c8mod = j % 8;
size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
float value = 0;
float16_t value = 0;
for (int d = 0; d < deep; ++d) {
size_t ai = src_r_offset + d * C16NUM;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[j];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
ADD_BIAS(value, bias, j)
DO_RELU(value, act_type)
DO_RELU6(value, act_type)
dst[ci] = value;
}
}
}
}
void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, int stride, int write_mode) {
if (write_mode == OutType_Nhwc) {
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r12div = r / 12, r12mod = r % 12;
int c8div = c / 8, c8mod = c % 8;
size_t ci = r * stride + c;
float16_t value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r12div * deep * 12 + d * 12 + r12mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
ADD_BIAS(value, bias, c)
DO_RELU(value, act_type)
DO_RELU6(value, act_type)
dst[ci] = value;
}
}
} else if (write_mode == OutType_C8) {
int col_8 = UP_ROUND(col, C8NUM);
int row_12 = UP_ROUND(row, C12NUM);
for (int r = 0; r < row_12; r++) {
for (int c = 0; c < col_8; c++) {
int r12div = r / C12NUM, r12mod = r % C12NUM;
int c8div = c / C8NUM, c8mod = c % C8NUM;
size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod);
float16_t value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod;
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
value = value + a[ai] * b[bi];
}
ADD_BIAS(value, bias, c)
DO_RELU(value, act_type)
DO_RELU6(value, act_type)
dst[ci] = value;
}
}
} else {
for (int i = 0; i < row; ++i) {
int src_r_offset = i;
int dst_r_offset = i * col * stride;
for (int j = 0; j < col; ++j) {
int c8div = j / 8, c8mod = j % 8;
size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
float16_t value = 0;
for (int d = 0; d < deep; ++d) {
size_t ai = src_r_offset + d * C12NUM;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
ADD_BIAS(value, bias, j)
DO_RELU(value, act_type)
DO_RELU6(value, act_type)
dst[ci] = value;
}
}
}
return;
}
void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
int depth, int row, int col, int stride, int out_type) {
if (out_type == OutType_C8) {
#ifdef ENABLE_ARM64
MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, false);
#else
MatMul12x8A32Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type);
#endif
} else {
#ifdef ENABLE_ARM64
MatmulFp16Neon64Opt(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type);
#else
MatMul12x8A32Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type);
#endif
}
return;
}
#ifdef ENABLE_ARM82_A32
void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
int depth, int col) {
// TODO(fun): function
return;
}
#endif
void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
int depth, int col) {
#ifdef ENABLE_ARM64
MatVecMulFp16Neon64(a, b, c, bias, (int)act_type, depth, col);
#else
MatVecMulA32Fp16(a, b, c, bias, (int)act_type, depth, col);
#endif
}
#ifdef ENABLE_ARM64
static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) {
size_t stride = col * 2;
asm volatile(
@ -392,6 +493,7 @@ static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31");
}
#endif
void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
size_t row_up_16 = UP_ROUND(row, C16NUM);
@ -442,6 +544,54 @@ void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si
return;
}
void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
size_t row_up_12 = UP_ROUND(row, C12NUM);
size_t row12 = row / C12NUM * C12NUM;
size_t col8 = col / C8NUM * C8NUM;
const float16_t *src_r = src_ptr;
float16_t *dst_r = dst_ptr;
size_t ri = 0;
// transpose 12x8
for (; ri < row12; ri += C12NUM) {
size_t ci = 0;
for (; ci < col8; ci += C8NUM) {
const float16_t *src_c = src_r + ci;
float16_t *dst_c = dst_r + ci * C12NUM;
#ifdef ENABLE_ARM82_A32
Transpose12x8A32Fp16(src_c, dst_c, col * sizeof(float16_t), 24);
#else
for (int tr = 0; tr < C12NUM; tr++) {
for (int tc = 0; tc < C8NUM; tc++) {
dst_c[tc * C12NUM + tr] = src_c[tr * col + tc];
}
}
#endif
}
for (; ci < col; ci++) {
const float16_t *src_c = src_r + ci;
float16_t *dst_c = dst_r + ci * C12NUM;
for (size_t i = 0; i < C12NUM; i++) {
dst_c[i] = src_c[i * col];
}
}
src_r += C12NUM * col;
dst_r += C12NUM * col;
}
for (; ri < row; ri++) {
for (size_t i = 0; i < col; ++i) {
dst_r[i * C12NUM] = src_r[i];
}
src_r += col;
dst_r += 1;
}
for (; ri < row_up_12; ri++) {
for (size_t i = 0; i < col; i++) {
dst_r[i * C12NUM] = 0;
}
dst_r += 1;
}
}
void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) {
if (is_fp32_src) {
const float *fp32_src = (const float *)src;

View File

@ -19,18 +19,51 @@
#include <float.h>
#include <string.h>
#ifdef ENABLE_NEON
#ifdef ENABLE_ARM64
#include <arm_neon.h>
#endif
#include "nnacl/errorcode.h"
#include "nnacl/matmul_parameter.h"
#include "nnacl/op_base.h"
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
#include "nnacl/fp16/pack_fp16.h"
#define ADD_BIAS(value, bias, c) \
if (bias != NULL) value = value + bias[c];
#define DO_RELU(value, act_type) \
if (act_type == ActType_Relu) value = MSMAX(0.0f, value);
#define DO_RELU6(value, act_type) \
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); \
if (act_type == ActType_Relu6) value = MSMAX(0.0f, value);
#ifdef __cplusplus
extern "C" {
#endif
void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, int stride, bool write_nhwc);
void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, int stride, int write_mode);
void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, int stride, int write_mode);
#ifdef ENABLE_ARM64
void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc);
void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc);
void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
int depth, int col);
#elif ENABLE_ARM82_A32
void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, int stride, int write_mode);
void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
int depth, int col);
#endif
void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
int depth, int row, int col, int stride, int out_type);
@ -42,14 +75,7 @@ void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row,
void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col);
void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc);
void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc);
void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type,
int depth, int col);
void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col);
void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src);

View File

@ -39,7 +39,7 @@ void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matri
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
for (int y = 0; y < in_channel; ++y) {
float16_t tmp = 0;
float tmp = 0;
for (int z = 0; z < k; ++z) {
tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n];
}

View File

@ -160,129 +160,35 @@ void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int
}
void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int channel) {
int hw16 = plane / C16NUM * C16NUM;
#ifdef ENABLE_ARM64
// Transpose16x8 in arm64
const int hw_tile = C16NUM;
#else
// Transpose8x8 in others
const int hw_tile = C8NUM;
#endif
int hw_align = plane / hw_tile * hw_tile;
int c8 = channel / C8NUM * C8NUM;
int batch = plane * channel;
for (int n = 0; n < batches; n++) {
const float16_t *src_batch = (const float16_t *)src + n * batch;
float16_t *dst_batch = (float16_t *)dst + n * batch;
int hw = 0;
for (; hw < hw16; hw += C16NUM) {
for (; hw < hw_align; hw += hw_tile) {
int c = 0;
for (; c < c8; c += C8NUM) {
const float16_t *src_ptr = src_batch + hw * channel + c;
float16_t *dst_ptr = dst_batch + c * plane + hw;
#ifdef ENABLE_ARM64
size_t srcStride = channel * sizeof(float16_t);
size_t dstStride = plane * sizeof(float16_t);
asm volatile(
"mov x10, %[src_ptr]\n"
"mov x11, %[dst_ptr]\n"
"ld1 {v0.8h}, [x10], %[srcStride]\n"
"ld1 {v1.8h}, [x10], %[srcStride]\n"
"ld1 {v2.8h}, [x10], %[srcStride]\n"
"ld1 {v3.8h}, [x10], %[srcStride]\n"
"ld1 {v4.8h}, [x10], %[srcStride]\n"
"ld1 {v5.8h}, [x10], %[srcStride]\n"
"ld1 {v6.8h}, [x10], %[srcStride]\n"
"ld1 {v7.8h}, [x10], %[srcStride]\n"
"zip1 v16.8h, v0.8h, v1.8h\n"
"zip1 v17.8h, v2.8h, v3.8h\n"
"zip1 v18.8h, v4.8h, v5.8h\n"
"zip1 v19.8h, v6.8h, v7.8h\n"
"ld1 {v8.8h}, [x10], %[srcStride]\n"
"ld1 {v9.8h}, [x10], %[srcStride]\n"
"ld1 {v10.8h}, [x10], %[srcStride]\n"
"ld1 {v11.8h}, [x10], %[srcStride]\n"
"ld1 {v12.8h}, [x10], %[srcStride]\n"
"ld1 {v13.8h}, [x10], %[srcStride]\n"
"ld1 {v14.8h}, [x10], %[srcStride]\n"
"ld1 {v15.8h}, [x10], %[srcStride]\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v24.2d, v20.2d, v22.2d\n"
"trn2 v25.2d, v20.2d, v22.2d\n"
"trn1 v26.2d, v21.2d, v23.2d\n"
"trn2 v27.2d, v21.2d, v23.2d\n"
"zip1 v16.8h, v8.8h, v9.8h\n"
"zip1 v17.8h, v10.8h, v11.8h\n"
"zip1 v18.8h, v12.8h, v13.8h\n"
"zip1 v19.8h, v14.8h, v15.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v28.2d, v20.2d, v22.2d\n"
"trn2 v29.2d, v20.2d, v22.2d\n"
"trn1 v30.2d, v21.2d, v23.2d\n"
"trn2 v31.2d, v21.2d, v23.2d\n"
"add x10, x11, #16\n"
"st1 {v24.8h}, [x11], %[dstStride]\n"
"st1 {v28.8h}, [x10], %[dstStride]\n"
"st1 {v26.8h}, [x11], %[dstStride]\n"
"st1 {v30.8h}, [x10], %[dstStride]\n"
"st1 {v25.8h}, [x11], %[dstStride]\n"
"st1 {v29.8h}, [x10], %[dstStride]\n"
"st1 {v27.8h}, [x11], %[dstStride]\n"
"st1 {v31.8h}, [x10], %[dstStride]\n"
"zip2 v16.8h, v0.8h, v1.8h\n"
"zip2 v17.8h, v2.8h, v3.8h\n"
"zip2 v18.8h, v4.8h, v5.8h\n"
"zip2 v19.8h, v6.8h, v7.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v24.2d, v20.2d, v22.2d\n"
"trn2 v25.2d, v20.2d, v22.2d\n"
"trn1 v26.2d, v21.2d, v23.2d\n"
"trn2 v27.2d, v21.2d, v23.2d\n"
"zip2 v16.8h, v8.8h, v9.8h\n"
"zip2 v17.8h, v10.8h, v11.8h\n"
"zip2 v18.8h, v12.8h, v13.8h\n"
"zip2 v19.8h, v14.8h, v15.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v28.2d, v20.2d, v22.2d\n"
"trn2 v29.2d, v20.2d, v22.2d\n"
"trn1 v30.2d, v21.2d, v23.2d\n"
"trn2 v31.2d, v21.2d, v23.2d\n"
"st1 {v24.8h}, [x11], %[dstStride]\n"
"st1 {v28.8h}, [x10], %[dstStride]\n"
"st1 {v26.8h}, [x11], %[dstStride]\n"
"st1 {v30.8h}, [x10], %[dstStride]\n"
"st1 {v25.8h}, [x11], %[dstStride]\n"
"st1 {v29.8h}, [x10], %[dstStride]\n"
"st1 {v27.8h}, [x11], %[dstStride]\n"
"st1 {v31.8h}, [x10], %[dstStride]\n"
:
:
[ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31");
size_t src_stride = channel * sizeof(float16_t);
size_t dst_stride = plane * sizeof(float16_t);
Transpose16x8ARM64Fp16(src_ptr, dst_ptr, src_stride, dst_stride);
#elif defined(ENABLE_ARM82_A32)
size_t src_stride = channel * sizeof(float16_t);
size_t dst_stride = plane * sizeof(float16_t);
Transpose8x8A32Fp16(src_ptr, dst_ptr, src_stride, dst_stride);
#else
for (int tr = 0; tr < C16NUM; tr++) {
for (int tr = 0; tr < hw_tile; tr++) {
for (int tc = 0; tc < C8NUM; tc++) {
dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc];
}
@ -292,7 +198,7 @@ void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int
for (; c < channel; c++) {
const float16_t *src_ptr = src_batch + hw * channel + c;
float16_t *dst_ptr = dst_batch + c * plane + hw;
for (size_t i = 0; i < C16NUM; i++) {
for (size_t i = 0; i < hw_tile; i++) {
dst_ptr[i] = src_ptr[i * channel];
}
}
@ -305,7 +211,6 @@ void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int
}
}
}
return;
}
void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) {
@ -565,3 +470,246 @@ void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, i
}
}
}
#ifdef ENABLE_ARM82_A32
inline void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride) {
asm volatile(
"mov r10, %[src]\n"
"mov r12, %[dst]\n"
"vld1.16 {q0}, [r10], %[src_stride]\n"
"vld1.16 {q2}, [r10], %[src_stride]\n"
"vld1.16 {q4}, [r10], %[src_stride]\n"
"vld1.16 {q6}, [r10], %[src_stride]\n"
"vtrn.16 d0, d4\n"
"vtrn.16 d1, d5\n"
"vtrn.16 d8, d12\n"
"vtrn.16 d9, d13\n"
"vld1.16 {q8}, [r10], %[src_stride]\n"
"vld1.16 {q10}, [r10], %[src_stride]\n"
"vld1.16 {q12}, [r10], %[src_stride]\n"
"vld1.16 {q14}, [r10], %[src_stride]\n"
"vtrn.32 d0, d8\n"
"vtrn.32 d4, d12\n"
"vtrn.32 d1, d9\n"
"vtrn.32 d5, d13\n"
"vtrn.16 d16, d20\n"
"vtrn.16 d17, d21\n"
"vtrn.16 d24, d28\n"
"vtrn.16 d25, d29\n"
"vtrn.32 d16, d24\n"
"vtrn.32 d20, d28\n"
"vtrn.32 d17, d25\n"
"vtrn.32 d21, d29\n"
"vswp d1, d16\n"
"vswp d5, d20\n"
"vswp d9, d24\n"
"vswp d13, d28\n"
"vst1.16 {q0}, [r12], %[dst_stride]\n"
"vst1.16 {q2}, [r12], %[dst_stride]\n"
"vst1.16 {q4}, [r12], %[dst_stride]\n"
"vst1.16 {q6}, [r12], %[dst_stride]\n"
"vst1.16 {q8}, [r12], %[dst_stride]\n"
"vst1.16 {q10}, [r12], %[dst_stride]\n"
"vst1.16 {q12}, [r12], %[dst_stride]\n"
"vst1.16 {q14}, [r12], %[dst_stride]\n"
:
: [ dst ] "r"(dst), [ src ] "r"(src), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride)
: "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
"q15");
}
inline void Transpose12x8A32Fp16(const float16_t *src_c, float16_t *dst_c, size_t src_stride, size_t dst_stride) {
asm volatile(
"mov r10, %[src_c]\n"
"mov r12, %[dst_c]\n"
"vld1.16 {q0}, [r10], %[src_stride]\n"
"vld1.16 {q2}, [r10], %[src_stride]\n"
"vld1.16 {q4}, [r10], %[src_stride]\n"
"vld1.16 {q6}, [r10], %[src_stride]\n"
"vtrn.16 d0, d4\n"
"vtrn.16 d1, d5\n"
"vtrn.16 d8, d12\n"
"vtrn.16 d9, d13\n"
"vld1.16 {q8}, [r10], %[src_stride]\n"
"vld1.16 {q10}, [r10], %[src_stride]\n"
"vld1.16 {q12}, [r10], %[src_stride]\n"
"vld1.16 {q14}, [r10], %[src_stride]\n"
"vtrn.32 d0, d8\n"
"vtrn.32 d4, d12\n"
"vtrn.32 d1, d9\n"
"vtrn.32 d5, d13\n"
"vtrn.16 d16, d20\n"
"vtrn.16 d17, d21\n"
"vtrn.16 d24, d28\n"
"vtrn.16 d25, d29\n"
"vld1.16 {q1}, [r10], %[src_stride]\n"
"vld1.16 {q3}, [r10], %[src_stride]\n"
"vld1.16 {q5}, [r10], %[src_stride]\n"
"vld1.16 {q7}, [r10], %[src_stride]\n"
"vtrn.32 d16, d24\n"
"vtrn.32 d20, d28\n"
"vtrn.32 d17, d25\n"
"vtrn.32 d21, d29\n"
"vswp d1, d16\n"
"vswp d5, d20\n"
"vswp d9, d24\n"
"vswp d13, d28\n"
"vtrn.16 d2, d6\n"
"vtrn.16 d3, d7\n"
"vtrn.16 d10, d14\n"
"vtrn.16 d11, d15\n"
"vtrn.32 d2, d10\n"
"vtrn.32 d6, d14\n"
"vtrn.32 d3, d11\n"
"vtrn.32 d7, d15\n"
"vst1.16 {q0, d2}, [r12], %[dst_stride]\n"
"vst1.16 {q2, d6}, [r12], %[dst_stride]\n"
"vst1.16 {q4, d10}, [r12], %[dst_stride]\n"
"vst1.16 {q6, d14}, [r12], %[dst_stride]\n"
"vswp d3, d18\n"
"vswp d7, d22\n"
"vswp d11, d26\n"
"vswp d15, d30\n"
"vst1.16 {q8, d18}, [r12], %[dst_stride]\n"
"vst1.16 {q10, d22}, [r12], %[dst_stride]\n"
"vst1.16 {q12, d26}, [r12], %[dst_stride]\n"
"vst1.16 {q14, d30}, [r12], %[dst_stride]\n"
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride)
: "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
"q15");
}
#endif
#ifdef ENABLE_ARM64
inline void Transpose16x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) {
asm volatile(
"mov x10, %[src_ptr]\n"
"mov x11, %[dst_ptr]\n"
"ld1 {v0.8h}, [x10], %[src_stride]\n"
"ld1 {v1.8h}, [x10], %[src_stride]\n"
"ld1 {v2.8h}, [x10], %[src_stride]\n"
"ld1 {v3.8h}, [x10], %[src_stride]\n"
"ld1 {v4.8h}, [x10], %[src_stride]\n"
"ld1 {v5.8h}, [x10], %[src_stride]\n"
"ld1 {v6.8h}, [x10], %[src_stride]\n"
"ld1 {v7.8h}, [x10], %[src_stride]\n"
"zip1 v16.8h, v0.8h, v1.8h\n"
"zip1 v17.8h, v2.8h, v3.8h\n"
"zip1 v18.8h, v4.8h, v5.8h\n"
"zip1 v19.8h, v6.8h, v7.8h\n"
"ld1 {v8.8h}, [x10], %[src_stride]\n"
"ld1 {v9.8h}, [x10], %[src_stride]\n"
"ld1 {v10.8h}, [x10], %[src_stride]\n"
"ld1 {v11.8h}, [x10], %[src_stride]\n"
"ld1 {v12.8h}, [x10], %[src_stride]\n"
"ld1 {v13.8h}, [x10], %[src_stride]\n"
"ld1 {v14.8h}, [x10], %[src_stride]\n"
"ld1 {v15.8h}, [x10], %[src_stride]\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v24.2d, v20.2d, v22.2d\n"
"trn2 v25.2d, v20.2d, v22.2d\n"
"trn1 v26.2d, v21.2d, v23.2d\n"
"trn2 v27.2d, v21.2d, v23.2d\n"
"zip1 v16.8h, v8.8h, v9.8h\n"
"zip1 v17.8h, v10.8h, v11.8h\n"
"zip1 v18.8h, v12.8h, v13.8h\n"
"zip1 v19.8h, v14.8h, v15.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v28.2d, v20.2d, v22.2d\n"
"trn2 v29.2d, v20.2d, v22.2d\n"
"trn1 v30.2d, v21.2d, v23.2d\n"
"trn2 v31.2d, v21.2d, v23.2d\n"
"add x10, x11, #16\n"
"st1 {v24.8h}, [x11], %[dst_stride]\n"
"st1 {v28.8h}, [x10], %[dst_stride]\n"
"st1 {v26.8h}, [x11], %[dst_stride]\n"
"st1 {v30.8h}, [x10], %[dst_stride]\n"
"st1 {v25.8h}, [x11], %[dst_stride]\n"
"st1 {v29.8h}, [x10], %[dst_stride]\n"
"st1 {v27.8h}, [x11], %[dst_stride]\n"
"st1 {v31.8h}, [x10], %[dst_stride]\n"
"zip2 v16.8h, v0.8h, v1.8h\n"
"zip2 v17.8h, v2.8h, v3.8h\n"
"zip2 v18.8h, v4.8h, v5.8h\n"
"zip2 v19.8h, v6.8h, v7.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v24.2d, v20.2d, v22.2d\n"
"trn2 v25.2d, v20.2d, v22.2d\n"
"trn1 v26.2d, v21.2d, v23.2d\n"
"trn2 v27.2d, v21.2d, v23.2d\n"
"zip2 v16.8h, v8.8h, v9.8h\n"
"zip2 v17.8h, v10.8h, v11.8h\n"
"zip2 v18.8h, v12.8h, v13.8h\n"
"zip2 v19.8h, v14.8h, v15.8h\n"
"trn1 v20.4s, v16.4s, v17.4s\n"
"trn2 v21.4s, v16.4s, v17.4s\n"
"trn1 v22.4s, v18.4s, v19.4s\n"
"trn2 v23.4s, v18.4s, v19.4s\n"
"trn1 v28.2d, v20.2d, v22.2d\n"
"trn2 v29.2d, v20.2d, v22.2d\n"
"trn1 v30.2d, v21.2d, v23.2d\n"
"trn2 v31.2d, v21.2d, v23.2d\n"
"st1 {v24.8h}, [x11], %[dst_stride]\n"
"st1 {v28.8h}, [x10], %[dst_stride]\n"
"st1 {v26.8h}, [x11], %[dst_stride]\n"
"st1 {v30.8h}, [x10], %[dst_stride]\n"
"st1 {v25.8h}, [x11], %[dst_stride]\n"
"st1 {v29.8h}, [x10], %[dst_stride]\n"
"st1 {v27.8h}, [x11], %[dst_stride]\n"
"st1 {v31.8h}, [x10], %[dst_stride]\n"
:
: [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
"v31");
}
#endif

View File

@ -17,11 +17,9 @@
#ifndef MINDSPORE_NNACL_FP16_PACK_FP16_H_
#define MINDSPORE_NNACL_FP16_PACK_FP16_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/conv_parameter.h"
#include "nnacl/op_base.h"
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
#ifdef __cplusplus
extern "C" {
@ -72,6 +70,17 @@ void PackNHWCFp16ToC8HWN8Fp16(float16_t *src, float16_t *dst, int batch, int pla
void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel);
void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, int channel);
#ifdef ENABLE_ARM82_A32
void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride);
void Transpose12x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride);
#endif
#ifdef ENABLE_ARM64
void Transpose16x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride);
#endif
#ifdef __cplusplus
}
#endif

View File

@ -16,9 +16,6 @@
#ifndef MINDSPORE_NNACL_FP16_PAD_FP16_H_
#define MINDSPORE_NNACL_FP16_PAD_FP16_H_
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/fp32/pad_fp32.h"
#ifdef __cplusplus

View File

@ -18,10 +18,8 @@
#define MINDSPORE_NNACL_FP16_POOLING_FP16_H_
#include <math.h>
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/pooling_parameter.h"
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
#ifdef __cplusplus
extern "C" {
#endif

View File

@ -17,7 +17,7 @@
#include "nnacl/fp16/power_fp16.h"
#include "nnacl/errorcode.h"
#if defined(ENABLE_NEON)
#if defined(ENABLE_ARM64)
float16x8_t OptimizedPowerSimdFp16(float16x8_t x, const void *exponent) {
int tmp = (int)(*(float16_t *)exponent);
int exp = abs(tmp);
@ -53,23 +53,23 @@ float16_t OptimizedPowerScalarFp16(float16_t x, const void *exponent) {
void PowerBroadCastFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale,
float shift) {
PowerScalarFunFp16 PowerScalarFunFp16_ = NULL;
#if defined(ENABLE_NEON)
#if defined(ENABLE_ARM64)
PowerSimdFunFp16 PowerSimdFunFp16_ = NULL;
#endif
if (CheckInteger(*exponent)) {
#if defined(ENABLE_NEON)
#if defined(ENABLE_ARM64)
PowerSimdFunFp16_ = OptimizedPowerSimdFp16;
#endif
PowerScalarFunFp16_ = OptimizedPowerScalarFp16;
} else {
#if defined(ENABLE_NEON)
#if defined(ENABLE_ARM64)
PowerSimdFunFp16_ = StdPowerSimdFp16;
#endif
PowerScalarFunFp16_ = StdPowerScalarFp16;
}
int i = 0;
#ifdef ENABLE_NEON
#ifdef ENABLE_ARM64
int len_c8 = UP_ROUND(len, C8NUM);
float16x8_t scale_8 = vmovq_n_f16(scale);
float16x8_t shift_8 = vmovq_n_f16(shift);
@ -87,7 +87,7 @@ void PowerSingleFp16(const float16_t *input, const float16_t *exponent, float16_
float shift) {
int i = 0;
PowerScalarFunFp16 PowerScalarFunFp16_ = NULL;
#ifdef ENABLE_NEON
#ifdef ENABLE_ARM64
int len_c8 = UP_ROUND(len, C8NUM);
float16x8_t scale_8 = vmovq_n_f16(scale);
float16x8_t shift_8 = vmovq_n_f16(shift);

View File

@ -19,9 +19,10 @@
#include <math.h>
#include "nnacl/op_base.h"
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
#include "nnacl/power_parameter.h"
#if defined(ENABLE_NEON)
#if defined(ENABLE_ARM64)
typedef float16x8_t (*PowerSimdFunFp16)(float16x8_t x, const void *exponent);
#endif
typedef float16_t (*PowerScalarFunFp16)(float16_t x, const void *exponent);
@ -36,7 +37,7 @@ static inline float16_t StdPowerScalarFp16(float16_t x, const void *exponent) {
return powf(x, *(float16_t *)exponent);
}
#if defined(ENABLE_NEON)
#if defined(ENABLE_ARM64)
static inline float16x8_t StdPowerSimdFp16(float16x8_t x, const void *exponent) {
float16x8_t result;
result[0] = powf(x[0], *(float16_t *)exponent);

View File

@ -18,10 +18,7 @@
#define MINDSPORE_NNACL_FP16_QUANTDTYPECAST_FP16_H_
#include "nnacl/op_base.h"
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
#ifdef __cplusplus
extern "C" {

View File

@ -19,9 +19,6 @@
#include "nnacl/op_base.h"
#include "nnacl/reduce_parameter.h"
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#ifdef __cplusplus
extern "C" {
#endif

View File

@ -18,10 +18,9 @@
#define MINDSPORE_NNACL_SCALE_FP16_H_
#include "nnacl/op_base.h"
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
#include "nnacl/scale.h"
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#ifdef __cplusplus
extern "C" {
#endif

View File

@ -40,7 +40,7 @@ void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channe
}
}
int k = 0;
#ifdef ENABLE_NEON
#ifdef ENABLE_ARM64
int count2 = (channel / C8NUM) * C8NUM;
for (; k < count2; k += C8NUM) {
float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + k);
@ -58,9 +58,9 @@ void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channe
void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel) {
int cur_batch_offset = 0;
for (int i = 0; i < batch; i++, cur_batch_offset += channel) {
float16_t sum = 0;
float16_t sum = 0.0f;
int j = 0;
#ifdef ENABLE_NEON
#ifdef ENABLE_ARM64
float16x8_t sum8 = vdupq_n_f16(0);
int count = (channel / C8NUM) * C8NUM;
for (; j < count; j += C8NUM) {
@ -72,7 +72,7 @@ void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel)
sum += src[cur_batch_offset + j];
}
int k = 0;
#ifdef ENABLE_NEON
#ifdef ENABLE_ARM64
const float16_t div = 1.0f / sum;
for (; k < count; k += C8NUM) {
vst1q_f16(dst + cur_batch_offset + k, vmulq_n_f16(vld1q_f16(src + cur_batch_offset + k), div));
@ -117,7 +117,7 @@ void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *s
}
for (int j = 0; j < input_shape[axis]; j++) {
int axis_offset = inner_offset + j * inner_size;
output_ptr[axis_offset] = exp(input_ptr[axis_offset] - max_data);
output_ptr[axis_offset] = expf(input_ptr[axis_offset] - max_data);
sum_data[k + sum_outter_offset] += output_ptr[axis_offset];
}
}

View File

@ -18,10 +18,9 @@
#define MINDSPORE_NNACL_FP16_SOFTMAX_FP16_H_
#include "nnacl/op_base.h"
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
#include "nnacl/softmax_parameter.h"
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#ifdef __cplusplus
extern "C" {
#endif

View File

@ -18,10 +18,8 @@
#define MINDSPORE_NNACL_FP16_TRANSPOSE_FP16_H_
#include "nnacl/op_base.h"
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
#include "nnacl/transpose.h"
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#ifdef __cplusplus
extern "C" {

View File

@ -16,562 +16,15 @@
#include "nnacl/fp16/winograd_transform_fp16.h"
// for fp16 convolution 3x3 filter/input/output transform F(4,3)
void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step) {
float16x8_t d00 = vld1q_f16(tmp_data);
float16x8_t d01 = vld1q_f16(tmp_data + 8);
float16x8_t d02 = vld1q_f16(tmp_data + 2 * 8);
float16x8_t d03 = vld1q_f16(tmp_data + 3 * 8);
float16x8_t d04 = vld1q_f16(tmp_data + 4 * 8);
float16x8_t d05 = vld1q_f16(tmp_data + 5 * 8);
float16x8_t d10 = vld1q_f16(tmp_data + 6 * 8);
float16x8_t d11 = vld1q_f16(tmp_data + 7 * 8);
float16x8_t d12 = vld1q_f16(tmp_data + 8 * 8);
float16x8_t d13 = vld1q_f16(tmp_data + 9 * 8);
float16x8_t d14 = vld1q_f16(tmp_data + 10 * 8);
float16x8_t d15 = vld1q_f16(tmp_data + 11 * 8);
float16x8_t d20 = vld1q_f16(tmp_data + 12 * 8);
float16x8_t d21 = vld1q_f16(tmp_data + 13 * 8);
float16x8_t d22 = vld1q_f16(tmp_data + 14 * 8);
float16x8_t d23 = vld1q_f16(tmp_data + 15 * 8);
float16x8_t d24 = vld1q_f16(tmp_data + 16 * 8);
float16x8_t d25 = vld1q_f16(tmp_data + 17 * 8);
float16x8_t d30 = vld1q_f16(tmp_data + 18 * 8);
float16x8_t d31 = vld1q_f16(tmp_data + 19 * 8);
float16x8_t d32 = vld1q_f16(tmp_data + 20 * 8);
float16x8_t d33 = vld1q_f16(tmp_data + 21 * 8);
float16x8_t d34 = vld1q_f16(tmp_data + 22 * 8);
float16x8_t d35 = vld1q_f16(tmp_data + 23 * 8);
float16x8_t d40 = vld1q_f16(tmp_data + 24 * 8);
float16x8_t d41 = vld1q_f16(tmp_data + 25 * 8);
float16x8_t d42 = vld1q_f16(tmp_data + 26 * 8);
float16x8_t d43 = vld1q_f16(tmp_data + 27 * 8);
float16x8_t d44 = vld1q_f16(tmp_data + 28 * 8);
float16x8_t d45 = vld1q_f16(tmp_data + 29 * 8);
float16x8_t d50 = vld1q_f16(tmp_data + 30 * 8);
float16x8_t d51 = vld1q_f16(tmp_data + 31 * 8);
float16x8_t d52 = vld1q_f16(tmp_data + 32 * 8);
float16x8_t d53 = vld1q_f16(tmp_data + 33 * 8);
float16x8_t d54 = vld1q_f16(tmp_data + 34 * 8);
float16x8_t d55 = vld1q_f16(tmp_data + 35 * 8);
float16x8_t t00 = vaddq_f16(vsubq_f16(vmulq_n_f16(d00, 4), vmulq_n_f16(d20, 5)), d40);
float16x8_t t01 = vaddq_f16(vsubq_f16(vmulq_n_f16(d01, 4), vmulq_n_f16(d21, 5)), d41);
float16x8_t t02 = vaddq_f16(vsubq_f16(vmulq_n_f16(d02, 4), vmulq_n_f16(d22, 5)), d42);
float16x8_t t03 = vaddq_f16(vsubq_f16(vmulq_n_f16(d03, 4), vmulq_n_f16(d23, 5)), d43);
float16x8_t t04 = vaddq_f16(vsubq_f16(vmulq_n_f16(d04, 4), vmulq_n_f16(d24, 5)), d44);
float16x8_t t05 = vaddq_f16(vsubq_f16(vmulq_n_f16(d05, 4), vmulq_n_f16(d25, 5)), d45);
float16x8_t t10 = vaddq_f16(vaddq_f16(d30, d40), vmulq_n_f16(vaddq_f16(d10, d20), -4));
float16x8_t t11 = vaddq_f16(vaddq_f16(d31, d41), vmulq_n_f16(vaddq_f16(d11, d21), -4));
float16x8_t t12 = vaddq_f16(vaddq_f16(d32, d42), vmulq_n_f16(vaddq_f16(d12, d22), -4));
float16x8_t t13 = vaddq_f16(vaddq_f16(d33, d43), vmulq_n_f16(vaddq_f16(d13, d23), -4));
float16x8_t t14 = vaddq_f16(vaddq_f16(d34, d44), vmulq_n_f16(vaddq_f16(d14, d24), -4));
float16x8_t t15 = vaddq_f16(vaddq_f16(d35, d45), vmulq_n_f16(vaddq_f16(d15, d25), -4));
float16x8_t t20 = vaddq_f16(vsubq_f16(d40, d30), vmulq_n_f16(vsubq_f16(d10, d20), 4));
float16x8_t t21 = vaddq_f16(vsubq_f16(d41, d31), vmulq_n_f16(vsubq_f16(d11, d21), 4));
float16x8_t t22 = vaddq_f16(vsubq_f16(d42, d32), vmulq_n_f16(vsubq_f16(d12, d22), 4));
float16x8_t t23 = vaddq_f16(vsubq_f16(d43, d33), vmulq_n_f16(vsubq_f16(d13, d23), 4));
float16x8_t t24 = vaddq_f16(vsubq_f16(d44, d34), vmulq_n_f16(vsubq_f16(d14, d24), 4));
float16x8_t t25 = vaddq_f16(vsubq_f16(d45, d35), vmulq_n_f16(vsubq_f16(d15, d25), 4));
float16x8_t t30 = vaddq_f16(vsubq_f16(d40, d20), vmulq_n_f16(vsubq_f16(d30, d10), 2));
float16x8_t t31 = vaddq_f16(vsubq_f16(d41, d21), vmulq_n_f16(vsubq_f16(d31, d11), 2));
float16x8_t t32 = vaddq_f16(vsubq_f16(d42, d22), vmulq_n_f16(vsubq_f16(d32, d12), 2));
float16x8_t t33 = vaddq_f16(vsubq_f16(d43, d23), vmulq_n_f16(vsubq_f16(d33, d13), 2));
float16x8_t t34 = vaddq_f16(vsubq_f16(d44, d24), vmulq_n_f16(vsubq_f16(d34, d14), 2));
float16x8_t t35 = vaddq_f16(vsubq_f16(d45, d25), vmulq_n_f16(vsubq_f16(d35, d15), 2));
float16x8_t t40 = vaddq_f16(vsubq_f16(d40, d20), vmulq_n_f16(vsubq_f16(d10, d30), 2));
float16x8_t t41 = vaddq_f16(vsubq_f16(d41, d21), vmulq_n_f16(vsubq_f16(d11, d31), 2));
float16x8_t t42 = vaddq_f16(vsubq_f16(d42, d22), vmulq_n_f16(vsubq_f16(d12, d32), 2));
float16x8_t t43 = vaddq_f16(vsubq_f16(d43, d23), vmulq_n_f16(vsubq_f16(d13, d33), 2));
float16x8_t t44 = vaddq_f16(vsubq_f16(d44, d24), vmulq_n_f16(vsubq_f16(d14, d34), 2));
float16x8_t t45 = vaddq_f16(vsubq_f16(d45, d25), vmulq_n_f16(vsubq_f16(d15, d35), 2));
float16x8_t t50 = vaddq_f16(vsubq_f16(vmulq_n_f16(d10, 4), vmulq_n_f16(d30, 5)), d50);
float16x8_t t51 = vaddq_f16(vsubq_f16(vmulq_n_f16(d11, 4), vmulq_n_f16(d31, 5)), d51);
float16x8_t t52 = vaddq_f16(vsubq_f16(vmulq_n_f16(d12, 4), vmulq_n_f16(d32, 5)), d52);
float16x8_t t53 = vaddq_f16(vsubq_f16(vmulq_n_f16(d13, 4), vmulq_n_f16(d33, 5)), d53);
float16x8_t t54 = vaddq_f16(vsubq_f16(vmulq_n_f16(d14, 4), vmulq_n_f16(d34, 5)), d54);
float16x8_t t55 = vaddq_f16(vsubq_f16(vmulq_n_f16(d15, 4), vmulq_n_f16(d35, 5)), d55);
float16x8_t m00 = vaddq_f16(vsubq_f16(vmulq_n_f16(t00, 4), vmulq_n_f16(t02, 5)), t04);
float16x8_t m01 = vaddq_f16(vaddq_f16(t03, t04), vmulq_n_f16(vaddq_f16(t01, t02), -4));
float16x8_t m02 = vaddq_f16(vsubq_f16(t04, t03), vmulq_n_f16(vsubq_f16(t01, t02), 4));
float16x8_t m03 = vaddq_f16(vsubq_f16(t04, t02), vmulq_n_f16(vsubq_f16(t03, t01), 2));
float16x8_t m04 = vaddq_f16(vsubq_f16(t04, t02), vmulq_n_f16(vsubq_f16(t01, t03), 2));
float16x8_t m05 = vaddq_f16(vsubq_f16(vmulq_n_f16(t01, 4), vmulq_n_f16(t03, 5)), t05);
float16x8_t m10 = vaddq_f16(vsubq_f16(vmulq_n_f16(t10, 4), vmulq_n_f16(t12, 5)), t14);
float16x8_t m11 = vaddq_f16(vaddq_f16(t13, t14), vmulq_n_f16(vaddq_f16(t11, t12), -4));
float16x8_t m12 = vaddq_f16(vsubq_f16(t14, t13), vmulq_n_f16(vsubq_f16(t11, t12), 4));
float16x8_t m13 = vaddq_f16(vsubq_f16(t14, t12), vmulq_n_f16(vsubq_f16(t13, t11), 2));
float16x8_t m14 = vaddq_f16(vsubq_f16(t14, t12), vmulq_n_f16(vsubq_f16(t11, t13), 2));
float16x8_t m15 = vaddq_f16(vsubq_f16(vmulq_n_f16(t11, 4), vmulq_n_f16(t13, 5)), t15);
float16x8_t m20 = vaddq_f16(vsubq_f16(vmulq_n_f16(t20, 4), vmulq_n_f16(t22, 5)), t24);
float16x8_t m21 = vaddq_f16(vaddq_f16(t23, t24), vmulq_n_f16(vaddq_f16(t21, t22), -4));
float16x8_t m22 = vaddq_f16(vsubq_f16(t24, t23), vmulq_n_f16(vsubq_f16(t21, t22), 4));
float16x8_t m23 = vaddq_f16(vsubq_f16(t24, t22), vmulq_n_f16(vsubq_f16(t23, t21), 2));
float16x8_t m24 = vaddq_f16(vsubq_f16(t24, t22), vmulq_n_f16(vsubq_f16(t21, t23), 2));
float16x8_t m25 = vaddq_f16(vsubq_f16(vmulq_n_f16(t21, 4), vmulq_n_f16(t23, 5)), t25);
float16x8_t m30 = vaddq_f16(vsubq_f16(vmulq_n_f16(t30, 4), vmulq_n_f16(t32, 5)), t34);
float16x8_t m31 = vaddq_f16(vaddq_f16(t33, t34), vmulq_n_f16(vaddq_f16(t31, t32), -4));
float16x8_t m32 = vaddq_f16(vsubq_f16(t34, t33), vmulq_n_f16(vsubq_f16(t31, t32), 4));
float16x8_t m33 = vaddq_f16(vsubq_f16(t34, t32), vmulq_n_f16(vsubq_f16(t33, t31), 2));
float16x8_t m34 = vaddq_f16(vsubq_f16(t34, t32), vmulq_n_f16(vsubq_f16(t31, t33), 2));
float16x8_t m35 = vaddq_f16(vsubq_f16(vmulq_n_f16(t31, 4), vmulq_n_f16(t33, 5)), t35);
float16x8_t m40 = vaddq_f16(vsubq_f16(vmulq_n_f16(t40, 4), vmulq_n_f16(t42, 5)), t44);
float16x8_t m41 = vaddq_f16(vaddq_f16(t43, t44), vmulq_n_f16(vaddq_f16(t41, t42), -4));
float16x8_t m42 = vaddq_f16(vsubq_f16(t44, t43), vmulq_n_f16(vsubq_f16(t41, t42), 4));
float16x8_t m43 = vaddq_f16(vsubq_f16(t44, t42), vmulq_n_f16(vsubq_f16(t43, t41), 2));
float16x8_t m44 = vaddq_f16(vsubq_f16(t44, t42), vmulq_n_f16(vsubq_f16(t41, t43), 2));
float16x8_t m45 = vaddq_f16(vsubq_f16(vmulq_n_f16(t41, 4), vmulq_n_f16(t43, 5)), t45);
float16x8_t m50 = vaddq_f16(vsubq_f16(vmulq_n_f16(t50, 4), vmulq_n_f16(t52, 5)), t54);
float16x8_t m51 = vaddq_f16(vaddq_f16(t53, t54), vmulq_n_f16(vaddq_f16(t51, t52), -4));
float16x8_t m52 = vaddq_f16(vsubq_f16(t54, t53), vmulq_n_f16(vsubq_f16(t51, t52), 4));
float16x8_t m53 = vaddq_f16(vsubq_f16(t54, t52), vmulq_n_f16(vsubq_f16(t53, t51), 2));
float16x8_t m54 = vaddq_f16(vsubq_f16(t54, t52), vmulq_n_f16(vsubq_f16(t51, t53), 2));
float16x8_t m55 = vaddq_f16(vsubq_f16(vmulq_n_f16(t51, 4), vmulq_n_f16(t53, 5)), t55);
vst1_f16(trans_input_data, vget_low_f16(m00));
vst1_f16(trans_input_data + 64, vget_high_f16(m00));
vst1_f16(trans_input_data + step, vget_low_f16(m01));
vst1_f16(trans_input_data + step + 64, vget_high_f16(m01));
vst1_f16(trans_input_data + 2 * step, vget_low_f16(m02));
vst1_f16(trans_input_data + 2 * step + 64, vget_high_f16(m02));
vst1_f16(trans_input_data + 3 * step, vget_low_f16(m03));
vst1_f16(trans_input_data + 3 * step + 64, vget_high_f16(m03));
vst1_f16(trans_input_data + 4 * step, vget_low_f16(m04));
vst1_f16(trans_input_data + 4 * step + 64, vget_high_f16(m04));
vst1_f16(trans_input_data + 5 * step, vget_low_f16(m05));
vst1_f16(trans_input_data + 5 * step + 64, vget_high_f16(m05));
vst1_f16(trans_input_data + 6 * step, vget_low_f16(m10));
vst1_f16(trans_input_data + 6 * step + 64, vget_high_f16(m10));
vst1_f16(trans_input_data + 7 * step, vget_low_f16(m11));
vst1_f16(trans_input_data + 7 * step + 64, vget_high_f16(m11));
vst1_f16(trans_input_data + 8 * step, vget_low_f16(m12));
vst1_f16(trans_input_data + 8 * step + 64, vget_high_f16(m12));
vst1_f16(trans_input_data + 9 * step, vget_low_f16(m13));
vst1_f16(trans_input_data + 9 * step + 64, vget_high_f16(m13));
vst1_f16(trans_input_data + 10 * step, vget_low_f16(m14));
vst1_f16(trans_input_data + 10 * step + 64, vget_high_f16(m14));
vst1_f16(trans_input_data + 11 * step, vget_low_f16(m15));
vst1_f16(trans_input_data + 11 * step + 64, vget_high_f16(m15));
vst1_f16(trans_input_data + 12 * step, vget_low_f16(m20));
vst1_f16(trans_input_data + 12 * step + 64, vget_high_f16(m20));
vst1_f16(trans_input_data + 13 * step, vget_low_f16(m21));
vst1_f16(trans_input_data + 13 * step + 64, vget_high_f16(m21));
vst1_f16(trans_input_data + 14 * step, vget_low_f16(m22));
vst1_f16(trans_input_data + 14 * step + 64, vget_high_f16(m22));
vst1_f16(trans_input_data + 15 * step, vget_low_f16(m23));
vst1_f16(trans_input_data + 15 * step + 64, vget_high_f16(m23));
vst1_f16(trans_input_data + 16 * step, vget_low_f16(m24));
vst1_f16(trans_input_data + 16 * step + 64, vget_high_f16(m24));
vst1_f16(trans_input_data + 17 * step, vget_low_f16(m25));
vst1_f16(trans_input_data + 17 * step + 64, vget_high_f16(m25));
vst1_f16(trans_input_data + 18 * step, vget_low_f16(m30));
vst1_f16(trans_input_data + 18 * step + 64, vget_high_f16(m30));
vst1_f16(trans_input_data + 19 * step, vget_low_f16(m31));
vst1_f16(trans_input_data + 19 * step + 64, vget_high_f16(m31));
vst1_f16(trans_input_data + 20 * step, vget_low_f16(m32));
vst1_f16(trans_input_data + 20 * step + 64, vget_high_f16(m32));
vst1_f16(trans_input_data + 21 * step, vget_low_f16(m33));
vst1_f16(trans_input_data + 21 * step + 64, vget_high_f16(m33));
vst1_f16(trans_input_data + 22 * step, vget_low_f16(m34));
vst1_f16(trans_input_data + 22 * step + 64, vget_high_f16(m34));
vst1_f16(trans_input_data + 23 * step, vget_low_f16(m35));
vst1_f16(trans_input_data + 23 * step + 64, vget_high_f16(m35));
vst1_f16(trans_input_data + 24 * step, vget_low_f16(m40));
vst1_f16(trans_input_data + 24 * step + 64, vget_high_f16(m40));
vst1_f16(trans_input_data + 25 * step, vget_low_f16(m41));
vst1_f16(trans_input_data + 25 * step + 64, vget_high_f16(m41));
vst1_f16(trans_input_data + 26 * step, vget_low_f16(m42));
vst1_f16(trans_input_data + 26 * step + 64, vget_high_f16(m42));
vst1_f16(trans_input_data + 27 * step, vget_low_f16(m43));
vst1_f16(trans_input_data + 27 * step + 64, vget_high_f16(m43));
vst1_f16(trans_input_data + 28 * step, vget_low_f16(m44));
vst1_f16(trans_input_data + 28 * step + 64, vget_high_f16(m44));
vst1_f16(trans_input_data + 29 * step, vget_low_f16(m45));
vst1_f16(trans_input_data + 29 * step + 64, vget_high_f16(m45));
vst1_f16(trans_input_data + 30 * step, vget_low_f16(m50));
vst1_f16(trans_input_data + 30 * step + 64, vget_high_f16(m50));
vst1_f16(trans_input_data + 31 * step, vget_low_f16(m51));
vst1_f16(trans_input_data + 31 * step + 64, vget_high_f16(m51));
vst1_f16(trans_input_data + 32 * step, vget_low_f16(m52));
vst1_f16(trans_input_data + 32 * step + 64, vget_high_f16(m52));
vst1_f16(trans_input_data + 33 * step, vget_low_f16(m53));
vst1_f16(trans_input_data + 33 * step + 64, vget_high_f16(m53));
vst1_f16(trans_input_data + 34 * step, vget_low_f16(m54));
vst1_f16(trans_input_data + 34 * step + 64, vget_high_f16(m54));
vst1_f16(trans_input_data + 35 * step, vget_low_f16(m55));
vst1_f16(trans_input_data + 35 * step + 64, vget_high_f16(m55));
}
void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data,
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) {
// input data format : nhwc
const int output_unit = 4;
int input_channel = conv_param->input_channel_;
int input_width = conv_param->input_w_;
int input_height = conv_param->input_h_;
int pad_w = conv_param->pad_l_;
int pad_h = conv_param->pad_u_;
int ic8 = UP_DIV(input_channel, C8NUM);
if (out_w_block == 0) {
return;
}
for (int cal_id = 0; cal_id < real_cal_num; cal_id++) {
int x_id = start_index + cal_id;
int origin_x = (x_id % out_w_block) * output_unit - pad_w;
int origin_y = (x_id / out_w_block) * output_unit - pad_h;
int real_x_start = origin_x > 0 ? 0 : -origin_x;
int real_x_end = (origin_x + 6) < input_width ? 6 : (input_width - origin_x);
int real_y_start = origin_y > 0 ? 0 : -origin_y;
int real_y_end = (origin_y + 6) < input_height ? 6 : (input_height - origin_y);
int src_plane_offset = ic8 * C8NUM * (origin_y * input_width + origin_x);
int dst_plane_offset = cal_id * C4NUM;
for (int ic = 0; ic < ic8; ic++) {
// clear tmp buffer
memset(tmp_data, 0, 6 * 6 * C8NUM * sizeof(float16_t));
// get real input block with padding
int src_ic4_offset = src_plane_offset + ic * C8NUM;
for (int interval = real_y_start; interval < real_y_end; interval++) {
int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * ic8 * C8NUM;
int dst_y_offset = interval * 6 * C8NUM + real_x_start * C8NUM;
for (int j = 0; j < (real_x_end - real_x_start); j++) {
int src_x_offset = src_y_offset + j * ic8 * C8NUM;
int dst_x_offset = dst_y_offset + j * C8NUM;
float16_t *src_addr = (float16_t *)(input_data) + src_x_offset;
float16_t *dst_addr = tmp_data + dst_x_offset;
vst1q_f16(dst_addr, vld1q_f16(src_addr));
}
}
// input transform
int dst_ic4_offset = dst_plane_offset + ic * 16 * C8NUM;
size_t dst_step = ic8 * C8NUM * 16;
float16_t *trans_input_ptr = trans_input + dst_ic4_offset;
Conv3x3Fp16InputUnit(tmp_data, trans_input_ptr, dst_step);
}
}
}
void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC4, int output_channel,
int kernel_plane) {
int dst_step = iC4 * C4NUM * 8;
for (int o = 0; o < output_channel; o++) {
int oc8_block_num = o / C8NUM;
int oc8_block_rem = o % C8NUM;
int src_oc_offset = o * iC4 * C4NUM * kernel_plane;
int dst_oc_offset = oc8_block_num * C8NUM * iC4 * C4NUM * 36 + oc8_block_rem;
for (int i = 0; i < iC4; i++) {
const float16_t *src_ic4_ptr = weight_data + src_oc_offset + i * kernel_plane * C4NUM;
float16_t *dst_ic4_ptr = trans_weight + dst_oc_offset + i * 8 * C4NUM;
float16x4_t g00 = vld1_f16(src_ic4_ptr);
float16x4_t g01 = vld1_f16(src_ic4_ptr + 4);
float16x4_t g02 = vld1_f16(src_ic4_ptr + 2 * 4);
float16x4_t g10 = vld1_f16(src_ic4_ptr + 3 * 4);
float16x4_t g11 = vld1_f16(src_ic4_ptr + 4 * 4);
float16x4_t g12 = vld1_f16(src_ic4_ptr + 5 * 4);
float16x4_t g20 = vld1_f16(src_ic4_ptr + 6 * 4);
float16x4_t g21 = vld1_f16(src_ic4_ptr + 7 * 4);
float16x4_t g22 = vld1_f16(src_ic4_ptr + 8 * 4);
float16x4_t dst00 = vmul_n_f16(g00, 0.25);
float16x4_t dst01 = vmul_n_f16(g01, 0.25);
float16x4_t dst02 = vmul_n_f16(g02, 0.25);
float16x4_t dst10 = vmul_n_f16(vadd_f16(g00, vadd_f16(g10, g20)), -0.1666666666667);
float16x4_t dst11 = vmul_n_f16(vadd_f16(g01, vadd_f16(g11, g21)), -0.1666666666667);
float16x4_t dst12 = vmul_n_f16(vadd_f16(g02, vadd_f16(g12, g22)), -0.1666666666667);
float16x4_t dst20 = vmul_n_f16(vsub_f16(vadd_f16(g00, g20), g10), -0.1666666666667);
float16x4_t dst21 = vmul_n_f16(vsub_f16(vadd_f16(g01, g21), g11), -0.1666666666667);
float16x4_t dst22 = vmul_n_f16(vsub_f16(vadd_f16(g02, g22), g12), -0.1666666666667);
float16x4_t dst30 = vadd_f16(vmul_n_f16(g10, 0.08333333333333),
vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667)));
float16x4_t dst31 = vadd_f16(vmul_n_f16(g11, 0.08333333333333),
vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667)));
float16x4_t dst32 = vadd_f16(vmul_n_f16(g12, 0.08333333333333),
vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667)));
float16x4_t dst40 = vsub_f16(vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667)),
vmul_n_f16(g10, 0.08333333333333));
float16x4_t dst41 = vsub_f16(vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667)),
vmul_n_f16(g11, 0.08333333333333));
float16x4_t dst42 = vsub_f16(vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667)),
vmul_n_f16(g12, 0.08333333333333));
float16x4_t dst50 = g20;
float16x4_t dst51 = g21;
float16x4_t dst52 = g22;
float16x4_t m00 = vmul_n_f16(dst00, 0.25);
float16x4_t m01 = vmul_n_f16(vadd_f16(dst00, vadd_f16(dst01, dst02)), -0.1666666666667);
float16x4_t m02 = vmul_n_f16(vsub_f16(vadd_f16(dst00, dst02), dst01), -0.1666666666667);
float16x4_t m03 = vadd_f16(vmul_n_f16(dst01, 0.08333333333333),
vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667)));
float16x4_t m04 = vsub_f16(vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667)),
vmul_n_f16(dst01, 0.08333333333333));
float16x4_t m05 = dst02;
float16x4_t m10 = vmul_n_f16(dst10, 0.25);
float16x4_t m11 = vmul_n_f16(vadd_f16(dst10, vadd_f16(dst11, dst12)), -0.1666666666667);
float16x4_t m12 = vmul_n_f16(vsub_f16(vadd_f16(dst10, dst12), dst11), -0.1666666666667);
float16x4_t m13 = vadd_f16(vmul_n_f16(dst11, 0.08333333333333),
vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667)));
float16x4_t m14 = vsub_f16(vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667)),
vmul_n_f16(dst11, 0.08333333333333));
float16x4_t m15 = dst12;
float16x4_t m20 = vmul_n_f16(dst20, 0.25);
float16x4_t m21 = vmul_n_f16(vadd_f16(dst20, vadd_f16(dst21, dst22)), -0.1666666666667);
float16x4_t m22 = vmul_n_f16(vsub_f16(vadd_f16(dst20, dst22), dst21), -0.1666666666667);
float16x4_t m23 = vadd_f16(vmul_n_f16(dst21, 0.08333333333333),
vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667)));
float16x4_t m24 = vsub_f16(vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667)),
vmul_n_f16(dst21, 0.08333333333333));
float16x4_t m25 = dst22;
float16x4_t m30 = vmul_n_f16(dst30, 0.25);
float16x4_t m31 = vmul_n_f16(vadd_f16(dst30, vadd_f16(dst31, dst32)), -0.1666666666667);
float16x4_t m32 = vmul_n_f16(vsub_f16(vadd_f16(dst30, dst32), dst31), -0.1666666666667);
float16x4_t m33 = vadd_f16(vmul_n_f16(dst31, 0.08333333333333),
vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667)));
float16x4_t m34 = vsub_f16(vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667)),
vmul_n_f16(dst31, 0.08333333333333));
float16x4_t m35 = dst32;
float16x4_t m40 = vmul_n_f16(dst40, 0.25);
float16x4_t m41 = vmul_n_f16(vadd_f16(dst40, vadd_f16(dst41, dst42)), -0.1666666666667);
float16x4_t m42 = vmul_n_f16(vsub_f16(vadd_f16(dst40, dst42), dst41), -0.1666666666667);
float16x4_t m43 = vadd_f16(vmul_n_f16(dst41, 0.08333333333333),
vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667)));
float16x4_t m44 = vsub_f16(vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667)),
vmul_n_f16(dst41, 0.08333333333333));
float16x4_t m45 = dst42;
float16x4_t m50 = vmul_n_f16(dst50, 0.25);
float16x4_t m51 = vmul_n_f16(vadd_f16(dst50, vadd_f16(dst51, dst52)), -0.1666666666667);
float16x4_t m52 = vmul_n_f16(vsub_f16(vadd_f16(dst50, dst52), dst51), -0.1666666666667);
float16x4_t m53 = vadd_f16(vmul_n_f16(dst51, 0.08333333333333),
vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667)));
float16x4_t m54 = vsub_f16(vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667)),
vmul_n_f16(dst51, 0.08333333333333));
float16x4_t m55 = dst52;
for (int j = 0; j < 4; j++) {
dst_ic4_ptr[j * 8] = m00[j];
dst_ic4_ptr[j * 8 + dst_step] = m01[j];
dst_ic4_ptr[j * 8 + 2 * dst_step] = m02[j];
dst_ic4_ptr[j * 8 + 3 * dst_step] = m03[j];
dst_ic4_ptr[j * 8 + 4 * dst_step] = m04[j];
dst_ic4_ptr[j * 8 + 5 * dst_step] = m05[j];
dst_ic4_ptr[j * 8 + 6 * dst_step] = m10[j];
dst_ic4_ptr[j * 8 + 7 * dst_step] = m11[j];
dst_ic4_ptr[j * 8 + 8 * dst_step] = m12[j];
dst_ic4_ptr[j * 8 + 9 * dst_step] = m13[j];
dst_ic4_ptr[j * 8 + 10 * dst_step] = m14[j];
dst_ic4_ptr[j * 8 + 11 * dst_step] = m15[j];
dst_ic4_ptr[j * 8 + 12 * dst_step] = m20[j];
dst_ic4_ptr[j * 8 + 13 * dst_step] = m21[j];
dst_ic4_ptr[j * 8 + 14 * dst_step] = m22[j];
dst_ic4_ptr[j * 8 + 15 * dst_step] = m23[j];
dst_ic4_ptr[j * 8 + 16 * dst_step] = m24[j];
dst_ic4_ptr[j * 8 + 17 * dst_step] = m25[j];
dst_ic4_ptr[j * 8 + 18 * dst_step] = m30[j];
dst_ic4_ptr[j * 8 + 19 * dst_step] = m31[j];
dst_ic4_ptr[j * 8 + 20 * dst_step] = m32[j];
dst_ic4_ptr[j * 8 + 21 * dst_step] = m33[j];
dst_ic4_ptr[j * 8 + 22 * dst_step] = m34[j];
dst_ic4_ptr[j * 8 + 23 * dst_step] = m35[j];
dst_ic4_ptr[j * 8 + 24 * dst_step] = m40[j];
dst_ic4_ptr[j * 8 + 25 * dst_step] = m41[j];
dst_ic4_ptr[j * 8 + 26 * dst_step] = m42[j];
dst_ic4_ptr[j * 8 + 27 * dst_step] = m43[j];
dst_ic4_ptr[j * 8 + 28 * dst_step] = m44[j];
dst_ic4_ptr[j * 8 + 29 * dst_step] = m45[j];
dst_ic4_ptr[j * 8 + 30 * dst_step] = m50[j];
dst_ic4_ptr[j * 8 + 31 * dst_step] = m51[j];
dst_ic4_ptr[j * 8 + 32 * dst_step] = m52[j];
dst_ic4_ptr[j * 8 + 33 * dst_step] = m53[j];
dst_ic4_ptr[j * 8 + 34 * dst_step] = m54[j];
dst_ic4_ptr[j * 8 + 35 * dst_step] = m55[j];
}
}
}
}
void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data,
int output_w) {
float16x8_t s00 = vld1q_f16(gemm_out);
float16x8_t s01 = vld1q_f16(gemm_out + 8);
float16x8_t s02 = vld1q_f16(gemm_out + 16);
float16x8_t s03 = vld1q_f16(gemm_out + 24);
float16x8_t s04 = vld1q_f16(gemm_out + 32);
float16x8_t s05 = vld1q_f16(gemm_out + 40);
float16x8_t s10 = vld1q_f16(gemm_out + 48);
float16x8_t s11 = vld1q_f16(gemm_out + 56);
float16x8_t s12 = vld1q_f16(gemm_out + 64);
float16x8_t s13 = vld1q_f16(gemm_out + 72);
float16x8_t s14 = vld1q_f16(gemm_out + 80);
float16x8_t s15 = vld1q_f16(gemm_out + 88);
float16x8_t s20 = vld1q_f16(gemm_out + 96);
float16x8_t s21 = vld1q_f16(gemm_out + 104);
float16x8_t s22 = vld1q_f16(gemm_out + 112);
float16x8_t s23 = vld1q_f16(gemm_out + 120);
float16x8_t s24 = vld1q_f16(gemm_out + 128);
float16x8_t s25 = vld1q_f16(gemm_out + 136);
float16x8_t s30 = vld1q_f16(gemm_out + 144);
float16x8_t s31 = vld1q_f16(gemm_out + 152);
float16x8_t s32 = vld1q_f16(gemm_out + 160);
float16x8_t s33 = vld1q_f16(gemm_out + 168);
float16x8_t s34 = vld1q_f16(gemm_out + 176);
float16x8_t s35 = vld1q_f16(gemm_out + 184);
float16x8_t s40 = vld1q_f16(gemm_out + 192);
float16x8_t s41 = vld1q_f16(gemm_out + 200);
float16x8_t s42 = vld1q_f16(gemm_out + 208);
float16x8_t s43 = vld1q_f16(gemm_out + 216);
float16x8_t s44 = vld1q_f16(gemm_out + 224);
float16x8_t s45 = vld1q_f16(gemm_out + 232);
float16x8_t s50 = vld1q_f16(gemm_out + 240);
float16x8_t s51 = vld1q_f16(gemm_out + 248);
float16x8_t s52 = vld1q_f16(gemm_out + 256);
float16x8_t s53 = vld1q_f16(gemm_out + 264);
float16x8_t s54 = vld1q_f16(gemm_out + 272);
float16x8_t s55 = vld1q_f16(gemm_out + 280);
float16x8_t t00 = vaddq_f16(vaddq_f16(vaddq_f16(s00, s10), vaddq_f16(s20, s30)), s40);
float16x8_t t01 = vaddq_f16(vaddq_f16(vaddq_f16(s01, s11), vaddq_f16(s21, s31)), s41);
float16x8_t t02 = vaddq_f16(vaddq_f16(vaddq_f16(s02, s12), vaddq_f16(s22, s32)), s42);
float16x8_t t03 = vaddq_f16(vaddq_f16(vaddq_f16(s03, s13), vaddq_f16(s23, s33)), s43);
float16x8_t t04 = vaddq_f16(vaddq_f16(vaddq_f16(s04, s14), vaddq_f16(s24, s34)), s44);
float16x8_t t05 = vaddq_f16(vaddq_f16(vaddq_f16(s05, s15), vaddq_f16(s25, s35)), s45);
float16x8_t t10 = vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 2));
float16x8_t t11 = vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 2));
float16x8_t t12 = vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 2));
float16x8_t t13 = vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 2));
float16x8_t t14 = vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 2));
float16x8_t t15 = vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 2));
float16x8_t t20 = vaddq_f16(vaddq_f16(s10, s20), vmulq_n_f16(vaddq_f16(s30, s40), 4));
float16x8_t t21 = vaddq_f16(vaddq_f16(s11, s21), vmulq_n_f16(vaddq_f16(s31, s41), 4));
float16x8_t t22 = vaddq_f16(vaddq_f16(s12, s22), vmulq_n_f16(vaddq_f16(s32, s42), 4));
float16x8_t t23 = vaddq_f16(vaddq_f16(s13, s23), vmulq_n_f16(vaddq_f16(s33, s43), 4));
float16x8_t t24 = vaddq_f16(vaddq_f16(s14, s24), vmulq_n_f16(vaddq_f16(s34, s44), 4));
float16x8_t t25 = vaddq_f16(vaddq_f16(s15, s25), vmulq_n_f16(vaddq_f16(s35, s45), 4));
float16x8_t t30 = vaddq_f16(vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 8)), s50);
float16x8_t t31 = vaddq_f16(vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 8)), s51);
float16x8_t t32 = vaddq_f16(vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 8)), s52);
float16x8_t t33 = vaddq_f16(vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 8)), s53);
float16x8_t t34 = vaddq_f16(vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 8)), s54);
float16x8_t t35 = vaddq_f16(vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 8)), s55);
float16x8_t bias_ptr = vld1q_f16(bias_data);
float16x8_t d00 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t00, t01), vaddq_f16(t02, t03)), t04), bias_ptr);
float16x8_t d01 = vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 2)), bias_ptr);
float16x8_t d02 = vaddq_f16(vaddq_f16(vaddq_f16(t01, t02), vmulq_n_f16(vaddq_f16(t03, t04), 4)), bias_ptr);
float16x8_t d03 =
vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 8)), t05), bias_ptr);
float16x8_t d10 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t10, t11), vaddq_f16(t12, t13)), t14), bias_ptr);
float16x8_t d11 = vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 2)), bias_ptr);
float16x8_t d12 = vaddq_f16(vaddq_f16(vaddq_f16(t11, t12), vmulq_n_f16(vaddq_f16(t13, t14), 4)), bias_ptr);
float16x8_t d13 =
vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 8)), t15), bias_ptr);
float16x8_t d20 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t20, t21), vaddq_f16(t22, t23)), t24), bias_ptr);
float16x8_t d21 = vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 2)), bias_ptr);
float16x8_t d22 = vaddq_f16(vaddq_f16(vaddq_f16(t21, t22), vmulq_n_f16(vaddq_f16(t23, t24), 4)), bias_ptr);
float16x8_t d23 =
vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 8)), t25), bias_ptr);
float16x8_t d30 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t30, t31), vaddq_f16(t32, t33)), t34), bias_ptr);
float16x8_t d31 = vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 2)), bias_ptr);
float16x8_t d32 = vaddq_f16(vaddq_f16(vaddq_f16(t31, t32), vmulq_n_f16(vaddq_f16(t33, t34), 4)), bias_ptr);
float16x8_t d33 =
vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 8)), t35), bias_ptr);
vst1q_f16(output_data, d00);
vst1q_f16(output_data + 8, d01);
vst1q_f16(output_data + 16, d02);
vst1q_f16(output_data + 24, d03);
vst1q_f16(output_data + output_w * 8, d10);
vst1q_f16(output_data + output_w * 8 + 8, d11);
vst1q_f16(output_data + output_w * 8 + 16, d12);
vst1q_f16(output_data + output_w * 8 + 24, d13);
vst1q_f16(output_data + 2 * output_w * 8, d20);
vst1q_f16(output_data + 2 * output_w * 8 + 8, d21);
vst1q_f16(output_data + 2 * output_w * 8 + 16, d22);
vst1q_f16(output_data + 2 * output_w * 8 + 24, d23);
vst1q_f16(output_data + 3 * output_w * 8, d30);
vst1q_f16(output_data + 3 * output_w * 8 + 8, d31);
vst1q_f16(output_data + 3 * output_w * 8 + 16, d32);
vst1q_f16(output_data + 3 * output_w * 8 + 24, d33);
}
void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data,
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) {
int output_channel = conv_param->output_channel_;
int output_h = conv_param->output_h_;
int out_h_block = UP_DIV(output_h, C4NUM);
int oc8 = UP_DIV(output_channel, C8NUM);
if (out_w_block == 0) {
return;
}
for (int i = 0; i < real_cal_num; i++) {
int out_w_index = (start_index + i) % out_w_block;
int out_h_index = (start_index + i) / out_w_block;
int src_tile_offset = i * oc8 * C8NUM * 36;
int dst_tile_offset = C8NUM * (out_w_index * C4NUM + out_h_index * C4NUM * out_w_block * C4NUM);
for (int j = 0; j < oc8; j++) {
int src_oc8_offset = src_tile_offset + j * 36 * C8NUM;
int dst_oc8_offset = dst_tile_offset + j * C8NUM * out_h_block * out_w_block * C4NUM * C4NUM;
const float16_t *src_ptr = gemm_out + src_oc8_offset;
const float16_t *bias_ptr = bias_data + j * C8NUM;
float16_t *dst_ptr = out_data + dst_oc8_offset;
// output transform
Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, out_w_block * C4NUM);
}
}
}
// fp16 common winograd
void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num,
int out_tile_index, int out_w_block_num, ConvParameter *conv_param,
InputTransFp16Func func) {
#ifdef ENABLE_ARM64
const int tile_num = 16;
#else
const int tile_num = 12;
#endif
int input_unit = conv_param->input_unit_;
int output_unit = conv_param->output_unit_;
int in_channel = conv_param->input_channel_;

View File

@ -23,25 +23,11 @@
#include "nnacl/fp16/cast_fp16.h"
#include "nnacl/fp16/conv_fp16.h"
#include "nnacl/fp16/matrix_fp16.h"
#include "nnacl/fp16/pack_fp16.h"
#ifdef __cplusplus
extern "C" {
#endif
// for fp16 convolution 3x3 filter/input/output transform
void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step);
void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data,
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param);
void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC8, int output_channel,
int kernel_plane);
void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, int output_w);
void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data,
int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param);
// fp16 common winograd
void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num,
int out_tile_index, int out_w_block_num, ConvParameter *conv_param,

View File

@ -17,9 +17,9 @@
#ifndef MINDSPORE_NNACL_FP16_WINOGRAD_UTILS_H_
#define MINDSPORE_NNACL_FP16_WINOGRAD_UTILS_H_
#include <arm_neon.h>
#include "nnacl/conv_parameter.h"
#include "nnacl/op_base.h"
#include "nnacl/intrinsics/ms_simd_instructions_fp16.h"
#define MAX_LEN 256

View File

@ -17,9 +17,11 @@
#ifndef MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
#define MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_
#include <math.h>
#ifdef ENABLE_ARM
#include <arm_neon.h>
#endif
#if defined(ENABLE_SSE) || defined(ENABLE_AVX)
#include <x86intrin.h>
#endif
@ -46,7 +48,7 @@
#ifdef ENABLE_ARM64
#define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2)
#else
inline static float32x4_t vrecp(float32x4_t v) {
static inline float32x4_t vrecp(float32x4_t v) {
float32x4_t r = vrecpeq_f32(v);
r = vmulq_f32(vrecpsq_f32(v, r), r);
r = vmulq_f32(vrecpsq_f32(v, r), r);
@ -205,25 +207,4 @@ static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) {
return dst;
}
#ifdef ENABLE_ARM64
static inline float16x8_t MS_TANHX8_F16(float16x8_t src) {
float32x4_t src_low = vcvt_f32_f16(vget_low_f16(src));
float32x4_t src_high = vcvt_f32_f16(vget_high_f16(src));
return vcombine_f16(vcvt_f16_f32(MS_TANHX4_F32(src_low)), vcvt_f16_f32(MS_TANHX4_F32(src_high)));
}
static inline float16x8_t MS_ERFX8_F16(float16x8_t src) {
float16x8_t dst;
dst[0] = erff(src[0]);
dst[1] = erff(src[1]);
dst[2] = erff(src[2]);
dst[3] = erff(src[3]);
dst[4] = erff(src[4]);
dst[5] = erff(src[5]);
dst[6] = erff(src[6]);
dst[7] = erff(src[7]);
return dst;
}
#endif
#endif // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_

View File

@ -0,0 +1,99 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_
#define MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_
#include <math.h>
#include "nnacl/intrinsics/ms_simd_instructions.h"
#if defined(ENABLE_ARM82_A32)
static inline float16x8_t divq_f16(float16x8_t in1, float16x8_t in2) {
float16x8_t dst;
asm volatile(
"vrecpe.f16 q14, %3\n"
"vrecps.f16 q15, %3, q14\n"
"vmul.f16 q14, q15, q14\n"
"vrecps.f16 q15, %3, q14\n"
"vmul.f16 q14, q15, q14\n"
"vmul.f16 %0, %2, q14\n"
: "=w"(dst)
: "0"(dst), "w"(in1), "w"(in2)
: "q14", "q15");
return dst;
}
static inline float16x4_t div_f16(float16x4_t in1, float16x4_t in2) {
float16x4_t dst;
asm volatile(
"vrecpe.f16 d14, %3\n"
"vrecps.f16 d16, %3, d14\n"
"vmul.f16 d14, d16, d14\n"
"vrecps.f16 d16, %3, d14\n"
"vmul.f16 d14, d16, d14\n"
"vmul.f16 %0, %2, d14\n"
: "=w"(dst)
: "0"(dst), "w"(in1), "w"(in2)
: "d14", "d16");
return dst;
}
static inline float vaddvq_f32(float32x4_t in) { // is not support in arm82 aarch32
return in[0] + in[1] + in[2] + in[3];
}
static inline float32x4_t cvt_f32_f16(float16x4_t in) {
float32x4_t dst;
asm volatile("vcvt.f32.f16 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :);
return dst;
}
static inline float16x4_t cvt_f16_f32(float32x4_t in) {
float16x4_t dst;
asm volatile("vcvt.f16.f32 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :);
return dst;
}
#define MS_CVT_F32_F16(src) cvt_f32_f16(src)
#define MS_CVT_F16_F32(src) cvt_f16_f32(src)
#define MS_DIV_F16(src1, src2) div_f16(src1, src2)
#define MS_DIVQ_F16(src1, src2) divq_f16(src1, src2)
#define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_f16(src1, src2, vdupq_n_f16(src3))
#else
#define MS_CVT_F32_F16(src) vcvt_f32_f16(src)
#define MS_CVT_F16_F32(src) vcvt_f16_f32(src)
#define MS_DIV_F16(src1, src2) vdiv_f16(src1, src2)
#define MS_DIVQ_F16(src1, src2) vdivq_f16(src1, src2)
#define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_n_f16(src1, src2, src3)
#endif
static inline float16x8_t MS_TANHX8_F16(float16x8_t src) {
float32x4_t src_low = MS_CVT_F32_F16(vget_low_f16(src));
float32x4_t src_high = MS_CVT_F32_F16(vget_high_f16(src));
return vcombine_f16(MS_CVT_F16_F32(MS_TANHX4_F32(src_low)), MS_CVT_F16_F32(MS_TANHX4_F32(src_high)));
}
static inline float16x8_t MS_ERFX8_F16(float16x8_t src) {
float16x8_t dst;
dst[0] = erff(src[0]);
dst[1] = erff(src[1]);
dst[2] = erff(src[2]);
dst[3] = erff(src[3]);
dst[4] = erff(src[4]);
dst[5] = erff(src[5]);
dst[6] = erff(src[6]);
dst[7] = erff(src[7]);
return dst;
}
#endif // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_

View File

@ -4,11 +4,15 @@ set(NNACL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..)
include_directories(NNACL_DIR)
########################### optimized files ###########################
file(GLOB SDOT_SRC ${NNACL_DIR}/assembly/opt/*.S)
file(GLOB FP16_C_SRC ${NNACL_DIR}/fp16/*.c)
file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/fp16/*.S)
if(PLATFORM_ARM32)
file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/arm82_aarch32_fp16/*.S)
else()
file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/fp16/*.S)
file(GLOB SDOT_SRC ${NNACL_DIR}/assembly/opt/*.S)
set_property(SOURCE ${SDOT_SRC} PROPERTY LANGUAGE C)
endif()
set_property(SOURCE ${SDOT_SRC} PROPERTY LANGUAGE C)
set_property(SOURCE ${FP16_C_SRC} PROPERTY LANGUAGE C)
set_property(SOURCE ${FP16_NEON_SRC} PROPERTY LANGUAGE C)
@ -17,7 +21,6 @@ if(APPLE)
set_source_files_properties(${FP16_NEON_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp")
endif()
########################### share library build ########################
list(APPEND SDOT_FILES ${SDOT_SRC})
list(APPEND FP16_FILES ${FP16_C_SRC})
list(APPEND FP16_FILES ${FP16_NEON_SRC})
@ -27,13 +30,20 @@ if(SUPPORT_TRAIN)
endif()
string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16")
add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES})
add_dependencies(nnacl_optimize_mid fbs_src)
if(NOT PLATFORM_ARM32)
list(APPEND SDOT_FILES ${SDOT_SRC})
add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES})
add_dependencies(nnacl_optimize_mid fbs_src)
endif()
if(ENABLE_FP16)
add_library(nnacl_fp16_mid OBJECT ${FP16_FILES})
if(PLATFORM_ARM32)
target_compile_options(nnacl_fp16_mid PRIVATE -march=armv8.2-a+fp16 -mfpu=neon-fp-armv8 -mfloat-abi=softfp)
else()
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16")
endif()
add_dependencies(nnacl_fp16_mid fbs_src)
endif()

View File

@ -5,6 +5,12 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_L
message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0")
endif()
if(PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0)
set(ENABLE_FP16 "off")
message(WARNING "If you want to build fp16 in arm82_a32, \
your Clang version:[${CMAKE_CXX_COMPILER_VERSION}] must not be less than 9.0 and please use android nkd r21e!")
endif()
option(MS_VERSION_MAJOR "major version" 0)
option(MS_VERSION_MINOR "minor version" 7)
option(MS_VERSION_REVISION "revision version" 0)
@ -15,6 +21,7 @@ option(PLATFORM_ARM64 "if build device for arm64" off)
option(PLATFORM_ARM32 "if build device for arm32" off)
option(ENABLE_CONVERTER "if build converter" on)
option(ENABLE_FP16 "if build fp16 ops" off)
option(ENABLE_ARM82_A32 "if build fp16 on platform_arm32" off)
option(ENABLE_TOOLS "if build tools" on)
option(BUILD_TESTCASES "if build testcase" on)
option(SUPPORT_GPU "if support gpu" off)
@ -177,6 +184,9 @@ if(ENABLE_NEON)
endif()
if(ENABLE_FP16)
add_compile_definitions(ENABLE_FP16)
if(PLATFORM_ARM32)
add_compile_definitions(ENABLE_ARM82_A32)
endif()
endif()
if(SUPPORT_GPU STREQUAL opencl)
add_definitions(-DGPU_OPENCL)

View File

@ -3,6 +3,9 @@ if(ENABLE_V0)
add_definitions(-DENABLE_V0)
endif()
include_directories(${CCSRC_DIR}/backend/kernel_compiler/cpu)
set(LITE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..)
include_directories(${LITE_DIR}/nnacl/)
include_directories(${LITE_DIR}/nnacl/optimize)
if(PLATFORM_ARM32 OR PLATFORM_ARM64)
#for performance
@ -209,9 +212,11 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Release")
endif()
########################## build optimize and float16 library #################################
if(PLATFORM_ARM64)
target_link_libraries(mindspore-lite cpu_opt_kernel_mid nnacl_optimize_mid)
target_link_libraries(mindspore-lite_static cpu_opt_kernel_mid nnacl_optimize_mid)
if(PLATFORM_ARM)
if(PLATFORM_ARM64)
target_link_libraries(mindspore-lite cpu_opt_kernel_mid nnacl_optimize_mid)
target_link_libraries(mindspore-lite_static cpu_opt_kernel_mid nnacl_optimize_mid)
endif()
if(ENABLE_FP16)
target_link_libraries(mindspore-lite cpu_fp16_kernel_mid nnacl_fp16_mid)
target_link_libraries(mindspore-lite_static cpu_fp16_kernel_mid nnacl_fp16_mid)
@ -247,8 +252,10 @@ if(DEFINED ARCHS)
target_link_libraries(mindspore_lite mindrt_mid)
endif()
if(PLATFORM_ARM64)
target_link_libraries(mindspore_lite cpu_opt_kernel_mid nnacl_optimize_mid)
if(PLATFORM_ARM)
if(PLATFORM_ARM64)
target_link_libraries(mindspore_lite cpu_opt_kernel_mid nnacl_optimize_mid)
endif()
if(ENABLE_FP16)
target_link_libraries(mindspore_lite cpu_fp16_kernel_mid nnacl_fp16_mid)
endif()

View File

@ -155,7 +155,11 @@ bool IsSupportSDot() {
bool IsSupportFloat16() {
bool status = false;
#ifdef ENABLE_ARM64
#ifdef ENABLE_ARM32
status = true;
#endif
#if defined(ENABLE_ARM64)
#if defined(__ANDROID__)
int hwcap_type = 16;
uint32_t hwcap = getHwCap(hwcap_type);

View File

@ -44,7 +44,7 @@ uint64_t GetTimeUs();
bool IsSupportSDot();
bool IsSupportFloat16();
#if defined(__arm__) || defined(__aarch64__)
#if defined(__arm__)
uint32_t getHwCap(int hwcap_type);
#endif

View File

@ -19,7 +19,7 @@
#include "src/common/version_manager.h"
#include "nnacl/pooling_parameter.h"
#include "src/ios_reg_kernels.h"
#ifdef ENABLE_ARM64
#if defined(ENABLE_FP16) && defined(ENABLE_ARM)
#if defined(__ANDROID__)
#include <asm/hwcap.h>
#endif
@ -55,6 +55,8 @@ int KernelRegistry::Init() {
} else {
MS_LOG(INFO) << "The current device NOT supports Sdot.";
}
#endif
#ifdef ENABLE_FP16
if (mindspore::lite::IsSupportFloat16()) {
MS_LOG(INFO) << "The current device supports float16.";
} else {

View File

@ -17,7 +17,7 @@ endif()
add_library(cpu_kernel_mid OBJECT ${KERNEL_SRC})
add_dependencies(cpu_kernel_mid fbs_src)
if(PLATFORM_ARM64)
if(PLATFORM_ARM)
if(ENABLE_FP16)
file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc)
if(SUPPORT_TRAIN)

View File

@ -52,7 +52,7 @@ int ConstantOfShapeCPUKernel::DoExecute(int task_id) {
ConstantOfShapeInt32(reinterpret_cast<int32_t *>(output_ptr_), start, start + current_stride,
param_->value_.int32_value_);
break;
#ifdef ENABLE_NEON
#ifdef ENABLE_FP16
case kNumberTypeFloat16:
ConstantOfShapeFp16(reinterpret_cast<float16_t *>(output_ptr_), start, start + current_stride,
param_->value_.f32_value_);

View File

@ -31,8 +31,8 @@ int Convolution1x1FP16CPUKernel::InitMatmulParam() {
matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_;
matmul_param_->col_ = conv_param_->output_channel_;
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->row_16_ = UP_ROUND(matmul_param_->row_, C16NUM);
matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM);
matmul_param_->row_align_ = UP_ROUND(matmul_param_->row_, row_tile_);
matmul_param_->col_align_ = UP_ROUND(matmul_param_->col_, col_tile_);
matmul_param_->act_type_ = conv_param_->act_type_;
return RET_OK;
}
@ -54,14 +54,14 @@ int Convolution1x1FP16CPUKernel::InitConv1x1Param() {
pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 ||
conv_param_->stride_w_ != 1);
if ((matmul_param_->row_ > (C16NUM * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) {
if ((matmul_param_->row_ > (row_tile_ * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) {
multi_thread_by_hw_ = true;
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, C16NUM));
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, C16NUM), thread_count_) * C16NUM;
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, row_tile_));
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, row_tile_), thread_count_) * row_tile_;
} else {
multi_thread_by_hw_ = false;
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM));
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM;
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, col_tile_));
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, col_tile_), thread_count_) * col_tile_;
}
if (pre_trans_input_) {
@ -81,8 +81,8 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() {
auto output_channel = weight_tensor->Batch();
if (in_tensors_.size() == 3) {
size_t size = UP_ROUND(output_channel, C8NUM) * sizeof(float16_t);
size_t weight_size = output_channel * sizeof(float16_t);
size_t size = UP_ROUND(output_channel, col_tile_) * sizeof(float16_t);
size_t bias_size = output_channel * sizeof(float16_t);
bias_data_ = malloc(size);
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!";
@ -94,11 +94,11 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() {
Float32ToFloat16(reinterpret_cast<float *>(origin_bias_), reinterpret_cast<float16_t *>(bias_data_),
output_channel);
}
memset(reinterpret_cast<char *>(bias_data_) + weight_size, 0, size - weight_size);
memset(reinterpret_cast<char *>(bias_data_) + bias_size, 0, size - bias_size);
}
size_t size = input_channel * UP_ROUND(output_channel, C8NUM) * sizeof(float16_t);
size_t down_size = input_channel * DOWN_DIV(output_channel, C8NUM) * C8NUM * sizeof(float16_t);
size_t size = input_channel * UP_ROUND(output_channel, col_tile_) * sizeof(float16_t);
size_t down_size = input_channel * DOWN_DIV(output_channel, col_tile_) * col_tile_ * sizeof(float16_t);
weight_ptr_ = reinterpret_cast<float16_t *>(malloc(size));
if (weight_ptr_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!";
@ -111,6 +111,12 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() {
}
int Convolution1x1FP16CPUKernel::Init() {
col_tile_ = C8NUM;
#ifdef ENABLE_ARM64
row_tile_ = C16NUM;
#else
row_tile_ = C12NUM;
#endif
matmul_param_ = new (std::nothrow) MatMulParameter();
if (matmul_param_ == nullptr) {
MS_LOG(ERROR) << "Init matmul_param_ failed.";
@ -177,8 +183,11 @@ int Convolution1x1FP16CPUKernel::RunHw(int task_id) {
float16_t *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_;
float16_t *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_;
#ifdef ENABLE_ARM64
RowMajor2Col16MajorFp16Opt(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
#else
RowMajor2Col12MajorFp16Opt(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
#endif
float16_t *thread_output_ptr = output_ptr_ + task_id * thread_stride_ * matmul_param_->col_;
MatMulFp16(thread_pack_input, weight_ptr_, thread_output_ptr, reinterpret_cast<float16_t *>(bias_data_),
matmul_param_->act_type_, matmul_param_->deep_, cur_hw_, matmul_param_->col_, matmul_param_->col_,
@ -211,7 +220,7 @@ int Convolution1x1FP16CPUKernel::Run() {
ConvolutionBaseFP16CPUKernel::GetExecuteTensor();
pack_input_ = reinterpret_cast<float16_t *>(
ctx_->allocator->Malloc(matmul_param_->row_16_ * matmul_param_->deep_ * sizeof(float16_t)));
ctx_->allocator->Malloc(matmul_param_->row_align_ * matmul_param_->deep_ * sizeof(float16_t)));
if (pack_input_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!";
return RET_MEMORY_FAILED;
@ -231,7 +240,11 @@ int Convolution1x1FP16CPUKernel::Run() {
if (multi_thread_by_hw_) {
ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Fp16RunHw, this, thread_count_);
} else {
#ifdef ENABLE_ARM64
RowMajor2Col16MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
#else
RowMajor2Col12MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
#endif
ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Fp16RunOc, this, thread_count_);
}
if (ret != RET_OK) {

View File

@ -62,6 +62,8 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
float16_t *pack_input_ = nullptr;
float16_t *output_ptr_ = nullptr;
MatMulParameter *matmul_param_ = nullptr;
int col_tile_;
int row_tile_;
};
} // namespace mindspore::kernel

View File

@ -34,7 +34,7 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
int out_channel = filter_tensor->Batch();
conv_param_->input_channel_ = in_channel;
conv_param_->output_channel_ = out_channel;
int oc8 = UP_ROUND(out_channel, C8NUM);
int oc8 = UP_ROUND(out_channel, col_tile_);
int kernel_plane = filter_tensor->Height() * filter_tensor->Width();
int pack_weight_size = oc8 * in_channel * kernel_plane;
@ -68,9 +68,8 @@ int ConvolutionFP16CPUKernel::InitWeightBias() {
}
int ConvolutionFP16CPUKernel::InitTmpBuffer() {
const int cal_num = 16;
int unit_size =
conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * cal_num * thread_count_;
conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * row_tile_ * thread_count_;
packed_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(unit_size * sizeof(float16_t)));
if (packed_input_ == nullptr) {
@ -87,6 +86,12 @@ int ConvolutionFP16CPUKernel::InitTmpBuffer() {
}
int ConvolutionFP16CPUKernel::Init() {
#ifdef ENABLE_ARM64
row_tile_ = C16NUM;
#else
row_tile_ = C12NUM;
#endif
col_tile_ = C8NUM;
auto ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.";
@ -98,7 +103,7 @@ int ConvolutionFP16CPUKernel::Init() {
void ConvolutionFP16CPUKernel::AdjustNumberOfThread() {
auto out_tensor = out_tensors_.front();
int out_plane = out_tensor->Height() * out_tensor->Width();
thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(out_plane, C16NUM));
thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(out_plane, row_tile_));
conv_param_->thread_num_ = thread_count_;
}

View File

@ -62,6 +62,8 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
float16_t *packed_input_ = nullptr;
float16_t *packed_weight_ = nullptr;
float16_t *col_major_input_ = nullptr;
int col_tile_;
int row_tile_;
};
} // namespace mindspore::kernel

View File

@ -38,13 +38,10 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
int out_channel = filter_tensor->Batch();
conv_param_->input_channel_ = in_channel;
conv_param_->output_channel_ = out_channel;
const int oc_block = C8NUM;
int oc_block_num = UP_DIV(out_channel, C8NUM);
int oc_block_num = UP_DIV(out_channel, col_tile_);
// init weight
// set data
auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float16_t);
auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * col_tile_ * sizeof(float16_t);
trans_weight_ = reinterpret_cast<float16_t *>(malloc(trans_matrix_data_size));
if (trans_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc trans_weight_ failed.";
@ -73,7 +70,7 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
MS_LOG(ERROR) << "get execute filter failed.";
return ret;
}
ret = WinogradFilterTransformFp16(execute_weight_, matrix_g, matrix_gt, oc_block);
ret = WinogradFilterTransformFp16(execute_weight_, matrix_g, matrix_gt, col_tile_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "winograd filter transform failed.";
return ret;
@ -85,12 +82,12 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
}
// init bias
bias_data_ = malloc(oc_block_num * oc_block * sizeof(float16_t));
bias_data_ = malloc(oc_block_num * col_tile_ * sizeof(float16_t));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_data_ failed.";
return RET_ERROR;
}
memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float16_t));
memset(bias_data_, 0, oc_block_num * col_tile_ * sizeof(float16_t));
if (in_tensors_.size() == kInputSize2) {
if (origin_bias_data_type_ == kNumberTypeFloat16) {
@ -105,11 +102,9 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
}
int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
const int cal_num = 16;
int channel_out = conv_param_->output_channel_;
size_t tile_buffer_size =
thread_count_ * cal_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t);
thread_count_ * row_tile_ * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t);
trans_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(tile_buffer_size));
if (trans_input_ == nullptr) {
MS_LOG(ERROR) << "malloc trans_input_ failed.";
@ -117,7 +112,7 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
}
gemm_out_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(
thread_count_ * cal_num * input_unit_ * input_unit_ * UP_ROUND(channel_out, C8NUM) * sizeof(float16_t)));
thread_count_ * row_tile_ * input_unit_ * input_unit_ * UP_ROUND(channel_out, C8NUM) * sizeof(float16_t)));
if (gemm_out_ == nullptr) {
MS_LOG(ERROR) << "malloc gemm_out_ failed.";
return RET_ERROR;
@ -131,7 +126,7 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
}
col_buffer_ = reinterpret_cast<float16_t *>(
ctx_->allocator->Malloc(thread_count_ * cal_num * conv_param_->input_channel_ * sizeof(float16_t)));
ctx_->allocator->Malloc(thread_count_ * row_tile_ * conv_param_->input_channel_ * sizeof(float16_t)));
if (col_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc col_buffer_ failed.";
return RET_ERROR;
@ -159,6 +154,12 @@ int ConvolutionWinogradFP16CPUKernel::ConfigInputOutput() {
}
int ConvolutionWinogradFP16CPUKernel::Init() {
col_tile_ = C8NUM;
#ifdef ENABLE_ARM64
row_tile_ = C16NUM;
#else
row_tile_ = C12NUM;
#endif
kernel_unit_ = conv_param_->kernel_h_;
input_unit_ = output_unit_ + kernel_unit_ - 1;
conv_param_->input_unit_ = input_unit_;

View File

@ -86,6 +86,8 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
TmpBufferAddressFp16 tmp_buffer_address_list_[4];
InputTransFp16Func in_func_;
OutputTransFp16Func out_func_;
int col_tile_;
int row_tile_;
};
} // namespace mindspore::kernel

View File

@ -13,26 +13,20 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef ENABLE_ARM64
#ifdef ENABLE_ARM
#include <arm_neon.h>
#endif
#include "nnacl/fp16/cast_fp16.h"
#ifdef __cplusplus
extern "C" {
#endif
#ifdef ENABLE_ARM64
extern void Float32ToFloat16(const float *input, float16_t *output, int number);
extern void Float16ToFloat32(const float16_t *input, float *output, int number);
inline void Float32ToFloat16_fp16_handler(const void *input, void *output, int number) {
static inline void Float32ToFloat16_fp16_handler(const void *input, void *output, int number) {
Float32ToFloat16(reinterpret_cast<const float *>(input), reinterpret_cast<float16_t *>(output), number);
}
inline void Float16ToFloat32_fp16_handler(const void *input, void *output, int number) {
static inline void Float16ToFloat32_fp16_handler(const void *input, void *output, int number) {
Float16ToFloat32(reinterpret_cast<const float16_t *>(input), reinterpret_cast<float *>(output), number);
}
#endif
#ifdef __cplusplus
}

View File

@ -53,8 +53,8 @@ int PoolingFp16CPUKernel::ReSize() {
}
int PoolingFp16CPUKernel::RunImpl(int task_id) {
float16_t minf = -FLT_MAX;
float16_t maxf = FLT_MAX;
float16_t minf = -FLT16_MAX;
float16_t maxf = FLT16_MAX;
if (pooling_param_->act_type_ == ActType_Relu) {
minf = 0.f;
} else if (pooling_param_->act_type_ == ActType_Relu6) {

View File

@ -45,7 +45,7 @@
#include "src/runtime/agent/npu/optimizer/npu_fusion_pass.h"
#include "src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h"
#endif
#if defined(ENABLE_ARM64) && defined(ENABLE_FP16)
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
#include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
#endif
@ -230,7 +230,7 @@ int CopyConstTensor(Tensor *tensor, std::map<Tensor *, Tensor *> *restored_origi
auto origin_data = tensor->data_c();
MS_ASSERT(origin_data != nullptr);
if (tensor->data_type() == kNumberTypeFloat32 && dst_data_type == kNumberTypeFloat16) {
#if defined(ENABLE_ARM64) && defined(ENABLE_FP16)
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
auto restore_tensor = Tensor::CopyTensor(*tensor, false);
restore_tensor->set_data(origin_data);
restore_tensor->set_own_data(tensor->own_data());

View File

@ -17,7 +17,7 @@
#include "src/sub_graph_kernel.h"
#include "src/tensor.h"
#include "src/tensorlist.h"
#if defined(ENABLE_ARM64) && defined(ENABLE_FP16)
#ifdef ENABLE_FP16
#include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
#endif
#include "src/common/version_manager.h"
@ -283,7 +283,7 @@ int CpuFp16SubGraph::Float16TensorToFloat32Tensor(lite::Tensor *tensor) {
}
int CpuFp16SubGraph::PreProcess() {
#ifdef ENABLE_ARM64
#ifdef ENABLE_FP16
if (!mindspore::lite::IsSupportFloat16()) {
MS_LOG(ERROR) << "Unsupported fp16 in this devices";
return RET_ERROR;
@ -347,7 +347,7 @@ int CpuFp16SubGraph::PreProcess() {
}
int CpuFp16SubGraph::PostProcess() {
#ifdef ENABLE_ARM64
#ifdef ENABLE_FP16
if (!mindspore::lite::IsSupportFloat16()) {
MS_LOG(ERROR) << "Unsupported fp16 in this devices";
return RET_ERROR;

View File

@ -377,8 +377,11 @@ add_dependencies(lite-test fbs_src)
target_link_libraries(lite-test dl mindspore::gtest)
if(PLATFORM_ARM64 AND ENABLE_FP16)
target_link_libraries(lite-test nnacl_fp16_mid nnacl_optimize_mid)
if(PLATFORM_ARM AND ENABLE_FP16)
target_link_libraries(lite-test nnacl_fp16_mid)
if(PLATFORM_ARM64)
target_link_libraries(lite-test nnacl_optimize_mid)
endif()
endif()
if(PLATFORM_ARM)