code generator

This commit is contained in:
lzk 2021-12-02 18:35:29 -08:00
parent 2496637ead
commit dfbe38aa5b
66 changed files with 14818 additions and 0 deletions

View File

@ -118,3 +118,7 @@
"mindspore/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice_grad.py" "redefined-outer-name"
"mindspore/tests/st/pynative/parser/test_parser_construct.py" "bad-super-call"
"mindspore/tests/ut/python/optimizer/test_auto_grad.py" "broad-except"
#MindSpore Lite
"mindspore/mindspore/lite/experiment/HPC-generator/generator.py" "redefined-builtin"
"mindspore/mindspore/lite/experiment/HPC-generator/generator.py" "exec-used"

View File

@ -96,3 +96,54 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/avx/Tiled
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetNodeOutputType
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetValueToProto
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetScalarToProto
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x96_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x80_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x96_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x96_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x80_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x96_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x80_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x80_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x80_kernel_nhwc_fp32

View File

@ -0,0 +1,194 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_1x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 128(%[dst_0]), %%zmm2\n"
"vmovups 192(%[dst_0]), %%zmm3\n"
"vmovups 256(%[dst_0]), %%zmm4\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 128(%[bias]), %%zmm2\n"
"vmovaps 192(%[bias]), %%zmm3\n"
"vmovaps 256(%[bias]), %%zmm4\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4");
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vmovups 128(%[weight]), %%zmm29\n"
"vmovups 192(%[weight]), %%zmm28\n"
"vmovups 256(%[weight]), %%zmm27\n"
"vbroadcastss 0(%[src_0]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
// block 1
"vmovups 320(%[weight]), %%zmm31\n"
"vmovups 384(%[weight]), %%zmm30\n"
"vmovups 448(%[weight]), %%zmm29\n"
"vmovups 512(%[weight]), %%zmm28\n"
"vmovups 576(%[weight]), %%zmm27\n"
"vbroadcastss 4(%[src_0]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
// block 2
"vmovups 640(%[weight]), %%zmm31\n"
"vmovups 704(%[weight]), %%zmm30\n"
"vmovups 768(%[weight]), %%zmm29\n"
"vmovups 832(%[weight]), %%zmm28\n"
"vmovups 896(%[weight]), %%zmm27\n"
"vbroadcastss 8(%[src_0]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
// block 3
"vmovups 960(%[weight]), %%zmm31\n"
"vmovups 1024(%[weight]), %%zmm30\n"
"vmovups 1088(%[weight]), %%zmm29\n"
"vmovups 1152(%[weight]), %%zmm28\n"
"vmovups 1216(%[weight]), %%zmm27\n"
"vbroadcastss 12(%[src_0]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
// block 4
"vmovups 1280(%[weight]), %%zmm31\n"
"vmovups 1344(%[weight]), %%zmm30\n"
"vmovups 1408(%[weight]), %%zmm29\n"
"vmovups 1472(%[weight]), %%zmm28\n"
"vmovups 1536(%[weight]), %%zmm27\n"
"vbroadcastss 16(%[src_0]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
// block 5
"vmovups 1600(%[weight]), %%zmm31\n"
"vmovups 1664(%[weight]), %%zmm30\n"
"vmovups 1728(%[weight]), %%zmm29\n"
"vmovups 1792(%[weight]), %%zmm28\n"
"vmovups 1856(%[weight]), %%zmm27\n"
"vbroadcastss 20(%[src_0]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
// block 6
"vmovups 1920(%[weight]), %%zmm31\n"
"vmovups 1984(%[weight]), %%zmm30\n"
"vmovups 2048(%[weight]), %%zmm29\n"
"vmovups 2112(%[weight]), %%zmm28\n"
"vmovups 2176(%[weight]), %%zmm27\n"
"vbroadcastss 24(%[src_0]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
// block 7
"vmovups 2240(%[weight]), %%zmm31\n"
"vmovups 2304(%[weight]), %%zmm30\n"
"vmovups 2368(%[weight]), %%zmm29\n"
"vmovups 2432(%[weight]), %%zmm28\n"
"vmovups 2496(%[weight]), %%zmm27\n"
"vbroadcastss 28(%[src_0]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"dec %[deep]\n"
"add $2560, %[weight]\n"
"add $32, %[src_0]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,216 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_1x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 128(%[dst_0]), %%zmm2\n"
"vmovups 192(%[dst_0]), %%zmm3\n"
"vmovups 256(%[dst_0]), %%zmm4\n"
"vmovups 320(%[dst_0]), %%zmm5\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 128(%[bias]), %%zmm2\n"
"vmovaps 192(%[bias]), %%zmm3\n"
"vmovaps 256(%[bias]), %%zmm4\n"
"vmovaps 320(%[bias]), %%zmm5\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5");
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vmovups 128(%[weight]), %%zmm29\n"
"vmovups 192(%[weight]), %%zmm28\n"
"vmovups 256(%[weight]), %%zmm27\n"
"vmovups 320(%[weight]), %%zmm26\n"
"vbroadcastss 0(%[src_0]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
// block 1
"vmovups 384(%[weight]), %%zmm31\n"
"vmovups 448(%[weight]), %%zmm30\n"
"vmovups 512(%[weight]), %%zmm29\n"
"vmovups 576(%[weight]), %%zmm28\n"
"vmovups 640(%[weight]), %%zmm27\n"
"vmovups 704(%[weight]), %%zmm26\n"
"vbroadcastss 4(%[src_0]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
// block 2
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vmovups 896(%[weight]), %%zmm29\n"
"vmovups 960(%[weight]), %%zmm28\n"
"vmovups 1024(%[weight]), %%zmm27\n"
"vmovups 1088(%[weight]), %%zmm26\n"
"vbroadcastss 8(%[src_0]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
// block 3
"vmovups 1152(%[weight]), %%zmm31\n"
"vmovups 1216(%[weight]), %%zmm30\n"
"vmovups 1280(%[weight]), %%zmm29\n"
"vmovups 1344(%[weight]), %%zmm28\n"
"vmovups 1408(%[weight]), %%zmm27\n"
"vmovups 1472(%[weight]), %%zmm26\n"
"vbroadcastss 12(%[src_0]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
// block 4
"vmovups 1536(%[weight]), %%zmm31\n"
"vmovups 1600(%[weight]), %%zmm30\n"
"vmovups 1664(%[weight]), %%zmm29\n"
"vmovups 1728(%[weight]), %%zmm28\n"
"vmovups 1792(%[weight]), %%zmm27\n"
"vmovups 1856(%[weight]), %%zmm26\n"
"vbroadcastss 16(%[src_0]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
// block 5
"vmovups 1920(%[weight]), %%zmm31\n"
"vmovups 1984(%[weight]), %%zmm30\n"
"vmovups 2048(%[weight]), %%zmm29\n"
"vmovups 2112(%[weight]), %%zmm28\n"
"vmovups 2176(%[weight]), %%zmm27\n"
"vmovups 2240(%[weight]), %%zmm26\n"
"vbroadcastss 20(%[src_0]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
// block 6
"vmovups 2304(%[weight]), %%zmm31\n"
"vmovups 2368(%[weight]), %%zmm30\n"
"vmovups 2432(%[weight]), %%zmm29\n"
"vmovups 2496(%[weight]), %%zmm28\n"
"vmovups 2560(%[weight]), %%zmm27\n"
"vmovups 2624(%[weight]), %%zmm26\n"
"vbroadcastss 24(%[src_0]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
// block 7
"vmovups 2688(%[weight]), %%zmm31\n"
"vmovups 2752(%[weight]), %%zmm30\n"
"vmovups 2816(%[weight]), %%zmm29\n"
"vmovups 2880(%[weight]), %%zmm28\n"
"vmovups 2944(%[weight]), %%zmm27\n"
"vmovups 3008(%[weight]), %%zmm26\n"
"vbroadcastss 28(%[src_0]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"dec %[deep]\n"
"add $3072, %[weight]\n"
"add $32, %[src_0]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
"vmovups %%zmm5, 320(%[dst_0])\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,272 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_2x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 128(%[dst_0]), %%zmm2\n"
"vmovups 192(%[dst_0]), %%zmm3\n"
"vmovups 256(%[dst_0]), %%zmm4\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n"
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n"
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n"
"vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 128(%[bias]), %%zmm2\n"
"vmovaps 192(%[bias]), %%zmm3\n"
"vmovaps 256(%[bias]), %%zmm4\n"
"vmovaps 0(%[bias]), %%zmm5\n"
"vmovaps 64(%[bias]), %%zmm6\n"
"vmovaps 128(%[bias]), %%zmm7\n"
"vmovaps 192(%[bias]), %%zmm8\n"
"vmovaps 256(%[bias]), %%zmm9\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9");
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vmovups 128(%[weight]), %%zmm29\n"
"vmovups 192(%[weight]), %%zmm28\n"
"vmovups 256(%[weight]), %%zmm27\n"
"vbroadcastss 0(%[src_0]), %%zmm26\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
// block 1
"vmovups 320(%[weight]), %%zmm31\n"
"vmovups 384(%[weight]), %%zmm30\n"
"vmovups 448(%[weight]), %%zmm29\n"
"vmovups 512(%[weight]), %%zmm28\n"
"vmovups 576(%[weight]), %%zmm27\n"
"vbroadcastss 4(%[src_0]), %%zmm26\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
// block 2
"vmovups 640(%[weight]), %%zmm31\n"
"vmovups 704(%[weight]), %%zmm30\n"
"vmovups 768(%[weight]), %%zmm29\n"
"vmovups 832(%[weight]), %%zmm28\n"
"vmovups 896(%[weight]), %%zmm27\n"
"vbroadcastss 8(%[src_0]), %%zmm26\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
// block 3
"vmovups 960(%[weight]), %%zmm31\n"
"vmovups 1024(%[weight]), %%zmm30\n"
"vmovups 1088(%[weight]), %%zmm29\n"
"vmovups 1152(%[weight]), %%zmm28\n"
"vmovups 1216(%[weight]), %%zmm27\n"
"vbroadcastss 12(%[src_0]), %%zmm26\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
// block 4
"vmovups 1280(%[weight]), %%zmm31\n"
"vmovups 1344(%[weight]), %%zmm30\n"
"vmovups 1408(%[weight]), %%zmm29\n"
"vmovups 1472(%[weight]), %%zmm28\n"
"vmovups 1536(%[weight]), %%zmm27\n"
"vbroadcastss 16(%[src_0]), %%zmm26\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
// block 5
"vmovups 1600(%[weight]), %%zmm31\n"
"vmovups 1664(%[weight]), %%zmm30\n"
"vmovups 1728(%[weight]), %%zmm29\n"
"vmovups 1792(%[weight]), %%zmm28\n"
"vmovups 1856(%[weight]), %%zmm27\n"
"vbroadcastss 20(%[src_0]), %%zmm26\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
// block 6
"vmovups 1920(%[weight]), %%zmm31\n"
"vmovups 1984(%[weight]), %%zmm30\n"
"vmovups 2048(%[weight]), %%zmm29\n"
"vmovups 2112(%[weight]), %%zmm28\n"
"vmovups 2176(%[weight]), %%zmm27\n"
"vbroadcastss 24(%[src_0]), %%zmm26\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
// block 7
"vmovups 2240(%[weight]), %%zmm31\n"
"vmovups 2304(%[weight]), %%zmm30\n"
"vmovups 2368(%[weight]), %%zmm29\n"
"vmovups 2432(%[weight]), %%zmm28\n"
"vmovups 2496(%[weight]), %%zmm27\n"
"vbroadcastss 28(%[src_0]), %%zmm26\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"dec %[deep]\n"
"add $2560, %[weight]\n"
"add $32, %[src_0]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,308 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_2x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 128(%[dst_0]), %%zmm2\n"
"vmovups 192(%[dst_0]), %%zmm3\n"
"vmovups 256(%[dst_0]), %%zmm4\n"
"vmovups 320(%[dst_0]), %%zmm5\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n"
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n"
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n"
"vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n"
"vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 128(%[bias]), %%zmm2\n"
"vmovaps 192(%[bias]), %%zmm3\n"
"vmovaps 256(%[bias]), %%zmm4\n"
"vmovaps 320(%[bias]), %%zmm5\n"
"vmovaps 0(%[bias]), %%zmm6\n"
"vmovaps 64(%[bias]), %%zmm7\n"
"vmovaps 128(%[bias]), %%zmm8\n"
"vmovaps 192(%[bias]), %%zmm9\n"
"vmovaps 256(%[bias]), %%zmm10\n"
"vmovaps 320(%[bias]), %%zmm11\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11");
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vmovups 128(%[weight]), %%zmm29\n"
"vmovups 192(%[weight]), %%zmm28\n"
"vmovups 256(%[weight]), %%zmm27\n"
"vmovups 320(%[weight]), %%zmm26\n"
"vbroadcastss 0(%[src_0]), %%zmm25\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
// block 1
"vmovups 384(%[weight]), %%zmm31\n"
"vmovups 448(%[weight]), %%zmm30\n"
"vmovups 512(%[weight]), %%zmm29\n"
"vmovups 576(%[weight]), %%zmm28\n"
"vmovups 640(%[weight]), %%zmm27\n"
"vmovups 704(%[weight]), %%zmm26\n"
"vbroadcastss 4(%[src_0]), %%zmm25\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
// block 2
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vmovups 896(%[weight]), %%zmm29\n"
"vmovups 960(%[weight]), %%zmm28\n"
"vmovups 1024(%[weight]), %%zmm27\n"
"vmovups 1088(%[weight]), %%zmm26\n"
"vbroadcastss 8(%[src_0]), %%zmm25\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
// block 3
"vmovups 1152(%[weight]), %%zmm31\n"
"vmovups 1216(%[weight]), %%zmm30\n"
"vmovups 1280(%[weight]), %%zmm29\n"
"vmovups 1344(%[weight]), %%zmm28\n"
"vmovups 1408(%[weight]), %%zmm27\n"
"vmovups 1472(%[weight]), %%zmm26\n"
"vbroadcastss 12(%[src_0]), %%zmm25\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
// block 4
"vmovups 1536(%[weight]), %%zmm31\n"
"vmovups 1600(%[weight]), %%zmm30\n"
"vmovups 1664(%[weight]), %%zmm29\n"
"vmovups 1728(%[weight]), %%zmm28\n"
"vmovups 1792(%[weight]), %%zmm27\n"
"vmovups 1856(%[weight]), %%zmm26\n"
"vbroadcastss 16(%[src_0]), %%zmm25\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
// block 5
"vmovups 1920(%[weight]), %%zmm31\n"
"vmovups 1984(%[weight]), %%zmm30\n"
"vmovups 2048(%[weight]), %%zmm29\n"
"vmovups 2112(%[weight]), %%zmm28\n"
"vmovups 2176(%[weight]), %%zmm27\n"
"vmovups 2240(%[weight]), %%zmm26\n"
"vbroadcastss 20(%[src_0]), %%zmm25\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
// block 6
"vmovups 2304(%[weight]), %%zmm31\n"
"vmovups 2368(%[weight]), %%zmm30\n"
"vmovups 2432(%[weight]), %%zmm29\n"
"vmovups 2496(%[weight]), %%zmm28\n"
"vmovups 2560(%[weight]), %%zmm27\n"
"vmovups 2624(%[weight]), %%zmm26\n"
"vbroadcastss 24(%[src_0]), %%zmm25\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
// block 7
"vmovups 2688(%[weight]), %%zmm31\n"
"vmovups 2752(%[weight]), %%zmm30\n"
"vmovups 2816(%[weight]), %%zmm29\n"
"vmovups 2880(%[weight]), %%zmm28\n"
"vmovups 2944(%[weight]), %%zmm27\n"
"vmovups 3008(%[weight]), %%zmm26\n"
"vbroadcastss 28(%[src_0]), %%zmm25\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"dec %[deep]\n"
"add $3072, %[weight]\n"
"add $32, %[src_0]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"vminps %%zmm10, %%zmm30, %%zmm10\n"
"vminps %%zmm11, %%zmm30, %%zmm11\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
"vmovups %%zmm5, 320(%[dst_0])\n"
"vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1),\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,351 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_3x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 128(%[dst_0]), %%zmm2\n"
"vmovups 192(%[dst_0]), %%zmm3\n"
"vmovups 256(%[dst_0]), %%zmm4\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n"
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n"
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n"
"vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n"
"vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n"
"vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n"
"vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 128(%[bias]), %%zmm2\n"
"vmovaps 192(%[bias]), %%zmm3\n"
"vmovaps 256(%[bias]), %%zmm4\n"
"vmovaps 0(%[bias]), %%zmm5\n"
"vmovaps 64(%[bias]), %%zmm6\n"
"vmovaps 128(%[bias]), %%zmm7\n"
"vmovaps 192(%[bias]), %%zmm8\n"
"vmovaps 256(%[bias]), %%zmm9\n"
"vmovaps 0(%[bias]), %%zmm10\n"
"vmovaps 64(%[bias]), %%zmm11\n"
"vmovaps 128(%[bias]), %%zmm12\n"
"vmovaps 192(%[bias]), %%zmm13\n"
"vmovaps 256(%[bias]), %%zmm14\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
"vxorps %%zmm14, %%zmm14, %%zmm14\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13", "%zmm14");
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vmovups 128(%[weight]), %%zmm29\n"
"vmovups 192(%[weight]), %%zmm28\n"
"vmovups 256(%[weight]), %%zmm27\n"
"vbroadcastss 0(%[src_0]), %%zmm26\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
// block 1
"vmovups 320(%[weight]), %%zmm31\n"
"vmovups 384(%[weight]), %%zmm30\n"
"vmovups 448(%[weight]), %%zmm29\n"
"vmovups 512(%[weight]), %%zmm28\n"
"vmovups 576(%[weight]), %%zmm27\n"
"vbroadcastss 4(%[src_0]), %%zmm26\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
// block 2
"vmovups 640(%[weight]), %%zmm31\n"
"vmovups 704(%[weight]), %%zmm30\n"
"vmovups 768(%[weight]), %%zmm29\n"
"vmovups 832(%[weight]), %%zmm28\n"
"vmovups 896(%[weight]), %%zmm27\n"
"vbroadcastss 8(%[src_0]), %%zmm26\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
// block 3
"vmovups 960(%[weight]), %%zmm31\n"
"vmovups 1024(%[weight]), %%zmm30\n"
"vmovups 1088(%[weight]), %%zmm29\n"
"vmovups 1152(%[weight]), %%zmm28\n"
"vmovups 1216(%[weight]), %%zmm27\n"
"vbroadcastss 12(%[src_0]), %%zmm26\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
// block 4
"vmovups 1280(%[weight]), %%zmm31\n"
"vmovups 1344(%[weight]), %%zmm30\n"
"vmovups 1408(%[weight]), %%zmm29\n"
"vmovups 1472(%[weight]), %%zmm28\n"
"vmovups 1536(%[weight]), %%zmm27\n"
"vbroadcastss 16(%[src_0]), %%zmm26\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
// block 5
"vmovups 1600(%[weight]), %%zmm31\n"
"vmovups 1664(%[weight]), %%zmm30\n"
"vmovups 1728(%[weight]), %%zmm29\n"
"vmovups 1792(%[weight]), %%zmm28\n"
"vmovups 1856(%[weight]), %%zmm27\n"
"vbroadcastss 20(%[src_0]), %%zmm26\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
// block 6
"vmovups 1920(%[weight]), %%zmm31\n"
"vmovups 1984(%[weight]), %%zmm30\n"
"vmovups 2048(%[weight]), %%zmm29\n"
"vmovups 2112(%[weight]), %%zmm28\n"
"vmovups 2176(%[weight]), %%zmm27\n"
"vbroadcastss 24(%[src_0]), %%zmm26\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
// block 7
"vmovups 2240(%[weight]), %%zmm31\n"
"vmovups 2304(%[weight]), %%zmm30\n"
"vmovups 2368(%[weight]), %%zmm29\n"
"vmovups 2432(%[weight]), %%zmm28\n"
"vmovups 2496(%[weight]), %%zmm27\n"
"vbroadcastss 28(%[src_0]), %%zmm26\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
"dec %[deep]\n"
"add $2560, %[weight]\n"
"add $32, %[src_0]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
"vmaxps %%zmm14, %%zmm31, %%zmm14\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"vminps %%zmm10, %%zmm30, %%zmm10\n"
"vminps %%zmm11, %%zmm30, %%zmm11\n"
"vminps %%zmm12, %%zmm30, %%zmm12\n"
"vminps %%zmm13, %%zmm30, %%zmm13\n"
"vminps %%zmm14, %%zmm30, %%zmm14\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2),\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,401 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_3x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 128(%[dst_0]), %%zmm2\n"
"vmovups 192(%[dst_0]), %%zmm3\n"
"vmovups 256(%[dst_0]), %%zmm4\n"
"vmovups 320(%[dst_0]), %%zmm5\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n"
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n"
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n"
"vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n"
"vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n"
"vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n"
"vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n"
"vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n"
"vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 128(%[bias]), %%zmm2\n"
"vmovaps 192(%[bias]), %%zmm3\n"
"vmovaps 256(%[bias]), %%zmm4\n"
"vmovaps 320(%[bias]), %%zmm5\n"
"vmovaps 0(%[bias]), %%zmm6\n"
"vmovaps 64(%[bias]), %%zmm7\n"
"vmovaps 128(%[bias]), %%zmm8\n"
"vmovaps 192(%[bias]), %%zmm9\n"
"vmovaps 256(%[bias]), %%zmm10\n"
"vmovaps 320(%[bias]), %%zmm11\n"
"vmovaps 0(%[bias]), %%zmm12\n"
"vmovaps 64(%[bias]), %%zmm13\n"
"vmovaps 128(%[bias]), %%zmm14\n"
"vmovaps 192(%[bias]), %%zmm15\n"
"vmovaps 256(%[bias]), %%zmm16\n"
"vmovaps 320(%[bias]), %%zmm17\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
"vxorps %%zmm14, %%zmm14, %%zmm14\n"
"vxorps %%zmm15, %%zmm15, %%zmm15\n"
"vxorps %%zmm16, %%zmm16, %%zmm16\n"
"vxorps %%zmm17, %%zmm17, %%zmm17\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17");
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vmovups 128(%[weight]), %%zmm29\n"
"vmovups 192(%[weight]), %%zmm28\n"
"vmovups 256(%[weight]), %%zmm27\n"
"vmovups 320(%[weight]), %%zmm26\n"
"vbroadcastss 0(%[src_0]), %%zmm25\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
// block 1
"vmovups 384(%[weight]), %%zmm31\n"
"vmovups 448(%[weight]), %%zmm30\n"
"vmovups 512(%[weight]), %%zmm29\n"
"vmovups 576(%[weight]), %%zmm28\n"
"vmovups 640(%[weight]), %%zmm27\n"
"vmovups 704(%[weight]), %%zmm26\n"
"vbroadcastss 4(%[src_0]), %%zmm25\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
// block 2
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vmovups 896(%[weight]), %%zmm29\n"
"vmovups 960(%[weight]), %%zmm28\n"
"vmovups 1024(%[weight]), %%zmm27\n"
"vmovups 1088(%[weight]), %%zmm26\n"
"vbroadcastss 8(%[src_0]), %%zmm25\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
// block 3
"vmovups 1152(%[weight]), %%zmm31\n"
"vmovups 1216(%[weight]), %%zmm30\n"
"vmovups 1280(%[weight]), %%zmm29\n"
"vmovups 1344(%[weight]), %%zmm28\n"
"vmovups 1408(%[weight]), %%zmm27\n"
"vmovups 1472(%[weight]), %%zmm26\n"
"vbroadcastss 12(%[src_0]), %%zmm25\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
// block 4
"vmovups 1536(%[weight]), %%zmm31\n"
"vmovups 1600(%[weight]), %%zmm30\n"
"vmovups 1664(%[weight]), %%zmm29\n"
"vmovups 1728(%[weight]), %%zmm28\n"
"vmovups 1792(%[weight]), %%zmm27\n"
"vmovups 1856(%[weight]), %%zmm26\n"
"vbroadcastss 16(%[src_0]), %%zmm25\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
// block 5
"vmovups 1920(%[weight]), %%zmm31\n"
"vmovups 1984(%[weight]), %%zmm30\n"
"vmovups 2048(%[weight]), %%zmm29\n"
"vmovups 2112(%[weight]), %%zmm28\n"
"vmovups 2176(%[weight]), %%zmm27\n"
"vmovups 2240(%[weight]), %%zmm26\n"
"vbroadcastss 20(%[src_0]), %%zmm25\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
// block 6
"vmovups 2304(%[weight]), %%zmm31\n"
"vmovups 2368(%[weight]), %%zmm30\n"
"vmovups 2432(%[weight]), %%zmm29\n"
"vmovups 2496(%[weight]), %%zmm28\n"
"vmovups 2560(%[weight]), %%zmm27\n"
"vmovups 2624(%[weight]), %%zmm26\n"
"vbroadcastss 24(%[src_0]), %%zmm25\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
// block 7
"vmovups 2688(%[weight]), %%zmm31\n"
"vmovups 2752(%[weight]), %%zmm30\n"
"vmovups 2816(%[weight]), %%zmm29\n"
"vmovups 2880(%[weight]), %%zmm28\n"
"vmovups 2944(%[weight]), %%zmm27\n"
"vmovups 3008(%[weight]), %%zmm26\n"
"vbroadcastss 28(%[src_0]), %%zmm25\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
"dec %[deep]\n"
"add $3072, %[weight]\n"
"add $32, %[src_0]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
"vmaxps %%zmm14, %%zmm31, %%zmm14\n"
"vmaxps %%zmm15, %%zmm31, %%zmm15\n"
"vmaxps %%zmm16, %%zmm31, %%zmm16\n"
"vmaxps %%zmm17, %%zmm31, %%zmm17\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"vminps %%zmm10, %%zmm30, %%zmm10\n"
"vminps %%zmm11, %%zmm30, %%zmm11\n"
"vminps %%zmm12, %%zmm30, %%zmm12\n"
"vminps %%zmm13, %%zmm30, %%zmm13\n"
"vminps %%zmm14, %%zmm30, %%zmm14\n"
"vminps %%zmm15, %%zmm30, %%zmm15\n"
"vminps %%zmm16, %%zmm30, %%zmm16\n"
"vminps %%zmm17, %%zmm30, %%zmm17\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
"vmovups %%zmm5, 320(%[dst_0])\n"
"vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2),\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,434 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
const float *dst_3 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 128(%[dst_0]), %%zmm2\n"
"vmovups 192(%[dst_0]), %%zmm3\n"
"vmovups 256(%[dst_0]), %%zmm4\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n"
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n"
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n"
"vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n"
"vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n"
"vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n"
"vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n"
"vmovups 0(%[dst_3]), %%zmm15\n"
"vmovups 64(%[dst_3]), %%zmm16\n"
"vmovups 128(%[dst_3]), %%zmm17\n"
"vmovups 192(%[dst_3]), %%zmm18\n"
"vmovups 256(%[dst_3]), %%zmm19\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 128(%[bias]), %%zmm2\n"
"vmovaps 192(%[bias]), %%zmm3\n"
"vmovaps 256(%[bias]), %%zmm4\n"
"vmovaps 0(%[bias]), %%zmm5\n"
"vmovaps 64(%[bias]), %%zmm6\n"
"vmovaps 128(%[bias]), %%zmm7\n"
"vmovaps 192(%[bias]), %%zmm8\n"
"vmovaps 256(%[bias]), %%zmm9\n"
"vmovaps 0(%[bias]), %%zmm10\n"
"vmovaps 64(%[bias]), %%zmm11\n"
"vmovaps 128(%[bias]), %%zmm12\n"
"vmovaps 192(%[bias]), %%zmm13\n"
"vmovaps 256(%[bias]), %%zmm14\n"
"vmovaps 0(%[bias]), %%zmm15\n"
"vmovaps 64(%[bias]), %%zmm16\n"
"vmovaps 128(%[bias]), %%zmm17\n"
"vmovaps 192(%[bias]), %%zmm18\n"
"vmovaps 256(%[bias]), %%zmm19\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
"vxorps %%zmm14, %%zmm14, %%zmm14\n"
"vxorps %%zmm15, %%zmm15, %%zmm15\n"
"vxorps %%zmm16, %%zmm16, %%zmm16\n"
"vxorps %%zmm17, %%zmm17, %%zmm17\n"
"vxorps %%zmm18, %%zmm18, %%zmm18\n"
"vxorps %%zmm19, %%zmm19, %%zmm19\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
[ dst_3 ] "r"(dst_3)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19");
const float *src_3 = src + 3 * dst_stride;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vmovups 128(%[weight]), %%zmm29\n"
"vmovups 192(%[weight]), %%zmm28\n"
"vmovups 256(%[weight]), %%zmm27\n"
"vbroadcastss 0(%[src_0]), %%zmm26\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 0(%[src_0], %[src_stride], 3), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
// block 1
"vmovups 320(%[weight]), %%zmm31\n"
"vmovups 384(%[weight]), %%zmm30\n"
"vmovups 448(%[weight]), %%zmm29\n"
"vmovups 512(%[weight]), %%zmm28\n"
"vmovups 576(%[weight]), %%zmm27\n"
"vbroadcastss 4(%[src_0]), %%zmm26\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 4(%[src_0], %[src_stride], 3), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
// block 2
"vmovups 640(%[weight]), %%zmm31\n"
"vmovups 704(%[weight]), %%zmm30\n"
"vmovups 768(%[weight]), %%zmm29\n"
"vmovups 832(%[weight]), %%zmm28\n"
"vmovups 896(%[weight]), %%zmm27\n"
"vbroadcastss 8(%[src_0]), %%zmm26\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 8(%[src_0], %[src_stride], 3), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
// block 3
"vmovups 960(%[weight]), %%zmm31\n"
"vmovups 1024(%[weight]), %%zmm30\n"
"vmovups 1088(%[weight]), %%zmm29\n"
"vmovups 1152(%[weight]), %%zmm28\n"
"vmovups 1216(%[weight]), %%zmm27\n"
"vbroadcastss 12(%[src_0]), %%zmm26\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 12(%[src_0], %[src_stride], 3), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
// block 4
"vmovups 1280(%[weight]), %%zmm31\n"
"vmovups 1344(%[weight]), %%zmm30\n"
"vmovups 1408(%[weight]), %%zmm29\n"
"vmovups 1472(%[weight]), %%zmm28\n"
"vmovups 1536(%[weight]), %%zmm27\n"
"vbroadcastss 16(%[src_0]), %%zmm26\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 16(%[src_0], %[src_stride], 3), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
// block 5
"vmovups 1600(%[weight]), %%zmm31\n"
"vmovups 1664(%[weight]), %%zmm30\n"
"vmovups 1728(%[weight]), %%zmm29\n"
"vmovups 1792(%[weight]), %%zmm28\n"
"vmovups 1856(%[weight]), %%zmm27\n"
"vbroadcastss 20(%[src_0]), %%zmm26\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 20(%[src_0], %[src_stride], 3), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
// block 6
"vmovups 1920(%[weight]), %%zmm31\n"
"vmovups 1984(%[weight]), %%zmm30\n"
"vmovups 2048(%[weight]), %%zmm29\n"
"vmovups 2112(%[weight]), %%zmm28\n"
"vmovups 2176(%[weight]), %%zmm27\n"
"vbroadcastss 24(%[src_0]), %%zmm26\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 24(%[src_0], %[src_stride], 3), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
// block 7
"vmovups 2240(%[weight]), %%zmm31\n"
"vmovups 2304(%[weight]), %%zmm30\n"
"vmovups 2368(%[weight]), %%zmm29\n"
"vmovups 2432(%[weight]), %%zmm28\n"
"vmovups 2496(%[weight]), %%zmm27\n"
"vbroadcastss 28(%[src_0]), %%zmm26\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 28(%[src_0], %[src_stride], 3), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
"dec %[deep]\n"
"add $2560, %[weight]\n"
"add $32, %[src_0]\n"
"add $32, %[src_3]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
"vmaxps %%zmm14, %%zmm31, %%zmm14\n"
"vmaxps %%zmm15, %%zmm31, %%zmm15\n"
"vmaxps %%zmm16, %%zmm31, %%zmm16\n"
"vmaxps %%zmm17, %%zmm31, %%zmm17\n"
"vmaxps %%zmm18, %%zmm31, %%zmm18\n"
"vmaxps %%zmm19, %%zmm31, %%zmm19\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"vminps %%zmm10, %%zmm30, %%zmm10\n"
"vminps %%zmm11, %%zmm30, %%zmm11\n"
"vminps %%zmm12, %%zmm30, %%zmm12\n"
"vminps %%zmm13, %%zmm30, %%zmm13\n"
"vminps %%zmm14, %%zmm30, %%zmm14\n"
"vminps %%zmm15, %%zmm30, %%zmm15\n"
"vminps %%zmm16, %%zmm30, %%zmm16\n"
"vminps %%zmm17, %%zmm30, %%zmm17\n"
"vminps %%zmm18, %%zmm30, %%zmm18\n"
"vminps %%zmm19, %%zmm30, %%zmm19\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm15, 0(%[dst_3])\n"
"vmovups %%zmm16, 64(%[dst_3])\n"
"vmovups %%zmm17, 128(%[dst_3])\n"
"vmovups %%zmm18, 192(%[dst_3])\n"
"vmovups %%zmm19, 256(%[dst_3])\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
[ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,499 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
const float *dst_3 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 128(%[dst_0]), %%zmm2\n"
"vmovups 192(%[dst_0]), %%zmm3\n"
"vmovups 256(%[dst_0]), %%zmm4\n"
"vmovups 320(%[dst_0]), %%zmm5\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n"
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n"
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n"
"vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n"
"vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n"
"vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n"
"vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n"
"vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n"
"vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n"
"vmovups 0(%[dst_3]), %%zmm18\n"
"vmovups 64(%[dst_3]), %%zmm19\n"
"vmovups 128(%[dst_3]), %%zmm20\n"
"vmovups 192(%[dst_3]), %%zmm21\n"
"vmovups 256(%[dst_3]), %%zmm22\n"
"vmovups 320(%[dst_3]), %%zmm23\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 128(%[bias]), %%zmm2\n"
"vmovaps 192(%[bias]), %%zmm3\n"
"vmovaps 256(%[bias]), %%zmm4\n"
"vmovaps 320(%[bias]), %%zmm5\n"
"vmovaps 0(%[bias]), %%zmm6\n"
"vmovaps 64(%[bias]), %%zmm7\n"
"vmovaps 128(%[bias]), %%zmm8\n"
"vmovaps 192(%[bias]), %%zmm9\n"
"vmovaps 256(%[bias]), %%zmm10\n"
"vmovaps 320(%[bias]), %%zmm11\n"
"vmovaps 0(%[bias]), %%zmm12\n"
"vmovaps 64(%[bias]), %%zmm13\n"
"vmovaps 128(%[bias]), %%zmm14\n"
"vmovaps 192(%[bias]), %%zmm15\n"
"vmovaps 256(%[bias]), %%zmm16\n"
"vmovaps 320(%[bias]), %%zmm17\n"
"vmovaps 0(%[bias]), %%zmm18\n"
"vmovaps 64(%[bias]), %%zmm19\n"
"vmovaps 128(%[bias]), %%zmm20\n"
"vmovaps 192(%[bias]), %%zmm21\n"
"vmovaps 256(%[bias]), %%zmm22\n"
"vmovaps 320(%[bias]), %%zmm23\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
"vxorps %%zmm14, %%zmm14, %%zmm14\n"
"vxorps %%zmm15, %%zmm15, %%zmm15\n"
"vxorps %%zmm16, %%zmm16, %%zmm16\n"
"vxorps %%zmm17, %%zmm17, %%zmm17\n"
"vxorps %%zmm18, %%zmm18, %%zmm18\n"
"vxorps %%zmm19, %%zmm19, %%zmm19\n"
"vxorps %%zmm20, %%zmm20, %%zmm20\n"
"vxorps %%zmm21, %%zmm21, %%zmm21\n"
"vxorps %%zmm22, %%zmm22, %%zmm22\n"
"vxorps %%zmm23, %%zmm23, %%zmm23\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
[ dst_3 ] "r"(dst_3)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22",
"%zmm23");
const float *src_3 = src + 3 * dst_stride;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vmovups 128(%[weight]), %%zmm29\n"
"vmovups 192(%[weight]), %%zmm28\n"
"vmovups 256(%[weight]), %%zmm27\n"
"vmovups 320(%[weight]), %%zmm26\n"
"vbroadcastss 0(%[src_0]), %%zmm25\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 0(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
// block 1
"vmovups 384(%[weight]), %%zmm31\n"
"vmovups 448(%[weight]), %%zmm30\n"
"vmovups 512(%[weight]), %%zmm29\n"
"vmovups 576(%[weight]), %%zmm28\n"
"vmovups 640(%[weight]), %%zmm27\n"
"vmovups 704(%[weight]), %%zmm26\n"
"vbroadcastss 4(%[src_0]), %%zmm25\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 4(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
// block 2
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vmovups 896(%[weight]), %%zmm29\n"
"vmovups 960(%[weight]), %%zmm28\n"
"vmovups 1024(%[weight]), %%zmm27\n"
"vmovups 1088(%[weight]), %%zmm26\n"
"vbroadcastss 8(%[src_0]), %%zmm25\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 8(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
// block 3
"vmovups 1152(%[weight]), %%zmm31\n"
"vmovups 1216(%[weight]), %%zmm30\n"
"vmovups 1280(%[weight]), %%zmm29\n"
"vmovups 1344(%[weight]), %%zmm28\n"
"vmovups 1408(%[weight]), %%zmm27\n"
"vmovups 1472(%[weight]), %%zmm26\n"
"vbroadcastss 12(%[src_0]), %%zmm25\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 12(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
// block 4
"vmovups 1536(%[weight]), %%zmm31\n"
"vmovups 1600(%[weight]), %%zmm30\n"
"vmovups 1664(%[weight]), %%zmm29\n"
"vmovups 1728(%[weight]), %%zmm28\n"
"vmovups 1792(%[weight]), %%zmm27\n"
"vmovups 1856(%[weight]), %%zmm26\n"
"vbroadcastss 16(%[src_0]), %%zmm25\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 16(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
// block 5
"vmovups 1920(%[weight]), %%zmm31\n"
"vmovups 1984(%[weight]), %%zmm30\n"
"vmovups 2048(%[weight]), %%zmm29\n"
"vmovups 2112(%[weight]), %%zmm28\n"
"vmovups 2176(%[weight]), %%zmm27\n"
"vmovups 2240(%[weight]), %%zmm26\n"
"vbroadcastss 20(%[src_0]), %%zmm25\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 20(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
// block 6
"vmovups 2304(%[weight]), %%zmm31\n"
"vmovups 2368(%[weight]), %%zmm30\n"
"vmovups 2432(%[weight]), %%zmm29\n"
"vmovups 2496(%[weight]), %%zmm28\n"
"vmovups 2560(%[weight]), %%zmm27\n"
"vmovups 2624(%[weight]), %%zmm26\n"
"vbroadcastss 24(%[src_0]), %%zmm25\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 24(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
// block 7
"vmovups 2688(%[weight]), %%zmm31\n"
"vmovups 2752(%[weight]), %%zmm30\n"
"vmovups 2816(%[weight]), %%zmm29\n"
"vmovups 2880(%[weight]), %%zmm28\n"
"vmovups 2944(%[weight]), %%zmm27\n"
"vmovups 3008(%[weight]), %%zmm26\n"
"vbroadcastss 28(%[src_0]), %%zmm25\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 28(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
"dec %[deep]\n"
"add $3072, %[weight]\n"
"add $32, %[src_0]\n"
"add $32, %[src_3]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
"vmaxps %%zmm14, %%zmm31, %%zmm14\n"
"vmaxps %%zmm15, %%zmm31, %%zmm15\n"
"vmaxps %%zmm16, %%zmm31, %%zmm16\n"
"vmaxps %%zmm17, %%zmm31, %%zmm17\n"
"vmaxps %%zmm18, %%zmm31, %%zmm18\n"
"vmaxps %%zmm19, %%zmm31, %%zmm19\n"
"vmaxps %%zmm20, %%zmm31, %%zmm20\n"
"vmaxps %%zmm21, %%zmm31, %%zmm21\n"
"vmaxps %%zmm22, %%zmm31, %%zmm22\n"
"vmaxps %%zmm23, %%zmm31, %%zmm23\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"vminps %%zmm10, %%zmm30, %%zmm10\n"
"vminps %%zmm11, %%zmm30, %%zmm11\n"
"vminps %%zmm12, %%zmm30, %%zmm12\n"
"vminps %%zmm13, %%zmm30, %%zmm13\n"
"vminps %%zmm14, %%zmm30, %%zmm14\n"
"vminps %%zmm15, %%zmm30, %%zmm15\n"
"vminps %%zmm16, %%zmm30, %%zmm16\n"
"vminps %%zmm17, %%zmm30, %%zmm17\n"
"vminps %%zmm18, %%zmm30, %%zmm18\n"
"vminps %%zmm19, %%zmm30, %%zmm19\n"
"vminps %%zmm20, %%zmm30, %%zmm20\n"
"vminps %%zmm21, %%zmm30, %%zmm21\n"
"vminps %%zmm22, %%zmm30, %%zmm22\n"
"vminps %%zmm23, %%zmm30, %%zmm23\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
"vmovups %%zmm5, 320(%[dst_0])\n"
"vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm18, 0(%[dst_3])\n"
"vmovups %%zmm19, 64(%[dst_3])\n"
"vmovups %%zmm20, 128(%[dst_3])\n"
"vmovups %%zmm21, 192(%[dst_3])\n"
"vmovups %%zmm22, 256(%[dst_3])\n"
"vmovups %%zmm23, 320(%[dst_3])\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
[ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,513 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
const float *dst_3 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 128(%[dst_0]), %%zmm2\n"
"vmovups 192(%[dst_0]), %%zmm3\n"
"vmovups 256(%[dst_0]), %%zmm4\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n"
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n"
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n"
"vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n"
"vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n"
"vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n"
"vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n"
"vmovups 0(%[dst_3]), %%zmm15\n"
"vmovups 64(%[dst_3]), %%zmm16\n"
"vmovups 128(%[dst_3]), %%zmm17\n"
"vmovups 192(%[dst_3]), %%zmm18\n"
"vmovups 256(%[dst_3]), %%zmm19\n"
"vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm20\n"
"vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm21\n"
"vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm22\n"
"vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm23\n"
"vmovups 256(%[dst_3], %[dst_stride], 1), %%zmm24\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 128(%[bias]), %%zmm2\n"
"vmovaps 192(%[bias]), %%zmm3\n"
"vmovaps 256(%[bias]), %%zmm4\n"
"vmovaps 0(%[bias]), %%zmm5\n"
"vmovaps 64(%[bias]), %%zmm6\n"
"vmovaps 128(%[bias]), %%zmm7\n"
"vmovaps 192(%[bias]), %%zmm8\n"
"vmovaps 256(%[bias]), %%zmm9\n"
"vmovaps 0(%[bias]), %%zmm10\n"
"vmovaps 64(%[bias]), %%zmm11\n"
"vmovaps 128(%[bias]), %%zmm12\n"
"vmovaps 192(%[bias]), %%zmm13\n"
"vmovaps 256(%[bias]), %%zmm14\n"
"vmovaps 0(%[bias]), %%zmm15\n"
"vmovaps 64(%[bias]), %%zmm16\n"
"vmovaps 128(%[bias]), %%zmm17\n"
"vmovaps 192(%[bias]), %%zmm18\n"
"vmovaps 256(%[bias]), %%zmm19\n"
"vmovaps 0(%[bias]), %%zmm20\n"
"vmovaps 64(%[bias]), %%zmm21\n"
"vmovaps 128(%[bias]), %%zmm22\n"
"vmovaps 192(%[bias]), %%zmm23\n"
"vmovaps 256(%[bias]), %%zmm24\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
"vxorps %%zmm14, %%zmm14, %%zmm14\n"
"vxorps %%zmm15, %%zmm15, %%zmm15\n"
"vxorps %%zmm16, %%zmm16, %%zmm16\n"
"vxorps %%zmm17, %%zmm17, %%zmm17\n"
"vxorps %%zmm18, %%zmm18, %%zmm18\n"
"vxorps %%zmm19, %%zmm19, %%zmm19\n"
"vxorps %%zmm20, %%zmm20, %%zmm20\n"
"vxorps %%zmm21, %%zmm21, %%zmm21\n"
"vxorps %%zmm22, %%zmm22, %%zmm22\n"
"vxorps %%zmm23, %%zmm23, %%zmm23\n"
"vxorps %%zmm24, %%zmm24, %%zmm24\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
[ dst_3 ] "r"(dst_3)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22",
"%zmm23", "%zmm24");
const float *src_3 = src + 3 * dst_stride;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vmovups 128(%[weight]), %%zmm29\n"
"vmovups 192(%[weight]), %%zmm28\n"
"vmovups 256(%[weight]), %%zmm27\n"
"vbroadcastss 0(%[src_0]), %%zmm26\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 0(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
// block 1
"vmovups 320(%[weight]), %%zmm31\n"
"vmovups 384(%[weight]), %%zmm30\n"
"vmovups 448(%[weight]), %%zmm29\n"
"vmovups 512(%[weight]), %%zmm28\n"
"vmovups 576(%[weight]), %%zmm27\n"
"vbroadcastss 4(%[src_0]), %%zmm26\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 4(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
// block 2
"vmovups 640(%[weight]), %%zmm31\n"
"vmovups 704(%[weight]), %%zmm30\n"
"vmovups 768(%[weight]), %%zmm29\n"
"vmovups 832(%[weight]), %%zmm28\n"
"vmovups 896(%[weight]), %%zmm27\n"
"vbroadcastss 8(%[src_0]), %%zmm26\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 8(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
// block 3
"vmovups 960(%[weight]), %%zmm31\n"
"vmovups 1024(%[weight]), %%zmm30\n"
"vmovups 1088(%[weight]), %%zmm29\n"
"vmovups 1152(%[weight]), %%zmm28\n"
"vmovups 1216(%[weight]), %%zmm27\n"
"vbroadcastss 12(%[src_0]), %%zmm26\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 12(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
// block 4
"vmovups 1280(%[weight]), %%zmm31\n"
"vmovups 1344(%[weight]), %%zmm30\n"
"vmovups 1408(%[weight]), %%zmm29\n"
"vmovups 1472(%[weight]), %%zmm28\n"
"vmovups 1536(%[weight]), %%zmm27\n"
"vbroadcastss 16(%[src_0]), %%zmm26\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 16(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
// block 5
"vmovups 1600(%[weight]), %%zmm31\n"
"vmovups 1664(%[weight]), %%zmm30\n"
"vmovups 1728(%[weight]), %%zmm29\n"
"vmovups 1792(%[weight]), %%zmm28\n"
"vmovups 1856(%[weight]), %%zmm27\n"
"vbroadcastss 20(%[src_0]), %%zmm26\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 20(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
// block 6
"vmovups 1920(%[weight]), %%zmm31\n"
"vmovups 1984(%[weight]), %%zmm30\n"
"vmovups 2048(%[weight]), %%zmm29\n"
"vmovups 2112(%[weight]), %%zmm28\n"
"vmovups 2176(%[weight]), %%zmm27\n"
"vbroadcastss 24(%[src_0]), %%zmm26\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 24(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
// block 7
"vmovups 2240(%[weight]), %%zmm31\n"
"vmovups 2304(%[weight]), %%zmm30\n"
"vmovups 2368(%[weight]), %%zmm29\n"
"vmovups 2432(%[weight]), %%zmm28\n"
"vmovups 2496(%[weight]), %%zmm27\n"
"vbroadcastss 28(%[src_0]), %%zmm26\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 28(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"dec %[deep]\n"
"add $2560, %[weight]\n"
"add $32, %[src_0]\n"
"add $32, %[src_3]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
"vmaxps %%zmm14, %%zmm31, %%zmm14\n"
"vmaxps %%zmm15, %%zmm31, %%zmm15\n"
"vmaxps %%zmm16, %%zmm31, %%zmm16\n"
"vmaxps %%zmm17, %%zmm31, %%zmm17\n"
"vmaxps %%zmm18, %%zmm31, %%zmm18\n"
"vmaxps %%zmm19, %%zmm31, %%zmm19\n"
"vmaxps %%zmm20, %%zmm31, %%zmm20\n"
"vmaxps %%zmm21, %%zmm31, %%zmm21\n"
"vmaxps %%zmm22, %%zmm31, %%zmm22\n"
"vmaxps %%zmm23, %%zmm31, %%zmm23\n"
"vmaxps %%zmm24, %%zmm31, %%zmm24\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"vminps %%zmm10, %%zmm30, %%zmm10\n"
"vminps %%zmm11, %%zmm30, %%zmm11\n"
"vminps %%zmm12, %%zmm30, %%zmm12\n"
"vminps %%zmm13, %%zmm30, %%zmm13\n"
"vminps %%zmm14, %%zmm30, %%zmm14\n"
"vminps %%zmm15, %%zmm30, %%zmm15\n"
"vminps %%zmm16, %%zmm30, %%zmm16\n"
"vminps %%zmm17, %%zmm30, %%zmm17\n"
"vminps %%zmm18, %%zmm30, %%zmm18\n"
"vminps %%zmm19, %%zmm30, %%zmm19\n"
"vminps %%zmm20, %%zmm30, %%zmm20\n"
"vminps %%zmm21, %%zmm30, %%zmm21\n"
"vminps %%zmm22, %%zmm30, %%zmm22\n"
"vminps %%zmm23, %%zmm30, %%zmm23\n"
"vminps %%zmm24, %%zmm30, %%zmm24\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm15, 0(%[dst_3])\n"
"vmovups %%zmm16, 64(%[dst_3])\n"
"vmovups %%zmm17, 128(%[dst_3])\n"
"vmovups %%zmm18, 192(%[dst_3])\n"
"vmovups %%zmm19, 256(%[dst_3])\n"
"vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 1),\n"
"vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 1),\n"
"vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 1),\n"
"vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 1),\n"
"vmovups %%zmm24, 256(%[dst_3], %[dst_stride], 1),\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
[ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,297 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst1;
__m256 dst2;
__m256 dst3;
__m256 dst4;
__m256 dst5;
__m256 dst6;
__m256 dst7;
__m256 dst8;
__m256 dst9;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48);
dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56);
dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64);
dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
dst6 = _mm256_setzero_ps();
dst7 = _mm256_setzero_ps();
dst8 = _mm256_setzero_ps();
dst9 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst1 = _mm256_load_ps(bias + 0);
dst2 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 0);
dst4 = _mm256_load_ps(bias + 0);
dst5 = _mm256_load_ps(bias + 0);
dst6 = _mm256_load_ps(bias + 0);
dst7 = _mm256_load_ps(bias + 0);
dst8 = _mm256_load_ps(bias + 0);
dst9 = _mm256_load_ps(bias + 0);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
__m256 src30 = _mm256_set1_ps(*(src + 24));
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
__m256 src40 = _mm256_set1_ps(*(src + 32));
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
__m256 src50 = _mm256_set1_ps(*(src + 40));
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
__m256 src60 = _mm256_set1_ps(*(src + 48));
dst6 = _mm256_fmadd_ps(dst6, src60, weight00);
__m256 src70 = _mm256_set1_ps(*(src + 56));
dst7 = _mm256_fmadd_ps(dst7, src70, weight00);
__m256 src80 = _mm256_set1_ps(*(src + 64));
dst8 = _mm256_fmadd_ps(dst8, src80, weight00);
__m256 src90 = _mm256_set1_ps(*(src + 72));
dst9 = _mm256_fmadd_ps(dst9, src90, weight00);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 8);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
__m256 src31 = _mm256_set1_ps(*(src + 25));
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
__m256 src41 = _mm256_set1_ps(*(src + 33));
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
__m256 src51 = _mm256_set1_ps(*(src + 41));
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
__m256 src61 = _mm256_set1_ps(*(src + 49));
dst6 = _mm256_fmadd_ps(dst6, src61, weight01);
__m256 src71 = _mm256_set1_ps(*(src + 57));
dst7 = _mm256_fmadd_ps(dst7, src71, weight01);
__m256 src81 = _mm256_set1_ps(*(src + 65));
dst8 = _mm256_fmadd_ps(dst8, src81, weight01);
__m256 src91 = _mm256_set1_ps(*(src + 73));
dst9 = _mm256_fmadd_ps(dst9, src91, weight01);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 16);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
__m256 src32 = _mm256_set1_ps(*(src + 26));
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
__m256 src42 = _mm256_set1_ps(*(src + 34));
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
__m256 src52 = _mm256_set1_ps(*(src + 42));
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
__m256 src62 = _mm256_set1_ps(*(src + 50));
dst6 = _mm256_fmadd_ps(dst6, src62, weight02);
__m256 src72 = _mm256_set1_ps(*(src + 58));
dst7 = _mm256_fmadd_ps(dst7, src72, weight02);
__m256 src82 = _mm256_set1_ps(*(src + 66));
dst8 = _mm256_fmadd_ps(dst8, src82, weight02);
__m256 src92 = _mm256_set1_ps(*(src + 74));
dst9 = _mm256_fmadd_ps(dst9, src92, weight02);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 24);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
__m256 src33 = _mm256_set1_ps(*(src + 27));
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
__m256 src43 = _mm256_set1_ps(*(src + 35));
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
__m256 src53 = _mm256_set1_ps(*(src + 43));
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
__m256 src63 = _mm256_set1_ps(*(src + 51));
dst6 = _mm256_fmadd_ps(dst6, src63, weight03);
__m256 src73 = _mm256_set1_ps(*(src + 59));
dst7 = _mm256_fmadd_ps(dst7, src73, weight03);
__m256 src83 = _mm256_set1_ps(*(src + 67));
dst8 = _mm256_fmadd_ps(dst8, src83, weight03);
__m256 src93 = _mm256_set1_ps(*(src + 75));
dst9 = _mm256_fmadd_ps(dst9, src93, weight03);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 32);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
__m256 src34 = _mm256_set1_ps(*(src + 28));
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
__m256 src44 = _mm256_set1_ps(*(src + 36));
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
__m256 src54 = _mm256_set1_ps(*(src + 44));
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
__m256 src64 = _mm256_set1_ps(*(src + 52));
dst6 = _mm256_fmadd_ps(dst6, src64, weight04);
__m256 src74 = _mm256_set1_ps(*(src + 60));
dst7 = _mm256_fmadd_ps(dst7, src74, weight04);
__m256 src84 = _mm256_set1_ps(*(src + 68));
dst8 = _mm256_fmadd_ps(dst8, src84, weight04);
__m256 src94 = _mm256_set1_ps(*(src + 76));
dst9 = _mm256_fmadd_ps(dst9, src94, weight04);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 40);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
__m256 src35 = _mm256_set1_ps(*(src + 29));
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
__m256 src45 = _mm256_set1_ps(*(src + 37));
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
__m256 src55 = _mm256_set1_ps(*(src + 45));
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
__m256 src65 = _mm256_set1_ps(*(src + 53));
dst6 = _mm256_fmadd_ps(dst6, src65, weight05);
__m256 src75 = _mm256_set1_ps(*(src + 61));
dst7 = _mm256_fmadd_ps(dst7, src75, weight05);
__m256 src85 = _mm256_set1_ps(*(src + 69));
dst8 = _mm256_fmadd_ps(dst8, src85, weight05);
__m256 src95 = _mm256_set1_ps(*(src + 77));
dst9 = _mm256_fmadd_ps(dst9, src95, weight05);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 48);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
__m256 src36 = _mm256_set1_ps(*(src + 30));
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
__m256 src46 = _mm256_set1_ps(*(src + 38));
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
__m256 src56 = _mm256_set1_ps(*(src + 46));
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
__m256 src66 = _mm256_set1_ps(*(src + 54));
dst6 = _mm256_fmadd_ps(dst6, src66, weight06);
__m256 src76 = _mm256_set1_ps(*(src + 62));
dst7 = _mm256_fmadd_ps(dst7, src76, weight06);
__m256 src86 = _mm256_set1_ps(*(src + 70));
dst8 = _mm256_fmadd_ps(dst8, src86, weight06);
__m256 src96 = _mm256_set1_ps(*(src + 78));
dst9 = _mm256_fmadd_ps(dst9, src96, weight06);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 56);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
__m256 src37 = _mm256_set1_ps(*(src + 31));
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
__m256 src47 = _mm256_set1_ps(*(src + 39));
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
__m256 src57 = _mm256_set1_ps(*(src + 47));
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
__m256 src67 = _mm256_set1_ps(*(src + 55));
dst6 = _mm256_fmadd_ps(dst6, src67, weight07);
__m256 src77 = _mm256_set1_ps(*(src + 63));
dst7 = _mm256_fmadd_ps(dst7, src77, weight07);
__m256 src87 = _mm256_set1_ps(*(src + 71));
dst8 = _mm256_fmadd_ps(dst8, src87, weight07);
__m256 src97 = _mm256_set1_ps(*(src + 79));
dst9 = _mm256_fmadd_ps(dst9, src97, weight07);
src = src + src_stride;
weight += 256;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
dst6 = _mm256_min_ps(dst6, relu6);
dst7 = _mm256_min_ps(dst7, relu6);
dst8 = _mm256_min_ps(dst8, relu6);
dst9 = _mm256_min_ps(dst9, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst9 = _mm256_max_ps(dst9, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst9 = _mm256_max_ps(dst9, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
_mm256_store_ps(dst + 0 * src_stride + 48, dst6);
_mm256_store_ps(dst + 0 * src_stride + 56, dst7);
_mm256_store_ps(dst + 0 * src_stride + 64, dst8);
_mm256_store_ps(dst + 0 * src_stride + 72, dst9);
}

View File

@ -0,0 +1,303 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 96(%[dst]), %%ymm3\n"
"vmovups 128(%[dst]), %%ymm4\n"
"vmovups 160(%[dst]), %%ymm5\n"
"vmovups 192(%[dst]), %%ymm6\n"
"vmovups 224(%[dst]), %%ymm7\n"
"vmovups 256(%[dst]), %%ymm8\n"
"vmovups 288(%[dst]), %%ymm9\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 0(%[bias]), %%ymm3\n"
"vmovaps 0(%[bias]), %%ymm4\n"
"vmovaps 0(%[bias]), %%ymm5\n"
"vmovaps 0(%[bias]), %%ymm6\n"
"vmovaps 0(%[bias]), %%ymm7\n"
"vmovaps 0(%[bias]), %%ymm8\n"
"vmovaps 0(%[bias]), %%ymm9\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
"vxorps %%ymm9, %%ymm9, %%ymm9\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vbroadcastss 0(%[src]), %%ymm14\n"
"vbroadcastss 32(%[src]), %%ymm13\n"
"vbroadcastss 64(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 96(%[src]), %%ymm14\n"
"vbroadcastss 128(%[src]), %%ymm13\n"
"vbroadcastss 160(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 192(%[src]), %%ymm14\n"
"vbroadcastss 224(%[src]), %%ymm13\n"
"vbroadcastss 256(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 288(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
// block 1
"vmovaps 32(%[weight]), %%ymm15\n"
"vbroadcastss 1(%[src]), %%ymm14\n"
"vbroadcastss 33(%[src]), %%ymm13\n"
"vbroadcastss 65(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 97(%[src]), %%ymm14\n"
"vbroadcastss 129(%[src]), %%ymm13\n"
"vbroadcastss 161(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 193(%[src]), %%ymm14\n"
"vbroadcastss 225(%[src]), %%ymm13\n"
"vbroadcastss 257(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 289(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
// block 2
"vmovaps 64(%[weight]), %%ymm15\n"
"vbroadcastss 2(%[src]), %%ymm14\n"
"vbroadcastss 34(%[src]), %%ymm13\n"
"vbroadcastss 66(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 98(%[src]), %%ymm14\n"
"vbroadcastss 130(%[src]), %%ymm13\n"
"vbroadcastss 162(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 194(%[src]), %%ymm14\n"
"vbroadcastss 226(%[src]), %%ymm13\n"
"vbroadcastss 258(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 290(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
// block 3
"vmovaps 96(%[weight]), %%ymm15\n"
"vbroadcastss 3(%[src]), %%ymm14\n"
"vbroadcastss 35(%[src]), %%ymm13\n"
"vbroadcastss 67(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 99(%[src]), %%ymm14\n"
"vbroadcastss 131(%[src]), %%ymm13\n"
"vbroadcastss 163(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 195(%[src]), %%ymm14\n"
"vbroadcastss 227(%[src]), %%ymm13\n"
"vbroadcastss 259(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 291(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
// block 4
"vmovaps 128(%[weight]), %%ymm15\n"
"vbroadcastss 4(%[src]), %%ymm14\n"
"vbroadcastss 36(%[src]), %%ymm13\n"
"vbroadcastss 68(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 100(%[src]), %%ymm14\n"
"vbroadcastss 132(%[src]), %%ymm13\n"
"vbroadcastss 164(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 196(%[src]), %%ymm14\n"
"vbroadcastss 228(%[src]), %%ymm13\n"
"vbroadcastss 260(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 292(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
// block 5
"vmovaps 160(%[weight]), %%ymm15\n"
"vbroadcastss 5(%[src]), %%ymm14\n"
"vbroadcastss 37(%[src]), %%ymm13\n"
"vbroadcastss 69(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 101(%[src]), %%ymm14\n"
"vbroadcastss 133(%[src]), %%ymm13\n"
"vbroadcastss 165(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 197(%[src]), %%ymm14\n"
"vbroadcastss 229(%[src]), %%ymm13\n"
"vbroadcastss 261(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 293(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
// block 6
"vmovaps 192(%[weight]), %%ymm15\n"
"vbroadcastss 6(%[src]), %%ymm14\n"
"vbroadcastss 38(%[src]), %%ymm13\n"
"vbroadcastss 70(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 102(%[src]), %%ymm14\n"
"vbroadcastss 134(%[src]), %%ymm13\n"
"vbroadcastss 166(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 198(%[src]), %%ymm14\n"
"vbroadcastss 230(%[src]), %%ymm13\n"
"vbroadcastss 262(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 294(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
// block 7
"vmovaps 224(%[weight]), %%ymm15\n"
"vbroadcastss 7(%[src]), %%ymm14\n"
"vbroadcastss 39(%[src]), %%ymm13\n"
"vbroadcastss 71(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 103(%[src]), %%ymm14\n"
"vbroadcastss 135(%[src]), %%ymm13\n"
"vbroadcastss 167(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 199(%[src]), %%ymm14\n"
"vbroadcastss 231(%[src]), %%ymm13\n"
"vbroadcastss 263(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 295(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"dec %[deep]\n"
"add 256, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
"vmaxps %%ymm9, %%ymm15, %%ymm9\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"vminps %%ymm6, %%ymm14, %%ymm6\n"
"vminps %%ymm7, %%ymm14, %%ymm7\n"
"vminps %%ymm8, %%ymm14, %%ymm8\n"
"vminps %%ymm9, %%ymm14, %%ymm9\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 96(%[dst])\n"
"vmovups %%ymm4, 128(%[dst])\n"
"vmovups %%ymm5, 160(%[dst])\n"
"vmovups %%ymm6, 192(%[dst])\n"
"vmovups %%ymm7, 224(%[dst])\n"
"vmovups %%ymm8, 256(%[dst])\n"
"vmovups %%ymm9, 288(%[dst])\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,321 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst1;
__m256 dst2;
__m256 dst3;
__m256 dst4;
__m256 dst5;
__m256 dst6;
__m256 dst7;
__m256 dst8;
__m256 dst9;
__m256 dst10;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48);
dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56);
dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64);
dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72);
dst10 = _mm256_load_ps(dst + 0 * dst_stride + 80);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
dst6 = _mm256_setzero_ps();
dst7 = _mm256_setzero_ps();
dst8 = _mm256_setzero_ps();
dst9 = _mm256_setzero_ps();
dst10 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst1 = _mm256_load_ps(bias + 0);
dst2 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 0);
dst4 = _mm256_load_ps(bias + 0);
dst5 = _mm256_load_ps(bias + 0);
dst6 = _mm256_load_ps(bias + 0);
dst7 = _mm256_load_ps(bias + 0);
dst8 = _mm256_load_ps(bias + 0);
dst9 = _mm256_load_ps(bias + 0);
dst10 = _mm256_load_ps(bias + 0);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
__m256 src30 = _mm256_set1_ps(*(src + 24));
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
__m256 src40 = _mm256_set1_ps(*(src + 32));
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
__m256 src50 = _mm256_set1_ps(*(src + 40));
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
__m256 src60 = _mm256_set1_ps(*(src + 48));
dst6 = _mm256_fmadd_ps(dst6, src60, weight00);
__m256 src70 = _mm256_set1_ps(*(src + 56));
dst7 = _mm256_fmadd_ps(dst7, src70, weight00);
__m256 src80 = _mm256_set1_ps(*(src + 64));
dst8 = _mm256_fmadd_ps(dst8, src80, weight00);
__m256 src90 = _mm256_set1_ps(*(src + 72));
dst9 = _mm256_fmadd_ps(dst9, src90, weight00);
__m256 src100 = _mm256_set1_ps(*(src + 80));
dst10 = _mm256_fmadd_ps(dst10, src100, weight00);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 8);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
__m256 src31 = _mm256_set1_ps(*(src + 25));
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
__m256 src41 = _mm256_set1_ps(*(src + 33));
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
__m256 src51 = _mm256_set1_ps(*(src + 41));
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
__m256 src61 = _mm256_set1_ps(*(src + 49));
dst6 = _mm256_fmadd_ps(dst6, src61, weight01);
__m256 src71 = _mm256_set1_ps(*(src + 57));
dst7 = _mm256_fmadd_ps(dst7, src71, weight01);
__m256 src81 = _mm256_set1_ps(*(src + 65));
dst8 = _mm256_fmadd_ps(dst8, src81, weight01);
__m256 src91 = _mm256_set1_ps(*(src + 73));
dst9 = _mm256_fmadd_ps(dst9, src91, weight01);
__m256 src101 = _mm256_set1_ps(*(src + 81));
dst10 = _mm256_fmadd_ps(dst10, src101, weight01);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 16);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
__m256 src32 = _mm256_set1_ps(*(src + 26));
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
__m256 src42 = _mm256_set1_ps(*(src + 34));
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
__m256 src52 = _mm256_set1_ps(*(src + 42));
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
__m256 src62 = _mm256_set1_ps(*(src + 50));
dst6 = _mm256_fmadd_ps(dst6, src62, weight02);
__m256 src72 = _mm256_set1_ps(*(src + 58));
dst7 = _mm256_fmadd_ps(dst7, src72, weight02);
__m256 src82 = _mm256_set1_ps(*(src + 66));
dst8 = _mm256_fmadd_ps(dst8, src82, weight02);
__m256 src92 = _mm256_set1_ps(*(src + 74));
dst9 = _mm256_fmadd_ps(dst9, src92, weight02);
__m256 src102 = _mm256_set1_ps(*(src + 82));
dst10 = _mm256_fmadd_ps(dst10, src102, weight02);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 24);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
__m256 src33 = _mm256_set1_ps(*(src + 27));
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
__m256 src43 = _mm256_set1_ps(*(src + 35));
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
__m256 src53 = _mm256_set1_ps(*(src + 43));
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
__m256 src63 = _mm256_set1_ps(*(src + 51));
dst6 = _mm256_fmadd_ps(dst6, src63, weight03);
__m256 src73 = _mm256_set1_ps(*(src + 59));
dst7 = _mm256_fmadd_ps(dst7, src73, weight03);
__m256 src83 = _mm256_set1_ps(*(src + 67));
dst8 = _mm256_fmadd_ps(dst8, src83, weight03);
__m256 src93 = _mm256_set1_ps(*(src + 75));
dst9 = _mm256_fmadd_ps(dst9, src93, weight03);
__m256 src103 = _mm256_set1_ps(*(src + 83));
dst10 = _mm256_fmadd_ps(dst10, src103, weight03);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 32);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
__m256 src34 = _mm256_set1_ps(*(src + 28));
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
__m256 src44 = _mm256_set1_ps(*(src + 36));
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
__m256 src54 = _mm256_set1_ps(*(src + 44));
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
__m256 src64 = _mm256_set1_ps(*(src + 52));
dst6 = _mm256_fmadd_ps(dst6, src64, weight04);
__m256 src74 = _mm256_set1_ps(*(src + 60));
dst7 = _mm256_fmadd_ps(dst7, src74, weight04);
__m256 src84 = _mm256_set1_ps(*(src + 68));
dst8 = _mm256_fmadd_ps(dst8, src84, weight04);
__m256 src94 = _mm256_set1_ps(*(src + 76));
dst9 = _mm256_fmadd_ps(dst9, src94, weight04);
__m256 src104 = _mm256_set1_ps(*(src + 84));
dst10 = _mm256_fmadd_ps(dst10, src104, weight04);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 40);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
__m256 src35 = _mm256_set1_ps(*(src + 29));
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
__m256 src45 = _mm256_set1_ps(*(src + 37));
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
__m256 src55 = _mm256_set1_ps(*(src + 45));
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
__m256 src65 = _mm256_set1_ps(*(src + 53));
dst6 = _mm256_fmadd_ps(dst6, src65, weight05);
__m256 src75 = _mm256_set1_ps(*(src + 61));
dst7 = _mm256_fmadd_ps(dst7, src75, weight05);
__m256 src85 = _mm256_set1_ps(*(src + 69));
dst8 = _mm256_fmadd_ps(dst8, src85, weight05);
__m256 src95 = _mm256_set1_ps(*(src + 77));
dst9 = _mm256_fmadd_ps(dst9, src95, weight05);
__m256 src105 = _mm256_set1_ps(*(src + 85));
dst10 = _mm256_fmadd_ps(dst10, src105, weight05);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 48);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
__m256 src36 = _mm256_set1_ps(*(src + 30));
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
__m256 src46 = _mm256_set1_ps(*(src + 38));
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
__m256 src56 = _mm256_set1_ps(*(src + 46));
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
__m256 src66 = _mm256_set1_ps(*(src + 54));
dst6 = _mm256_fmadd_ps(dst6, src66, weight06);
__m256 src76 = _mm256_set1_ps(*(src + 62));
dst7 = _mm256_fmadd_ps(dst7, src76, weight06);
__m256 src86 = _mm256_set1_ps(*(src + 70));
dst8 = _mm256_fmadd_ps(dst8, src86, weight06);
__m256 src96 = _mm256_set1_ps(*(src + 78));
dst9 = _mm256_fmadd_ps(dst9, src96, weight06);
__m256 src106 = _mm256_set1_ps(*(src + 86));
dst10 = _mm256_fmadd_ps(dst10, src106, weight06);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 56);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
__m256 src37 = _mm256_set1_ps(*(src + 31));
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
__m256 src47 = _mm256_set1_ps(*(src + 39));
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
__m256 src57 = _mm256_set1_ps(*(src + 47));
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
__m256 src67 = _mm256_set1_ps(*(src + 55));
dst6 = _mm256_fmadd_ps(dst6, src67, weight07);
__m256 src77 = _mm256_set1_ps(*(src + 63));
dst7 = _mm256_fmadd_ps(dst7, src77, weight07);
__m256 src87 = _mm256_set1_ps(*(src + 71));
dst8 = _mm256_fmadd_ps(dst8, src87, weight07);
__m256 src97 = _mm256_set1_ps(*(src + 79));
dst9 = _mm256_fmadd_ps(dst9, src97, weight07);
__m256 src107 = _mm256_set1_ps(*(src + 87));
dst10 = _mm256_fmadd_ps(dst10, src107, weight07);
src = src + src_stride;
weight += 256;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
dst6 = _mm256_min_ps(dst6, relu6);
dst7 = _mm256_min_ps(dst7, relu6);
dst8 = _mm256_min_ps(dst8, relu6);
dst9 = _mm256_min_ps(dst9, relu6);
dst10 = _mm256_min_ps(dst10, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst9 = _mm256_max_ps(dst9, relu);
dst10 = _mm256_max_ps(dst10, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst9 = _mm256_max_ps(dst9, relu);
dst10 = _mm256_max_ps(dst10, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
_mm256_store_ps(dst + 0 * src_stride + 48, dst6);
_mm256_store_ps(dst + 0 * src_stride + 56, dst7);
_mm256_store_ps(dst + 0 * src_stride + 64, dst8);
_mm256_store_ps(dst + 0 * src_stride + 72, dst9);
_mm256_store_ps(dst + 0 * src_stride + 80, dst10);
}

View File

@ -0,0 +1,325 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 96(%[dst]), %%ymm3\n"
"vmovups 128(%[dst]), %%ymm4\n"
"vmovups 160(%[dst]), %%ymm5\n"
"vmovups 192(%[dst]), %%ymm6\n"
"vmovups 224(%[dst]), %%ymm7\n"
"vmovups 256(%[dst]), %%ymm8\n"
"vmovups 288(%[dst]), %%ymm9\n"
"vmovups 320(%[dst]), %%ymm10\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 0(%[bias]), %%ymm3\n"
"vmovaps 0(%[bias]), %%ymm4\n"
"vmovaps 0(%[bias]), %%ymm5\n"
"vmovaps 0(%[bias]), %%ymm6\n"
"vmovaps 0(%[bias]), %%ymm7\n"
"vmovaps 0(%[bias]), %%ymm8\n"
"vmovaps 0(%[bias]), %%ymm9\n"
"vmovaps 0(%[bias]), %%ymm10\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
"vxorps %%ymm9, %%ymm9, %%ymm9\n"
"vxorps %%ymm10, %%ymm10, %%ymm10\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vbroadcastss 0(%[src]), %%ymm14\n"
"vbroadcastss 32(%[src]), %%ymm13\n"
"vbroadcastss 64(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 96(%[src]), %%ymm14\n"
"vbroadcastss 128(%[src]), %%ymm13\n"
"vbroadcastss 160(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 192(%[src]), %%ymm14\n"
"vbroadcastss 224(%[src]), %%ymm13\n"
"vbroadcastss 256(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 288(%[src]), %%ymm14\n"
"vbroadcastss 320(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
// block 1
"vmovaps 32(%[weight]), %%ymm15\n"
"vbroadcastss 1(%[src]), %%ymm14\n"
"vbroadcastss 33(%[src]), %%ymm13\n"
"vbroadcastss 65(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 97(%[src]), %%ymm14\n"
"vbroadcastss 129(%[src]), %%ymm13\n"
"vbroadcastss 161(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 193(%[src]), %%ymm14\n"
"vbroadcastss 225(%[src]), %%ymm13\n"
"vbroadcastss 257(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 289(%[src]), %%ymm14\n"
"vbroadcastss 321(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
// block 2
"vmovaps 64(%[weight]), %%ymm15\n"
"vbroadcastss 2(%[src]), %%ymm14\n"
"vbroadcastss 34(%[src]), %%ymm13\n"
"vbroadcastss 66(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 98(%[src]), %%ymm14\n"
"vbroadcastss 130(%[src]), %%ymm13\n"
"vbroadcastss 162(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 194(%[src]), %%ymm14\n"
"vbroadcastss 226(%[src]), %%ymm13\n"
"vbroadcastss 258(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 290(%[src]), %%ymm14\n"
"vbroadcastss 322(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
// block 3
"vmovaps 96(%[weight]), %%ymm15\n"
"vbroadcastss 3(%[src]), %%ymm14\n"
"vbroadcastss 35(%[src]), %%ymm13\n"
"vbroadcastss 67(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 99(%[src]), %%ymm14\n"
"vbroadcastss 131(%[src]), %%ymm13\n"
"vbroadcastss 163(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 195(%[src]), %%ymm14\n"
"vbroadcastss 227(%[src]), %%ymm13\n"
"vbroadcastss 259(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 291(%[src]), %%ymm14\n"
"vbroadcastss 323(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
// block 4
"vmovaps 128(%[weight]), %%ymm15\n"
"vbroadcastss 4(%[src]), %%ymm14\n"
"vbroadcastss 36(%[src]), %%ymm13\n"
"vbroadcastss 68(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 100(%[src]), %%ymm14\n"
"vbroadcastss 132(%[src]), %%ymm13\n"
"vbroadcastss 164(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 196(%[src]), %%ymm14\n"
"vbroadcastss 228(%[src]), %%ymm13\n"
"vbroadcastss 260(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 292(%[src]), %%ymm14\n"
"vbroadcastss 324(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
// block 5
"vmovaps 160(%[weight]), %%ymm15\n"
"vbroadcastss 5(%[src]), %%ymm14\n"
"vbroadcastss 37(%[src]), %%ymm13\n"
"vbroadcastss 69(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 101(%[src]), %%ymm14\n"
"vbroadcastss 133(%[src]), %%ymm13\n"
"vbroadcastss 165(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 197(%[src]), %%ymm14\n"
"vbroadcastss 229(%[src]), %%ymm13\n"
"vbroadcastss 261(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 293(%[src]), %%ymm14\n"
"vbroadcastss 325(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
// block 6
"vmovaps 192(%[weight]), %%ymm15\n"
"vbroadcastss 6(%[src]), %%ymm14\n"
"vbroadcastss 38(%[src]), %%ymm13\n"
"vbroadcastss 70(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 102(%[src]), %%ymm14\n"
"vbroadcastss 134(%[src]), %%ymm13\n"
"vbroadcastss 166(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 198(%[src]), %%ymm14\n"
"vbroadcastss 230(%[src]), %%ymm13\n"
"vbroadcastss 262(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 294(%[src]), %%ymm14\n"
"vbroadcastss 326(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
// block 7
"vmovaps 224(%[weight]), %%ymm15\n"
"vbroadcastss 7(%[src]), %%ymm14\n"
"vbroadcastss 39(%[src]), %%ymm13\n"
"vbroadcastss 71(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 103(%[src]), %%ymm14\n"
"vbroadcastss 135(%[src]), %%ymm13\n"
"vbroadcastss 167(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 199(%[src]), %%ymm14\n"
"vbroadcastss 231(%[src]), %%ymm13\n"
"vbroadcastss 263(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 295(%[src]), %%ymm14\n"
"vbroadcastss 327(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
"dec %[deep]\n"
"add 256, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
"vmaxps %%ymm9, %%ymm15, %%ymm9\n"
"vmaxps %%ymm10, %%ymm15, %%ymm10\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"vminps %%ymm6, %%ymm14, %%ymm6\n"
"vminps %%ymm7, %%ymm14, %%ymm7\n"
"vminps %%ymm8, %%ymm14, %%ymm8\n"
"vminps %%ymm9, %%ymm14, %%ymm9\n"
"vminps %%ymm10, %%ymm14, %%ymm10\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 96(%[dst])\n"
"vmovups %%ymm4, 128(%[dst])\n"
"vmovups %%ymm5, 160(%[dst])\n"
"vmovups %%ymm6, 192(%[dst])\n"
"vmovups %%ymm7, 224(%[dst])\n"
"vmovups %%ymm8, 256(%[dst])\n"
"vmovups %%ymm9, 288(%[dst])\n"
"vmovups %%ymm10, 320(%[dst])\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,345 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst1;
__m256 dst2;
__m256 dst3;
__m256 dst4;
__m256 dst5;
__m256 dst6;
__m256 dst7;
__m256 dst8;
__m256 dst9;
__m256 dst10;
__m256 dst11;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48);
dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56);
dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64);
dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72);
dst10 = _mm256_load_ps(dst + 0 * dst_stride + 80);
dst11 = _mm256_load_ps(dst + 0 * dst_stride + 88);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
dst6 = _mm256_setzero_ps();
dst7 = _mm256_setzero_ps();
dst8 = _mm256_setzero_ps();
dst9 = _mm256_setzero_ps();
dst10 = _mm256_setzero_ps();
dst11 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst1 = _mm256_load_ps(bias + 0);
dst2 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 0);
dst4 = _mm256_load_ps(bias + 0);
dst5 = _mm256_load_ps(bias + 0);
dst6 = _mm256_load_ps(bias + 0);
dst7 = _mm256_load_ps(bias + 0);
dst8 = _mm256_load_ps(bias + 0);
dst9 = _mm256_load_ps(bias + 0);
dst10 = _mm256_load_ps(bias + 0);
dst11 = _mm256_load_ps(bias + 0);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
__m256 src30 = _mm256_set1_ps(*(src + 24));
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
__m256 src40 = _mm256_set1_ps(*(src + 32));
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
__m256 src50 = _mm256_set1_ps(*(src + 40));
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
__m256 src60 = _mm256_set1_ps(*(src + 48));
dst6 = _mm256_fmadd_ps(dst6, src60, weight00);
__m256 src70 = _mm256_set1_ps(*(src + 56));
dst7 = _mm256_fmadd_ps(dst7, src70, weight00);
__m256 src80 = _mm256_set1_ps(*(src + 64));
dst8 = _mm256_fmadd_ps(dst8, src80, weight00);
__m256 src90 = _mm256_set1_ps(*(src + 72));
dst9 = _mm256_fmadd_ps(dst9, src90, weight00);
__m256 src100 = _mm256_set1_ps(*(src + 80));
dst10 = _mm256_fmadd_ps(dst10, src100, weight00);
__m256 src110 = _mm256_set1_ps(*(src + 88));
dst11 = _mm256_fmadd_ps(dst11, src110, weight00);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 8);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
__m256 src31 = _mm256_set1_ps(*(src + 25));
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
__m256 src41 = _mm256_set1_ps(*(src + 33));
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
__m256 src51 = _mm256_set1_ps(*(src + 41));
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
__m256 src61 = _mm256_set1_ps(*(src + 49));
dst6 = _mm256_fmadd_ps(dst6, src61, weight01);
__m256 src71 = _mm256_set1_ps(*(src + 57));
dst7 = _mm256_fmadd_ps(dst7, src71, weight01);
__m256 src81 = _mm256_set1_ps(*(src + 65));
dst8 = _mm256_fmadd_ps(dst8, src81, weight01);
__m256 src91 = _mm256_set1_ps(*(src + 73));
dst9 = _mm256_fmadd_ps(dst9, src91, weight01);
__m256 src101 = _mm256_set1_ps(*(src + 81));
dst10 = _mm256_fmadd_ps(dst10, src101, weight01);
__m256 src111 = _mm256_set1_ps(*(src + 89));
dst11 = _mm256_fmadd_ps(dst11, src111, weight01);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 16);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
__m256 src32 = _mm256_set1_ps(*(src + 26));
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
__m256 src42 = _mm256_set1_ps(*(src + 34));
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
__m256 src52 = _mm256_set1_ps(*(src + 42));
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
__m256 src62 = _mm256_set1_ps(*(src + 50));
dst6 = _mm256_fmadd_ps(dst6, src62, weight02);
__m256 src72 = _mm256_set1_ps(*(src + 58));
dst7 = _mm256_fmadd_ps(dst7, src72, weight02);
__m256 src82 = _mm256_set1_ps(*(src + 66));
dst8 = _mm256_fmadd_ps(dst8, src82, weight02);
__m256 src92 = _mm256_set1_ps(*(src + 74));
dst9 = _mm256_fmadd_ps(dst9, src92, weight02);
__m256 src102 = _mm256_set1_ps(*(src + 82));
dst10 = _mm256_fmadd_ps(dst10, src102, weight02);
__m256 src112 = _mm256_set1_ps(*(src + 90));
dst11 = _mm256_fmadd_ps(dst11, src112, weight02);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 24);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
__m256 src33 = _mm256_set1_ps(*(src + 27));
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
__m256 src43 = _mm256_set1_ps(*(src + 35));
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
__m256 src53 = _mm256_set1_ps(*(src + 43));
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
__m256 src63 = _mm256_set1_ps(*(src + 51));
dst6 = _mm256_fmadd_ps(dst6, src63, weight03);
__m256 src73 = _mm256_set1_ps(*(src + 59));
dst7 = _mm256_fmadd_ps(dst7, src73, weight03);
__m256 src83 = _mm256_set1_ps(*(src + 67));
dst8 = _mm256_fmadd_ps(dst8, src83, weight03);
__m256 src93 = _mm256_set1_ps(*(src + 75));
dst9 = _mm256_fmadd_ps(dst9, src93, weight03);
__m256 src103 = _mm256_set1_ps(*(src + 83));
dst10 = _mm256_fmadd_ps(dst10, src103, weight03);
__m256 src113 = _mm256_set1_ps(*(src + 91));
dst11 = _mm256_fmadd_ps(dst11, src113, weight03);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 32);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
__m256 src34 = _mm256_set1_ps(*(src + 28));
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
__m256 src44 = _mm256_set1_ps(*(src + 36));
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
__m256 src54 = _mm256_set1_ps(*(src + 44));
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
__m256 src64 = _mm256_set1_ps(*(src + 52));
dst6 = _mm256_fmadd_ps(dst6, src64, weight04);
__m256 src74 = _mm256_set1_ps(*(src + 60));
dst7 = _mm256_fmadd_ps(dst7, src74, weight04);
__m256 src84 = _mm256_set1_ps(*(src + 68));
dst8 = _mm256_fmadd_ps(dst8, src84, weight04);
__m256 src94 = _mm256_set1_ps(*(src + 76));
dst9 = _mm256_fmadd_ps(dst9, src94, weight04);
__m256 src104 = _mm256_set1_ps(*(src + 84));
dst10 = _mm256_fmadd_ps(dst10, src104, weight04);
__m256 src114 = _mm256_set1_ps(*(src + 92));
dst11 = _mm256_fmadd_ps(dst11, src114, weight04);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 40);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
__m256 src35 = _mm256_set1_ps(*(src + 29));
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
__m256 src45 = _mm256_set1_ps(*(src + 37));
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
__m256 src55 = _mm256_set1_ps(*(src + 45));
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
__m256 src65 = _mm256_set1_ps(*(src + 53));
dst6 = _mm256_fmadd_ps(dst6, src65, weight05);
__m256 src75 = _mm256_set1_ps(*(src + 61));
dst7 = _mm256_fmadd_ps(dst7, src75, weight05);
__m256 src85 = _mm256_set1_ps(*(src + 69));
dst8 = _mm256_fmadd_ps(dst8, src85, weight05);
__m256 src95 = _mm256_set1_ps(*(src + 77));
dst9 = _mm256_fmadd_ps(dst9, src95, weight05);
__m256 src105 = _mm256_set1_ps(*(src + 85));
dst10 = _mm256_fmadd_ps(dst10, src105, weight05);
__m256 src115 = _mm256_set1_ps(*(src + 93));
dst11 = _mm256_fmadd_ps(dst11, src115, weight05);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 48);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
__m256 src36 = _mm256_set1_ps(*(src + 30));
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
__m256 src46 = _mm256_set1_ps(*(src + 38));
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
__m256 src56 = _mm256_set1_ps(*(src + 46));
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
__m256 src66 = _mm256_set1_ps(*(src + 54));
dst6 = _mm256_fmadd_ps(dst6, src66, weight06);
__m256 src76 = _mm256_set1_ps(*(src + 62));
dst7 = _mm256_fmadd_ps(dst7, src76, weight06);
__m256 src86 = _mm256_set1_ps(*(src + 70));
dst8 = _mm256_fmadd_ps(dst8, src86, weight06);
__m256 src96 = _mm256_set1_ps(*(src + 78));
dst9 = _mm256_fmadd_ps(dst9, src96, weight06);
__m256 src106 = _mm256_set1_ps(*(src + 86));
dst10 = _mm256_fmadd_ps(dst10, src106, weight06);
__m256 src116 = _mm256_set1_ps(*(src + 94));
dst11 = _mm256_fmadd_ps(dst11, src116, weight06);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 56);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
__m256 src37 = _mm256_set1_ps(*(src + 31));
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
__m256 src47 = _mm256_set1_ps(*(src + 39));
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
__m256 src57 = _mm256_set1_ps(*(src + 47));
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
__m256 src67 = _mm256_set1_ps(*(src + 55));
dst6 = _mm256_fmadd_ps(dst6, src67, weight07);
__m256 src77 = _mm256_set1_ps(*(src + 63));
dst7 = _mm256_fmadd_ps(dst7, src77, weight07);
__m256 src87 = _mm256_set1_ps(*(src + 71));
dst8 = _mm256_fmadd_ps(dst8, src87, weight07);
__m256 src97 = _mm256_set1_ps(*(src + 79));
dst9 = _mm256_fmadd_ps(dst9, src97, weight07);
__m256 src107 = _mm256_set1_ps(*(src + 87));
dst10 = _mm256_fmadd_ps(dst10, src107, weight07);
__m256 src117 = _mm256_set1_ps(*(src + 95));
dst11 = _mm256_fmadd_ps(dst11, src117, weight07);
src = src + src_stride;
weight += 256;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
dst6 = _mm256_min_ps(dst6, relu6);
dst7 = _mm256_min_ps(dst7, relu6);
dst8 = _mm256_min_ps(dst8, relu6);
dst9 = _mm256_min_ps(dst9, relu6);
dst10 = _mm256_min_ps(dst10, relu6);
dst11 = _mm256_min_ps(dst11, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst9 = _mm256_max_ps(dst9, relu);
dst10 = _mm256_max_ps(dst10, relu);
dst11 = _mm256_max_ps(dst11, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst9 = _mm256_max_ps(dst9, relu);
dst10 = _mm256_max_ps(dst10, relu);
dst11 = _mm256_max_ps(dst11, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
_mm256_store_ps(dst + 0 * src_stride + 48, dst6);
_mm256_store_ps(dst + 0 * src_stride + 56, dst7);
_mm256_store_ps(dst + 0 * src_stride + 64, dst8);
_mm256_store_ps(dst + 0 * src_stride + 72, dst9);
_mm256_store_ps(dst + 0 * src_stride + 80, dst10);
_mm256_store_ps(dst + 0 * src_stride + 88, dst11);
}

View File

@ -0,0 +1,347 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 96(%[dst]), %%ymm3\n"
"vmovups 128(%[dst]), %%ymm4\n"
"vmovups 160(%[dst]), %%ymm5\n"
"vmovups 192(%[dst]), %%ymm6\n"
"vmovups 224(%[dst]), %%ymm7\n"
"vmovups 256(%[dst]), %%ymm8\n"
"vmovups 288(%[dst]), %%ymm9\n"
"vmovups 320(%[dst]), %%ymm10\n"
"vmovups 352(%[dst]), %%ymm11\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 0(%[bias]), %%ymm3\n"
"vmovaps 0(%[bias]), %%ymm4\n"
"vmovaps 0(%[bias]), %%ymm5\n"
"vmovaps 0(%[bias]), %%ymm6\n"
"vmovaps 0(%[bias]), %%ymm7\n"
"vmovaps 0(%[bias]), %%ymm8\n"
"vmovaps 0(%[bias]), %%ymm9\n"
"vmovaps 0(%[bias]), %%ymm10\n"
"vmovaps 0(%[bias]), %%ymm11\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
"vxorps %%ymm9, %%ymm9, %%ymm9\n"
"vxorps %%ymm10, %%ymm10, %%ymm10\n"
"vxorps %%ymm11, %%ymm11, %%ymm11\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vbroadcastss 0(%[src]), %%ymm14\n"
"vbroadcastss 32(%[src]), %%ymm13\n"
"vbroadcastss 64(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 96(%[src]), %%ymm14\n"
"vbroadcastss 128(%[src]), %%ymm13\n"
"vbroadcastss 160(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 192(%[src]), %%ymm14\n"
"vbroadcastss 224(%[src]), %%ymm13\n"
"vbroadcastss 256(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 288(%[src]), %%ymm14\n"
"vbroadcastss 320(%[src]), %%ymm13\n"
"vbroadcastss 352(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
// block 1
"vmovaps 32(%[weight]), %%ymm15\n"
"vbroadcastss 1(%[src]), %%ymm14\n"
"vbroadcastss 33(%[src]), %%ymm13\n"
"vbroadcastss 65(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 97(%[src]), %%ymm14\n"
"vbroadcastss 129(%[src]), %%ymm13\n"
"vbroadcastss 161(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 193(%[src]), %%ymm14\n"
"vbroadcastss 225(%[src]), %%ymm13\n"
"vbroadcastss 257(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 289(%[src]), %%ymm14\n"
"vbroadcastss 321(%[src]), %%ymm13\n"
"vbroadcastss 353(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
// block 2
"vmovaps 64(%[weight]), %%ymm15\n"
"vbroadcastss 2(%[src]), %%ymm14\n"
"vbroadcastss 34(%[src]), %%ymm13\n"
"vbroadcastss 66(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 98(%[src]), %%ymm14\n"
"vbroadcastss 130(%[src]), %%ymm13\n"
"vbroadcastss 162(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 194(%[src]), %%ymm14\n"
"vbroadcastss 226(%[src]), %%ymm13\n"
"vbroadcastss 258(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 290(%[src]), %%ymm14\n"
"vbroadcastss 322(%[src]), %%ymm13\n"
"vbroadcastss 354(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
// block 3
"vmovaps 96(%[weight]), %%ymm15\n"
"vbroadcastss 3(%[src]), %%ymm14\n"
"vbroadcastss 35(%[src]), %%ymm13\n"
"vbroadcastss 67(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 99(%[src]), %%ymm14\n"
"vbroadcastss 131(%[src]), %%ymm13\n"
"vbroadcastss 163(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 195(%[src]), %%ymm14\n"
"vbroadcastss 227(%[src]), %%ymm13\n"
"vbroadcastss 259(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 291(%[src]), %%ymm14\n"
"vbroadcastss 323(%[src]), %%ymm13\n"
"vbroadcastss 355(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
// block 4
"vmovaps 128(%[weight]), %%ymm15\n"
"vbroadcastss 4(%[src]), %%ymm14\n"
"vbroadcastss 36(%[src]), %%ymm13\n"
"vbroadcastss 68(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 100(%[src]), %%ymm14\n"
"vbroadcastss 132(%[src]), %%ymm13\n"
"vbroadcastss 164(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 196(%[src]), %%ymm14\n"
"vbroadcastss 228(%[src]), %%ymm13\n"
"vbroadcastss 260(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 292(%[src]), %%ymm14\n"
"vbroadcastss 324(%[src]), %%ymm13\n"
"vbroadcastss 356(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
// block 5
"vmovaps 160(%[weight]), %%ymm15\n"
"vbroadcastss 5(%[src]), %%ymm14\n"
"vbroadcastss 37(%[src]), %%ymm13\n"
"vbroadcastss 69(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 101(%[src]), %%ymm14\n"
"vbroadcastss 133(%[src]), %%ymm13\n"
"vbroadcastss 165(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 197(%[src]), %%ymm14\n"
"vbroadcastss 229(%[src]), %%ymm13\n"
"vbroadcastss 261(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 293(%[src]), %%ymm14\n"
"vbroadcastss 325(%[src]), %%ymm13\n"
"vbroadcastss 357(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
// block 6
"vmovaps 192(%[weight]), %%ymm15\n"
"vbroadcastss 6(%[src]), %%ymm14\n"
"vbroadcastss 38(%[src]), %%ymm13\n"
"vbroadcastss 70(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 102(%[src]), %%ymm14\n"
"vbroadcastss 134(%[src]), %%ymm13\n"
"vbroadcastss 166(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 198(%[src]), %%ymm14\n"
"vbroadcastss 230(%[src]), %%ymm13\n"
"vbroadcastss 262(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 294(%[src]), %%ymm14\n"
"vbroadcastss 326(%[src]), %%ymm13\n"
"vbroadcastss 358(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
// block 7
"vmovaps 224(%[weight]), %%ymm15\n"
"vbroadcastss 7(%[src]), %%ymm14\n"
"vbroadcastss 39(%[src]), %%ymm13\n"
"vbroadcastss 71(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 103(%[src]), %%ymm14\n"
"vbroadcastss 135(%[src]), %%ymm13\n"
"vbroadcastss 167(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 199(%[src]), %%ymm14\n"
"vbroadcastss 231(%[src]), %%ymm13\n"
"vbroadcastss 263(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"vbroadcastss 295(%[src]), %%ymm14\n"
"vbroadcastss 327(%[src]), %%ymm13\n"
"vbroadcastss 359(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
"dec %[deep]\n"
"add 256, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
"vmaxps %%ymm9, %%ymm15, %%ymm9\n"
"vmaxps %%ymm10, %%ymm15, %%ymm10\n"
"vmaxps %%ymm11, %%ymm15, %%ymm11\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"vminps %%ymm6, %%ymm14, %%ymm6\n"
"vminps %%ymm7, %%ymm14, %%ymm7\n"
"vminps %%ymm8, %%ymm14, %%ymm8\n"
"vminps %%ymm9, %%ymm14, %%ymm9\n"
"vminps %%ymm10, %%ymm14, %%ymm10\n"
"vminps %%ymm11, %%ymm14, %%ymm11\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 96(%[dst])\n"
"vmovups %%ymm4, 128(%[dst])\n"
"vmovups %%ymm5, 160(%[dst])\n"
"vmovups %%ymm6, 192(%[dst])\n"
"vmovups %%ymm7, 224(%[dst])\n"
"vmovups %%ymm8, 256(%[dst])\n"
"vmovups %%ymm9, 288(%[dst])\n"
"vmovups %%ymm10, 320(%[dst])\n"
"vmovups %%ymm11, 352(%[dst])\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,105 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst1;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst1 = _mm256_load_ps(bias + 8);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 weight10 = _mm256_load_ps(weight + 8);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
dst1 = _mm256_fmadd_ps(dst1, src00, weight10);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 16);
__m256 weight11 = _mm256_load_ps(weight + 24);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
dst1 = _mm256_fmadd_ps(dst1, src01, weight11);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 32);
__m256 weight12 = _mm256_load_ps(weight + 40);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
dst1 = _mm256_fmadd_ps(dst1, src02, weight12);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 48);
__m256 weight13 = _mm256_load_ps(weight + 56);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
dst1 = _mm256_fmadd_ps(dst1, src03, weight13);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 64);
__m256 weight14 = _mm256_load_ps(weight + 72);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
dst1 = _mm256_fmadd_ps(dst1, src04, weight14);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 80);
__m256 weight15 = _mm256_load_ps(weight + 88);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
dst1 = _mm256_fmadd_ps(dst1, src05, weight15);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 96);
__m256 weight16 = _mm256_load_ps(weight + 104);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
dst1 = _mm256_fmadd_ps(dst1, src06, weight16);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 112);
__m256 weight17 = _mm256_load_ps(weight + 120);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
dst1 = _mm256_fmadd_ps(dst1, src07, weight17);
src = src + src_stride;
weight += 512;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 1 * src_stride + 0, dst1);
}

