!9673 [MS][LITE][Develop]add avx fp32 matmul kernel

From: @lx0095
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
This commit is contained in:
mindspore-ci-bot 2020-12-10 09:28:03 +08:00 committed by Gitee
commit de2b53d9f0
28 changed files with 1354 additions and 82 deletions

View File

@ -21,7 +21,8 @@ option(SUPPORT_NPU "if support npu" off)
option(OFFLINE_COMPILE "if offline compile OpenCL kernel" off)
option(BUILD_MINDDATA_EXAMPLE "" on)
option(ENABLE_VERBOSE "" off)
option(ENABLE_X86_64_SSE "if x86_64 support SSE instruction set" off)
option(ENABLE_SSE "if x86_64 support SSE instruction set" off)
option(ENABLE_AVX "if x86_64 support SSE instruction set" off)
set(DIR_PREFIX mindspore-lite)
set(MS_VERSION ${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION})
@ -202,7 +203,13 @@ endif()
if (NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64)
if ("${X86_64_SIMD}" STREQUAL "sse")
add_compile_definitions(ENABLE_X86_64_SSE)
add_compile_definitions(ENABLE_SSE)
endif ()
if ("${X86_64_SIMD}" STREQUAL "avx")
add_compile_definitions(ENABLE_SSE)
add_compile_definitions(ENABLE_AVX)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mfma")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx -mfma")
endif ()
endif ()

View File

