From f18d7cf43592fa891a2fa77d202bf1a7351a2e7e Mon Sep 17 00:00:00 2001 From: lzk Date: Mon, 12 Apr 2021 23:10:40 -0700 Subject: [PATCH] fp16 bug fix --- build.sh | 2 +- .../kernel_compiler/cpu/nnacl/CMakeLists.txt | 4 +- .../cpu/nnacl/arg_min_max_parameter.h | 2 +- .../arm82_aarch32_fp16/Matmul12x8Fp16.S | 602 ++++++++++++++++ .../assembly/fp16/IndirectGemmFp16_16x8.S | 667 ------------------ .../cpu/nnacl/fp16/activation_fp16.c | 42 +- .../cpu/nnacl/fp16/activation_fp16.h | 4 +- .../cpu/nnacl/fp16/arithmetic_fp16.c | 18 +- .../cpu/nnacl/fp16/arithmetic_fp16.h | 4 +- .../cpu/nnacl/fp16/arithmetic_self_fp16.c | 4 +- .../cpu/nnacl/fp16/arithmetic_self_fp16.h | 4 +- .../cpu/nnacl/fp16/batchnorm_fp16.c | 4 +- .../cpu/nnacl/fp16/batchnorm_fp16.h | 9 +- .../cpu/nnacl/fp16/cast_fp16.h | 1 - .../cpu/nnacl/fp16/common_func_fp16.c | 14 + .../cpu/nnacl/fp16/common_func_fp16.h | 1 - .../cpu/nnacl/fp16/constant_of_shape_fp16.h | 5 +- .../cpu/nnacl/fp16/conv_depthwise_fp16.c | 13 +- .../cpu/nnacl/fp16/conv_depthwise_fp16.h | 2 +- .../cpu/nnacl/fp16/conv_fp16.c | 18 +- .../cpu/nnacl/fp16/conv_fp16.h | 2 +- .../cpu/nnacl/fp16/crop_fp16.h | 1 - .../cpu/nnacl/fp16/deconv_fp16.c | 10 +- .../cpu/nnacl/fp16/deconv_winograd_fp16.c | 25 +- .../kernel_compiler/cpu/nnacl/fp16/exp_fp16.h | 1 + .../cpu/nnacl/fp16/instance_norm_fp16.c | 1 + .../cpu/nnacl/fp16/instance_norm_fp16.h | 1 - .../cpu/nnacl/fp16/lstm_fp16.c | 3 +- .../cpu/nnacl/fp16/matmul_fp16.c | 178 ++++- .../cpu/nnacl/fp16/matmul_fp16.h | 48 +- .../cpu/nnacl/fp16/matrix_fp16.c | 2 +- .../cpu/nnacl/fp16/pack_fp16.c | 374 +++++++--- .../cpu/nnacl/fp16/pack_fp16.h | 15 +- .../kernel_compiler/cpu/nnacl/fp16/pad_fp16.h | 3 - .../cpu/nnacl/fp16/pooling_fp16.h | 4 +- .../cpu/nnacl/fp16/power_fp16.c | 12 +- .../cpu/nnacl/fp16/power_fp16.h | 5 +- .../cpu/nnacl/fp16/quant_dtype_cast_fp16.h | 5 +- .../cpu/nnacl/fp16/reduce_fp16.h | 3 - .../cpu/nnacl/fp16/scale_fp16.h | 5 +- .../cpu/nnacl/fp16/softmax_fp16.c | 10 +- .../cpu/nnacl/fp16/softmax_fp16.h | 5 +- .../cpu/nnacl/fp16/transpose_fp16.h | 4 +- .../cpu/nnacl/fp16/winograd_transform_fp16.c | 555 +-------------- .../cpu/nnacl/fp16/winograd_transform_fp16.h | 16 +- .../cpu/nnacl/fp16/winograd_utils_fp16.h | 2 +- .../nnacl/intrinsics/ms_simd_instructions.h | 25 +- .../intrinsics/ms_simd_instructions_fp16.h | 99 +++ .../cpu/nnacl/optimize/CMakeLists.txt | 26 +- mindspore/lite/CMakeLists.txt | 10 + mindspore/lite/src/CMakeLists.txt | 17 +- mindspore/lite/src/common/utils.cc | 6 +- mindspore/lite/src/common/utils.h | 2 +- mindspore/lite/src/kernel_registry.cc | 4 +- .../src/runtime/kernel/arm/CMakeLists.txt | 2 +- .../kernel/arm/base/constant_of_shape.cc | 2 +- .../kernel/arm/fp16/convolution_1x1_fp16.cc | 41 +- .../kernel/arm/fp16/convolution_1x1_fp16.h | 2 + .../kernel/arm/fp16/convolution_fp16.cc | 13 +- .../kernel/arm/fp16/convolution_fp16.h | 2 + .../arm/fp16/convolution_winograd_fp16.cc | 27 +- .../arm/fp16/convolution_winograd_fp16.h | 2 + .../runtime/kernel/arm/fp16/fp16_op_handler.h | 14 +- .../runtime/kernel/arm/fp16/pooling_fp16.cc | 4 +- mindspore/lite/src/scheduler.cc | 4 +- mindspore/lite/src/sub_graph_kernel.cc | 6 +- mindspore/lite/test/CMakeLists.txt | 7 +- 67 files changed, 1457 insertions(+), 1568 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S delete mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/assembly/fp16/IndirectGemmFp16_16x8.S create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/ms_simd_instructions_fp16.h diff --git a/build.sh b/build.sh index 8b44fa692ce..034999172d4 100755 --- a/build.sh +++ b/build.sh @@ -607,7 +607,7 @@ build_lite() cmake -DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" -DANDROID_NATIVE_API_LEVEL="19" \ -DANDROID_NDK="${ANDROID_NDK}" -DANDROID_ABI="armeabi-v7a" -DANDROID_TOOLCHAIN_NAME="clang" \ -DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ - -DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ + -DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} -DENABLE_FP16="on" \ -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \ -DSUPPORT_GPU=${LOCAL_LITE_ENABLE_GPU} -DSUPPORT_NPU=${LOCAL_LITE_ENABLE_NPU} -DENABLE_V0=on \ -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/CMakeLists.txt b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/CMakeLists.txt index ce584c85886..94a001d3922 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/CMakeLists.txt +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/CMakeLists.txt @@ -68,7 +68,7 @@ add_library(nnacl_mid OBJECT ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC}) add_dependencies(nnacl fbs_src) add_dependencies(nnacl_mid fbs_src) -########################### arm64 build optimize library ######################## -if(PLATFORM_ARM64) +########################### arm fp16 build optimize library ######################## +if(ENABLE_FP16) add_subdirectory(${NNACL_DIR}/optimize) endif() diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/arg_min_max_parameter.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/arg_min_max_parameter.h index f6cf61471c7..bb2cb50a060 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/arg_min_max_parameter.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/arg_min_max_parameter.h @@ -30,7 +30,7 @@ typedef struct ArgElement { int8_t i8_data_; int32_t i_data_; float f_data_; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_ARM float16_t f16_data_; #endif } data_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S new file mode 100644 index 00000000000..11733db5f54 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/assembly/arm82_aarch32_fp16/Matmul12x8Fp16.S @@ -0,0 +1,602 @@ +#ifdef ENABLE_ARM32 +#include "nnacl/assembly_global.h" + .text + .align 5 + .global MatMul12x8A32Fp16 +#ifndef __APPLE__ + .type MatMul12x8A32Fp16, %function +#endif + +// void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, +// int deep, int row, int col, int stride, bool write_mode); +// r0: a +// r1: b +// r2: dst +// r3: bias +// #4: depth +// #8: row +// #12: col +// #16: stride +// #20: writeNhwc/writeWino + +asm_function MatMul12x8A32Fp16 + // r13(sp) and r15(pc) can not be used!! + // r9 r4 is tmp register + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r3-r11, lr} + vpush {q4-q7} + add sp, sp, #104 + + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + ldr lr, [sp, #20] + + mov r10, r1 // b + mov r11, r0 // a + mov r12, r2 // dst + + cmp lr, #2 + bne NoWinograd + mul r4, r8, r7 // stride * col + add r4, r4, r4 // r4 * sizeof(float16_t) + mov r9, #16 + mul r9, r8, r9 // stride * 8 * sizeof(float16_t) +NoWinograd: + add r8, r8, r8 // stride * sizeof(float16_t) + +a .req r0 +weight .req r1 +dst .req r2 +bias .req r3 +depth .req r5 +row .req r6 +col .req r7 +stride .req r8 +b_tmp .req r10 +a_tmp .req r11 +dst_tmp .req r12 + +.macro STORE_12x8 p1 + vst1.16 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_12x7 p1, p2, p3 + add r4, dst, #8 + add r9, dst, #12 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + vst1.16 {\p3}, [r9] + add dst, dst, stride +.endm + +.macro STORE_12x6 p1, p2 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + add dst, dst, stride +.endm + +.macro STORE_12x5 p1, p2 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride +.endm + +.macro STORE_12x4 p1 + vst1.16 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_12x3 p1, p2 + add r4, dst, #4 + vst1.32 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride +.endm + +.macro STORE_12x2 p1 + vst1.32 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_12x1 p1 + vst1.16 {\p1}, [dst] + add dst, dst, stride +.endm + +.macro STORE_C8 p1, p2 + vst1.16 {\p1}, [dst] + cmp row, \p2 + add dst, dst, stride + beq WriteEnd +.endm + +.macro STORE_C7 p1, p2, p3, p4 + add r4, dst, #8 + add r9, dst, #12 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + vst1.16 {\p3}, [r9] + add dst, dst, stride + cmp row, \p4 + beq WriteEnd +.endm + +.macro STORE_C6 p1, p2, p3 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.32 {\p2}, [r4] + add dst, dst, stride + cmp row, \p3 + beq WriteEnd +.endm + +.macro STORE_C5 p1, p2, p3 + add r4, dst, #8 + vst1.16 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride + cmp row, \p3 + beq WriteEnd +.endm + +.macro STORE_C4 p1, p2 + vst1.16 {\p1}, [dst] + cmp row, \p2 + add dst, dst, stride + beq WriteEnd +.endm + +.macro STORE_C3 p1, p2, p3 + add r4, dst, #4 + vst1.32 {\p1}, [dst] + vst1.16 {\p2}, [r4] + add dst, dst, stride + cmp row, \p3 + beq WriteEnd +.endm + +.macro STORE_C2 p1, p2 + vst1.32 {\p1}, [dst] + add dst, dst, stride + cmp row, \p2 + beq WriteEnd +.endm + +.macro STORE_C1 p1, p2 + vst1.16 {\p1}, [dst] + add dst, dst, stride + cmp row, \p2 + beq WriteEnd +.endm + +LoopRow12: + ldr bias, [sp, #-40] + LoopCol8: + mov dst, dst_tmp + mov a, a_tmp + ldr depth, [sp, #4] + veor q4, q4, q4 + veor q5, q5, q5 + veor q6, q6, q6 + veor q7, q7, q7 + veor q8, q8, q8 + veor q9, q9, q9 + veor q10, q10, q10 + veor q11, q11, q11 + veor q12, q12, q12 + veor q13, q13, q13 + veor q14, q14, q14 + veor q15, q15, q15 + LoopDepth: + vld1.16 {q0, d2}, [a]! + vld1.16 {q2}, [weight]! + vmla.f16 q4, q2, d0[0] + vmla.f16 q5, q2, d0[1] + vmla.f16 q6, q2, d0[2] + vmla.f16 q7, q2, d0[3] + vmla.f16 q8, q2, d1[0] + vmla.f16 q9, q2, d1[1] + vmla.f16 q10, q2, d1[2] + vmla.f16 q11, q2, d1[3] + vmla.f16 q12, q2, d2[0] + vmla.f16 q13, q2, d2[1] + vmla.f16 q14, q2, d2[2] + vmla.f16 q15, q2, d2[3] + + subs depth, depth, #1 + bne LoopDepth + + Bias: + cmp bias, #0 + beq Activation + vld1.16 {q0}, [bias]! + vadd.f16 q4, q4, q0 + vadd.f16 q5, q5, q0 + vadd.f16 q6, q6, q0 + vadd.f16 q7, q7, q0 + vadd.f16 q8, q8, q0 + vadd.f16 q9, q9, q0 + vadd.f16 q10, q10, q0 + vadd.f16 q11, q11, q0 + vadd.f16 q12, q12, q0 + vadd.f16 q13, q13, q0 + vadd.f16 q14, q14, q0 + vadd.f16 q15, q15, q0 + + Activation: + ldr lr, [sp] + cmp lr, #3 + beq Relu6 + cmp lr, #1 + beq Relu + b Write + + Relu6: + vmov.i16 q2, #0x4600 + vadd.f16 q4, q4, q2 + vadd.f16 q5, q5, q2 + vadd.f16 q6, q6, q2 + vadd.f16 q7, q7, q2 + vmin.f16 q8, q8, q2 + vmin.f16 q9, q9, q2 + vmin.f16 q10, q10, q2 + vmin.f16 q11, q11, q2 + vmin.f16 q12, q12, q2 + vmin.f16 q13, q13, q2 + vmin.f16 q14, q14, q2 + vmin.f16 q15, q15, q2 + + Relu: + veor q3, q3, q3 + vmax.f16 q4, q4, q3 + vmax.f16 q5, q5, q3 + vmax.f16 q6, q6, q3 + vmax.f16 q7, q7, q3 + vmax.f16 q8, q8, q3 + vmax.f16 q9, q9, q3 + vmax.f16 q10, q10, q3 + vmax.f16 q11, q11, q3 + vmax.f16 q12, q12, q3 + vmax.f16 q13, q13, q3 + vmax.f16 q14, q14, q3 + vmax.f16 q15, q15, q3 + + Write: + ldr lr, [sp, #20] + cmp lr, #2 + beq WriteWinograd + cmp row, #12 + bge Write12xCol + b WriteRowxCol + + WriteWinograd: + vst1.16 {q4}, [dst] + add dst, dst, r4 + vst1.16 {q5}, [dst] + add dst, dst, r4 + vst1.16 {q6}, [dst] + add dst, dst, r4 + vst1.16 {q7}, [dst] + add dst, dst, r4 + vst1.16 {q8}, [dst] + add dst, dst, r4 + vst1.16 {q9}, [dst] + add dst, dst, r4 + vst1.16 {q10}, [dst] + add dst, dst, r4 + vst1.16 {q11}, [dst] + add dst, dst, r4 + vst1.16 {q12}, [dst] + add dst, dst, r4 + vst1.16 {q13}, [dst] + add dst, dst, r4 + vst1.16 {q14}, [dst] + add dst, dst, r4 + vst1.16 {q15}, [dst] + add dst_tmp, dst_tmp, r9 + b WriteEnd + Write12xCol: + cmp col, #8 + bge Write12x8 + cmp col, #1 + beq Write12x1 + cmp col, #2 + beq Write12x2 + cmp col, #3 + beq Write12x3 + cmp col, #4 + beq Write12x4 + cmp col, #5 + beq Write12x5 + cmp col, #6 + beq Write12x6 + b Write12x7 + + WriteRowxCol: + cmp col, #8 + bge WriteRowx8 + cmp col, #1 + beq WriteRowx1 + cmp col, #2 + beq WriteRowx2 + cmp col, #3 + beq WriteRowx3 + cmp col, #4 + beq WriteRowx4 + cmp col, #5 + beq WriteRowx5 + cmp col, #6 + beq WriteRowx6 + b WriteRowx7 + + Write12x8: + STORE_12x8 q4 + STORE_12x8 q5 + STORE_12x8 q6 + STORE_12x8 q7 + STORE_12x8 q8 + STORE_12x8 q9 + STORE_12x8 q10 + STORE_12x8 q11 + STORE_12x8 q12 + STORE_12x8 q13 + STORE_12x8 q14 + STORE_12x8 q15 + b WriteEnd + WriteRowx8: + STORE_C8 q4, #1 + STORE_C8 q5, #2 + STORE_C8 q6, #3 + STORE_C8 q7, #4 + STORE_C8 q8, #5 + STORE_C8 q9, #6 + STORE_C8 q10, #7 + STORE_C8 q11, #8 + STORE_C8 q12, #9 + STORE_C8 q13, #10 + STORE_C8 q14, #11 + STORE_C8 q15, #12 + b WriteEnd + + Write12x1: + STORE_12x1 d8[0] + STORE_12x1 d10[0] + STORE_12x1 d12[0] + STORE_12x1 d14[0] + STORE_12x1 d16[0] + STORE_12x1 d18[0] + STORE_12x1 d20[0] + STORE_12x1 d22[0] + STORE_12x1 d24[0] + STORE_12x1 d26[0] + STORE_12x1 d28[0] + STORE_12x1 d30[0] + b WriteEnd + WriteRowx1: + STORE_C1 d8[0], #1 + STORE_C1 d10[0], #2 + STORE_C1 d12[0], #3 + STORE_C1 d14[0], #4 + STORE_C1 d16[0], #5 + STORE_C1 d18[0], #6 + STORE_C1 d20[0], #7 + STORE_C1 d22[0], #8 + STORE_C1 d24[0], #9 + STORE_C1 d26[0], #10 + STORE_C1 d28[0], #11 + STORE_C1 d30[0], #12 + b WriteEnd + + Write12x2: + STORE_12x2 d8[0] + STORE_12x2 d10[0] + STORE_12x2 d12[0] + STORE_12x2 d14[0] + STORE_12x2 d16[0] + STORE_12x2 d18[0] + STORE_12x2 d20[0] + STORE_12x2 d22[0] + STORE_12x2 d24[0] + STORE_12x2 d26[0] + STORE_12x2 d28[0] + STORE_12x2 d30[0] + b WriteEnd + WriteRowx2: + STORE_C2 d8[0], #1 + STORE_C2 d10[0], #2 + STORE_C2 d12[0], #3 + STORE_C2 d14[0], #4 + STORE_C2 d16[0], #5 + STORE_C2 d18[0], #6 + STORE_C2 d20[0], #7 + STORE_C2 d22[0], #8 + STORE_C2 d24[0], #9 + STORE_C2 d26[0], #10 + STORE_C2 d28[0], #11 + STORE_C2 d30[0], #12 + b WriteEnd + + Write12x3: + STORE_12x3 d8[0], d8[2] + STORE_12x3 d10[0], d10[2] + STORE_12x3 d12[0], d12[2] + STORE_12x3 d14[0], d14[2] + STORE_12x3 d16[0], d16[2] + STORE_12x3 d18[0], d18[2] + STORE_12x3 d20[0], d20[2] + STORE_12x3 d22[0], d22[2] + STORE_12x3 d24[0], d24[2] + STORE_12x3 d26[0], d26[2] + STORE_12x3 d28[0], d28[2] + STORE_12x3 d30[0], d30[2] + b WriteEnd + WriteRowx3: + STORE_C3 d8[0], d8[2], #1 + STORE_C3 d10[0], d10[2], #2 + STORE_C3 d12[0], d12[2], #3 + STORE_C3 d14[0], d14[2], #4 + STORE_C3 d16[0], d16[2], #5 + STORE_C3 d18[0], d18[2], #6 + STORE_C3 d20[0], d20[2], #7 + STORE_C3 d22[0], d22[2], #8 + STORE_C3 d24[0], d24[2], #9 + STORE_C3 d26[0], d26[2], #10 + STORE_C3 d28[0], d28[2], #11 + STORE_C3 d30[0], d30[2], #12 + b WriteEnd + + Write12x4: + STORE_12x4 d8 + STORE_12x4 d10 + STORE_12x4 d12 + STORE_12x4 d14 + STORE_12x4 d16 + STORE_12x4 d18 + STORE_12x4 d20 + STORE_12x4 d22 + STORE_12x4 d24 + STORE_12x4 d26 + STORE_12x4 d28 + STORE_12x4 d30 + b WriteEnd + WriteRowx4: + STORE_C4 d8, #1 + STORE_C4 d10, #2 + STORE_C4 d12, #3 + STORE_C4 d14, #4 + STORE_C4 d16, #5 + STORE_C4 d18, #6 + STORE_C4 d20, #7 + STORE_C4 d22, #8 + STORE_C4 d24, #9 + STORE_C4 d26, #10 + STORE_C4 d28, #11 + STORE_C4 d30, #12 + b WriteEnd + + Write12x5: + STORE_12x5 d8, d9[0] + STORE_12x5 d10, d11[0] + STORE_12x5 d12, d13[0] + STORE_12x5 d14, d15[0] + STORE_12x5 d16, d17[0] + STORE_12x5 d18, d19[0] + STORE_12x5 d20, d21[0] + STORE_12x5 d22, d23[0] + STORE_12x5 d24, d25[0] + STORE_12x5 d26, d27[0] + STORE_12x5 d28, d29[0] + STORE_12x5 d30, d31[0] + b WriteEnd + WriteRowx5: + STORE_C5 d8, d9[0], #1 + STORE_C5 d10, d11[0], #2 + STORE_C5 d12, d13[0], #3 + STORE_C5 d14, d15[0], #4 + STORE_C5 d16, d17[0], #5 + STORE_C5 d18, d19[0], #6 + STORE_C5 d20, d21[0], #7 + STORE_C5 d22, d23[0], #8 + STORE_C5 d24, d25[0], #9 + STORE_C5 d26, d27[0], #10 + STORE_C5 d28, d29[0], #11 + STORE_C5 d30, d31[0], #12 + b WriteEnd + + Write12x6: + STORE_12x6 d8, d9[0] + STORE_12x6 d10, d11[0] + STORE_12x6 d12, d13[0] + STORE_12x6 d14, d15[0] + STORE_12x6 d16, d17[0] + STORE_12x6 d18, d19[0] + STORE_12x6 d20, d21[0] + STORE_12x6 d22, d23[0] + STORE_12x6 d24, d25[0] + STORE_12x6 d26, d27[0] + STORE_12x6 d28, d29[0] + STORE_12x6 d30, d31[0] + b WriteEnd + WriteRowx6: + STORE_C6 d8, d9[0], #1 + STORE_C6 d10, d11[0], #2 + STORE_C6 d12, d13[0], #3 + STORE_C6 d14, d15[0], #4 + STORE_C6 d16, d17[0], #5 + STORE_C6 d18, d19[0], #6 + STORE_C6 d20, d21[0], #7 + STORE_C6 d22, d23[0], #8 + STORE_C6 d24, d25[0], #9 + STORE_C6 d26, d27[0], #10 + STORE_C6 d28, d29[0], #11 + STORE_C6 d30, d31[0], #12 + b WriteEnd + + Write12x7: + STORE_12x7 d8, d9[0], d9[2] + STORE_12x7 d10, d11[0], d11[2] + STORE_12x7 d12, d13[0], d13[2] + STORE_12x7 d14, d15[0], d15[2] + STORE_12x7 d16, d17[0], d17[2] + STORE_12x7 d18, d19[0], d19[2] + STORE_12x7 d20, d21[0], d21[2] + STORE_12x7 d22, d23[0], d23[2] + STORE_12x7 d24, d25[0], d25[2] + STORE_12x7 d26, d27[0], d27[2] + STORE_12x7 d28, d29[0], d29[2] + STORE_12x7 d30, d31[0], d31[2] + b WriteEnd + WriteRowx7: + STORE_C7 d8, d9[0], d9[2], #1 + STORE_C7 d10, d11[0], d11[2], #2 + STORE_C7 d12, d13[0], d13[2], #3 + STORE_C7 d14, d15[0], d15[2], #4 + STORE_C7 d16, d17[0], d17[2], #5 + STORE_C7 d18, d19[0], d19[2], #6 + STORE_C7 d20, d21[0], d21[2], #7 + STORE_C7 d22, d23[0], d23[2], #8 + STORE_C7 d24, d25[0], d25[2], #9 + STORE_C7 d26, d27[0], d27[2], #10 + STORE_C7 d28, d29[0], d29[2], #11 + STORE_C7 d30, d31[0], d31[2], #12 + b WriteEnd + + WriteEnd: + cmp col, #8 + ble LoopColEnd + sub col, col, #8 + ldr lr, [sp, #20] + cmp lr, #2 + beq LoopCol8 + add dst_tmp, dst_tmp, #16 + b LoopCol8 + LoopColEnd: + cmp row, #12 + ble LoopRowEnd + sub row, row, #12 + mov a_tmp, a + mov weight, b_tmp + ldr lr, [sp, #20] + cmp lr, #2 + beq WinogradDst + ldr lr, [sp, #12] + sub lr, lr, col + add lr, lr, lr // col *= 2 + sub dst_tmp, dst, lr + b LoopRow + WinogradDst: + add dst_tmp, dst, r9 + LoopRow: + mov dst, dst_tmp + ldr col, [sp, #12] + b LoopRow12 +LoopRowEnd: + sub sp, sp, #104 + vpop {q4-q7} + pop {r3-r11, pc} +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/assembly/fp16/IndirectGemmFp16_16x8.S b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/assembly/fp16/IndirectGemmFp16_16x8.S deleted file mode 100644 index 5bd1da914e6..00000000000 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/assembly/fp16/IndirectGemmFp16_16x8.S +++ /dev/null @@ -1,667 +0,0 @@ -#ifdef ENABLE_ARM64 -#include "nnacl/assembly_global.h" - -.text -.align 5 - -// void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, -// size_t step, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC4, size_t relu, size_t relu6); -// x0: output, x1: input, x2: weight, x3: bias, x4: step, x5: ic4, x6: oc8, x7: offset, -// x8:mode, x9: writeC4, x10:relu, x11: relu6 -// compute 8 channel for 16 outputs -asm_function IndirectGemmFp16_16x8 - - .macro INIT_BIAS - dup v16.4s, wzr - cbz x3, InitBias - ld1 {v16.8h}, [x3] - InitBias: - mov v17.16b, v16.16b - mov v18.16b, v16.16b - mov v19.16b, v16.16b - mov v20.16b, v16.16b - mov v21.16b, v16.16b - mov v22.16b, v16.16b - mov v23.16b, v16.16b - mov v24.16b, v16.16b - mov v25.16b, v16.16b - mov v26.16b, v16.16b - mov v27.16b, v16.16b - mov v28.16b, v16.16b - mov v29.16b, v16.16b - mov v30.16b, v16.16b - mov v31.16b, v16.16b - .endm - - // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to - // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers - // x19 ~ r29 should be also preserved - // whereas our coding style do not permit such amount of parameters - sub sp, sp, #144 - // performance between storing 4 registers at the same time and separately storing them on in-order cores - // is not tested yet - st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 - st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 - stp x19, x20, [sp], #16 - - ldr x8, [sp, #0] - ldr x9, [sp, #8] - ldr x10, [sp, #16] - ldr x11, [sp, #24] - - cbnz x8, IndirectGemmStart - // step is one for common convolution, where ic8 should multiply by kernel size - // step is (a+b-1) for F(a,b) in winograd - mul x5, x4, x5 - mov x4, #1 - -IndirectGemmStart: - - LoopOc: - - mov x14, x4 - mov x12, x1 - - LoopKsize: - - mov x15, x0 - INIT_BIAS - // load input for output 1-8 - ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64 - // load weight - ld1 {v8.8h, v9.8h}, [x2], #32 - // first 2 steps for output 1 and 3 - fmla v16.8h, v8.8h, v0.h[0] - fmla v18.8h, v8.8h, v1.h[0] - fmla v16.8h, v9.8h, v0.h[1] - fmla v18.8h, v9.8h, v1.h[1] - // load weight - ld1 {v10.8h, v11.8h}, [x2], #32 - // first 2 steps for output 2 and 4 - fmla v17.8h, v8.8h, v0.h[4] - fmla v19.8h, v8.8h, v1.h[4] - fmla v17.8h, v9.8h, v0.h[5] - fmla v19.8h, v9.8h, v1.h[5] - // load input for output 9-16 - // input cache should be refreshed after loading - // ATTENTION: advance is preferred, but advancing too much may lead to invalid prefetching - ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64 - // last 2 steps for output 1 and 3 - fmla v16.8h, v10.8h, v0.h[2] - fmla v18.8h, v10.8h, v1.h[2] - fmla v16.8h, v11.8h, v0.h[3] - fmla v18.8h, v11.8h, v1.h[3] - - // check if ic4=1 - subs x13, x5, #1 - beq LoopIcEnd - - LoopIc: - // last 2 steps for output 2 and 4 - fmla v17.8h, v10.8h, v0.h[6] - fmla v19.8h, v10.8h, v1.h[6] - fmla v17.8h, v11.8h, v0.h[7] - fmla v19.8h, v11.8h, v1.h[7] - // steps for output 5-8 - fmla v20.8h, v8.8h, v2.h[0] - fmla v22.8h, v8.8h, v3.h[0] - fmla v20.8h, v9.8h, v2.h[1] - fmla v22.8h, v9.8h, v3.h[1] - fmla v21.8h, v8.8h, v2.h[4] - fmla v23.8h, v8.8h, v3.h[4] - fmla v21.8h, v9.8h, v2.h[5] - fmla v23.8h, v9.8h, v3.h[5] - fmla v20.8h, v10.8h, v2.h[2] - fmla v22.8h, v10.8h, v3.h[2] - fmla v20.8h, v11.8h, v2.h[3] - fmla v22.8h, v11.8h, v3.h[3] - fmla v21.8h, v10.8h, v2.h[6] - fmla v23.8h, v10.8h, v3.h[6] - fmla v21.8h, v11.8h, v2.h[7] - fmla v23.8h, v11.8h, v3.h[7] - // load input for output 1-8 - ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x12], #64 - // steps for output 9-12 - fmla v24.8h, v8.8h, v4.h[0] - fmla v26.8h, v8.8h, v5.h[0] - fmla v24.8h, v9.8h, v4.h[1] - fmla v26.8h, v9.8h, v5.h[1] - fmla v25.8h, v8.8h, v4.h[4] - fmla v27.8h, v8.8h, v5.h[4] - fmla v25.8h, v9.8h, v4.h[5] - fmla v27.8h, v9.8h, v5.h[5] - fmla v24.8h, v10.8h, v4.h[2] - fmla v26.8h, v10.8h, v5.h[2] - fmla v24.8h, v11.8h, v4.h[3] - fmla v26.8h, v11.8h, v5.h[3] - fmla v25.8h, v10.8h, v4.h[6] - fmla v27.8h, v10.8h, v5.h[6] - fmla v25.8h, v11.8h, v4.h[7] - fmla v27.8h, v11.8h, v5.h[7] - // steps for output 13-16 - fmla v28.8h, v8.8h, v6.h[0] - fmla v30.8h, v8.8h, v7.h[0] - fmla v28.8h, v9.8h, v6.h[1] - fmla v30.8h, v9.8h, v7.h[1] - fmla v29.8h, v8.8h, v6.h[4] - fmla v31.8h, v8.8h, v7.h[4] - fmla v29.8h, v9.8h, v6.h[5] - fmla v31.8h, v9.8h, v7.h[5] - // load weight - ld1 {v8.8h, v9.8h}, [x2], #32 - fmla v28.8h, v10.8h, v6.h[2] - fmla v30.8h, v10.8h, v7.h[2] - fmla v28.8h, v11.8h, v6.h[3] - fmla v30.8h, v11.8h, v7.h[3] - fmla v29.8h, v10.8h, v6.h[6] - fmla v31.8h, v10.8h, v7.h[6] - fmla v29.8h, v11.8h, v6.h[7] - fmla v31.8h, v11.8h, v7.h[7] - // load weight - ld1 {v10.8h, v11.8h}, [x2], #32 - // first 2 steps for output 1-4 - fmla v16.8h, v8.8h, v0.h[0] - fmla v18.8h, v8.8h, v1.h[0] - fmla v16.8h, v9.8h, v0.h[1] - fmla v18.8h, v9.8h, v1.h[1] - fmla v17.8h, v8.8h, v0.h[4] - fmla v19.8h, v8.8h, v1.h[4] - fmla v17.8h, v9.8h, v0.h[5] - fmla v19.8h, v9.8h, v1.h[5] - // load input for output 9-16 - ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12], #64 - // last 2 steps for output 1 and 3 - fmla v16.8h, v10.8h, v0.h[2] - fmla v18.8h, v10.8h, v1.h[2] - fmla v16.8h, v11.8h, v0.h[3] - fmla v18.8h, v11.8h, v1.h[3] - - subs x13, x13, #1 - bne LoopIc - - LoopIcEnd: - fmla v17.8h, v10.8h, v0.h[6] - fmla v19.8h, v10.8h, v1.h[6] - fmla v17.8h, v11.8h, v0.h[7] - fmla v19.8h, v11.8h, v1.h[7] - // steps for output 5-8 - fmla v20.8h, v8.8h, v2.h[0] - fmla v22.8h, v8.8h, v3.h[0] - fmla v20.8h, v9.8h, v2.h[1] - fmla v22.8h, v9.8h, v3.h[1] - fmla v21.8h, v8.8h, v2.h[4] - fmla v23.8h, v8.8h, v3.h[4] - fmla v21.8h, v9.8h, v2.h[5] - fmla v23.8h, v9.8h, v3.h[5] - fmla v20.8h, v10.8h, v2.h[2] - fmla v22.8h, v10.8h, v3.h[2] - fmla v20.8h, v11.8h, v2.h[3] - fmla v22.8h, v11.8h, v3.h[3] - fmla v21.8h, v10.8h, v2.h[6] - fmla v23.8h, v10.8h, v3.h[6] - fmla v21.8h, v11.8h, v2.h[7] - fmla v23.8h, v11.8h, v3.h[7] - // steps for output 9-12 - fmla v24.8h, v8.8h, v4.h[0] - fmla v26.8h, v8.8h, v5.h[0] - fmla v24.8h, v9.8h, v4.h[1] - fmla v26.8h, v9.8h, v5.h[1] - fmla v25.8h, v8.8h, v4.h[4] - fmla v27.8h, v8.8h, v5.h[4] - fmla v25.8h, v9.8h, v4.h[5] - fmla v27.8h, v9.8h, v5.h[5] - fmla v24.8h, v10.8h, v4.h[2] - fmla v26.8h, v10.8h, v5.h[2] - fmla v24.8h, v11.8h, v4.h[3] - fmla v26.8h, v11.8h, v5.h[3] - fmla v25.8h, v10.8h, v4.h[6] - fmla v27.8h, v10.8h, v5.h[6] - fmla v25.8h, v11.8h, v4.h[7] - fmla v27.8h, v11.8h, v5.h[7] - // steps for output 13-16 - fmla v28.8h, v8.8h, v6.h[0] - fmla v30.8h, v8.8h, v7.h[0] - fmla v28.8h, v9.8h, v6.h[1] - fmla v30.8h, v9.8h, v7.h[1] - fmla v29.8h, v8.8h, v6.h[4] - fmla v31.8h, v8.8h, v7.h[4] - fmla v29.8h, v9.8h, v6.h[5] - fmla v31.8h, v9.8h, v7.h[5] - fmla v28.8h, v10.8h, v6.h[2] - fmla v30.8h, v10.8h, v7.h[2] - fmla v28.8h, v11.8h, v6.h[3] - fmla v30.8h, v11.8h, v7.h[3] - fmla v29.8h, v10.8h, v6.h[6] - fmla v31.8h, v10.8h, v7.h[6] - fmla v29.8h, v11.8h, v6.h[7] - fmla v31.8h, v11.8h, v7.h[7] - - cbnz x11, Relu6 - cbnz x10, Relu - b WriteStart - Relu6: - movi v9.8h, #0x46, lsl #8 - fmin v16.8h, v16.8h, v9.8h - fmin v17.8h, v17.8h, v9.8h - fmin v18.8h, v18.8h, v9.8h - fmin v19.8h, v19.8h, v9.8h - fmin v20.8h, v20.8h, v9.8h - fmin v21.8h, v21.8h, v9.8h - fmin v22.8h, v22.8h, v9.8h - fmin v23.8h, v23.8h, v9.8h - fmin v24.8h, v24.8h, v9.8h - fmin v25.8h, v25.8h, v9.8h - fmin v26.8h, v26.8h, v9.8h - fmin v27.8h, v27.8h, v9.8h - fmin v28.8h, v28.8h, v9.8h - fmin v29.8h, v29.8h, v9.8h - fmin v30.8h, v30.8h, v9.8h - fmin v31.8h, v31.8h, v9.8h - Relu: - dup v8.4s, wzr - fmax v16.8h, v16.8h, v8.8h - fmax v17.8h, v17.8h, v8.8h - fmax v18.8h, v18.8h, v8.8h - fmax v19.8h, v19.8h, v8.8h - fmax v20.8h, v20.8h, v8.8h - fmax v21.8h, v21.8h, v8.8h - fmax v22.8h, v22.8h, v8.8h - fmax v23.8h, v23.8h, v8.8h - fmax v24.8h, v24.8h, v8.8h - fmax v25.8h, v25.8h, v8.8h - fmax v26.8h, v26.8h, v8.8h - fmax v27.8h, v27.8h, v8.8h - fmax v28.8h, v28.8h, v8.8h - fmax v29.8h, v29.8h, v8.8h - fmax v30.8h, v30.8h, v8.8h - fmax v31.8h, v31.8h, v8.8h - - WriteStart: - cbnz x9, Write8 - cmp x6, #1 - beq Write1 - cmp x6, #2 - beq Write2 - cmp x6, #3 - beq Write3 - cmp x6, #4 - beq Write4 - cmp x6, #5 - beq Write5 - cmp x6, #6 - beq Write6 - cmp x6, #7 - beq Write7 - b Write8 - // prefetching is not preferred while writing results in spite of cache missing - // you could try prfm pstl2strm - // there are almost no benefits observed though - Write1: - str h16, [x15] - add x15, x15, x7 - str h17, [x15] - add x15, x15, x7 - str h18, [x15] - add x15, x15, x7 - str h19, [x15] - add x15, x15, x7 - str h20, [x15] - add x15, x15, x7 - str h21, [x15] - add x15, x15, x7 - str h22, [x15] - add x15, x15, x7 - str h23, [x15] - add x15, x15, x7 - str h24, [x15] - add x15, x15, x7 - str h25, [x15] - add x15, x15, x7 - str h26, [x15] - add x15, x15, x7 - str h27, [x15] - add x15, x15, x7 - str h28, [x15] - add x15, x15, x7 - str h29, [x15] - add x15, x15, x7 - str h30, [x15] - add x15, x15, x7 - str h31, [x15] - add x0, x0, #2 - b WriteEnd - Write2: - add x17, x15, #2 - st1 {v16.h}[0], [x15], x7 - st1 {v16.h}[1], [x17], x7 - st1 {v17.h}[0], [x15], x7 - st1 {v17.h}[1], [x17], x7 - st1 {v18.h}[0], [x15], x7 - st1 {v18.h}[1], [x17], x7 - st1 {v19.h}[0], [x15], x7 - st1 {v19.h}[1], [x17], x7 - st1 {v20.h}[0], [x15], x7 - st1 {v20.h}[1], [x17], x7 - st1 {v21.h}[0], [x15], x7 - st1 {v21.h}[1], [x17], x7 - st1 {v22.h}[0], [x15], x7 - st1 {v22.h}[1], [x17], x7 - st1 {v23.h}[0], [x15], x7 - st1 {v23.h}[1], [x17], x7 - st1 {v24.h}[0], [x15], x7 - st1 {v24.h}[1], [x17], x7 - st1 {v25.h}[0], [x15], x7 - st1 {v25.h}[1], [x17], x7 - st1 {v26.h}[0], [x15], x7 - st1 {v26.h}[1], [x17], x7 - st1 {v27.h}[0], [x15], x7 - st1 {v27.h}[1], [x17], x7 - st1 {v28.h}[0], [x15], x7 - st1 {v28.h}[1], [x17], x7 - st1 {v29.h}[0], [x15], x7 - st1 {v29.h}[1], [x17], x7 - st1 {v30.h}[0], [x15], x7 - st1 {v30.h}[1], [x17], x7 - st1 {v31.h}[0], [x15] - st1 {v31.h}[1], [x17] - add x0, x0, #4 - b WriteEnd - Write3: - add x17, x15, #4 - add x16, x15, #2 - st1 {v16.h}[0], [x15], x7 - st1 {v16.h}[1], [x16], x7 - st1 {v16.h}[2], [x17], x7 - st1 {v17.h}[0], [x15], x7 - st1 {v17.h}[1], [x16], x7 - st1 {v17.h}[2], [x17], x7 - st1 {v18.h}[0], [x15], x7 - st1 {v18.h}[1], [x16], x7 - st1 {v18.h}[2], [x17], x7 - st1 {v19.h}[0], [x15], x7 - st1 {v19.h}[1], [x16], x7 - st1 {v19.h}[2], [x17], x7 - st1 {v20.h}[0], [x15], x7 - st1 {v20.h}[1], [x16], x7 - st1 {v20.h}[2], [x17], x7 - st1 {v21.h}[0], [x15], x7 - st1 {v21.h}[1], [x16], x7 - st1 {v21.h}[2], [x17], x7 - st1 {v22.h}[0], [x15], x7 - st1 {v22.h}[1], [x16], x7 - st1 {v22.h}[2], [x17], x7 - st1 {v23.h}[0], [x15], x7 - st1 {v23.h}[1], [x16], x7 - st1 {v23.h}[2], [x17], x7 - st1 {v24.h}[0], [x15], x7 - st1 {v24.h}[1], [x16], x7 - st1 {v24.h}[2], [x17], x7 - st1 {v25.h}[0], [x15], x7 - st1 {v25.h}[1], [x16], x7 - st1 {v25.h}[2], [x17], x7 - st1 {v26.h}[0], [x15], x7 - st1 {v26.h}[1], [x16], x7 - st1 {v26.h}[2], [x17], x7 - st1 {v27.h}[0], [x15], x7 - st1 {v27.h}[1], [x16], x7 - st1 {v27.h}[2], [x17], x7 - st1 {v28.h}[0], [x15], x7 - st1 {v28.h}[1], [x16], x7 - st1 {v28.h}[2], [x17], x7 - st1 {v29.h}[0], [x15], x7 - st1 {v29.h}[1], [x16], x7 - st1 {v29.h}[2], [x17], x7 - st1 {v30.h}[0], [x15], x7 - st1 {v30.h}[1], [x16], x7 - st1 {v30.h}[2], [x17], x7 - st1 {v31.h}[0], [x15] - st1 {v31.h}[1], [x16] - st1 {v31.h}[2], [x17] - add x0, x0, #6 - b WriteEnd - Write4: - st1 {v16.4h}, [x15], x7 - st1 {v17.4h}, [x15], x7 - st1 {v18.4h}, [x15], x7 - st1 {v19.4h}, [x15], x7 - st1 {v20.4h}, [x15], x7 - st1 {v21.4h}, [x15], x7 - st1 {v22.4h}, [x15], x7 - st1 {v23.4h}, [x15], x7 - st1 {v24.4h}, [x15], x7 - st1 {v25.4h}, [x15], x7 - st1 {v26.4h}, [x15], x7 - st1 {v27.4h}, [x15], x7 - st1 {v28.4h}, [x15], x7 - st1 {v29.4h}, [x15], x7 - st1 {v30.4h}, [x15], x7 - st1 {v31.4h}, [x15] - add x0, x0, #8 - b WriteEnd - Write5: - add x17, x15, #8 - st1 {v16.4h}, [x15], x7 - st1 {v16.h}[4], [x17], x7 - st1 {v17.4h}, [x15], x7 - st1 {v17.h}[4], [x17], x7 - st1 {v18.4h}, [x15], x7 - st1 {v18.h}[4], [x17], x7 - st1 {v19.4h}, [x15], x7 - st1 {v19.h}[4], [x17], x7 - st1 {v20.4h}, [x15], x7 - st1 {v20.h}[4], [x17], x7 - st1 {v21.4h}, [x15], x7 - st1 {v21.h}[4], [x17], x7 - st1 {v22.4h}, [x15], x7 - st1 {v22.h}[4], [x17], x7 - st1 {v23.4h}, [x15], x7 - st1 {v23.h}[4], [x17], x7 - st1 {v24.4h}, [x15], x7 - st1 {v24.h}[4], [x17], x7 - st1 {v25.4h}, [x15], x7 - st1 {v25.h}[4], [x17], x7 - st1 {v26.4h}, [x15], x7 - st1 {v26.h}[4], [x17], x7 - st1 {v27.4h}, [x15], x7 - st1 {v27.h}[4], [x17], x7 - st1 {v28.4h}, [x15], x7 - st1 {v28.h}[4], [x17], x7 - st1 {v29.4h}, [x15], x7 - st1 {v29.h}[4], [x17], x7 - st1 {v30.4h}, [x15], x7 - st1 {v30.h}[4], [x17], x7 - st1 {v31.4h}, [x15] - st1 {v31.h}[4], [x17] - add x0, x0, #10 - b WriteEnd - Write6: - add x17, x15, #8 - add x16, x15, #10 - st1 {v16.4h}, [x15], x7 - ins v0.s[0], v16.s[2] - st1 {v0.h}[0], [x17], x7 - st1 {v0.h}[1], [x16], x7 - st1 {v17.4h}, [x15], x7 - ins v1.s[0], v17.s[2] - st1 {v1.h}[0], [x17], x7 - st1 {v1.h}[1], [x16], x7 - st1 {v18.4h}, [x15], x7 - ins v2.s[0], v18.s[2] - st1 {v2.h}[0], [x17], x7 - st1 {v2.h}[1], [x16], x7 - st1 {v19.4h}, [x15], x7 - ins v3.s[0], v19.s[2] - st1 {v3.h}[0], [x17], x7 - st1 {v3.h}[1], [x16], x7 - st1 {v20.4h}, [x15], x7 - ins v4.s[0], v20.s[2] - st1 {v4.h}[0], [x17], x7 - st1 {v4.h}[1], [x16], x7 - st1 {v21.4h}, [x15], x7 - ins v5.s[0], v21.s[2] - st1 {v5.h}[0], [x17], x7 - st1 {v5.h}[1], [x16], x7 - st1 {v22.4h}, [x15], x7 - ins v6.s[0], v22.s[2] - st1 {v6.h}[0], [x17], x7 - st1 {v6.h}[1], [x16], x7 - st1 {v23.4h}, [x15], x7 - ins v7.s[0], v23.s[2] - st1 {v7.h}[0], [x17], x7 - st1 {v7.h}[1], [x16], x7 - st1 {v24.4h}, [x15], x7 - ins v8.s[0], v24.s[2] - st1 {v8.h}[0], [x17], x7 - st1 {v8.h}[1], [x16], x7 - st1 {v25.4h}, [x15], x7 - ins v9.s[0], v25.s[2] - st1 {v9.h}[0], [x17], x7 - st1 {v9.h}[1], [x16], x7 - st1 {v26.4h}, [x15], x7 - ins v10.s[0], v26.s[2] - st1 {v10.h}[0], [x17], x7 - st1 {v10.h}[1], [x16], x7 - st1 {v27.4h}, [x15], x7 - ins v11.s[0], v27.s[2] - st1 {v11.h}[0], [x17], x7 - st1 {v11.h}[1], [x16], x7 - st1 {v28.4h}, [x15], x7 - ins v12.s[0], v28.s[2] - st1 {v12.h}[0], [x17], x7 - st1 {v12.h}[1], [x16], x7 - st1 {v29.4h}, [x15], x7 - ins v13.s[0], v29.s[2] - st1 {v13.h}[0], [x17], x7 - st1 {v13.h}[1], [x16], x7 - st1 {v30.4h}, [x15], x7 - ins v14.s[0], v30.s[2] - st1 {v14.h}[0], [x17], x7 - st1 {v14.h}[1], [x16], x7 - st1 {v31.4h}, [x15] - ins v15.s[0], v31.s[2] - st1 {v14.h}[0], [x17] - st1 {v14.h}[1], [x16] - add x0, x0, #12 - b WriteEnd - Write7: - add x17, x15, #8 - add x19, x15, #10 - add x16, x15, #12 - st1 {v16.4h}, [x15], x7 - ins v0.s[0], v16.s[2] - st1 {v0.h}[0], [x17], x7 - st1 {v0.h}[1], [x19], x7 - st1 {v16.h}[6], [x16], x7 - st1 {v17.4h}, [x15], x7 - ins v1.s[0], v17.s[2] - st1 {v1.h}[0], [x17], x7 - st1 {v1.h}[1], [x19], x7 - st1 {v17.h}[6], [x16], x7 - st1 {v18.4h}, [x15], x7 - ins v2.s[0], v18.s[2] - st1 {v2.h}[0], [x17], x7 - st1 {v2.h}[1], [x19], x7 - st1 {v18.h}[6], [x16], x7 - st1 {v19.4h}, [x15], x7 - ins v3.s[0], v19.s[2] - st1 {v3.h}[0], [x17], x7 - st1 {v3.h}[1], [x19], x7 - st1 {v19.h}[6], [x16], x7 - st1 {v20.4h}, [x15], x7 - ins v4.s[0], v20.s[2] - st1 {v4.h}[0], [x17], x7 - st1 {v4.h}[1], [x19], x7 - st1 {v20.h}[6], [x16], x7 - st1 {v21.4h}, [x15], x7 - ins v5.s[0], v21.s[2] - st1 {v5.h}[0], [x17], x7 - st1 {v5.h}[1], [x19], x7 - st1 {v21.h}[6], [x16], x7 - st1 {v22.4h}, [x15], x7 - ins v6.s[0], v22.s[2] - st1 {v6.h}[0], [x17], x7 - st1 {v6.h}[1], [x19], x7 - st1 {v22.h}[6], [x16], x7 - st1 {v23.4h}, [x15], x7 - ins v7.s[0], v23.s[2] - st1 {v7.h}[0], [x17], x7 - st1 {v7.h}[1], [x19], x7 - st1 {v23.h}[6], [x16], x7 - st1 {v24.4h}, [x15], x7 - ins v8.s[0], v24.s[2] - st1 {v8.h}[0], [x17], x7 - st1 {v8.h}[1], [x19], x7 - st1 {v24.h}[6], [x16], x7 - st1 {v25.4h}, [x15], x7 - ins v9.s[0], v25.s[2] - st1 {v9.h}[0], [x17], x7 - st1 {v9.h}[1], [x19], x7 - st1 {v25.h}[6], [x16], x7 - st1 {v26.4h}, [x15], x7 - ins v10.s[0], v26.s[2] - st1 {v10.h}[0], [x17], x7 - st1 {v10.h}[1], [x19], x7 - st1 {v26.h}[6], [x16], x7 - st1 {v27.4h}, [x15], x7 - ins v11.s[0], v27.s[2] - st1 {v11.h}[0], [x17], x7 - st1 {v11.h}[1], [x19], x7 - st1 {v27.h}[6], [x16], x7 - st1 {v28.4h}, [x15], x7 - ins v12.s[0], v28.s[2] - st1 {v12.h}[0], [x17], x7 - st1 {v12.h}[1], [x19], x7 - st1 {v28.h}[6], [x16], x7 - st1 {v29.4h}, [x15], x7 - ins v13.s[0], v29.s[2] - st1 {v13.h}[0], [x17], x7 - st1 {v13.h}[1], [x19], x7 - st1 {v29.h}[6], [x16], x7 - st1 {v30.4h}, [x15], x7 - ins v14.s[0], v30.s[2] - st1 {v14.h}[0], [x17], x7 - st1 {v14.h}[1], [x19], x7 - st1 {v30.h}[6], [x16], x7 - st1 {v31.4h}, [x15] - ins v15.s[0], v31.s[2] - st1 {v15.h}[0], [x17] - st1 {v15.h}[1], [x19] - st1 {v31.h}[6], [x16] - add x0, x0, #14 - b WriteEnd - Write8: - st1 {v16.8h}, [x15], x7 - st1 {v17.8h}, [x15], x7 - st1 {v18.8h}, [x15], x7 - st1 {v19.8h}, [x15], x7 - st1 {v20.8h}, [x15], x7 - st1 {v21.8h}, [x15], x7 - st1 {v22.8h}, [x15], x7 - st1 {v23.8h}, [x15], x7 - st1 {v24.8h}, [x15], x7 - st1 {v25.8h}, [x15], x7 - st1 {v26.8h}, [x15], x7 - st1 {v27.8h}, [x15], x7 - st1 {v28.8h}, [x15], x7 - st1 {v29.8h}, [x15], x7 - st1 {v30.8h}, [x15], x7 - st1 {v31.8h}, [x15] - add x0, x0, #16 - - WriteEnd: - subs x14, x14, #1 - bne LoopKsize - - subs x6, x6, #8 - cbz x3, NoStepForward - add x3, x3, #16 - NoStepForward: - bgt LoopOc - - sub sp, sp, #144 - ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 - ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 - ldp x19, x20, [sp], #16 - ret -#endif - diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/activation_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/activation_fp16.c index f4afa205ca8..4f63c98b547 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/activation_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/activation_fp16.c @@ -29,7 +29,7 @@ int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) { } #endif for (; offset < ele_num; offset++) { - dst[offset] = src[offset] < 0 ? 0 : src[offset]; + dst[offset] = src[offset] < 0.0f ? 0.0f : src[offset]; } return NNACL_OK; } @@ -47,14 +47,24 @@ int Relu6Fp16(const float16_t *data, float16_t *dst, int ele_num) { } #endif for (; offset < ele_num; offset++) { - dst[offset] = data[offset] < 0 ? 0 : data[offset]; - dst[offset] = dst[offset] > 6 ? 6 : dst[offset]; + dst[offset] = data[offset] < 0.0f ? 0.0f : data[offset]; + dst[offset] = dst[offset] > 6.0f ? 6.0f : dst[offset]; } return NNACL_OK; } int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha) { - for (int i = 0; i < ele_num; ++i) { + int i = 0; +#ifdef ENABLE_NEON + int ele_c8 = UP_ROUND(ele_num, C8NUM); + for (; i < ele_c8; i += C8NUM) { + float16x8_t src_tmp = vld1q_f16(src + i); + float16x8_t mul_tmp = vmulq_n_f16(src_tmp, alpha); + float16x8_t mask = vcgtq_f16(src_tmp, vdupq_n_f16(0.0f)); + vst1q_f16(dst + i, vbslq_f32(mask, src_tmp, mul_tmp)); + } +#endif + for (; i < ele_num; ++i) { dst[i] = src[i] > (float16_t)0.0f ? src[i] : (src[i] * alpha); } return NNACL_OK; @@ -62,12 +72,12 @@ int LReluFp16(const float16_t *src, float16_t *dst, int ele_num, float16_t alpha int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { int i = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON int count = (ele_num / C4NUM) * C4NUM; for (; i < count; i += C4NUM) { float32x4_t tmp; simd_exp(vnegq_f32(vcvt_f32_f16(vld1_f16(src + i))), (float *)&tmp); - vst1_f16(dst + i, vcvt_f16_f32(vdivq_f32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), tmp)))); + vst1_f16(dst + i, vcvt_f16_f32(MS_DIVQ_F32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), tmp)))); } #endif for (; i < ele_num; ++i) { @@ -79,9 +89,9 @@ int SigmoidFp16(const float16_t *src, float16_t *dst, int ele_num) { } float16_t TanhOptFp16(float16_t src) { - if (src > 5.0) { + if (src > 5.0f) { return 1.0f; - } else if (src < -5.0) { + } else if (src < -5.0f) { return -1.0f; } else { float square = src * src; @@ -93,7 +103,7 @@ float16_t TanhOptFp16(float16_t src) { int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { int i = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_NEON static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f}, {17325.0f, 17325.0f, 17325.0f, 17325.0f}, {135135.0f, 135135.0f, 135135.0f, 135135.0f}, @@ -112,7 +122,7 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { float32x4_t b = vaddq_f32( vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square), paramv[2]); - vst1_f16(dst + i, vcvt_f16_f32(vminq_f32(vmaxq_f32(vdivq_f32(a, b), neg_one), pos_one))); + vst1_f16(dst + i, vcvt_f16_f32(vminq_f32(vmaxq_f32(MS_DIVQ_F32(a, b), neg_one), pos_one))); } #endif for (; i < ele_num; ++i) { @@ -130,7 +140,7 @@ int TanhFp16(const float16_t *src, float16_t *dst, int ele_num) { int HSwishFp16(const float16_t *src, float16_t *dst, int ele_num) { for (int i = 0; i < ele_num; ++i) { float16_t in = src[i]; - float16_t relu6 = MSMIN(MSMAX(in + 3, 0), 6); + float16_t relu6 = MSMIN(MSMAX(in + 3.0f, 0.0f), 6.0f); dst[i] = in * relu6 / (float16_t)6.0f; } return NNACL_OK; @@ -181,26 +191,26 @@ int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate) for (; i < C8; i += C8NUM) { float16x8_t in = vld1q_f16(src + i); float16x8_t res = - 0.5 * in * (1.0 + MS_TANHX8_F16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * in * in) * in)); + 0.5f * in * (1.0f + MS_TANHX8_F16(((float16_t)0.79788456080287f + (float16_t)0.035677408136f * in * in) * in)); vst1q_f16(dst + i, res); } #endif for (; i < length; i++) { dst[i] = - 0.5 * src[i] * - (1.0 + TanhOptFp16(((float16_t)0.79788456080287 + (float16_t)0.035677408136 * src[i] * src[i]) * src[i])); + 0.5f * src[i] * + (1.0f + TanhOptFp16(((float16_t)0.79788456080287f + (float16_t)0.035677408136f * src[i] * src[i]) * src[i])); } } else { #ifdef ENABLE_NEON int C8 = UP_ROUND(length, C8NUM); for (; i < C8; i += C8NUM) { float16x8_t in = vld1q_f16(src + i); - const float16x8_t res = 0.5 * in * (1.0 + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f)); + const float16x8_t res = 0.5f * in * (1.0f + MS_ERFX8_F16(in / (float16_t)1.4142135623730951f)); vst1q_f16(dst + i, res); } #endif for (; i < length; i++) { - dst[i] = 0.5 * src[i] * (1.0 + erf(src[i] / 1.4142135623730951f)); + dst[i] = 0.5f * src[i] * (1.0f + erff(src[i] / 1.4142135623730951f)); } } return NNACL_OK; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/activation_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/activation_fp16.h index e2879626023..d97454b2d28 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/activation_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/activation_fp16.h @@ -16,11 +16,9 @@ #ifndef MINDSPORE_NNACL_FP16_ACTIVATION_FP16_H_ #define MINDSPORE_NNACL_FP16_ACTIVATION_FP16_H_ -#ifdef ENABLE_NEON -#include -#endif #include #include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" #include "nnacl/int8/fixed_point.h" #ifdef __cplusplus diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_fp16.c index c8aa9b92b4b..31b9ac36efe 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_fp16.c @@ -569,7 +569,7 @@ int ElementDivFp16(const float16_t *input0, const float16_t *input1, float16_t * for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vdivq_f16(vin0, vin1); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1); vst1q_f16(output + index, vout); } #endif @@ -591,7 +591,7 @@ int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_ #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vdivq_f16(vin0_opt, vin1); + float16x8_t vout = MS_DIVQ_F16(vin0_opt, vin1); vst1q_f16(output + index, vout); } #endif @@ -606,7 +606,7 @@ int ElementOptDivFp16(const float16_t *input0, const float16_t *input1, float16_ #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); - float16x8_t vout = vdivq_f16(vin0, vin1_opt); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1_opt); vst1q_f16(output + index, vout); } #endif @@ -624,7 +624,7 @@ int ElementDivReluFp16(const float16_t *input0, const float16_t *input1, float16 for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vdivq_f16(vin0, vin1); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1); vout = vmaxq_f16(vout, zeros); vst1q_f16(output + index, vout); } @@ -652,7 +652,7 @@ int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, floa #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vmaxq_f16(vdivq_f16(vin0_opt, vin1), zeros); + float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros); vst1q_f16(output + index, vout); } #endif @@ -670,7 +670,7 @@ int ElementOptDivReluFp16(const float16_t *input0, const float16_t *input1, floa #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); - float16x8_t vout = vmaxq_f16(vdivq_f16(vin0, vin1_opt), zeros); + float16x8_t vout = vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros); vst1q_f16(output + index, vout); } #endif @@ -689,7 +689,7 @@ int ElementDivRelu6Fp16(const float16_t *input0, const float16_t *input1, float1 for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vdivq_f16(vin0, vin1); + float16x8_t vout = MS_DIVQ_F16(vin0, vin1); vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); vst1q_f16(output + index, vout); } @@ -716,7 +716,7 @@ int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, flo #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin1 = vld1q_f16(input1 + index); - float16x8_t vout = vminq_f16(vmaxq_f16(vdivq_f16(vin0_opt, vin1), zeros), bounds); + float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0_opt, vin1), zeros), bounds); vst1q_f16(output + index, vout); } #endif @@ -733,7 +733,7 @@ int ElementOptDivRelu6Fp16(const float16_t *input0, const float16_t *input1, flo #ifdef ENABLE_NEON for (; index <= element_size - 8; index += C8NUM) { float16x8_t vin0 = vld1q_f16(input0 + index); - float16x8_t vout = vminq_f16(vmaxq_f16(vdivq_f16(vin0, vin1_opt), zeros), bounds); + float16x8_t vout = vminq_f16(vmaxq_f16(MS_DIVQ_F16(vin0, vin1_opt), zeros), bounds); vst1q_f16(output + index, vout); } #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_fp16.h index e500aaccd45..813e48c7079 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_fp16.h @@ -16,10 +16,8 @@ #ifndef MINDSPORE_NNACL_FP16_ARITHMETIC_FP16_H_ #define MINDSPORE_NNACL_FP16_ARITHMETIC_FP16_H_ -#ifdef ENABLE_NEON -#include -#endif #include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" #include "nnacl/base/arithmetic_base.h" #include "nnacl/errorcode.h" diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_self_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_self_fp16.c index 044da62500a..be3c5f0b0be 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_self_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_self_fp16.c @@ -80,7 +80,7 @@ int ElementLogicalNotFp16(float16_t *input, float16_t *output, int element_size) int ElementRoundFp16(float16_t *input, float16_t *output, int element_size) { for (int i = 0; i < element_size; i++) { - output[i] = round(input[i]); + output[i] = roundf(input[i]); } return NNACL_OK; } @@ -94,7 +94,7 @@ int ElementFloorFp16(float16_t *input, float16_t *output, int element_size) { int ElementCeilFp16(float16_t *input, float16_t *output, int number) { for (int i = 0; i < number; ++i) { - output[i] = ceil(input[i]); + output[i] = ceilf(input[i]); } return NNACL_OK; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_self_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_self_fp16.h index 3b94e104066..58ad411aa29 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_self_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/arithmetic_self_fp16.h @@ -16,10 +16,8 @@ #ifndef MINDSPORE_NNACL_FP16_ARITHMETIC_SELF_FP16_H_ #define MINDSPORE_NNACL_FP16_ARITHMETIC_SELF_FP16_H_ -#ifdef ENABLE_NEON -#include -#endif #include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" #include "nnacl/errorcode.h" #ifdef __cplusplus diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/batchnorm_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/batchnorm_fp16.c index bae64529bcc..05f0b8078e3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/batchnorm_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/batchnorm_fp16.c @@ -26,7 +26,7 @@ void BatchNormFp16(const float16_t *input, const void *mean, const void *varianc for (int i = 0; i < cur_unit; i++) { for (int c = 0; c < param->channel_; c++) { - float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_); + float16_t variance_sqrt = sqrtf(((const float16_t *)variance)[c] + param->epsilon_); if (variance_sqrt != 0) { output[cur_offset + c] = (input[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt; } @@ -44,7 +44,7 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset for (int i = 0; i < cur_unit; i++) { for (int c = 0; c < param->channel_; c++) { - float16_t variance_sqrt = sqrt(((const float16_t *)variance)[c] + param->epsilon_); + float16_t variance_sqrt = sqrtf(((const float16_t *)variance)[c] + param->epsilon_); if (variance_sqrt != 0) { float16_t norm_val = (((const float16_t *)input)[cur_offset + c] - ((const float16_t *)mean)[c]) / variance_sqrt; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/batchnorm_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/batchnorm_fp16.h index 8f6d6aa485f..e6f459cb572 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/batchnorm_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/batchnorm_fp16.h @@ -13,12 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_BATCHNORM_FP16_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_BATCHNORM_FP16_H_ +#ifndef MINDSPORE_NNACL_FP16_BATCHNORM_FP16_H_ +#define MINDSPORE_NNACL_FP16_BATCHNORM_FP16_H_ -#ifdef ENABLE_NEON -#include -#endif #include "nnacl/batchnorm_parameter.h" #ifdef __cplusplus @@ -34,4 +31,4 @@ void FusedBatchNormFp16(const void *input, const void *scale, const void *offset } #endif -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_BATCHNORM_FP16_H_ +#endif // MINDSPORE_NNACL_FP16_BATCHNORM_FP16_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/cast_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/cast_fp16.h index d7136cd1ae5..da818ed4f8a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/cast_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/cast_fp16.h @@ -16,7 +16,6 @@ #ifndef MINDSPORE_NNACL_CAST_FP16_H_ #define MINDSPORE_NNACL_CAST_FP16_H_ -#include #include "nnacl/op_base.h" #ifdef __cplusplus diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/common_func_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/common_func_fp16.c index 621de2c77bf..9a66f52e5de 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/common_func_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/common_func_fp16.c @@ -56,3 +56,17 @@ void PostConvFuncFp16C4(const float16_t *c4_out, float16_t *nhwc_out, const floa PostFuncBiasReluC4Fp16(nhwc_out, c4_out, bias, oc4div, oc4mod, plane, stride_size, act_type); return; } + +#ifdef ENABLE_ARM82_A32 +void PostFuncBiasReluC4Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc4div, size_t oc4mod, + size_t plane_size, size_t plane_stride, size_t relu_type) { + // TODO(fun): function + return; +} + +void PostFuncBiasReluC8Fp16(float16_t *dst, const float16_t *src, const float16_t *bias, size_t oc8div, size_t oc8mod, + size_t plane_size, size_t stride, size_t relu_type) { + // TODO(fun): function + return; +} +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/common_func_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/common_func_fp16.h index b6abd3e2677..01d0a7b456b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/common_func_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/common_func_fp16.h @@ -16,7 +16,6 @@ #ifndef MINDSPORE_NNACL_FP16_COMMON_FUNC_FP16_H_ #define MINDSPORE_NNACL_FP16_COMMON_FUNC_FP16_H_ -#include #include "nnacl/op_base.h" #ifdef __cplusplus diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/constant_of_shape_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/constant_of_shape_fp16.h index 6c42506a467..bba9538ef30 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/constant_of_shape_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/constant_of_shape_fp16.h @@ -16,9 +16,6 @@ #ifndef MINDSPORE_NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ #define MINDSPORE_NNACL_FP16_CONSTANT_OF_SHAPE_FP16_H_ -#ifdef ENABLE_NEON -#include -#endif #include "nnacl/op_base.h" #include "nnacl/errorcode.h" #include "nnacl/constant_of_shape_parameter.h" @@ -27,7 +24,7 @@ extern "C" { #endif #ifdef __cplusplus -#ifdef ENABLE_NEON +#ifdef ENABLE_FP16 inline int ConstantOfShapeFp16(float16_t *output, int start, int end, float16_t value) { for (int i = start; i < end; i++) { output[i] = value; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_depthwise_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_depthwise_fp16.c index f4a2e740078..f7176e51d6b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_depthwise_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_depthwise_fp16.c @@ -18,6 +18,18 @@ #include #include "nnacl/fp16/activation_fp16.h" +#ifdef ENABLE_ARM82_A32 +void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *weight_ptr, size_t num_pixels, + size_t output_channel, size_t input_step) { + for (int i = 0; i < num_pixels; i++) { + for (int c = 0; c < output_channel; c++) { + *output_ptr++ += weight_ptr[c] * input_ptr[c]; + } + input_ptr += input_step; + } +} +#endif + void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float16_t *weight_data, const float16_t *bias_data, const ConvParameter *conv_param, int task_id) { int h_step = UP_DIV(conv_param->output_h_, conv_param->thread_num_); @@ -57,7 +69,6 @@ void ConvDwFp16(float16_t *output_data, const float16_t *input_data, const float const float16_t *src_kw = src_kh + iw_origin * conv_param->input_channel_; int num_pixels = out_w_end - out_w_start; - ConvDwFp16Row(dst_w, src_kw, weight_kh, num_pixels, conv_param->output_channel_, in_sw_step); weight_kh += conv_param->output_channel_; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_depthwise_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_depthwise_fp16.h index 80c53471860..e5b781dfd69 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_depthwise_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_depthwise_fp16.h @@ -23,9 +23,9 @@ #ifdef __cplusplus extern "C" { #endif -#ifdef ENABLE_ARM64 void ConvDwFp16Row(float16_t *output_ptr, const float16_t *input_ptr, const float16_t *filter_ptr, size_t num_pixels, size_t input_channel, size_t input_step); +#ifdef ENABLE_ARM64 void ConvDwFp16Border(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, size_t height, size_t width, size_t in_kh_step, size_t in_kw_step, size_t kernel_w, size_t relu, size_t relu6); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c index 069fe5139f9..df9524c5c98 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.c @@ -31,7 +31,7 @@ void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weigh #ifdef __cplusplus } #endif -#ifndef ENABLE_NEON +#ifndef ENABLE_ARM64 void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, size_t ic4, size_t out_channel, size_t offset, size_t mode, size_t writeC8, size_t relu, size_t relu6) { @@ -124,7 +124,11 @@ void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *we // fp16 convolution common (im2col+gemm) void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id, ConvParameter *conv_param) { +#ifdef ENABLE_ARM64 const int tile_n = 16; +#else + const int tile_n = 12; +#endif int out_channel = conv_param->output_channel_; int output_count = conv_param->output_h_ * conv_param->output_w_; int output_tile_count = UP_DIV(output_count, tile_n); @@ -144,7 +148,11 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ Im2ColPackUnitFp16(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); int out_offset = thread_id * tile_n * out_channel + out_batch_offset; +#ifdef ENABLE_ARM64 RowMajor2Col16MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep); +#else + RowMajor2Col12MajorFp16Opt(gemm_input, col_major_gemm_input, tile_n, deep); +#endif MatMulFp16(col_major_gemm_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep, real_cal_num, out_channel, out_channel, OutType_Nhwc); } @@ -155,7 +163,11 @@ void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data, float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, InputTransFp16Func in_func, OutputTransFp16Func out_func) { +#ifdef ENABLE_ARM64 const int tile_num = 16; +#else + const int tile_num = 12; +#endif int in_channel = conv_param->input_channel_; int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); @@ -194,7 +206,11 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset; float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; for (int i = 0; i < input_unit_square; ++i) { +#ifdef ENABLE_ARM64 RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); +#else + RowMajor2Col12MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); +#endif MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel, cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.h index 8ecde2bc6fe..971044769ad 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/conv_fp16.h @@ -24,7 +24,7 @@ typedef float16_t *TmpBufferAddressFp16; typedef float16_t *MatricesFp16; -#ifndef ENABLE_NEON +#ifndef ENABLE_ARM64 void IndirectGemmFp16_16x8(float16_t *output, float16_t *input, float16_t *weight, float16_t *bias, size_t step, size_t ic4, size_t oc8, size_t offset, size_t mode, size_t writeC8, size_t relu, size_t relu6); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/crop_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/crop_fp16.h index 6efa3a98375..2bae96ca4f4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/crop_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/crop_fp16.h @@ -17,7 +17,6 @@ #ifndef MINDSPORE_NNACL_FP16_CROP_FP16_H_ #define MINDSPORE_NNACL_FP16_CROP_FP16_H_ -#include #include "nnacl/op_base.h" #include "nnacl/crop_parameter.h" diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/deconv_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/deconv_fp16.c index d29dc2b2bfd..4ef8f232357 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/deconv_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/deconv_fp16.c @@ -53,11 +53,7 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride; float16_t *tmp_dst = dst_ptr + dst_index; const float16_t *tmp_src = src_ptr + src_index; -#ifdef DEBUG_CODE - for (int i = 0; i < C8NUM; i++) { - tmp_dst[i] += tmp_src[i]; - } -#else +#ifdef ENABLE_ARM64 asm volatile( "mov x0, %[tmp_src] \n" "mov x1, %[tmp_dst] \n" @@ -72,6 +68,10 @@ int DeConvPostFp16(const float16_t *src, float16_t *tmp, const float16_t *bias, : : [ tmp_src ] "r"(tmp_src), [ tmp_dst ] "r"(tmp_dst) : "x0", "x1", "v0", "v1"); +#else + for (int i = 0; i < C8NUM; i++) { + tmp_dst[i] += tmp_src[i]; + } #endif } /*kw*/ } /*kh*/ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/deconv_winograd_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/deconv_winograd_fp16.c index 98db40619ee..f8003ec939f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/deconv_winograd_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/deconv_winograd_fp16.c @@ -47,6 +47,7 @@ void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t cuont8 = count / C8NUM * C8NUM; int i = 0; for (; i < cuont8; i += C8NUM) { +#ifdef ENABLE_ARM64 size_t src_step = src_stride * sizeof(float16_t); size_t dst_step = dst_stride * sizeof(float16_t); asm volatile( @@ -93,7 +94,9 @@ void DeConvWgMergeFp16(const float16_t *src, float16_t *dst, size_t src_stride, : : [ src_ptr ] "r"(src_ptr), [ dst_ptr ] "r"(dst_ptr), [ src_step ] "r"(src_step), [ dst_step ] "r"(dst_step) : "x7", "x8", "x10", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); - +#else + // TODO(fun): arm32 +#endif src_ptr += C8NUM * src_stride; dst_ptr += C8NUM * dst_stride; } @@ -373,3 +376,23 @@ void DeconvWgPostFp16(float16_t *tile_out, float16_t *nc4hw4_output, ConvParamet } return; } + +#ifdef ENABLE_ARM82_A32 +void WinogradTransLeftFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, + size_t length) { + // TODO(fun): function + return; +} + +void WinogradTransRightFp16(const float16_t *S, const float16_t *B, float16_t *M, size_t w, size_t h, size_t k, + size_t length) { + // TODO(fun): function + return; +} + +void TiledC4MatmulFp16(float16_t *dst, const float16_t *src, const float16_t *weight, size_t ic4, size_t cal_num, + size_t oc4) { + // TODO(fun): function + return; +} +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/exp_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/exp_fp16.h index 8bc71c38054..9607239bda4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/exp_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/exp_fp16.h @@ -18,6 +18,7 @@ #define MINDSPORE_NNACL_FP16_EXP_H_ #include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c index c5d286c01b3..75ca35b3d0b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.c @@ -16,6 +16,7 @@ #include "nnacl/fp16/instance_norm_fp16.h" #include #include "nnacl/errorcode.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" int InstanceNormFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *gamma_data, const float16_t *beta_data, const InstanceNormParameter *param, size_t task_id) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.h index e22bd885197..5b743f2d74e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/instance_norm_fp16.h @@ -17,7 +17,6 @@ #define MINDSPORE_NNACL_FP16_INSTANCE_NORM_H_ #include "nnacl/instance_norm_parameter.h" - #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/lstm_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/lstm_fp16.c index a1a2e7bf5cd..30f122f0cca 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/lstm_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/lstm_fp16.c @@ -21,6 +21,7 @@ #include "nnacl/fp16/arithmetic_fp16.h" #include "nnacl/fp16/matmul_fp16.h" #include "nnacl/fp16/cast_fp16.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" void PackLstmWeightFp32ToFp16(float16_t *dst, const float *src, int batch, int deep, int col, int col_align) { for (int i = 0; i < batch; i++) { @@ -121,7 +122,7 @@ int ElementOptMulAccFp16(const float16_t *input0, const float16_t input1, float1 for (; index <= element_size - 8; index += 8) { float16x8_t vin0 = vld1q_f16(input0 + index); float16x8_t vout = vld1q_f16(output + index); - vout = vfmaq_n_f16(vout, vin0, input1); + vout = MS_FMAQ_N_F16(vout, vin0, input1); vst1q_f16(output + index, vout); } for (; index < element_size; index++) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matmul_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matmul_fp16.c index 982fc392320..8a70d04f72d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matmul_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matmul_fp16.c @@ -226,24 +226,43 @@ void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, return; } -void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, - int deep, int row, int col, int stride, bool write_nhwc) { - if (write_nhwc) { - /* col16-major * row8-major => col-major */ +void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode) { + if (write_mode == OutType_Nhwc) { for (int r = 0; r < row; r++) { for (int c = 0; c < col; c++) { + int r16div = r / 16, r16mod = r % 16; + int c8div = c / 8, c8mod = c % 8; + size_t ci = r * stride + c; + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r16div * deep * 16 + d * 16 + r16mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else if (write_mode == OutType_C8) { + int col_8 = UP_ROUND(col, C8NUM); + int row_16 = UP_ROUND(row, C16NUM); + for (int r = 0; r < row_16; r++) { + for (int c = 0; c < col_8; c++) { int r16div = r / C16NUM, r16mod = r % C16NUM; int c8div = c / C8NUM, c8mod = c % C8NUM; - size_t ci = r * stride + c; - float value = 0; + size_t ci = (c8div * C8NUM * row_16 + r * C8NUM + c8mod); + float16_t value = 0; for (int d = 0; d < deep; d++) { size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; value = value + a[ai] * b[bi]; } - if (bias != NULL) value += bias[c]; - if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); - if (act_type != ActType_No) value = MSMAX(0.0f, value); + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) dst[ci] = value; } } @@ -254,37 +273,119 @@ void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const fl for (int j = 0; j < col; ++j) { int c8div = j / 8, c8mod = j % 8; size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; - float value = 0; + float16_t value = 0; for (int d = 0; d < deep; ++d) { size_t ai = src_r_offset + d * C16NUM; size_t bi = c8div * deep * 8 + d * 8 + c8mod; value = value + a[ai] * b[bi]; } - if (bias != NULL) value += bias[j]; - if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); - if (act_type != ActType_No) value = MSMAX(0.0f, value); + ADD_BIAS(value, bias, j) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } +} + +void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode) { + if (write_mode == OutType_Nhwc) { + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r12div = r / 12, r12mod = r % 12; + int c8div = c / 8, c8mod = c % 8; + size_t ci = r * stride + c; + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * 12 + d * 12 + r12mod; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else if (write_mode == OutType_C8) { + int col_8 = UP_ROUND(col, C8NUM); + int row_12 = UP_ROUND(row, C12NUM); + for (int r = 0; r < row_12; r++) { + for (int c = 0; c < col_8; c++) { + int r12div = r / C12NUM, r12mod = r % C12NUM; + int c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = (c8div * C8NUM * row_12 + r * C8NUM + c8mod); + float16_t value = 0; + for (int d = 0; d < deep; d++) { + size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod; + size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, c) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) + dst[ci] = value; + } + } + } else { + for (int i = 0; i < row; ++i) { + int src_r_offset = i; + int dst_r_offset = i * col * stride; + for (int j = 0; j < col; ++j) { + int c8div = j / 8, c8mod = j % 8; + size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; + float16_t value = 0; + for (int d = 0; d < deep; ++d) { + size_t ai = src_r_offset + d * C12NUM; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; + value = value + a[ai] * b[bi]; + } + ADD_BIAS(value, bias, j) + DO_RELU(value, act_type) + DO_RELU6(value, act_type) dst[ci] = value; } } } - return; } void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, int depth, int row, int col, int stride, int out_type) { if (out_type == OutType_C8) { +#ifdef ENABLE_ARM64 MatmulFp16Neon64(a, b, c, bias, (int)act_type, depth, row, col, stride, false); +#else + MatMul12x8A32Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); +#endif } else { +#ifdef ENABLE_ARM64 MatmulFp16Neon64Opt(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); +#else + MatMul12x8A32Fp16(a, b, c, bias, (int)act_type, depth, row, col, stride, out_type); +#endif } return; } +#ifdef ENABLE_ARM82_A32 +void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col) { + // TODO(fun): function + return; +} +#endif + void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, int depth, int col) { +#ifdef ENABLE_ARM64 MatVecMulFp16Neon64(a, b, c, bias, (int)act_type, depth, col); +#else + MatVecMulA32Fp16(a, b, c, bias, (int)act_type, depth, col); +#endif } +#ifdef ENABLE_ARM64 static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) { size_t stride = col * 2; asm volatile( @@ -392,6 +493,7 @@ static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_ "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); } +#endif void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { size_t row_up_16 = UP_ROUND(row, C16NUM); @@ -442,6 +544,54 @@ void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si return; } +void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + size_t row_up_12 = UP_ROUND(row, C12NUM); + size_t row12 = row / C12NUM * C12NUM; + size_t col8 = col / C8NUM * C8NUM; + const float16_t *src_r = src_ptr; + float16_t *dst_r = dst_ptr; + size_t ri = 0; + // transpose 12x8 + for (; ri < row12; ri += C12NUM) { + size_t ci = 0; + for (; ci < col8; ci += C8NUM) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; +#ifdef ENABLE_ARM82_A32 + Transpose12x8A32Fp16(src_c, dst_c, col * sizeof(float16_t), 24); +#else + for (int tr = 0; tr < C12NUM; tr++) { + for (int tc = 0; tc < C8NUM; tc++) { + dst_c[tc * C12NUM + tr] = src_c[tr * col + tc]; + } + } +#endif + } + for (; ci < col; ci++) { + const float16_t *src_c = src_r + ci; + float16_t *dst_c = dst_r + ci * C12NUM; + for (size_t i = 0; i < C12NUM; i++) { + dst_c[i] = src_c[i * col]; + } + } + src_r += C12NUM * col; + dst_r += C12NUM * col; + } + for (; ri < row; ri++) { + for (size_t i = 0; i < col; ++i) { + dst_r[i * C12NUM] = src_r[i]; + } + src_r += col; + dst_r += 1; + } + for (; ri < row_up_12; ri++) { + for (size_t i = 0; i < col; i++) { + dst_r[i * C12NUM] = 0; + } + dst_r += 1; + } +} + void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) { if (is_fp32_src) { const float *fp32_src = (const float *)src; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matmul_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matmul_fp16.h index 113553ef597..12a564a6774 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matmul_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matmul_fp16.h @@ -19,18 +19,51 @@ #include #include -#ifdef ENABLE_NEON +#ifdef ENABLE_ARM64 #include #endif #include "nnacl/errorcode.h" #include "nnacl/matmul_parameter.h" #include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" +#include "nnacl/fp16/pack_fp16.h" + +#define ADD_BIAS(value, bias, c) \ + if (bias != NULL) value = value + bias[c]; + +#define DO_RELU(value, act_type) \ + if (act_type == ActType_Relu) value = MSMAX(0.0f, value); + +#define DO_RELU6(value, act_type) \ + if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); \ + if (act_type == ActType_Relu6) value = MSMAX(0.0f, value); #ifdef __cplusplus extern "C" { #endif -void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, - int deep, int row, int col, int stride, bool write_nhwc); +void MatMul16x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode); + +void MatMul12x8Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode); + +#ifdef ENABLE_ARM64 +void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc); + +void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); + +void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col); + +#elif ENABLE_ARM82_A32 +void MatMul12x8A32Fp16(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, int write_mode); + +void MatVecMulA32Fp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, + int depth, int col); +#endif void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, int depth, int row, int col, int stride, int out_type); @@ -42,14 +75,7 @@ void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); -void MatmulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, - size_t depth, size_t row, size_t col, size_t stride, bool write_nhwc); - -void MatmulFp16Neon64Opt(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, - size_t depth, size_t row, size_t col, size_t stride, size_t write_nhwc); - -void MatVecMulFp16Neon64(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, int act_type, - int depth, int col); +void RowMajor2Col12MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col); void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matrix_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matrix_fp16.c index ce8acfba751..a716dd974ec 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matrix_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/matrix_fp16.c @@ -39,7 +39,7 @@ void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matri for (int i = 0; i < m; ++i) { for (int j = 0; j < n; ++j) { for (int y = 0; y < in_channel; ++y) { - float16_t tmp = 0; + float tmp = 0; for (int z = 0; z < k; ++z) { tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pack_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pack_fp16.c index aecc351ec07..5cecafc98e3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pack_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pack_fp16.c @@ -160,129 +160,35 @@ void PackNCHWToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int } void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int channel) { - int hw16 = plane / C16NUM * C16NUM; +#ifdef ENABLE_ARM64 + // Transpose16x8 in arm64 + const int hw_tile = C16NUM; +#else + // Transpose8x8 in others + const int hw_tile = C8NUM; +#endif + int hw_align = plane / hw_tile * hw_tile; int c8 = channel / C8NUM * C8NUM; int batch = plane * channel; for (int n = 0; n < batches; n++) { const float16_t *src_batch = (const float16_t *)src + n * batch; float16_t *dst_batch = (float16_t *)dst + n * batch; int hw = 0; - for (; hw < hw16; hw += C16NUM) { + for (; hw < hw_align; hw += hw_tile) { int c = 0; for (; c < c8; c += C8NUM) { const float16_t *src_ptr = src_batch + hw * channel + c; float16_t *dst_ptr = dst_batch + c * plane + hw; #ifdef ENABLE_ARM64 - size_t srcStride = channel * sizeof(float16_t); - size_t dstStride = plane * sizeof(float16_t); - asm volatile( - "mov x10, %[src_ptr]\n" - "mov x11, %[dst_ptr]\n" - - "ld1 {v0.8h}, [x10], %[srcStride]\n" - "ld1 {v1.8h}, [x10], %[srcStride]\n" - "ld1 {v2.8h}, [x10], %[srcStride]\n" - "ld1 {v3.8h}, [x10], %[srcStride]\n" - "ld1 {v4.8h}, [x10], %[srcStride]\n" - "ld1 {v5.8h}, [x10], %[srcStride]\n" - "ld1 {v6.8h}, [x10], %[srcStride]\n" - "ld1 {v7.8h}, [x10], %[srcStride]\n" - - "zip1 v16.8h, v0.8h, v1.8h\n" - "zip1 v17.8h, v2.8h, v3.8h\n" - "zip1 v18.8h, v4.8h, v5.8h\n" - "zip1 v19.8h, v6.8h, v7.8h\n" - - "ld1 {v8.8h}, [x10], %[srcStride]\n" - "ld1 {v9.8h}, [x10], %[srcStride]\n" - "ld1 {v10.8h}, [x10], %[srcStride]\n" - "ld1 {v11.8h}, [x10], %[srcStride]\n" - "ld1 {v12.8h}, [x10], %[srcStride]\n" - "ld1 {v13.8h}, [x10], %[srcStride]\n" - "ld1 {v14.8h}, [x10], %[srcStride]\n" - "ld1 {v15.8h}, [x10], %[srcStride]\n" - - "trn1 v20.4s, v16.4s, v17.4s\n" - "trn2 v21.4s, v16.4s, v17.4s\n" - "trn1 v22.4s, v18.4s, v19.4s\n" - "trn2 v23.4s, v18.4s, v19.4s\n" - - "trn1 v24.2d, v20.2d, v22.2d\n" - "trn2 v25.2d, v20.2d, v22.2d\n" - "trn1 v26.2d, v21.2d, v23.2d\n" - "trn2 v27.2d, v21.2d, v23.2d\n" - - "zip1 v16.8h, v8.8h, v9.8h\n" - "zip1 v17.8h, v10.8h, v11.8h\n" - "zip1 v18.8h, v12.8h, v13.8h\n" - "zip1 v19.8h, v14.8h, v15.8h\n" - - "trn1 v20.4s, v16.4s, v17.4s\n" - "trn2 v21.4s, v16.4s, v17.4s\n" - "trn1 v22.4s, v18.4s, v19.4s\n" - "trn2 v23.4s, v18.4s, v19.4s\n" - - "trn1 v28.2d, v20.2d, v22.2d\n" - "trn2 v29.2d, v20.2d, v22.2d\n" - "trn1 v30.2d, v21.2d, v23.2d\n" - "trn2 v31.2d, v21.2d, v23.2d\n" - - "add x10, x11, #16\n" - "st1 {v24.8h}, [x11], %[dstStride]\n" - "st1 {v28.8h}, [x10], %[dstStride]\n" - "st1 {v26.8h}, [x11], %[dstStride]\n" - "st1 {v30.8h}, [x10], %[dstStride]\n" - "st1 {v25.8h}, [x11], %[dstStride]\n" - "st1 {v29.8h}, [x10], %[dstStride]\n" - "st1 {v27.8h}, [x11], %[dstStride]\n" - "st1 {v31.8h}, [x10], %[dstStride]\n" - - "zip2 v16.8h, v0.8h, v1.8h\n" - "zip2 v17.8h, v2.8h, v3.8h\n" - "zip2 v18.8h, v4.8h, v5.8h\n" - "zip2 v19.8h, v6.8h, v7.8h\n" - - "trn1 v20.4s, v16.4s, v17.4s\n" - "trn2 v21.4s, v16.4s, v17.4s\n" - "trn1 v22.4s, v18.4s, v19.4s\n" - "trn2 v23.4s, v18.4s, v19.4s\n" - - "trn1 v24.2d, v20.2d, v22.2d\n" - "trn2 v25.2d, v20.2d, v22.2d\n" - "trn1 v26.2d, v21.2d, v23.2d\n" - "trn2 v27.2d, v21.2d, v23.2d\n" - - "zip2 v16.8h, v8.8h, v9.8h\n" - "zip2 v17.8h, v10.8h, v11.8h\n" - "zip2 v18.8h, v12.8h, v13.8h\n" - "zip2 v19.8h, v14.8h, v15.8h\n" - - "trn1 v20.4s, v16.4s, v17.4s\n" - "trn2 v21.4s, v16.4s, v17.4s\n" - "trn1 v22.4s, v18.4s, v19.4s\n" - "trn2 v23.4s, v18.4s, v19.4s\n" - - "trn1 v28.2d, v20.2d, v22.2d\n" - "trn2 v29.2d, v20.2d, v22.2d\n" - "trn1 v30.2d, v21.2d, v23.2d\n" - "trn2 v31.2d, v21.2d, v23.2d\n" - - "st1 {v24.8h}, [x11], %[dstStride]\n" - "st1 {v28.8h}, [x10], %[dstStride]\n" - "st1 {v26.8h}, [x11], %[dstStride]\n" - "st1 {v30.8h}, [x10], %[dstStride]\n" - "st1 {v25.8h}, [x11], %[dstStride]\n" - "st1 {v29.8h}, [x10], %[dstStride]\n" - "st1 {v27.8h}, [x11], %[dstStride]\n" - "st1 {v31.8h}, [x10], %[dstStride]\n" - : - : - [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride) - : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31"); + size_t src_stride = channel * sizeof(float16_t); + size_t dst_stride = plane * sizeof(float16_t); + Transpose16x8ARM64Fp16(src_ptr, dst_ptr, src_stride, dst_stride); +#elif defined(ENABLE_ARM82_A32) + size_t src_stride = channel * sizeof(float16_t); + size_t dst_stride = plane * sizeof(float16_t); + Transpose8x8A32Fp16(src_ptr, dst_ptr, src_stride, dst_stride); #else - for (int tr = 0; tr < C16NUM; tr++) { + for (int tr = 0; tr < hw_tile; tr++) { for (int tc = 0; tc < C8NUM; tc++) { dst_ptr[tc * plane + tr] = src_ptr[tr * channel + tc]; } @@ -292,7 +198,7 @@ void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int for (; c < channel; c++) { const float16_t *src_ptr = src_batch + hw * channel + c; float16_t *dst_ptr = dst_batch + c * plane + hw; - for (size_t i = 0; i < C16NUM; i++) { + for (size_t i = 0; i < hw_tile; i++) { dst_ptr[i] = src_ptr[i * channel]; } } @@ -305,7 +211,6 @@ void PackNHWCToNCHWFp16(const void *src, void *dst, int batches, int plane, int } } } - return; } void PackNCHWToNHWCFp16(const void *src, void *dst, int batch, int plane, int channel) { @@ -565,3 +470,246 @@ void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, i } } } + +#ifdef ENABLE_ARM82_A32 +inline void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride) { + asm volatile( + "mov r10, %[src]\n" + "mov r12, %[dst]\n" + "vld1.16 {q0}, [r10], %[src_stride]\n" + "vld1.16 {q2}, [r10], %[src_stride]\n" + "vld1.16 {q4}, [r10], %[src_stride]\n" + "vld1.16 {q6}, [r10], %[src_stride]\n" + + "vtrn.16 d0, d4\n" + "vtrn.16 d1, d5\n" + "vtrn.16 d8, d12\n" + "vtrn.16 d9, d13\n" + + "vld1.16 {q8}, [r10], %[src_stride]\n" + "vld1.16 {q10}, [r10], %[src_stride]\n" + "vld1.16 {q12}, [r10], %[src_stride]\n" + "vld1.16 {q14}, [r10], %[src_stride]\n" + + "vtrn.32 d0, d8\n" + "vtrn.32 d4, d12\n" + "vtrn.32 d1, d9\n" + "vtrn.32 d5, d13\n" + + "vtrn.16 d16, d20\n" + "vtrn.16 d17, d21\n" + "vtrn.16 d24, d28\n" + "vtrn.16 d25, d29\n" + + "vtrn.32 d16, d24\n" + "vtrn.32 d20, d28\n" + "vtrn.32 d17, d25\n" + "vtrn.32 d21, d29\n" + + "vswp d1, d16\n" + "vswp d5, d20\n" + "vswp d9, d24\n" + "vswp d13, d28\n" + + "vst1.16 {q0}, [r12], %[dst_stride]\n" + "vst1.16 {q2}, [r12], %[dst_stride]\n" + "vst1.16 {q4}, [r12], %[dst_stride]\n" + "vst1.16 {q6}, [r12], %[dst_stride]\n" + + "vst1.16 {q8}, [r12], %[dst_stride]\n" + "vst1.16 {q10}, [r12], %[dst_stride]\n" + "vst1.16 {q12}, [r12], %[dst_stride]\n" + "vst1.16 {q14}, [r12], %[dst_stride]\n" + + : + : [ dst ] "r"(dst), [ src ] "r"(src), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +} + +inline void Transpose12x8A32Fp16(const float16_t *src_c, float16_t *dst_c, size_t src_stride, size_t dst_stride) { + asm volatile( + "mov r10, %[src_c]\n" + "mov r12, %[dst_c]\n" + + "vld1.16 {q0}, [r10], %[src_stride]\n" + "vld1.16 {q2}, [r10], %[src_stride]\n" + "vld1.16 {q4}, [r10], %[src_stride]\n" + "vld1.16 {q6}, [r10], %[src_stride]\n" + + "vtrn.16 d0, d4\n" + "vtrn.16 d1, d5\n" + "vtrn.16 d8, d12\n" + "vtrn.16 d9, d13\n" + + "vld1.16 {q8}, [r10], %[src_stride]\n" + "vld1.16 {q10}, [r10], %[src_stride]\n" + "vld1.16 {q12}, [r10], %[src_stride]\n" + "vld1.16 {q14}, [r10], %[src_stride]\n" + + "vtrn.32 d0, d8\n" + "vtrn.32 d4, d12\n" + "vtrn.32 d1, d9\n" + "vtrn.32 d5, d13\n" + + "vtrn.16 d16, d20\n" + "vtrn.16 d17, d21\n" + "vtrn.16 d24, d28\n" + "vtrn.16 d25, d29\n" + + "vld1.16 {q1}, [r10], %[src_stride]\n" + "vld1.16 {q3}, [r10], %[src_stride]\n" + "vld1.16 {q5}, [r10], %[src_stride]\n" + "vld1.16 {q7}, [r10], %[src_stride]\n" + + "vtrn.32 d16, d24\n" + "vtrn.32 d20, d28\n" + "vtrn.32 d17, d25\n" + "vtrn.32 d21, d29\n" + + "vswp d1, d16\n" + "vswp d5, d20\n" + "vswp d9, d24\n" + "vswp d13, d28\n" + + "vtrn.16 d2, d6\n" + "vtrn.16 d3, d7\n" + "vtrn.16 d10, d14\n" + "vtrn.16 d11, d15\n" + + "vtrn.32 d2, d10\n" + "vtrn.32 d6, d14\n" + "vtrn.32 d3, d11\n" + "vtrn.32 d7, d15\n" + + "vst1.16 {q0, d2}, [r12], %[dst_stride]\n" + "vst1.16 {q2, d6}, [r12], %[dst_stride]\n" + "vst1.16 {q4, d10}, [r12], %[dst_stride]\n" + "vst1.16 {q6, d14}, [r12], %[dst_stride]\n" + + "vswp d3, d18\n" + "vswp d7, d22\n" + "vswp d11, d26\n" + "vswp d15, d30\n" + + "vst1.16 {q8, d18}, [r12], %[dst_stride]\n" + "vst1.16 {q10, d22}, [r12], %[dst_stride]\n" + "vst1.16 {q12, d26}, [r12], %[dst_stride]\n" + "vst1.16 {q14, d30}, [r12], %[dst_stride]\n" + + : + : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride) + : "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15"); +} +#endif + +#ifdef ENABLE_ARM64 +inline void Transpose16x8ARM64Fp16(const float16_t *src_ptr, float16_t *dst_ptr, size_t src_stride, size_t dst_stride) { + asm volatile( + "mov x10, %[src_ptr]\n" + "mov x11, %[dst_ptr]\n" + + "ld1 {v0.8h}, [x10], %[src_stride]\n" + "ld1 {v1.8h}, [x10], %[src_stride]\n" + "ld1 {v2.8h}, [x10], %[src_stride]\n" + "ld1 {v3.8h}, [x10], %[src_stride]\n" + "ld1 {v4.8h}, [x10], %[src_stride]\n" + "ld1 {v5.8h}, [x10], %[src_stride]\n" + "ld1 {v6.8h}, [x10], %[src_stride]\n" + "ld1 {v7.8h}, [x10], %[src_stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "ld1 {v8.8h}, [x10], %[src_stride]\n" + "ld1 {v9.8h}, [x10], %[src_stride]\n" + "ld1 {v10.8h}, [x10], %[src_stride]\n" + "ld1 {v11.8h}, [x10], %[src_stride]\n" + "ld1 {v12.8h}, [x10], %[src_stride]\n" + "ld1 {v13.8h}, [x10], %[src_stride]\n" + "ld1 {v14.8h}, [x10], %[src_stride]\n" + "ld1 {v15.8h}, [x10], %[src_stride]\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip1 v16.8h, v8.8h, v9.8h\n" + "zip1 v17.8h, v10.8h, v11.8h\n" + "zip1 v18.8h, v12.8h, v13.8h\n" + "zip1 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "add x10, x11, #16\n" + "st1 {v24.8h}, [x11], %[dst_stride]\n" + "st1 {v28.8h}, [x10], %[dst_stride]\n" + "st1 {v26.8h}, [x11], %[dst_stride]\n" + "st1 {v30.8h}, [x10], %[dst_stride]\n" + "st1 {v25.8h}, [x11], %[dst_stride]\n" + "st1 {v29.8h}, [x10], %[dst_stride]\n" + "st1 {v27.8h}, [x11], %[dst_stride]\n" + "st1 {v31.8h}, [x10], %[dst_stride]\n" + + "zip2 v16.8h, v0.8h, v1.8h\n" + "zip2 v17.8h, v2.8h, v3.8h\n" + "zip2 v18.8h, v4.8h, v5.8h\n" + "zip2 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v16.8h, v8.8h, v9.8h\n" + "zip2 v17.8h, v10.8h, v11.8h\n" + "zip2 v18.8h, v12.8h, v13.8h\n" + "zip2 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], %[dst_stride]\n" + "st1 {v28.8h}, [x10], %[dst_stride]\n" + "st1 {v26.8h}, [x11], %[dst_stride]\n" + "st1 {v30.8h}, [x10], %[dst_stride]\n" + "st1 {v25.8h}, [x11], %[dst_stride]\n" + "st1 {v29.8h}, [x10], %[dst_stride]\n" + "st1 {v27.8h}, [x11], %[dst_stride]\n" + "st1 {v31.8h}, [x10], %[dst_stride]\n" + : + : [ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ src_stride ] "r"(src_stride), [ dst_stride ] "r"(dst_stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} +#endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pack_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pack_fp16.h index 5f8f1632c07..9856f9d76b4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pack_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pack_fp16.h @@ -17,11 +17,9 @@ #ifndef MINDSPORE_NNACL_FP16_PACK_FP16_H_ #define MINDSPORE_NNACL_FP16_PACK_FP16_H_ -#ifdef ENABLE_NEON -#include -#endif #include "nnacl/conv_parameter.h" #include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" #ifdef __cplusplus extern "C" { @@ -72,6 +70,17 @@ void PackNHWCFp16ToC8HWN8Fp16(float16_t *src, float16_t *dst, int batch, int pla void PackNHWC8Fp16ToNHWCFp32(float16_t *src, float *dst, int batch, int plane, int channel); void PackNHWC8ToNHWCFp16(float16_t *src, float16_t *dst, int batch, int plane, int channel); + +#ifdef ENABLE_ARM82_A32 +void Transpose8x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); + +void Transpose12x8A32Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); +#endif + +#ifdef ENABLE_ARM64 +void Transpose16x8ARM64Fp16(const float16_t *src, float16_t *dst, size_t src_stride, size_t dst_stride); +#endif + #ifdef __cplusplus } #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pad_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pad_fp16.h index 514e9f3fbc7..e41db9528d6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pad_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pad_fp16.h @@ -16,9 +16,6 @@ #ifndef MINDSPORE_NNACL_FP16_PAD_FP16_H_ #define MINDSPORE_NNACL_FP16_PAD_FP16_H_ -#ifdef ENABLE_NEON -#include -#endif #include "nnacl/fp32/pad_fp32.h" #ifdef __cplusplus diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pooling_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pooling_fp16.h index f3dec61de72..d20ca72457f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pooling_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/pooling_fp16.h @@ -18,10 +18,8 @@ #define MINDSPORE_NNACL_FP16_POOLING_FP16_H_ #include -#ifdef ENABLE_NEON -#include -#endif #include "nnacl/pooling_parameter.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.c index 5d23c750319..471221e81bb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.c @@ -17,7 +17,7 @@ #include "nnacl/fp16/power_fp16.h" #include "nnacl/errorcode.h" -#if defined(ENABLE_NEON) +#if defined(ENABLE_ARM64) float16x8_t OptimizedPowerSimdFp16(float16x8_t x, const void *exponent) { int tmp = (int)(*(float16_t *)exponent); int exp = abs(tmp); @@ -53,23 +53,23 @@ float16_t OptimizedPowerScalarFp16(float16_t x, const void *exponent) { void PowerBroadCastFp16(const float16_t *input, const float16_t *exponent, float16_t *output, int len, float scale, float shift) { PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; -#if defined(ENABLE_NEON) +#if defined(ENABLE_ARM64) PowerSimdFunFp16 PowerSimdFunFp16_ = NULL; #endif if (CheckInteger(*exponent)) { -#if defined(ENABLE_NEON) +#if defined(ENABLE_ARM64) PowerSimdFunFp16_ = OptimizedPowerSimdFp16; #endif PowerScalarFunFp16_ = OptimizedPowerScalarFp16; } else { -#if defined(ENABLE_NEON) +#if defined(ENABLE_ARM64) PowerSimdFunFp16_ = StdPowerSimdFp16; #endif PowerScalarFunFp16_ = StdPowerScalarFp16; } int i = 0; -#ifdef ENABLE_NEON +#ifdef ENABLE_ARM64 int len_c8 = UP_ROUND(len, C8NUM); float16x8_t scale_8 = vmovq_n_f16(scale); float16x8_t shift_8 = vmovq_n_f16(shift); @@ -87,7 +87,7 @@ void PowerSingleFp16(const float16_t *input, const float16_t *exponent, float16_ float shift) { int i = 0; PowerScalarFunFp16 PowerScalarFunFp16_ = NULL; -#ifdef ENABLE_NEON +#ifdef ENABLE_ARM64 int len_c8 = UP_ROUND(len, C8NUM); float16x8_t scale_8 = vmovq_n_f16(scale); float16x8_t shift_8 = vmovq_n_f16(shift); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.h index 8a5ad96fb02..206c68afd83 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/power_fp16.h @@ -19,9 +19,10 @@ #include #include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" #include "nnacl/power_parameter.h" -#if defined(ENABLE_NEON) +#if defined(ENABLE_ARM64) typedef float16x8_t (*PowerSimdFunFp16)(float16x8_t x, const void *exponent); #endif typedef float16_t (*PowerScalarFunFp16)(float16_t x, const void *exponent); @@ -36,7 +37,7 @@ static inline float16_t StdPowerScalarFp16(float16_t x, const void *exponent) { return powf(x, *(float16_t *)exponent); } -#if defined(ENABLE_NEON) +#if defined(ENABLE_ARM64) static inline float16x8_t StdPowerSimdFp16(float16x8_t x, const void *exponent) { float16x8_t result; result[0] = powf(x[0], *(float16_t *)exponent); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/quant_dtype_cast_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/quant_dtype_cast_fp16.h index a92025d5c1a..f9a612526b4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/quant_dtype_cast_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/quant_dtype_cast_fp16.h @@ -18,10 +18,7 @@ #define MINDSPORE_NNACL_FP16_QUANTDTYPECAST_FP16_H_ #include "nnacl/op_base.h" - -#ifdef ENABLE_NEON -#include -#endif +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" #ifdef __cplusplus extern "C" { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/reduce_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/reduce_fp16.h index 442b1c1b64c..c080882c971 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/reduce_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/reduce_fp16.h @@ -19,9 +19,6 @@ #include "nnacl/op_base.h" #include "nnacl/reduce_parameter.h" -#ifdef ENABLE_NEON -#include -#endif #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/scale_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/scale_fp16.h index 81da793d6bc..98208d3819e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/scale_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/scale_fp16.h @@ -18,10 +18,9 @@ #define MINDSPORE_NNACL_SCALE_FP16_H_ #include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" #include "nnacl/scale.h" -#ifdef ENABLE_NEON -#include -#endif + #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/softmax_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/softmax_fp16.c index 31c44ad268c..dcc64d9199e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/softmax_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/softmax_fp16.c @@ -40,7 +40,7 @@ void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channe } } int k = 0; -#ifdef ENABLE_NEON +#ifdef ENABLE_ARM64 int count2 = (channel / C8NUM) * C8NUM; for (; k < count2; k += C8NUM) { float16x8_t input_8 = vld1q_f16(src + cur_batch_offset + k); @@ -58,9 +58,9 @@ void SoftmaxNormFp16(const float16_t *src, float16_t *dst, int batch, int channe void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel) { int cur_batch_offset = 0; for (int i = 0; i < batch; i++, cur_batch_offset += channel) { - float16_t sum = 0; + float16_t sum = 0.0f; int j = 0; -#ifdef ENABLE_NEON +#ifdef ENABLE_ARM64 float16x8_t sum8 = vdupq_n_f16(0); int count = (channel / C8NUM) * C8NUM; for (; j < count; j += C8NUM) { @@ -72,7 +72,7 @@ void SumAndDivFp16(const float16_t *src, float16_t *dst, int batch, int channel) sum += src[cur_batch_offset + j]; } int k = 0; -#ifdef ENABLE_NEON +#ifdef ENABLE_ARM64 const float16_t div = 1.0f / sum; for (; k < count; k += C8NUM) { vst1q_f16(dst + cur_batch_offset + k, vmulq_n_f16(vld1q_f16(src + cur_batch_offset + k), div)); @@ -117,7 +117,7 @@ void SoftmaxFp16(const float16_t *input_ptr, float16_t *output_ptr, float16_t *s } for (int j = 0; j < input_shape[axis]; j++) { int axis_offset = inner_offset + j * inner_size; - output_ptr[axis_offset] = exp(input_ptr[axis_offset] - max_data); + output_ptr[axis_offset] = expf(input_ptr[axis_offset] - max_data); sum_data[k + sum_outter_offset] += output_ptr[axis_offset]; } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/softmax_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/softmax_fp16.h index cdd6ead438b..3de8e7133e7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/softmax_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/softmax_fp16.h @@ -18,10 +18,9 @@ #define MINDSPORE_NNACL_FP16_SOFTMAX_FP16_H_ #include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" #include "nnacl/softmax_parameter.h" -#ifdef ENABLE_NEON -#include -#endif + #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/transpose_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/transpose_fp16.h index 9c300df2de7..49b5f2db122 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/transpose_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/transpose_fp16.h @@ -18,10 +18,8 @@ #define MINDSPORE_NNACL_FP16_TRANSPOSE_FP16_H_ #include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" #include "nnacl/transpose.h" -#ifdef ENABLE_NEON -#include -#endif #ifdef __cplusplus extern "C" { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.c index 26a6d5775c3..b0aa3383ec9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.c @@ -16,562 +16,15 @@ #include "nnacl/fp16/winograd_transform_fp16.h" -// for fp16 convolution 3x3 filter/input/output transform F(4,3) -void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step) { - float16x8_t d00 = vld1q_f16(tmp_data); - float16x8_t d01 = vld1q_f16(tmp_data + 8); - float16x8_t d02 = vld1q_f16(tmp_data + 2 * 8); - float16x8_t d03 = vld1q_f16(tmp_data + 3 * 8); - float16x8_t d04 = vld1q_f16(tmp_data + 4 * 8); - float16x8_t d05 = vld1q_f16(tmp_data + 5 * 8); - - float16x8_t d10 = vld1q_f16(tmp_data + 6 * 8); - float16x8_t d11 = vld1q_f16(tmp_data + 7 * 8); - float16x8_t d12 = vld1q_f16(tmp_data + 8 * 8); - float16x8_t d13 = vld1q_f16(tmp_data + 9 * 8); - float16x8_t d14 = vld1q_f16(tmp_data + 10 * 8); - float16x8_t d15 = vld1q_f16(tmp_data + 11 * 8); - - float16x8_t d20 = vld1q_f16(tmp_data + 12 * 8); - float16x8_t d21 = vld1q_f16(tmp_data + 13 * 8); - float16x8_t d22 = vld1q_f16(tmp_data + 14 * 8); - float16x8_t d23 = vld1q_f16(tmp_data + 15 * 8); - float16x8_t d24 = vld1q_f16(tmp_data + 16 * 8); - float16x8_t d25 = vld1q_f16(tmp_data + 17 * 8); - - float16x8_t d30 = vld1q_f16(tmp_data + 18 * 8); - float16x8_t d31 = vld1q_f16(tmp_data + 19 * 8); - float16x8_t d32 = vld1q_f16(tmp_data + 20 * 8); - float16x8_t d33 = vld1q_f16(tmp_data + 21 * 8); - float16x8_t d34 = vld1q_f16(tmp_data + 22 * 8); - float16x8_t d35 = vld1q_f16(tmp_data + 23 * 8); - - float16x8_t d40 = vld1q_f16(tmp_data + 24 * 8); - float16x8_t d41 = vld1q_f16(tmp_data + 25 * 8); - float16x8_t d42 = vld1q_f16(tmp_data + 26 * 8); - float16x8_t d43 = vld1q_f16(tmp_data + 27 * 8); - float16x8_t d44 = vld1q_f16(tmp_data + 28 * 8); - float16x8_t d45 = vld1q_f16(tmp_data + 29 * 8); - - float16x8_t d50 = vld1q_f16(tmp_data + 30 * 8); - float16x8_t d51 = vld1q_f16(tmp_data + 31 * 8); - float16x8_t d52 = vld1q_f16(tmp_data + 32 * 8); - float16x8_t d53 = vld1q_f16(tmp_data + 33 * 8); - float16x8_t d54 = vld1q_f16(tmp_data + 34 * 8); - float16x8_t d55 = vld1q_f16(tmp_data + 35 * 8); - - float16x8_t t00 = vaddq_f16(vsubq_f16(vmulq_n_f16(d00, 4), vmulq_n_f16(d20, 5)), d40); - float16x8_t t01 = vaddq_f16(vsubq_f16(vmulq_n_f16(d01, 4), vmulq_n_f16(d21, 5)), d41); - float16x8_t t02 = vaddq_f16(vsubq_f16(vmulq_n_f16(d02, 4), vmulq_n_f16(d22, 5)), d42); - float16x8_t t03 = vaddq_f16(vsubq_f16(vmulq_n_f16(d03, 4), vmulq_n_f16(d23, 5)), d43); - float16x8_t t04 = vaddq_f16(vsubq_f16(vmulq_n_f16(d04, 4), vmulq_n_f16(d24, 5)), d44); - float16x8_t t05 = vaddq_f16(vsubq_f16(vmulq_n_f16(d05, 4), vmulq_n_f16(d25, 5)), d45); - - float16x8_t t10 = vaddq_f16(vaddq_f16(d30, d40), vmulq_n_f16(vaddq_f16(d10, d20), -4)); - float16x8_t t11 = vaddq_f16(vaddq_f16(d31, d41), vmulq_n_f16(vaddq_f16(d11, d21), -4)); - float16x8_t t12 = vaddq_f16(vaddq_f16(d32, d42), vmulq_n_f16(vaddq_f16(d12, d22), -4)); - float16x8_t t13 = vaddq_f16(vaddq_f16(d33, d43), vmulq_n_f16(vaddq_f16(d13, d23), -4)); - float16x8_t t14 = vaddq_f16(vaddq_f16(d34, d44), vmulq_n_f16(vaddq_f16(d14, d24), -4)); - float16x8_t t15 = vaddq_f16(vaddq_f16(d35, d45), vmulq_n_f16(vaddq_f16(d15, d25), -4)); - - float16x8_t t20 = vaddq_f16(vsubq_f16(d40, d30), vmulq_n_f16(vsubq_f16(d10, d20), 4)); - float16x8_t t21 = vaddq_f16(vsubq_f16(d41, d31), vmulq_n_f16(vsubq_f16(d11, d21), 4)); - float16x8_t t22 = vaddq_f16(vsubq_f16(d42, d32), vmulq_n_f16(vsubq_f16(d12, d22), 4)); - float16x8_t t23 = vaddq_f16(vsubq_f16(d43, d33), vmulq_n_f16(vsubq_f16(d13, d23), 4)); - float16x8_t t24 = vaddq_f16(vsubq_f16(d44, d34), vmulq_n_f16(vsubq_f16(d14, d24), 4)); - float16x8_t t25 = vaddq_f16(vsubq_f16(d45, d35), vmulq_n_f16(vsubq_f16(d15, d25), 4)); - - float16x8_t t30 = vaddq_f16(vsubq_f16(d40, d20), vmulq_n_f16(vsubq_f16(d30, d10), 2)); - float16x8_t t31 = vaddq_f16(vsubq_f16(d41, d21), vmulq_n_f16(vsubq_f16(d31, d11), 2)); - float16x8_t t32 = vaddq_f16(vsubq_f16(d42, d22), vmulq_n_f16(vsubq_f16(d32, d12), 2)); - float16x8_t t33 = vaddq_f16(vsubq_f16(d43, d23), vmulq_n_f16(vsubq_f16(d33, d13), 2)); - float16x8_t t34 = vaddq_f16(vsubq_f16(d44, d24), vmulq_n_f16(vsubq_f16(d34, d14), 2)); - float16x8_t t35 = vaddq_f16(vsubq_f16(d45, d25), vmulq_n_f16(vsubq_f16(d35, d15), 2)); - - float16x8_t t40 = vaddq_f16(vsubq_f16(d40, d20), vmulq_n_f16(vsubq_f16(d10, d30), 2)); - float16x8_t t41 = vaddq_f16(vsubq_f16(d41, d21), vmulq_n_f16(vsubq_f16(d11, d31), 2)); - float16x8_t t42 = vaddq_f16(vsubq_f16(d42, d22), vmulq_n_f16(vsubq_f16(d12, d32), 2)); - float16x8_t t43 = vaddq_f16(vsubq_f16(d43, d23), vmulq_n_f16(vsubq_f16(d13, d33), 2)); - float16x8_t t44 = vaddq_f16(vsubq_f16(d44, d24), vmulq_n_f16(vsubq_f16(d14, d34), 2)); - float16x8_t t45 = vaddq_f16(vsubq_f16(d45, d25), vmulq_n_f16(vsubq_f16(d15, d35), 2)); - - float16x8_t t50 = vaddq_f16(vsubq_f16(vmulq_n_f16(d10, 4), vmulq_n_f16(d30, 5)), d50); - float16x8_t t51 = vaddq_f16(vsubq_f16(vmulq_n_f16(d11, 4), vmulq_n_f16(d31, 5)), d51); - float16x8_t t52 = vaddq_f16(vsubq_f16(vmulq_n_f16(d12, 4), vmulq_n_f16(d32, 5)), d52); - float16x8_t t53 = vaddq_f16(vsubq_f16(vmulq_n_f16(d13, 4), vmulq_n_f16(d33, 5)), d53); - float16x8_t t54 = vaddq_f16(vsubq_f16(vmulq_n_f16(d14, 4), vmulq_n_f16(d34, 5)), d54); - float16x8_t t55 = vaddq_f16(vsubq_f16(vmulq_n_f16(d15, 4), vmulq_n_f16(d35, 5)), d55); - - float16x8_t m00 = vaddq_f16(vsubq_f16(vmulq_n_f16(t00, 4), vmulq_n_f16(t02, 5)), t04); - float16x8_t m01 = vaddq_f16(vaddq_f16(t03, t04), vmulq_n_f16(vaddq_f16(t01, t02), -4)); - float16x8_t m02 = vaddq_f16(vsubq_f16(t04, t03), vmulq_n_f16(vsubq_f16(t01, t02), 4)); - float16x8_t m03 = vaddq_f16(vsubq_f16(t04, t02), vmulq_n_f16(vsubq_f16(t03, t01), 2)); - float16x8_t m04 = vaddq_f16(vsubq_f16(t04, t02), vmulq_n_f16(vsubq_f16(t01, t03), 2)); - float16x8_t m05 = vaddq_f16(vsubq_f16(vmulq_n_f16(t01, 4), vmulq_n_f16(t03, 5)), t05); - - float16x8_t m10 = vaddq_f16(vsubq_f16(vmulq_n_f16(t10, 4), vmulq_n_f16(t12, 5)), t14); - float16x8_t m11 = vaddq_f16(vaddq_f16(t13, t14), vmulq_n_f16(vaddq_f16(t11, t12), -4)); - float16x8_t m12 = vaddq_f16(vsubq_f16(t14, t13), vmulq_n_f16(vsubq_f16(t11, t12), 4)); - float16x8_t m13 = vaddq_f16(vsubq_f16(t14, t12), vmulq_n_f16(vsubq_f16(t13, t11), 2)); - float16x8_t m14 = vaddq_f16(vsubq_f16(t14, t12), vmulq_n_f16(vsubq_f16(t11, t13), 2)); - float16x8_t m15 = vaddq_f16(vsubq_f16(vmulq_n_f16(t11, 4), vmulq_n_f16(t13, 5)), t15); - - float16x8_t m20 = vaddq_f16(vsubq_f16(vmulq_n_f16(t20, 4), vmulq_n_f16(t22, 5)), t24); - float16x8_t m21 = vaddq_f16(vaddq_f16(t23, t24), vmulq_n_f16(vaddq_f16(t21, t22), -4)); - float16x8_t m22 = vaddq_f16(vsubq_f16(t24, t23), vmulq_n_f16(vsubq_f16(t21, t22), 4)); - float16x8_t m23 = vaddq_f16(vsubq_f16(t24, t22), vmulq_n_f16(vsubq_f16(t23, t21), 2)); - float16x8_t m24 = vaddq_f16(vsubq_f16(t24, t22), vmulq_n_f16(vsubq_f16(t21, t23), 2)); - float16x8_t m25 = vaddq_f16(vsubq_f16(vmulq_n_f16(t21, 4), vmulq_n_f16(t23, 5)), t25); - - float16x8_t m30 = vaddq_f16(vsubq_f16(vmulq_n_f16(t30, 4), vmulq_n_f16(t32, 5)), t34); - float16x8_t m31 = vaddq_f16(vaddq_f16(t33, t34), vmulq_n_f16(vaddq_f16(t31, t32), -4)); - float16x8_t m32 = vaddq_f16(vsubq_f16(t34, t33), vmulq_n_f16(vsubq_f16(t31, t32), 4)); - float16x8_t m33 = vaddq_f16(vsubq_f16(t34, t32), vmulq_n_f16(vsubq_f16(t33, t31), 2)); - float16x8_t m34 = vaddq_f16(vsubq_f16(t34, t32), vmulq_n_f16(vsubq_f16(t31, t33), 2)); - float16x8_t m35 = vaddq_f16(vsubq_f16(vmulq_n_f16(t31, 4), vmulq_n_f16(t33, 5)), t35); - - float16x8_t m40 = vaddq_f16(vsubq_f16(vmulq_n_f16(t40, 4), vmulq_n_f16(t42, 5)), t44); - float16x8_t m41 = vaddq_f16(vaddq_f16(t43, t44), vmulq_n_f16(vaddq_f16(t41, t42), -4)); - float16x8_t m42 = vaddq_f16(vsubq_f16(t44, t43), vmulq_n_f16(vsubq_f16(t41, t42), 4)); - float16x8_t m43 = vaddq_f16(vsubq_f16(t44, t42), vmulq_n_f16(vsubq_f16(t43, t41), 2)); - float16x8_t m44 = vaddq_f16(vsubq_f16(t44, t42), vmulq_n_f16(vsubq_f16(t41, t43), 2)); - float16x8_t m45 = vaddq_f16(vsubq_f16(vmulq_n_f16(t41, 4), vmulq_n_f16(t43, 5)), t45); - - float16x8_t m50 = vaddq_f16(vsubq_f16(vmulq_n_f16(t50, 4), vmulq_n_f16(t52, 5)), t54); - float16x8_t m51 = vaddq_f16(vaddq_f16(t53, t54), vmulq_n_f16(vaddq_f16(t51, t52), -4)); - float16x8_t m52 = vaddq_f16(vsubq_f16(t54, t53), vmulq_n_f16(vsubq_f16(t51, t52), 4)); - float16x8_t m53 = vaddq_f16(vsubq_f16(t54, t52), vmulq_n_f16(vsubq_f16(t53, t51), 2)); - float16x8_t m54 = vaddq_f16(vsubq_f16(t54, t52), vmulq_n_f16(vsubq_f16(t51, t53), 2)); - float16x8_t m55 = vaddq_f16(vsubq_f16(vmulq_n_f16(t51, 4), vmulq_n_f16(t53, 5)), t55); - - vst1_f16(trans_input_data, vget_low_f16(m00)); - vst1_f16(trans_input_data + 64, vget_high_f16(m00)); - vst1_f16(trans_input_data + step, vget_low_f16(m01)); - vst1_f16(trans_input_data + step + 64, vget_high_f16(m01)); - vst1_f16(trans_input_data + 2 * step, vget_low_f16(m02)); - vst1_f16(trans_input_data + 2 * step + 64, vget_high_f16(m02)); - vst1_f16(trans_input_data + 3 * step, vget_low_f16(m03)); - vst1_f16(trans_input_data + 3 * step + 64, vget_high_f16(m03)); - vst1_f16(trans_input_data + 4 * step, vget_low_f16(m04)); - vst1_f16(trans_input_data + 4 * step + 64, vget_high_f16(m04)); - vst1_f16(trans_input_data + 5 * step, vget_low_f16(m05)); - vst1_f16(trans_input_data + 5 * step + 64, vget_high_f16(m05)); - - vst1_f16(trans_input_data + 6 * step, vget_low_f16(m10)); - vst1_f16(trans_input_data + 6 * step + 64, vget_high_f16(m10)); - vst1_f16(trans_input_data + 7 * step, vget_low_f16(m11)); - vst1_f16(trans_input_data + 7 * step + 64, vget_high_f16(m11)); - vst1_f16(trans_input_data + 8 * step, vget_low_f16(m12)); - vst1_f16(trans_input_data + 8 * step + 64, vget_high_f16(m12)); - vst1_f16(trans_input_data + 9 * step, vget_low_f16(m13)); - vst1_f16(trans_input_data + 9 * step + 64, vget_high_f16(m13)); - vst1_f16(trans_input_data + 10 * step, vget_low_f16(m14)); - vst1_f16(trans_input_data + 10 * step + 64, vget_high_f16(m14)); - vst1_f16(trans_input_data + 11 * step, vget_low_f16(m15)); - vst1_f16(trans_input_data + 11 * step + 64, vget_high_f16(m15)); - - vst1_f16(trans_input_data + 12 * step, vget_low_f16(m20)); - vst1_f16(trans_input_data + 12 * step + 64, vget_high_f16(m20)); - vst1_f16(trans_input_data + 13 * step, vget_low_f16(m21)); - vst1_f16(trans_input_data + 13 * step + 64, vget_high_f16(m21)); - vst1_f16(trans_input_data + 14 * step, vget_low_f16(m22)); - vst1_f16(trans_input_data + 14 * step + 64, vget_high_f16(m22)); - vst1_f16(trans_input_data + 15 * step, vget_low_f16(m23)); - vst1_f16(trans_input_data + 15 * step + 64, vget_high_f16(m23)); - vst1_f16(trans_input_data + 16 * step, vget_low_f16(m24)); - vst1_f16(trans_input_data + 16 * step + 64, vget_high_f16(m24)); - vst1_f16(trans_input_data + 17 * step, vget_low_f16(m25)); - vst1_f16(trans_input_data + 17 * step + 64, vget_high_f16(m25)); - - vst1_f16(trans_input_data + 18 * step, vget_low_f16(m30)); - vst1_f16(trans_input_data + 18 * step + 64, vget_high_f16(m30)); - vst1_f16(trans_input_data + 19 * step, vget_low_f16(m31)); - vst1_f16(trans_input_data + 19 * step + 64, vget_high_f16(m31)); - vst1_f16(trans_input_data + 20 * step, vget_low_f16(m32)); - vst1_f16(trans_input_data + 20 * step + 64, vget_high_f16(m32)); - vst1_f16(trans_input_data + 21 * step, vget_low_f16(m33)); - vst1_f16(trans_input_data + 21 * step + 64, vget_high_f16(m33)); - vst1_f16(trans_input_data + 22 * step, vget_low_f16(m34)); - vst1_f16(trans_input_data + 22 * step + 64, vget_high_f16(m34)); - vst1_f16(trans_input_data + 23 * step, vget_low_f16(m35)); - vst1_f16(trans_input_data + 23 * step + 64, vget_high_f16(m35)); - - vst1_f16(trans_input_data + 24 * step, vget_low_f16(m40)); - vst1_f16(trans_input_data + 24 * step + 64, vget_high_f16(m40)); - vst1_f16(trans_input_data + 25 * step, vget_low_f16(m41)); - vst1_f16(trans_input_data + 25 * step + 64, vget_high_f16(m41)); - vst1_f16(trans_input_data + 26 * step, vget_low_f16(m42)); - vst1_f16(trans_input_data + 26 * step + 64, vget_high_f16(m42)); - vst1_f16(trans_input_data + 27 * step, vget_low_f16(m43)); - vst1_f16(trans_input_data + 27 * step + 64, vget_high_f16(m43)); - vst1_f16(trans_input_data + 28 * step, vget_low_f16(m44)); - vst1_f16(trans_input_data + 28 * step + 64, vget_high_f16(m44)); - vst1_f16(trans_input_data + 29 * step, vget_low_f16(m45)); - vst1_f16(trans_input_data + 29 * step + 64, vget_high_f16(m45)); - - vst1_f16(trans_input_data + 30 * step, vget_low_f16(m50)); - vst1_f16(trans_input_data + 30 * step + 64, vget_high_f16(m50)); - vst1_f16(trans_input_data + 31 * step, vget_low_f16(m51)); - vst1_f16(trans_input_data + 31 * step + 64, vget_high_f16(m51)); - vst1_f16(trans_input_data + 32 * step, vget_low_f16(m52)); - vst1_f16(trans_input_data + 32 * step + 64, vget_high_f16(m52)); - vst1_f16(trans_input_data + 33 * step, vget_low_f16(m53)); - vst1_f16(trans_input_data + 33 * step + 64, vget_high_f16(m53)); - vst1_f16(trans_input_data + 34 * step, vget_low_f16(m54)); - vst1_f16(trans_input_data + 34 * step + 64, vget_high_f16(m54)); - vst1_f16(trans_input_data + 35 * step, vget_low_f16(m55)); - vst1_f16(trans_input_data + 35 * step + 64, vget_high_f16(m55)); -} - -void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, - int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) { - // input data format : nhwc - const int output_unit = 4; - int input_channel = conv_param->input_channel_; - int input_width = conv_param->input_w_; - int input_height = conv_param->input_h_; - int pad_w = conv_param->pad_l_; - int pad_h = conv_param->pad_u_; - int ic8 = UP_DIV(input_channel, C8NUM); - if (out_w_block == 0) { - return; - } - for (int cal_id = 0; cal_id < real_cal_num; cal_id++) { - int x_id = start_index + cal_id; - int origin_x = (x_id % out_w_block) * output_unit - pad_w; - int origin_y = (x_id / out_w_block) * output_unit - pad_h; - int real_x_start = origin_x > 0 ? 0 : -origin_x; - int real_x_end = (origin_x + 6) < input_width ? 6 : (input_width - origin_x); - int real_y_start = origin_y > 0 ? 0 : -origin_y; - int real_y_end = (origin_y + 6) < input_height ? 6 : (input_height - origin_y); - - int src_plane_offset = ic8 * C8NUM * (origin_y * input_width + origin_x); - int dst_plane_offset = cal_id * C4NUM; - for (int ic = 0; ic < ic8; ic++) { - // clear tmp buffer - memset(tmp_data, 0, 6 * 6 * C8NUM * sizeof(float16_t)); - - // get real input block with padding - int src_ic4_offset = src_plane_offset + ic * C8NUM; - for (int interval = real_y_start; interval < real_y_end; interval++) { - int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * ic8 * C8NUM; - int dst_y_offset = interval * 6 * C8NUM + real_x_start * C8NUM; - for (int j = 0; j < (real_x_end - real_x_start); j++) { - int src_x_offset = src_y_offset + j * ic8 * C8NUM; - int dst_x_offset = dst_y_offset + j * C8NUM; - float16_t *src_addr = (float16_t *)(input_data) + src_x_offset; - float16_t *dst_addr = tmp_data + dst_x_offset; - vst1q_f16(dst_addr, vld1q_f16(src_addr)); - } - } - - // input transform - int dst_ic4_offset = dst_plane_offset + ic * 16 * C8NUM; - size_t dst_step = ic8 * C8NUM * 16; - float16_t *trans_input_ptr = trans_input + dst_ic4_offset; - Conv3x3Fp16InputUnit(tmp_data, trans_input_ptr, dst_step); - } - } -} - -void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC4, int output_channel, - int kernel_plane) { - int dst_step = iC4 * C4NUM * 8; - for (int o = 0; o < output_channel; o++) { - int oc8_block_num = o / C8NUM; - int oc8_block_rem = o % C8NUM; - int src_oc_offset = o * iC4 * C4NUM * kernel_plane; - int dst_oc_offset = oc8_block_num * C8NUM * iC4 * C4NUM * 36 + oc8_block_rem; - for (int i = 0; i < iC4; i++) { - const float16_t *src_ic4_ptr = weight_data + src_oc_offset + i * kernel_plane * C4NUM; - float16_t *dst_ic4_ptr = trans_weight + dst_oc_offset + i * 8 * C4NUM; - float16x4_t g00 = vld1_f16(src_ic4_ptr); - float16x4_t g01 = vld1_f16(src_ic4_ptr + 4); - float16x4_t g02 = vld1_f16(src_ic4_ptr + 2 * 4); - float16x4_t g10 = vld1_f16(src_ic4_ptr + 3 * 4); - float16x4_t g11 = vld1_f16(src_ic4_ptr + 4 * 4); - float16x4_t g12 = vld1_f16(src_ic4_ptr + 5 * 4); - float16x4_t g20 = vld1_f16(src_ic4_ptr + 6 * 4); - float16x4_t g21 = vld1_f16(src_ic4_ptr + 7 * 4); - float16x4_t g22 = vld1_f16(src_ic4_ptr + 8 * 4); - - float16x4_t dst00 = vmul_n_f16(g00, 0.25); - float16x4_t dst01 = vmul_n_f16(g01, 0.25); - float16x4_t dst02 = vmul_n_f16(g02, 0.25); - - float16x4_t dst10 = vmul_n_f16(vadd_f16(g00, vadd_f16(g10, g20)), -0.1666666666667); - float16x4_t dst11 = vmul_n_f16(vadd_f16(g01, vadd_f16(g11, g21)), -0.1666666666667); - float16x4_t dst12 = vmul_n_f16(vadd_f16(g02, vadd_f16(g12, g22)), -0.1666666666667); - - float16x4_t dst20 = vmul_n_f16(vsub_f16(vadd_f16(g00, g20), g10), -0.1666666666667); - float16x4_t dst21 = vmul_n_f16(vsub_f16(vadd_f16(g01, g21), g11), -0.1666666666667); - float16x4_t dst22 = vmul_n_f16(vsub_f16(vadd_f16(g02, g22), g12), -0.1666666666667); - - float16x4_t dst30 = vadd_f16(vmul_n_f16(g10, 0.08333333333333), - vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667))); - float16x4_t dst31 = vadd_f16(vmul_n_f16(g11, 0.08333333333333), - vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667))); - float16x4_t dst32 = vadd_f16(vmul_n_f16(g12, 0.08333333333333), - vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667))); - - float16x4_t dst40 = vsub_f16(vadd_f16(vmul_n_f16(g00, 0.04166666666667), vmul_n_f16(g20, 0.1666666666667)), - vmul_n_f16(g10, 0.08333333333333)); - float16x4_t dst41 = vsub_f16(vadd_f16(vmul_n_f16(g01, 0.04166666666667), vmul_n_f16(g21, 0.1666666666667)), - vmul_n_f16(g11, 0.08333333333333)); - float16x4_t dst42 = vsub_f16(vadd_f16(vmul_n_f16(g02, 0.04166666666667), vmul_n_f16(g22, 0.1666666666667)), - vmul_n_f16(g12, 0.08333333333333)); - - float16x4_t dst50 = g20; - float16x4_t dst51 = g21; - float16x4_t dst52 = g22; - - float16x4_t m00 = vmul_n_f16(dst00, 0.25); - float16x4_t m01 = vmul_n_f16(vadd_f16(dst00, vadd_f16(dst01, dst02)), -0.1666666666667); - float16x4_t m02 = vmul_n_f16(vsub_f16(vadd_f16(dst00, dst02), dst01), -0.1666666666667); - float16x4_t m03 = vadd_f16(vmul_n_f16(dst01, 0.08333333333333), - vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667))); - float16x4_t m04 = vsub_f16(vadd_f16(vmul_n_f16(dst00, 0.04166666666667), vmul_n_f16(dst02, 0.1666666666667)), - vmul_n_f16(dst01, 0.08333333333333)); - float16x4_t m05 = dst02; - - float16x4_t m10 = vmul_n_f16(dst10, 0.25); - float16x4_t m11 = vmul_n_f16(vadd_f16(dst10, vadd_f16(dst11, dst12)), -0.1666666666667); - float16x4_t m12 = vmul_n_f16(vsub_f16(vadd_f16(dst10, dst12), dst11), -0.1666666666667); - float16x4_t m13 = vadd_f16(vmul_n_f16(dst11, 0.08333333333333), - vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667))); - float16x4_t m14 = vsub_f16(vadd_f16(vmul_n_f16(dst10, 0.04166666666667), vmul_n_f16(dst12, 0.1666666666667)), - vmul_n_f16(dst11, 0.08333333333333)); - float16x4_t m15 = dst12; - - float16x4_t m20 = vmul_n_f16(dst20, 0.25); - float16x4_t m21 = vmul_n_f16(vadd_f16(dst20, vadd_f16(dst21, dst22)), -0.1666666666667); - float16x4_t m22 = vmul_n_f16(vsub_f16(vadd_f16(dst20, dst22), dst21), -0.1666666666667); - float16x4_t m23 = vadd_f16(vmul_n_f16(dst21, 0.08333333333333), - vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667))); - float16x4_t m24 = vsub_f16(vadd_f16(vmul_n_f16(dst20, 0.04166666666667), vmul_n_f16(dst22, 0.1666666666667)), - vmul_n_f16(dst21, 0.08333333333333)); - float16x4_t m25 = dst22; - - float16x4_t m30 = vmul_n_f16(dst30, 0.25); - float16x4_t m31 = vmul_n_f16(vadd_f16(dst30, vadd_f16(dst31, dst32)), -0.1666666666667); - float16x4_t m32 = vmul_n_f16(vsub_f16(vadd_f16(dst30, dst32), dst31), -0.1666666666667); - float16x4_t m33 = vadd_f16(vmul_n_f16(dst31, 0.08333333333333), - vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667))); - float16x4_t m34 = vsub_f16(vadd_f16(vmul_n_f16(dst30, 0.04166666666667), vmul_n_f16(dst32, 0.1666666666667)), - vmul_n_f16(dst31, 0.08333333333333)); - float16x4_t m35 = dst32; - - float16x4_t m40 = vmul_n_f16(dst40, 0.25); - float16x4_t m41 = vmul_n_f16(vadd_f16(dst40, vadd_f16(dst41, dst42)), -0.1666666666667); - float16x4_t m42 = vmul_n_f16(vsub_f16(vadd_f16(dst40, dst42), dst41), -0.1666666666667); - float16x4_t m43 = vadd_f16(vmul_n_f16(dst41, 0.08333333333333), - vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667))); - float16x4_t m44 = vsub_f16(vadd_f16(vmul_n_f16(dst40, 0.04166666666667), vmul_n_f16(dst42, 0.1666666666667)), - vmul_n_f16(dst41, 0.08333333333333)); - float16x4_t m45 = dst42; - - float16x4_t m50 = vmul_n_f16(dst50, 0.25); - float16x4_t m51 = vmul_n_f16(vadd_f16(dst50, vadd_f16(dst51, dst52)), -0.1666666666667); - float16x4_t m52 = vmul_n_f16(vsub_f16(vadd_f16(dst50, dst52), dst51), -0.1666666666667); - float16x4_t m53 = vadd_f16(vmul_n_f16(dst51, 0.08333333333333), - vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667))); - float16x4_t m54 = vsub_f16(vadd_f16(vmul_n_f16(dst50, 0.04166666666667), vmul_n_f16(dst52, 0.1666666666667)), - vmul_n_f16(dst51, 0.08333333333333)); - float16x4_t m55 = dst52; - - for (int j = 0; j < 4; j++) { - dst_ic4_ptr[j * 8] = m00[j]; - dst_ic4_ptr[j * 8 + dst_step] = m01[j]; - dst_ic4_ptr[j * 8 + 2 * dst_step] = m02[j]; - dst_ic4_ptr[j * 8 + 3 * dst_step] = m03[j]; - dst_ic4_ptr[j * 8 + 4 * dst_step] = m04[j]; - dst_ic4_ptr[j * 8 + 5 * dst_step] = m05[j]; - dst_ic4_ptr[j * 8 + 6 * dst_step] = m10[j]; - dst_ic4_ptr[j * 8 + 7 * dst_step] = m11[j]; - dst_ic4_ptr[j * 8 + 8 * dst_step] = m12[j]; - dst_ic4_ptr[j * 8 + 9 * dst_step] = m13[j]; - dst_ic4_ptr[j * 8 + 10 * dst_step] = m14[j]; - dst_ic4_ptr[j * 8 + 11 * dst_step] = m15[j]; - dst_ic4_ptr[j * 8 + 12 * dst_step] = m20[j]; - dst_ic4_ptr[j * 8 + 13 * dst_step] = m21[j]; - dst_ic4_ptr[j * 8 + 14 * dst_step] = m22[j]; - dst_ic4_ptr[j * 8 + 15 * dst_step] = m23[j]; - dst_ic4_ptr[j * 8 + 16 * dst_step] = m24[j]; - dst_ic4_ptr[j * 8 + 17 * dst_step] = m25[j]; - dst_ic4_ptr[j * 8 + 18 * dst_step] = m30[j]; - dst_ic4_ptr[j * 8 + 19 * dst_step] = m31[j]; - dst_ic4_ptr[j * 8 + 20 * dst_step] = m32[j]; - dst_ic4_ptr[j * 8 + 21 * dst_step] = m33[j]; - dst_ic4_ptr[j * 8 + 22 * dst_step] = m34[j]; - dst_ic4_ptr[j * 8 + 23 * dst_step] = m35[j]; - dst_ic4_ptr[j * 8 + 24 * dst_step] = m40[j]; - dst_ic4_ptr[j * 8 + 25 * dst_step] = m41[j]; - dst_ic4_ptr[j * 8 + 26 * dst_step] = m42[j]; - dst_ic4_ptr[j * 8 + 27 * dst_step] = m43[j]; - dst_ic4_ptr[j * 8 + 28 * dst_step] = m44[j]; - dst_ic4_ptr[j * 8 + 29 * dst_step] = m45[j]; - dst_ic4_ptr[j * 8 + 30 * dst_step] = m50[j]; - dst_ic4_ptr[j * 8 + 31 * dst_step] = m51[j]; - dst_ic4_ptr[j * 8 + 32 * dst_step] = m52[j]; - dst_ic4_ptr[j * 8 + 33 * dst_step] = m53[j]; - dst_ic4_ptr[j * 8 + 34 * dst_step] = m54[j]; - dst_ic4_ptr[j * 8 + 35 * dst_step] = m55[j]; - } - } - } -} - -void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, - int output_w) { - float16x8_t s00 = vld1q_f16(gemm_out); - float16x8_t s01 = vld1q_f16(gemm_out + 8); - float16x8_t s02 = vld1q_f16(gemm_out + 16); - float16x8_t s03 = vld1q_f16(gemm_out + 24); - float16x8_t s04 = vld1q_f16(gemm_out + 32); - float16x8_t s05 = vld1q_f16(gemm_out + 40); - - float16x8_t s10 = vld1q_f16(gemm_out + 48); - float16x8_t s11 = vld1q_f16(gemm_out + 56); - float16x8_t s12 = vld1q_f16(gemm_out + 64); - float16x8_t s13 = vld1q_f16(gemm_out + 72); - float16x8_t s14 = vld1q_f16(gemm_out + 80); - float16x8_t s15 = vld1q_f16(gemm_out + 88); - - float16x8_t s20 = vld1q_f16(gemm_out + 96); - float16x8_t s21 = vld1q_f16(gemm_out + 104); - float16x8_t s22 = vld1q_f16(gemm_out + 112); - float16x8_t s23 = vld1q_f16(gemm_out + 120); - float16x8_t s24 = vld1q_f16(gemm_out + 128); - float16x8_t s25 = vld1q_f16(gemm_out + 136); - - float16x8_t s30 = vld1q_f16(gemm_out + 144); - float16x8_t s31 = vld1q_f16(gemm_out + 152); - float16x8_t s32 = vld1q_f16(gemm_out + 160); - float16x8_t s33 = vld1q_f16(gemm_out + 168); - float16x8_t s34 = vld1q_f16(gemm_out + 176); - float16x8_t s35 = vld1q_f16(gemm_out + 184); - - float16x8_t s40 = vld1q_f16(gemm_out + 192); - float16x8_t s41 = vld1q_f16(gemm_out + 200); - float16x8_t s42 = vld1q_f16(gemm_out + 208); - float16x8_t s43 = vld1q_f16(gemm_out + 216); - float16x8_t s44 = vld1q_f16(gemm_out + 224); - float16x8_t s45 = vld1q_f16(gemm_out + 232); - - float16x8_t s50 = vld1q_f16(gemm_out + 240); - float16x8_t s51 = vld1q_f16(gemm_out + 248); - float16x8_t s52 = vld1q_f16(gemm_out + 256); - float16x8_t s53 = vld1q_f16(gemm_out + 264); - float16x8_t s54 = vld1q_f16(gemm_out + 272); - float16x8_t s55 = vld1q_f16(gemm_out + 280); - - float16x8_t t00 = vaddq_f16(vaddq_f16(vaddq_f16(s00, s10), vaddq_f16(s20, s30)), s40); - float16x8_t t01 = vaddq_f16(vaddq_f16(vaddq_f16(s01, s11), vaddq_f16(s21, s31)), s41); - float16x8_t t02 = vaddq_f16(vaddq_f16(vaddq_f16(s02, s12), vaddq_f16(s22, s32)), s42); - float16x8_t t03 = vaddq_f16(vaddq_f16(vaddq_f16(s03, s13), vaddq_f16(s23, s33)), s43); - float16x8_t t04 = vaddq_f16(vaddq_f16(vaddq_f16(s04, s14), vaddq_f16(s24, s34)), s44); - float16x8_t t05 = vaddq_f16(vaddq_f16(vaddq_f16(s05, s15), vaddq_f16(s25, s35)), s45); - - float16x8_t t10 = vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 2)); - float16x8_t t11 = vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 2)); - float16x8_t t12 = vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 2)); - float16x8_t t13 = vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 2)); - float16x8_t t14 = vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 2)); - float16x8_t t15 = vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 2)); - - float16x8_t t20 = vaddq_f16(vaddq_f16(s10, s20), vmulq_n_f16(vaddq_f16(s30, s40), 4)); - float16x8_t t21 = vaddq_f16(vaddq_f16(s11, s21), vmulq_n_f16(vaddq_f16(s31, s41), 4)); - float16x8_t t22 = vaddq_f16(vaddq_f16(s12, s22), vmulq_n_f16(vaddq_f16(s32, s42), 4)); - float16x8_t t23 = vaddq_f16(vaddq_f16(s13, s23), vmulq_n_f16(vaddq_f16(s33, s43), 4)); - float16x8_t t24 = vaddq_f16(vaddq_f16(s14, s24), vmulq_n_f16(vaddq_f16(s34, s44), 4)); - float16x8_t t25 = vaddq_f16(vaddq_f16(s15, s25), vmulq_n_f16(vaddq_f16(s35, s45), 4)); - - float16x8_t t30 = vaddq_f16(vaddq_f16(vsubq_f16(s10, s20), vmulq_n_f16(vsubq_f16(s30, s40), 8)), s50); - float16x8_t t31 = vaddq_f16(vaddq_f16(vsubq_f16(s11, s21), vmulq_n_f16(vsubq_f16(s31, s41), 8)), s51); - float16x8_t t32 = vaddq_f16(vaddq_f16(vsubq_f16(s12, s22), vmulq_n_f16(vsubq_f16(s32, s42), 8)), s52); - float16x8_t t33 = vaddq_f16(vaddq_f16(vsubq_f16(s13, s23), vmulq_n_f16(vsubq_f16(s33, s43), 8)), s53); - float16x8_t t34 = vaddq_f16(vaddq_f16(vsubq_f16(s14, s24), vmulq_n_f16(vsubq_f16(s34, s44), 8)), s54); - float16x8_t t35 = vaddq_f16(vaddq_f16(vsubq_f16(s15, s25), vmulq_n_f16(vsubq_f16(s35, s45), 8)), s55); - - float16x8_t bias_ptr = vld1q_f16(bias_data); - float16x8_t d00 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t00, t01), vaddq_f16(t02, t03)), t04), bias_ptr); - float16x8_t d01 = vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 2)), bias_ptr); - float16x8_t d02 = vaddq_f16(vaddq_f16(vaddq_f16(t01, t02), vmulq_n_f16(vaddq_f16(t03, t04), 4)), bias_ptr); - float16x8_t d03 = - vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t01, t02), vmulq_n_f16(vsubq_f16(t03, t04), 8)), t05), bias_ptr); - - float16x8_t d10 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t10, t11), vaddq_f16(t12, t13)), t14), bias_ptr); - float16x8_t d11 = vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 2)), bias_ptr); - float16x8_t d12 = vaddq_f16(vaddq_f16(vaddq_f16(t11, t12), vmulq_n_f16(vaddq_f16(t13, t14), 4)), bias_ptr); - float16x8_t d13 = - vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t11, t12), vmulq_n_f16(vsubq_f16(t13, t14), 8)), t15), bias_ptr); - - float16x8_t d20 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t20, t21), vaddq_f16(t22, t23)), t24), bias_ptr); - float16x8_t d21 = vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 2)), bias_ptr); - float16x8_t d22 = vaddq_f16(vaddq_f16(vaddq_f16(t21, t22), vmulq_n_f16(vaddq_f16(t23, t24), 4)), bias_ptr); - float16x8_t d23 = - vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t21, t22), vmulq_n_f16(vsubq_f16(t23, t24), 8)), t25), bias_ptr); - - float16x8_t d30 = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t30, t31), vaddq_f16(t32, t33)), t34), bias_ptr); - float16x8_t d31 = vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 2)), bias_ptr); - float16x8_t d32 = vaddq_f16(vaddq_f16(vaddq_f16(t31, t32), vmulq_n_f16(vaddq_f16(t33, t34), 4)), bias_ptr); - float16x8_t d33 = - vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t31, t32), vmulq_n_f16(vsubq_f16(t33, t34), 8)), t35), bias_ptr); - - vst1q_f16(output_data, d00); - vst1q_f16(output_data + 8, d01); - vst1q_f16(output_data + 16, d02); - vst1q_f16(output_data + 24, d03); - - vst1q_f16(output_data + output_w * 8, d10); - vst1q_f16(output_data + output_w * 8 + 8, d11); - vst1q_f16(output_data + output_w * 8 + 16, d12); - vst1q_f16(output_data + output_w * 8 + 24, d13); - - vst1q_f16(output_data + 2 * output_w * 8, d20); - vst1q_f16(output_data + 2 * output_w * 8 + 8, d21); - vst1q_f16(output_data + 2 * output_w * 8 + 16, d22); - vst1q_f16(output_data + 2 * output_w * 8 + 24, d23); - - vst1q_f16(output_data + 3 * output_w * 8, d30); - vst1q_f16(output_data + 3 * output_w * 8 + 8, d31); - vst1q_f16(output_data + 3 * output_w * 8 + 16, d32); - vst1q_f16(output_data + 3 * output_w * 8 + 24, d33); -} - -void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, - int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) { - int output_channel = conv_param->output_channel_; - int output_h = conv_param->output_h_; - int out_h_block = UP_DIV(output_h, C4NUM); - int oc8 = UP_DIV(output_channel, C8NUM); - if (out_w_block == 0) { - return; - } - for (int i = 0; i < real_cal_num; i++) { - int out_w_index = (start_index + i) % out_w_block; - int out_h_index = (start_index + i) / out_w_block; - int src_tile_offset = i * oc8 * C8NUM * 36; - int dst_tile_offset = C8NUM * (out_w_index * C4NUM + out_h_index * C4NUM * out_w_block * C4NUM); - - for (int j = 0; j < oc8; j++) { - int src_oc8_offset = src_tile_offset + j * 36 * C8NUM; - int dst_oc8_offset = dst_tile_offset + j * C8NUM * out_h_block * out_w_block * C4NUM * C4NUM; - const float16_t *src_ptr = gemm_out + src_oc8_offset; - const float16_t *bias_ptr = bias_data + j * C8NUM; - float16_t *dst_ptr = out_data + dst_oc8_offset; - - // output transform - Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, out_w_block * C4NUM); - } - } -} - // fp16 common winograd void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, int out_tile_index, int out_w_block_num, ConvParameter *conv_param, InputTransFp16Func func) { +#ifdef ENABLE_ARM64 const int tile_num = 16; +#else + const int tile_num = 12; +#endif int input_unit = conv_param->input_unit_; int output_unit = conv_param->output_unit_; int in_channel = conv_param->input_channel_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.h index 863248950e8..e217d64b482 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_transform_fp16.h @@ -23,25 +23,11 @@ #include "nnacl/fp16/cast_fp16.h" #include "nnacl/fp16/conv_fp16.h" #include "nnacl/fp16/matrix_fp16.h" +#include "nnacl/fp16/pack_fp16.h" #ifdef __cplusplus extern "C" { #endif - -// for fp16 convolution 3x3 filter/input/output transform -void Conv3x3Fp16InputUnit(float16_t *tmp_data, float16_t *trans_input_data, size_t step); - -void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, - int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); - -void Conv3x3Fp16FilterTransform(const float16_t *weight_data, float16_t *trans_weight, int iC8, int output_channel, - int kernel_plane); - -void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data, float16_t *output_data, int output_w); - -void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, - int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param); - // fp16 common winograd void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, int out_tile_index, int out_w_block_num, ConvParameter *conv_param, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_utils_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_utils_fp16.h index 72a4b709a09..f177e005bbc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_utils_fp16.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp16/winograd_utils_fp16.h @@ -17,9 +17,9 @@ #ifndef MINDSPORE_NNACL_FP16_WINOGRAD_UTILS_H_ #define MINDSPORE_NNACL_FP16_WINOGRAD_UTILS_H_ -#include #include "nnacl/conv_parameter.h" #include "nnacl/op_base.h" +#include "nnacl/intrinsics/ms_simd_instructions_fp16.h" #define MAX_LEN 256 diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/ms_simd_instructions.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/ms_simd_instructions.h index 4b46d798d70..070c498f4a3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/ms_simd_instructions.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/ms_simd_instructions.h @@ -17,9 +17,11 @@ #ifndef MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ #define MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ #include + #ifdef ENABLE_ARM #include #endif + #if defined(ENABLE_SSE) || defined(ENABLE_AVX) #include #endif @@ -46,7 +48,7 @@ #ifdef ENABLE_ARM64 #define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2) #else -inline static float32x4_t vrecp(float32x4_t v) { +static inline float32x4_t vrecp(float32x4_t v) { float32x4_t r = vrecpeq_f32(v); r = vmulq_f32(vrecpsq_f32(v, r), r); r = vmulq_f32(vrecpsq_f32(v, r), r); @@ -205,25 +207,4 @@ static inline MS_FLOAT32X4 MS_ERFX4_F32(MS_FLOAT32X4 src) { return dst; } -#ifdef ENABLE_ARM64 -static inline float16x8_t MS_TANHX8_F16(float16x8_t src) { - float32x4_t src_low = vcvt_f32_f16(vget_low_f16(src)); - float32x4_t src_high = vcvt_f32_f16(vget_high_f16(src)); - return vcombine_f16(vcvt_f16_f32(MS_TANHX4_F32(src_low)), vcvt_f16_f32(MS_TANHX4_F32(src_high))); -} - -static inline float16x8_t MS_ERFX8_F16(float16x8_t src) { - float16x8_t dst; - dst[0] = erff(src[0]); - dst[1] = erff(src[1]); - dst[2] = erff(src[2]); - dst[3] = erff(src[3]); - dst[4] = erff(src[4]); - dst[5] = erff(src[5]); - dst[6] = erff(src[6]); - dst[7] = erff(src[7]); - return dst; -} -#endif - #endif // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/ms_simd_instructions_fp16.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/ms_simd_instructions_fp16.h new file mode 100644 index 00000000000..cbe3a3ab126 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/ms_simd_instructions_fp16.h @@ -0,0 +1,99 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ +#define MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ +#include +#include "nnacl/intrinsics/ms_simd_instructions.h" + +#if defined(ENABLE_ARM82_A32) +static inline float16x8_t divq_f16(float16x8_t in1, float16x8_t in2) { + float16x8_t dst; + asm volatile( + "vrecpe.f16 q14, %3\n" + "vrecps.f16 q15, %3, q14\n" + "vmul.f16 q14, q15, q14\n" + "vrecps.f16 q15, %3, q14\n" + "vmul.f16 q14, q15, q14\n" + "vmul.f16 %0, %2, q14\n" + : "=w"(dst) + : "0"(dst), "w"(in1), "w"(in2) + : "q14", "q15"); + return dst; +} + +static inline float16x4_t div_f16(float16x4_t in1, float16x4_t in2) { + float16x4_t dst; + asm volatile( + "vrecpe.f16 d14, %3\n" + "vrecps.f16 d16, %3, d14\n" + "vmul.f16 d14, d16, d14\n" + "vrecps.f16 d16, %3, d14\n" + "vmul.f16 d14, d16, d14\n" + "vmul.f16 %0, %2, d14\n" + : "=w"(dst) + : "0"(dst), "w"(in1), "w"(in2) + : "d14", "d16"); + return dst; +} + +static inline float vaddvq_f32(float32x4_t in) { // is not support in arm82 aarch32 + return in[0] + in[1] + in[2] + in[3]; +} + +static inline float32x4_t cvt_f32_f16(float16x4_t in) { + float32x4_t dst; + asm volatile("vcvt.f32.f16 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :); + return dst; +} + +static inline float16x4_t cvt_f16_f32(float32x4_t in) { + float16x4_t dst; + asm volatile("vcvt.f16.f32 %0, %2\n" : "=w"(dst) : "0"(dst), "w"(in) :); + return dst; +} + +#define MS_CVT_F32_F16(src) cvt_f32_f16(src) +#define MS_CVT_F16_F32(src) cvt_f16_f32(src) +#define MS_DIV_F16(src1, src2) div_f16(src1, src2) +#define MS_DIVQ_F16(src1, src2) divq_f16(src1, src2) +#define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_f16(src1, src2, vdupq_n_f16(src3)) +#else +#define MS_CVT_F32_F16(src) vcvt_f32_f16(src) +#define MS_CVT_F16_F32(src) vcvt_f16_f32(src) +#define MS_DIV_F16(src1, src2) vdiv_f16(src1, src2) +#define MS_DIVQ_F16(src1, src2) vdivq_f16(src1, src2) +#define MS_FMAQ_N_F16(src1, src2, src3) vfmaq_n_f16(src1, src2, src3) +#endif + +static inline float16x8_t MS_TANHX8_F16(float16x8_t src) { + float32x4_t src_low = MS_CVT_F32_F16(vget_low_f16(src)); + float32x4_t src_high = MS_CVT_F32_F16(vget_high_f16(src)); + return vcombine_f16(MS_CVT_F16_F32(MS_TANHX4_F32(src_low)), MS_CVT_F16_F32(MS_TANHX4_F32(src_high))); +} + +static inline float16x8_t MS_ERFX8_F16(float16x8_t src) { + float16x8_t dst; + dst[0] = erff(src[0]); + dst[1] = erff(src[1]); + dst[2] = erff(src[2]); + dst[3] = erff(src[3]); + dst[4] = erff(src[4]); + dst[5] = erff(src[5]); + dst[6] = erff(src[6]); + dst[7] = erff(src[7]); + return dst; +} +#endif // MINDSPORE_NNACL_INTRINSICS_MS_SIMD_INSTRUCTIONS_FP16_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/optimize/CMakeLists.txt b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/optimize/CMakeLists.txt index 69524448ce4..6e1f0e6c591 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/optimize/CMakeLists.txt +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/optimize/CMakeLists.txt @@ -4,11 +4,15 @@ set(NNACL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..) include_directories(NNACL_DIR) ########################### optimized files ########################### -file(GLOB SDOT_SRC ${NNACL_DIR}/assembly/opt/*.S) file(GLOB FP16_C_SRC ${NNACL_DIR}/fp16/*.c) -file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/fp16/*.S) +if(PLATFORM_ARM32) + file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/arm82_aarch32_fp16/*.S) +else() + file(GLOB FP16_NEON_SRC ${NNACL_DIR}/assembly/fp16/*.S) + file(GLOB SDOT_SRC ${NNACL_DIR}/assembly/opt/*.S) + set_property(SOURCE ${SDOT_SRC} PROPERTY LANGUAGE C) +endif() -set_property(SOURCE ${SDOT_SRC} PROPERTY LANGUAGE C) set_property(SOURCE ${FP16_C_SRC} PROPERTY LANGUAGE C) set_property(SOURCE ${FP16_NEON_SRC} PROPERTY LANGUAGE C) @@ -17,7 +21,6 @@ if(APPLE) set_source_files_properties(${FP16_NEON_SRC} PROPERTIES COMPILE_FLAGS "-x assembler-with-cpp") endif() ########################### share library build ######################## -list(APPEND SDOT_FILES ${SDOT_SRC}) list(APPEND FP16_FILES ${FP16_C_SRC}) list(APPEND FP16_FILES ${FP16_NEON_SRC}) @@ -27,13 +30,20 @@ if(SUPPORT_TRAIN) endif() string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") -add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES}) -add_dependencies(nnacl_optimize_mid fbs_src) +if(NOT PLATFORM_ARM32) + list(APPEND SDOT_FILES ${SDOT_SRC}) + add_library(nnacl_optimize_mid OBJECT ${SDOT_FILES}) + add_dependencies(nnacl_optimize_mid fbs_src) +endif() if(ENABLE_FP16) add_library(nnacl_fp16_mid OBJECT ${FP16_FILES}) + if(PLATFORM_ARM32) + target_compile_options(nnacl_fp16_mid PRIVATE -march=armv8.2-a+fp16 -mfpu=neon-fp-armv8 -mfloat-abi=softfp) + else() + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod+fp16") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=armv8.2-a+dotprod+fp16") + endif() add_dependencies(nnacl_fp16_mid fbs_src) endif() \ No newline at end of file diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index 37f1affff21..8932f6b63d3 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -5,6 +5,12 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_L message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0") endif() +if(PLATFORM_ARM32 AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0) + set(ENABLE_FP16 "off") + message(WARNING "If you want to build fp16 in arm82_a32, \ + your Clang version:[${CMAKE_CXX_COMPILER_VERSION}] must not be less than 9.0 and please use android nkd r21e!") +endif() + option(MS_VERSION_MAJOR "major version" 0) option(MS_VERSION_MINOR "minor version" 7) option(MS_VERSION_REVISION "revision version" 0) @@ -15,6 +21,7 @@ option(PLATFORM_ARM64 "if build device for arm64" off) option(PLATFORM_ARM32 "if build device for arm32" off) option(ENABLE_CONVERTER "if build converter" on) option(ENABLE_FP16 "if build fp16 ops" off) +option(ENABLE_ARM82_A32 "if build fp16 on platform_arm32" off) option(ENABLE_TOOLS "if build tools" on) option(BUILD_TESTCASES "if build testcase" on) option(SUPPORT_GPU "if support gpu" off) @@ -177,6 +184,9 @@ if(ENABLE_NEON) endif() if(ENABLE_FP16) add_compile_definitions(ENABLE_FP16) + if(PLATFORM_ARM32) + add_compile_definitions(ENABLE_ARM82_A32) + endif() endif() if(SUPPORT_GPU STREQUAL opencl) add_definitions(-DGPU_OPENCL) diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index d363f8f2d50..60ac8cb7fb4 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -3,6 +3,9 @@ if(ENABLE_V0) add_definitions(-DENABLE_V0) endif() include_directories(${CCSRC_DIR}/backend/kernel_compiler/cpu) +set(LITE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..) +include_directories(${LITE_DIR}/nnacl/) +include_directories(${LITE_DIR}/nnacl/optimize) if(PLATFORM_ARM32 OR PLATFORM_ARM64) #for performance @@ -209,9 +212,11 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") endif() ########################## build optimize and float16 library ################################# -if(PLATFORM_ARM64) - target_link_libraries(mindspore-lite cpu_opt_kernel_mid nnacl_optimize_mid) - target_link_libraries(mindspore-lite_static cpu_opt_kernel_mid nnacl_optimize_mid) +if(PLATFORM_ARM) + if(PLATFORM_ARM64) + target_link_libraries(mindspore-lite cpu_opt_kernel_mid nnacl_optimize_mid) + target_link_libraries(mindspore-lite_static cpu_opt_kernel_mid nnacl_optimize_mid) + endif() if(ENABLE_FP16) target_link_libraries(mindspore-lite cpu_fp16_kernel_mid nnacl_fp16_mid) target_link_libraries(mindspore-lite_static cpu_fp16_kernel_mid nnacl_fp16_mid) @@ -247,8 +252,10 @@ if(DEFINED ARCHS) target_link_libraries(mindspore_lite mindrt_mid) endif() - if(PLATFORM_ARM64) - target_link_libraries(mindspore_lite cpu_opt_kernel_mid nnacl_optimize_mid) + if(PLATFORM_ARM) + if(PLATFORM_ARM64) + target_link_libraries(mindspore_lite cpu_opt_kernel_mid nnacl_optimize_mid) + endif() if(ENABLE_FP16) target_link_libraries(mindspore_lite cpu_fp16_kernel_mid nnacl_fp16_mid) endif() diff --git a/mindspore/lite/src/common/utils.cc b/mindspore/lite/src/common/utils.cc index ab66fcc0764..b7c89aaaf5b 100644 --- a/mindspore/lite/src/common/utils.cc +++ b/mindspore/lite/src/common/utils.cc @@ -155,7 +155,11 @@ bool IsSupportSDot() { bool IsSupportFloat16() { bool status = false; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_ARM32 + status = true; +#endif + +#if defined(ENABLE_ARM64) #if defined(__ANDROID__) int hwcap_type = 16; uint32_t hwcap = getHwCap(hwcap_type); diff --git a/mindspore/lite/src/common/utils.h b/mindspore/lite/src/common/utils.h index fca624997a7..0b4fb9874cc 100644 --- a/mindspore/lite/src/common/utils.h +++ b/mindspore/lite/src/common/utils.h @@ -44,7 +44,7 @@ uint64_t GetTimeUs(); bool IsSupportSDot(); bool IsSupportFloat16(); -#if defined(__arm__) || defined(__aarch64__) +#if defined(__arm__) uint32_t getHwCap(int hwcap_type); #endif diff --git a/mindspore/lite/src/kernel_registry.cc b/mindspore/lite/src/kernel_registry.cc index c99718b9cf0..8945b366fb9 100644 --- a/mindspore/lite/src/kernel_registry.cc +++ b/mindspore/lite/src/kernel_registry.cc @@ -19,7 +19,7 @@ #include "src/common/version_manager.h" #include "nnacl/pooling_parameter.h" #include "src/ios_reg_kernels.h" -#ifdef ENABLE_ARM64 +#if defined(ENABLE_FP16) && defined(ENABLE_ARM) #if defined(__ANDROID__) #include #endif @@ -55,6 +55,8 @@ int KernelRegistry::Init() { } else { MS_LOG(INFO) << "The current device NOT supports Sdot."; } +#endif +#ifdef ENABLE_FP16 if (mindspore::lite::IsSupportFloat16()) { MS_LOG(INFO) << "The current device supports float16."; } else { diff --git a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt index 6e6c4f40439..dabfdd768a4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt +++ b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt @@ -17,7 +17,7 @@ endif() add_library(cpu_kernel_mid OBJECT ${KERNEL_SRC}) add_dependencies(cpu_kernel_mid fbs_src) -if(PLATFORM_ARM64) +if(PLATFORM_ARM) if(ENABLE_FP16) file(GLOB FP16_KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fp16/*.cc) if(SUPPORT_TRAIN) diff --git a/mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.cc b/mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.cc index bcb2d92d10c..48c1ec53ccb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/constant_of_shape.cc @@ -52,7 +52,7 @@ int ConstantOfShapeCPUKernel::DoExecute(int task_id) { ConstantOfShapeInt32(reinterpret_cast(output_ptr_), start, start + current_stride, param_->value_.int32_value_); break; -#ifdef ENABLE_NEON +#ifdef ENABLE_FP16 case kNumberTypeFloat16: ConstantOfShapeFp16(reinterpret_cast(output_ptr_), start, start + current_stride, param_->value_.f32_value_); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc index b9152a48049..2e56e5cf644 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.cc @@ -31,8 +31,8 @@ int Convolution1x1FP16CPUKernel::InitMatmulParam() { matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; matmul_param_->col_ = conv_param_->output_channel_; matmul_param_->deep_ = conv_param_->input_channel_; - matmul_param_->row_16_ = UP_ROUND(matmul_param_->row_, C16NUM); - matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); + matmul_param_->row_align_ = UP_ROUND(matmul_param_->row_, row_tile_); + matmul_param_->col_align_ = UP_ROUND(matmul_param_->col_, col_tile_); matmul_param_->act_type_ = conv_param_->act_type_; return RET_OK; } @@ -54,14 +54,14 @@ int Convolution1x1FP16CPUKernel::InitConv1x1Param() { pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 || conv_param_->stride_w_ != 1); - if ((matmul_param_->row_ > (C16NUM * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) { + if ((matmul_param_->row_ > (row_tile_ * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) { multi_thread_by_hw_ = true; - thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, C16NUM)); - thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, C16NUM), thread_count_) * C16NUM; + thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, row_tile_)); + thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, row_tile_), thread_count_) * row_tile_; } else { multi_thread_by_hw_ = false; - thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); - thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM; + thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, col_tile_)); + thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, col_tile_), thread_count_) * col_tile_; } if (pre_trans_input_) { @@ -81,8 +81,8 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { auto output_channel = weight_tensor->Batch(); if (in_tensors_.size() == 3) { - size_t size = UP_ROUND(output_channel, C8NUM) * sizeof(float16_t); - size_t weight_size = output_channel * sizeof(float16_t); + size_t size = UP_ROUND(output_channel, col_tile_) * sizeof(float16_t); + size_t bias_size = output_channel * sizeof(float16_t); bias_data_ = malloc(size); if (bias_data_ == nullptr) { MS_LOG(ERROR) << "Conv1x1 Malloc bias_ptr_ error!"; @@ -94,11 +94,11 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { Float32ToFloat16(reinterpret_cast(origin_bias_), reinterpret_cast(bias_data_), output_channel); } - memset(reinterpret_cast(bias_data_) + weight_size, 0, size - weight_size); + memset(reinterpret_cast(bias_data_) + bias_size, 0, size - bias_size); } - size_t size = input_channel * UP_ROUND(output_channel, C8NUM) * sizeof(float16_t); - size_t down_size = input_channel * DOWN_DIV(output_channel, C8NUM) * C8NUM * sizeof(float16_t); + size_t size = input_channel * UP_ROUND(output_channel, col_tile_) * sizeof(float16_t); + size_t down_size = input_channel * DOWN_DIV(output_channel, col_tile_) * col_tile_ * sizeof(float16_t); weight_ptr_ = reinterpret_cast(malloc(size)); if (weight_ptr_ == nullptr) { MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!"; @@ -111,6 +111,12 @@ int Convolution1x1FP16CPUKernel::InitWeightBias() { } int Convolution1x1FP16CPUKernel::Init() { + col_tile_ = C8NUM; +#ifdef ENABLE_ARM64 + row_tile_ = C16NUM; +#else + row_tile_ = C12NUM; +#endif matmul_param_ = new (std::nothrow) MatMulParameter(); if (matmul_param_ == nullptr) { MS_LOG(ERROR) << "Init matmul_param_ failed."; @@ -177,8 +183,11 @@ int Convolution1x1FP16CPUKernel::RunHw(int task_id) { float16_t *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_; float16_t *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_; +#ifdef ENABLE_ARM64 RowMajor2Col16MajorFp16Opt(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); - +#else + RowMajor2Col12MajorFp16Opt(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_); +#endif float16_t *thread_output_ptr = output_ptr_ + task_id * thread_stride_ * matmul_param_->col_; MatMulFp16(thread_pack_input, weight_ptr_, thread_output_ptr, reinterpret_cast(bias_data_), matmul_param_->act_type_, matmul_param_->deep_, cur_hw_, matmul_param_->col_, matmul_param_->col_, @@ -211,7 +220,7 @@ int Convolution1x1FP16CPUKernel::Run() { ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); pack_input_ = reinterpret_cast( - ctx_->allocator->Malloc(matmul_param_->row_16_ * matmul_param_->deep_ * sizeof(float16_t))); + ctx_->allocator->Malloc(matmul_param_->row_align_ * matmul_param_->deep_ * sizeof(float16_t))); if (pack_input_ == nullptr) { MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!"; return RET_MEMORY_FAILED; @@ -231,7 +240,11 @@ int Convolution1x1FP16CPUKernel::Run() { if (multi_thread_by_hw_) { ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Fp16RunHw, this, thread_count_); } else { +#ifdef ENABLE_ARM64 RowMajor2Col16MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); +#else + RowMajor2Col12MajorFp16Opt(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_); +#endif ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Fp16RunOc, this, thread_count_); } if (ret != RET_OK) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h index f3b2953c096..680ba6ef40c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h @@ -62,6 +62,8 @@ class Convolution1x1FP16CPUKernel : public ConvolutionBaseFP16CPUKernel { float16_t *pack_input_ = nullptr; float16_t *output_ptr_ = nullptr; MatMulParameter *matmul_param_ = nullptr; + int col_tile_; + int row_tile_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index 67f4cb3d669..9c3624acd00 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -34,7 +34,7 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { int out_channel = filter_tensor->Batch(); conv_param_->input_channel_ = in_channel; conv_param_->output_channel_ = out_channel; - int oc8 = UP_ROUND(out_channel, C8NUM); + int oc8 = UP_ROUND(out_channel, col_tile_); int kernel_plane = filter_tensor->Height() * filter_tensor->Width(); int pack_weight_size = oc8 * in_channel * kernel_plane; @@ -68,9 +68,8 @@ int ConvolutionFP16CPUKernel::InitWeightBias() { } int ConvolutionFP16CPUKernel::InitTmpBuffer() { - const int cal_num = 16; int unit_size = - conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * cal_num * thread_count_; + conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * row_tile_ * thread_count_; packed_input_ = reinterpret_cast(ctx_->allocator->Malloc(unit_size * sizeof(float16_t))); if (packed_input_ == nullptr) { @@ -87,6 +86,12 @@ int ConvolutionFP16CPUKernel::InitTmpBuffer() { } int ConvolutionFP16CPUKernel::Init() { +#ifdef ENABLE_ARM64 + row_tile_ = C16NUM; +#else + row_tile_ = C12NUM; +#endif + col_tile_ = C8NUM; auto ret = InitWeightBias(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init weight bias failed."; @@ -98,7 +103,7 @@ int ConvolutionFP16CPUKernel::Init() { void ConvolutionFP16CPUKernel::AdjustNumberOfThread() { auto out_tensor = out_tensors_.front(); int out_plane = out_tensor->Height() * out_tensor->Width(); - thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(out_plane, C16NUM)); + thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(out_plane, row_tile_)); conv_param_->thread_num_ = thread_count_; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h index 7424ed80e1c..f9c08591394 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h @@ -62,6 +62,8 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { float16_t *packed_input_ = nullptr; float16_t *packed_weight_ = nullptr; float16_t *col_major_input_ = nullptr; + int col_tile_; + int row_tile_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc index e655e35234a..6db4fce2db0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc @@ -38,13 +38,10 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { int out_channel = filter_tensor->Batch(); conv_param_->input_channel_ = in_channel; conv_param_->output_channel_ = out_channel; - - const int oc_block = C8NUM; - int oc_block_num = UP_DIV(out_channel, C8NUM); - + int oc_block_num = UP_DIV(out_channel, col_tile_); // init weight // set data - auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float16_t); + auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * col_tile_ * sizeof(float16_t); trans_weight_ = reinterpret_cast(malloc(trans_matrix_data_size)); if (trans_weight_ == nullptr) { MS_LOG(ERROR) << "malloc trans_weight_ failed."; @@ -73,7 +70,7 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { MS_LOG(ERROR) << "get execute filter failed."; return ret; } - ret = WinogradFilterTransformFp16(execute_weight_, matrix_g, matrix_gt, oc_block); + ret = WinogradFilterTransformFp16(execute_weight_, matrix_g, matrix_gt, col_tile_); if (ret != RET_OK) { MS_LOG(ERROR) << "winograd filter transform failed."; return ret; @@ -85,12 +82,12 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { } // init bias - bias_data_ = malloc(oc_block_num * oc_block * sizeof(float16_t)); + bias_data_ = malloc(oc_block_num * col_tile_ * sizeof(float16_t)); if (bias_data_ == nullptr) { MS_LOG(ERROR) << "malloc bias_data_ failed."; return RET_ERROR; } - memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float16_t)); + memset(bias_data_, 0, oc_block_num * col_tile_ * sizeof(float16_t)); if (in_tensors_.size() == kInputSize2) { if (origin_bias_data_type_ == kNumberTypeFloat16) { @@ -105,11 +102,9 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { } int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { - const int cal_num = 16; int channel_out = conv_param_->output_channel_; - size_t tile_buffer_size = - thread_count_ * cal_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t); + thread_count_ * row_tile_ * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t); trans_input_ = reinterpret_cast(ctx_->allocator->Malloc(tile_buffer_size)); if (trans_input_ == nullptr) { MS_LOG(ERROR) << "malloc trans_input_ failed."; @@ -117,7 +112,7 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { } gemm_out_ = reinterpret_cast(ctx_->allocator->Malloc( - thread_count_ * cal_num * input_unit_ * input_unit_ * UP_ROUND(channel_out, C8NUM) * sizeof(float16_t))); + thread_count_ * row_tile_ * input_unit_ * input_unit_ * UP_ROUND(channel_out, C8NUM) * sizeof(float16_t))); if (gemm_out_ == nullptr) { MS_LOG(ERROR) << "malloc gemm_out_ failed."; return RET_ERROR; @@ -131,7 +126,7 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { } col_buffer_ = reinterpret_cast( - ctx_->allocator->Malloc(thread_count_ * cal_num * conv_param_->input_channel_ * sizeof(float16_t))); + ctx_->allocator->Malloc(thread_count_ * row_tile_ * conv_param_->input_channel_ * sizeof(float16_t))); if (col_buffer_ == nullptr) { MS_LOG(ERROR) << "malloc col_buffer_ failed."; return RET_ERROR; @@ -159,6 +154,12 @@ int ConvolutionWinogradFP16CPUKernel::ConfigInputOutput() { } int ConvolutionWinogradFP16CPUKernel::Init() { + col_tile_ = C8NUM; +#ifdef ENABLE_ARM64 + row_tile_ = C16NUM; +#else + row_tile_ = C12NUM; +#endif kernel_unit_ = conv_param_->kernel_h_; input_unit_ = output_unit_ + kernel_unit_ - 1; conv_param_->input_unit_ = input_unit_; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h index 627fe5092e3..e9793875df2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h @@ -86,6 +86,8 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { TmpBufferAddressFp16 tmp_buffer_address_list_[4]; InputTransFp16Func in_func_; OutputTransFp16Func out_func_; + int col_tile_; + int row_tile_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/fp16_op_handler.h b/mindspore/lite/src/runtime/kernel/arm/fp16/fp16_op_handler.h index 86db28de400..dc284d0630c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/fp16_op_handler.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/fp16_op_handler.h @@ -13,26 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifdef ENABLE_ARM64 +#ifdef ENABLE_ARM #include #endif +#include "nnacl/fp16/cast_fp16.h" #ifdef __cplusplus extern "C" { #endif -#ifdef ENABLE_ARM64 -extern void Float32ToFloat16(const float *input, float16_t *output, int number); -extern void Float16ToFloat32(const float16_t *input, float *output, int number); - -inline void Float32ToFloat16_fp16_handler(const void *input, void *output, int number) { +static inline void Float32ToFloat16_fp16_handler(const void *input, void *output, int number) { Float32ToFloat16(reinterpret_cast(input), reinterpret_cast(output), number); } - -inline void Float16ToFloat32_fp16_handler(const void *input, void *output, int number) { +static inline void Float16ToFloat32_fp16_handler(const void *input, void *output, int number) { Float16ToFloat32(reinterpret_cast(input), reinterpret_cast(output), number); } -#endif #ifdef __cplusplus } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/pooling_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/pooling_fp16.cc index 06dc0fb5d87..bcc91cb2b9d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/pooling_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/pooling_fp16.cc @@ -53,8 +53,8 @@ int PoolingFp16CPUKernel::ReSize() { } int PoolingFp16CPUKernel::RunImpl(int task_id) { - float16_t minf = -FLT_MAX; - float16_t maxf = FLT_MAX; + float16_t minf = -FLT16_MAX; + float16_t maxf = FLT16_MAX; if (pooling_param_->act_type_ == ActType_Relu) { minf = 0.f; } else if (pooling_param_->act_type_ == ActType_Relu6) { diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index 2a8e735a074..95793ab208b 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -45,7 +45,7 @@ #include "src/runtime/agent/npu/optimizer/npu_fusion_pass.h" #include "src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h" #endif -#if defined(ENABLE_ARM64) && defined(ENABLE_FP16) +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" #endif @@ -230,7 +230,7 @@ int CopyConstTensor(Tensor *tensor, std::map *restored_origi auto origin_data = tensor->data_c(); MS_ASSERT(origin_data != nullptr); if (tensor->data_type() == kNumberTypeFloat32 && dst_data_type == kNumberTypeFloat16) { -#if defined(ENABLE_ARM64) && defined(ENABLE_FP16) +#if defined(ENABLE_ARM) && defined(ENABLE_FP16) auto restore_tensor = Tensor::CopyTensor(*tensor, false); restore_tensor->set_data(origin_data); restore_tensor->set_own_data(tensor->own_data()); diff --git a/mindspore/lite/src/sub_graph_kernel.cc b/mindspore/lite/src/sub_graph_kernel.cc index eda97d010bb..44dcd7c30ba 100644 --- a/mindspore/lite/src/sub_graph_kernel.cc +++ b/mindspore/lite/src/sub_graph_kernel.cc @@ -17,7 +17,7 @@ #include "src/sub_graph_kernel.h" #include "src/tensor.h" #include "src/tensorlist.h" -#if defined(ENABLE_ARM64) && defined(ENABLE_FP16) +#ifdef ENABLE_FP16 #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" #endif #include "src/common/version_manager.h" @@ -283,7 +283,7 @@ int CpuFp16SubGraph::Float16TensorToFloat32Tensor(lite::Tensor *tensor) { } int CpuFp16SubGraph::PreProcess() { -#ifdef ENABLE_ARM64 +#ifdef ENABLE_FP16 if (!mindspore::lite::IsSupportFloat16()) { MS_LOG(ERROR) << "Unsupported fp16 in this devices"; return RET_ERROR; @@ -347,7 +347,7 @@ int CpuFp16SubGraph::PreProcess() { } int CpuFp16SubGraph::PostProcess() { -#ifdef ENABLE_ARM64 +#ifdef ENABLE_FP16 if (!mindspore::lite::IsSupportFloat16()) { MS_LOG(ERROR) << "Unsupported fp16 in this devices"; return RET_ERROR; diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index f466970e3e4..4d19ac4c54f 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -377,8 +377,11 @@ add_dependencies(lite-test fbs_src) target_link_libraries(lite-test dl mindspore::gtest) -if(PLATFORM_ARM64 AND ENABLE_FP16) - target_link_libraries(lite-test nnacl_fp16_mid nnacl_optimize_mid) +if(PLATFORM_ARM AND ENABLE_FP16) + target_link_libraries(lite-test nnacl_fp16_mid) + if(PLATFORM_ARM64) + target_link_libraries(lite-test nnacl_optimize_mid) + endif() endif() if(PLATFORM_ARM)