View File

@ -0,0 +1,127 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 32(%[bias]), %%ymm1\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vmovaps 32(%[weight]), %%ymm14\n"
"vbroadcastss 0(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
// block 1
"vmovaps 64(%[weight]), %%ymm15\n"
"vmovaps 96(%[weight]), %%ymm14\n"
"vbroadcastss 1(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
// block 2
"vmovaps 128(%[weight]), %%ymm15\n"
"vmovaps 160(%[weight]), %%ymm14\n"
"vbroadcastss 2(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
// block 3
"vmovaps 192(%[weight]), %%ymm15\n"
"vmovaps 224(%[weight]), %%ymm14\n"
"vbroadcastss 3(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
// block 4
"vmovaps 256(%[weight]), %%ymm15\n"
"vmovaps 288(%[weight]), %%ymm14\n"
"vbroadcastss 4(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
// block 5
"vmovaps 320(%[weight]), %%ymm15\n"
"vmovaps 352(%[weight]), %%ymm14\n"
"vbroadcastss 5(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
// block 6
"vmovaps 384(%[weight]), %%ymm15\n"
"vmovaps 416(%[weight]), %%ymm14\n"
"vbroadcastss 6(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
// block 7
"vmovaps 448(%[weight]), %%ymm15\n"
"vmovaps 480(%[weight]), %%ymm14\n"
"vbroadcastss 7(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
"dec %[deep]\n"
"add 512, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,129 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst1;
__m256 dst2;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0);
dst2 = _mm256_load_ps(dst + 2 * dst_stride + 0);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst1 = _mm256_load_ps(bias + 8);
dst2 = _mm256_load_ps(bias + 16);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 weight10 = _mm256_load_ps(weight + 8);
__m256 weight20 = _mm256_load_ps(weight + 16);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
dst1 = _mm256_fmadd_ps(dst1, src00, weight10);
dst2 = _mm256_fmadd_ps(dst2, src00, weight20);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 24);
__m256 weight11 = _mm256_load_ps(weight + 32);
__m256 weight21 = _mm256_load_ps(weight + 40);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
dst1 = _mm256_fmadd_ps(dst1, src01, weight11);
dst2 = _mm256_fmadd_ps(dst2, src01, weight21);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 48);
__m256 weight12 = _mm256_load_ps(weight + 56);
__m256 weight22 = _mm256_load_ps(weight + 64);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
dst1 = _mm256_fmadd_ps(dst1, src02, weight12);
dst2 = _mm256_fmadd_ps(dst2, src02, weight22);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 72);
__m256 weight13 = _mm256_load_ps(weight + 80);
__m256 weight23 = _mm256_load_ps(weight + 88);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
dst1 = _mm256_fmadd_ps(dst1, src03, weight13);
dst2 = _mm256_fmadd_ps(dst2, src03, weight23);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 96);
__m256 weight14 = _mm256_load_ps(weight + 104);
__m256 weight24 = _mm256_load_ps(weight + 112);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
dst1 = _mm256_fmadd_ps(dst1, src04, weight14);
dst2 = _mm256_fmadd_ps(dst2, src04, weight24);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 120);
__m256 weight15 = _mm256_load_ps(weight + 128);
__m256 weight25 = _mm256_load_ps(weight + 136);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
dst1 = _mm256_fmadd_ps(dst1, src05, weight15);
dst2 = _mm256_fmadd_ps(dst2, src05, weight25);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 144);
__m256 weight16 = _mm256_load_ps(weight + 152);
__m256 weight26 = _mm256_load_ps(weight + 160);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
dst1 = _mm256_fmadd_ps(dst1, src06, weight16);
dst2 = _mm256_fmadd_ps(dst2, src06, weight26);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 168);
__m256 weight17 = _mm256_load_ps(weight + 176);
__m256 weight27 = _mm256_load_ps(weight + 184);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
dst1 = _mm256_fmadd_ps(dst1, src07, weight17);
dst2 = _mm256_fmadd_ps(dst2, src07, weight27);
src = src + src_stride;
weight += 768;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 1 * src_stride + 0, dst1);
_mm256_store_ps(dst + 2 * src_stride + 0, dst2);
}