@ -37,6 +37,12 @@ if ("${X86_64_SIMD}" STREQUAL "sse")
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
endif()
if ("${X86_64_SIMD}" STREQUAL "avx")
file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/x86_64_sse/*.c
${NNACL_DIR}/assembly/avx/*.S)
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
endif()
########################### build nnacl static library ########################
string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
add_library(nnacl STATIC ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC})

View File

@ -0,0 +1,941 @@
#ifdef ENABLE_AVX
#ifndef WIN32
.text
.align 4
.global MatmulFloatAvxOpt
#ifndef __APPLE__
.type MatmulFloatAvxOpt, %function
#endif
// void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
// int row, int col, size_t stride, size_t writeMode)
// rdi: a
// rsi: b
// rdx: c
// rcx: bias
// r8: act_type
// r9: depth
// 8: row
// 16: col
// 24: stride
// 32: writeNhwc/writeWino
MatmulFloatAvxOpt:
// rbx, rsp, rbp, r12-r15 must be saved according to x86 calling convention
pushq %r15
pushq %r14
pushq %r13
pushq %r12
pushq %rbx
pushq %rbp
pushq %r9
pushq %r8
pushq %rcx
pushq %rdx
pushq %rsi
pushq %rdi
addq $96, %rsp
movq 8(%rsp), %rbp
movq 16(%rsp), %rbx
movq 24(%rsp), %r10
movq 32(%rsp), %r14
movq $24, %r11
imul %r9, %r11
cmpq $0, %r14
jne NoC8Steps
movq $48, %r13
imul %rbp, %r13
NoC8Steps:
cmpq $2, %r14
jne NoWinoSteps
movq $4, %r12
imul %r10, %r12
imul %rbx, %r12
movq $48, %r13
imul %r10, %r13
NoWinoSteps:
movq $4, %rax
imul %rax, %r10
LoopRow:
movq -88(%rsp), %rsi
movq 16(%rsp), %rbx
movq -72(%rsp), %rcx
LoopCol:
cmpq $0, %r14
je NoReloadDst
movq -80(%rsp), %rdx
NoReloadDst:
movq -96(%rsp), %rdi
movq -56(%rsp), %r9
vmovups (%rsi), %ymm0
vmovups 32(%rsi), %ymm1
vbroadcastss (%rdi), %ymm10
vbroadcastss 4(%rdi), %ymm11
vbroadcastss 8(%rdi), %ymm12
vbroadcastss 12(%rdi), %ymm13
vbroadcastss 16(%rdi), %ymm2
vbroadcastss 20(%rdi), %ymm3
addq $64, %rsi
vmulps %ymm0, %ymm10, %ymm4
vmulps %ymm1, %ymm10, %ymm5
vmulps %ymm0, %ymm11, %ymm6
vmulps %ymm1, %ymm11, %ymm7
vmulps %ymm0, %ymm12, %ymm8
vmulps %ymm1, %ymm12, %ymm9
vmulps %ymm0, %ymm13, %ymm10
vmulps %ymm1, %ymm13, %ymm11
add $24, %rdi
vmulps %ymm0, %ymm2, %ymm12
vmulps %ymm1, %ymm2, %ymm13
vmulps %ymm0, %ymm3, %ymm14
vmulps %ymm1, %ymm3, %ymm15
subq $1, %r9
cmpq $0, %r9
je Bias
LoopDepth:
vmovups (%rsi), %ymm0
vmovups 32(%rsi), %ymm1
vbroadcastss (%rdi), %ymm2
vbroadcastss 4(%rdi), %ymm3
vfmadd231ps %ymm0, %ymm2, %ymm4
addq $64, %rsi
vfmadd231ps %ymm1, %ymm2, %ymm5
vbroadcastss 8(%rdi), %ymm2
vfmadd231ps %ymm0, %ymm3, %ymm6
vfmadd231ps %ymm1, %ymm3, %ymm7
vbroadcastss 12(%rdi), %ymm3
vfmadd231ps %ymm0, %ymm2, %ymm8
prefetcht0 384(%rsi)
vfmadd231ps %ymm1, %ymm2, %ymm9
vbroadcastss 16(%rdi), %ymm2
vfmadd231ps %ymm0, %ymm3, %ymm10
vfmadd231ps %ymm1, %ymm3, %ymm11
vbroadcastss 20(%rdi), %ymm3
vfmadd231ps %ymm0, %ymm2, %ymm12
vfmadd231ps %ymm1, %ymm2, %ymm13
addq $24, %rdi
vfmadd231ps %ymm0, %ymm3, %ymm14
vfmadd231ps %ymm1, %ymm3, %ymm15
subq $1, %r9
cmpq $0, %r9
ja LoopDepth
Bias:
cmpq $0, %rcx
je Activation
vmovups (%rcx), %ymm0
vmovups 32(%rcx), %ymm1
add $64, %rcx
vaddps %ymm0, %ymm4, %ymm4
vaddps %ymm1, %ymm5, %ymm5
vaddps %ymm0, %ymm6, %ymm6
vaddps %ymm1, %ymm7, %ymm7
vaddps %ymm0, %ymm8, %ymm8
vaddps %ymm1, %ymm9, %ymm9
vaddps %ymm0, %ymm10, %ymm10
vaddps %ymm1, %ymm11, %ymm11
vaddps %ymm0, %ymm12, %ymm12
vaddps %ymm1, %ymm13, %ymm13
vaddps %ymm0, %ymm14, %ymm14
vaddps %ymm1, %ymm15, %ymm15
Activation:
cmpq $3, %r8
je Relu6
cmpq $1, %r8
je Relu
jmp Write
Relu6:
movq $6, %rax
vcvtsi2ss %rax, %xmm0, %xmm0
vshufps $0, %xmm0, %xmm0, %xmm0
vinsertf128 $1, %xmm0, %ymm0, %ymm0
vminps %ymm0, %ymm4, %ymm4
vminps %ymm0, %ymm5, %ymm5
vminps %ymm0, %ymm6, %ymm6
vminps %ymm0, %ymm7, %ymm7
vminps %ymm0, %ymm8, %ymm8
vminps %ymm0, %ymm9, %ymm9
vminps %ymm0, %ymm10, %ymm10
vminps %ymm0, %ymm11, %ymm11
vminps %ymm0, %ymm12, %ymm12
vminps %ymm0, %ymm13, %ymm13
vminps %ymm0, %ymm14, %ymm14
vminps %ymm0, %ymm15, %ymm15
Relu:
vxorps %ymm1, %ymm1, %ymm1
vmaxps %ymm1, %ymm4, %ymm4
vmaxps %ymm1, %ymm5, %ymm5
vmaxps %ymm1, %ymm6, %ymm6
vmaxps %ymm1, %ymm7, %ymm7
vmaxps %ymm1, %ymm8, %ymm8
vmaxps %ymm1, %ymm9, %ymm9
vmaxps %ymm1, %ymm10, %ymm10
vmaxps %ymm1, %ymm11, %ymm11
vmaxps %ymm1, %ymm12, %ymm12
vmaxps %ymm1, %ymm13, %ymm13
vmaxps %ymm1, %ymm14, %ymm14
vmaxps %ymm1, %ymm15, %ymm15
Write:
cmpq $2, %r14
je WriteWino
cmpq $0, %r14
je WriteC8
cmpq $1, %rbx
je Write1
cmpq $2, %rbx
je Write2
cmpq $3, %rbx
je Write3
cmpq $4, %rbx
je Write4
cmpq $5, %rbx
je Write5
cmpq $6, %rbx
je Write6
cmpq $7, %rbx
je Write7
cmpq $8, %rbx
je Write8
cmpq $9, %rbx
je Write9
cmpq $10, %rbx
je Write10
cmpq $11, %rbx
je Write11
cmpq $12, %rbx
je Write12
cmpq $13, %rbx
je Write13
cmpq $14, %rbx
je Write14
cmpq $15, %rbx
je Write15
jmp Write16
Write1:
movq %rdx, %rax
addq $4, %rax
movq %rax, -80(%rsp)
vmovss %xmm4, (%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovss %xmm6, (%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovss %xmm8, (%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovss %xmm10, (%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovss %xmm12, (%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovss %xmm14, (%rdx)
addq %r10, %rdx
addq $4, %rdx
jmp WriteEnd
Write2:
movq %rdx, %rax
addq $8, %rax
movq %rax, -80(%rsp)
vmovsd %xmm4, (%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovsd %xmm6, (%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovsd %xmm8, (%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovsd %xmm10, (%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovsd %xmm12, (%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovsd %xmm14, (%rdx)
addq %r10, %rdx
addq $8, %rdx
jmp WriteEnd
Write3:
movq %rdx, %rax
addq $12, %rax
movq %rax, -80(%rsp)
vmovsd %xmm4, (%rdx)
movhlps %xmm4, %xmm4
vmovss %xmm4, 8(%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovsd %xmm6, (%rdx)
movhlps %xmm6, %xmm6
vmovss %xmm6, 8(%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovsd %xmm8, (%rdx)
movhlps %xmm8, %xmm8
vmovss %xmm8, 8(%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovsd %xmm10, (%rdx)
movhlps %xmm10, %xmm10
vmovss %xmm10, 8(%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovsd %xmm12, (%rdx)
movhlps %xmm12, %xmm12
vmovss %xmm12, 8(%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovsd %xmm14, (%rdx)
movhlps %xmm14, %xmm14
vmovss %xmm14, 8(%rdx)
addq %r10, %rdx
addq $12, %rdx
jmp WriteEnd
Write4:
movq %rdx, %rax
addq $16, %rax
movq %rax, -80(%rsp)
vmovups %xmm4, (%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm6, (%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm8, (%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm10, (%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm12, (%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm14, (%rdx)
addq %r10, %rdx
addq $16, %rdx
jmp WriteEnd
Write5:
movq %rdx, %rax
addq $20, %rax
movq %rax, -80(%rsp)
vmovups %xmm4, (%rdx)
vextractf128 $1, %ymm4, %xmm4
vmovss %xmm4, 16(%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm6, (%rdx)
vextractf128 $1, %ymm6, %xmm6
vmovss %xmm6, 16(%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm8, (%rdx)
vextractf128 $1, %ymm8, %xmm8
vmovss %xmm8, 16(%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm10, (%rdx)
vextractf128 $1, %ymm10, %xmm10
vmovss %xmm10, 16(%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm12, (%rdx)
vextractf128 $1, %ymm12, %xmm12
vmovss %xmm12, 16(%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm14, (%rdx)
vextractf128 $1, %ymm14, %xmm14
vmovss %xmm14, 16(%rdx)
addq %r10, %rdx
addq $20, %rdx
jmp WriteEnd
Write6:
movq %rdx, %rax
addq $24, %rax
movq %rax, -80(%rsp)
vmovups %xmm4, (%rdx)
vextractf128 $1, %ymm4, %xmm4
vmovsd %xmm4, 16(%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm6, (%rdx)
vextractf128 $1, %ymm6, %xmm6
vmovsd %xmm6, 16(%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm8, (%rdx)
vextractf128 $1, %ymm8, %xmm8
vmovsd %xmm8, 16(%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm10, (%rdx)
vextractf128 $1, %ymm10, %xmm10
vmovsd %xmm10, 16(%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm12, (%rdx)
vextractf128 $1, %ymm12, %xmm12
vmovsd %xmm12, 16(%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm14, (%rdx)
vextractf128 $1, %ymm14, %xmm14
vmovsd %xmm14, 16(%rdx)
addq %r10, %rdx
addq $24, %rdx
jmp WriteEnd
Write7:
movq %rdx, %rax
addq $28, %rax
movq %rax, -80(%rsp)
vmovups %xmm4, (%rdx)
vextractf128 $1, %ymm4, %xmm4
vmovsd %xmm4, 16(%rdx)
movhlps %xmm4, %xmm4
vmovss %xmm4, 24(%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm6, (%rdx)
vextractf128 $1, %ymm6, %xmm6
vmovsd %xmm6, 16(%rdx)
movhlps %xmm6, %xmm6
vmovss %xmm6, 24(%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm8, (%rdx)
vextractf128 $1, %ymm8, %xmm8
vmovsd %xmm8, 16(%rdx)
movhlps %xmm8, %xmm8
vmovss %xmm8, 24(%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm10, (%rdx)
vextractf128 $1, %ymm10, %xmm10
vmovsd %xmm10, 16(%rdx)
movhlps %xmm10, %xmm10
vmovss %xmm10, 24(%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm12, (%rdx)
vextractf128 $1, %ymm12, %xmm12
vmovsd %xmm12, 16(%rdx)
movhlps %xmm12, %xmm12
vmovss %xmm12, 24(%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %xmm14, (%rdx)
vextractf128 $1, %ymm14, %xmm14
vmovsd %xmm14, 16(%rdx)
movhlps %xmm14, %xmm14
vmovss %xmm14, 24(%rdx)
addq %r10, %rdx
addq $28, %rdx
jmp WriteEnd
Write8:
movq %rdx, %rax
addq $32, %rax
movq %rax, -80(%rsp)
vmovups %ymm4, (%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm6, (%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm8, (%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm10, (%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm12, (%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm14, (%rdx)
addq %r10, %rdx
addq $32, %rdx
jmp WriteEnd
Write9:
movq %rdx, %rax
addq $36, %rax
movq %rax, -80(%rsp)
vmovups %ymm4, (%rdx)
vmovss %xmm5, 32(%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm6, (%rdx)
vmovss %xmm7, 32(%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm8, (%rdx)
vmovss %xmm9, 32(%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm10, (%rdx)
vmovss %xmm11, 32(%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm12, (%rdx)
vmovss %xmm13, 32(%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm14, (%rdx)
vmovss %xmm15, 32(%rdx)
addq %r10, %rdx
addq $36, %rdx
jmp WriteEnd
Write10:
movq %rdx, %rax
addq $40, %rax
movq %rax, -80(%rsp)
vmovups %ymm4, (%rdx)
vmovsd %xmm5, 32(%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm6, (%rdx)
vmovsd %xmm7, 32(%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm8, (%rdx)
vmovsd %xmm9, 32(%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm10, (%rdx)
vmovsd %xmm11, 32(%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm12, (%rdx)
vmovsd %xmm13, 32(%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm14, (%rdx)
vmovsd %xmm15, 32(%rdx)
addq %r10, %rdx
addq $40, %rdx
jmp WriteEnd
Write11:
movq %rdx, %rax
addq $44, %rax
movq %rax, -80(%rsp)
vmovups %ymm4, (%rdx)
vmovsd %xmm5, 32(%rdx)
movhlps %xmm5, %xmm5
vmovss %xmm5, 40(%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm6, (%rdx)
vmovsd %xmm7, 32(%rdx)
movhlps %xmm7, %xmm7
vmovss %xmm7, 40(%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm8, (%rdx)
vmovsd %xmm9, 32(%rdx)
movhlps %xmm9, %xmm9
vmovss %xmm9, 40(%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm10, (%rdx)
vmovsd %xmm11, 32(%rdx)
movhlps %xmm11, %xmm11
vmovss %xmm11, 40(%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm12, (%rdx)
vmovsd %xmm13, 32(%rdx)
movhlps %xmm13, %xmm13
vmovss %xmm13, 40(%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm14, (%rdx)
vmovsd %xmm15, 32(%rdx)
movhlps %xmm15, %xmm15
vmovss %xmm15, 40(%rdx)
addq %r10, %rdx
addq $44, %rdx
jmp WriteEnd
Write12:
movq %rdx, %rax
addq $48, %rax
movq %rax, -80(%rsp)
vmovups %ymm4, (%rdx)
vmovups %xmm5, 32(%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm6, (%rdx)
vmovups %xmm7, 32(%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm8, (%rdx)
vmovups %xmm9, 32(%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm10, (%rdx)
vmovups %xmm11, 32(%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm12, (%rdx)
vmovups %xmm13, 32(%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm14, (%rdx)
vmovups %xmm15, 32(%rdx)
addq %r10, %rdx
addq $48, %rdx
jmp WriteEnd
Write13:
movq %rdx, %rax
addq $52, %rax
movq %rax, -80(%rsp)
vmovups %ymm4, (%rdx)
vmovups %xmm5, 32(%rdx)
vextractf128 $1, %ymm5, %xmm5
vmovss %xmm5, 48(%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm6, (%rdx)
vmovups %xmm7, 32(%rdx)
vextractf128 $1, %ymm7, %xmm7
vmovss %xmm7, 48(%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm8, (%rdx)
vmovups %xmm9, 32(%rdx)
vextractf128 $1, %ymm9, %xmm9
vmovss %xmm9, 48(%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm10, (%rdx)
vmovups %xmm11, 32(%rdx)
vextractf128 $1, %ymm11, %xmm11
vmovss %xmm11, 48(%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm12, (%rdx)
vmovups %xmm13, 32(%rdx)
vextractf128 $1, %ymm13, %xmm13
vmovss %xmm13, 48(%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm14, (%rdx)
vmovups %xmm15, 32(%rdx)
vextractf128 $1, %ymm15, %xmm15
vmovss %xmm15, 48(%rdx)
addq %r10, %rdx
addq $52, %rdx
jmp WriteEnd
Write14:
movq %rdx, %rax
addq $56, %rax
movq %rax, -80(%rsp)
vmovups %ymm4, (%rdx)
vmovups %xmm5, 32(%rdx)
vextractf128 $1, %ymm5, %xmm5
vmovsd %xmm5, 48(%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm6, (%rdx)
vmovups %xmm7, 32(%rdx)
vextractf128 $1, %ymm7, %xmm7
vmovsd %xmm7, 48(%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm8, (%rdx)
vmovups %xmm9, 32(%rdx)
vextractf128 $1, %ymm9, %xmm9
vmovsd %xmm9, 48(%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm10, (%rdx)
vmovups %xmm11, 32(%rdx)
vextractf128 $1, %ymm11, %xmm11
vmovsd %xmm11, 48(%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm12, (%rdx)
vmovups %xmm13, 32(%rdx)
vextractf128 $1, %ymm13, %xmm13
vmovsd %xmm13, 48(%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm14, (%rdx)
vmovups %xmm15, 32(%rdx)
vextractf128 $1, %ymm15, %xmm15
vmovsd %xmm15, 48(%rdx)
addq %r10, %rdx
addq $56, %rdx
jmp WriteEnd
Write15:
movq %rdx, %rax
addq $60, %rax
movq %rax, -80(%rsp)
vmovups %ymm4, (%rdx)
vmovups %xmm5, 32(%rdx)
vextractf128 $1, %ymm5, %xmm5
vmovsd %xmm5, 48(%rdx)
movhlps %xmm5, %xmm5
vmovss %xmm5, 56(%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm6, (%rdx)
vmovups %xmm7, 32(%rdx)
vextractf128 $1, %ymm7, %xmm7
vmovsd %xmm7, 48(%rdx)
movhlps %xmm7, %xmm7
vmovss %xmm7, 56(%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm8, (%rdx)
vmovups %xmm9, 32(%rdx)
vextractf128 $1, %ymm9, %xmm9
vmovsd %xmm9, 48(%rdx)
movhlps %xmm9, %xmm9
vmovss %xmm9, 56(%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm10, (%rdx)
vmovups %xmm11, 32(%rdx)
vextractf128 $1, %ymm11, %xmm11
vmovsd %xmm11, 48(%rdx)
movhlps %xmm11, %xmm11
vmovss %xmm11, 56(%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm12, (%rdx)
vmovups %xmm13, 32(%rdx)
vextractf128 $1, %ymm13, %xmm13
vmovsd %xmm13, 48(%rdx)
movhlps %xmm13, %xmm13
vmovss %xmm13, 56(%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm14, (%rdx)
vmovups %xmm15, 32(%rdx)
vextractf128 $1, %ymm15, %xmm15
vmovsd %xmm15, 48(%rdx)
movhlps %xmm15, %xmm15
vmovss %xmm15, 56(%rdx)
addq %r10, %rdx
addq $60, %rdx
jmp WriteEnd
WriteC8:
movq %rdx, %rax
addq %r11, %rdx
movq %rdx, %r15
addq %r11, %rdx
movq %rdx, -80(%rsp)
vmovups %ymm4, (%rax)
vmovups %ymm6, 32(%rax)
vmovups %ymm8, 64(%rax)
vmovups %ymm10, 96(%rax)
vmovups %ymm12, 128(%rax)
vmovups %ymm14, 160(%rax)
vmovups %ymm5, (%r15)
vmovups %ymm7, 32(%r15)
vmovups %ymm9, 64(%r15)
vmovups %ymm11, 96(%r15)
vmovups %ymm13, 128(%r15)
vmovups %ymm15, 160(%r15)
jmp WriteEnd
WriteWino:
movq %rdx, %rax
addq %r13, %rdx
movq %rdx, %r15
addq %r13, %rdx
movq %rdx, -80(%rsp)
vmovups %ymm4, (%rax)
vmovups %ymm5, (%r15)
addq %r12, %rax
addq %r12, %r15
vmovups %ymm6, (%rax)
vmovups %ymm7, (%r15)
addq %r12, %rax
addq %r12, %r15
vmovups %ymm8, (%rax)
vmovups %ymm9, (%r15)
addq %r12, %rax
addq %r12, %r15
vmovups %ymm10, (%rax)
vmovups %ymm11, (%r15)
addq %r12, %rax
addq %r12, %r15
vmovups %ymm12, (%rax)
vmovups %ymm13, (%r15)
addq %r12, %rax
addq %r12, %r15
vmovups %ymm14, (%rax)
vmovups %ymm15, (%r15)
jmp WriteEnd
Write16:
movq %rdx, %rax
addq $64, %rax
movq %rax, -80(%rsp)
vmovups %ymm4, (%rdx)
vmovups %ymm5, 32(%rdx)
cmpq $1, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm6, (%rdx)
vmovups %ymm7, 32(%rdx)
cmpq $2, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm8, (%rdx)
vmovups %ymm9, 32(%rdx)
cmpq $3, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm10, (%rdx)
vmovups %ymm11, 32(%rdx)
cmpq $4, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm12, (%rdx)
vmovups %ymm13, 32(%rdx)
cmpq $5, %rbp
je WriteEnd
addq %r10, %rdx
vmovups %ymm14, (%rdx)
vmovups %ymm15, 32(%rdx)
addq %r10, %rdx
addq $64, %rdx
WriteEnd:
cmpq $16, %rbx
jbe LoopColEnd
subq $16, %rbx
jmp LoopCol
LoopColEnd:
movq -96(%rsp), %rdi
addq %r11, %rdi
movq %rdi, -96(%rsp)
cmpq $0, %r14
je C8DstStep
cmpq $2, %r14
je WinoDstStep
movq $4, %rax
movq 16(%rsp), %rbx
imul %rbx, %rax
subq %rax, %rdx
movq %rdx, -80(%rsp)
jmp NoDstStep
C8DstStep:
movq -80(%rsp), %rax
addq $384, %rax
movq %rax, -80(%rsp)
jmp NoDstStep
WinoDstStep:
addq %r13, %rdx
movq %rdx, -80(%rsp)
NoDstStep:
cmpq $6, %rbp
jbe LoopRowEnd
subq $6, %rbp
jmp LoopRow
LoopRowEnd:
subq $96, %rsp
popq %rdi
popq %rsi
popq %rdx
popq %rcx
popq %r8
popq %r9
popq %rbp
popq %rbx
popq %r12
popq %r13
popq %r14
popq %r15
retq
#endif
#endif

View File

@ -56,7 +56,7 @@ void AdderFp32(const float *input_data, float *packed_input, const float *packed
int out_channel = conv_param->output_channel_;
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
int output_count = conv_param->output_h_ * conv_param->output_w_;
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
const int cal_num = C4NUM;
#else
const int cal_num = C12NUM;
@ -78,7 +78,7 @@ void AdderFp32(const float *input_data, float *packed_input, const float *packed
int out_offset = thread_id * cal_num * out_channel + out_batch_offset;
float *gemm_output = output_data + out_offset;
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep);
#else
RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep);

View File

@ -43,7 +43,7 @@ void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_p
void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t stride, size_t relu_type) {
#if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE)
#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, relu_type, C8NUM);
#else
size_t oc8mod = output_channel % C8NUM;
@ -68,7 +68,7 @@ void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bi
return;
}
#if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE)
#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) {
const int unitStep = 4 * length;
for (int y = 0; y < h; ++y) {

View File

@ -39,7 +39,7 @@ float ShortToFloat32(uint16_t src_value);
uint16_t Float32ToShort(float src_value);
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6);

View File

@ -202,7 +202,7 @@ void ConvDwBorder(float *dst, const float *src, const float *weight, const float
const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM;
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float),
conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6);
@ -285,7 +285,7 @@ void ConvDwSWFp32(float *output_data, const float *input_data, const float *weig
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
@ -839,7 +839,7 @@ void DeconvDwSWFp32(float *output_data, const float *input_data, const float *we
float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),

View File

@ -26,7 +26,9 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_
int out_channel = conv_param->output_channel_;
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
int output_count = conv_param->output_h_ * conv_param->output_w_;
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#ifdef ENABLE_AVX
const int cal_num = C6NUM;
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
const int cal_num = C4NUM;
#else
const int cal_num = C12NUM;
@ -48,7 +50,9 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_
int out_offset = thread_id * cal_num * out_channel + out_batch_offset;
float *gemm_output = output_data + out_offset;
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#ifdef ENABLE_AVX
RowMajor2Col6Major(gemm_input, col_major_gemm_input, cal_num, deep);
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep);
#else
RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep);
@ -97,7 +101,7 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
float *dst_ptr = gemm_out + task_id * gemm_out_offset;
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
#else
RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);

View File

@ -41,7 +41,7 @@ void DeConvPostFp32C8(const float *src, float *tmp, const float *bias, float *ds
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
size_t output_plane = conv_param->output_w_ * conv_param->output_h_;
int oc8 = UP_ROUND(output_channel, C8NUM);
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
const int tile_num = 4;
#else
const int tile_num = 12;

View File

@ -28,9 +28,21 @@ void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col)
for (int r = 0; r < row; r++) {
const float *src = src_ptr + r * col;
for (int c = 0; c < col; c++) {
int cd8 = c / 4;
int cm8 = c % 4;
dst_ptr[cd8 * 4 * row + r * 4 + cm8] = src[c];
int cd4 = c / C4NUM;
int cm4 = c % C4NUM;
dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = src[c];
}
}
return;
}
void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col) {
for (int r = 0; r < row; r++) {
const float *src = src_ptr + r * col;
for (int c = 0; c < col; c++) {
int cd6 = c / C6NUM;
int cm6 = c % C6NUM;
dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = src[c];
}
}
return;
@ -40,9 +52,9 @@ void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col)
for (int r = 0; r < row; r++) {
const float *src = src_ptr + r * col;
for (int c = 0; c < col; c++) {
int cd8 = c / 8;
int cm8 = c % 8;
dst_ptr[cd8 * 8 * row + r * 8 + cm8] = src[c];
int cd8 = c / C8NUM;
int cm8 = c % C8NUM;
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[c];
}
}
return;
@ -52,9 +64,21 @@ void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col)
for (int r = 0; r < row; r++) {
const float *src = src_ptr + r * col;
for (int c = 0; c < col; c++) {
int cd8 = c / C12NUM;
int cm8 = c % C12NUM;
dst_ptr[cd8 * C12NUM * row + r * C12NUM + cm8] = src[c];
int cd12 = c / C12NUM;
int cm12 = c % C12NUM;
dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = src[c];
}
}
return;
}
void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col) {
for (int r = 0; r < row; r++) {
const float *src = src_ptr + r * col;
for (int c = 0; c < col; c++) {
int cd16 = c / C16NUM;
int cm16 = c % C16NUM;
dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = src[c];
}
}
return;
@ -190,7 +214,7 @@ void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, size_t row, size_
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
#elif ENABLE_X86_64_SSE
#elif ENABLE_SSE
__m128 src1 = _mm_loadu_ps(src_c);
__m128 src2 = _mm_loadu_ps(src_c + col);
__m128 src3 = _mm_loadu_ps(src_c + 2 * col);
@ -421,7 +445,7 @@ void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
#elif ENABLE_X86_64_SSE
#elif ENABLE_SSE
/* 8x4 row-major to col-major */
__m128 src1 = _mm_loadu_ps(src_c);
__m128 src2 = _mm_loadu_ps(src_c + col);
@ -478,6 +502,145 @@ void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t
return;
}
void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
size_t row16 = row / C16NUM * C16NUM;
size_t col_skip = col / C4NUM * C4NUM;
int skip_size = C4NUM;
const float *src_r = src_ptr;
float *dst_r = dst_ptr;
size_t ri = 0;
for (; ri < row16; ri += C16NUM) {
size_t ci = 0;
for (; ci < col_skip; ci += skip_size) {
const float *src_c = src_r + ci;
float *dst_c = dst_r + ci * C16NUM;
for (int tr = 0; tr < C16NUM; tr++) {
for (int tc = 0; tc < C4NUM; tc++) {
dst_c[tc * C16NUM + tr] = src_c[tr * col + tc];
}
}
}
for (; ci < col; ci++) {
const float *src_c = src_r + ci;
float *dst_c = dst_r + ci * C16NUM;
for (size_t i = 0; i < C16NUM; i++) {
dst_c[i] = src_c[i * col];
}
}
src_r += C16NUM * col;
dst_r += C16NUM * col;
}
for (; ri < row; ri++) {
for (size_t i = 0; i < col; i++) {
dst_r[i * C16NUM] = src_r[i];
}
src_r += col;
dst_r += 1;
}
return;
}
void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
size_t totalRow = UP_ROUND(row, C6NUM);
size_t row6 = row / C6NUM * C6NUM;
size_t col8 = col / C8NUM * C8NUM;
const float *src_r = src_ptr;
float *dst_r = dst_ptr;
size_t ri = 0;
for (; ri < row6; ri += C6NUM) {
size_t ci = 0;
for (; ci < col8; ci += C8NUM) {
const float *src_c = src_r + ci;
float *dst_c = dst_r + ci * C6NUM;
/* 6x8 row-major to col-major */
#ifdef ENABLE_AVX
__m256 src0 = _mm256_loadu_ps(src_c);
__m256 src1 = _mm256_loadu_ps(src_c + col);
__m256 src2 = _mm256_loadu_ps(src_c + 2 * col);
__m256 src3 = _mm256_loadu_ps(src_c + 3 * col);
__m256 src4 = _mm256_loadu_ps(src_c + 4 * col);
__m256 src5 = _mm256_loadu_ps(src_c + 5 * col);
__m256 trans0 = _mm256_unpacklo_ps(src0, src1);
__m256 trans1 = _mm256_unpacklo_ps(src2, src3);
__m256 trans2 = _mm256_unpacklo_ps(src4, src5);
__m256 trans3 = _mm256_unpackhi_ps(src0, src1);
__m256 trans4 = _mm256_unpackhi_ps(src2, src3);
__m256 trans5 = _mm256_unpackhi_ps(src4, src5);
__m128 lo0 = _mm256_castps256_ps128(trans0);
__m128 lo1 = _mm256_castps256_ps128(trans1);
__m128 lo2 = _mm256_castps256_ps128(trans2);
__m128 lo3 = _mm256_castps256_ps128(trans3);
__m128 lo4 = _mm256_castps256_ps128(trans4);
__m128 lo5 = _mm256_castps256_ps128(trans5);
__m128 hi0 = _mm256_extractf128_ps(trans0, 1);
__m128 hi1 = _mm256_extractf128_ps(trans1, 1);
__m128 hi2 = _mm256_extractf128_ps(trans2, 1);
__m128 hi3 = _mm256_extractf128_ps(trans3, 1);
__m128 hi4 = _mm256_extractf128_ps(trans4, 1);
__m128 hi5 = _mm256_extractf128_ps(trans5, 1);
__m128 res0 = _mm_shuffle_ps(lo0, lo1, _MM_SHUFFLE(1, 0, 1, 0));
__m128 res1 = _mm_shuffle_ps(lo2, lo0, _MM_SHUFFLE(3, 2, 1, 0));
__m128 res2 = _mm_shuffle_ps(lo1, lo2, _MM_SHUFFLE(3, 2, 3, 2));
__m128 res3 = _mm_shuffle_ps(lo3, lo4, _MM_SHUFFLE(1, 0, 1, 0));
__m128 res4 = _mm_shuffle_ps(lo5, lo3, _MM_SHUFFLE(3, 2, 1, 0));
__m128 res5 = _mm_shuffle_ps(lo4, lo5, _MM_SHUFFLE(3, 2, 3, 2));
__m128 res6 = _mm_shuffle_ps(hi0, hi1, _MM_SHUFFLE(1, 0, 1, 0));
__m128 res7 = _mm_shuffle_ps(hi2, hi0, _MM_SHUFFLE(3, 2, 1, 0));
__m128 res8 = _mm_shuffle_ps(hi1, hi2, _MM_SHUFFLE(3, 2, 3, 2));
__m128 res9 = _mm_shuffle_ps(hi3, hi4, _MM_SHUFFLE(1, 0, 1, 0));
__m128 res10 = _mm_shuffle_ps(hi5, hi3, _MM_SHUFFLE(3, 2, 1, 0));
__m128 res11 = _mm_shuffle_ps(hi4, hi5, _MM_SHUFFLE(3, 2, 3, 2));
_mm_storeu_ps(dst_c, res0);
_mm_storeu_ps(dst_c + 4, res1);
_mm_storeu_ps(dst_c + 8, res2);
_mm_storeu_ps(dst_c + 12, res3);
_mm_storeu_ps(dst_c + 16, res4);
_mm_storeu_ps(dst_c + 20, res5);
_mm_storeu_ps(dst_c + 24, res6);
_mm_storeu_ps(dst_c + 28, res7);
_mm_storeu_ps(dst_c + 32, res8);
_mm_storeu_ps(dst_c + 36, res9);
_mm_storeu_ps(dst_c + 40, res10);
_mm_storeu_ps(dst_c + 44, res11);
#else
for (int tr = 0; tr < C6NUM; tr++) {
for (int tc = 0; tc < C8NUM; tc++) {
dst_c[tc * C6NUM + tr] = src_c[tr * col + tc];
}
}
#endif
}
for (; ci < col; ci++) {
const float *src_c = src_r + ci;
float *dst_c = dst_r + ci * C6NUM;
for (size_t i = 0; i < C6NUM; i++) {
dst_c[i] = src_c[i * col];
}
}
src_r += C6NUM * col;
dst_r += C6NUM * col;
}
for (; ri < row; ri++) {
for (size_t i = 0; i < col; i++) {
dst_r[i * C6NUM] = src_r[i];
}
src_r += col;
dst_r += 1;
}
for (; ri < totalRow; ri++) {
for (size_t i = 0; i < col; i++) {
dst_r[i * C6NUM] = 0;
}
dst_r += 1;
}
return;
}
void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
size_t row8 = row / C4NUM * C4NUM;
size_t col4 = col / C4NUM * C4NUM;
@ -519,7 +682,7 @@ void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "r10", "r12", "q0", "q1", "q2", "q3");
#elif ENABLE_X86_64_SSE
#elif ENABLE_SSE
__m128 src1 = _mm_loadu_ps(src_c);
__m128 src2 = _mm_loadu_ps(src_c + col);
__m128 src3 = _mm_loadu_ps(src_c + 2 * col);
@ -630,6 +793,34 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
return;
}
#ifdef ENABLE_AVX
#ifdef WIN32
void MatMul6x16(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, int out_type) {
if (out_type == OutType_Nhwc) {
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r6div = r / C6NUM, r6mod = r % C6NUM;
int c16div = c / C16NUM, c16mod = c % C16NUM;
size_t ci = r * stride + c;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r6div * deep * C6NUM + d * C6NUM + r6mod;
size_t bi = c16div * deep * C16NUM + d * C16NUM + c16mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[c];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
}
}
return;
}
#endif
#endif
void MatMul4x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, int out_type) {
if (out_type == OutType_C8) {
@ -670,7 +861,19 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
} else {
MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
}
#elif ENABLE_X86_64_SSE
#elif ENABLE_AVX
if (out_type == OutType_Nhwc) {
#ifdef WIN32
MatMul6x16(a, b, c, bias, act_type, deep, row, col, stride, out_type);
#else
MatmulFloatAvxOpt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
#endif
} else if (out_type == OutType_C8) {
MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
} else {
MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
}
#elif ENABLE_SSE
if (out_type == OutType_C8) {
MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
} else {

View File

@ -31,11 +31,15 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
void MatVecMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int col);
void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
#ifdef ENABLE_ARM
void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col);
#endif
@ -49,11 +53,16 @@ void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bi
int col, int stride, size_t writeNhwc, size_t WriteWino);
void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, int stride, int write_mode);
#elif ENABLE_X86_64_SSE
#elif ENABLE_SSE
#include <x86intrin.h>
void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, int stride, size_t writeNhwc, size_t WriteWino);
void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, int stride, int write_mode);
#ifdef ENABLE_AVX
void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, int stride, int write_mode);
#endif
#endif
#ifdef ENABLE_NNACL_INFER_SHAPE

