forked from mindspore-Ecosystem/mindspore
!9673 [MS][LITE][Develop]add avx fp32 matmul kernel
From: @lx0095 Reviewed-by: @zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
de2b53d9f0
|
@ -21,7 +21,8 @@ option(SUPPORT_NPU "if support npu" off)
|
|||
option(OFFLINE_COMPILE "if offline compile OpenCL kernel" off)
|
||||
option(BUILD_MINDDATA_EXAMPLE "" on)
|
||||
option(ENABLE_VERBOSE "" off)
|
||||
option(ENABLE_X86_64_SSE "if x86_64 support SSE instruction set" off)
|
||||
option(ENABLE_SSE "if x86_64 support SSE instruction set" off)
|
||||
option(ENABLE_AVX "if x86_64 support SSE instruction set" off)
|
||||
|
||||
set(DIR_PREFIX mindspore-lite)
|
||||
set(MS_VERSION ${MS_VERSION_MAJOR}.${MS_VERSION_MINOR}.${MS_VERSION_REVISION})
|
||||
|
@ -202,7 +203,13 @@ endif()
|
|||
|
||||
if (NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64)
|
||||
if ("${X86_64_SIMD}" STREQUAL "sse")
|
||||
add_compile_definitions(ENABLE_X86_64_SSE)
|
||||
add_compile_definitions(ENABLE_SSE)
|
||||
endif ()
|
||||
if ("${X86_64_SIMD}" STREQUAL "avx")
|
||||
add_compile_definitions(ENABLE_SSE)
|
||||
add_compile_definitions(ENABLE_AVX)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mfma")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx -mfma")
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
|
|
|
@ -37,6 +37,12 @@ if ("${X86_64_SIMD}" STREQUAL "sse")
|
|||
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
|
||||
endif()
|
||||
|
||||
if ("${X86_64_SIMD}" STREQUAL "avx")
|
||||
file(GLOB ASSEMBLY_SRC ${NNACL_DIR}/x86_64_sse/*.c
|
||||
${NNACL_DIR}/assembly/avx/*.S)
|
||||
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
|
||||
endif()
|
||||
|
||||
########################### build nnacl static library ########################
|
||||
string(REPLACE "-fvisibility=hidden" "-fvisibility=default" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
|
||||
add_library(nnacl STATIC ${KERNEL_SRC} ${TRAIN_SRC} ${ASSEMBLY_SRC})
|
||||
|
|
|
@ -0,0 +1,941 @@
|
|||
#ifdef ENABLE_AVX
|
||||
#ifndef WIN32
|
||||
.text
|
||||
.align 4
|
||||
.global MatmulFloatAvxOpt
|
||||
#ifndef __APPLE__
|
||||
.type MatmulFloatAvxOpt, %function
|
||||
#endif
|
||||
|
||||
// void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
|
||||
// int row, int col, size_t stride, size_t writeMode)
|
||||
// rdi: a
|
||||
// rsi: b
|
||||
// rdx: c
|
||||
// rcx: bias
|
||||
// r8: act_type
|
||||
// r9: depth
|
||||
// 8: row
|
||||
// 16: col
|
||||
// 24: stride
|
||||
// 32: writeNhwc/writeWino
|
||||
|
||||
MatmulFloatAvxOpt:
|
||||
// rbx, rsp, rbp, r12-r15 must be saved according to x86 calling convention
|
||||
pushq %r15
|
||||
pushq %r14
|
||||
pushq %r13
|
||||
pushq %r12
|
||||
pushq %rbx
|
||||
pushq %rbp
|
||||
pushq %r9
|
||||
pushq %r8
|
||||
pushq %rcx
|
||||
pushq %rdx
|
||||
pushq %rsi
|
||||
pushq %rdi
|
||||
addq $96, %rsp
|
||||
|
||||
movq 8(%rsp), %rbp
|
||||
movq 16(%rsp), %rbx
|
||||
movq 24(%rsp), %r10
|
||||
movq 32(%rsp), %r14
|
||||
|
||||
movq $24, %r11
|
||||
imul %r9, %r11
|
||||
cmpq $0, %r14
|
||||
jne NoC8Steps
|
||||
movq $48, %r13
|
||||
imul %rbp, %r13
|
||||
NoC8Steps:
|
||||
cmpq $2, %r14
|
||||
jne NoWinoSteps
|
||||
movq $4, %r12
|
||||
imul %r10, %r12
|
||||
imul %rbx, %r12
|
||||
movq $48, %r13
|
||||
imul %r10, %r13
|
||||
NoWinoSteps:
|
||||
movq $4, %rax
|
||||
imul %rax, %r10
|
||||
|
||||
LoopRow:
|
||||
movq -88(%rsp), %rsi
|
||||
movq 16(%rsp), %rbx
|
||||
movq -72(%rsp), %rcx
|
||||
|
||||
LoopCol:
|
||||
cmpq $0, %r14
|
||||
je NoReloadDst
|
||||
movq -80(%rsp), %rdx
|
||||
NoReloadDst:
|
||||
movq -96(%rsp), %rdi
|
||||
movq -56(%rsp), %r9
|
||||
|
||||
vmovups (%rsi), %ymm0
|
||||
vmovups 32(%rsi), %ymm1
|
||||
vbroadcastss (%rdi), %ymm10
|
||||
vbroadcastss 4(%rdi), %ymm11
|
||||
vbroadcastss 8(%rdi), %ymm12
|
||||
vbroadcastss 12(%rdi), %ymm13
|
||||
vbroadcastss 16(%rdi), %ymm2
|
||||
vbroadcastss 20(%rdi), %ymm3
|
||||
addq $64, %rsi
|
||||
vmulps %ymm0, %ymm10, %ymm4
|
||||
vmulps %ymm1, %ymm10, %ymm5
|
||||
vmulps %ymm0, %ymm11, %ymm6
|
||||
vmulps %ymm1, %ymm11, %ymm7
|
||||
vmulps %ymm0, %ymm12, %ymm8
|
||||
vmulps %ymm1, %ymm12, %ymm9
|
||||
vmulps %ymm0, %ymm13, %ymm10
|
||||
vmulps %ymm1, %ymm13, %ymm11
|
||||
add $24, %rdi
|
||||
vmulps %ymm0, %ymm2, %ymm12
|
||||
vmulps %ymm1, %ymm2, %ymm13
|
||||
vmulps %ymm0, %ymm3, %ymm14
|
||||
vmulps %ymm1, %ymm3, %ymm15
|
||||
|
||||
subq $1, %r9
|
||||
cmpq $0, %r9
|
||||
je Bias
|
||||
|
||||
LoopDepth:
|
||||
vmovups (%rsi), %ymm0
|
||||
vmovups 32(%rsi), %ymm1
|
||||
vbroadcastss (%rdi), %ymm2
|
||||
vbroadcastss 4(%rdi), %ymm3
|
||||
vfmadd231ps %ymm0, %ymm2, %ymm4
|
||||
addq $64, %rsi
|
||||
vfmadd231ps %ymm1, %ymm2, %ymm5
|
||||
vbroadcastss 8(%rdi), %ymm2
|
||||
vfmadd231ps %ymm0, %ymm3, %ymm6
|
||||
vfmadd231ps %ymm1, %ymm3, %ymm7
|
||||
vbroadcastss 12(%rdi), %ymm3
|
||||
vfmadd231ps %ymm0, %ymm2, %ymm8
|
||||
prefetcht0 384(%rsi)
|
||||
vfmadd231ps %ymm1, %ymm2, %ymm9
|
||||
vbroadcastss 16(%rdi), %ymm2
|
||||
vfmadd231ps %ymm0, %ymm3, %ymm10
|
||||
vfmadd231ps %ymm1, %ymm3, %ymm11
|
||||
vbroadcastss 20(%rdi), %ymm3
|
||||
vfmadd231ps %ymm0, %ymm2, %ymm12
|
||||
vfmadd231ps %ymm1, %ymm2, %ymm13
|
||||
addq $24, %rdi
|
||||
vfmadd231ps %ymm0, %ymm3, %ymm14
|
||||
vfmadd231ps %ymm1, %ymm3, %ymm15
|
||||
|
||||
subq $1, %r9
|
||||
cmpq $0, %r9
|
||||
ja LoopDepth
|
||||
|
||||
Bias:
|
||||
cmpq $0, %rcx
|
||||
je Activation
|
||||
vmovups (%rcx), %ymm0
|
||||
vmovups 32(%rcx), %ymm1
|
||||
add $64, %rcx
|
||||
vaddps %ymm0, %ymm4, %ymm4
|
||||
vaddps %ymm1, %ymm5, %ymm5
|
||||
vaddps %ymm0, %ymm6, %ymm6
|
||||
vaddps %ymm1, %ymm7, %ymm7
|
||||
vaddps %ymm0, %ymm8, %ymm8
|
||||
vaddps %ymm1, %ymm9, %ymm9
|
||||
vaddps %ymm0, %ymm10, %ymm10
|
||||
vaddps %ymm1, %ymm11, %ymm11
|
||||
vaddps %ymm0, %ymm12, %ymm12
|
||||
vaddps %ymm1, %ymm13, %ymm13
|
||||
vaddps %ymm0, %ymm14, %ymm14
|
||||
vaddps %ymm1, %ymm15, %ymm15
|
||||
|
||||
Activation:
|
||||
cmpq $3, %r8
|
||||
je Relu6
|
||||
cmpq $1, %r8
|
||||
je Relu
|
||||
jmp Write
|
||||
|
||||
Relu6:
|
||||
movq $6, %rax
|
||||
vcvtsi2ss %rax, %xmm0, %xmm0
|
||||
vshufps $0, %xmm0, %xmm0, %xmm0
|
||||
vinsertf128 $1, %xmm0, %ymm0, %ymm0
|
||||
vminps %ymm0, %ymm4, %ymm4
|
||||
vminps %ymm0, %ymm5, %ymm5
|
||||
vminps %ymm0, %ymm6, %ymm6
|
||||
vminps %ymm0, %ymm7, %ymm7
|
||||
vminps %ymm0, %ymm8, %ymm8
|
||||
vminps %ymm0, %ymm9, %ymm9
|
||||
vminps %ymm0, %ymm10, %ymm10
|
||||
vminps %ymm0, %ymm11, %ymm11
|
||||
vminps %ymm0, %ymm12, %ymm12
|
||||
vminps %ymm0, %ymm13, %ymm13
|
||||
vminps %ymm0, %ymm14, %ymm14
|
||||
vminps %ymm0, %ymm15, %ymm15
|
||||
|
||||
Relu:
|
||||
vxorps %ymm1, %ymm1, %ymm1
|
||||
vmaxps %ymm1, %ymm4, %ymm4
|
||||
vmaxps %ymm1, %ymm5, %ymm5
|
||||
vmaxps %ymm1, %ymm6, %ymm6
|
||||
vmaxps %ymm1, %ymm7, %ymm7
|
||||
vmaxps %ymm1, %ymm8, %ymm8
|
||||
vmaxps %ymm1, %ymm9, %ymm9
|
||||
vmaxps %ymm1, %ymm10, %ymm10
|
||||
vmaxps %ymm1, %ymm11, %ymm11
|
||||
vmaxps %ymm1, %ymm12, %ymm12
|
||||
vmaxps %ymm1, %ymm13, %ymm13
|
||||
vmaxps %ymm1, %ymm14, %ymm14
|
||||
vmaxps %ymm1, %ymm15, %ymm15
|
||||
|
||||
Write:
|
||||
cmpq $2, %r14
|
||||
je WriteWino
|
||||
cmpq $0, %r14
|
||||
je WriteC8
|
||||
cmpq $1, %rbx
|
||||
je Write1
|
||||
cmpq $2, %rbx
|
||||
je Write2
|
||||
cmpq $3, %rbx
|
||||
je Write3
|
||||
cmpq $4, %rbx
|
||||
je Write4
|
||||
cmpq $5, %rbx
|
||||
je Write5
|
||||
cmpq $6, %rbx
|
||||
je Write6
|
||||
cmpq $7, %rbx
|
||||
je Write7
|
||||
cmpq $8, %rbx
|
||||
je Write8
|
||||
cmpq $9, %rbx
|
||||
je Write9
|
||||
cmpq $10, %rbx
|
||||
je Write10
|
||||
cmpq $11, %rbx
|
||||
je Write11
|
||||
cmpq $12, %rbx
|
||||
je Write12
|
||||
cmpq $13, %rbx
|
||||
je Write13
|
||||
cmpq $14, %rbx
|
||||
je Write14
|
||||
cmpq $15, %rbx
|
||||
je Write15
|
||||
jmp Write16
|
||||
|
||||
Write1:
|
||||
movq %rdx, %rax
|
||||
addq $4, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovss %xmm4, (%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovss %xmm6, (%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovss %xmm8, (%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovss %xmm10, (%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovss %xmm12, (%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovss %xmm14, (%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $4, %rdx
|
||||
jmp WriteEnd
|
||||
Write2:
|
||||
movq %rdx, %rax
|
||||
addq $8, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovsd %xmm4, (%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovsd %xmm6, (%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovsd %xmm8, (%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovsd %xmm10, (%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovsd %xmm12, (%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovsd %xmm14, (%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $8, %rdx
|
||||
jmp WriteEnd
|
||||
Write3:
|
||||
movq %rdx, %rax
|
||||
addq $12, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovsd %xmm4, (%rdx)
|
||||
movhlps %xmm4, %xmm4
|
||||
vmovss %xmm4, 8(%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovsd %xmm6, (%rdx)
|
||||
movhlps %xmm6, %xmm6
|
||||
vmovss %xmm6, 8(%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovsd %xmm8, (%rdx)
|
||||
movhlps %xmm8, %xmm8
|
||||
vmovss %xmm8, 8(%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovsd %xmm10, (%rdx)
|
||||
movhlps %xmm10, %xmm10
|
||||
vmovss %xmm10, 8(%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovsd %xmm12, (%rdx)
|
||||
movhlps %xmm12, %xmm12
|
||||
vmovss %xmm12, 8(%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovsd %xmm14, (%rdx)
|
||||
movhlps %xmm14, %xmm14
|
||||
vmovss %xmm14, 8(%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $12, %rdx
|
||||
jmp WriteEnd
|
||||
Write4:
|
||||
movq %rdx, %rax
|
||||
addq $16, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovups %xmm4, (%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm6, (%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm8, (%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm10, (%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm12, (%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm14, (%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $16, %rdx
|
||||
jmp WriteEnd
|
||||
Write5:
|
||||
movq %rdx, %rax
|
||||
addq $20, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovups %xmm4, (%rdx)
|
||||
vextractf128 $1, %ymm4, %xmm4
|
||||
vmovss %xmm4, 16(%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm6, (%rdx)
|
||||
vextractf128 $1, %ymm6, %xmm6
|
||||
vmovss %xmm6, 16(%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm8, (%rdx)
|
||||
vextractf128 $1, %ymm8, %xmm8
|
||||
vmovss %xmm8, 16(%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm10, (%rdx)
|
||||
vextractf128 $1, %ymm10, %xmm10
|
||||
vmovss %xmm10, 16(%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm12, (%rdx)
|
||||
vextractf128 $1, %ymm12, %xmm12
|
||||
vmovss %xmm12, 16(%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm14, (%rdx)
|
||||
vextractf128 $1, %ymm14, %xmm14
|
||||
vmovss %xmm14, 16(%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $20, %rdx
|
||||
jmp WriteEnd
|
||||
Write6:
|
||||
movq %rdx, %rax
|
||||
addq $24, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovups %xmm4, (%rdx)
|
||||
vextractf128 $1, %ymm4, %xmm4
|
||||
vmovsd %xmm4, 16(%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm6, (%rdx)
|
||||
vextractf128 $1, %ymm6, %xmm6
|
||||
vmovsd %xmm6, 16(%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm8, (%rdx)
|
||||
vextractf128 $1, %ymm8, %xmm8
|
||||
vmovsd %xmm8, 16(%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm10, (%rdx)
|
||||
vextractf128 $1, %ymm10, %xmm10
|
||||
vmovsd %xmm10, 16(%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm12, (%rdx)
|
||||
vextractf128 $1, %ymm12, %xmm12
|
||||
vmovsd %xmm12, 16(%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm14, (%rdx)
|
||||
vextractf128 $1, %ymm14, %xmm14
|
||||
vmovsd %xmm14, 16(%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $24, %rdx
|
||||
jmp WriteEnd
|
||||
Write7:
|
||||
movq %rdx, %rax
|
||||
addq $28, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovups %xmm4, (%rdx)
|
||||
vextractf128 $1, %ymm4, %xmm4
|
||||
vmovsd %xmm4, 16(%rdx)
|
||||
movhlps %xmm4, %xmm4
|
||||
vmovss %xmm4, 24(%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm6, (%rdx)
|
||||
vextractf128 $1, %ymm6, %xmm6
|
||||
vmovsd %xmm6, 16(%rdx)
|
||||
movhlps %xmm6, %xmm6
|
||||
vmovss %xmm6, 24(%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm8, (%rdx)
|
||||
vextractf128 $1, %ymm8, %xmm8
|
||||
vmovsd %xmm8, 16(%rdx)
|
||||
movhlps %xmm8, %xmm8
|
||||
vmovss %xmm8, 24(%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm10, (%rdx)
|
||||
vextractf128 $1, %ymm10, %xmm10
|
||||
vmovsd %xmm10, 16(%rdx)
|
||||
movhlps %xmm10, %xmm10
|
||||
vmovss %xmm10, 24(%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm12, (%rdx)
|
||||
vextractf128 $1, %ymm12, %xmm12
|
||||
vmovsd %xmm12, 16(%rdx)
|
||||
movhlps %xmm12, %xmm12
|
||||
vmovss %xmm12, 24(%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %xmm14, (%rdx)
|
||||
vextractf128 $1, %ymm14, %xmm14
|
||||
vmovsd %xmm14, 16(%rdx)
|
||||
movhlps %xmm14, %xmm14
|
||||
vmovss %xmm14, 24(%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $28, %rdx
|
||||
jmp WriteEnd
|
||||
Write8:
|
||||
movq %rdx, %rax
|
||||
addq $32, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovups %ymm4, (%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm6, (%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm8, (%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm10, (%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm12, (%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm14, (%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $32, %rdx
|
||||
jmp WriteEnd
|
||||
Write9:
|
||||
movq %rdx, %rax
|
||||
addq $36, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovups %ymm4, (%rdx)
|
||||
vmovss %xmm5, 32(%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm6, (%rdx)
|
||||
vmovss %xmm7, 32(%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm8, (%rdx)
|
||||
vmovss %xmm9, 32(%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm10, (%rdx)
|
||||
vmovss %xmm11, 32(%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm12, (%rdx)
|
||||
vmovss %xmm13, 32(%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm14, (%rdx)
|
||||
vmovss %xmm15, 32(%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $36, %rdx
|
||||
jmp WriteEnd
|
||||
Write10:
|
||||
movq %rdx, %rax
|
||||
addq $40, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovups %ymm4, (%rdx)
|
||||
vmovsd %xmm5, 32(%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm6, (%rdx)
|
||||
vmovsd %xmm7, 32(%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm8, (%rdx)
|
||||
vmovsd %xmm9, 32(%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm10, (%rdx)
|
||||
vmovsd %xmm11, 32(%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm12, (%rdx)
|
||||
vmovsd %xmm13, 32(%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm14, (%rdx)
|
||||
vmovsd %xmm15, 32(%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $40, %rdx
|
||||
jmp WriteEnd
|
||||
Write11:
|
||||
movq %rdx, %rax
|
||||
addq $44, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovups %ymm4, (%rdx)
|
||||
vmovsd %xmm5, 32(%rdx)
|
||||
movhlps %xmm5, %xmm5
|
||||
vmovss %xmm5, 40(%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm6, (%rdx)
|
||||
vmovsd %xmm7, 32(%rdx)
|
||||
movhlps %xmm7, %xmm7
|
||||
vmovss %xmm7, 40(%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm8, (%rdx)
|
||||
vmovsd %xmm9, 32(%rdx)
|
||||
movhlps %xmm9, %xmm9
|
||||
vmovss %xmm9, 40(%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm10, (%rdx)
|
||||
vmovsd %xmm11, 32(%rdx)
|
||||
movhlps %xmm11, %xmm11
|
||||
vmovss %xmm11, 40(%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm12, (%rdx)
|
||||
vmovsd %xmm13, 32(%rdx)
|
||||
movhlps %xmm13, %xmm13
|
||||
vmovss %xmm13, 40(%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm14, (%rdx)
|
||||
vmovsd %xmm15, 32(%rdx)
|
||||
movhlps %xmm15, %xmm15
|
||||
vmovss %xmm15, 40(%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $44, %rdx
|
||||
jmp WriteEnd
|
||||
Write12:
|
||||
movq %rdx, %rax
|
||||
addq $48, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovups %ymm4, (%rdx)
|
||||
vmovups %xmm5, 32(%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm6, (%rdx)
|
||||
vmovups %xmm7, 32(%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm8, (%rdx)
|
||||
vmovups %xmm9, 32(%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm10, (%rdx)
|
||||
vmovups %xmm11, 32(%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm12, (%rdx)
|
||||
vmovups %xmm13, 32(%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm14, (%rdx)
|
||||
vmovups %xmm15, 32(%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $48, %rdx
|
||||
jmp WriteEnd
|
||||
Write13:
|
||||
movq %rdx, %rax
|
||||
addq $52, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovups %ymm4, (%rdx)
|
||||
vmovups %xmm5, 32(%rdx)
|
||||
vextractf128 $1, %ymm5, %xmm5
|
||||
vmovss %xmm5, 48(%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm6, (%rdx)
|
||||
vmovups %xmm7, 32(%rdx)
|
||||
vextractf128 $1, %ymm7, %xmm7
|
||||
vmovss %xmm7, 48(%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm8, (%rdx)
|
||||
vmovups %xmm9, 32(%rdx)
|
||||
vextractf128 $1, %ymm9, %xmm9
|
||||
vmovss %xmm9, 48(%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm10, (%rdx)
|
||||
vmovups %xmm11, 32(%rdx)
|
||||
vextractf128 $1, %ymm11, %xmm11
|
||||
vmovss %xmm11, 48(%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm12, (%rdx)
|
||||
vmovups %xmm13, 32(%rdx)
|
||||
vextractf128 $1, %ymm13, %xmm13
|
||||
vmovss %xmm13, 48(%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm14, (%rdx)
|
||||
vmovups %xmm15, 32(%rdx)
|
||||
vextractf128 $1, %ymm15, %xmm15
|
||||
vmovss %xmm15, 48(%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $52, %rdx
|
||||
jmp WriteEnd
|
||||
Write14:
|
||||
movq %rdx, %rax
|
||||
addq $56, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovups %ymm4, (%rdx)
|
||||
vmovups %xmm5, 32(%rdx)
|
||||
vextractf128 $1, %ymm5, %xmm5
|
||||
vmovsd %xmm5, 48(%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm6, (%rdx)
|
||||
vmovups %xmm7, 32(%rdx)
|
||||
vextractf128 $1, %ymm7, %xmm7
|
||||
vmovsd %xmm7, 48(%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm8, (%rdx)
|
||||
vmovups %xmm9, 32(%rdx)
|
||||
vextractf128 $1, %ymm9, %xmm9
|
||||
vmovsd %xmm9, 48(%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm10, (%rdx)
|
||||
vmovups %xmm11, 32(%rdx)
|
||||
vextractf128 $1, %ymm11, %xmm11
|
||||
vmovsd %xmm11, 48(%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm12, (%rdx)
|
||||
vmovups %xmm13, 32(%rdx)
|
||||
vextractf128 $1, %ymm13, %xmm13
|
||||
vmovsd %xmm13, 48(%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm14, (%rdx)
|
||||
vmovups %xmm15, 32(%rdx)
|
||||
vextractf128 $1, %ymm15, %xmm15
|
||||
vmovsd %xmm15, 48(%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $56, %rdx
|
||||
jmp WriteEnd
|
||||
Write15:
|
||||
movq %rdx, %rax
|
||||
addq $60, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovups %ymm4, (%rdx)
|
||||
vmovups %xmm5, 32(%rdx)
|
||||
vextractf128 $1, %ymm5, %xmm5
|
||||
vmovsd %xmm5, 48(%rdx)
|
||||
movhlps %xmm5, %xmm5
|
||||
vmovss %xmm5, 56(%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm6, (%rdx)
|
||||
vmovups %xmm7, 32(%rdx)
|
||||
vextractf128 $1, %ymm7, %xmm7
|
||||
vmovsd %xmm7, 48(%rdx)
|
||||
movhlps %xmm7, %xmm7
|
||||
vmovss %xmm7, 56(%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm8, (%rdx)
|
||||
vmovups %xmm9, 32(%rdx)
|
||||
vextractf128 $1, %ymm9, %xmm9
|
||||
vmovsd %xmm9, 48(%rdx)
|
||||
movhlps %xmm9, %xmm9
|
||||
vmovss %xmm9, 56(%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm10, (%rdx)
|
||||
vmovups %xmm11, 32(%rdx)
|
||||
vextractf128 $1, %ymm11, %xmm11
|
||||
vmovsd %xmm11, 48(%rdx)
|
||||
movhlps %xmm11, %xmm11
|
||||
vmovss %xmm11, 56(%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm12, (%rdx)
|
||||
vmovups %xmm13, 32(%rdx)
|
||||
vextractf128 $1, %ymm13, %xmm13
|
||||
vmovsd %xmm13, 48(%rdx)
|
||||
movhlps %xmm13, %xmm13
|
||||
vmovss %xmm13, 56(%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm14, (%rdx)
|
||||
vmovups %xmm15, 32(%rdx)
|
||||
vextractf128 $1, %ymm15, %xmm15
|
||||
vmovsd %xmm15, 48(%rdx)
|
||||
movhlps %xmm15, %xmm15
|
||||
vmovss %xmm15, 56(%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $60, %rdx
|
||||
jmp WriteEnd
|
||||
WriteC8:
|
||||
movq %rdx, %rax
|
||||
addq %r11, %rdx
|
||||
movq %rdx, %r15
|
||||
addq %r11, %rdx
|
||||
movq %rdx, -80(%rsp)
|
||||
vmovups %ymm4, (%rax)
|
||||
vmovups %ymm6, 32(%rax)
|
||||
vmovups %ymm8, 64(%rax)
|
||||
vmovups %ymm10, 96(%rax)
|
||||
vmovups %ymm12, 128(%rax)
|
||||
vmovups %ymm14, 160(%rax)
|
||||
vmovups %ymm5, (%r15)
|
||||
vmovups %ymm7, 32(%r15)
|
||||
vmovups %ymm9, 64(%r15)
|
||||
vmovups %ymm11, 96(%r15)
|
||||
vmovups %ymm13, 128(%r15)
|
||||
vmovups %ymm15, 160(%r15)
|
||||
jmp WriteEnd
|
||||
WriteWino:
|
||||
movq %rdx, %rax
|
||||
addq %r13, %rdx
|
||||
movq %rdx, %r15
|
||||
addq %r13, %rdx
|
||||
movq %rdx, -80(%rsp)
|
||||
vmovups %ymm4, (%rax)
|
||||
vmovups %ymm5, (%r15)
|
||||
addq %r12, %rax
|
||||
addq %r12, %r15
|
||||
vmovups %ymm6, (%rax)
|
||||
vmovups %ymm7, (%r15)
|
||||
addq %r12, %rax
|
||||
addq %r12, %r15
|
||||
vmovups %ymm8, (%rax)
|
||||
vmovups %ymm9, (%r15)
|
||||
addq %r12, %rax
|
||||
addq %r12, %r15
|
||||
vmovups %ymm10, (%rax)
|
||||
vmovups %ymm11, (%r15)
|
||||
addq %r12, %rax
|
||||
addq %r12, %r15
|
||||
vmovups %ymm12, (%rax)
|
||||
vmovups %ymm13, (%r15)
|
||||
addq %r12, %rax
|
||||
addq %r12, %r15
|
||||
vmovups %ymm14, (%rax)
|
||||
vmovups %ymm15, (%r15)
|
||||
jmp WriteEnd
|
||||
Write16:
|
||||
movq %rdx, %rax
|
||||
addq $64, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
vmovups %ymm4, (%rdx)
|
||||
vmovups %ymm5, 32(%rdx)
|
||||
cmpq $1, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm6, (%rdx)
|
||||
vmovups %ymm7, 32(%rdx)
|
||||
cmpq $2, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm8, (%rdx)
|
||||
vmovups %ymm9, 32(%rdx)
|
||||
cmpq $3, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm10, (%rdx)
|
||||
vmovups %ymm11, 32(%rdx)
|
||||
cmpq $4, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm12, (%rdx)
|
||||
vmovups %ymm13, 32(%rdx)
|
||||
cmpq $5, %rbp
|
||||
je WriteEnd
|
||||
addq %r10, %rdx
|
||||
vmovups %ymm14, (%rdx)
|
||||
vmovups %ymm15, 32(%rdx)
|
||||
addq %r10, %rdx
|
||||
addq $64, %rdx
|
||||
|
||||
WriteEnd:
|
||||
cmpq $16, %rbx
|
||||
jbe LoopColEnd
|
||||
subq $16, %rbx
|
||||
jmp LoopCol
|
||||
|
||||
LoopColEnd:
|
||||
movq -96(%rsp), %rdi
|
||||
addq %r11, %rdi
|
||||
movq %rdi, -96(%rsp)
|
||||
cmpq $0, %r14
|
||||
je C8DstStep
|
||||
cmpq $2, %r14
|
||||
je WinoDstStep
|
||||
movq $4, %rax
|
||||
movq 16(%rsp), %rbx
|
||||
imul %rbx, %rax
|
||||
subq %rax, %rdx
|
||||
movq %rdx, -80(%rsp)
|
||||
jmp NoDstStep
|
||||
C8DstStep:
|
||||
movq -80(%rsp), %rax
|
||||
addq $384, %rax
|
||||
movq %rax, -80(%rsp)
|
||||
jmp NoDstStep
|
||||
WinoDstStep:
|
||||
addq %r13, %rdx
|
||||
movq %rdx, -80(%rsp)
|
||||
NoDstStep:
|
||||
cmpq $6, %rbp
|
||||
jbe LoopRowEnd
|
||||
subq $6, %rbp
|
||||
jmp LoopRow
|
||||
|
||||
LoopRowEnd:
|
||||
subq $96, %rsp
|
||||
popq %rdi
|
||||
popq %rsi
|
||||
popq %rdx
|
||||
popq %rcx
|
||||
popq %r8
|
||||
popq %r9
|
||||
popq %rbp
|
||||
popq %rbx
|
||||
popq %r12
|
||||
popq %r13
|
||||
popq %r14
|
||||
popq %r15
|
||||
retq
|
||||
#endif
|
||||
#endif
|
|
@ -56,7 +56,7 @@ void AdderFp32(const float *input_data, float *packed_input, const float *packed
|
|||
int out_channel = conv_param->output_channel_;
|
||||
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
|
||||
int output_count = conv_param->output_h_ * conv_param->output_w_;
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
const int cal_num = C4NUM;
|
||||
#else
|
||||
const int cal_num = C12NUM;
|
||||
|
@ -78,7 +78,7 @@ void AdderFp32(const float *input_data, float *packed_input, const float *packed
|
|||
|
||||
int out_offset = thread_id * cal_num * out_channel + out_batch_offset;
|
||||
float *gemm_output = output_data + out_offset;
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep);
|
||||
#else
|
||||
RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep);
|
||||
|
|
|
@ -43,7 +43,7 @@ void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_p
|
|||
|
||||
void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
|
||||
size_t plane_size, size_t stride, size_t relu_type) {
|
||||
#if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE)
|
||||
#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
|
||||
PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, relu_type, C8NUM);
|
||||
#else
|
||||
size_t oc8mod = output_channel % C8NUM;
|
||||
|
@ -68,7 +68,7 @@ void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bi
|
|||
return;
|
||||
}
|
||||
|
||||
#if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE)
|
||||
#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
|
||||
void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) {
|
||||
const int unitStep = 4 * length;
|
||||
for (int y = 0; y < h; ++y) {
|
||||
|
|
|
@ -39,7 +39,7 @@ float ShortToFloat32(uint16_t src_value);
|
|||
|
||||
uint16_t Float32ToShort(float src_value);
|
||||
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
void ConvDwFp32Center(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
|
||||
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
|
||||
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, size_t relu, size_t relu6);
|
||||
|
|
|
@ -202,7 +202,7 @@ void ConvDwBorder(float *dst, const float *src, const float *weight, const float
|
|||
const float *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_;
|
||||
const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM;
|
||||
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
ConvDwFp32Border(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
|
||||
sliding->in_kh_step_ * sizeof(float), sliding->in_kw_step_ * sizeof(float),
|
||||
conv_param->kernel_w_ * C4NUM * sizeof(float), relu, relu6);
|
||||
|
@ -285,7 +285,7 @@ void ConvDwSWFp32(float *output_data, const float *input_data, const float *weig
|
|||
int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_;
|
||||
const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_;
|
||||
float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
|
||||
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
|
||||
sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
|
||||
|
@ -839,7 +839,7 @@ void DeconvDwSWFp32(float *output_data, const float *input_data, const float *we
|
|||
float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_;
|
||||
const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_;
|
||||
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
|
||||
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float),
|
||||
sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float),
|
||||
|
|
|
@ -26,7 +26,9 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_
|
|||
int out_channel = conv_param->output_channel_;
|
||||
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
|
||||
int output_count = conv_param->output_h_ * conv_param->output_w_;
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#ifdef ENABLE_AVX
|
||||
const int cal_num = C6NUM;
|
||||
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
const int cal_num = C4NUM;
|
||||
#else
|
||||
const int cal_num = C12NUM;
|
||||
|
@ -48,7 +50,9 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_
|
|||
|
||||
int out_offset = thread_id * cal_num * out_channel + out_batch_offset;
|
||||
float *gemm_output = output_data + out_offset;
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#ifdef ENABLE_AVX
|
||||
RowMajor2Col6Major(gemm_input, col_major_gemm_input, cal_num, deep);
|
||||
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
RowMajor2Col4Major(gemm_input, col_major_gemm_input, cal_num, deep);
|
||||
#else
|
||||
RowMajor2Col12Major(gemm_input, col_major_gemm_input, cal_num, deep);
|
||||
|
@ -97,7 +101,7 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
|
|||
float *dst_ptr = gemm_out + task_id * gemm_out_offset;
|
||||
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
|
||||
for (int i = 0; i < input_unit_square; ++i) {
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
RowMajor2Col4Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
|
||||
#else
|
||||
RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel);
|
||||
|
|
|
@ -41,7 +41,7 @@ void DeConvPostFp32C8(const float *src, float *tmp, const float *bias, float *ds
|
|||
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
|
||||
size_t output_plane = conv_param->output_w_ * conv_param->output_h_;
|
||||
int oc8 = UP_ROUND(output_channel, C8NUM);
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
const int tile_num = 4;
|
||||
#else
|
||||
const int tile_num = 12;
|
||||
|
|
|
@ -28,9 +28,21 @@ void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col)
|
|||
for (int r = 0; r < row; r++) {
|
||||
const float *src = src_ptr + r * col;
|
||||
for (int c = 0; c < col; c++) {
|
||||
int cd8 = c / 4;
|
||||
int cm8 = c % 4;
|
||||
dst_ptr[cd8 * 4 * row + r * 4 + cm8] = src[c];
|
||||
int cd4 = c / C4NUM;
|
||||
int cm4 = c % C4NUM;
|
||||
dst_ptr[cd4 * C4NUM * row + r * C4NUM + cm4] = src[c];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col) {
|
||||
for (int r = 0; r < row; r++) {
|
||||
const float *src = src_ptr + r * col;
|
||||
for (int c = 0; c < col; c++) {
|
||||
int cd6 = c / C6NUM;
|
||||
int cm6 = c % C6NUM;
|
||||
dst_ptr[cd6 * C6NUM * row + r * C6NUM + cm6] = src[c];
|
||||
}
|
||||
}
|
||||
return;
|
||||
|
@ -40,9 +52,9 @@ void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col)
|
|||
for (int r = 0; r < row; r++) {
|
||||
const float *src = src_ptr + r * col;
|
||||
for (int c = 0; c < col; c++) {
|
||||
int cd8 = c / 8;
|
||||
int cm8 = c % 8;
|
||||
dst_ptr[cd8 * 8 * row + r * 8 + cm8] = src[c];
|
||||
int cd8 = c / C8NUM;
|
||||
int cm8 = c % C8NUM;
|
||||
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[c];
|
||||
}
|
||||
}
|
||||
return;
|
||||
|
@ -52,9 +64,21 @@ void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col)
|
|||
for (int r = 0; r < row; r++) {
|
||||
const float *src = src_ptr + r * col;
|
||||
for (int c = 0; c < col; c++) {
|
||||
int cd8 = c / C12NUM;
|
||||
int cm8 = c % C12NUM;
|
||||
dst_ptr[cd8 * C12NUM * row + r * C12NUM + cm8] = src[c];
|
||||
int cd12 = c / C12NUM;
|
||||
int cm12 = c % C12NUM;
|
||||
dst_ptr[cd12 * C12NUM * row + r * C12NUM + cm12] = src[c];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col) {
|
||||
for (int r = 0; r < row; r++) {
|
||||
const float *src = src_ptr + r * col;
|
||||
for (int c = 0; c < col; c++) {
|
||||
int cd16 = c / C16NUM;
|
||||
int cm16 = c % C16NUM;
|
||||
dst_ptr[cd16 * C16NUM * row + r * C16NUM + cm16] = src[c];
|
||||
}
|
||||
}
|
||||
return;
|
||||
|
@ -190,7 +214,7 @@ void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, size_t row, size_
|
|||
:
|
||||
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
|
||||
: "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
|
||||
#elif ENABLE_X86_64_SSE
|
||||
#elif ENABLE_SSE
|
||||
__m128 src1 = _mm_loadu_ps(src_c);
|
||||
__m128 src2 = _mm_loadu_ps(src_c + col);
|
||||
__m128 src3 = _mm_loadu_ps(src_c + 2 * col);
|
||||
|
@ -421,7 +445,7 @@ void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t
|
|||
:
|
||||
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
|
||||
: "r10", "r11", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7");
|
||||
#elif ENABLE_X86_64_SSE
|
||||
#elif ENABLE_SSE
|
||||
/* 8x4 row-major to col-major */
|
||||
__m128 src1 = _mm_loadu_ps(src_c);
|
||||
__m128 src2 = _mm_loadu_ps(src_c + col);
|
||||
|
@ -478,6 +502,145 @@ void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t
|
|||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
|
||||
size_t row16 = row / C16NUM * C16NUM;
|
||||
size_t col_skip = col / C4NUM * C4NUM;
|
||||
int skip_size = C4NUM;
|
||||
const float *src_r = src_ptr;
|
||||
float *dst_r = dst_ptr;
|
||||
|
||||
size_t ri = 0;
|
||||
for (; ri < row16; ri += C16NUM) {
|
||||
size_t ci = 0;
|
||||
for (; ci < col_skip; ci += skip_size) {
|
||||
const float *src_c = src_r + ci;
|
||||
float *dst_c = dst_r + ci * C16NUM;
|
||||
for (int tr = 0; tr < C16NUM; tr++) {
|
||||
for (int tc = 0; tc < C4NUM; tc++) {
|
||||
dst_c[tc * C16NUM + tr] = src_c[tr * col + tc];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (; ci < col; ci++) {
|
||||
const float *src_c = src_r + ci;
|
||||
float *dst_c = dst_r + ci * C16NUM;
|
||||
for (size_t i = 0; i < C16NUM; i++) {
|
||||
dst_c[i] = src_c[i * col];
|
||||
}
|
||||
}
|
||||
src_r += C16NUM * col;
|
||||
dst_r += C16NUM * col;
|
||||
}
|
||||
for (; ri < row; ri++) {
|
||||
for (size_t i = 0; i < col; i++) {
|
||||
dst_r[i * C16NUM] = src_r[i];
|
||||
}
|
||||
src_r += col;
|
||||
dst_r += 1;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
|
||||
size_t totalRow = UP_ROUND(row, C6NUM);
|
||||
size_t row6 = row / C6NUM * C6NUM;
|
||||
size_t col8 = col / C8NUM * C8NUM;
|
||||
const float *src_r = src_ptr;
|
||||
float *dst_r = dst_ptr;
|
||||
|
||||
size_t ri = 0;
|
||||
for (; ri < row6; ri += C6NUM) {
|
||||
size_t ci = 0;
|
||||
for (; ci < col8; ci += C8NUM) {
|
||||
const float *src_c = src_r + ci;
|
||||
float *dst_c = dst_r + ci * C6NUM;
|
||||
|
||||
/* 6x8 row-major to col-major */
|
||||
#ifdef ENABLE_AVX
|
||||
__m256 src0 = _mm256_loadu_ps(src_c);
|
||||
__m256 src1 = _mm256_loadu_ps(src_c + col);
|
||||
__m256 src2 = _mm256_loadu_ps(src_c + 2 * col);
|
||||
__m256 src3 = _mm256_loadu_ps(src_c + 3 * col);
|
||||
__m256 src4 = _mm256_loadu_ps(src_c + 4 * col);
|
||||
__m256 src5 = _mm256_loadu_ps(src_c + 5 * col);
|
||||
__m256 trans0 = _mm256_unpacklo_ps(src0, src1);
|
||||
__m256 trans1 = _mm256_unpacklo_ps(src2, src3);
|
||||
__m256 trans2 = _mm256_unpacklo_ps(src4, src5);
|
||||
__m256 trans3 = _mm256_unpackhi_ps(src0, src1);
|
||||
__m256 trans4 = _mm256_unpackhi_ps(src2, src3);
|
||||
__m256 trans5 = _mm256_unpackhi_ps(src4, src5);
|
||||
__m128 lo0 = _mm256_castps256_ps128(trans0);
|
||||
__m128 lo1 = _mm256_castps256_ps128(trans1);
|
||||
__m128 lo2 = _mm256_castps256_ps128(trans2);
|
||||
__m128 lo3 = _mm256_castps256_ps128(trans3);
|
||||
__m128 lo4 = _mm256_castps256_ps128(trans4);
|
||||
__m128 lo5 = _mm256_castps256_ps128(trans5);
|
||||
__m128 hi0 = _mm256_extractf128_ps(trans0, 1);
|
||||
__m128 hi1 = _mm256_extractf128_ps(trans1, 1);
|
||||
__m128 hi2 = _mm256_extractf128_ps(trans2, 1);
|
||||
__m128 hi3 = _mm256_extractf128_ps(trans3, 1);
|
||||
__m128 hi4 = _mm256_extractf128_ps(trans4, 1);
|
||||
__m128 hi5 = _mm256_extractf128_ps(trans5, 1);
|
||||
__m128 res0 = _mm_shuffle_ps(lo0, lo1, _MM_SHUFFLE(1, 0, 1, 0));
|
||||
__m128 res1 = _mm_shuffle_ps(lo2, lo0, _MM_SHUFFLE(3, 2, 1, 0));
|
||||
__m128 res2 = _mm_shuffle_ps(lo1, lo2, _MM_SHUFFLE(3, 2, 3, 2));
|
||||
__m128 res3 = _mm_shuffle_ps(lo3, lo4, _MM_SHUFFLE(1, 0, 1, 0));
|
||||
__m128 res4 = _mm_shuffle_ps(lo5, lo3, _MM_SHUFFLE(3, 2, 1, 0));
|
||||
__m128 res5 = _mm_shuffle_ps(lo4, lo5, _MM_SHUFFLE(3, 2, 3, 2));
|
||||
__m128 res6 = _mm_shuffle_ps(hi0, hi1, _MM_SHUFFLE(1, 0, 1, 0));
|
||||
__m128 res7 = _mm_shuffle_ps(hi2, hi0, _MM_SHUFFLE(3, 2, 1, 0));
|
||||
__m128 res8 = _mm_shuffle_ps(hi1, hi2, _MM_SHUFFLE(3, 2, 3, 2));
|
||||
__m128 res9 = _mm_shuffle_ps(hi3, hi4, _MM_SHUFFLE(1, 0, 1, 0));
|
||||
__m128 res10 = _mm_shuffle_ps(hi5, hi3, _MM_SHUFFLE(3, 2, 1, 0));
|
||||
__m128 res11 = _mm_shuffle_ps(hi4, hi5, _MM_SHUFFLE(3, 2, 3, 2));
|
||||
_mm_storeu_ps(dst_c, res0);
|
||||
_mm_storeu_ps(dst_c + 4, res1);
|
||||
_mm_storeu_ps(dst_c + 8, res2);
|
||||
_mm_storeu_ps(dst_c + 12, res3);
|
||||
_mm_storeu_ps(dst_c + 16, res4);
|
||||
_mm_storeu_ps(dst_c + 20, res5);
|
||||
_mm_storeu_ps(dst_c + 24, res6);
|
||||
_mm_storeu_ps(dst_c + 28, res7);
|
||||
_mm_storeu_ps(dst_c + 32, res8);
|
||||
_mm_storeu_ps(dst_c + 36, res9);
|
||||
_mm_storeu_ps(dst_c + 40, res10);
|
||||
_mm_storeu_ps(dst_c + 44, res11);
|
||||
#else
|
||||
for (int tr = 0; tr < C6NUM; tr++) {
|
||||
for (int tc = 0; tc < C8NUM; tc++) {
|
||||
dst_c[tc * C6NUM + tr] = src_c[tr * col + tc];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
for (; ci < col; ci++) {
|
||||
const float *src_c = src_r + ci;
|
||||
float *dst_c = dst_r + ci * C6NUM;
|
||||
for (size_t i = 0; i < C6NUM; i++) {
|
||||
dst_c[i] = src_c[i * col];
|
||||
}
|
||||
}
|
||||
src_r += C6NUM * col;
|
||||
dst_r += C6NUM * col;
|
||||
}
|
||||
|
||||
for (; ri < row; ri++) {
|
||||
for (size_t i = 0; i < col; i++) {
|
||||
dst_r[i * C6NUM] = src_r[i];
|
||||
}
|
||||
src_r += col;
|
||||
dst_r += 1;
|
||||
}
|
||||
|
||||
for (; ri < totalRow; ri++) {
|
||||
for (size_t i = 0; i < col; i++) {
|
||||
dst_r[i * C6NUM] = 0;
|
||||
}
|
||||
dst_r += 1;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
|
||||
size_t row8 = row / C4NUM * C4NUM;
|
||||
size_t col4 = col / C4NUM * C4NUM;
|
||||
|
@ -519,7 +682,7 @@ void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t
|
|||
:
|
||||
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
|
||||
: "r10", "r12", "q0", "q1", "q2", "q3");
|
||||
#elif ENABLE_X86_64_SSE
|
||||
#elif ENABLE_SSE
|
||||
__m128 src1 = _mm_loadu_ps(src_c);
|
||||
__m128 src2 = _mm_loadu_ps(src_c + col);
|
||||
__m128 src3 = _mm_loadu_ps(src_c + 2 * col);
|
||||
|
@ -630,6 +793,34 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
|
|||
return;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
#ifdef WIN32
|
||||
void MatMul6x16(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, int stride, int out_type) {
|
||||
if (out_type == OutType_Nhwc) {
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r6div = r / C6NUM, r6mod = r % C6NUM;
|
||||
int c16div = c / C16NUM, c16mod = c % C16NUM;
|
||||
size_t ci = r * stride + c;
|
||||
float value = 0;
|
||||
for (int d = 0; d < deep; d++) {
|
||||
size_t ai = r6div * deep * C6NUM + d * C6NUM + r6mod;
|
||||
size_t bi = c16div * deep * C16NUM + d * C16NUM + c16mod;
|
||||
value = value + a[ai] * b[bi];
|
||||
}
|
||||
if (bias != NULL) value += bias[c];
|
||||
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
|
||||
if (act_type != ActType_No) value = MSMAX(0.0f, value);
|
||||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
void MatMul4x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||
int col, int stride, int out_type) {
|
||||
if (out_type == OutType_C8) {
|
||||
|
@ -670,7 +861,19 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
|
|||
} else {
|
||||
MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
|
||||
}
|
||||
#elif ENABLE_X86_64_SSE
|
||||
#elif ENABLE_AVX
|
||||
if (out_type == OutType_Nhwc) {
|
||||
#ifdef WIN32
|
||||
MatMul6x16(a, b, c, bias, act_type, deep, row, col, stride, out_type);
|
||||
#else
|
||||
MatmulFloatAvxOpt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
|
||||
#endif
|
||||
} else if (out_type == OutType_C8) {
|
||||
MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
|
||||
} else {
|
||||
MatmulFloatSse64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
|
||||
}
|
||||
#elif ENABLE_SSE
|
||||
if (out_type == OutType_C8) {
|
||||
MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
|
||||
} else {
|
||||
|
|
|
@ -31,11 +31,15 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
|
|||
void MatVecMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int col);
|
||||
void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Row4Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
#ifdef ENABLE_ARM
|
||||
void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int col);
|
||||
#endif
|
||||
|
@ -49,11 +53,16 @@ void MatmulFloatNeon32(const float *a, const float *b, float *c, const float *bi
|
|||
int col, int stride, size_t writeNhwc, size_t WriteWino);
|
||||
void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col, int stride, int write_mode);
|
||||
#elif ENABLE_X86_64_SSE
|
||||
#elif ENABLE_SSE
|
||||
#include <x86intrin.h>
|
||||
void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col, int stride, size_t writeNhwc, size_t WriteWino);
|
||||
void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col, int stride, int write_mode);
|
||||
#ifdef ENABLE_AVX
|
||||
void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
int col, int stride, int write_mode);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_NNACL_INFER_SHAPE
|
||||
|
|
|
@ -42,6 +42,7 @@ typedef struct MatMulParameter {
|
|||
int row_;
|
||||
int col_;
|
||||
int row_4_;
|
||||
int row_6_;
|
||||
int row_8_;
|
||||
int row_12_;
|
||||
int row_16_;
|
||||
|
|
|
@ -125,7 +125,7 @@ int B(const float *poly_array, float *matrix_b, int in_unit) {
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
#if !defined(ENABLE_ARM) && !defined(ENABLE_X86_64_SSE)
|
||||
#if !defined(ENABLE_ARM) && !defined(ENABLE_SSE)
|
||||
void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n,
|
||||
int in_channel, int c4_channel) {
|
||||
int cnt = 0;
|
||||
|
@ -228,7 +228,7 @@ int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *ma
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c,
|
||||
const float *bias, int m, int k, int n) {
|
||||
int count = 0;
|
||||
|
|
|
@ -52,7 +52,7 @@ void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *
|
|||
int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, const float *matrix_gt,
|
||||
int oc_block, int input_unit_, int kernel_unit_, int channel, int batch, bool pack);
|
||||
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
void MatrixMultiplyVec(const MS_FLOAT32X4 *matrix_a, const MS_FLOAT32X4 *matrix_b, MS_FLOAT32X4 *matrix_c,
|
||||
const float *bias, int m, int k, int n);
|
||||
#endif
|
||||
|
|
|
@ -21,8 +21,8 @@
|
|||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_X86_64_SSE
|
||||
#include <nmmintrin.h>
|
||||
#ifdef ENABLE_SSE
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
|
@ -31,6 +31,7 @@
|
|||
|
||||
#define C2NUM 2
|
||||
#define C4NUM 4
|
||||
#define C6NUM 6
|
||||
#define C8NUM 8
|
||||
#define C12NUM 12
|
||||
#define C16NUM 16
|
||||
|
@ -91,7 +92,7 @@ typedef enum ActType { ActType_No, ActType_Relu, ActType_Sigmod, ActType_Relu6,
|
|||
#define MS_MAXQ_F32 vmaxq_f32
|
||||
#define MS_MINQ_F32 vminq_f32
|
||||
#define MS_MULQ_F32(src1, src2) vmulq_n_f32(src1, src2)
|
||||
#elif defined(ENABLE_X86_64_SSE)
|
||||
#elif defined(ENABLE_SSE)
|
||||
#define MS_FLOAT32X4 __m128
|
||||
#define MS_LDQ_F32 _mm_loadu_ps
|
||||
#define MS_ADDQ_F32 _mm_add_ps
|
||||
|
|
|
@ -756,7 +756,7 @@ void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int ch
|
|||
return;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_X86_64_SSE
|
||||
#ifndef ENABLE_SSE
|
||||
void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int channel) {
|
||||
int hw8 = plane / C8NUM * C8NUM;
|
||||
int c8 = channel / C8NUM * C8NUM;
|
||||
|
|
|
@ -79,7 +79,7 @@ void GeneralInputTransformUnit(const float *src_data, float *dst_data, const flo
|
|||
int src_step, int dst_step, int in_unit) {
|
||||
int len = in_unit * in_unit;
|
||||
if (len > MAX_LEN) return;
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_X86_64_SSE)
|
||||
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
|
||||
MS_FLOAT32X4 src[MAX_LEN];
|
||||
MS_FLOAT32X4 t[MAX_LEN];
|
||||
MS_FLOAT32X4 m[MAX_LEN];
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifdef ENABLE_X86_64_SSE
|
||||
#include <nmmintrin.h>
|
||||
#ifdef ENABLE_SSE
|
||||
#include <x86intrin.h>
|
||||
#include "nnacl/fp32/conv_depthwise_fp32.h"
|
||||
|
||||
void ConvDwFp32Border(float *dst, const float *src, const float *weight, const float *bias, size_t height, size_t width,
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifdef ENABLE_X86_64_SSE
|
||||
#include <nmmintrin.h>
|
||||
#ifdef ENABLE_SSE
|
||||
#include <x86intrin.h>
|
||||
#include "nnacl/minimal_filtering_generator.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifdef ENABLE_X86_64_SSE
|
||||
#include <nmmintrin.h>
|
||||
#ifdef ENABLE_SSE
|
||||
#include <x86intrin.h>
|
||||
#include "nnacl/pack.h"
|
||||
#include "nnacl/int8/conv_int8.h"
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifdef ENABLE_X86_64_SSE
|
||||
#include <nmmintrin.h>
|
||||
#ifdef ENABLE_SSE
|
||||
#include <x86intrin.h>
|
||||
#include "nnacl/fp32/common_func_fp32.h"
|
||||
|
||||
void PostFuncBiasReluC8(float *dst, const float *src, const float *bias, size_t oc8div, size_t oc8mod,
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifdef ENABLE_X86_64_SSE
|
||||
#include <nmmintrin.h>
|
||||
#ifdef ENABLE_SSE
|
||||
#include <x86intrin.h>
|
||||
#include "nnacl/fp32/common_func_fp32.h"
|
||||
|
||||
void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) {
|
||||
|
|
|
@ -60,6 +60,7 @@ void Convolution1x1CPUKernel::InitConv1x1MatmulParam() {
|
|||
matmul_param_->col_ = conv_param_->output_channel_;
|
||||
matmul_param_->deep_ = conv_param_->input_channel_;
|
||||
matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM);
|
||||
matmul_param_->row_6_ = UP_ROUND(matmul_param_->row_, C6NUM);
|
||||
matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM);
|
||||
matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM);
|
||||
matmul_param_->act_type_ = conv_param_->act_type_;
|
||||
|
@ -71,8 +72,13 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() {
|
|||
auto input_channel = filter_tensor->Channel();
|
||||
auto output_channel = filter_tensor->Batch();
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
int col_tile = C16NUM;
|
||||
#else
|
||||
int col_tile = C8NUM;
|
||||
#endif
|
||||
if (in_tensors_.size() == 3) {
|
||||
int size = UP_ROUND(output_channel, C8NUM) * sizeof(float);
|
||||
int size = UP_ROUND(output_channel, col_tile) * sizeof(float);
|
||||
int weight_size = output_channel * sizeof(float);
|
||||
bias_data_ = malloc(size);
|
||||
if (bias_data_ == nullptr) {
|
||||
|
@ -83,22 +89,29 @@ int Convolution1x1CPUKernel::InitConv1x1BiasWeight() {
|
|||
memset(reinterpret_cast<char *>(bias_data_) + weight_size, 0, size - weight_size);
|
||||
}
|
||||
|
||||
int size = input_channel * UP_ROUND(output_channel, C8NUM) * sizeof(float);
|
||||
int down_size = input_channel * DOWN_DIV(output_channel, C8NUM) * C8NUM * sizeof(float);
|
||||
int size = input_channel * UP_ROUND(output_channel, col_tile) * sizeof(float);
|
||||
int down_size = input_channel * DOWN_DIV(output_channel, col_tile) * col_tile * sizeof(float);
|
||||
weight_ptr_ = reinterpret_cast<float *>(malloc(size));
|
||||
if (weight_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Conv1x1 Malloc weight_ptr_ error!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(reinterpret_cast<char *>(weight_ptr_) + down_size, 0, size - down_size);
|
||||
#ifdef ENABLE_AVX
|
||||
RowMajor2Col16Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel,
|
||||
input_channel);
|
||||
#else
|
||||
RowMajor2Col8Major(reinterpret_cast<float *>(filter_tensor->MutableData()), weight_ptr_, output_channel,
|
||||
input_channel);
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Convolution1x1CPUKernel::InitConv1x1Param() {
|
||||
int hw_tile = C12NUM;
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#ifdef ENABLE_AVX
|
||||
hw_tile = C6NUM;
|
||||
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
hw_tile = C4NUM;
|
||||
#endif
|
||||
if ((matmul_param_->row_ > (hw_tile * op_parameter_->thread_num_)) && (matmul_param_->row_ > matmul_param_->col_)) {
|
||||
|
@ -106,9 +119,14 @@ int Convolution1x1CPUKernel::InitConv1x1Param() {
|
|||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, hw_tile));
|
||||
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, hw_tile), thread_count_) * hw_tile;
|
||||
} else {
|
||||
#ifdef ENABLE_AVX
|
||||
int col_tile = C16NUM;
|
||||
#else
|
||||
int col_tile = C8NUM;
|
||||
#endif
|
||||
multi_thread_by_hw_ = false;
|
||||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM));
|
||||
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_) * C8NUM;
|
||||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, col_tile));
|
||||
thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, col_tile), thread_count_) * col_tile;
|
||||
}
|
||||
|
||||
pre_trans_input_ = (conv_param_->pad_u_ != 0 || conv_param_->pad_l_ != 0 || conv_param_->stride_h_ != 1 ||
|
||||
|
@ -175,7 +193,9 @@ int Convolution1x1CPUKernel::DoConv1x1Hw(int task_id) {
|
|||
float *thread_input_ptr = input_ptr_ + task_id * thread_stride_ * matmul_param_->deep_;
|
||||
float *thread_pack_input = pack_input_ + task_id * thread_stride_ * matmul_param_->deep_;
|
||||
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#if ENABLE_AVX
|
||||
RowMajor2Col6Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
|
||||
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
RowMajor2Col4Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
|
||||
#else
|
||||
RowMajor2Col12Major(thread_input_ptr, thread_pack_input, cur_hw_, matmul_param_->deep_);
|
||||
|
@ -202,7 +222,10 @@ int Convolution1x1CPUKernel::Run() {
|
|||
auto src_in = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
|
||||
auto src_out = reinterpret_cast<float *>(out_tensors_[0]->MutableData());
|
||||
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#ifdef ENABLE_AVX
|
||||
pack_input_ =
|
||||
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_6_ * matmul_param_->deep_ * sizeof(float)));
|
||||
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
pack_input_ =
|
||||
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float)));
|
||||
#else
|
||||
|
@ -226,7 +249,9 @@ int Convolution1x1CPUKernel::Run() {
|
|||
if (multi_thread_by_hw_) {
|
||||
ParallelLaunch(this->context_->thread_pool_, Convolution1x1RunHw, this, thread_count_);
|
||||
} else {
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#ifdef ENABLE_AVX
|
||||
RowMajor2Col6Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
|
||||
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
|
||||
#else
|
||||
RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
|
||||
|
|
|
@ -42,7 +42,12 @@ int ConvolutionCPUKernel::InitWeightBias() {
|
|||
conv_param_->input_channel_ = in_channel;
|
||||
conv_param_->output_channel_ = out_channel;
|
||||
int kernel_plane = filter_tensor->Height() * filter_tensor->Width();
|
||||
int oc_block_num = UP_ROUND(out_channel, C8NUM);
|
||||
#ifdef ENABLE_AVX
|
||||
const int oc_block = C16NUM;
|
||||
#else
|
||||
const int oc_block = C8NUM;
|
||||
#endif
|
||||
int oc_block_num = UP_ROUND(out_channel, oc_block);
|
||||
int pack_weight_size = oc_block_num * in_channel * kernel_plane;
|
||||
|
||||
auto origin_weight = reinterpret_cast<float *>(filter_tensor->data_c());
|
||||
|
@ -52,7 +57,11 @@ int ConvolutionCPUKernel::InitWeightBias() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
memset(packed_weight_, 0, pack_weight_size * sizeof(float));
|
||||
#ifdef ENABLE_AVX
|
||||
RowMajor2Col16Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane);
|
||||
#else
|
||||
RowMajor2Col8Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane);
|
||||
#endif
|
||||
|
||||
bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * sizeof(float)));
|
||||
if (bias_data_ == nullptr) {
|
||||
|
@ -72,7 +81,10 @@ int ConvolutionCPUKernel::InitWeightBias() {
|
|||
|
||||
int ConvolutionCPUKernel::InitTmpBuffer() {
|
||||
MS_ASSERT(ctx_->allocator != nullptr);
|
||||
#ifdef ENABLE_ARM32
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C6NUM * thread_count_;
|
||||
#elif ENABLE_ARM32 || ENABLE_SSE
|
||||
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * conv_param_->input_channel_ * C4NUM * thread_count_;
|
||||
#else
|
||||
int unit_size =
|
||||
|
|
|
@ -115,7 +115,7 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_4_;
|
||||
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
|
||||
tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_4_, oc * C8NUM * kernel_plane_,
|
||||
|
@ -174,7 +174,7 @@ int DeConvolutionCPUKernel::InitRunBuf() {
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
tmp_buffer_ =
|
||||
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->col_8_ * sizeof(float)));
|
||||
#else
|
||||
|
@ -186,7 +186,7 @@ int DeConvolutionCPUKernel::InitRunBuf() {
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
pack_input_ =
|
||||
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float)));
|
||||
#else
|
||||
|
@ -215,7 +215,7 @@ int DeConvolutionCPUKernel::Run() {
|
|||
input_ptr_ = src_in + batch_index * input_plane_ * conv_param_->input_channel_;
|
||||
output_ptr_ = src_out + batch_index * output_plane_ * conv_param_->output_channel_;
|
||||
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
|
||||
#else
|
||||
RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
|
||||
|
|
|
@ -51,12 +51,18 @@ int FullconnectionCPUKernel::ReSize() {
|
|||
fc_param_->col_ = out_tensors_.at(0)->shape().back();
|
||||
fc_param_->deep_ = (in_tensors_.at(1)->shape()).at(1);
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
int col_tile = C16NUM;
|
||||
#else
|
||||
int col_tile = C8NUM;
|
||||
#endif
|
||||
fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM);
|
||||
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM);
|
||||
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, col_tile);
|
||||
fc_param_->row_6_ = UP_ROUND(fc_param_->col_, C6NUM);
|
||||
fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM);
|
||||
|
||||
thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8));
|
||||
thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_);
|
||||
thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, col_tile));
|
||||
thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, col_tile), thread_count_);
|
||||
|
||||
#ifdef ENABLE_ARM
|
||||
if (fc_param_->row_ == 1) {
|
||||
|
@ -75,7 +81,9 @@ int FullconnectionCPUKernel::ReSize() {
|
|||
memcpy(bias_ptr_, in_tensors_[2]->MutableData(), fc_param_->col_ * sizeof(float));
|
||||
}
|
||||
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#ifdef ENABLE_AVX
|
||||
int row_tmp = is_vector_input_ ? 1 : fc_param_->row_6_;
|
||||
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
int row_tmp = is_vector_input_ ? 1 : fc_param_->row_4_;
|
||||
#else
|
||||
int row_tmp = is_vector_input_ ? 1 : fc_param_->row_12_;
|
||||
|
@ -120,7 +128,9 @@ void FullconnectionCPUKernel::InitMatrixA(const float *src_ptr, float *dst_ptr)
|
|||
return;
|
||||
}
|
||||
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#ifdef ENABLE_AVX
|
||||
RowMajor2Col6Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_);
|
||||
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
RowMajor2Col4Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_);
|
||||
#else
|
||||
RowMajor2Col12Major(src_ptr, a_pack_ptr_, fc_param_->row_, fc_param_->deep_);
|
||||
|
@ -132,8 +142,11 @@ void FullconnectionCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr)
|
|||
memcpy(dst_ptr, src_ptr, fc_param_->col_ * fc_param_->deep_ * sizeof(float));
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
RowMajor2Col16Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_);
|
||||
#else
|
||||
RowMajor2Col8Major(src_ptr, dst_ptr, fc_param_->col_, fc_param_->deep_);
|
||||
#endif
|
||||
}
|
||||
|
||||
int FcFp32MatmulRun(void *cdata, int task_id) {
|
||||
|
@ -147,14 +160,19 @@ int FcFp32MatmulRun(void *cdata, int task_id) {
|
|||
}
|
||||
|
||||
int FullconnectionCPUKernel::DoMatmul(int task_id) {
|
||||
int cur_oc = MSMIN(thread_stride_ * C8NUM, fc_param_->col_ - task_id * thread_stride_ * C8NUM);
|
||||
#ifdef ENABLE_AVX
|
||||
int col_tile = C16NUM;
|
||||
#else
|
||||
int col_tile = C8NUM;
|
||||
#endif
|
||||
int cur_oc = MSMIN(thread_stride_ * col_tile, fc_param_->col_ - task_id * thread_stride_ * col_tile);
|
||||
if (cur_oc <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
auto b = b_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_;
|
||||
auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + task_id * thread_stride_ * C8NUM;
|
||||
auto c = c_ptr_ + task_id * thread_stride_ * C8NUM;
|
||||
auto b = b_ptr_ + task_id * thread_stride_ * col_tile * fc_param_->deep_;
|
||||
auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + task_id * thread_stride_ * col_tile;
|
||||
auto c = c_ptr_ + task_id * thread_stride_ * col_tile;
|
||||
if (is_vector_input_) {
|
||||
MatVecMul(a_ptr_, b, c, bias, fc_param_->act_type_, fc_param_->deep_, cur_oc);
|
||||
} else {
|
||||
|
|
|
@ -75,9 +75,12 @@ int MatmulCPUKernel::MallocMatrixABuffer() {
|
|||
#endif
|
||||
params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1];
|
||||
params_->row_4_ = UP_ROUND(params_->row_, C4NUM);
|
||||
params_->row_6_ = UP_ROUND(params_->row_, C6NUM);
|
||||
params_->row_12_ = UP_ROUND(params_->row_, C12NUM);
|
||||
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#ifdef ENABLE_AVX
|
||||
int row_tmp = is_vector_a_ ? 1 : params_->row_6_;
|
||||
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
int row_tmp = is_vector_a_ ? 1 : params_->row_4_;
|
||||
#else
|
||||
int row_tmp = is_vector_a_ ? 1 : params_->row_12_;
|
||||
|
@ -106,9 +109,14 @@ int MatmulCPUKernel::MallocMatrixBBuffer() {
|
|||
for (size_t i = 0; i < b_shape.size() - 2; ++i) {
|
||||
batch *= b_shape[i];
|
||||
}
|
||||
#ifdef ENABLE_AVX
|
||||
int col_tile = C16NUM;
|
||||
#else
|
||||
int col_tile = C8NUM;
|
||||
#endif
|
||||
params_->batch = batch;
|
||||
params_->col_ = params_->b_transpose_ ? b_shape[b_shape.size() - 2] : b_shape[b_shape.size() - 1];
|
||||
params_->col_8_ = UP_ROUND(params_->col_, 8);
|
||||
params_->col_8_ = UP_ROUND(params_->col_, col_tile);
|
||||
params_->deep_ = params_->b_transpose_ ? b_shape[b_shape.size() - 1] : b_shape[b_shape.size() - 2];
|
||||
|
||||
int col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_;
|
||||
|
@ -123,8 +131,8 @@ int MatmulCPUKernel::MallocMatrixBBuffer() {
|
|||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8));
|
||||
thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_);
|
||||
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, col_tile));
|
||||
thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, col_tile), thread_count_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -134,7 +142,12 @@ int MatmulCPUKernel::InitBias() {
|
|||
params_->col_ = params_->b_const_
|
||||
? (params_->b_transpose_ ? b_shape.at(b_shape.size() - 2) : b_shape.at(b_shape.size() - 1))
|
||||
: (c_shape.at(c_shape.size() - 1));
|
||||
params_->col_8_ = UP_ROUND(params_->col_, 8);
|
||||
#ifdef ENABLE_AVX
|
||||
int col_tile = C16NUM;
|
||||
#else
|
||||
int col_tile = C8NUM;
|
||||
#endif
|
||||
params_->col_8_ = UP_ROUND(params_->col_, col_tile);
|
||||
auto col_tmp = is_vector_a_ ? params_->col_ : params_->col_8_;
|
||||
if (bias_ptr_ == nullptr) {
|
||||
bias_ptr_ = reinterpret_cast<float *>(malloc(col_tmp * sizeof(float)));
|
||||
|
@ -171,7 +184,14 @@ void MatmulCPUKernel::InitMatrixA(const float *src_ptr, float *dst_ptr) {
|
|||
|
||||
for (int i = 0; i < params_->batch; i++) {
|
||||
const float *src = src_ptr + i * params_->deep_ * params_->row_;
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#ifdef ENABLE_AVX
|
||||
float *dst = dst_ptr + i * params_->deep_ * params_->row_6_;
|
||||
if (params_->a_transpose_) {
|
||||
RowMajor2Row6Major(src, dst, params_->deep_, params_->row_);
|
||||
} else {
|
||||
RowMajor2Col6Major(src, dst, params_->row_, params_->deep_);
|
||||
}
|
||||
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
float *dst = dst_ptr + i * params_->deep_ * params_->row_4_;
|
||||
if (params_->a_transpose_) {
|
||||
RowMajor2Row4Major(src, dst, params_->deep_, params_->row_);
|
||||
|
@ -207,11 +227,19 @@ void MatmulCPUKernel::InitMatrixB(const float *src_ptr, float *dst_ptr) {
|
|||
for (int i = 0; i < params_->batch; i++) {
|
||||
const float *src = src_ptr + i * params_->deep_ * params_->col_;
|
||||
float *dst = dst_ptr + i * params_->deep_ * params_->col_8_;
|
||||
#ifdef ENABLE_AVX
|
||||
if (params_->b_transpose_) {
|
||||
RowMajor2Col16Major(src, dst, params_->col_, params_->deep_);
|
||||
} else {
|
||||
RowMajor2Row16Major(src, dst, params_->deep_, params_->col_);
|
||||
}
|
||||
#else
|
||||
if (params_->b_transpose_) {
|
||||
RowMajor2Col8Major(src, dst, params_->col_, params_->deep_);
|
||||
} else {
|
||||
RowMajor2Row8Major(src, dst, params_->deep_, params_->col_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -247,13 +275,18 @@ int MatmulCPUKernel::Init() {
|
|||
}
|
||||
|
||||
int MatmulCPUKernel::RunImpl(int task_id) {
|
||||
int cur_oc = MSMIN(thread_stride_ * C8NUM, params_->col_ - task_id * thread_stride_ * C8NUM);
|
||||
#ifdef ENABLE_AVX
|
||||
int col_tile = C16NUM;
|
||||
#else
|
||||
int col_tile = C8NUM;
|
||||
#endif
|
||||
int cur_oc = MSMIN(thread_stride_ * col_tile, params_->col_ - task_id * thread_stride_ * col_tile);
|
||||
if (cur_oc <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto b = cur_b_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_;
|
||||
auto c = cur_c_ptr_ + task_id * thread_stride_ * C8NUM;
|
||||
auto bias = bias_ptr_ ? bias_ptr_ + task_id * thread_stride_ * C8NUM : NULL;
|
||||
auto b = cur_b_ptr_ + task_id * thread_stride_ * col_tile * params_->deep_;
|
||||
auto c = cur_c_ptr_ + task_id * thread_stride_ * col_tile;
|
||||
auto bias = bias_ptr_ ? bias_ptr_ + task_id * thread_stride_ * col_tile : NULL;
|
||||
MS_ASSERT(cur_a_ptr_);
|
||||
MS_ASSERT(b);
|
||||
MS_ASSERT(c);
|
||||
|
@ -323,7 +356,9 @@ int MatmulCPUKernel::Run() {
|
|||
cur_b_ptr_ = b_ptr_ + i * params_->deep_ * params_->col_;
|
||||
cur_c_ptr_ = c_src + i * params_->row_ * params_->col_;
|
||||
} else {
|
||||
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
|
||||
#ifdef ENABLE_AVX
|
||||
cur_a_ptr_ = a_ptr_ + i * params_->row_6_ * params_->deep_;
|
||||
#elif defined(ENABLE_ARM32) || defined(ENABLE_SSE)
|
||||
cur_a_ptr_ = a_ptr_ + i * params_->row_4_ * params_->deep_;
|
||||
#else
|
||||
cur_a_ptr_ = a_ptr_ + i * params_->row_12_ * params_->deep_;
|
||||
|
|
|
@ -74,6 +74,16 @@ if ("${X86_64_SIMD}" STREQUAL "sse")
|
|||
)
|
||||
endif()
|
||||
|
||||
if ("${X86_64_SIMD}" STREQUAL "avx")
|
||||
file(GLOB TEST_ASSEMBLY_SRC ${LITE_DIR}/nnacl/x86_64_sse/*.c
|
||||
${LITE_DIR}/nnacl/assembly/avx/*.S)
|
||||
set_property(SOURCE ${TEST_ASSEMBLY_SRC} PROPERTY LANGUAGE C)
|
||||
set(KERNEL_OP_SRC
|
||||
${KERNEL_OP_SRC}
|
||||
${TEST_ASSEMBLY_SRC}
|
||||
)
|
||||
endif()
|
||||
|
||||
### gpu kernel
|
||||
if (SUPPORT_GPU)
|
||||
file(GLOB GPU_KERNEL_OP_SRC
|
||||
|
|
Loading…
Reference in New Issue