View File

@ -0,0 +1,149 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n"
"vmovups 0(%[dst], %[dst_stride], 2), %%ymm2\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 32(%[bias]), %%ymm1\n"
"vmovaps 64(%[bias]), %%ymm2\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vmovaps 32(%[weight]), %%ymm14\n"
"vmovaps 64(%[weight]), %%ymm13\n"
"vbroadcastss 0(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
// block 1
"vmovaps 96(%[weight]), %%ymm15\n"
"vmovaps 128(%[weight]), %%ymm14\n"
"vmovaps 160(%[weight]), %%ymm13\n"
"vbroadcastss 1(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
// block 2
"vmovaps 192(%[weight]), %%ymm15\n"
"vmovaps 224(%[weight]), %%ymm14\n"
"vmovaps 256(%[weight]), %%ymm13\n"
"vbroadcastss 2(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
// block 3
"vmovaps 288(%[weight]), %%ymm15\n"
"vmovaps 320(%[weight]), %%ymm14\n"
"vmovaps 352(%[weight]), %%ymm13\n"
"vbroadcastss 3(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
// block 4
"vmovaps 384(%[weight]), %%ymm15\n"
"vmovaps 416(%[weight]), %%ymm14\n"
"vmovaps 448(%[weight]), %%ymm13\n"
"vbroadcastss 4(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
// block 5
"vmovaps 480(%[weight]), %%ymm15\n"
"vmovaps 512(%[weight]), %%ymm14\n"
"vmovaps 544(%[weight]), %%ymm13\n"
"vbroadcastss 5(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
// block 6
"vmovaps 576(%[weight]), %%ymm15\n"
"vmovaps 608(%[weight]), %%ymm14\n"
"vmovaps 640(%[weight]), %%ymm13\n"
"vbroadcastss 6(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
// block 7
"vmovaps 672(%[weight]), %%ymm15\n"
"vmovaps 704(%[weight]), %%ymm14\n"
"vmovaps 736(%[weight]), %%ymm13\n"
"vbroadcastss 7(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
"dec %[deep]\n"
"add 768, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm2, 0(%[dst], %[dst_stride], 2)\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,153 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst1;
__m256 dst2;
__m256 dst3;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0);
dst2 = _mm256_load_ps(dst + 2 * dst_stride + 0);
dst3 = _mm256_load_ps(dst + 3 * dst_stride + 0);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst1 = _mm256_load_ps(bias + 8);
dst2 = _mm256_load_ps(bias + 16);
dst3 = _mm256_load_ps(bias + 24);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 src00 = _mm256_set1_ps(*(src + 0));
__m256 weight00 = _mm256_load_ps(weight + 0);
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
__m256 weight10 = _mm256_load_ps(weight + 8);
dst1 = _mm256_fmadd_ps(dst1, src00, weight10);
__m256 weight20 = _mm256_load_ps(weight + 16);
dst2 = _mm256_fmadd_ps(dst2, src00, weight20);
__m256 weight30 = _mm256_load_ps(weight + 24);
dst3 = _mm256_fmadd_ps(dst3, src00, weight30);
// bock1
__m256 src01 = _mm256_set1_ps(*(src + 1));
__m256 weight01 = _mm256_load_ps(weight + 32);
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
__m256 weight11 = _mm256_load_ps(weight + 40);
dst1 = _mm256_fmadd_ps(dst1, src01, weight11);
__m256 weight21 = _mm256_load_ps(weight + 48);
dst2 = _mm256_fmadd_ps(dst2, src01, weight21);
__m256 weight31 = _mm256_load_ps(weight + 56);
dst3 = _mm256_fmadd_ps(dst3, src01, weight31);
// bock2
__m256 src02 = _mm256_set1_ps(*(src + 2));
__m256 weight02 = _mm256_load_ps(weight + 64);
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
__m256 weight12 = _mm256_load_ps(weight + 72);
dst1 = _mm256_fmadd_ps(dst1, src02, weight12);
__m256 weight22 = _mm256_load_ps(weight + 80);
dst2 = _mm256_fmadd_ps(dst2, src02, weight22);
__m256 weight32 = _mm256_load_ps(weight + 88);
dst3 = _mm256_fmadd_ps(dst3, src02, weight32);
// bock3
__m256 src03 = _mm256_set1_ps(*(src + 3));
__m256 weight03 = _mm256_load_ps(weight + 96);
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
__m256 weight13 = _mm256_load_ps(weight + 104);
dst1 = _mm256_fmadd_ps(dst1, src03, weight13);
__m256 weight23 = _mm256_load_ps(weight + 112);
dst2 = _mm256_fmadd_ps(dst2, src03, weight23);
__m256 weight33 = _mm256_load_ps(weight + 120);
dst3 = _mm256_fmadd_ps(dst3, src03, weight33);
// bock4
__m256 src04 = _mm256_set1_ps(*(src + 4));
__m256 weight04 = _mm256_load_ps(weight + 128);
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
__m256 weight14 = _mm256_load_ps(weight + 136);
dst1 = _mm256_fmadd_ps(dst1, src04, weight14);
__m256 weight24 = _mm256_load_ps(weight + 144);
dst2 = _mm256_fmadd_ps(dst2, src04, weight24);
__m256 weight34 = _mm256_load_ps(weight + 152);
dst3 = _mm256_fmadd_ps(dst3, src04, weight34);
// bock5
__m256 src05 = _mm256_set1_ps(*(src + 5));
__m256 weight05 = _mm256_load_ps(weight + 160);
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
__m256 weight15 = _mm256_load_ps(weight + 168);
dst1 = _mm256_fmadd_ps(dst1, src05, weight15);
__m256 weight25 = _mm256_load_ps(weight + 176);
dst2 = _mm256_fmadd_ps(dst2, src05, weight25);
__m256 weight35 = _mm256_load_ps(weight + 184);
dst3 = _mm256_fmadd_ps(dst3, src05, weight35);
// bock6
__m256 src06 = _mm256_set1_ps(*(src + 6));
__m256 weight06 = _mm256_load_ps(weight + 192);
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
__m256 weight16 = _mm256_load_ps(weight + 200);
dst1 = _mm256_fmadd_ps(dst1, src06, weight16);
__m256 weight26 = _mm256_load_ps(weight + 208);
dst2 = _mm256_fmadd_ps(dst2, src06, weight26);
__m256 weight36 = _mm256_load_ps(weight + 216);
dst3 = _mm256_fmadd_ps(dst3, src06, weight36);
// bock7
__m256 src07 = _mm256_set1_ps(*(src + 7));
__m256 weight07 = _mm256_load_ps(weight + 224);
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
__m256 weight17 = _mm256_load_ps(weight + 232);
dst1 = _mm256_fmadd_ps(dst1, src07, weight17);
__m256 weight27 = _mm256_load_ps(weight + 240);
dst2 = _mm256_fmadd_ps(dst2, src07, weight27);
__m256 weight37 = _mm256_load_ps(weight + 248);
dst3 = _mm256_fmadd_ps(dst3, src07, weight37);
src = src + src_stride;
weight += 1024;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 1 * src_stride + 0, dst1);
_mm256_store_ps(dst + 2 * src_stride + 0, dst2);
_mm256_store_ps(dst + 3 * src_stride + 0, dst3);
}