View File

@ -42,6 +42,7 @@ typedef struct MatMulParameter {
int row_;
int col_;
int row_4_;
int row_6_;
int row_8_;
int row_12_;
int row_16_;

View File

@ -125,7 +125,7 @@ int B(const float *poly_array, float *matrix_b, int in_unit) {
return NNACL_OK;
}
#if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE)
#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n,
int in_channel, int c4_channel) {
int cnt = 0;
@ -228,7 +228,7 @@ int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *ma
return NNACL_OK;
}
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c,
const float *bias, int m, int k, int n) {
int count = 0;

View File

@ -52,7 +52,7 @@ void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *
int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, const float *matrix_gt,
int oc_block, int input_unit_, int kernel_unit_, int channel, int batch, bool pack);
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c,
const float *bias, int m, int k, int n);
#endif

View File

@ -21,8 +21,8 @@
#include <arm_neon.h>
#endif
#ifdef ENABLE_X86_64_SSE
#include <nmmintrin.h>
#ifdef ENABLE_SSE
#include <x86intrin.h>
#endif
#include <stdint.h>
@ -31,6 +31,7 @@
#define C2NUM 2
#define C4NUM 4
#define C6NUM 6
#define C8NUM 8
#define C12NUM 12
#define C16NUM 16
@ -91,7 +92,7 @@ typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6,
#define MS_MAXQ_F32 vmaxq_f32
#define MS_MINQ_F32 vminq_f32
#define MS_MULQ_F32(src1, src2) vmulq_n_f32(src1, src2)
#elif defined(ENABLE_X86_64_SSE)
#elif defined(ENABLE_SSE)
#define MS_FLOAT32X4 __m128
#define MS_LDQ_F32 _mm_loadu_ps
#define MS_ADDQ_F32 _mm_add_ps