View File

@ -0,0 +1,174 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
const float *dst_4 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n"
"vmovups 0(%[dst], %[dst_stride], 2), %%ymm2\n"
"vmovups 0(%[dst_4]), %%ymm3\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 32(%[bias]), %%ymm1\n"
"vmovaps 64(%[bias]), %%ymm2\n"
"vmovaps 96(%[bias]), %%ymm3\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
[ dst_4 ] "r"(dst_4)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3");
asm volatile(
"0:\n"
// block 0
"vbroadcastss 0(%[src]), %%ymm15\n"
"vmovaps 0(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vmovaps 32(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
"vmovaps 64(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
"vmovaps 96(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
// block 1
"vbroadcastss 1(%[src]), %%ymm15\n"
"vmovaps 128(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vmovaps 160(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
"vmovaps 192(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
"vmovaps 224(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
// block 2
"vbroadcastss 2(%[src]), %%ymm15\n"
"vmovaps 256(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vmovaps 288(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
"vmovaps 320(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
"vmovaps 352(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
// block 3
"vbroadcastss 3(%[src]), %%ymm15\n"
"vmovaps 384(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vmovaps 416(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
"vmovaps 448(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
"vmovaps 480(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
// block 4
"vbroadcastss 4(%[src]), %%ymm15\n"
"vmovaps 512(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vmovaps 544(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
"vmovaps 576(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
"vmovaps 608(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
// block 5
"vbroadcastss 5(%[src]), %%ymm15\n"
"vmovaps 640(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vmovaps 672(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
"vmovaps 704(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
"vmovaps 736(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
// block 6
"vbroadcastss 6(%[src]), %%ymm15\n"
"vmovaps 768(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vmovaps 800(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
"vmovaps 832(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
"vmovaps 864(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
// block 7
"vbroadcastss 7(%[src]), %%ymm15\n"
"vmovaps 896(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vmovaps 928(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
"vmovaps 960(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
"vmovaps 992(%[weight]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"dec %[deep]\n"
"add 1024, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm2, 0(%[dst], %[dst_stride], 2)\n"
"vmovups %%ymm3, 0(%[dst_4])\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
[ dst_4 ] "r"(dst_4)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,81 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 8);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 16);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 24);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 32);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 40);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 48);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 56);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
src = src + src_stride;
weight += 256;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
}

View File

@ -0,0 +1,105 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vbroadcastss 0(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
// block 1
"vmovaps 32(%[weight]), %%ymm15\n"
"vbroadcastss 1(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
// block 2
"vmovaps 64(%[weight]), %%ymm15\n"
"vbroadcastss 2(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
// block 3
"vmovaps 96(%[weight]), %%ymm15\n"
"vbroadcastss 3(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
// block 4
"vmovaps 128(%[weight]), %%ymm15\n"
"vbroadcastss 4(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
// block 5
"vmovaps 160(%[weight]), %%ymm15\n"
"vbroadcastss 5(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
// block 6
"vmovaps 192(%[weight]), %%ymm15\n"
"vbroadcastss 6(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
// block 7
"vmovaps 224(%[weight]), %%ymm15\n"
"vbroadcastss 7(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"dec %[deep]\n"
"add 256, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,145 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst2;
__m256 dst1;
__m256 dst3;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst2 = _mm256_load_ps(bias + 8);
dst1 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 8);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 weight10 = _mm256_load_ps(weight + 8);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
dst2 = _mm256_fmadd_ps(dst2, src00, weight10);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
dst3 = _mm256_fmadd_ps(dst3, src10, weight10);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 16);
__m256 weight11 = _mm256_load_ps(weight + 24);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
dst2 = _mm256_fmadd_ps(dst2, src01, weight11);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
dst3 = _mm256_fmadd_ps(dst3, src11, weight11);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 32);
__m256 weight12 = _mm256_load_ps(weight + 40);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
dst2 = _mm256_fmadd_ps(dst2, src02, weight12);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
dst3 = _mm256_fmadd_ps(dst3, src12, weight12);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 48);
__m256 weight13 = _mm256_load_ps(weight + 56);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
dst2 = _mm256_fmadd_ps(dst2, src03, weight13);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
dst3 = _mm256_fmadd_ps(dst3, src13, weight13);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 64);
__m256 weight14 = _mm256_load_ps(weight + 72);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
dst2 = _mm256_fmadd_ps(dst2, src04, weight14);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
dst3 = _mm256_fmadd_ps(dst3, src14, weight14);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 80);
__m256 weight15 = _mm256_load_ps(weight + 88);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
dst2 = _mm256_fmadd_ps(dst2, src05, weight15);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
dst3 = _mm256_fmadd_ps(dst3, src15, weight15);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 96);
__m256 weight16 = _mm256_load_ps(weight + 104);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
dst2 = _mm256_fmadd_ps(dst2, src06, weight16);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
dst3 = _mm256_fmadd_ps(dst3, src16, weight16);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 112);
__m256 weight17 = _mm256_load_ps(weight + 120);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
dst2 = _mm256_fmadd_ps(dst2, src07, weight17);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
dst3 = _mm256_fmadd_ps(dst3, src17, weight17);
src = src + src_stride;
weight += 512;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst3 = _mm256_max_ps(dst3, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst3 = _mm256_max_ps(dst3, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 1 * src_stride + 0, dst2);
_mm256_store_ps(dst + 1 * src_stride + 8, dst3);
}

View File

@ -0,0 +1,163 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n"
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 32(%[bias]), %%ymm2\n"
"vmovaps 32(%[bias]), %%ymm3\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vmovaps 32(%[weight]), %%ymm14\n"
"vbroadcastss 0(%[src]), %%ymm13\n"
"vbroadcastss 32(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
// block 1
"vmovaps 64(%[weight]), %%ymm15\n"
"vmovaps 96(%[weight]), %%ymm14\n"
"vbroadcastss 1(%[src]), %%ymm13\n"
"vbroadcastss 33(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
// block 2
"vmovaps 128(%[weight]), %%ymm15\n"
"vmovaps 160(%[weight]), %%ymm14\n"
"vbroadcastss 2(%[src]), %%ymm13\n"
"vbroadcastss 34(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
// block 3
"vmovaps 192(%[weight]), %%ymm15\n"
"vmovaps 224(%[weight]), %%ymm14\n"
"vbroadcastss 3(%[src]), %%ymm13\n"
"vbroadcastss 35(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
// block 4
"vmovaps 256(%[weight]), %%ymm15\n"
"vmovaps 288(%[weight]), %%ymm14\n"
"vbroadcastss 4(%[src]), %%ymm13\n"
"vbroadcastss 36(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
// block 5
"vmovaps 320(%[weight]), %%ymm15\n"
"vmovaps 352(%[weight]), %%ymm14\n"
"vbroadcastss 5(%[src]), %%ymm13\n"
"vbroadcastss 37(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
// block 6
"vmovaps 384(%[weight]), %%ymm15\n"
"vmovaps 416(%[weight]), %%ymm14\n"
"vbroadcastss 6(%[src]), %%ymm13\n"
"vbroadcastss 38(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
// block 7
"vmovaps 448(%[weight]), %%ymm15\n"
"vmovaps 480(%[weight]), %%ymm14\n"
"vbroadcastss 7(%[src]), %%ymm13\n"
"vbroadcastss 39(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"dec %[deep]\n"
"add 512, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,185 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst2;
__m256 dst4;
__m256 dst1;
__m256 dst3;
__m256 dst5;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0);
dst4 = _mm256_load_ps(dst + 2 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8);
dst5 = _mm256_load_ps(dst + 2 * dst_stride + 8);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst2 = _mm256_load_ps(bias + 8);
dst4 = _mm256_load_ps(bias + 16);
dst1 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 8);
dst5 = _mm256_load_ps(bias + 16);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 weight10 = _mm256_load_ps(weight + 8);
__m256 weight20 = _mm256_load_ps(weight + 16);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
dst2 = _mm256_fmadd_ps(dst2, src00, weight10);
dst4 = _mm256_fmadd_ps(dst4, src00, weight20);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
dst3 = _mm256_fmadd_ps(dst3, src10, weight10);
dst5 = _mm256_fmadd_ps(dst5, src10, weight20);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 24);
__m256 weight11 = _mm256_load_ps(weight + 32);
__m256 weight21 = _mm256_load_ps(weight + 40);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
dst2 = _mm256_fmadd_ps(dst2, src01, weight11);
dst4 = _mm256_fmadd_ps(dst4, src01, weight21);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
dst3 = _mm256_fmadd_ps(dst3, src11, weight11);
dst5 = _mm256_fmadd_ps(dst5, src11, weight21);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 48);
__m256 weight12 = _mm256_load_ps(weight + 56);
__m256 weight22 = _mm256_load_ps(weight + 64);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
dst2 = _mm256_fmadd_ps(dst2, src02, weight12);
dst4 = _mm256_fmadd_ps(dst4, src02, weight22);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
dst3 = _mm256_fmadd_ps(dst3, src12, weight12);
dst5 = _mm256_fmadd_ps(dst5, src12, weight22);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 72);
__m256 weight13 = _mm256_load_ps(weight + 80);
__m256 weight23 = _mm256_load_ps(weight + 88);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
dst2 = _mm256_fmadd_ps(dst2, src03, weight13);
dst4 = _mm256_fmadd_ps(dst4, src03, weight23);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
dst3 = _mm256_fmadd_ps(dst3, src13, weight13);
dst5 = _mm256_fmadd_ps(dst5, src13, weight23);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 96);
__m256 weight14 = _mm256_load_ps(weight + 104);
__m256 weight24 = _mm256_load_ps(weight + 112);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
dst2 = _mm256_fmadd_ps(dst2, src04, weight14);
dst4 = _mm256_fmadd_ps(dst4, src04, weight24);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
dst3 = _mm256_fmadd_ps(dst3, src14, weight14);
dst5 = _mm256_fmadd_ps(dst5, src14, weight24);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 120);
__m256 weight15 = _mm256_load_ps(weight + 128);
__m256 weight25 = _mm256_load_ps(weight + 136);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
dst2 = _mm256_fmadd_ps(dst2, src05, weight15);
dst4 = _mm256_fmadd_ps(dst4, src05, weight25);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
dst3 = _mm256_fmadd_ps(dst3, src15, weight15);
dst5 = _mm256_fmadd_ps(dst5, src15, weight25);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 144);
__m256 weight16 = _mm256_load_ps(weight + 152);
__m256 weight26 = _mm256_load_ps(weight + 160);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
dst2 = _mm256_fmadd_ps(dst2, src06, weight16);
dst4 = _mm256_fmadd_ps(dst4, src06, weight26);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
dst3 = _mm256_fmadd_ps(dst3, src16, weight16);
dst5 = _mm256_fmadd_ps(dst5, src16, weight26);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 168);
__m256 weight17 = _mm256_load_ps(weight + 176);
__m256 weight27 = _mm256_load_ps(weight + 184);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
dst2 = _mm256_fmadd_ps(dst2, src07, weight17);
dst4 = _mm256_fmadd_ps(dst4, src07, weight27);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
dst3 = _mm256_fmadd_ps(dst3, src17, weight17);
dst5 = _mm256_fmadd_ps(dst5, src17, weight27);
src = src + src_stride;
weight += 768;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst5 = _mm256_max_ps(dst5, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst5 = _mm256_max_ps(dst5, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 1 * src_stride + 0, dst2);
_mm256_store_ps(dst + 1 * src_stride + 8, dst3);
_mm256_store_ps(dst + 2 * src_stride + 0, dst4);
_mm256_store_ps(dst + 2 * src_stride + 8, dst5);
}

View File

@ -0,0 +1,199 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n"
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n"
"vmovups 0(%[dst], %[dst_stride], 2), %%ymm4\n"
"vmovups 32(%[dst], %[dst_stride], 2), %%ymm5\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 32(%[bias]), %%ymm2\n"
"vmovaps 32(%[bias]), %%ymm3\n"
"vmovaps 64(%[bias]), %%ymm4\n"
"vmovaps 64(%[bias]), %%ymm5\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vmovaps 32(%[weight]), %%ymm14\n"
"vmovaps 64(%[weight]), %%ymm13\n"
"vbroadcastss 0(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
"vbroadcastss 32(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
// block 1
"vmovaps 96(%[weight]), %%ymm15\n"
"vmovaps 128(%[weight]), %%ymm14\n"
"vmovaps 160(%[weight]), %%ymm13\n"
"vbroadcastss 1(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
"vbroadcastss 33(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
// block 2
"vmovaps 192(%[weight]), %%ymm15\n"
"vmovaps 224(%[weight]), %%ymm14\n"
"vmovaps 256(%[weight]), %%ymm13\n"
"vbroadcastss 2(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
"vbroadcastss 34(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
// block 3
"vmovaps 288(%[weight]), %%ymm15\n"
"vmovaps 320(%[weight]), %%ymm14\n"
"vmovaps 352(%[weight]), %%ymm13\n"
"vbroadcastss 3(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
"vbroadcastss 35(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
// block 4
"vmovaps 384(%[weight]), %%ymm15\n"
"vmovaps 416(%[weight]), %%ymm14\n"
"vmovaps 448(%[weight]), %%ymm13\n"
"vbroadcastss 4(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
"vbroadcastss 36(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
// block 5
"vmovaps 480(%[weight]), %%ymm15\n"
"vmovaps 512(%[weight]), %%ymm14\n"
"vmovaps 544(%[weight]), %%ymm13\n"
"vbroadcastss 5(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
"vbroadcastss 37(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
// block 6
"vmovaps 576(%[weight]), %%ymm15\n"
"vmovaps 608(%[weight]), %%ymm14\n"
"vmovaps 640(%[weight]), %%ymm13\n"
"vbroadcastss 6(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
"vbroadcastss 38(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
// block 7
"vmovaps 672(%[weight]), %%ymm15\n"
"vmovaps 704(%[weight]), %%ymm14\n"
"vmovaps 736(%[weight]), %%ymm13\n"
"vbroadcastss 7(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
"vbroadcastss 39(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
"dec %[deep]\n"
"add 768, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm4, 0(%[dst], %[dst_stride], 2)\n"
"vmovups %%ymm5, 32(%[dst], %[dst_stride], 2)\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,225 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst2;
__m256 dst4;
__m256 dst6;
__m256 dst1;
__m256 dst3;
__m256 dst5;
__m256 dst7;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0);
dst4 = _mm256_load_ps(dst + 2 * dst_stride + 0);
dst6 = _mm256_load_ps(dst + 3 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8);
dst5 = _mm256_load_ps(dst + 2 * dst_stride + 8);
dst7 = _mm256_load_ps(dst + 3 * dst_stride + 8);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
dst6 = _mm256_setzero_ps();
dst7 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst2 = _mm256_load_ps(bias + 8);
dst4 = _mm256_load_ps(bias + 16);
dst6 = _mm256_load_ps(bias + 24);
dst1 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 8);
dst5 = _mm256_load_ps(bias + 16);
dst7 = _mm256_load_ps(bias + 24);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 src00 = _mm256_set1_ps(*(src + 0));
__m256 src10 = _mm256_set1_ps(*(src + 8));
__m256 weight00 = _mm256_load_ps(weight + 0);
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
__m256 weight10 = _mm256_load_ps(weight + 8);
dst2 = _mm256_fmadd_ps(dst2, src00, weight10);
dst3 = _mm256_fmadd_ps(dst3, src10, weight10);
__m256 weight20 = _mm256_load_ps(weight + 16);
dst4 = _mm256_fmadd_ps(dst4, src00, weight20);
dst5 = _mm256_fmadd_ps(dst5, src10, weight20);
__m256 weight30 = _mm256_load_ps(weight + 24);
dst6 = _mm256_fmadd_ps(dst6, src00, weight30);
dst7 = _mm256_fmadd_ps(dst7, src10, weight30);
// bock1
__m256 src01 = _mm256_set1_ps(*(src + 1));
__m256 src11 = _mm256_set1_ps(*(src + 9));
__m256 weight01 = _mm256_load_ps(weight + 32);
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
__m256 weight11 = _mm256_load_ps(weight + 40);
dst2 = _mm256_fmadd_ps(dst2, src01, weight11);
dst3 = _mm256_fmadd_ps(dst3, src11, weight11);
__m256 weight21 = _mm256_load_ps(weight + 48);
dst4 = _mm256_fmadd_ps(dst4, src01, weight21);
dst5 = _mm256_fmadd_ps(dst5, src11, weight21);
__m256 weight31 = _mm256_load_ps(weight + 56);
dst6 = _mm256_fmadd_ps(dst6, src01, weight31);
dst7 = _mm256_fmadd_ps(dst7, src11, weight31);
// bock2
__m256 src02 = _mm256_set1_ps(*(src + 2));
__m256 src12 = _mm256_set1_ps(*(src + 10));
__m256 weight02 = _mm256_load_ps(weight + 64);
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
__m256 weight12 = _mm256_load_ps(weight + 72);
dst2 = _mm256_fmadd_ps(dst2, src02, weight12);
dst3 = _mm256_fmadd_ps(dst3, src12, weight12);
__m256 weight22 = _mm256_load_ps(weight + 80);
dst4 = _mm256_fmadd_ps(dst4, src02, weight22);
dst5 = _mm256_fmadd_ps(dst5, src12, weight22);
__m256 weight32 = _mm256_load_ps(weight + 88);
dst6 = _mm256_fmadd_ps(dst6, src02, weight32);
dst7 = _mm256_fmadd_ps(dst7, src12, weight32);
// bock3
__m256 src03 = _mm256_set1_ps(*(src + 3));
__m256 src13 = _mm256_set1_ps(*(src + 11));
__m256 weight03 = _mm256_load_ps(weight + 96);
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
__m256 weight13 = _mm256_load_ps(weight + 104);
dst2 = _mm256_fmadd_ps(dst2, src03, weight13);
dst3 = _mm256_fmadd_ps(dst3, src13, weight13);
__m256 weight23 = _mm256_load_ps(weight + 112);
dst4 = _mm256_fmadd_ps(dst4, src03, weight23);
dst5 = _mm256_fmadd_ps(dst5, src13, weight23);
__m256 weight33 = _mm256_load_ps(weight + 120);
dst6 = _mm256_fmadd_ps(dst6, src03, weight33);
dst7 = _mm256_fmadd_ps(dst7, src13, weight33);
// bock4
__m256 src04 = _mm256_set1_ps(*(src + 4));
__m256 src14 = _mm256_set1_ps(*(src + 12));
__m256 weight04 = _mm256_load_ps(weight + 128);
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
__m256 weight14 = _mm256_load_ps(weight + 136);
dst2 = _mm256_fmadd_ps(dst2, src04, weight14);
dst3 = _mm256_fmadd_ps(dst3, src14, weight14);
__m256 weight24 = _mm256_load_ps(weight + 144);
dst4 = _mm256_fmadd_ps(dst4, src04, weight24);
dst5 = _mm256_fmadd_ps(dst5, src14, weight24);
__m256 weight34 = _mm256_load_ps(weight + 152);
dst6 = _mm256_fmadd_ps(dst6, src04, weight34);
dst7 = _mm256_fmadd_ps(dst7, src14, weight34);
// bock5
__m256 src05 = _mm256_set1_ps(*(src + 5));
__m256 src15 = _mm256_set1_ps(*(src + 13));
__m256 weight05 = _mm256_load_ps(weight + 160);
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
__m256 weight15 = _mm256_load_ps(weight + 168);
dst2 = _mm256_fmadd_ps(dst2, src05, weight15);
dst3 = _mm256_fmadd_ps(dst3, src15, weight15);
__m256 weight25 = _mm256_load_ps(weight + 176);
dst4 = _mm256_fmadd_ps(dst4, src05, weight25);
dst5 = _mm256_fmadd_ps(dst5, src15, weight25);
__m256 weight35 = _mm256_load_ps(weight + 184);
dst6 = _mm256_fmadd_ps(dst6, src05, weight35);
dst7 = _mm256_fmadd_ps(dst7, src15, weight35);
// bock6
__m256 src06 = _mm256_set1_ps(*(src + 6));
__m256 src16 = _mm256_set1_ps(*(src + 14));
__m256 weight06 = _mm256_load_ps(weight + 192);
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
__m256 weight16 = _mm256_load_ps(weight + 200);
dst2 = _mm256_fmadd_ps(dst2, src06, weight16);
dst3 = _mm256_fmadd_ps(dst3, src16, weight16);
__m256 weight26 = _mm256_load_ps(weight + 208);
dst4 = _mm256_fmadd_ps(dst4, src06, weight26);
dst5 = _mm256_fmadd_ps(dst5, src16, weight26);
__m256 weight36 = _mm256_load_ps(weight + 216);
dst6 = _mm256_fmadd_ps(dst6, src06, weight36);
dst7 = _mm256_fmadd_ps(dst7, src16, weight36);
// bock7
__m256 src07 = _mm256_set1_ps(*(src + 7));
__m256 src17 = _mm256_set1_ps(*(src + 15));
__m256 weight07 = _mm256_load_ps(weight + 224);
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
__m256 weight17 = _mm256_load_ps(weight + 232);
dst2 = _mm256_fmadd_ps(dst2, src07, weight17);
dst3 = _mm256_fmadd_ps(dst3, src17, weight17);
__m256 weight27 = _mm256_load_ps(weight + 240);
dst4 = _mm256_fmadd_ps(dst4, src07, weight27);
dst5 = _mm256_fmadd_ps(dst5, src17, weight27);
__m256 weight37 = _mm256_load_ps(weight + 248);
dst6 = _mm256_fmadd_ps(dst6, src07, weight37);
dst7 = _mm256_fmadd_ps(dst7, src17, weight37);
src = src + src_stride;
weight += 1024;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst6 = _mm256_min_ps(dst6, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
dst7 = _mm256_min_ps(dst7, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst7 = _mm256_max_ps(dst7, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst7 = _mm256_max_ps(dst7, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 1 * src_stride + 0, dst2);
_mm256_store_ps(dst + 1 * src_stride + 8, dst3);
_mm256_store_ps(dst + 2 * src_stride + 0, dst4);
_mm256_store_ps(dst + 2 * src_stride + 8, dst5);
_mm256_store_ps(dst + 3 * src_stride + 0, dst6);
_mm256_store_ps(dst + 3 * src_stride + 8, dst7);
}

View File

@ -0,0 +1,238 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
const float *dst_4 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n"
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n"
"vmovups 0(%[dst], %[dst_stride], 2), %%ymm4\n"
"vmovups 32(%[dst], %[dst_stride], 2), %%ymm5\n"
"vmovups 0(%[dst_4]), %%ymm6\n"
"vmovups 32(%[dst_4]), %%ymm7\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 32(%[bias]), %%ymm2\n"
"vmovaps 32(%[bias]), %%ymm3\n"
"vmovaps 64(%[bias]), %%ymm4\n"
"vmovaps 64(%[bias]), %%ymm5\n"
"vmovaps 96(%[bias]), %%ymm6\n"
"vmovaps 96(%[bias]), %%ymm7\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
[ dst_4 ] "r"(dst_4)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7");
asm volatile(
"0:\n"
// block 0
"vbroadcastss 0(%[src]), %%ymm15\n"
"vbroadcastss 32(%[src]), %%ymm14\n"
"vmovaps 0(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
"vmovaps 32(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vmovaps 64(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vmovaps 96(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
// block 1
"vbroadcastss 1(%[src]), %%ymm15\n"
"vbroadcastss 33(%[src]), %%ymm14\n"
"vmovaps 128(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
"vmovaps 160(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vmovaps 192(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vmovaps 224(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
// block 2
"vbroadcastss 2(%[src]), %%ymm15\n"
"vbroadcastss 34(%[src]), %%ymm14\n"
"vmovaps 256(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
"vmovaps 288(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vmovaps 320(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vmovaps 352(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
// block 3
"vbroadcastss 3(%[src]), %%ymm15\n"
"vbroadcastss 35(%[src]), %%ymm14\n"
"vmovaps 384(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
"vmovaps 416(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vmovaps 448(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vmovaps 480(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
// block 4
"vbroadcastss 4(%[src]), %%ymm15\n"
"vbroadcastss 36(%[src]), %%ymm14\n"
"vmovaps 512(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
"vmovaps 544(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vmovaps 576(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vmovaps 608(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
// block 5
"vbroadcastss 5(%[src]), %%ymm15\n"
"vbroadcastss 37(%[src]), %%ymm14\n"
"vmovaps 640(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
"vmovaps 672(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vmovaps 704(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vmovaps 736(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
// block 6
"vbroadcastss 6(%[src]), %%ymm15\n"
"vbroadcastss 38(%[src]), %%ymm14\n"
"vmovaps 768(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
"vmovaps 800(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vmovaps 832(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vmovaps 864(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
// block 7
"vbroadcastss 7(%[src]), %%ymm15\n"
"vbroadcastss 39(%[src]), %%ymm14\n"
"vmovaps 896(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
"vmovaps 928(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vmovaps 960(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vmovaps 992(%[weight]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
"dec %[deep]\n"
"add 1024, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"vminps %%ymm6, %%ymm14, %%ymm6\n"
"vminps %%ymm7, %%ymm14, %%ymm7\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm4, 0(%[dst], %[dst_stride], 2)\n"
"vmovups %%ymm5, 32(%[dst], %[dst_stride], 2)\n"
"vmovups %%ymm6, 0(%[dst_4])\n"
"vmovups %%ymm7, 32(%[dst_4])\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
[ dst_4 ] "r"(dst_4)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,105 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst1;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst1 = _mm256_load_ps(bias + 0);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 8);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 16);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 24);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 32);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 40);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 48);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 56);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
src = src + src_stride;
weight += 256;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
}

View File

@ -0,0 +1,127 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vbroadcastss 0(%[src]), %%ymm14\n"
"vbroadcastss 32(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
// block 1
"vmovaps 32(%[weight]), %%ymm15\n"
"vbroadcastss 1(%[src]), %%ymm14\n"
"vbroadcastss 33(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
// block 2
"vmovaps 64(%[weight]), %%ymm15\n"
"vbroadcastss 2(%[src]), %%ymm14\n"
"vbroadcastss 34(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
// block 3
"vmovaps 96(%[weight]), %%ymm15\n"
"vbroadcastss 3(%[src]), %%ymm14\n"
"vbroadcastss 35(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
// block 4
"vmovaps 128(%[weight]), %%ymm15\n"
"vbroadcastss 4(%[src]), %%ymm14\n"
"vbroadcastss 36(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
// block 5
"vmovaps 160(%[weight]), %%ymm15\n"
"vbroadcastss 5(%[src]), %%ymm14\n"
"vbroadcastss 37(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
// block 6
"vmovaps 192(%[weight]), %%ymm15\n"
"vbroadcastss 6(%[src]), %%ymm14\n"
"vbroadcastss 38(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
// block 7
"vmovaps 224(%[weight]), %%ymm15\n"
"vbroadcastss 7(%[src]), %%ymm14\n"
"vbroadcastss 39(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"dec %[deep]\n"
"add 256, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,185 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst3;
__m256 dst1;
__m256 dst4;
__m256 dst2;
__m256 dst5;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 8);
dst1 = _mm256_load_ps(bias + 0);
dst4 = _mm256_load_ps(bias + 8);
dst2 = _mm256_load_ps(bias + 0);
dst5 = _mm256_load_ps(bias + 8);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 weight10 = _mm256_load_ps(weight + 8);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
dst3 = _mm256_fmadd_ps(dst3, src00, weight10);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
dst4 = _mm256_fmadd_ps(dst4, src10, weight10);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
dst5 = _mm256_fmadd_ps(dst5, src20, weight10);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 16);
__m256 weight11 = _mm256_load_ps(weight + 24);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
dst3 = _mm256_fmadd_ps(dst3, src01, weight11);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
dst4 = _mm256_fmadd_ps(dst4, src11, weight11);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
dst5 = _mm256_fmadd_ps(dst5, src21, weight11);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 32);
__m256 weight12 = _mm256_load_ps(weight + 40);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
dst3 = _mm256_fmadd_ps(dst3, src02, weight12);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
dst4 = _mm256_fmadd_ps(dst4, src12, weight12);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
dst5 = _mm256_fmadd_ps(dst5, src22, weight12);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 48);
__m256 weight13 = _mm256_load_ps(weight + 56);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
dst3 = _mm256_fmadd_ps(dst3, src03, weight13);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
dst4 = _mm256_fmadd_ps(dst4, src13, weight13);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
dst5 = _mm256_fmadd_ps(dst5, src23, weight13);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 64);
__m256 weight14 = _mm256_load_ps(weight + 72);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
dst3 = _mm256_fmadd_ps(dst3, src04, weight14);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
dst4 = _mm256_fmadd_ps(dst4, src14, weight14);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
dst5 = _mm256_fmadd_ps(dst5, src24, weight14);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 80);
__m256 weight15 = _mm256_load_ps(weight + 88);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
dst3 = _mm256_fmadd_ps(dst3, src05, weight15);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
dst4 = _mm256_fmadd_ps(dst4, src15, weight15);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
dst5 = _mm256_fmadd_ps(dst5, src25, weight15);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 96);
__m256 weight16 = _mm256_load_ps(weight + 104);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
dst3 = _mm256_fmadd_ps(dst3, src06, weight16);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
dst4 = _mm256_fmadd_ps(dst4, src16, weight16);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
dst5 = _mm256_fmadd_ps(dst5, src26, weight16);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 112);
__m256 weight17 = _mm256_load_ps(weight + 120);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
dst3 = _mm256_fmadd_ps(dst3, src07, weight17);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
dst4 = _mm256_fmadd_ps(dst4, src17, weight17);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
dst5 = _mm256_fmadd_ps(dst5, src27, weight17);
src = src + src_stride;
weight += 512;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst5 = _mm256_max_ps(dst5, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst5 = _mm256_max_ps(dst5, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 1 * src_stride + 0, dst3);
_mm256_store_ps(dst + 1 * src_stride + 8, dst4);
_mm256_store_ps(dst + 1 * src_stride + 16, dst5);
}

View File

@ -0,0 +1,199 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n"
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n"
"vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 32(%[bias]), %%ymm3\n"
"vmovaps 32(%[bias]), %%ymm4\n"
"vmovaps 32(%[bias]), %%ymm5\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vmovaps 32(%[weight]), %%ymm14\n"
"vbroadcastss 0(%[src]), %%ymm13\n"
"vbroadcastss 32(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vbroadcastss 64(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
// block 1
"vmovaps 64(%[weight]), %%ymm15\n"
"vmovaps 96(%[weight]), %%ymm14\n"
"vbroadcastss 1(%[src]), %%ymm13\n"
"vbroadcastss 33(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vbroadcastss 65(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
// block 2
"vmovaps 128(%[weight]), %%ymm15\n"
"vmovaps 160(%[weight]), %%ymm14\n"
"vbroadcastss 2(%[src]), %%ymm13\n"
"vbroadcastss 34(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vbroadcastss 66(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
// block 3
"vmovaps 192(%[weight]), %%ymm15\n"
"vmovaps 224(%[weight]), %%ymm14\n"
"vbroadcastss 3(%[src]), %%ymm13\n"
"vbroadcastss 35(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vbroadcastss 67(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
// block 4
"vmovaps 256(%[weight]), %%ymm15\n"
"vmovaps 288(%[weight]), %%ymm14\n"
"vbroadcastss 4(%[src]), %%ymm13\n"
"vbroadcastss 36(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vbroadcastss 68(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
// block 5
"vmovaps 320(%[weight]), %%ymm15\n"
"vmovaps 352(%[weight]), %%ymm14\n"
"vbroadcastss 5(%[src]), %%ymm13\n"
"vbroadcastss 37(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vbroadcastss 69(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
// block 6
"vmovaps 384(%[weight]), %%ymm15\n"
"vmovaps 416(%[weight]), %%ymm14\n"
"vbroadcastss 6(%[src]), %%ymm13\n"
"vbroadcastss 38(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vbroadcastss 70(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
// block 7
"vmovaps 448(%[weight]), %%ymm15\n"
"vmovaps 480(%[weight]), %%ymm14\n"
"vbroadcastss 7(%[src]), %%ymm13\n"
"vbroadcastss 39(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vbroadcastss 71(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"dec %[deep]\n"
"add 512, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,241 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst3;
__m256 dst6;
__m256 dst1;
__m256 dst4;
__m256 dst7;
__m256 dst2;
__m256 dst5;
__m256 dst8;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0);
dst6 = _mm256_load_ps(dst + 2 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8);
dst7 = _mm256_load_ps(dst + 2 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16);
dst8 = _mm256_load_ps(dst + 2 * dst_stride + 16);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
dst6 = _mm256_setzero_ps();
dst7 = _mm256_setzero_ps();
dst8 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 8);
dst6 = _mm256_load_ps(bias + 16);
dst1 = _mm256_load_ps(bias + 0);
dst4 = _mm256_load_ps(bias + 8);
dst7 = _mm256_load_ps(bias + 16);
dst2 = _mm256_load_ps(bias + 0);
dst5 = _mm256_load_ps(bias + 8);
dst8 = _mm256_load_ps(bias + 16);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 weight10 = _mm256_load_ps(weight + 8);
__m256 weight20 = _mm256_load_ps(weight + 16);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
dst3 = _mm256_fmadd_ps(dst3, src00, weight10);
dst6 = _mm256_fmadd_ps(dst6, src00, weight20);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
dst4 = _mm256_fmadd_ps(dst4, src10, weight10);
dst7 = _mm256_fmadd_ps(dst7, src10, weight20);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
dst5 = _mm256_fmadd_ps(dst5, src20, weight10);
dst8 = _mm256_fmadd_ps(dst8, src20, weight20);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 24);
__m256 weight11 = _mm256_load_ps(weight + 32);
__m256 weight21 = _mm256_load_ps(weight + 40);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
dst3 = _mm256_fmadd_ps(dst3, src01, weight11);
dst6 = _mm256_fmadd_ps(dst6, src01, weight21);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
dst4 = _mm256_fmadd_ps(dst4, src11, weight11);
dst7 = _mm256_fmadd_ps(dst7, src11, weight21);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
dst5 = _mm256_fmadd_ps(dst5, src21, weight11);
dst8 = _mm256_fmadd_ps(dst8, src21, weight21);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 48);
__m256 weight12 = _mm256_load_ps(weight + 56);
__m256 weight22 = _mm256_load_ps(weight + 64);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
dst3 = _mm256_fmadd_ps(dst3, src02, weight12);
dst6 = _mm256_fmadd_ps(dst6, src02, weight22);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
dst4 = _mm256_fmadd_ps(dst4, src12, weight12);
dst7 = _mm256_fmadd_ps(dst7, src12, weight22);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
dst5 = _mm256_fmadd_ps(dst5, src22, weight12);
dst8 = _mm256_fmadd_ps(dst8, src22, weight22);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 72);
__m256 weight13 = _mm256_load_ps(weight + 80);
__m256 weight23 = _mm256_load_ps(weight + 88);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
dst3 = _mm256_fmadd_ps(dst3, src03, weight13);
dst6 = _mm256_fmadd_ps(dst6, src03, weight23);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
dst4 = _mm256_fmadd_ps(dst4, src13, weight13);
dst7 = _mm256_fmadd_ps(dst7, src13, weight23);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
dst5 = _mm256_fmadd_ps(dst5, src23, weight13);
dst8 = _mm256_fmadd_ps(dst8, src23, weight23);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 96);
__m256 weight14 = _mm256_load_ps(weight + 104);
__m256 weight24 = _mm256_load_ps(weight + 112);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
dst3 = _mm256_fmadd_ps(dst3, src04, weight14);
dst6 = _mm256_fmadd_ps(dst6, src04, weight24);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
dst4 = _mm256_fmadd_ps(dst4, src14, weight14);
dst7 = _mm256_fmadd_ps(dst7, src14, weight24);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
dst5 = _mm256_fmadd_ps(dst5, src24, weight14);
dst8 = _mm256_fmadd_ps(dst8, src24, weight24);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 120);
__m256 weight15 = _mm256_load_ps(weight + 128);
__m256 weight25 = _mm256_load_ps(weight + 136);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
dst3 = _mm256_fmadd_ps(dst3, src05, weight15);
dst6 = _mm256_fmadd_ps(dst6, src05, weight25);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
dst4 = _mm256_fmadd_ps(dst4, src15, weight15);
dst7 = _mm256_fmadd_ps(dst7, src15, weight25);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
dst5 = _mm256_fmadd_ps(dst5, src25, weight15);
dst8 = _mm256_fmadd_ps(dst8, src25, weight25);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 144);
__m256 weight16 = _mm256_load_ps(weight + 152);
__m256 weight26 = _mm256_load_ps(weight + 160);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
dst3 = _mm256_fmadd_ps(dst3, src06, weight16);
dst6 = _mm256_fmadd_ps(dst6, src06, weight26);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
dst4 = _mm256_fmadd_ps(dst4, src16, weight16);
dst7 = _mm256_fmadd_ps(dst7, src16, weight26);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
dst5 = _mm256_fmadd_ps(dst5, src26, weight16);
dst8 = _mm256_fmadd_ps(dst8, src26, weight26);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 168);
__m256 weight17 = _mm256_load_ps(weight + 176);
__m256 weight27 = _mm256_load_ps(weight + 184);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
dst3 = _mm256_fmadd_ps(dst3, src07, weight17);
dst6 = _mm256_fmadd_ps(dst6, src07, weight27);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
dst4 = _mm256_fmadd_ps(dst4, src17, weight17);
dst7 = _mm256_fmadd_ps(dst7, src17, weight27);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
dst5 = _mm256_fmadd_ps(dst5, src27, weight17);
dst8 = _mm256_fmadd_ps(dst8, src27, weight27);
src = src + src_stride;
weight += 768;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst6 = _mm256_min_ps(dst6, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst7 = _mm256_min_ps(dst7, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
dst8 = _mm256_min_ps(dst8, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst8 = _mm256_max_ps(dst8, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst8 = _mm256_max_ps(dst8, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 1 * src_stride + 0, dst3);
_mm256_store_ps(dst + 1 * src_stride + 8, dst4);
_mm256_store_ps(dst + 1 * src_stride + 16, dst5);
_mm256_store_ps(dst + 2 * src_stride + 0, dst6);
_mm256_store_ps(dst + 2 * src_stride + 8, dst7);
_mm256_store_ps(dst + 2 * src_stride + 16, dst8);
}

View File

@ -0,0 +1,249 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n"
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n"
"vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n"
"vmovups 0(%[dst], %[dst_stride], 2), %%ymm6\n"
"vmovups 32(%[dst], %[dst_stride], 2), %%ymm7\n"
"vmovups 64(%[dst], %[dst_stride], 2), %%ymm8\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 32(%[bias]), %%ymm3\n"
"vmovaps 32(%[bias]), %%ymm4\n"
"vmovaps 32(%[bias]), %%ymm5\n"
"vmovaps 64(%[bias]), %%ymm6\n"
"vmovaps 64(%[bias]), %%ymm7\n"
"vmovaps 64(%[bias]), %%ymm8\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vmovaps 32(%[weight]), %%ymm14\n"
"vmovaps 64(%[weight]), %%ymm13\n"
"vbroadcastss 0(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
"vbroadcastss 32(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
"vbroadcastss 64(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
// block 1
"vmovaps 96(%[weight]), %%ymm15\n"
"vmovaps 128(%[weight]), %%ymm14\n"
"vmovaps 160(%[weight]), %%ymm13\n"
"vbroadcastss 1(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
"vbroadcastss 33(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
"vbroadcastss 65(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
// block 2
"vmovaps 192(%[weight]), %%ymm15\n"
"vmovaps 224(%[weight]), %%ymm14\n"
"vmovaps 256(%[weight]), %%ymm13\n"
"vbroadcastss 2(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
"vbroadcastss 34(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
"vbroadcastss 66(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
// block 3
"vmovaps 288(%[weight]), %%ymm15\n"
"vmovaps 320(%[weight]), %%ymm14\n"
"vmovaps 352(%[weight]), %%ymm13\n"
"vbroadcastss 3(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
"vbroadcastss 35(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
"vbroadcastss 67(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
// block 4
"vmovaps 384(%[weight]), %%ymm15\n"
"vmovaps 416(%[weight]), %%ymm14\n"
"vmovaps 448(%[weight]), %%ymm13\n"
"vbroadcastss 4(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
"vbroadcastss 36(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
"vbroadcastss 68(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
// block 5
"vmovaps 480(%[weight]), %%ymm15\n"
"vmovaps 512(%[weight]), %%ymm14\n"
"vmovaps 544(%[weight]), %%ymm13\n"
"vbroadcastss 5(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
"vbroadcastss 37(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
"vbroadcastss 69(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
// block 6
"vmovaps 576(%[weight]), %%ymm15\n"
"vmovaps 608(%[weight]), %%ymm14\n"
"vmovaps 640(%[weight]), %%ymm13\n"
"vbroadcastss 6(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
"vbroadcastss 38(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
"vbroadcastss 70(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
// block 7
"vmovaps 672(%[weight]), %%ymm15\n"
"vmovaps 704(%[weight]), %%ymm14\n"
"vmovaps 736(%[weight]), %%ymm13\n"
"vbroadcastss 7(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
"vbroadcastss 39(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
"vbroadcastss 71(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"dec %[deep]\n"
"add 768, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"vminps %%ymm6, %%ymm14, %%ymm6\n"
"vminps %%ymm7, %%ymm14, %%ymm7\n"
"vminps %%ymm8, %%ymm14, %%ymm8\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm6, 0(%[dst], %[dst_stride], 2)\n"
"vmovups %%ymm7, 32(%[dst], %[dst_stride], 2)\n"
"vmovups %%ymm8, 64(%[dst], %[dst_stride], 2)\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,297 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst3;
__m256 dst6;
__m256 dst9;
__m256 dst1;
__m256 dst4;
__m256 dst7;
__m256 dst10;
__m256 dst2;
__m256 dst5;
__m256 dst8;
__m256 dst11;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0);
dst6 = _mm256_load_ps(dst + 2 * dst_stride + 0);
dst9 = _mm256_load_ps(dst + 3 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8);
dst7 = _mm256_load_ps(dst + 2 * dst_stride + 8);
dst10 = _mm256_load_ps(dst + 3 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16);
dst8 = _mm256_load_ps(dst + 2 * dst_stride + 16);
dst11 = _mm256_load_ps(dst + 3 * dst_stride + 16);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
dst6 = _mm256_setzero_ps();
dst7 = _mm256_setzero_ps();
dst8 = _mm256_setzero_ps();
dst9 = _mm256_setzero_ps();
dst10 = _mm256_setzero_ps();
dst11 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 8);
dst6 = _mm256_load_ps(bias + 16);
dst9 = _mm256_load_ps(bias + 24);
dst1 = _mm256_load_ps(bias + 0);
dst4 = _mm256_load_ps(bias + 8);
dst7 = _mm256_load_ps(bias + 16);
dst10 = _mm256_load_ps(bias + 24);
dst2 = _mm256_load_ps(bias + 0);
dst5 = _mm256_load_ps(bias + 8);
dst8 = _mm256_load_ps(bias + 16);
dst11 = _mm256_load_ps(bias + 24);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 src00 = _mm256_set1_ps(*(src + 0));
__m256 src10 = _mm256_set1_ps(*(src + 8));
__m256 src20 = _mm256_set1_ps(*(src + 16));
__m256 weight00 = _mm256_load_ps(weight + 0);
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
__m256 weight10 = _mm256_load_ps(weight + 8);
dst3 = _mm256_fmadd_ps(dst3, src00, weight10);
dst4 = _mm256_fmadd_ps(dst4, src10, weight10);
dst5 = _mm256_fmadd_ps(dst5, src20, weight10);
__m256 weight20 = _mm256_load_ps(weight + 16);
dst6 = _mm256_fmadd_ps(dst6, src00, weight20);
dst7 = _mm256_fmadd_ps(dst7, src10, weight20);
dst8 = _mm256_fmadd_ps(dst8, src20, weight20);
__m256 weight30 = _mm256_load_ps(weight + 24);
dst9 = _mm256_fmadd_ps(dst9, src00, weight30);
dst10 = _mm256_fmadd_ps(dst10, src10, weight30);
dst11 = _mm256_fmadd_ps(dst11, src20, weight30);
// bock1
__m256 src01 = _mm256_set1_ps(*(src + 1));
__m256 src11 = _mm256_set1_ps(*(src + 9));
__m256 src21 = _mm256_set1_ps(*(src + 17));
__m256 weight01 = _mm256_load_ps(weight + 32);
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
__m256 weight11 = _mm256_load_ps(weight + 40);
dst3 = _mm256_fmadd_ps(dst3, src01, weight11);
dst4 = _mm256_fmadd_ps(dst4, src11, weight11);
dst5 = _mm256_fmadd_ps(dst5, src21, weight11);
__m256 weight21 = _mm256_load_ps(weight + 48);
dst6 = _mm256_fmadd_ps(dst6, src01, weight21);
dst7 = _mm256_fmadd_ps(dst7, src11, weight21);
dst8 = _mm256_fmadd_ps(dst8, src21, weight21);
__m256 weight31 = _mm256_load_ps(weight + 56);
dst9 = _mm256_fmadd_ps(dst9, src01, weight31);
dst10 = _mm256_fmadd_ps(dst10, src11, weight31);
dst11 = _mm256_fmadd_ps(dst11, src21, weight31);
// bock2
__m256 src02 = _mm256_set1_ps(*(src + 2));
__m256 src12 = _mm256_set1_ps(*(src + 10));
__m256 src22 = _mm256_set1_ps(*(src + 18));
__m256 weight02 = _mm256_load_ps(weight + 64);
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
__m256 weight12 = _mm256_load_ps(weight + 72);
dst3 = _mm256_fmadd_ps(dst3, src02, weight12);
dst4 = _mm256_fmadd_ps(dst4, src12, weight12);
dst5 = _mm256_fmadd_ps(dst5, src22, weight12);
__m256 weight22 = _mm256_load_ps(weight + 80);
dst6 = _mm256_fmadd_ps(dst6, src02, weight22);
dst7 = _mm256_fmadd_ps(dst7, src12, weight22);
dst8 = _mm256_fmadd_ps(dst8, src22, weight22);
__m256 weight32 = _mm256_load_ps(weight + 88);
dst9 = _mm256_fmadd_ps(dst9, src02, weight32);
dst10 = _mm256_fmadd_ps(dst10, src12, weight32);
dst11 = _mm256_fmadd_ps(dst11, src22, weight32);
// bock3
__m256 src03 = _mm256_set1_ps(*(src + 3));
__m256 src13 = _mm256_set1_ps(*(src + 11));
__m256 src23 = _mm256_set1_ps(*(src + 19));
__m256 weight03 = _mm256_load_ps(weight + 96);
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
__m256 weight13 = _mm256_load_ps(weight + 104);
dst3 = _mm256_fmadd_ps(dst3, src03, weight13);
dst4 = _mm256_fmadd_ps(dst4, src13, weight13);
dst5 = _mm256_fmadd_ps(dst5, src23, weight13);
__m256 weight23 = _mm256_load_ps(weight + 112);
dst6 = _mm256_fmadd_ps(dst6, src03, weight23);
dst7 = _mm256_fmadd_ps(dst7, src13, weight23);
dst8 = _mm256_fmadd_ps(dst8, src23, weight23);
__m256 weight33 = _mm256_load_ps(weight + 120);
dst9 = _mm256_fmadd_ps(dst9, src03, weight33);
dst10 = _mm256_fmadd_ps(dst10, src13, weight33);
dst11 = _mm256_fmadd_ps(dst11, src23, weight33);
// bock4
__m256 src04 = _mm256_set1_ps(*(src + 4));
__m256 src14 = _mm256_set1_ps(*(src + 12));
__m256 src24 = _mm256_set1_ps(*(src + 20));
__m256 weight04 = _mm256_load_ps(weight + 128);
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
__m256 weight14 = _mm256_load_ps(weight + 136);
dst3 = _mm256_fmadd_ps(dst3, src04, weight14);
dst4 = _mm256_fmadd_ps(dst4, src14, weight14);
dst5 = _mm256_fmadd_ps(dst5, src24, weight14);
__m256 weight24 = _mm256_load_ps(weight + 144);
dst6 = _mm256_fmadd_ps(dst6, src04, weight24);
dst7 = _mm256_fmadd_ps(dst7, src14, weight24);
dst8 = _mm256_fmadd_ps(dst8, src24, weight24);
__m256 weight34 = _mm256_load_ps(weight + 152);
dst9 = _mm256_fmadd_ps(dst9, src04, weight34);
dst10 = _mm256_fmadd_ps(dst10, src14, weight34);
dst11 = _mm256_fmadd_ps(dst11, src24, weight34);
// bock5
__m256 src05 = _mm256_set1_ps(*(src + 5));
__m256 src15 = _mm256_set1_ps(*(src + 13));
__m256 src25 = _mm256_set1_ps(*(src + 21));
__m256 weight05 = _mm256_load_ps(weight + 160);
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
__m256 weight15 = _mm256_load_ps(weight + 168);
dst3 = _mm256_fmadd_ps(dst3, src05, weight15);
dst4 = _mm256_fmadd_ps(dst4, src15, weight15);
dst5 = _mm256_fmadd_ps(dst5, src25, weight15);
__m256 weight25 = _mm256_load_ps(weight + 176);
dst6 = _mm256_fmadd_ps(dst6, src05, weight25);
dst7 = _mm256_fmadd_ps(dst7, src15, weight25);
dst8 = _mm256_fmadd_ps(dst8, src25, weight25);
__m256 weight35 = _mm256_load_ps(weight + 184);
dst9 = _mm256_fmadd_ps(dst9, src05, weight35);
dst10 = _mm256_fmadd_ps(dst10, src15, weight35);
dst11 = _mm256_fmadd_ps(dst11, src25, weight35);
// bock6
__m256 src06 = _mm256_set1_ps(*(src + 6));
__m256 src16 = _mm256_set1_ps(*(src + 14));
__m256 src26 = _mm256_set1_ps(*(src + 22));
__m256 weight06 = _mm256_load_ps(weight + 192);
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
__m256 weight16 = _mm256_load_ps(weight + 200);
dst3 = _mm256_fmadd_ps(dst3, src06, weight16);
dst4 = _mm256_fmadd_ps(dst4, src16, weight16);
dst5 = _mm256_fmadd_ps(dst5, src26, weight16);
__m256 weight26 = _mm256_load_ps(weight + 208);
dst6 = _mm256_fmadd_ps(dst6, src06, weight26);
dst7 = _mm256_fmadd_ps(dst7, src16, weight26);
dst8 = _mm256_fmadd_ps(dst8, src26, weight26);
__m256 weight36 = _mm256_load_ps(weight + 216);
dst9 = _mm256_fmadd_ps(dst9, src06, weight36);
dst10 = _mm256_fmadd_ps(dst10, src16, weight36);
dst11 = _mm256_fmadd_ps(dst11, src26, weight36);
// bock7
__m256 src07 = _mm256_set1_ps(*(src + 7));
__m256 src17 = _mm256_set1_ps(*(src + 15));
__m256 src27 = _mm256_set1_ps(*(src + 23));
__m256 weight07 = _mm256_load_ps(weight + 224);
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
__m256 weight17 = _mm256_load_ps(weight + 232);
dst3 = _mm256_fmadd_ps(dst3, src07, weight17);
dst4 = _mm256_fmadd_ps(dst4, src17, weight17);
dst5 = _mm256_fmadd_ps(dst5, src27, weight17);
__m256 weight27 = _mm256_load_ps(weight + 240);
dst6 = _mm256_fmadd_ps(dst6, src07, weight27);
dst7 = _mm256_fmadd_ps(dst7, src17, weight27);
dst8 = _mm256_fmadd_ps(dst8, src27, weight27);
__m256 weight37 = _mm256_load_ps(weight + 248);
dst9 = _mm256_fmadd_ps(dst9, src07, weight37);
dst10 = _mm256_fmadd_ps(dst10, src17, weight37);
dst11 = _mm256_fmadd_ps(dst11, src27, weight37);
src = src + src_stride;
weight += 1024;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst6 = _mm256_min_ps(dst6, relu6);
dst9 = _mm256_min_ps(dst9, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst7 = _mm256_min_ps(dst7, relu6);
dst10 = _mm256_min_ps(dst10, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
dst8 = _mm256_min_ps(dst8, relu6);
dst11 = _mm256_min_ps(dst11, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst9 = _mm256_max_ps(dst9, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst10 = _mm256_max_ps(dst10, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst11 = _mm256_max_ps(dst11, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst9 = _mm256_max_ps(dst9, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst10 = _mm256_max_ps(dst10, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst11 = _mm256_max_ps(dst11, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 1 * src_stride + 0, dst3);
_mm256_store_ps(dst + 1 * src_stride + 8, dst4);
_mm256_store_ps(dst + 1 * src_stride + 16, dst5);
_mm256_store_ps(dst + 2 * src_stride + 0, dst6);
_mm256_store_ps(dst + 2 * src_stride + 8, dst7);
_mm256_store_ps(dst + 2 * src_stride + 16, dst8);
_mm256_store_ps(dst + 3 * src_stride + 0, dst9);
_mm256_store_ps(dst + 3 * src_stride + 8, dst10);
_mm256_store_ps(dst + 3 * src_stride + 16, dst11);
}

View File

@ -0,0 +1,302 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
const float *dst_4 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n"
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n"
"vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n"
"vmovups 0(%[dst], %[dst_stride], 2), %%ymm6\n"
"vmovups 32(%[dst], %[dst_stride], 2), %%ymm7\n"
"vmovups 64(%[dst], %[dst_stride], 2), %%ymm8\n"
"vmovups 0(%[dst_4]), %%ymm9\n"
"vmovups 32(%[dst_4]), %%ymm10\n"
"vmovups 64(%[dst_4]), %%ymm11\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 32(%[bias]), %%ymm3\n"
"vmovaps 32(%[bias]), %%ymm4\n"
"vmovaps 32(%[bias]), %%ymm5\n"
"vmovaps 64(%[bias]), %%ymm6\n"
"vmovaps 64(%[bias]), %%ymm7\n"
"vmovaps 64(%[bias]), %%ymm8\n"
"vmovaps 96(%[bias]), %%ymm9\n"
"vmovaps 96(%[bias]), %%ymm10\n"
"vmovaps 96(%[bias]), %%ymm11\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
"vxorps %%ymm9, %%ymm9, %%ymm9\n"
"vxorps %%ymm10, %%ymm10, %%ymm10\n"
"vxorps %%ymm11, %%ymm11, %%ymm11\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
[ dst_4 ] "r"(dst_4)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11");
asm volatile(
"0:\n"
// block 0
"vbroadcastss 0(%[src]), %%ymm15\n"
"vbroadcastss 32(%[src]), %%ymm14\n"
"vbroadcastss 64(%[src]), %%ymm13\n"
"vmovaps 0(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
"vmovaps 32(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
"vmovaps 64(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vmovaps 96(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
// block 1
"vbroadcastss 1(%[src]), %%ymm15\n"
"vbroadcastss 33(%[src]), %%ymm14\n"
"vbroadcastss 65(%[src]), %%ymm13\n"
"vmovaps 128(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
"vmovaps 160(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
"vmovaps 192(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vmovaps 224(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
// block 2
"vbroadcastss 2(%[src]), %%ymm15\n"
"vbroadcastss 34(%[src]), %%ymm14\n"
"vbroadcastss 66(%[src]), %%ymm13\n"
"vmovaps 256(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
"vmovaps 288(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
"vmovaps 320(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vmovaps 352(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
// block 3
"vbroadcastss 3(%[src]), %%ymm15\n"
"vbroadcastss 35(%[src]), %%ymm14\n"
"vbroadcastss 67(%[src]), %%ymm13\n"
"vmovaps 384(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
"vmovaps 416(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
"vmovaps 448(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vmovaps 480(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
// block 4
"vbroadcastss 4(%[src]), %%ymm15\n"
"vbroadcastss 36(%[src]), %%ymm14\n"
"vbroadcastss 68(%[src]), %%ymm13\n"
"vmovaps 512(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
"vmovaps 544(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
"vmovaps 576(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vmovaps 608(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
// block 5
"vbroadcastss 5(%[src]), %%ymm15\n"
"vbroadcastss 37(%[src]), %%ymm14\n"
"vbroadcastss 69(%[src]), %%ymm13\n"
"vmovaps 640(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
"vmovaps 672(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
"vmovaps 704(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vmovaps 736(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
// block 6
"vbroadcastss 6(%[src]), %%ymm15\n"
"vbroadcastss 38(%[src]), %%ymm14\n"
"vbroadcastss 70(%[src]), %%ymm13\n"
"vmovaps 768(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
"vmovaps 800(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
"vmovaps 832(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vmovaps 864(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
// block 7
"vbroadcastss 7(%[src]), %%ymm15\n"
"vbroadcastss 39(%[src]), %%ymm14\n"
"vbroadcastss 71(%[src]), %%ymm13\n"
"vmovaps 896(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
"vmovaps 928(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
"vmovaps 960(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vmovaps 992(%[weight]), %%ymm12\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
"dec %[deep]\n"
"add 1024, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
"vmaxps %%ymm9, %%ymm15, %%ymm9\n"
"vmaxps %%ymm10, %%ymm15, %%ymm10\n"
"vmaxps %%ymm11, %%ymm15, %%ymm11\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"vminps %%ymm6, %%ymm14, %%ymm6\n"
"vminps %%ymm7, %%ymm14, %%ymm7\n"
"vminps %%ymm8, %%ymm14, %%ymm8\n"
"vminps %%ymm9, %%ymm14, %%ymm9\n"
"vminps %%ymm10, %%ymm14, %%ymm10\n"
"vminps %%ymm11, %%ymm14, %%ymm11\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm6, 0(%[dst], %[dst_stride], 2)\n"
"vmovups %%ymm7, 32(%[dst], %[dst_stride], 2)\n"
"vmovups %%ymm8, 64(%[dst], %[dst_stride], 2)\n"
"vmovups %%ymm9, 0(%[dst_4])\n"
"vmovups %%ymm10, 32(%[dst_4])\n"
"vmovups %%ymm11, 64(%[dst_4])\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
[ dst_4 ] "r"(dst_4)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,129 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst1;
__m256 dst2;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst1 = _mm256_load_ps(bias + 0);
dst2 = _mm256_load_ps(bias + 0);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 8);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 16);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 24);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 32);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 40);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 48);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 56);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
src = src + src_stride;
weight += 256;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
}

View File

@ -0,0 +1,149 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vbroadcastss 0(%[src]), %%ymm14\n"
"vbroadcastss 32(%[src]), %%ymm13\n"
"vbroadcastss 64(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
// block 1
"vmovaps 32(%[weight]), %%ymm15\n"
"vbroadcastss 1(%[src]), %%ymm14\n"
"vbroadcastss 33(%[src]), %%ymm13\n"
"vbroadcastss 65(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
// block 2
"vmovaps 64(%[weight]), %%ymm15\n"
"vbroadcastss 2(%[src]), %%ymm14\n"
"vbroadcastss 34(%[src]), %%ymm13\n"
"vbroadcastss 66(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
// block 3
"vmovaps 96(%[weight]), %%ymm15\n"
"vbroadcastss 3(%[src]), %%ymm14\n"
"vbroadcastss 35(%[src]), %%ymm13\n"
"vbroadcastss 67(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
// block 4
"vmovaps 128(%[weight]), %%ymm15\n"
"vbroadcastss 4(%[src]), %%ymm14\n"
"vbroadcastss 36(%[src]), %%ymm13\n"
"vbroadcastss 68(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
// block 5
"vmovaps 160(%[weight]), %%ymm15\n"
"vbroadcastss 5(%[src]), %%ymm14\n"
"vbroadcastss 37(%[src]), %%ymm13\n"
"vbroadcastss 69(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
// block 6
"vmovaps 192(%[weight]), %%ymm15\n"
"vbroadcastss 6(%[src]), %%ymm14\n"
"vbroadcastss 38(%[src]), %%ymm13\n"
"vbroadcastss 70(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
// block 7
"vmovaps 224(%[weight]), %%ymm15\n"
"vbroadcastss 7(%[src]), %%ymm14\n"
"vbroadcastss 39(%[src]), %%ymm13\n"
"vbroadcastss 71(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"dec %[deep]\n"
"add 256, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,225 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst4;
__m256 dst1;
__m256 dst5;
__m256 dst2;
__m256 dst6;
__m256 dst3;
__m256 dst7;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst4 = _mm256_load_ps(dst + 1 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst5 = _mm256_load_ps(dst + 1 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst6 = _mm256_load_ps(dst + 1 * dst_stride + 16);
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
dst7 = _mm256_load_ps(dst + 1 * dst_stride + 24);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
dst6 = _mm256_setzero_ps();
dst7 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst4 = _mm256_load_ps(bias + 8);
dst1 = _mm256_load_ps(bias + 0);
dst5 = _mm256_load_ps(bias + 8);
dst2 = _mm256_load_ps(bias + 0);
dst6 = _mm256_load_ps(bias + 8);
dst3 = _mm256_load_ps(bias + 0);
dst7 = _mm256_load_ps(bias + 8);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 weight10 = _mm256_load_ps(weight + 8);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
dst4 = _mm256_fmadd_ps(dst4, src00, weight10);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
dst5 = _mm256_fmadd_ps(dst5, src10, weight10);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
dst6 = _mm256_fmadd_ps(dst6, src20, weight10);
__m256 src30 = _mm256_set1_ps(*(src + 24));
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
dst7 = _mm256_fmadd_ps(dst7, src30, weight10);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 16);
__m256 weight11 = _mm256_load_ps(weight + 24);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
dst4 = _mm256_fmadd_ps(dst4, src01, weight11);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
dst5 = _mm256_fmadd_ps(dst5, src11, weight11);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
dst6 = _mm256_fmadd_ps(dst6, src21, weight11);
__m256 src31 = _mm256_set1_ps(*(src + 25));
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
dst7 = _mm256_fmadd_ps(dst7, src31, weight11);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 32);
__m256 weight12 = _mm256_load_ps(weight + 40);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
dst4 = _mm256_fmadd_ps(dst4, src02, weight12);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
dst5 = _mm256_fmadd_ps(dst5, src12, weight12);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
dst6 = _mm256_fmadd_ps(dst6, src22, weight12);
__m256 src32 = _mm256_set1_ps(*(src + 26));
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
dst7 = _mm256_fmadd_ps(dst7, src32, weight12);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 48);
__m256 weight13 = _mm256_load_ps(weight + 56);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
dst4 = _mm256_fmadd_ps(dst4, src03, weight13);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
dst5 = _mm256_fmadd_ps(dst5, src13, weight13);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
dst6 = _mm256_fmadd_ps(dst6, src23, weight13);
__m256 src33 = _mm256_set1_ps(*(src + 27));
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
dst7 = _mm256_fmadd_ps(dst7, src33, weight13);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 64);
__m256 weight14 = _mm256_load_ps(weight + 72);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
dst4 = _mm256_fmadd_ps(dst4, src04, weight14);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
dst5 = _mm256_fmadd_ps(dst5, src14, weight14);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
dst6 = _mm256_fmadd_ps(dst6, src24, weight14);
__m256 src34 = _mm256_set1_ps(*(src + 28));
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
dst7 = _mm256_fmadd_ps(dst7, src34, weight14);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 80);
__m256 weight15 = _mm256_load_ps(weight + 88);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
dst4 = _mm256_fmadd_ps(dst4, src05, weight15);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
dst5 = _mm256_fmadd_ps(dst5, src15, weight15);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
dst6 = _mm256_fmadd_ps(dst6, src25, weight15);
__m256 src35 = _mm256_set1_ps(*(src + 29));
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
dst7 = _mm256_fmadd_ps(dst7, src35, weight15);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 96);
__m256 weight16 = _mm256_load_ps(weight + 104);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
dst4 = _mm256_fmadd_ps(dst4, src06, weight16);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
dst5 = _mm256_fmadd_ps(dst5, src16, weight16);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
dst6 = _mm256_fmadd_ps(dst6, src26, weight16);
__m256 src36 = _mm256_set1_ps(*(src + 30));
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
dst7 = _mm256_fmadd_ps(dst7, src36, weight16);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 112);
__m256 weight17 = _mm256_load_ps(weight + 120);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
dst4 = _mm256_fmadd_ps(dst4, src07, weight17);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
dst5 = _mm256_fmadd_ps(dst5, src17, weight17);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
dst6 = _mm256_fmadd_ps(dst6, src27, weight17);
__m256 src37 = _mm256_set1_ps(*(src + 31));
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
dst7 = _mm256_fmadd_ps(dst7, src37, weight17);
src = src + src_stride;
weight += 512;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst6 = _mm256_min_ps(dst6, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst7 = _mm256_min_ps(dst7, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst7 = _mm256_max_ps(dst7, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst7 = _mm256_max_ps(dst7, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
_mm256_store_ps(dst + 1 * src_stride + 0, dst4);
_mm256_store_ps(dst + 1 * src_stride + 8, dst5);
_mm256_store_ps(dst + 1 * src_stride + 16, dst6);
_mm256_store_ps(dst + 1 * src_stride + 24, dst7);
}

View File

@ -0,0 +1,235 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 96(%[dst]), %%ymm3\n"
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm4\n"
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm5\n"
"vmovups 64(%[dst], %[dst_stride], 1), %%ymm6\n"
"vmovups 96(%[dst], %[dst_stride], 1), %%ymm7\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 0(%[bias]), %%ymm3\n"
"vmovaps 32(%[bias]), %%ymm4\n"
"vmovaps 32(%[bias]), %%ymm5\n"
"vmovaps 32(%[bias]), %%ymm6\n"
"vmovaps 32(%[bias]), %%ymm7\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vmovaps 32(%[weight]), %%ymm14\n"
"vbroadcastss 0(%[src]), %%ymm13\n"
"vbroadcastss 32(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vbroadcastss 64(%[src]), %%ymm13\n"
"vbroadcastss 96(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
// block 1
"vmovaps 64(%[weight]), %%ymm15\n"
"vmovaps 96(%[weight]), %%ymm14\n"
"vbroadcastss 1(%[src]), %%ymm13\n"
"vbroadcastss 33(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vbroadcastss 65(%[src]), %%ymm13\n"
"vbroadcastss 97(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
// block 2
"vmovaps 128(%[weight]), %%ymm15\n"
"vmovaps 160(%[weight]), %%ymm14\n"
"vbroadcastss 2(%[src]), %%ymm13\n"
"vbroadcastss 34(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vbroadcastss 66(%[src]), %%ymm13\n"
"vbroadcastss 98(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
// block 3
"vmovaps 192(%[weight]), %%ymm15\n"
"vmovaps 224(%[weight]), %%ymm14\n"
"vbroadcastss 3(%[src]), %%ymm13\n"
"vbroadcastss 35(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vbroadcastss 67(%[src]), %%ymm13\n"
"vbroadcastss 99(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
// block 4
"vmovaps 256(%[weight]), %%ymm15\n"
"vmovaps 288(%[weight]), %%ymm14\n"
"vbroadcastss 4(%[src]), %%ymm13\n"
"vbroadcastss 36(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vbroadcastss 68(%[src]), %%ymm13\n"
"vbroadcastss 100(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
// block 5
"vmovaps 320(%[weight]), %%ymm15\n"
"vmovaps 352(%[weight]), %%ymm14\n"
"vbroadcastss 5(%[src]), %%ymm13\n"
"vbroadcastss 37(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vbroadcastss 69(%[src]), %%ymm13\n"
"vbroadcastss 101(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
// block 6
"vmovaps 384(%[weight]), %%ymm15\n"
"vmovaps 416(%[weight]), %%ymm14\n"
"vbroadcastss 6(%[src]), %%ymm13\n"
"vbroadcastss 38(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vbroadcastss 70(%[src]), %%ymm13\n"
"vbroadcastss 102(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
// block 7
"vmovaps 448(%[weight]), %%ymm15\n"
"vmovaps 480(%[weight]), %%ymm14\n"
"vbroadcastss 7(%[src]), %%ymm13\n"
"vbroadcastss 39(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vbroadcastss 71(%[src]), %%ymm13\n"
"vbroadcastss 103(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"dec %[deep]\n"
"add 512, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"vminps %%ymm6, %%ymm14, %%ymm6\n"
"vminps %%ymm7, %%ymm14, %%ymm7\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 96(%[dst])\n"
"vmovups %%ymm4, 0(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm5, 32(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm6, 64(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm7, 96(%[dst], %[dst_stride], 1)\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,297 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst4;
__m256 dst8;
__m256 dst1;
__m256 dst5;
__m256 dst9;
__m256 dst2;
__m256 dst6;
__m256 dst10;
__m256 dst3;
__m256 dst7;
__m256 dst11;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst4 = _mm256_load_ps(dst + 1 * dst_stride + 0);
dst8 = _mm256_load_ps(dst + 2 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst5 = _mm256_load_ps(dst + 1 * dst_stride + 8);
dst9 = _mm256_load_ps(dst + 2 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst6 = _mm256_load_ps(dst + 1 * dst_stride + 16);
dst10 = _mm256_load_ps(dst + 2 * dst_stride + 16);
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
dst7 = _mm256_load_ps(dst + 1 * dst_stride + 24);
dst11 = _mm256_load_ps(dst + 2 * dst_stride + 24);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
dst6 = _mm256_setzero_ps();
dst7 = _mm256_setzero_ps();
dst8 = _mm256_setzero_ps();
dst9 = _mm256_setzero_ps();
dst10 = _mm256_setzero_ps();
dst11 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst4 = _mm256_load_ps(bias + 8);
dst8 = _mm256_load_ps(bias + 16);
dst1 = _mm256_load_ps(bias + 0);
dst5 = _mm256_load_ps(bias + 8);
dst9 = _mm256_load_ps(bias + 16);
dst2 = _mm256_load_ps(bias + 0);
dst6 = _mm256_load_ps(bias + 8);
dst10 = _mm256_load_ps(bias + 16);
dst3 = _mm256_load_ps(bias + 0);
dst7 = _mm256_load_ps(bias + 8);
dst11 = _mm256_load_ps(bias + 16);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 weight10 = _mm256_load_ps(weight + 8);
__m256 weight20 = _mm256_load_ps(weight + 16);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
dst4 = _mm256_fmadd_ps(dst4, src00, weight10);
dst8 = _mm256_fmadd_ps(dst8, src00, weight20);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
dst5 = _mm256_fmadd_ps(dst5, src10, weight10);
dst9 = _mm256_fmadd_ps(dst9, src10, weight20);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
dst6 = _mm256_fmadd_ps(dst6, src20, weight10);
dst10 = _mm256_fmadd_ps(dst10, src20, weight20);
__m256 src30 = _mm256_set1_ps(*(src + 24));
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
dst7 = _mm256_fmadd_ps(dst7, src30, weight10);
dst11 = _mm256_fmadd_ps(dst11, src30, weight20);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 24);
__m256 weight11 = _mm256_load_ps(weight + 32);
__m256 weight21 = _mm256_load_ps(weight + 40);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
dst4 = _mm256_fmadd_ps(dst4, src01, weight11);
dst8 = _mm256_fmadd_ps(dst8, src01, weight21);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
dst5 = _mm256_fmadd_ps(dst5, src11, weight11);
dst9 = _mm256_fmadd_ps(dst9, src11, weight21);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
dst6 = _mm256_fmadd_ps(dst6, src21, weight11);
dst10 = _mm256_fmadd_ps(dst10, src21, weight21);
__m256 src31 = _mm256_set1_ps(*(src + 25));
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
dst7 = _mm256_fmadd_ps(dst7, src31, weight11);
dst11 = _mm256_fmadd_ps(dst11, src31, weight21);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 48);
__m256 weight12 = _mm256_load_ps(weight + 56);
__m256 weight22 = _mm256_load_ps(weight + 64);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
dst4 = _mm256_fmadd_ps(dst4, src02, weight12);
dst8 = _mm256_fmadd_ps(dst8, src02, weight22);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
dst5 = _mm256_fmadd_ps(dst5, src12, weight12);
dst9 = _mm256_fmadd_ps(dst9, src12, weight22);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
dst6 = _mm256_fmadd_ps(dst6, src22, weight12);
dst10 = _mm256_fmadd_ps(dst10, src22, weight22);
__m256 src32 = _mm256_set1_ps(*(src + 26));
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
dst7 = _mm256_fmadd_ps(dst7, src32, weight12);
dst11 = _mm256_fmadd_ps(dst11, src32, weight22);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 72);
__m256 weight13 = _mm256_load_ps(weight + 80);
__m256 weight23 = _mm256_load_ps(weight + 88);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
dst4 = _mm256_fmadd_ps(dst4, src03, weight13);
dst8 = _mm256_fmadd_ps(dst8, src03, weight23);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
dst5 = _mm256_fmadd_ps(dst5, src13, weight13);
dst9 = _mm256_fmadd_ps(dst9, src13, weight23);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
dst6 = _mm256_fmadd_ps(dst6, src23, weight13);
dst10 = _mm256_fmadd_ps(dst10, src23, weight23);
__m256 src33 = _mm256_set1_ps(*(src + 27));
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
dst7 = _mm256_fmadd_ps(dst7, src33, weight13);
dst11 = _mm256_fmadd_ps(dst11, src33, weight23);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 96);
__m256 weight14 = _mm256_load_ps(weight + 104);
__m256 weight24 = _mm256_load_ps(weight + 112);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
dst4 = _mm256_fmadd_ps(dst4, src04, weight14);
dst8 = _mm256_fmadd_ps(dst8, src04, weight24);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
dst5 = _mm256_fmadd_ps(dst5, src14, weight14);
dst9 = _mm256_fmadd_ps(dst9, src14, weight24);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
dst6 = _mm256_fmadd_ps(dst6, src24, weight14);
dst10 = _mm256_fmadd_ps(dst10, src24, weight24);
__m256 src34 = _mm256_set1_ps(*(src + 28));
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
dst7 = _mm256_fmadd_ps(dst7, src34, weight14);
dst11 = _mm256_fmadd_ps(dst11, src34, weight24);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 120);
__m256 weight15 = _mm256_load_ps(weight + 128);
__m256 weight25 = _mm256_load_ps(weight + 136);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
dst4 = _mm256_fmadd_ps(dst4, src05, weight15);
dst8 = _mm256_fmadd_ps(dst8, src05, weight25);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
dst5 = _mm256_fmadd_ps(dst5, src15, weight15);
dst9 = _mm256_fmadd_ps(dst9, src15, weight25);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
dst6 = _mm256_fmadd_ps(dst6, src25, weight15);
dst10 = _mm256_fmadd_ps(dst10, src25, weight25);
__m256 src35 = _mm256_set1_ps(*(src + 29));
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
dst7 = _mm256_fmadd_ps(dst7, src35, weight15);
dst11 = _mm256_fmadd_ps(dst11, src35, weight25);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 144);
__m256 weight16 = _mm256_load_ps(weight + 152);
__m256 weight26 = _mm256_load_ps(weight + 160);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
dst4 = _mm256_fmadd_ps(dst4, src06, weight16);
dst8 = _mm256_fmadd_ps(dst8, src06, weight26);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
dst5 = _mm256_fmadd_ps(dst5, src16, weight16);
dst9 = _mm256_fmadd_ps(dst9, src16, weight26);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
dst6 = _mm256_fmadd_ps(dst6, src26, weight16);
dst10 = _mm256_fmadd_ps(dst10, src26, weight26);
__m256 src36 = _mm256_set1_ps(*(src + 30));
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
dst7 = _mm256_fmadd_ps(dst7, src36, weight16);
dst11 = _mm256_fmadd_ps(dst11, src36, weight26);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 168);
__m256 weight17 = _mm256_load_ps(weight + 176);
__m256 weight27 = _mm256_load_ps(weight + 184);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
dst4 = _mm256_fmadd_ps(dst4, src07, weight17);
dst8 = _mm256_fmadd_ps(dst8, src07, weight27);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
dst5 = _mm256_fmadd_ps(dst5, src17, weight17);
dst9 = _mm256_fmadd_ps(dst9, src17, weight27);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
dst6 = _mm256_fmadd_ps(dst6, src27, weight17);
dst10 = _mm256_fmadd_ps(dst10, src27, weight27);
__m256 src37 = _mm256_set1_ps(*(src + 31));
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
dst7 = _mm256_fmadd_ps(dst7, src37, weight17);
dst11 = _mm256_fmadd_ps(dst11, src37, weight27);
src = src + src_stride;
weight += 768;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst8 = _mm256_min_ps(dst8, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
dst9 = _mm256_min_ps(dst9, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst6 = _mm256_min_ps(dst6, relu6);
dst10 = _mm256_min_ps(dst10, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst7 = _mm256_min_ps(dst7, relu6);
dst11 = _mm256_min_ps(dst11, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst9 = _mm256_max_ps(dst9, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst10 = _mm256_max_ps(dst10, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst11 = _mm256_max_ps(dst11, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst9 = _mm256_max_ps(dst9, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst10 = _mm256_max_ps(dst10, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst11 = _mm256_max_ps(dst11, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
_mm256_store_ps(dst + 1 * src_stride + 0, dst4);
_mm256_store_ps(dst + 1 * src_stride + 8, dst5);
_mm256_store_ps(dst + 1 * src_stride + 16, dst6);
_mm256_store_ps(dst + 1 * src_stride + 24, dst7);
_mm256_store_ps(dst + 2 * src_stride + 0, dst8);
_mm256_store_ps(dst + 2 * src_stride + 8, dst9);
_mm256_store_ps(dst + 2 * src_stride + 16, dst10);
_mm256_store_ps(dst + 2 * src_stride + 24, dst11);
}

View File

@ -0,0 +1,299 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 96(%[dst]), %%ymm3\n"
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm4\n"
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm5\n"
"vmovups 64(%[dst], %[dst_stride], 1), %%ymm6\n"
"vmovups 96(%[dst], %[dst_stride], 1), %%ymm7\n"
"vmovups 0(%[dst], %[dst_stride], 2), %%ymm8\n"
"vmovups 32(%[dst], %[dst_stride], 2), %%ymm9\n"
"vmovups 64(%[dst], %[dst_stride], 2), %%ymm10\n"
"vmovups 96(%[dst], %[dst_stride], 2), %%ymm11\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 0(%[bias]), %%ymm3\n"
"vmovaps 32(%[bias]), %%ymm4\n"
"vmovaps 32(%[bias]), %%ymm5\n"
"vmovaps 32(%[bias]), %%ymm6\n"
"vmovaps 32(%[bias]), %%ymm7\n"
"vmovaps 64(%[bias]), %%ymm8\n"
"vmovaps 64(%[bias]), %%ymm9\n"
"vmovaps 64(%[bias]), %%ymm10\n"
"vmovaps 64(%[bias]), %%ymm11\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
"vxorps %%ymm9, %%ymm9, %%ymm9\n"
"vxorps %%ymm10, %%ymm10, %%ymm10\n"
"vxorps %%ymm11, %%ymm11, %%ymm11\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vmovaps 32(%[weight]), %%ymm14\n"
"vmovaps 64(%[weight]), %%ymm13\n"
"vbroadcastss 0(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vbroadcastss 32(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
"vbroadcastss 64(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
"vbroadcastss 96(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
// block 1
"vmovaps 96(%[weight]), %%ymm15\n"
"vmovaps 128(%[weight]), %%ymm14\n"
"vmovaps 160(%[weight]), %%ymm13\n"
"vbroadcastss 1(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vbroadcastss 33(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
"vbroadcastss 65(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
"vbroadcastss 97(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
// block 2
"vmovaps 192(%[weight]), %%ymm15\n"
"vmovaps 224(%[weight]), %%ymm14\n"
"vmovaps 256(%[weight]), %%ymm13\n"
"vbroadcastss 2(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vbroadcastss 34(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
"vbroadcastss 66(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
"vbroadcastss 98(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
// block 3
"vmovaps 288(%[weight]), %%ymm15\n"
"vmovaps 320(%[weight]), %%ymm14\n"
"vmovaps 352(%[weight]), %%ymm13\n"
"vbroadcastss 3(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vbroadcastss 35(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
"vbroadcastss 67(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
"vbroadcastss 99(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
// block 4
"vmovaps 384(%[weight]), %%ymm15\n"
"vmovaps 416(%[weight]), %%ymm14\n"
"vmovaps 448(%[weight]), %%ymm13\n"
"vbroadcastss 4(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vbroadcastss 36(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
"vbroadcastss 68(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
"vbroadcastss 100(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
// block 5
"vmovaps 480(%[weight]), %%ymm15\n"
"vmovaps 512(%[weight]), %%ymm14\n"
"vmovaps 544(%[weight]), %%ymm13\n"
"vbroadcastss 5(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vbroadcastss 37(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
"vbroadcastss 69(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
"vbroadcastss 101(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
// block 6
"vmovaps 576(%[weight]), %%ymm15\n"
"vmovaps 608(%[weight]), %%ymm14\n"
"vmovaps 640(%[weight]), %%ymm13\n"
"vbroadcastss 6(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vbroadcastss 38(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
"vbroadcastss 70(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
"vbroadcastss 102(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
// block 7
"vmovaps 672(%[weight]), %%ymm15\n"
"vmovaps 704(%[weight]), %%ymm14\n"
"vmovaps 736(%[weight]), %%ymm13\n"
"vbroadcastss 7(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
"vbroadcastss 39(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
"vbroadcastss 71(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
"vbroadcastss 103(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
"dec %[deep]\n"
"add 768, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
"vmaxps %%ymm9, %%ymm15, %%ymm9\n"
"vmaxps %%ymm10, %%ymm15, %%ymm10\n"
"vmaxps %%ymm11, %%ymm15, %%ymm11\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"vminps %%ymm6, %%ymm14, %%ymm6\n"
"vminps %%ymm7, %%ymm14, %%ymm7\n"
"vminps %%ymm8, %%ymm14, %%ymm8\n"
"vminps %%ymm9, %%ymm14, %%ymm9\n"
"vminps %%ymm10, %%ymm14, %%ymm10\n"
"vminps %%ymm11, %%ymm14, %%ymm11\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 96(%[dst])\n"
"vmovups %%ymm4, 0(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm5, 32(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm6, 64(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm7, 96(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm8, 0(%[dst], %[dst_stride], 2)\n"
"vmovups %%ymm9, 32(%[dst], %[dst_stride], 2)\n"
"vmovups %%ymm10, 64(%[dst], %[dst_stride], 2)\n"
"vmovups %%ymm11, 96(%[dst], %[dst_stride], 2)\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,153 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst1;
__m256 dst2;
__m256 dst3;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst1 = _mm256_load_ps(bias + 0);
dst2 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 0);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
__m256 src30 = _mm256_set1_ps(*(src + 24));
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 8);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
__m256 src31 = _mm256_set1_ps(*(src + 25));
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 16);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
__m256 src32 = _mm256_set1_ps(*(src + 26));
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 24);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
__m256 src33 = _mm256_set1_ps(*(src + 27));
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 32);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
__m256 src34 = _mm256_set1_ps(*(src + 28));
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 40);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
__m256 src35 = _mm256_set1_ps(*(src + 29));
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 48);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
__m256 src36 = _mm256_set1_ps(*(src + 30));
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 56);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
__m256 src37 = _mm256_set1_ps(*(src + 31));
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
src = src + src_stride;
weight += 256;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
}