View File

@ -756,7 +756,7 @@ void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int ch
return;
}
#ifndef ENABLE_X86_64_SSE
#ifndef ENABLE_SSE
void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) {
int hw8 = plane / C8NUM * C8NUM;
int c8 = channel / C8NUM * C8NUM;

View File

@ -79,7 +79,7 @@ void GeneralInputTransformUnit(const float *src_data, float *dst_data, const flo
int src_step, int dst_step, int in_unit) {
int len = in_unit * in_unit;
if (len > MAX_LEN) return;
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
MS_FLOAT32X4 src[MAX_LEN];
MS_FLOAT32X4 t[MAX_LEN];
MS_FLOAT32X4 m[MAX_LEN];

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifdef ENABLE_X86_64_SSE
#include <nmmintrin.h>
#ifdef ENABLE_SSE
#include <x86intrin.h>
#include "nnacl/fp32/conv_depthwise_fp32.h"
void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifdef ENABLE_X86_64_SSE
#include <nmmintrin.h>
#ifdef ENABLE_SSE
#include <x86intrin.h>
#include "nnacl/minimal_filtering_generator.h"
#include "nnacl/op_base.h"

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifdef ENABLE_X86_64_SSE
#include <nmmintrin.h>
#ifdef ENABLE_SSE
#include <x86intrin.h>
#include "nnacl/pack.h"
#include "nnacl/int8/conv_int8.h"

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifdef ENABLE_X86_64_SSE
#include <nmmintrin.h>
#ifdef ENABLE_SSE
#include <x86intrin.h>
#include "nnacl/fp32/common_func_fp32.h"
void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod,

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef ENABLE_X86_64_SSE
#include <nmmintrin.h>
#ifdef ENABLE_SSE
#include <x86intrin.h>
#include "nnacl/fp32/common_func_fp32.h"
void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) {

View File

@ -60,6 +60,7 @@ void Convolution1x1CPUKernel::InitConv1x1MatmulParam() {
matmul_param_->col_ = conv_param_->output_channel_;
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM);
matmul_param_->row_6_ = UP_ROUND(matmul_param_->row_, C6NUM);
matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM);
matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM);
matmul_param_->act_type_ = conv_param_->act_type_;
@ -71,8 +72,13 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() {
auto input_channel = filter_tensor->Channel();
auto output_channel = filter_tensor->Batch();
#ifdef ENABLE_AVX
int col_tile = C16NUM;
#else
int col_tile = C8NUM;
#endif
if (in_tensors_.size() == 3) {
int size = UP_ROUND(output_channel, C8NUM) * sizeof(float);
int size = UP_ROUND(output_channel, col_tile) * sizeof(float);
int weight_size = output_channel * sizeof(float);
bias_data_ = malloc(size);
if (bias_data_ == nullptr) {
@ -83,22 +89,29 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() {
memset(reinterpret_cast<char *>(bias_data_) + weight_size, 0, size - weight_size);
}
int size = input_channel * UP_ROUND(output_channel, C8NUM) * sizeof(float);
int down_size = input_channel * DOWN_DIV(output_channel, C8NUM) * C8NUM * sizeof(float);
int size = input_channel * UP_ROUND(output_channel, col_tile) * sizeof(float);
int down_size = input_channel * DOWN_DIV(output_channel, col_tile) * col_tile * sizeof(float);
weight_ptr_ = reinterpret_cast<float *>(malloc(size));
if (weight_ptr_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!";
return RET_ERROR;
}
memset(reinterpret_cast<char *>(weight_ptr_) + down_size, 0, size - down_size);
#ifdef ENABLE_AVX
RowMajor2Col16Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel,
input_channel);
#else
RowMajor2Col8Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel,
input_channel);
#endif
return RET_OK;
}
int Convolution1x1CPUKernel::InitConv1x1Param() {
int hw_tile = C12NUM;
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#ifdef ENABLE_AVX
hw_tile = C6NUM;
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
hw_tile = C4NUM;
#endif
if ((matmul_param_->row_ > (hw_tile * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) {
@ -106,9 +119,14 @@ int Convolution1x1CPUKernel::InitConv1x1Param() {
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, hw_tile));
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, hw_tile), thread_count_) * hw_tile;
} else {
#ifdef ENABLE_AVX
int col_tile = C16NUM;
#else
int col_tile = C8NUM;
#endif
multi_thread_by_hw_ = false;
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM));
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM;
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, col_tile));
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, col_tile), thread_count_) * col_tile;
}
pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 ||
@ -175,7 +193,9 @@ int Convolution1x1CPUKernel::DoConv1x1Hw(int task_id) {
float *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_;
float *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_;
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#if ENABLE_AVX
RowMajor2Col6Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
RowMajor2Col4Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
#else
RowMajor2Col12Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
@ -202,7 +222,10 @@ int Convolution1x1CPUKernel::Run() {
auto src_in = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
auto src_out = reinterpret_cast<float *>(out_tensors_[0]->MutableData());
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#ifdef ENABLE_AVX
pack_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_6_ * matmul_param_->deep_ * sizeof(float)));
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
pack_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float)));
#else
@ -226,7 +249,9 @@ int Convolution1x1CPUKernel::Run() {
if (multi_thread_by_hw_) {
ParallelLaunch(this->context_->thread_pool_, Convolution1x1RunHw, this, thread_count_);
} else {
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#ifdef ENABLE_AVX
RowMajor2Col6Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
#else
RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);

View File

@ -42,7 +42,12 @@ int ConvolutionCPUKernel::InitWeightBias() {
conv_param_->input_channel_ = in_channel;
conv_param_->output_channel_ = out_channel;
int kernel_plane = filter_tensor->Height() * filter_tensor->Width();
int oc_block_num = UP_ROUND(out_channel, C8NUM);
#ifdef ENABLE_AVX
const int oc_block = C16NUM;
#else
const int oc_block = C8NUM;
#endif
int oc_block_num = UP_ROUND(out_channel, oc_block);
int pack_weight_size = oc_block_num * in_channel * kernel_plane;
auto origin_weight = reinterpret_cast<float *>(filter_tensor->data_c());
@ -52,7 +57,11 @@ int ConvolutionCPUKernel::InitWeightBias() {
return RET_ERROR;
}
memset(packed_weight_, 0, pack_weight_size * sizeof(float));
#ifdef ENABLE_AVX
RowMajor2Col16Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane);
#else
RowMajor2Col8Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane);
#endif
bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * sizeof(float)));
if (bias_data_ == nullptr) {
@ -72,7 +81,10 @@ int ConvolutionCPUKernel::InitWeightBias() {
int ConvolutionCPUKernel::InitTmpBuffer() {
MS_ASSERT(ctx_->allocator != nullptr);
#ifdef ENABLE_ARM32
#ifdef ENABLE_AVX
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C6NUM * thread_count_;
#elif ENABLE_ARM32 || ENABLE_SSE
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C4NUM * thread_count_;
#else
int unit_size =

View File

@ -115,7 +115,7 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) {
return RET_OK;
}
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_4_;
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_4_, oc * C8NUM * kernel_plane_,
@ -174,7 +174,7 @@ int DeConvolutionCPUKernel::InitRunBuf() {
return RET_NULL_PTR;
}
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
tmp_buffer_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->col_8_ * sizeof(float)));
#else
@ -186,7 +186,7 @@ int DeConvolutionCPUKernel::InitRunBuf() {
return RET_NULL_PTR;
}
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
pack_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float)));
#else
@ -215,7 +215,7 @@ int DeConvolutionCPUKernel::Run() {
input_ptr_ = src_in + batch_index * input_plane_ * conv_param_->input_channel_;
output_ptr_ = src_out + batch_index * output_plane_ * conv_param_->output_channel_;
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
#else
RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);