View File

@ -0,0 +1,171 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 96(%[dst]), %%ymm3\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 0(%[bias]), %%ymm3\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vbroadcastss 0(%[src]), %%ymm14\n"
"vbroadcastss 32(%[src]), %%ymm13\n"
"vbroadcastss 64(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 96(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
// block 1
"vmovaps 32(%[weight]), %%ymm15\n"
"vbroadcastss 1(%[src]), %%ymm14\n"
"vbroadcastss 33(%[src]), %%ymm13\n"
"vbroadcastss 65(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 97(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
// block 2
"vmovaps 64(%[weight]), %%ymm15\n"
"vbroadcastss 2(%[src]), %%ymm14\n"
"vbroadcastss 34(%[src]), %%ymm13\n"
"vbroadcastss 66(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 98(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
// block 3
"vmovaps 96(%[weight]), %%ymm15\n"
"vbroadcastss 3(%[src]), %%ymm14\n"
"vbroadcastss 35(%[src]), %%ymm13\n"
"vbroadcastss 67(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 99(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
// block 4
"vmovaps 128(%[weight]), %%ymm15\n"
"vbroadcastss 4(%[src]), %%ymm14\n"
"vbroadcastss 36(%[src]), %%ymm13\n"
"vbroadcastss 68(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 100(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
// block 5
"vmovaps 160(%[weight]), %%ymm15\n"
"vbroadcastss 5(%[src]), %%ymm14\n"
"vbroadcastss 37(%[src]), %%ymm13\n"
"vbroadcastss 69(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 101(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
// block 6
"vmovaps 192(%[weight]), %%ymm15\n"
"vbroadcastss 6(%[src]), %%ymm14\n"
"vbroadcastss 38(%[src]), %%ymm13\n"
"vbroadcastss 70(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 102(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
// block 7
"vmovaps 224(%[weight]), %%ymm15\n"
"vbroadcastss 7(%[src]), %%ymm14\n"
"vbroadcastss 39(%[src]), %%ymm13\n"
"vbroadcastss 71(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 103(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"dec %[deep]\n"
"add 256, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 96(%[dst])\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,265 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst5;
__m256 dst1;
__m256 dst6;
__m256 dst2;
__m256 dst7;
__m256 dst3;
__m256 dst8;
__m256 dst4;
__m256 dst9;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst5 = _mm256_load_ps(dst + 1 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst6 = _mm256_load_ps(dst + 1 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst7 = _mm256_load_ps(dst + 1 * dst_stride + 16);
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
dst8 = _mm256_load_ps(dst + 1 * dst_stride + 24);
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
dst9 = _mm256_load_ps(dst + 1 * dst_stride + 32);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
dst6 = _mm256_setzero_ps();
dst7 = _mm256_setzero_ps();
dst8 = _mm256_setzero_ps();
dst9 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst5 = _mm256_load_ps(bias + 8);
dst1 = _mm256_load_ps(bias + 0);
dst6 = _mm256_load_ps(bias + 8);
dst2 = _mm256_load_ps(bias + 0);
dst7 = _mm256_load_ps(bias + 8);
dst3 = _mm256_load_ps(bias + 0);
dst8 = _mm256_load_ps(bias + 8);
dst4 = _mm256_load_ps(bias + 0);
dst9 = _mm256_load_ps(bias + 8);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 weight10 = _mm256_load_ps(weight + 8);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
dst5 = _mm256_fmadd_ps(dst5, src00, weight10);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
dst6 = _mm256_fmadd_ps(dst6, src10, weight10);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
dst7 = _mm256_fmadd_ps(dst7, src20, weight10);
__m256 src30 = _mm256_set1_ps(*(src + 24));
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
dst8 = _mm256_fmadd_ps(dst8, src30, weight10);
__m256 src40 = _mm256_set1_ps(*(src + 32));
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
dst9 = _mm256_fmadd_ps(dst9, src40, weight10);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 16);
__m256 weight11 = _mm256_load_ps(weight + 24);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
dst5 = _mm256_fmadd_ps(dst5, src01, weight11);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
dst6 = _mm256_fmadd_ps(dst6, src11, weight11);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
dst7 = _mm256_fmadd_ps(dst7, src21, weight11);
__m256 src31 = _mm256_set1_ps(*(src + 25));
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
dst8 = _mm256_fmadd_ps(dst8, src31, weight11);
__m256 src41 = _mm256_set1_ps(*(src + 33));
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
dst9 = _mm256_fmadd_ps(dst9, src41, weight11);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 32);
__m256 weight12 = _mm256_load_ps(weight + 40);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
dst5 = _mm256_fmadd_ps(dst5, src02, weight12);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
dst6 = _mm256_fmadd_ps(dst6, src12, weight12);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
dst7 = _mm256_fmadd_ps(dst7, src22, weight12);
__m256 src32 = _mm256_set1_ps(*(src + 26));
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
dst8 = _mm256_fmadd_ps(dst8, src32, weight12);
__m256 src42 = _mm256_set1_ps(*(src + 34));
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
dst9 = _mm256_fmadd_ps(dst9, src42, weight12);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 48);
__m256 weight13 = _mm256_load_ps(weight + 56);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
dst5 = _mm256_fmadd_ps(dst5, src03, weight13);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
dst6 = _mm256_fmadd_ps(dst6, src13, weight13);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
dst7 = _mm256_fmadd_ps(dst7, src23, weight13);
__m256 src33 = _mm256_set1_ps(*(src + 27));
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
dst8 = _mm256_fmadd_ps(dst8, src33, weight13);
__m256 src43 = _mm256_set1_ps(*(src + 35));
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
dst9 = _mm256_fmadd_ps(dst9, src43, weight13);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 64);
__m256 weight14 = _mm256_load_ps(weight + 72);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
dst5 = _mm256_fmadd_ps(dst5, src04, weight14);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
dst6 = _mm256_fmadd_ps(dst6, src14, weight14);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
dst7 = _mm256_fmadd_ps(dst7, src24, weight14);
__m256 src34 = _mm256_set1_ps(*(src + 28));
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
dst8 = _mm256_fmadd_ps(dst8, src34, weight14);
__m256 src44 = _mm256_set1_ps(*(src + 36));
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
dst9 = _mm256_fmadd_ps(dst9, src44, weight14);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 80);
__m256 weight15 = _mm256_load_ps(weight + 88);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
dst5 = _mm256_fmadd_ps(dst5, src05, weight15);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
dst6 = _mm256_fmadd_ps(dst6, src15, weight15);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
dst7 = _mm256_fmadd_ps(dst7, src25, weight15);
__m256 src35 = _mm256_set1_ps(*(src + 29));
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
dst8 = _mm256_fmadd_ps(dst8, src35, weight15);
__m256 src45 = _mm256_set1_ps(*(src + 37));
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
dst9 = _mm256_fmadd_ps(dst9, src45, weight15);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 96);
__m256 weight16 = _mm256_load_ps(weight + 104);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
dst5 = _mm256_fmadd_ps(dst5, src06, weight16);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
dst6 = _mm256_fmadd_ps(dst6, src16, weight16);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
dst7 = _mm256_fmadd_ps(dst7, src26, weight16);
__m256 src36 = _mm256_set1_ps(*(src + 30));
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
dst8 = _mm256_fmadd_ps(dst8, src36, weight16);
__m256 src46 = _mm256_set1_ps(*(src + 38));
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
dst9 = _mm256_fmadd_ps(dst9, src46, weight16);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 112);
__m256 weight17 = _mm256_load_ps(weight + 120);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
dst5 = _mm256_fmadd_ps(dst5, src07, weight17);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
dst6 = _mm256_fmadd_ps(dst6, src17, weight17);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
dst7 = _mm256_fmadd_ps(dst7, src27, weight17);
__m256 src37 = _mm256_set1_ps(*(src + 31));
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
dst8 = _mm256_fmadd_ps(dst8, src37, weight17);
__m256 src47 = _mm256_set1_ps(*(src + 39));
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
dst9 = _mm256_fmadd_ps(dst9, src47, weight17);
src = src + src_stride;
weight += 512;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst6 = _mm256_min_ps(dst6, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst7 = _mm256_min_ps(dst7, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst8 = _mm256_min_ps(dst8, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst9 = _mm256_min_ps(dst9, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst9 = _mm256_max_ps(dst9, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst9 = _mm256_max_ps(dst9, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
_mm256_store_ps(dst + 1 * src_stride + 0, dst5);
_mm256_store_ps(dst + 1 * src_stride + 8, dst6);
_mm256_store_ps(dst + 1 * src_stride + 16, dst7);
_mm256_store_ps(dst + 1 * src_stride + 24, dst8);
_mm256_store_ps(dst + 1 * src_stride + 32, dst9);
}

View File

@ -0,0 +1,271 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 96(%[dst]), %%ymm3\n"
"vmovups 128(%[dst]), %%ymm4\n"
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm5\n"
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm6\n"
"vmovups 64(%[dst], %[dst_stride], 1), %%ymm7\n"
"vmovups 96(%[dst], %[dst_stride], 1), %%ymm8\n"
"vmovups 128(%[dst], %[dst_stride], 1), %%ymm9\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 0(%[bias]), %%ymm3\n"
"vmovaps 0(%[bias]), %%ymm4\n"
"vmovaps 32(%[bias]), %%ymm5\n"
"vmovaps 32(%[bias]), %%ymm6\n"
"vmovaps 32(%[bias]), %%ymm7\n"
"vmovaps 32(%[bias]), %%ymm8\n"
"vmovaps 32(%[bias]), %%ymm9\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
"vxorps %%ymm9, %%ymm9, %%ymm9\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vmovaps 32(%[weight]), %%ymm14\n"
"vbroadcastss 0(%[src]), %%ymm13\n"
"vbroadcastss 32(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vbroadcastss 64(%[src]), %%ymm13\n"
"vbroadcastss 96(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
"vbroadcastss 128(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
// block 1
"vmovaps 64(%[weight]), %%ymm15\n"
"vmovaps 96(%[weight]), %%ymm14\n"
"vbroadcastss 1(%[src]), %%ymm13\n"
"vbroadcastss 33(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vbroadcastss 65(%[src]), %%ymm13\n"
"vbroadcastss 97(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
"vbroadcastss 129(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
// block 2
"vmovaps 128(%[weight]), %%ymm15\n"
"vmovaps 160(%[weight]), %%ymm14\n"
"vbroadcastss 2(%[src]), %%ymm13\n"
"vbroadcastss 34(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vbroadcastss 66(%[src]), %%ymm13\n"
"vbroadcastss 98(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
"vbroadcastss 130(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
// block 3
"vmovaps 192(%[weight]), %%ymm15\n"
"vmovaps 224(%[weight]), %%ymm14\n"
"vbroadcastss 3(%[src]), %%ymm13\n"
"vbroadcastss 35(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vbroadcastss 67(%[src]), %%ymm13\n"
"vbroadcastss 99(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
"vbroadcastss 131(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
// block 4
"vmovaps 256(%[weight]), %%ymm15\n"
"vmovaps 288(%[weight]), %%ymm14\n"
"vbroadcastss 4(%[src]), %%ymm13\n"
"vbroadcastss 36(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vbroadcastss 68(%[src]), %%ymm13\n"
"vbroadcastss 100(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
"vbroadcastss 132(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
// block 5
"vmovaps 320(%[weight]), %%ymm15\n"
"vmovaps 352(%[weight]), %%ymm14\n"
"vbroadcastss 5(%[src]), %%ymm13\n"
"vbroadcastss 37(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vbroadcastss 69(%[src]), %%ymm13\n"
"vbroadcastss 101(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
"vbroadcastss 133(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
// block 6
"vmovaps 384(%[weight]), %%ymm15\n"
"vmovaps 416(%[weight]), %%ymm14\n"
"vbroadcastss 6(%[src]), %%ymm13\n"
"vbroadcastss 38(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vbroadcastss 70(%[src]), %%ymm13\n"
"vbroadcastss 102(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
"vbroadcastss 134(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
// block 7
"vmovaps 448(%[weight]), %%ymm15\n"
"vmovaps 480(%[weight]), %%ymm14\n"
"vbroadcastss 7(%[src]), %%ymm13\n"
"vbroadcastss 39(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
"vbroadcastss 71(%[src]), %%ymm13\n"
"vbroadcastss 103(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
"vbroadcastss 135(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
"dec %[deep]\n"
"add 512, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
"vmaxps %%ymm9, %%ymm15, %%ymm9\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"vminps %%ymm6, %%ymm14, %%ymm6\n"
"vminps %%ymm7, %%ymm14, %%ymm7\n"
"vminps %%ymm8, %%ymm14, %%ymm8\n"
"vminps %%ymm9, %%ymm14, %%ymm9\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 96(%[dst])\n"
"vmovups %%ymm4, 128(%[dst])\n"
"vmovups %%ymm5, 0(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm6, 32(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm7, 64(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm8, 96(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm9, 128(%[dst], %[dst_stride], 1)\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,177 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst1;
__m256 dst2;
__m256 dst3;
__m256 dst4;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst1 = _mm256_load_ps(bias + 0);
dst2 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 0);
dst4 = _mm256_load_ps(bias + 0);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
__m256 src30 = _mm256_set1_ps(*(src + 24));
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
__m256 src40 = _mm256_set1_ps(*(src + 32));
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 8);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
__m256 src31 = _mm256_set1_ps(*(src + 25));
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
__m256 src41 = _mm256_set1_ps(*(src + 33));
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 16);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
__m256 src32 = _mm256_set1_ps(*(src + 26));
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
__m256 src42 = _mm256_set1_ps(*(src + 34));
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 24);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
__m256 src33 = _mm256_set1_ps(*(src + 27));
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
__m256 src43 = _mm256_set1_ps(*(src + 35));
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 32);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
__m256 src34 = _mm256_set1_ps(*(src + 28));
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
__m256 src44 = _mm256_set1_ps(*(src + 36));
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 40);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
__m256 src35 = _mm256_set1_ps(*(src + 29));
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
__m256 src45 = _mm256_set1_ps(*(src + 37));
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 48);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
__m256 src36 = _mm256_set1_ps(*(src + 30));
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
__m256 src46 = _mm256_set1_ps(*(src + 38));
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 56);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
__m256 src37 = _mm256_set1_ps(*(src + 31));
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
__m256 src47 = _mm256_set1_ps(*(src + 39));
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
src = src + src_stride;
weight += 256;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
}

View File

@ -0,0 +1,193 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 96(%[dst]), %%ymm3\n"
"vmovups 128(%[dst]), %%ymm4\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 0(%[bias]), %%ymm3\n"
"vmovaps 0(%[bias]), %%ymm4\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vbroadcastss 0(%[src]), %%ymm14\n"
"vbroadcastss 32(%[src]), %%ymm13\n"
"vbroadcastss 64(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 96(%[src]), %%ymm14\n"
"vbroadcastss 128(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
// block 1
"vmovaps 32(%[weight]), %%ymm15\n"
"vbroadcastss 1(%[src]), %%ymm14\n"
"vbroadcastss 33(%[src]), %%ymm13\n"
"vbroadcastss 65(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 97(%[src]), %%ymm14\n"
"vbroadcastss 129(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
// block 2
"vmovaps 64(%[weight]), %%ymm15\n"
"vbroadcastss 2(%[src]), %%ymm14\n"
"vbroadcastss 34(%[src]), %%ymm13\n"
"vbroadcastss 66(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 98(%[src]), %%ymm14\n"
"vbroadcastss 130(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
// block 3
"vmovaps 96(%[weight]), %%ymm15\n"
"vbroadcastss 3(%[src]), %%ymm14\n"
"vbroadcastss 35(%[src]), %%ymm13\n"
"vbroadcastss 67(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 99(%[src]), %%ymm14\n"
"vbroadcastss 131(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
// block 4
"vmovaps 128(%[weight]), %%ymm15\n"
"vbroadcastss 4(%[src]), %%ymm14\n"
"vbroadcastss 36(%[src]), %%ymm13\n"
"vbroadcastss 68(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 100(%[src]), %%ymm14\n"
"vbroadcastss 132(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
// block 5
"vmovaps 160(%[weight]), %%ymm15\n"
"vbroadcastss 5(%[src]), %%ymm14\n"
"vbroadcastss 37(%[src]), %%ymm13\n"
"vbroadcastss 69(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 101(%[src]), %%ymm14\n"
"vbroadcastss 133(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
// block 6
"vmovaps 192(%[weight]), %%ymm15\n"
"vbroadcastss 6(%[src]), %%ymm14\n"
"vbroadcastss 38(%[src]), %%ymm13\n"
"vbroadcastss 70(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 102(%[src]), %%ymm14\n"
"vbroadcastss 134(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
// block 7
"vmovaps 224(%[weight]), %%ymm15\n"
"vbroadcastss 7(%[src]), %%ymm14\n"
"vbroadcastss 39(%[src]), %%ymm13\n"
"vbroadcastss 71(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 103(%[src]), %%ymm14\n"
"vbroadcastss 135(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"dec %[deep]\n"
"add 256, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 96(%[dst])\n"
"vmovups %%ymm4, 128(%[dst])\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,305 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst6;
__m256 dst1;
__m256 dst7;
__m256 dst2;
__m256 dst8;
__m256 dst3;
__m256 dst9;
__m256 dst4;
__m256 dst10;
__m256 dst5;
__m256 dst11;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst6 = _mm256_load_ps(dst + 1 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst7 = _mm256_load_ps(dst + 1 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst8 = _mm256_load_ps(dst + 1 * dst_stride + 16);
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
dst9 = _mm256_load_ps(dst + 1 * dst_stride + 24);
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
dst10 = _mm256_load_ps(dst + 1 * dst_stride + 32);
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
dst11 = _mm256_load_ps(dst + 1 * dst_stride + 40);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
dst6 = _mm256_setzero_ps();
dst7 = _mm256_setzero_ps();
dst8 = _mm256_setzero_ps();
dst9 = _mm256_setzero_ps();
dst10 = _mm256_setzero_ps();
dst11 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst6 = _mm256_load_ps(bias + 8);
dst1 = _mm256_load_ps(bias + 0);
dst7 = _mm256_load_ps(bias + 8);
dst2 = _mm256_load_ps(bias + 0);
dst8 = _mm256_load_ps(bias + 8);
dst3 = _mm256_load_ps(bias + 0);
dst9 = _mm256_load_ps(bias + 8);
dst4 = _mm256_load_ps(bias + 0);
dst10 = _mm256_load_ps(bias + 8);
dst5 = _mm256_load_ps(bias + 0);
dst11 = _mm256_load_ps(bias + 8);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 weight10 = _mm256_load_ps(weight + 8);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
dst6 = _mm256_fmadd_ps(dst6, src00, weight10);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
dst7 = _mm256_fmadd_ps(dst7, src10, weight10);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
dst8 = _mm256_fmadd_ps(dst8, src20, weight10);
__m256 src30 = _mm256_set1_ps(*(src + 24));
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
dst9 = _mm256_fmadd_ps(dst9, src30, weight10);
__m256 src40 = _mm256_set1_ps(*(src + 32));
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
dst10 = _mm256_fmadd_ps(dst10, src40, weight10);
__m256 src50 = _mm256_set1_ps(*(src + 40));
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
dst11 = _mm256_fmadd_ps(dst11, src50, weight10);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 16);
__m256 weight11 = _mm256_load_ps(weight + 24);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
dst6 = _mm256_fmadd_ps(dst6, src01, weight11);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
dst7 = _mm256_fmadd_ps(dst7, src11, weight11);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
dst8 = _mm256_fmadd_ps(dst8, src21, weight11);
__m256 src31 = _mm256_set1_ps(*(src + 25));
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
dst9 = _mm256_fmadd_ps(dst9, src31, weight11);
__m256 src41 = _mm256_set1_ps(*(src + 33));
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
dst10 = _mm256_fmadd_ps(dst10, src41, weight11);
__m256 src51 = _mm256_set1_ps(*(src + 41));
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
dst11 = _mm256_fmadd_ps(dst11, src51, weight11);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 32);
__m256 weight12 = _mm256_load_ps(weight + 40);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
dst6 = _mm256_fmadd_ps(dst6, src02, weight12);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
dst7 = _mm256_fmadd_ps(dst7, src12, weight12);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
dst8 = _mm256_fmadd_ps(dst8, src22, weight12);
__m256 src32 = _mm256_set1_ps(*(src + 26));
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
dst9 = _mm256_fmadd_ps(dst9, src32, weight12);
__m256 src42 = _mm256_set1_ps(*(src + 34));
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
dst10 = _mm256_fmadd_ps(dst10, src42, weight12);
__m256 src52 = _mm256_set1_ps(*(src + 42));
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
dst11 = _mm256_fmadd_ps(dst11, src52, weight12);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 48);
__m256 weight13 = _mm256_load_ps(weight + 56);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
dst6 = _mm256_fmadd_ps(dst6, src03, weight13);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
dst7 = _mm256_fmadd_ps(dst7, src13, weight13);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
dst8 = _mm256_fmadd_ps(dst8, src23, weight13);
__m256 src33 = _mm256_set1_ps(*(src + 27));
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
dst9 = _mm256_fmadd_ps(dst9, src33, weight13);
__m256 src43 = _mm256_set1_ps(*(src + 35));
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
dst10 = _mm256_fmadd_ps(dst10, src43, weight13);
__m256 src53 = _mm256_set1_ps(*(src + 43));
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
dst11 = _mm256_fmadd_ps(dst11, src53, weight13);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 64);
__m256 weight14 = _mm256_load_ps(weight + 72);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
dst6 = _mm256_fmadd_ps(dst6, src04, weight14);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
dst7 = _mm256_fmadd_ps(dst7, src14, weight14);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
dst8 = _mm256_fmadd_ps(dst8, src24, weight14);
__m256 src34 = _mm256_set1_ps(*(src + 28));
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
dst9 = _mm256_fmadd_ps(dst9, src34, weight14);
__m256 src44 = _mm256_set1_ps(*(src + 36));
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
dst10 = _mm256_fmadd_ps(dst10, src44, weight14);
__m256 src54 = _mm256_set1_ps(*(src + 44));
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
dst11 = _mm256_fmadd_ps(dst11, src54, weight14);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 80);
__m256 weight15 = _mm256_load_ps(weight + 88);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
dst6 = _mm256_fmadd_ps(dst6, src05, weight15);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
dst7 = _mm256_fmadd_ps(dst7, src15, weight15);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
dst8 = _mm256_fmadd_ps(dst8, src25, weight15);
__m256 src35 = _mm256_set1_ps(*(src + 29));
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
dst9 = _mm256_fmadd_ps(dst9, src35, weight15);
__m256 src45 = _mm256_set1_ps(*(src + 37));
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
dst10 = _mm256_fmadd_ps(dst10, src45, weight15);
__m256 src55 = _mm256_set1_ps(*(src + 45));
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
dst11 = _mm256_fmadd_ps(dst11, src55, weight15);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 96);
__m256 weight16 = _mm256_load_ps(weight + 104);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
dst6 = _mm256_fmadd_ps(dst6, src06, weight16);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
dst7 = _mm256_fmadd_ps(dst7, src16, weight16);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
dst8 = _mm256_fmadd_ps(dst8, src26, weight16);
__m256 src36 = _mm256_set1_ps(*(src + 30));
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
dst9 = _mm256_fmadd_ps(dst9, src36, weight16);
__m256 src46 = _mm256_set1_ps(*(src + 38));
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
dst10 = _mm256_fmadd_ps(dst10, src46, weight16);
__m256 src56 = _mm256_set1_ps(*(src + 46));
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
dst11 = _mm256_fmadd_ps(dst11, src56, weight16);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 112);
__m256 weight17 = _mm256_load_ps(weight + 120);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
dst6 = _mm256_fmadd_ps(dst6, src07, weight17);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
dst7 = _mm256_fmadd_ps(dst7, src17, weight17);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
dst8 = _mm256_fmadd_ps(dst8, src27, weight17);
__m256 src37 = _mm256_set1_ps(*(src + 31));
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
dst9 = _mm256_fmadd_ps(dst9, src37, weight17);
__m256 src47 = _mm256_set1_ps(*(src + 39));
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
dst10 = _mm256_fmadd_ps(dst10, src47, weight17);
__m256 src57 = _mm256_set1_ps(*(src + 47));
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
dst11 = _mm256_fmadd_ps(dst11, src57, weight17);
src = src + src_stride;
weight += 512;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst6 = _mm256_min_ps(dst6, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst7 = _mm256_min_ps(dst7, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst8 = _mm256_min_ps(dst8, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst9 = _mm256_min_ps(dst9, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst10 = _mm256_min_ps(dst10, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
dst11 = _mm256_min_ps(dst11, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst9 = _mm256_max_ps(dst9, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst10 = _mm256_max_ps(dst10, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst11 = _mm256_max_ps(dst11, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst8 = _mm256_max_ps(dst8, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst9 = _mm256_max_ps(dst9, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst10 = _mm256_max_ps(dst10, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst11 = _mm256_max_ps(dst11, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
_mm256_store_ps(dst + 1 * src_stride + 0, dst6);
_mm256_store_ps(dst + 1 * src_stride + 8, dst7);
_mm256_store_ps(dst + 1 * src_stride + 16, dst8);
_mm256_store_ps(dst + 1 * src_stride + 24, dst9);
_mm256_store_ps(dst + 1 * src_stride + 32, dst10);
_mm256_store_ps(dst + 1 * src_stride + 40, dst11);
}

View File

@ -0,0 +1,307 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 96(%[dst]), %%ymm3\n"
"vmovups 128(%[dst]), %%ymm4\n"
"vmovups 160(%[dst]), %%ymm5\n"
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm6\n"
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm7\n"
"vmovups 64(%[dst], %[dst_stride], 1), %%ymm8\n"
"vmovups 96(%[dst], %[dst_stride], 1), %%ymm9\n"
"vmovups 128(%[dst], %[dst_stride], 1), %%ymm10\n"
"vmovups 160(%[dst], %[dst_stride], 1), %%ymm11\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 0(%[bias]), %%ymm3\n"
"vmovaps 0(%[bias]), %%ymm4\n"
"vmovaps 0(%[bias]), %%ymm5\n"
"vmovaps 32(%[bias]), %%ymm6\n"
"vmovaps 32(%[bias]), %%ymm7\n"
"vmovaps 32(%[bias]), %%ymm8\n"
"vmovaps 32(%[bias]), %%ymm9\n"
"vmovaps 32(%[bias]), %%ymm10\n"
"vmovaps 32(%[bias]), %%ymm11\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
"vxorps %%ymm9, %%ymm9, %%ymm9\n"
"vxorps %%ymm10, %%ymm10, %%ymm10\n"
"vxorps %%ymm11, %%ymm11, %%ymm11\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vmovaps 32(%[weight]), %%ymm14\n"
"vbroadcastss 0(%[src]), %%ymm13\n"
"vbroadcastss 32(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vbroadcastss 64(%[src]), %%ymm13\n"
"vbroadcastss 96(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
"vbroadcastss 128(%[src]), %%ymm13\n"
"vbroadcastss 160(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
// block 1
"vmovaps 64(%[weight]), %%ymm15\n"
"vmovaps 96(%[weight]), %%ymm14\n"
"vbroadcastss 1(%[src]), %%ymm13\n"
"vbroadcastss 33(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vbroadcastss 65(%[src]), %%ymm13\n"
"vbroadcastss 97(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
"vbroadcastss 129(%[src]), %%ymm13\n"
"vbroadcastss 161(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
// block 2
"vmovaps 128(%[weight]), %%ymm15\n"
"vmovaps 160(%[weight]), %%ymm14\n"
"vbroadcastss 2(%[src]), %%ymm13\n"
"vbroadcastss 34(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vbroadcastss 66(%[src]), %%ymm13\n"
"vbroadcastss 98(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
"vbroadcastss 130(%[src]), %%ymm13\n"
"vbroadcastss 162(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
// block 3
"vmovaps 192(%[weight]), %%ymm15\n"
"vmovaps 224(%[weight]), %%ymm14\n"
"vbroadcastss 3(%[src]), %%ymm13\n"
"vbroadcastss 35(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vbroadcastss 67(%[src]), %%ymm13\n"
"vbroadcastss 99(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
"vbroadcastss 131(%[src]), %%ymm13\n"
"vbroadcastss 163(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
// block 4
"vmovaps 256(%[weight]), %%ymm15\n"
"vmovaps 288(%[weight]), %%ymm14\n"
"vbroadcastss 4(%[src]), %%ymm13\n"
"vbroadcastss 36(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vbroadcastss 68(%[src]), %%ymm13\n"
"vbroadcastss 100(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
"vbroadcastss 132(%[src]), %%ymm13\n"
"vbroadcastss 164(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
// block 5
"vmovaps 320(%[weight]), %%ymm15\n"
"vmovaps 352(%[weight]), %%ymm14\n"
"vbroadcastss 5(%[src]), %%ymm13\n"
"vbroadcastss 37(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vbroadcastss 69(%[src]), %%ymm13\n"
"vbroadcastss 101(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
"vbroadcastss 133(%[src]), %%ymm13\n"
"vbroadcastss 165(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
// block 6
"vmovaps 384(%[weight]), %%ymm15\n"
"vmovaps 416(%[weight]), %%ymm14\n"
"vbroadcastss 6(%[src]), %%ymm13\n"
"vbroadcastss 38(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vbroadcastss 70(%[src]), %%ymm13\n"
"vbroadcastss 102(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
"vbroadcastss 134(%[src]), %%ymm13\n"
"vbroadcastss 166(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
// block 7
"vmovaps 448(%[weight]), %%ymm15\n"
"vmovaps 480(%[weight]), %%ymm14\n"
"vbroadcastss 7(%[src]), %%ymm13\n"
"vbroadcastss 39(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
"vbroadcastss 71(%[src]), %%ymm13\n"
"vbroadcastss 103(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
"vbroadcastss 135(%[src]), %%ymm13\n"
"vbroadcastss 167(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
"dec %[deep]\n"
"add 512, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
"vmaxps %%ymm9, %%ymm15, %%ymm9\n"
"vmaxps %%ymm10, %%ymm15, %%ymm10\n"
"vmaxps %%ymm11, %%ymm15, %%ymm11\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"vminps %%ymm6, %%ymm14, %%ymm6\n"
"vminps %%ymm7, %%ymm14, %%ymm7\n"
"vminps %%ymm8, %%ymm14, %%ymm8\n"
"vminps %%ymm9, %%ymm14, %%ymm9\n"
"vminps %%ymm10, %%ymm14, %%ymm10\n"
"vminps %%ymm11, %%ymm14, %%ymm11\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 96(%[dst])\n"
"vmovups %%ymm4, 128(%[dst])\n"
"vmovups %%ymm5, 160(%[dst])\n"
"vmovups %%ymm6, 0(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm7, 32(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm8, 64(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm9, 96(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm10, 128(%[dst], %[dst_stride], 1)\n"
"vmovups %%ymm11, 160(%[dst], %[dst_stride], 1)\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,201 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst1;
__m256 dst2;
__m256 dst3;
__m256 dst4;
__m256 dst5;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst1 = _mm256_load_ps(bias + 0);
dst2 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 0);
dst4 = _mm256_load_ps(bias + 0);
dst5 = _mm256_load_ps(bias + 0);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
__m256 src30 = _mm256_set1_ps(*(src + 24));
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
__m256 src40 = _mm256_set1_ps(*(src + 32));
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
__m256 src50 = _mm256_set1_ps(*(src + 40));
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 8);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
__m256 src31 = _mm256_set1_ps(*(src + 25));
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
__m256 src41 = _mm256_set1_ps(*(src + 33));
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
__m256 src51 = _mm256_set1_ps(*(src + 41));
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 16);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
__m256 src32 = _mm256_set1_ps(*(src + 26));
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
__m256 src42 = _mm256_set1_ps(*(src + 34));
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
__m256 src52 = _mm256_set1_ps(*(src + 42));
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 24);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
__m256 src33 = _mm256_set1_ps(*(src + 27));
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
__m256 src43 = _mm256_set1_ps(*(src + 35));
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
__m256 src53 = _mm256_set1_ps(*(src + 43));
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 32);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
__m256 src34 = _mm256_set1_ps(*(src + 28));
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
__m256 src44 = _mm256_set1_ps(*(src + 36));
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
__m256 src54 = _mm256_set1_ps(*(src + 44));
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 40);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
__m256 src35 = _mm256_set1_ps(*(src + 29));
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
__m256 src45 = _mm256_set1_ps(*(src + 37));
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
__m256 src55 = _mm256_set1_ps(*(src + 45));
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 48);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
__m256 src36 = _mm256_set1_ps(*(src + 30));
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
__m256 src46 = _mm256_set1_ps(*(src + 38));
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
__m256 src56 = _mm256_set1_ps(*(src + 46));
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 56);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
__m256 src37 = _mm256_set1_ps(*(src + 31));
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
__m256 src47 = _mm256_set1_ps(*(src + 39));
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
__m256 src57 = _mm256_set1_ps(*(src + 47));
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
src = src + src_stride;
weight += 256;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst5 = _mm256_max_ps(dst5, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst5 = _mm256_max_ps(dst5, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
}

View File

@ -0,0 +1,215 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 96(%[dst]), %%ymm3\n"
"vmovups 128(%[dst]), %%ymm4\n"
"vmovups 160(%[dst]), %%ymm5\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 0(%[bias]), %%ymm3\n"
"vmovaps 0(%[bias]), %%ymm4\n"
"vmovaps 0(%[bias]), %%ymm5\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vbroadcastss 0(%[src]), %%ymm14\n"
"vbroadcastss 32(%[src]), %%ymm13\n"
"vbroadcastss 64(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 96(%[src]), %%ymm14\n"
"vbroadcastss 128(%[src]), %%ymm13\n"
"vbroadcastss 160(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
// block 1
"vmovaps 32(%[weight]), %%ymm15\n"
"vbroadcastss 1(%[src]), %%ymm14\n"
"vbroadcastss 33(%[src]), %%ymm13\n"
"vbroadcastss 65(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 97(%[src]), %%ymm14\n"
"vbroadcastss 129(%[src]), %%ymm13\n"
"vbroadcastss 161(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
// block 2
"vmovaps 64(%[weight]), %%ymm15\n"
"vbroadcastss 2(%[src]), %%ymm14\n"
"vbroadcastss 34(%[src]), %%ymm13\n"
"vbroadcastss 66(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 98(%[src]), %%ymm14\n"
"vbroadcastss 130(%[src]), %%ymm13\n"
"vbroadcastss 162(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
// block 3
"vmovaps 96(%[weight]), %%ymm15\n"
"vbroadcastss 3(%[src]), %%ymm14\n"
"vbroadcastss 35(%[src]), %%ymm13\n"
"vbroadcastss 67(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 99(%[src]), %%ymm14\n"
"vbroadcastss 131(%[src]), %%ymm13\n"
"vbroadcastss 163(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
// block 4
"vmovaps 128(%[weight]), %%ymm15\n"
"vbroadcastss 4(%[src]), %%ymm14\n"
"vbroadcastss 36(%[src]), %%ymm13\n"
"vbroadcastss 68(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 100(%[src]), %%ymm14\n"
"vbroadcastss 132(%[src]), %%ymm13\n"
"vbroadcastss 164(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
// block 5
"vmovaps 160(%[weight]), %%ymm15\n"
"vbroadcastss 5(%[src]), %%ymm14\n"
"vbroadcastss 37(%[src]), %%ymm13\n"
"vbroadcastss 69(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 101(%[src]), %%ymm14\n"
"vbroadcastss 133(%[src]), %%ymm13\n"
"vbroadcastss 165(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
// block 6
"vmovaps 192(%[weight]), %%ymm15\n"
"vbroadcastss 6(%[src]), %%ymm14\n"
"vbroadcastss 38(%[src]), %%ymm13\n"
"vbroadcastss 70(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 102(%[src]), %%ymm14\n"
"vbroadcastss 134(%[src]), %%ymm13\n"
"vbroadcastss 166(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
// block 7
"vmovaps 224(%[weight]), %%ymm15\n"
"vbroadcastss 7(%[src]), %%ymm14\n"
"vbroadcastss 39(%[src]), %%ymm13\n"
"vbroadcastss 71(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 103(%[src]), %%ymm14\n"
"vbroadcastss 135(%[src]), %%ymm13\n"
"vbroadcastss 167(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"dec %[deep]\n"
"add 256, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 96(%[dst])\n"
"vmovups %%ymm4, 128(%[dst])\n"
"vmovups %%ymm5, 160(%[dst])\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,225 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst1;
__m256 dst2;
__m256 dst3;
__m256 dst4;
__m256 dst5;
__m256 dst6;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
dst6 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst1 = _mm256_load_ps(bias + 0);
dst2 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 0);
dst4 = _mm256_load_ps(bias + 0);
dst5 = _mm256_load_ps(bias + 0);
dst6 = _mm256_load_ps(bias + 0);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
__m256 src30 = _mm256_set1_ps(*(src + 24));
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
__m256 src40 = _mm256_set1_ps(*(src + 32));
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
__m256 src50 = _mm256_set1_ps(*(src + 40));
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
__m256 src60 = _mm256_set1_ps(*(src + 48));
dst6 = _mm256_fmadd_ps(dst6, src60, weight00);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 8);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
__m256 src31 = _mm256_set1_ps(*(src + 25));
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
__m256 src41 = _mm256_set1_ps(*(src + 33));
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
__m256 src51 = _mm256_set1_ps(*(src + 41));
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
__m256 src61 = _mm256_set1_ps(*(src + 49));
dst6 = _mm256_fmadd_ps(dst6, src61, weight01);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 16);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
__m256 src32 = _mm256_set1_ps(*(src + 26));
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
__m256 src42 = _mm256_set1_ps(*(src + 34));
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
__m256 src52 = _mm256_set1_ps(*(src + 42));
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
__m256 src62 = _mm256_set1_ps(*(src + 50));
dst6 = _mm256_fmadd_ps(dst6, src62, weight02);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 24);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
__m256 src33 = _mm256_set1_ps(*(src + 27));
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
__m256 src43 = _mm256_set1_ps(*(src + 35));
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
__m256 src53 = _mm256_set1_ps(*(src + 43));
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
__m256 src63 = _mm256_set1_ps(*(src + 51));
dst6 = _mm256_fmadd_ps(dst6, src63, weight03);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 32);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
__m256 src34 = _mm256_set1_ps(*(src + 28));
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
__m256 src44 = _mm256_set1_ps(*(src + 36));
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
__m256 src54 = _mm256_set1_ps(*(src + 44));
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
__m256 src64 = _mm256_set1_ps(*(src + 52));
dst6 = _mm256_fmadd_ps(dst6, src64, weight04);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 40);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
__m256 src35 = _mm256_set1_ps(*(src + 29));
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
__m256 src45 = _mm256_set1_ps(*(src + 37));
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
__m256 src55 = _mm256_set1_ps(*(src + 45));
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
__m256 src65 = _mm256_set1_ps(*(src + 53));
dst6 = _mm256_fmadd_ps(dst6, src65, weight05);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 48);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
__m256 src36 = _mm256_set1_ps(*(src + 30));
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
__m256 src46 = _mm256_set1_ps(*(src + 38));
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
__m256 src56 = _mm256_set1_ps(*(src + 46));
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
__m256 src66 = _mm256_set1_ps(*(src + 54));
dst6 = _mm256_fmadd_ps(dst6, src66, weight06);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 56);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
__m256 src37 = _mm256_set1_ps(*(src + 31));
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
__m256 src47 = _mm256_set1_ps(*(src + 39));
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
__m256 src57 = _mm256_set1_ps(*(src + 47));
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
__m256 src67 = _mm256_set1_ps(*(src + 55));
dst6 = _mm256_fmadd_ps(dst6, src67, weight07);
src = src + src_stride;
weight += 256;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
dst6 = _mm256_min_ps(dst6, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst6 = _mm256_max_ps(dst6, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst6 = _mm256_max_ps(dst6, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
_mm256_store_ps(dst + 0 * src_stride + 48, dst6);
}

View File

@ -0,0 +1,237 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 96(%[dst]), %%ymm3\n"
"vmovups 128(%[dst]), %%ymm4\n"
"vmovups 160(%[dst]), %%ymm5\n"
"vmovups 192(%[dst]), %%ymm6\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 0(%[bias]), %%ymm3\n"
"vmovaps 0(%[bias]), %%ymm4\n"
"vmovaps 0(%[bias]), %%ymm5\n"
"vmovaps 0(%[bias]), %%ymm6\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vbroadcastss 0(%[src]), %%ymm14\n"
"vbroadcastss 32(%[src]), %%ymm13\n"
"vbroadcastss 64(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 96(%[src]), %%ymm14\n"
"vbroadcastss 128(%[src]), %%ymm13\n"
"vbroadcastss 160(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 192(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
// block 1
"vmovaps 32(%[weight]), %%ymm15\n"
"vbroadcastss 1(%[src]), %%ymm14\n"
"vbroadcastss 33(%[src]), %%ymm13\n"
"vbroadcastss 65(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 97(%[src]), %%ymm14\n"
"vbroadcastss 129(%[src]), %%ymm13\n"
"vbroadcastss 161(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 193(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
// block 2
"vmovaps 64(%[weight]), %%ymm15\n"
"vbroadcastss 2(%[src]), %%ymm14\n"
"vbroadcastss 34(%[src]), %%ymm13\n"
"vbroadcastss 66(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 98(%[src]), %%ymm14\n"
"vbroadcastss 130(%[src]), %%ymm13\n"
"vbroadcastss 162(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 194(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
// block 3
"vmovaps 96(%[weight]), %%ymm15\n"
"vbroadcastss 3(%[src]), %%ymm14\n"
"vbroadcastss 35(%[src]), %%ymm13\n"
"vbroadcastss 67(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 99(%[src]), %%ymm14\n"
"vbroadcastss 131(%[src]), %%ymm13\n"
"vbroadcastss 163(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 195(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
// block 4
"vmovaps 128(%[weight]), %%ymm15\n"
"vbroadcastss 4(%[src]), %%ymm14\n"
"vbroadcastss 36(%[src]), %%ymm13\n"
"vbroadcastss 68(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 100(%[src]), %%ymm14\n"
"vbroadcastss 132(%[src]), %%ymm13\n"
"vbroadcastss 164(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 196(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
// block 5
"vmovaps 160(%[weight]), %%ymm15\n"
"vbroadcastss 5(%[src]), %%ymm14\n"
"vbroadcastss 37(%[src]), %%ymm13\n"
"vbroadcastss 69(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 101(%[src]), %%ymm14\n"
"vbroadcastss 133(%[src]), %%ymm13\n"
"vbroadcastss 165(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 197(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
// block 6
"vmovaps 192(%[weight]), %%ymm15\n"
"vbroadcastss 6(%[src]), %%ymm14\n"
"vbroadcastss 38(%[src]), %%ymm13\n"
"vbroadcastss 70(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 102(%[src]), %%ymm14\n"
"vbroadcastss 134(%[src]), %%ymm13\n"
"vbroadcastss 166(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 198(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
// block 7
"vmovaps 224(%[weight]), %%ymm15\n"
"vbroadcastss 7(%[src]), %%ymm14\n"
"vbroadcastss 39(%[src]), %%ymm13\n"
"vbroadcastss 71(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 103(%[src]), %%ymm14\n"
"vbroadcastss 135(%[src]), %%ymm13\n"
"vbroadcastss 167(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 199(%[src]), %%ymm14\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"dec %[deep]\n"
"add 256, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"vminps %%ymm6, %%ymm14, %%ymm6\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 96(%[dst])\n"
"vmovups %%ymm4, 128(%[dst])\n"
"vmovups %%ymm5, 160(%[dst])\n"
"vmovups %%ymm6, 192(%[dst])\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,249 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst1;
__m256 dst2;
__m256 dst3;
__m256 dst4;
__m256 dst5;
__m256 dst6;
__m256 dst7;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48);
dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
dst6 = _mm256_setzero_ps();
dst7 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst1 = _mm256_load_ps(bias + 0);
dst2 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 0);
dst4 = _mm256_load_ps(bias + 0);
dst5 = _mm256_load_ps(bias + 0);
dst6 = _mm256_load_ps(bias + 0);
dst7 = _mm256_load_ps(bias + 0);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
__m256 src30 = _mm256_set1_ps(*(src + 24));
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
__m256 src40 = _mm256_set1_ps(*(src + 32));
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
__m256 src50 = _mm256_set1_ps(*(src + 40));
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
__m256 src60 = _mm256_set1_ps(*(src + 48));
dst6 = _mm256_fmadd_ps(dst6, src60, weight00);
__m256 src70 = _mm256_set1_ps(*(src + 56));
dst7 = _mm256_fmadd_ps(dst7, src70, weight00);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 8);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
__m256 src31 = _mm256_set1_ps(*(src + 25));
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
__m256 src41 = _mm256_set1_ps(*(src + 33));
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
__m256 src51 = _mm256_set1_ps(*(src + 41));
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
__m256 src61 = _mm256_set1_ps(*(src + 49));
dst6 = _mm256_fmadd_ps(dst6, src61, weight01);
__m256 src71 = _mm256_set1_ps(*(src + 57));
dst7 = _mm256_fmadd_ps(dst7, src71, weight01);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 16);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
__m256 src32 = _mm256_set1_ps(*(src + 26));
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
__m256 src42 = _mm256_set1_ps(*(src + 34));
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
__m256 src52 = _mm256_set1_ps(*(src + 42));
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
__m256 src62 = _mm256_set1_ps(*(src + 50));
dst6 = _mm256_fmadd_ps(dst6, src62, weight02);
__m256 src72 = _mm256_set1_ps(*(src + 58));
dst7 = _mm256_fmadd_ps(dst7, src72, weight02);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 24);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
__m256 src33 = _mm256_set1_ps(*(src + 27));
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
__m256 src43 = _mm256_set1_ps(*(src + 35));
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
__m256 src53 = _mm256_set1_ps(*(src + 43));
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
__m256 src63 = _mm256_set1_ps(*(src + 51));
dst6 = _mm256_fmadd_ps(dst6, src63, weight03);
__m256 src73 = _mm256_set1_ps(*(src + 59));
dst7 = _mm256_fmadd_ps(dst7, src73, weight03);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 32);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
__m256 src34 = _mm256_set1_ps(*(src + 28));
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
__m256 src44 = _mm256_set1_ps(*(src + 36));
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
__m256 src54 = _mm256_set1_ps(*(src + 44));
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
__m256 src64 = _mm256_set1_ps(*(src + 52));
dst6 = _mm256_fmadd_ps(dst6, src64, weight04);
__m256 src74 = _mm256_set1_ps(*(src + 60));
dst7 = _mm256_fmadd_ps(dst7, src74, weight04);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 40);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
__m256 src35 = _mm256_set1_ps(*(src + 29));
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
__m256 src45 = _mm256_set1_ps(*(src + 37));
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
__m256 src55 = _mm256_set1_ps(*(src + 45));
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
__m256 src65 = _mm256_set1_ps(*(src + 53));
dst6 = _mm256_fmadd_ps(dst6, src65, weight05);
__m256 src75 = _mm256_set1_ps(*(src + 61));
dst7 = _mm256_fmadd_ps(dst7, src75, weight05);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 48);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
__m256 src36 = _mm256_set1_ps(*(src + 30));
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
__m256 src46 = _mm256_set1_ps(*(src + 38));
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
__m256 src56 = _mm256_set1_ps(*(src + 46));
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
__m256 src66 = _mm256_set1_ps(*(src + 54));
dst6 = _mm256_fmadd_ps(dst6, src66, weight06);
__m256 src76 = _mm256_set1_ps(*(src + 62));
dst7 = _mm256_fmadd_ps(dst7, src76, weight06);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 56);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
__m256 src37 = _mm256_set1_ps(*(src + 31));
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
__m256 src47 = _mm256_set1_ps(*(src + 39));
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
__m256 src57 = _mm256_set1_ps(*(src + 47));
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
__m256 src67 = _mm256_set1_ps(*(src + 55));
dst6 = _mm256_fmadd_ps(dst6, src67, weight07);
__m256 src77 = _mm256_set1_ps(*(src + 63));
dst7 = _mm256_fmadd_ps(dst7, src77, weight07);
src = src + src_stride;
weight += 256;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
dst6 = _mm256_min_ps(dst6, relu6);
dst7 = _mm256_min_ps(dst7, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst7 = _mm256_max_ps(dst7, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst7 = _mm256_max_ps(dst7, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
_mm256_store_ps(dst + 0 * src_stride + 48, dst6);
_mm256_store_ps(dst + 0 * src_stride + 56, dst7);
}

View File

@ -0,0 +1,259 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 96(%[dst]), %%ymm3\n"
"vmovups 128(%[dst]), %%ymm4\n"
"vmovups 160(%[dst]), %%ymm5\n"
"vmovups 192(%[dst]), %%ymm6\n"
"vmovups 224(%[dst]), %%ymm7\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 0(%[bias]), %%ymm3\n"
"vmovaps 0(%[bias]), %%ymm4\n"
"vmovaps 0(%[bias]), %%ymm5\n"
"vmovaps 0(%[bias]), %%ymm6\n"
"vmovaps 0(%[bias]), %%ymm7\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vbroadcastss 0(%[src]), %%ymm14\n"
"vbroadcastss 32(%[src]), %%ymm13\n"
"vbroadcastss 64(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 96(%[src]), %%ymm14\n"
"vbroadcastss 128(%[src]), %%ymm13\n"
"vbroadcastss 160(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 192(%[src]), %%ymm14\n"
"vbroadcastss 224(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
// block 1
"vmovaps 32(%[weight]), %%ymm15\n"
"vbroadcastss 1(%[src]), %%ymm14\n"
"vbroadcastss 33(%[src]), %%ymm13\n"
"vbroadcastss 65(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 97(%[src]), %%ymm14\n"
"vbroadcastss 129(%[src]), %%ymm13\n"
"vbroadcastss 161(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 193(%[src]), %%ymm14\n"
"vbroadcastss 225(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
// block 2
"vmovaps 64(%[weight]), %%ymm15\n"
"vbroadcastss 2(%[src]), %%ymm14\n"
"vbroadcastss 34(%[src]), %%ymm13\n"
"vbroadcastss 66(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 98(%[src]), %%ymm14\n"
"vbroadcastss 130(%[src]), %%ymm13\n"
"vbroadcastss 162(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 194(%[src]), %%ymm14\n"
"vbroadcastss 226(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
// block 3
"vmovaps 96(%[weight]), %%ymm15\n"
"vbroadcastss 3(%[src]), %%ymm14\n"
"vbroadcastss 35(%[src]), %%ymm13\n"
"vbroadcastss 67(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 99(%[src]), %%ymm14\n"
"vbroadcastss 131(%[src]), %%ymm13\n"
"vbroadcastss 163(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 195(%[src]), %%ymm14\n"
"vbroadcastss 227(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
// block 4
"vmovaps 128(%[weight]), %%ymm15\n"
"vbroadcastss 4(%[src]), %%ymm14\n"
"vbroadcastss 36(%[src]), %%ymm13\n"
"vbroadcastss 68(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 100(%[src]), %%ymm14\n"
"vbroadcastss 132(%[src]), %%ymm13\n"
"vbroadcastss 164(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 196(%[src]), %%ymm14\n"
"vbroadcastss 228(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
// block 5
"vmovaps 160(%[weight]), %%ymm15\n"
"vbroadcastss 5(%[src]), %%ymm14\n"
"vbroadcastss 37(%[src]), %%ymm13\n"
"vbroadcastss 69(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 101(%[src]), %%ymm14\n"
"vbroadcastss 133(%[src]), %%ymm13\n"
"vbroadcastss 165(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 197(%[src]), %%ymm14\n"
"vbroadcastss 229(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
// block 6
"vmovaps 192(%[weight]), %%ymm15\n"
"vbroadcastss 6(%[src]), %%ymm14\n"
"vbroadcastss 38(%[src]), %%ymm13\n"
"vbroadcastss 70(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 102(%[src]), %%ymm14\n"
"vbroadcastss 134(%[src]), %%ymm13\n"
"vbroadcastss 166(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 198(%[src]), %%ymm14\n"
"vbroadcastss 230(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
// block 7
"vmovaps 224(%[weight]), %%ymm15\n"
"vbroadcastss 7(%[src]), %%ymm14\n"
"vbroadcastss 39(%[src]), %%ymm13\n"
"vbroadcastss 71(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 103(%[src]), %%ymm14\n"
"vbroadcastss 135(%[src]), %%ymm13\n"
"vbroadcastss 167(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 199(%[src]), %%ymm14\n"
"vbroadcastss 231(%[src]), %%ymm13\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"dec %[deep]\n"
"add 256, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"vminps %%ymm6, %%ymm14, %%ymm6\n"
"vminps %%ymm7, %%ymm14, %%ymm7\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 96(%[dst])\n"
"vmovups %%ymm4, 128(%[dst])\n"
"vmovups %%ymm5, 160(%[dst])\n"
"vmovups %%ymm6, 192(%[dst])\n"
"vmovups %%ymm7, 224(%[dst])\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,273 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
__m256 dst0;
__m256 dst1;
__m256 dst2;
__m256 dst3;
__m256 dst4;
__m256 dst5;
__m256 dst6;
__m256 dst7;
__m256 dst8;
if (inc_flag) {
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48);
dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56);
dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64);
} else if (bias == NULL) {
dst0 = _mm256_setzero_ps();
dst1 = _mm256_setzero_ps();
dst2 = _mm256_setzero_ps();
dst3 = _mm256_setzero_ps();
dst4 = _mm256_setzero_ps();
dst5 = _mm256_setzero_ps();
dst6 = _mm256_setzero_ps();
dst7 = _mm256_setzero_ps();
dst8 = _mm256_setzero_ps();
} else {
dst0 = _mm256_load_ps(bias + 0);
dst1 = _mm256_load_ps(bias + 0);
dst2 = _mm256_load_ps(bias + 0);
dst3 = _mm256_load_ps(bias + 0);
dst4 = _mm256_load_ps(bias + 0);
dst5 = _mm256_load_ps(bias + 0);
dst6 = _mm256_load_ps(bias + 0);
dst7 = _mm256_load_ps(bias + 0);
dst8 = _mm256_load_ps(bias + 0);
}
for (int i = 0; i < (deep >> 3); ++i) {
// bock0
__m256 weight00 = _mm256_load_ps(weight + 0);
__m256 src00 = _mm256_set1_ps(*(src + 0));
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
__m256 src10 = _mm256_set1_ps(*(src + 8));
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
__m256 src20 = _mm256_set1_ps(*(src + 16));
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
__m256 src30 = _mm256_set1_ps(*(src + 24));
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
__m256 src40 = _mm256_set1_ps(*(src + 32));
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
__m256 src50 = _mm256_set1_ps(*(src + 40));
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
__m256 src60 = _mm256_set1_ps(*(src + 48));
dst6 = _mm256_fmadd_ps(dst6, src60, weight00);
__m256 src70 = _mm256_set1_ps(*(src + 56));
dst7 = _mm256_fmadd_ps(dst7, src70, weight00);
__m256 src80 = _mm256_set1_ps(*(src + 64));
dst8 = _mm256_fmadd_ps(dst8, src80, weight00);
// bock1
__m256 weight01 = _mm256_load_ps(weight + 8);
__m256 src01 = _mm256_set1_ps(*(src + 1));
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
__m256 src11 = _mm256_set1_ps(*(src + 9));
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
__m256 src21 = _mm256_set1_ps(*(src + 17));
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
__m256 src31 = _mm256_set1_ps(*(src + 25));
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
__m256 src41 = _mm256_set1_ps(*(src + 33));
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
__m256 src51 = _mm256_set1_ps(*(src + 41));
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
__m256 src61 = _mm256_set1_ps(*(src + 49));
dst6 = _mm256_fmadd_ps(dst6, src61, weight01);
__m256 src71 = _mm256_set1_ps(*(src + 57));
dst7 = _mm256_fmadd_ps(dst7, src71, weight01);
__m256 src81 = _mm256_set1_ps(*(src + 65));
dst8 = _mm256_fmadd_ps(dst8, src81, weight01);
// bock2
__m256 weight02 = _mm256_load_ps(weight + 16);
__m256 src02 = _mm256_set1_ps(*(src + 2));
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
__m256 src12 = _mm256_set1_ps(*(src + 10));
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
__m256 src22 = _mm256_set1_ps(*(src + 18));
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
__m256 src32 = _mm256_set1_ps(*(src + 26));
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
__m256 src42 = _mm256_set1_ps(*(src + 34));
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
__m256 src52 = _mm256_set1_ps(*(src + 42));
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
__m256 src62 = _mm256_set1_ps(*(src + 50));
dst6 = _mm256_fmadd_ps(dst6, src62, weight02);
__m256 src72 = _mm256_set1_ps(*(src + 58));
dst7 = _mm256_fmadd_ps(dst7, src72, weight02);
__m256 src82 = _mm256_set1_ps(*(src + 66));
dst8 = _mm256_fmadd_ps(dst8, src82, weight02);
// bock3
__m256 weight03 = _mm256_load_ps(weight + 24);
__m256 src03 = _mm256_set1_ps(*(src + 3));
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
__m256 src13 = _mm256_set1_ps(*(src + 11));
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
__m256 src23 = _mm256_set1_ps(*(src + 19));
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
__m256 src33 = _mm256_set1_ps(*(src + 27));
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
__m256 src43 = _mm256_set1_ps(*(src + 35));
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
__m256 src53 = _mm256_set1_ps(*(src + 43));
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
__m256 src63 = _mm256_set1_ps(*(src + 51));
dst6 = _mm256_fmadd_ps(dst6, src63, weight03);
__m256 src73 = _mm256_set1_ps(*(src + 59));
dst7 = _mm256_fmadd_ps(dst7, src73, weight03);
__m256 src83 = _mm256_set1_ps(*(src + 67));
dst8 = _mm256_fmadd_ps(dst8, src83, weight03);
// bock4
__m256 weight04 = _mm256_load_ps(weight + 32);
__m256 src04 = _mm256_set1_ps(*(src + 4));
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
__m256 src14 = _mm256_set1_ps(*(src + 12));
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
__m256 src24 = _mm256_set1_ps(*(src + 20));
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
__m256 src34 = _mm256_set1_ps(*(src + 28));
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
__m256 src44 = _mm256_set1_ps(*(src + 36));
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
__m256 src54 = _mm256_set1_ps(*(src + 44));
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
__m256 src64 = _mm256_set1_ps(*(src + 52));
dst6 = _mm256_fmadd_ps(dst6, src64, weight04);
__m256 src74 = _mm256_set1_ps(*(src + 60));
dst7 = _mm256_fmadd_ps(dst7, src74, weight04);
__m256 src84 = _mm256_set1_ps(*(src + 68));
dst8 = _mm256_fmadd_ps(dst8, src84, weight04);
// bock5
__m256 weight05 = _mm256_load_ps(weight + 40);
__m256 src05 = _mm256_set1_ps(*(src + 5));
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
__m256 src15 = _mm256_set1_ps(*(src + 13));
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
__m256 src25 = _mm256_set1_ps(*(src + 21));
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
__m256 src35 = _mm256_set1_ps(*(src + 29));
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
__m256 src45 = _mm256_set1_ps(*(src + 37));
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
__m256 src55 = _mm256_set1_ps(*(src + 45));
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
__m256 src65 = _mm256_set1_ps(*(src + 53));
dst6 = _mm256_fmadd_ps(dst6, src65, weight05);
__m256 src75 = _mm256_set1_ps(*(src + 61));
dst7 = _mm256_fmadd_ps(dst7, src75, weight05);
__m256 src85 = _mm256_set1_ps(*(src + 69));
dst8 = _mm256_fmadd_ps(dst8, src85, weight05);
// bock6
__m256 weight06 = _mm256_load_ps(weight + 48);
__m256 src06 = _mm256_set1_ps(*(src + 6));
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
__m256 src16 = _mm256_set1_ps(*(src + 14));
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
__m256 src26 = _mm256_set1_ps(*(src + 22));
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
__m256 src36 = _mm256_set1_ps(*(src + 30));
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
__m256 src46 = _mm256_set1_ps(*(src + 38));
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
__m256 src56 = _mm256_set1_ps(*(src + 46));
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
__m256 src66 = _mm256_set1_ps(*(src + 54));
dst6 = _mm256_fmadd_ps(dst6, src66, weight06);
__m256 src76 = _mm256_set1_ps(*(src + 62));
dst7 = _mm256_fmadd_ps(dst7, src76, weight06);
__m256 src86 = _mm256_set1_ps(*(src + 70));
dst8 = _mm256_fmadd_ps(dst8, src86, weight06);
// bock7
__m256 weight07 = _mm256_load_ps(weight + 56);
__m256 src07 = _mm256_set1_ps(*(src + 7));
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
__m256 src17 = _mm256_set1_ps(*(src + 15));
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
__m256 src27 = _mm256_set1_ps(*(src + 23));
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
__m256 src37 = _mm256_set1_ps(*(src + 31));
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
__m256 src47 = _mm256_set1_ps(*(src + 39));
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
__m256 src57 = _mm256_set1_ps(*(src + 47));
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
__m256 src67 = _mm256_set1_ps(*(src + 55));
dst6 = _mm256_fmadd_ps(dst6, src67, weight07);
__m256 src77 = _mm256_set1_ps(*(src + 63));
dst7 = _mm256_fmadd_ps(dst7, src77, weight07);
__m256 src87 = _mm256_set1_ps(*(src + 71));
dst8 = _mm256_fmadd_ps(dst8, src87, weight07);
src = src + src_stride;
weight += 256;
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_min_ps(dst0, relu6);
dst1 = _mm256_min_ps(dst1, relu6);
dst2 = _mm256_min_ps(dst2, relu6);
dst3 = _mm256_min_ps(dst3, relu6);
dst4 = _mm256_min_ps(dst4, relu6);
dst5 = _mm256_min_ps(dst5, relu6);
dst6 = _mm256_min_ps(dst6, relu6);
dst7 = _mm256_min_ps(dst7, relu6);
dst8 = _mm256_min_ps(dst8, relu6);
// relu
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst8 = _mm256_max_ps(dst8, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
dst0 = _mm256_max_ps(dst0, relu);
dst1 = _mm256_max_ps(dst1, relu);
dst2 = _mm256_max_ps(dst2, relu);
dst3 = _mm256_max_ps(dst3, relu);
dst4 = _mm256_max_ps(dst4, relu);
dst5 = _mm256_max_ps(dst5, relu);
dst6 = _mm256_max_ps(dst6, relu);
dst7 = _mm256_max_ps(dst7, relu);
dst8 = _mm256_max_ps(dst8, relu);
}
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
_mm256_store_ps(dst + 0 * src_stride + 48, dst6);
_mm256_store_ps(dst + 0 * src_stride + 56, dst7);
_mm256_store_ps(dst + 0 * src_stride + 64, dst8);
}