View File

@ -51,12 +51,18 @@ int FullconnectionCPUKernel::ReSize() {
fc_param_->col_ = out_tensors_.at(0)->shape().back();
fc_param_->deep_ = (in_tensors_.at(1)->shape()).at(1);
#ifdef ENABLE_AVX
int col_tile = C16NUM;
#else
int col_tile = C8NUM;
#endif
fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, col_tile);
fc_param_->row_6_ = UP_ROUND(fc_param_->col_, C6NUM);
fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM);
thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_);
thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, col_tile));
thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, col_tile), thread_count_);
#ifdef ENABLE_ARM
if (fc_param_->row_ == 1) {
@ -75,7 +81,9 @@ int FullconnectionCPUKernel::ReSize() {
memcpy(bias_ptr_, in_tensors_[2]->MutableData(), fc_param_->col_ * sizeof(float));
}
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#ifdef ENABLE_AVX
int row_tmp = is_vector_input_ ? 1 : fc_param_->row_6_;
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
int row_tmp = is_vector_input_ ? 1 : fc_param_->row_4_;
#else
int row_tmp = is_vector_input_ ? 1 : fc_param_->row_12_;
@ -120,7 +128,9 @@ void FullconnectionCPUKernel::InitMatrixA(const float *src_ptr, float *dst_ptr)
return;
}
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#ifdef ENABLE_AVX
RowMajor2Col6Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_);
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
RowMajor2Col4Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_);
#else
RowMajor2Col12Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_);
@ -132,8 +142,11 @@ void FullconnectionCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr)
memcpy(dst_ptr, src_ptr, fc_param_->col_ * fc_param_->deep_ * sizeof(float));
return;
}
#ifdef ENABLE_AVX
RowMajor2Col16Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_);
#else
RowMajor2Col8Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_);
#endif
}
int FcFp32MatmulRun(void *cdata, int task_id) {
@ -147,14 +160,19 @@ int FcFp32MatmulRun(void *cdata, int task_id) {
}
int FullconnectionCPUKernel::DoMatmul(int task_id) {
int cur_oc = MSMIN(thread_stride_ * C8NUM, fc_param_->col_ - task_id * thread_stride_ * C8NUM);
#ifdef ENABLE_AVX
int col_tile = C16NUM;
#else
int col_tile = C8NUM;
#endif
int cur_oc = MSMIN(thread_stride_ * col_tile, fc_param_->col_ - task_id * thread_stride_ * col_tile);
if (cur_oc <= 0) {
return RET_OK;
}
auto b = b_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_;
auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + task_id * thread_stride_ * C8NUM;
auto c = c_ptr_ + task_id * thread_stride_ * C8NUM;
auto b = b_ptr_ + task_id * thread_stride_ * col_tile * fc_param_->deep_;
auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + task_id * thread_stride_ * col_tile;
auto c = c_ptr_ + task_id * thread_stride_ * col_tile;
if (is_vector_input_) {
MatVecMul(a_ptr_, b, c, bias, fc_param_->act_type_, fc_param_->deep_, cur_oc);
} else {

View File

@ -75,9 +75,12 @@ int MatmulCPUKernel::MallocMatrixABuffer() {
#endif
params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1];
params_->row_4_ = UP_ROUND(params_->row_, C4NUM);
params_->row_6_ = UP_ROUND(params_->row_, C6NUM);
params_->row_12_ = UP_ROUND(params_->row_, C12NUM);
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#ifdef ENABLE_AVX
int row_tmp = is_vector_a_ ? 1 : params_->row_6_;
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
int row_tmp = is_vector_a_ ? 1 : params_->row_4_;
#else
int row_tmp = is_vector_a_ ? 1 : params_->row_12_;
@ -106,9 +109,14 @@ int MatmulCPUKernel::MallocMatrixBBuffer() {
for (size_t i = 0; i < b_shape.size() - 2; ++i) {
batch *= b_shape[i];
}
#ifdef ENABLE_AVX
int col_tile = C16NUM;
#else
int col_tile = C8NUM;
#endif
params_->batch = batch;
params_->col_ = params_->b_transpose_ ? b_shape[b_shape.size() - 2] : b_shape[b_shape.size() - 1];
params_->col_8_ = UP_ROUND(params_->col_, 8);
params_->col_8_ = UP_ROUND(params_->col_, col_tile);
params_->deep_ = params_->b_transpose_ ? b_shape[b_shape.size() - 1] : b_shape[b_shape.size() - 2];
int col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_;
@ -123,8 +131,8 @@ int MatmulCPUKernel::MallocMatrixBBuffer() {
return RET_MEMORY_FAILED;
}
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_);
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, col_tile));
thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, col_tile), thread_count_);
return RET_OK;
}
@ -134,7 +142,12 @@ int MatmulCPUKernel::InitBias() {
params_->col_ = params_->b_const_
? (params_->b_transpose_ ? b_shape.at(b_shape.size() - 2) : b_shape.at(b_shape.size() - 1))
: (c_shape.at(c_shape.size() - 1));
params_->col_8_ = UP_ROUND(params_->col_, 8);
#ifdef ENABLE_AVX
int col_tile = C16NUM;
#else
int col_tile = C8NUM;
#endif
params_->col_8_ = UP_ROUND(params_->col_, col_tile);
auto col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_;
if (bias_ptr_ == nullptr) {
bias_ptr_ = reinterpret_cast<float *>(malloc(col_tmp * sizeof(float)));
@ -171,7 +184,14 @@ void MatmulCPUKernel::InitMatrixA(const float *src_ptr, float *dst_ptr) {
for (int i = 0; i < params_->batch; i++) {
const float *src = src_ptr + i * params_->deep_ * params_->row_;
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#ifdef ENABLE_AVX
float *dst = dst_ptr + i * params_->deep_ * params_->row_6_;
if (params_->a_transpose_) {
RowMajor2Row6Major(src, dst, params_->deep_, params_->row_);
} else {
RowMajor2Col6Major(src, dst, params_->row_, params_->deep_);
}
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
float *dst = dst_ptr + i * params_->deep_ * params_->row_4_;
if (params_->a_transpose_) {
RowMajor2Row4Major(src, dst, params_->deep_, params_->row_);
@ -207,11 +227,19 @@ void MatmulCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr) {
for (int i = 0; i < params_->batch; i++) {
const float *src = src_ptr + i * params_->deep_ * params_->col_;
float *dst = dst_ptr + i * params_->deep_ * params_->col_8_;
#ifdef ENABLE_AVX
if (params_->b_transpose_) {
RowMajor2Col16Major(src, dst, params_->col_, params_->deep_);
} else {
RowMajor2Row16Major(src, dst, params_->deep_, params_->col_);
}
#else
if (params_->b_transpose_) {
RowMajor2Col8Major(src, dst, params_->col_, params_->deep_);
} else {
RowMajor2Row8Major(src, dst, params_->deep_, params_->col_);
}
#endif
}
return;
}
@ -247,13 +275,18 @@ int MatmulCPUKernel::Init() {
}
int MatmulCPUKernel::RunImpl(int task_id) {
int cur_oc = MSMIN(thread_stride_ * C8NUM, params_->col_ - task_id * thread_stride_ * C8NUM);
#ifdef ENABLE_AVX
int col_tile = C16NUM;
#else
int col_tile = C8NUM;
#endif
int cur_oc = MSMIN(thread_stride_ * col_tile, params_->col_ - task_id * thread_stride_ * col_tile);
if (cur_oc <= 0) {
return RET_OK;
}
auto b = cur_b_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_;
auto c = cur_c_ptr_ + task_id * thread_stride_ * C8NUM;
auto bias = bias_ptr_ ? bias_ptr_ + task_id * thread_stride_ * C8NUM : NULL;
auto b = cur_b_ptr_ + task_id * thread_stride_ * col_tile * params_->deep_;
auto c = cur_c_ptr_ + task_id * thread_stride_ * col_tile;
auto bias = bias_ptr_ ? bias_ptr_ + task_id * thread_stride_ * col_tile : NULL;
MS_ASSERT(cur_a_ptr_);
MS_ASSERT(b);
MS_ASSERT(c);
@ -323,7 +356,9 @@ int MatmulCPUKernel::Run() {
cur_b_ptr_ = b_ptr_ + i * params_->deep_ * params_->col_;
cur_c_ptr_ = c_src + i * params_->row_ * params_->col_;
} else {
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
#ifdef ENABLE_AVX
cur_a_ptr_ = a_ptr_ + i * params_->row_6_ * params_->deep_;
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
cur_a_ptr_ = a_ptr_ + i * params_->row_4_ * params_->deep_;
#else
cur_a_ptr_ = a_ptr_ + i * params_->row_12_ * params_->deep_;

View File

@ -74,6 +74,16 @@ if ("${X86_64_SIMD}" STREQUAL "sse")
)
endif()
if ("${X86_64_SIMD}" STREQUAL "avx")
file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c
${LITE_DIR}/nnacl/assembly/avx/*.S)
set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C)
set(KERNEL_OP_SRC
${KERNEL_OP_SRC}
${TEST_ASSEMBLY_SRC}
)
endif()
### gpu kernel
if (SUPPORT_GPU)
file(GLOB GPU_KERNEL_OP_SRC