View File

@ -0,0 +1,281 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst]), %%ymm0\n"
"vmovups 32(%[dst]), %%ymm1\n"
"vmovups 64(%[dst]), %%ymm2\n"
"vmovups 96(%[dst]), %%ymm3\n"
"vmovups 128(%[dst]), %%ymm4\n"
"vmovups 160(%[dst]), %%ymm5\n"
"vmovups 192(%[dst]), %%ymm6\n"
"vmovups 224(%[dst]), %%ymm7\n"
"vmovups 256(%[dst]), %%ymm8\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%ymm0\n"
"vmovaps 0(%[bias]), %%ymm1\n"
"vmovaps 0(%[bias]), %%ymm2\n"
"vmovaps 0(%[bias]), %%ymm3\n"
"vmovaps 0(%[bias]), %%ymm4\n"
"vmovaps 0(%[bias]), %%ymm5\n"
"vmovaps 0(%[bias]), %%ymm6\n"
"vmovaps 0(%[bias]), %%ymm7\n"
"vmovaps 0(%[bias]), %%ymm8\n"
"jmp 2f\n"
"1:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
"2:\n"
:
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8");
asm volatile(
"0:\n"
// block 0
"vmovaps 0(%[weight]), %%ymm15\n"
"vbroadcastss 0(%[src]), %%ymm14\n"
"vbroadcastss 32(%[src]), %%ymm13\n"
"vbroadcastss 64(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 96(%[src]), %%ymm14\n"
"vbroadcastss 128(%[src]), %%ymm13\n"
"vbroadcastss 160(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 192(%[src]), %%ymm14\n"
"vbroadcastss 224(%[src]), %%ymm13\n"
"vbroadcastss 256(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
// block 1
"vmovaps 32(%[weight]), %%ymm15\n"
"vbroadcastss 1(%[src]), %%ymm14\n"
"vbroadcastss 33(%[src]), %%ymm13\n"
"vbroadcastss 65(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 97(%[src]), %%ymm14\n"
"vbroadcastss 129(%[src]), %%ymm13\n"
"vbroadcastss 161(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 193(%[src]), %%ymm14\n"
"vbroadcastss 225(%[src]), %%ymm13\n"
"vbroadcastss 257(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
// block 2
"vmovaps 64(%[weight]), %%ymm15\n"
"vbroadcastss 2(%[src]), %%ymm14\n"
"vbroadcastss 34(%[src]), %%ymm13\n"
"vbroadcastss 66(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 98(%[src]), %%ymm14\n"
"vbroadcastss 130(%[src]), %%ymm13\n"
"vbroadcastss 162(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 194(%[src]), %%ymm14\n"
"vbroadcastss 226(%[src]), %%ymm13\n"
"vbroadcastss 258(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
// block 3
"vmovaps 96(%[weight]), %%ymm15\n"
"vbroadcastss 3(%[src]), %%ymm14\n"
"vbroadcastss 35(%[src]), %%ymm13\n"
"vbroadcastss 67(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 99(%[src]), %%ymm14\n"
"vbroadcastss 131(%[src]), %%ymm13\n"
"vbroadcastss 163(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 195(%[src]), %%ymm14\n"
"vbroadcastss 227(%[src]), %%ymm13\n"
"vbroadcastss 259(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
// block 4
"vmovaps 128(%[weight]), %%ymm15\n"
"vbroadcastss 4(%[src]), %%ymm14\n"
"vbroadcastss 36(%[src]), %%ymm13\n"
"vbroadcastss 68(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 100(%[src]), %%ymm14\n"
"vbroadcastss 132(%[src]), %%ymm13\n"
"vbroadcastss 164(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 196(%[src]), %%ymm14\n"
"vbroadcastss 228(%[src]), %%ymm13\n"
"vbroadcastss 260(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
// block 5
"vmovaps 160(%[weight]), %%ymm15\n"
"vbroadcastss 5(%[src]), %%ymm14\n"
"vbroadcastss 37(%[src]), %%ymm13\n"
"vbroadcastss 69(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 101(%[src]), %%ymm14\n"
"vbroadcastss 133(%[src]), %%ymm13\n"
"vbroadcastss 165(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 197(%[src]), %%ymm14\n"
"vbroadcastss 229(%[src]), %%ymm13\n"
"vbroadcastss 261(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
// block 6
"vmovaps 192(%[weight]), %%ymm15\n"
"vbroadcastss 6(%[src]), %%ymm14\n"
"vbroadcastss 38(%[src]), %%ymm13\n"
"vbroadcastss 70(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 102(%[src]), %%ymm14\n"
"vbroadcastss 134(%[src]), %%ymm13\n"
"vbroadcastss 166(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 198(%[src]), %%ymm14\n"
"vbroadcastss 230(%[src]), %%ymm13\n"
"vbroadcastss 262(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
// block 7
"vmovaps 224(%[weight]), %%ymm15\n"
"vbroadcastss 7(%[src]), %%ymm14\n"
"vbroadcastss 39(%[src]), %%ymm13\n"
"vbroadcastss 71(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
"vbroadcastss 103(%[src]), %%ymm14\n"
"vbroadcastss 135(%[src]), %%ymm13\n"
"vbroadcastss 167(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
"vbroadcastss 199(%[src]), %%ymm14\n"
"vbroadcastss 231(%[src]), %%ymm13\n"
"vbroadcastss 263(%[src]), %%ymm12\n"
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
"dec %[deep]\n"
"add 256, %[weight]\n"
"add %[src_stride], %[src]\n"
"jg 0b\n"
"movq %[inc_flag], %%rax\n"
"and $0x2, %%eax\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm14\n"
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
"vminps %%ymm0, %%ymm14, %%ymm0\n"
"vminps %%ymm1, %%ymm14, %%ymm1\n"
"vminps %%ymm2, %%ymm14, %%ymm2\n"
"vminps %%ymm3, %%ymm14, %%ymm3\n"
"vminps %%ymm4, %%ymm14, %%ymm4\n"
"vminps %%ymm5, %%ymm14, %%ymm5\n"
"vminps %%ymm6, %%ymm14, %%ymm6\n"
"vminps %%ymm7, %%ymm14, %%ymm7\n"
"vminps %%ymm8, %%ymm14, %%ymm8\n"
"3:\n"
"vmovups %%ymm0, 0(%[dst])\n"
"vmovups %%ymm1, 32(%[dst])\n"
"vmovups %%ymm2, 64(%[dst])\n"
"vmovups %%ymm3, 96(%[dst])\n"
"vmovups %%ymm4, 128(%[dst])\n"
"vmovups %%ymm5, 160(%[dst])\n"
"vmovups %%ymm6, 192(%[dst])\n"
"vmovups %%ymm7, 224(%[dst])\n"
"vmovups %%ymm8, 256(%[dst])\n"
:
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}

View File

@ -0,0 +1,87 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# generate gemm fma asm code
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c
# generate gemm fma intrinics code
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c
# generate gemm avx512 asm code
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=4 col_block=96 -O ./gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=3 col_block=96 -O ./gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=2 col_block=96 -O ./gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=1 col_block=96 -O ./gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=5 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=4 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=3 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=2 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=1 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c

View File

@ -0,0 +1,148 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""HPC generator"""
import sys
import os
import io
import argparse
from itertools import chain
def key_value_pair(line):
"""
split key and value
:param line:
:return:
"""
key, value = line.split("=", 1)
try:
value = int(value)
except ValueError:
print("Error: you input value must be integer, but now is ", value)
return key, value
def get_indent(line):
"""
get indent length
:param line:
:return:
"""
index = 0
for i in line:
if i == " ":
index += 1
else:
break
return index
def print_line(line):
"""
Convert line to a python string
:param line:
:return:
"""
global python_indent
global generate_code_indent
if line.strip()[0] == "}" or line.strip()[0] == ")":
python_indent = -1
split_str = line.split("@")
if line.strip()[0] != "@" and len(split_str) == 1:
if get_indent(line) == python_indent or python_indent == -1:
result = ["print(", line, ", file=OUT_STREAM)"]
python_indent = -1
if "{" in line or "asm volatile(" in line:
generate_code_indent = get_indent(line)
if line.strip().startswith("}") and "{" not in line:
generate_code_indent -= 4
if (len(line) == 1 and line[0] == "}"):
# modify next fun generate_code_indent
generate_code_indent = -4
return "\"".join(result)
if line.strip()[0] == '@':
# get python indent and first generate_code_indent
if python_indent == -1:
generate_code_indent = get_indent(line) - 4
python_indent = get_indent(line)
result = split_str[0][python_indent:] + split_str[1]
return result
index = get_indent(split_str[0])
result = [split_str[0][python_indent:index] + "print("]
Prefix = " " * (generate_code_indent + 4) + split_str[0].lstrip()
Suffix = " %("
for str_tmp in split_str[1:]:
second = str_tmp.find("}")
Suffix += str_tmp[1:second] + ', '
str_tmp = str_tmp.replace(str_tmp[0:second + 1], "%d")
Prefix += str_tmp
result.append(Prefix)
result.append(Suffix + "), file=OUT_STREAM)")
return "\"".join(result)
def generate_code(template_file, exec_dict):
"""
generate hpc
:param template_file: template file path
:param exec_dict: dict
:return: hpc
"""
output_stream = io.StringIO()
with open(template_file, 'r') as f:
generate_code_lines = []
for line in f:
line = line.replace("\n", "")
if line.strip() and line.strip()[0] != "@":
line = line.replace("\"", "\\\"")
if line.strip() and line.strip()[0] != "@":
line = line.replace("%", "%%")
if "print" in line:
line = line.replace("%%", "%")
if not line:
generate_code_lines.append("print(" + "\"" + line + "\"" + ", file=OUT_STREAM)")
else:
str = print_line(line)
if "%(" not in str:
str = str.replace("%%[", "%[")
generate_code_lines.append(str)
# print('\n'.join(generate_code_lines))
c = compile('\n'.join(generate_code_lines), '', 'exec')
exec_dict["OUT_STREAM"] = output_stream
exec(c, exec_dict)
return output_stream.getvalue()
generate_code_indent = -4
python_indent = -1
parser = argparse.ArgumentParser(description="MSLite NNACL Code Generator")
parser.add_argument("-I", dest="Template_File", nargs=1, help="template file to generate code")
parser.add_argument("-A", dest="defines", metavar="KEY=VALUE", nargs="*", type=key_value_pair, action="append",
help="Custom Parameters")
parser.add_argument("-O", dest="Output_File", nargs=1, help="generate code output file path")
if __name__ == "__main__":
parameters = parser.parse_args(sys.argv[1:])
exec_globals = dict(chain(*parameters.defines))
generate_code_str = generate_code(parameters.Template_File[0], exec_globals)
if os.path.exists(parameters.Output_File[0]):
os.remove(parameters.Output_File[0])
saveDir = os.path.dirname(parameters.Output_File[0])
if not os.path.exists(saveDir):
os.mkdir(saveDir)
with open(parameters.Output_File[0], "w", encoding='utf-8') as output_file:
output_file.write(generate_code_str)

View File

@ -0,0 +1,169 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, const float *src, const float *weight,
const float *bias, const size_t act_flag, const size_t row_block,
const size_t col_block, const size_t deep, const size_t src_stride,
const size_t dst_stride, const size_t inc_flag) {
@asm_flag_list = []
@row_split_number = [row for row in range(3, row_block, 3)]
@for row in row_split_number:
const float *dst_@{row} = dst + @{row} * dst_stride;
@asm_flag_list.append("[dst_" + str(row) + "] " + "\"r\"(dst_" + str(row) + ")");
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
@col_split_num = col_block >> 4;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\\n"
"je 0f\\n"
@for row in range(0, row_block):
@tmp = int(row / 3) * 3
@for col in range(0, col_split_num):
@if row % 3 == 0:
"vmovups @{col * 64}(%[dst_@{tmp}]), %%zmm@{row * col_split_num + col}\\n"
@else:
"vmovups @{col * 64}(%[dst_@{tmp}], %[dst_stride], @{row - tmp}), %%zmm@{row * col_split_num + col}\\n"
"jmp 2f\\n"
"0:\\n"
"cmpq $0, %[bias]\\n"
"je 1f\\n"
@for row in range(0, row_block):
@for col in range(0, col_split_num):
"vmovaps @{col * 64}(%[bias]), %%zmm@{row * col_split_num + col}\\n"
"jmp 2f\\n"
"1:\\n"
@for row in range(0, row_block):
@for col in range(0, col_split_num):
"vxorps %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}\\n"
"2:\\n"
:
@list = ["[dst_0] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)"]
@list.extend(asm_flag_list)
@print(" : " + ", ".join(list), file=OUT_STREAM)
@print(" : " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, row_block * col_block >>4)]), file=OUT_STREAM)
);
@for row in row_split_number:
const float *src_@{row} = src + @{row} * dst_stride;
@asm_flag_list.append("[src_" + str(row) + "] " + "\"r\"(src_" + str(row) + ")");
asm volatile(
"0:\\n"
@loop_count = 8
@for i in range(0, loop_count):
// block @{i}
@for col in range(0, col_split_num):
"vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n"
@if col_split_num == 6:
@if row_block == 4:
"vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num}\\n"
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], 1), %%zmm@{31 - col_split_num - 1}\\n"
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{0 * col_split_num + col}\\n"
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{1 * col_split_num + col}\\n"
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], 2), %%zmm@{31 - col_split_num}\\n"
"vbroadcastss @{i * 4}(%[src_3]), %%zmm@{31 - col_split_num - 1}\\n"
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n"
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{3 * col_split_num + col}\\n"
@else:
@for row in range(0, row_block):
@if row == 0:
"vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num - row}\\n"
@else:
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], @{row}), %%zmm@{31 - col_split_num - row}\\n"
@for row in range(0, row_block):
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - row}, %%zmm@{row * col_split_num + col}\\n"
@elif col_split_num == 5:
@if row_block == 5:
"vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num}\\n"
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], 1), %%zmm@{31 - col_split_num - 1}\\n"
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{0 * col_split_num + col}\\n"
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{1 * col_split_num + col}\\n"
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], 2), %%zmm@{31 - col_split_num}\\n"
"vbroadcastss @{i * 4}(%[src_3]), %%zmm@{31 - col_split_num - 1}\\n"
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n"
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{3 * col_split_num + col}\\n"
"vbroadcastss @{i * 4}(%[src_3], %[src_stride], 1), %%zmm@{31 - col_split_num}\\n"
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n"
@else:
@for row in range(0, row_block):
@if row == 0:
"vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num - row}\\n"
@else:
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], @{row}), %%zmm@{31 - col_split_num - row}\\n"
@for row in range(0, row_block):
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - row}, %%zmm@{row * col_split_num + col}\\n"
@elif col_split_num == 2:
@for row in range(int(row_block / 6)):
@for j in range(0, 6):
@tmp = int(j / 3) * 3
@if j % 3 == 0:
"vbroadcastss @{i * 4}(%[src_@{row * 6 + tmp}]), %%zmm@{31 - col_split_num - j}\\n"
@else:
"vbroadcastss @{i * 4}(%[src_@{row * 6 + tmp}], %[src_stride], @{j - tmp}), %%zmm@{31 - col_split_num - j}\\n"
@for col in range(0, col_split_num):
@for j in range(0, 6):
"fmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - j}, %%zmm@{(row * 6 + j) * col_split_num + col}\\n"
"dec %[deep]\\n"
"add $@{col_block * 4 * 8}, %[weight]\\n"
"add $@{loop_count * 4}, %[src_0]\\n"
@for row in row_split_number:
"add $@{loop_count * 4}, %[src_@{row}]\\n"
"jg 0b\\n"
"and $0x2, %[inc_flag]\\n"
"je 3f\\n"
"movq %[act_flag], %rax\\n"
"and $0x3, %eax\\n"
"je 3f\\n"
// relu
"vxorps %zmm31, %zmm31, %zmm31\\n"
@for col in range(0, col_split_num):
@for row in range(0, row_block):
"vmaxps %%zmm@{row + col * row_block}, %%zmm31, %%zmm@{row + col * row_block}\\n"
"and $0x1, %eax\\n"
"je 3f\\n"
// relu6
"mov $0x40C00000, %eax\\n"
"vmovd %eax, %xmm30\\n"
"vbroadcastss %xmm30, %zmm30\\n"
@for col in range(0, col_split_num):
@for row in range(0, row_block):
"vminps %%zmm@{row + col * row_block}, %%zmm30, %%zmm@{row + col * row_block}\\n"
"3:\\n"
@for row in range(0, row_block):
@tmp = int(row / 3) * 3
@for col in range(0, col_split_num):
@if row % 3 == 0:
"vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{tmp}])\\n"
@else:
"vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{tmp}], %[dst_stride], @{row - tmp}),\\n"
:
@list = ["[src_0] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[deep] \"r\"(deep_t)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst_0] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"]
@list.extend(asm_flag_list)
@print(" : " + ", ".join(list), file=OUT_STREAM)
@print(" : \"%rax\", " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, 32)]), file=OUT_STREAM)
);
}

View File

@ -0,0 +1,85 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma intrinsic code
void nnacl_gemm_fma_@{row_block}x@{col_block}_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight,
const float *bias, const size_t act_flag, const size_t row_block,
const size_t col_block, const size_t deep, const size_t src_stride,
const size_t dst_stride, const size_t inc_flag) {
@for i in range(0, row_block):
@for j in range(0, col_block >> 3):
__m256 dst@{j * row_block + i};
if (inc_flag) {
@for i in range(0, row_block):
@for j in range(0, col_block >> 3):
dst@{j * row_block + i} = _mm256_load_ps(dst + @{j} * dst_stride + @{i * 8});
} else if (bias == NULL) {
@for i in range(0, row_block * col_block >> 3):
dst@{i} = _mm256_setzero_ps();
} else {
@for i in range(0, row_block):
@for j in range(0, col_block >> 3):
dst@{j * row_block + i} = _mm256_load_ps(bias + @{j * 8});
}
for (int i = 0; i < (deep >> 3); ++i) {
@for i in range(0, 8):
// bock@{i}
@if col_block == 32:
@for row in range(0, row_block):
__m256 src@{row}@{i} = _mm256_set1_ps(*(src + @{row * 8 + i}));
@for col in range(0, col_block >> 3):
__m256 weight@{col}@{i} = _mm256_load_ps(weight + @{col * 8 + i * col_block});
@for row in range(0, row_block):
dst@{row + col * row_block} = _mm256_fmadd_ps(dst@{row + col * row_block}, src@{row}@{i}, weight@{col}@{i});
@else:
@for col in range(0, col_block >> 3):
__m256 weight@{col}@{i} = _mm256_load_ps(weight + @{col * 8 + i * col_block});
@for row in range(0, row_block):
__m256 src@{row}@{i} = _mm256_set1_ps(*(src + @{row * 8 + i}));
@for col in range(0, col_block >> 3):
dst@{row + col * row_block} = _mm256_fmadd_ps(dst@{row + col * row_block}, src@{row}@{i}, weight@{col}@{i});
src = src + src_stride;
weight += @{8 * col_block * 4};
}
if (act_flag & 0x02) {
// relu6
__m256 relu6 = _mm256_set1_ps(6.0f);
__m256 relu = _mm256_setzero_ps();
@for i in range(0, row_block):
@for j in range(0, col_block >> 3):
dst@{i + j * row_block} = _mm256_min_ps(dst@{i + j * row_block}, relu6);
// relu
@for i in range(0, row_block):
@for j in range(0, col_block >> 3):
dst@{i + j * row_block} = _mm256_max_ps(dst@{i + j * row_block}, relu);
}
if (act_flag & 0x01) {
// relu
__m256 relu = _mm256_setzero_ps();
@for i in range(0, row_block):
@for j in range(0, col_block >> 3):
dst@{i + j * row_block} = _mm256_max_ps(dst@{i + j * row_block}, relu);
}
@if col_block == 32:
@for j in range(0, col_block >> 3):
@for i in range(0, row_block):
_mm256_store_ps(dst + @{j} * src_stride + @{i * 8}, dst@{j * row_block + i});
@else:
@for j in range(0, col_block >> 3):
@for i in range(0, row_block):
_mm256_store_ps(dst + @{j} * src_stride + @{i * 8}, dst@{j * row_block + i});
}

View File

@ -0,0 +1,149 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 fma asm code
void nnacl_gemm_fma_@{row_block}x@{col_block}_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight,
const float *bias, const size_t act_flag, const size_t row_block,
const size_t col_block, const size_t deep, const size_t src_stride,
const size_t dst_stride, const size_t inc_flag) {
@if col_block == 32:
const float *dst_4 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\\n"
"je 0f\\n"
@for col in range(0, min((col_block >> 3), 3)):
@for row in range(0, row_block):
@if col == 0:
"vmovups @{row * 32}(%[dst]), %%ymm@{row + col * row_block}\\n"
@else:
"vmovups @{row * 32}(%[dst], %[dst_stride], @{col}), %%ymm@{row + col * row_block}\\n"
@if col_block == 32:
@for row in range(0, row_block):
"vmovups @{row * 32}(%[dst_4]), %%ymm@{row + (col + 1) * row_block}\\n"
"jmp 2f\\n"
"0:\\n"
"cmpq $0, %[bias]\\n"
"je 1f\\n"
@for col in range(0, col_block >> 3):
@for row in range(0, row_block):
"vmovaps @{col * 32}(%[bias]), %%ymm@{row + col * row_block}\\n"
"jmp 2f\\n"
"1:\\n"
@for col in range(0, col_block >> 3):
@for row in range(0, row_block):
"vxorps %%ymm@{row + col * row_block}, %%ymm@{row + col * row_block}, %%ymm@{row + col * row_block}\\n"
"2:\\n"
:
@list = ["[dst] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)"]
@if col_block == 32:
@list.append("[dst_4] \"r\"(dst_4)")
@print(" : " + ", ".join(list), file=OUT_STREAM)
@print(" : " + ", ".join(["\"%ymm" + str(i) + "\"" for i in range(0, row_block * col_block >> 3)]), file=OUT_STREAM)
);
asm volatile(
"0:\\n"
@for i in range(0, 8):
// block @{i}
@if col_block == 32:
@for row in range(0, row_block):
"vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{15 - row}\\n"
@for col in range(0, col_block >> 3):
"vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - row_block}\\n"
@for row in range(0, row_block):
"vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - row_block}, %%ymm@{15 - row}\\n"
@elif col_block == 24:
@for col in range(0, col_block >> 3):
"vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n"
@for row in range(0, row_block):
"vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{14 - col}\\n"
@for col in range(0, col_block >> 3):
"vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3)}, %%ymm@{15 - col}\\n"
@elif col_block == 16:
@for col in range(0, col_block >> 3):
"vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n"
@for row in range(0, row_block >> 1):
"vbroadcastss @{row * 64 + i}(%[src]), %%ymm@{14 - col}\\n"
"vbroadcastss @{row * 64 + 32 + i}(%[src]), %%ymm@{13 - col}\\n"
@for col in range(0, col_block >> 3):
@for j in range(0, 2):
"vfmadd231ps %%ymm@{row * 2 + j + col * row_block}, %%ymm@{15 - (col_block >> 3) - j}, %%ymm@{15 - col}\\n"
@for row in range(row_block >> 1 << 1, row_block):
"vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{14 - col}\\n"
@for col in range(0, col_block >> 3):
"vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3)}, %%ymm@{15 - col}\\n"
@else:
@for col in range(0, col_block >> 3):
"vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n"
@split_num = 3
@for row in range(0, int(row_block / split_num)):
@for j in range(0, split_num):
"vbroadcastss @{row * 96 + j * 32 + i}(%[src]), %%ymm@{15 - (col_block >> 3) - j}\\n"
@for col in range(0, col_block >> 3):
@for j in range(0, split_num):
"vfmadd231ps %%ymm@{row * split_num + j + col * row_block}, %%ymm@{15 - (col_block >> 3) - j}, %%ymm@{15 - col}\\n"
@for row in range(int(row_block / split_num) * split_num, row_block):
"vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{15 - (col_block >> 3) - (row - int(row_block / split_num) * split_num)}\\n"
@for col in range(0, col_block >> 3):
@for row in range(int(row_block / split_num) * split_num, row_block):
"vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3) - (row - int(row_block / split_num) * split_num)}, %%ymm@{15 - col}\\n"
"dec %[deep]\\n"
"add @{col_block * 4 * 8}, %[weight]\\n"
"add %[src_stride], %[src]\\n"
"jg 0b\\n"
"movq %[inc_flag], %rax\\n"
"and $0x2, %eax\\n"
"je 3f\\n"
"movq %[act_flag], %rax\\n"
"and $0x3, %eax\\n"
"je 3f\\n"
// relu
"vxorps %ymm15, %ymm15, %ymm15\\n"
@for col in range(0, col_block >> 3):
@for row in range(0, row_block):
"vmaxps %%ymm@{row + col * row_block}, %%ymm15, %%ymm@{row + col * row_block}\\n"
"and $0x1, %eax\\n"
"je 3f\\n"
// relu6
"mov $0x40C00000, %eax\\n"
"vmovd %eax, %xmm14\\n"
"vpermps %ymm14, %ymm15, %ymm14\\n"
@for col in range(0, col_block >> 3):
@for row in range(0, row_block):
"vminps %%ymm@{row + col * row_block}, %%ymm14, %%ymm@{row + col * row_block}\\n"
"3:\\n"
@for col in range(0, min((col_block >> 3), 3)):
@for row in range(0, row_block):
@if col == 0:
"vmovups %%ymm@{row + col * row_block}, @{row * 32}(%[dst])\\n"
@else:
"vmovups %%ymm@{row + col * row_block}, @{row * 32}(%[dst], %[dst_stride], @{col})\\n"
@if col_block == 32:
@for row in range(0, row_block):
"vmovups %%ymm@{row + (col + 1) * row_block}, @{row * 32}(%[dst_4])\\n"
:
@list = ["[src] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[deep] \"r\"(deep_t)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"]
@if col_block == 32:
@list.append("[dst_4] \"r\"(dst_4)")
@print(" : " + ", ".join(list), file=OUT_STREAM)
@print(" : \"%rax\", " + ", ".join(["\"%ymm" + str(i) + "\"" for i in range(0, 16)]), file=OUT_STREAM)
);
}