diff --git a/.jenkins/check/config/filter_pylint.txt b/.jenkins/check/config/filter_pylint.txt index e7f4e134b97..30a3ee48c0f 100644 --- a/.jenkins/check/config/filter_pylint.txt +++ b/.jenkins/check/config/filter_pylint.txt @@ -118,3 +118,7 @@ "mindspore/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice_grad.py" "redefined-outer-name" "mindspore/tests/st/pynative/parser/test_parser_construct.py" "bad-super-call" "mindspore/tests/ut/python/optimizer/test_auto_grad.py" "broad-except" + +#MindSpore Lite +"mindspore/mindspore/lite/experiment/HPC-generator/generator.py" "redefined-builtin" +"mindspore/mindspore/lite/experiment/HPC-generator/generator.py" "exec-used" \ No newline at end of file diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index b8c0e8f12e3..8aa5770e181 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -96,3 +96,54 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/avx/Tiled mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetNodeOutputType mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetValueToProto mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetScalarToProto +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x96_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x80_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x96_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x96_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x80_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x96_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x80_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x80_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x80_kernel_nhwc_fp32 diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..fead3fbf07f --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c @@ -0,0 +1,194 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 128(%[bias]), %%zmm2\n" + "vmovaps 192(%[bias]), %%zmm3\n" + "vmovaps 256(%[bias]), %%zmm4\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4"); + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + + "dec %[deep]\n" + "add $2560, %[weight]\n" + "add $32, %[src_0]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..dfa732aedea --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c @@ -0,0 +1,216 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 128(%[bias]), %%zmm2\n" + "vmovaps 192(%[bias]), %%zmm3\n" + "vmovaps 256(%[bias]), %%zmm4\n" + "vmovaps 320(%[bias]), %%zmm5\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + + "dec %[deep]\n" + "add $3072, %[weight]\n" + "add $32, %[src_0]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..49ee1fc9540 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c @@ -0,0 +1,272 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 128(%[bias]), %%zmm2\n" + "vmovaps 192(%[bias]), %%zmm3\n" + "vmovaps 256(%[bias]), %%zmm4\n" + "vmovaps 0(%[bias]), %%zmm5\n" + "vmovaps 64(%[bias]), %%zmm6\n" + "vmovaps 128(%[bias]), %%zmm7\n" + "vmovaps 192(%[bias]), %%zmm8\n" + "vmovaps 256(%[bias]), %%zmm9\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + + "dec %[deep]\n" + "add $2560, %[weight]\n" + "add $32, %[src_0]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..a36d357f577 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c @@ -0,0 +1,308 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 128(%[bias]), %%zmm2\n" + "vmovaps 192(%[bias]), %%zmm3\n" + "vmovaps 256(%[bias]), %%zmm4\n" + "vmovaps 320(%[bias]), %%zmm5\n" + "vmovaps 0(%[bias]), %%zmm6\n" + "vmovaps 64(%[bias]), %%zmm7\n" + "vmovaps 128(%[bias]), %%zmm8\n" + "vmovaps 192(%[bias]), %%zmm9\n" + "vmovaps 256(%[bias]), %%zmm10\n" + "vmovaps 320(%[bias]), %%zmm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + + "dec %[deep]\n" + "add $3072, %[weight]\n" + "add $32, %[src_0]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1),\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..5dbb4874098 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c @@ -0,0 +1,351 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 128(%[bias]), %%zmm2\n" + "vmovaps 192(%[bias]), %%zmm3\n" + "vmovaps 256(%[bias]), %%zmm4\n" + "vmovaps 0(%[bias]), %%zmm5\n" + "vmovaps 64(%[bias]), %%zmm6\n" + "vmovaps 128(%[bias]), %%zmm7\n" + "vmovaps 192(%[bias]), %%zmm8\n" + "vmovaps 256(%[bias]), %%zmm9\n" + "vmovaps 0(%[bias]), %%zmm10\n" + "vmovaps 64(%[bias]), %%zmm11\n" + "vmovaps 128(%[bias]), %%zmm12\n" + "vmovaps 192(%[bias]), %%zmm13\n" + "vmovaps 256(%[bias]), %%zmm14\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14"); + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + + "dec %[deep]\n" + "add $2560, %[weight]\n" + "add $32, %[src_0]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2),\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..d84aa83254b --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c @@ -0,0 +1,401 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n" + "vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 128(%[bias]), %%zmm2\n" + "vmovaps 192(%[bias]), %%zmm3\n" + "vmovaps 256(%[bias]), %%zmm4\n" + "vmovaps 320(%[bias]), %%zmm5\n" + "vmovaps 0(%[bias]), %%zmm6\n" + "vmovaps 64(%[bias]), %%zmm7\n" + "vmovaps 128(%[bias]), %%zmm8\n" + "vmovaps 192(%[bias]), %%zmm9\n" + "vmovaps 256(%[bias]), %%zmm10\n" + "vmovaps 320(%[bias]), %%zmm11\n" + "vmovaps 0(%[bias]), %%zmm12\n" + "vmovaps 64(%[bias]), %%zmm13\n" + "vmovaps 128(%[bias]), %%zmm14\n" + "vmovaps 192(%[bias]), %%zmm15\n" + "vmovaps 256(%[bias]), %%zmm16\n" + "vmovaps 320(%[bias]), %%zmm17\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" + + "dec %[deep]\n" + "add $3072, %[weight]\n" + "add $32, %[src_0]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2),\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..d3e8bf83423 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c @@ -0,0 +1,434 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 0(%[dst_3]), %%zmm15\n" + "vmovups 64(%[dst_3]), %%zmm16\n" + "vmovups 128(%[dst_3]), %%zmm17\n" + "vmovups 192(%[dst_3]), %%zmm18\n" + "vmovups 256(%[dst_3]), %%zmm19\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 128(%[bias]), %%zmm2\n" + "vmovaps 192(%[bias]), %%zmm3\n" + "vmovaps 256(%[bias]), %%zmm4\n" + "vmovaps 0(%[bias]), %%zmm5\n" + "vmovaps 64(%[bias]), %%zmm6\n" + "vmovaps 128(%[bias]), %%zmm7\n" + "vmovaps 192(%[bias]), %%zmm8\n" + "vmovaps 256(%[bias]), %%zmm9\n" + "vmovaps 0(%[bias]), %%zmm10\n" + "vmovaps 64(%[bias]), %%zmm11\n" + "vmovaps 128(%[bias]), %%zmm12\n" + "vmovaps 192(%[bias]), %%zmm13\n" + "vmovaps 256(%[bias]), %%zmm14\n" + "vmovaps 0(%[bias]), %%zmm15\n" + "vmovaps 64(%[bias]), %%zmm16\n" + "vmovaps 128(%[bias]), %%zmm17\n" + "vmovaps 192(%[bias]), %%zmm18\n" + "vmovaps 256(%[bias]), %%zmm19\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * dst_stride; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_0], %[src_stride], 3), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_0], %[src_stride], 3), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_0], %[src_stride], 3), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_0], %[src_stride], 3), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_0], %[src_stride], 3), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_0], %[src_stride], 3), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_0], %[src_stride], 3), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_0], %[src_stride], 3), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" + + "dec %[deep]\n" + "add $2560, %[weight]\n" + "add $32, %[src_0]\n" + "add $32, %[src_3]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm15, 0(%[dst_3])\n" + "vmovups %%zmm16, 64(%[dst_3])\n" + "vmovups %%zmm17, 128(%[dst_3])\n" + "vmovups %%zmm18, 192(%[dst_3])\n" + "vmovups %%zmm19, 256(%[dst_3])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..d945400880a --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c @@ -0,0 +1,499 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 320(%[dst_0]), %%zmm5\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n" + "vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n" + "vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_3]), %%zmm18\n" + "vmovups 64(%[dst_3]), %%zmm19\n" + "vmovups 128(%[dst_3]), %%zmm20\n" + "vmovups 192(%[dst_3]), %%zmm21\n" + "vmovups 256(%[dst_3]), %%zmm22\n" + "vmovups 320(%[dst_3]), %%zmm23\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 128(%[bias]), %%zmm2\n" + "vmovaps 192(%[bias]), %%zmm3\n" + "vmovaps 256(%[bias]), %%zmm4\n" + "vmovaps 320(%[bias]), %%zmm5\n" + "vmovaps 0(%[bias]), %%zmm6\n" + "vmovaps 64(%[bias]), %%zmm7\n" + "vmovaps 128(%[bias]), %%zmm8\n" + "vmovaps 192(%[bias]), %%zmm9\n" + "vmovaps 256(%[bias]), %%zmm10\n" + "vmovaps 320(%[bias]), %%zmm11\n" + "vmovaps 0(%[bias]), %%zmm12\n" + "vmovaps 64(%[bias]), %%zmm13\n" + "vmovaps 128(%[bias]), %%zmm14\n" + "vmovaps 192(%[bias]), %%zmm15\n" + "vmovaps 256(%[bias]), %%zmm16\n" + "vmovaps 320(%[bias]), %%zmm17\n" + "vmovaps 0(%[bias]), %%zmm18\n" + "vmovaps 64(%[bias]), %%zmm19\n" + "vmovaps 128(%[bias]), %%zmm20\n" + "vmovaps 192(%[bias]), %%zmm21\n" + "vmovaps 256(%[bias]), %%zmm22\n" + "vmovaps 320(%[bias]), %%zmm23\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23"); + const float *src_3 = src + 3 * dst_stride; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vmovups 320(%[weight]), %%zmm26\n" + "vbroadcastss 0(%[src_0]), %%zmm25\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 1 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vmovups 512(%[weight]), %%zmm29\n" + "vmovups 576(%[weight]), %%zmm28\n" + "vmovups 640(%[weight]), %%zmm27\n" + "vmovups 704(%[weight]), %%zmm26\n" + "vbroadcastss 4(%[src_0]), %%zmm25\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 2 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vmovups 1024(%[weight]), %%zmm27\n" + "vmovups 1088(%[weight]), %%zmm26\n" + "vbroadcastss 8(%[src_0]), %%zmm25\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 3 + "vmovups 1152(%[weight]), %%zmm31\n" + "vmovups 1216(%[weight]), %%zmm30\n" + "vmovups 1280(%[weight]), %%zmm29\n" + "vmovups 1344(%[weight]), %%zmm28\n" + "vmovups 1408(%[weight]), %%zmm27\n" + "vmovups 1472(%[weight]), %%zmm26\n" + "vbroadcastss 12(%[src_0]), %%zmm25\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 4 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vmovups 1792(%[weight]), %%zmm27\n" + "vmovups 1856(%[weight]), %%zmm26\n" + "vbroadcastss 16(%[src_0]), %%zmm25\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 5 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vmovups 2240(%[weight]), %%zmm26\n" + "vbroadcastss 20(%[src_0]), %%zmm25\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 6 + "vmovups 2304(%[weight]), %%zmm31\n" + "vmovups 2368(%[weight]), %%zmm30\n" + "vmovups 2432(%[weight]), %%zmm29\n" + "vmovups 2496(%[weight]), %%zmm28\n" + "vmovups 2560(%[weight]), %%zmm27\n" + "vmovups 2624(%[weight]), %%zmm26\n" + "vbroadcastss 24(%[src_0]), %%zmm25\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + // block 7 + "vmovups 2688(%[weight]), %%zmm31\n" + "vmovups 2752(%[weight]), %%zmm30\n" + "vmovups 2816(%[weight]), %%zmm29\n" + "vmovups 2880(%[weight]), %%zmm28\n" + "vmovups 2944(%[weight]), %%zmm27\n" + "vmovups 3008(%[weight]), %%zmm26\n" + "vbroadcastss 28(%[src_0]), %%zmm25\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" + "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" + + "dec %[deep]\n" + "add $3072, %[weight]\n" + "add $32, %[src_0]\n" + "add $32, %[src_3]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 320(%[dst_0])\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm18, 0(%[dst_3])\n" + "vmovups %%zmm19, 64(%[dst_3])\n" + "vmovups %%zmm20, 128(%[dst_3])\n" + "vmovups %%zmm21, 192(%[dst_3])\n" + "vmovups %%zmm22, 256(%[dst_3])\n" + "vmovups %%zmm23, 320(%[dst_3])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..9a0c58579b4 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c @@ -0,0 +1,513 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 256(%[dst_0]), %%zmm4\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n" + "vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n" + "vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n" + "vmovups 0(%[dst_3]), %%zmm15\n" + "vmovups 64(%[dst_3]), %%zmm16\n" + "vmovups 128(%[dst_3]), %%zmm17\n" + "vmovups 192(%[dst_3]), %%zmm18\n" + "vmovups 256(%[dst_3]), %%zmm19\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm21\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm22\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm23\n" + "vmovups 256(%[dst_3], %[dst_stride], 1), %%zmm24\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 128(%[bias]), %%zmm2\n" + "vmovaps 192(%[bias]), %%zmm3\n" + "vmovaps 256(%[bias]), %%zmm4\n" + "vmovaps 0(%[bias]), %%zmm5\n" + "vmovaps 64(%[bias]), %%zmm6\n" + "vmovaps 128(%[bias]), %%zmm7\n" + "vmovaps 192(%[bias]), %%zmm8\n" + "vmovaps 256(%[bias]), %%zmm9\n" + "vmovaps 0(%[bias]), %%zmm10\n" + "vmovaps 64(%[bias]), %%zmm11\n" + "vmovaps 128(%[bias]), %%zmm12\n" + "vmovaps 192(%[bias]), %%zmm13\n" + "vmovaps 256(%[bias]), %%zmm14\n" + "vmovaps 0(%[bias]), %%zmm15\n" + "vmovaps 64(%[bias]), %%zmm16\n" + "vmovaps 128(%[bias]), %%zmm17\n" + "vmovaps 192(%[bias]), %%zmm18\n" + "vmovaps 256(%[bias]), %%zmm19\n" + "vmovaps 0(%[bias]), %%zmm20\n" + "vmovaps 64(%[bias]), %%zmm21\n" + "vmovaps 128(%[bias]), %%zmm22\n" + "vmovaps 192(%[bias]), %%zmm23\n" + "vmovaps 256(%[bias]), %%zmm24\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + "vxorps %%zmm24, %%zmm24, %%zmm24\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23", "%zmm24"); + const float *src_3 = src + 3 * dst_stride; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vmovups 256(%[weight]), %%zmm27\n" + "vbroadcastss 0(%[src_0]), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 0(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + // block 1 + "vmovups 320(%[weight]), %%zmm31\n" + "vmovups 384(%[weight]), %%zmm30\n" + "vmovups 448(%[weight]), %%zmm29\n" + "vmovups 512(%[weight]), %%zmm28\n" + "vmovups 576(%[weight]), %%zmm27\n" + "vbroadcastss 4(%[src_0]), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 4(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + // block 2 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vmovups 768(%[weight]), %%zmm29\n" + "vmovups 832(%[weight]), %%zmm28\n" + "vmovups 896(%[weight]), %%zmm27\n" + "vbroadcastss 8(%[src_0]), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 8(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + // block 3 + "vmovups 960(%[weight]), %%zmm31\n" + "vmovups 1024(%[weight]), %%zmm30\n" + "vmovups 1088(%[weight]), %%zmm29\n" + "vmovups 1152(%[weight]), %%zmm28\n" + "vmovups 1216(%[weight]), %%zmm27\n" + "vbroadcastss 12(%[src_0]), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 12(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + // block 4 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vmovups 1536(%[weight]), %%zmm27\n" + "vbroadcastss 16(%[src_0]), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 16(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + // block 5 + "vmovups 1600(%[weight]), %%zmm31\n" + "vmovups 1664(%[weight]), %%zmm30\n" + "vmovups 1728(%[weight]), %%zmm29\n" + "vmovups 1792(%[weight]), %%zmm28\n" + "vmovups 1856(%[weight]), %%zmm27\n" + "vbroadcastss 20(%[src_0]), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 20(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + // block 6 + "vmovups 1920(%[weight]), %%zmm31\n" + "vmovups 1984(%[weight]), %%zmm30\n" + "vmovups 2048(%[weight]), %%zmm29\n" + "vmovups 2112(%[weight]), %%zmm28\n" + "vmovups 2176(%[weight]), %%zmm27\n" + "vbroadcastss 24(%[src_0]), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 24(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + // block 7 + "vmovups 2240(%[weight]), %%zmm31\n" + "vmovups 2304(%[weight]), %%zmm30\n" + "vmovups 2368(%[weight]), %%zmm29\n" + "vmovups 2432(%[weight]), %%zmm28\n" + "vmovups 2496(%[weight]), %%zmm27\n" + "vbroadcastss 28(%[src_0]), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" + "vbroadcastss 28(%[src_3]), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + + "dec %[deep]\n" + "add $2560, %[weight]\n" + "add $32, %[src_0]\n" + "add $32, %[src_3]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "vmaxps %%zmm24, %%zmm31, %%zmm24\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + "vminps %%zmm24, %%zmm30, %%zmm24\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 256(%[dst_0])\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm15, 0(%[dst_3])\n" + "vmovups %%zmm16, 64(%[dst_3])\n" + "vmovups %%zmm17, 128(%[dst_3])\n" + "vmovups %%zmm18, 192(%[dst_3])\n" + "vmovups %%zmm19, 256(%[dst_3])\n" + "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 1),\n" + "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 1),\n" + "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 1),\n" + "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 1),\n" + "vmovups %%zmm24, 256(%[dst_3], %[dst_stride], 1),\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..1603526877f --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,297 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + __m256 dst9; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + __m256 src90 = _mm256_set1_ps(*(src + 72)); + dst9 = _mm256_fmadd_ps(dst9, src90, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + __m256 src91 = _mm256_set1_ps(*(src + 73)); + dst9 = _mm256_fmadd_ps(dst9, src91, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + __m256 src92 = _mm256_set1_ps(*(src + 74)); + dst9 = _mm256_fmadd_ps(dst9, src92, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + __m256 src93 = _mm256_set1_ps(*(src + 75)); + dst9 = _mm256_fmadd_ps(dst9, src93, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + __m256 src94 = _mm256_set1_ps(*(src + 76)); + dst9 = _mm256_fmadd_ps(dst9, src94, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + __m256 src95 = _mm256_set1_ps(*(src + 77)); + dst9 = _mm256_fmadd_ps(dst9, src95, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + __m256 src96 = _mm256_set1_ps(*(src + 78)); + dst9 = _mm256_fmadd_ps(dst9, src96, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + __m256 src97 = _mm256_set1_ps(*(src + 79)); + dst9 = _mm256_fmadd_ps(dst9, src97, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); + _mm256_store_ps(dst + 0 * src_stride + 72, dst9); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..5cd2341cf12 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,303 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "vmovups 288(%[dst]), %%ymm9\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "vmovaps 0(%[bias]), %%ymm9\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 288(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 289(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 290(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 291(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 292(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 293(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 294(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 295(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + "vmovups %%ymm9, 288(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..dcc963de7ef --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,321 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + __m256 dst9; + __m256 dst10; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72); + dst10 = _mm256_load_ps(dst + 0 * dst_stride + 80); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 0); + dst10 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + __m256 src90 = _mm256_set1_ps(*(src + 72)); + dst9 = _mm256_fmadd_ps(dst9, src90, weight00); + __m256 src100 = _mm256_set1_ps(*(src + 80)); + dst10 = _mm256_fmadd_ps(dst10, src100, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + __m256 src91 = _mm256_set1_ps(*(src + 73)); + dst9 = _mm256_fmadd_ps(dst9, src91, weight01); + __m256 src101 = _mm256_set1_ps(*(src + 81)); + dst10 = _mm256_fmadd_ps(dst10, src101, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + __m256 src92 = _mm256_set1_ps(*(src + 74)); + dst9 = _mm256_fmadd_ps(dst9, src92, weight02); + __m256 src102 = _mm256_set1_ps(*(src + 82)); + dst10 = _mm256_fmadd_ps(dst10, src102, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + __m256 src93 = _mm256_set1_ps(*(src + 75)); + dst9 = _mm256_fmadd_ps(dst9, src93, weight03); + __m256 src103 = _mm256_set1_ps(*(src + 83)); + dst10 = _mm256_fmadd_ps(dst10, src103, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + __m256 src94 = _mm256_set1_ps(*(src + 76)); + dst9 = _mm256_fmadd_ps(dst9, src94, weight04); + __m256 src104 = _mm256_set1_ps(*(src + 84)); + dst10 = _mm256_fmadd_ps(dst10, src104, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + __m256 src95 = _mm256_set1_ps(*(src + 77)); + dst9 = _mm256_fmadd_ps(dst9, src95, weight05); + __m256 src105 = _mm256_set1_ps(*(src + 85)); + dst10 = _mm256_fmadd_ps(dst10, src105, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + __m256 src96 = _mm256_set1_ps(*(src + 78)); + dst9 = _mm256_fmadd_ps(dst9, src96, weight06); + __m256 src106 = _mm256_set1_ps(*(src + 86)); + dst10 = _mm256_fmadd_ps(dst10, src106, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + __m256 src97 = _mm256_set1_ps(*(src + 79)); + dst9 = _mm256_fmadd_ps(dst9, src97, weight07); + __m256 src107 = _mm256_set1_ps(*(src + 87)); + dst10 = _mm256_fmadd_ps(dst10, src107, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); + _mm256_store_ps(dst + 0 * src_stride + 72, dst9); + _mm256_store_ps(dst + 0 * src_stride + 80, dst10); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..31c3c4fc2b7 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,325 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "vmovups 288(%[dst]), %%ymm9\n" + "vmovups 320(%[dst]), %%ymm10\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "vmovaps 0(%[bias]), %%ymm9\n" + "vmovaps 0(%[bias]), %%ymm10\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 288(%[src]), %%ymm14\n" + "vbroadcastss 320(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 289(%[src]), %%ymm14\n" + "vbroadcastss 321(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 290(%[src]), %%ymm14\n" + "vbroadcastss 322(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 291(%[src]), %%ymm14\n" + "vbroadcastss 323(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 292(%[src]), %%ymm14\n" + "vbroadcastss 324(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 293(%[src]), %%ymm14\n" + "vbroadcastss 325(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 294(%[src]), %%ymm14\n" + "vbroadcastss 326(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 295(%[src]), %%ymm14\n" + "vbroadcastss 327(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + "vmovups %%ymm9, 288(%[dst])\n" + "vmovups %%ymm10, 320(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..8e39ab64baf --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,345 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + __m256 dst9; + __m256 dst10; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72); + dst10 = _mm256_load_ps(dst + 0 * dst_stride + 80); + dst11 = _mm256_load_ps(dst + 0 * dst_stride + 88); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 0); + dst10 = _mm256_load_ps(bias + 0); + dst11 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + __m256 src90 = _mm256_set1_ps(*(src + 72)); + dst9 = _mm256_fmadd_ps(dst9, src90, weight00); + __m256 src100 = _mm256_set1_ps(*(src + 80)); + dst10 = _mm256_fmadd_ps(dst10, src100, weight00); + __m256 src110 = _mm256_set1_ps(*(src + 88)); + dst11 = _mm256_fmadd_ps(dst11, src110, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + __m256 src91 = _mm256_set1_ps(*(src + 73)); + dst9 = _mm256_fmadd_ps(dst9, src91, weight01); + __m256 src101 = _mm256_set1_ps(*(src + 81)); + dst10 = _mm256_fmadd_ps(dst10, src101, weight01); + __m256 src111 = _mm256_set1_ps(*(src + 89)); + dst11 = _mm256_fmadd_ps(dst11, src111, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + __m256 src92 = _mm256_set1_ps(*(src + 74)); + dst9 = _mm256_fmadd_ps(dst9, src92, weight02); + __m256 src102 = _mm256_set1_ps(*(src + 82)); + dst10 = _mm256_fmadd_ps(dst10, src102, weight02); + __m256 src112 = _mm256_set1_ps(*(src + 90)); + dst11 = _mm256_fmadd_ps(dst11, src112, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + __m256 src93 = _mm256_set1_ps(*(src + 75)); + dst9 = _mm256_fmadd_ps(dst9, src93, weight03); + __m256 src103 = _mm256_set1_ps(*(src + 83)); + dst10 = _mm256_fmadd_ps(dst10, src103, weight03); + __m256 src113 = _mm256_set1_ps(*(src + 91)); + dst11 = _mm256_fmadd_ps(dst11, src113, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + __m256 src94 = _mm256_set1_ps(*(src + 76)); + dst9 = _mm256_fmadd_ps(dst9, src94, weight04); + __m256 src104 = _mm256_set1_ps(*(src + 84)); + dst10 = _mm256_fmadd_ps(dst10, src104, weight04); + __m256 src114 = _mm256_set1_ps(*(src + 92)); + dst11 = _mm256_fmadd_ps(dst11, src114, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + __m256 src95 = _mm256_set1_ps(*(src + 77)); + dst9 = _mm256_fmadd_ps(dst9, src95, weight05); + __m256 src105 = _mm256_set1_ps(*(src + 85)); + dst10 = _mm256_fmadd_ps(dst10, src105, weight05); + __m256 src115 = _mm256_set1_ps(*(src + 93)); + dst11 = _mm256_fmadd_ps(dst11, src115, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + __m256 src96 = _mm256_set1_ps(*(src + 78)); + dst9 = _mm256_fmadd_ps(dst9, src96, weight06); + __m256 src106 = _mm256_set1_ps(*(src + 86)); + dst10 = _mm256_fmadd_ps(dst10, src106, weight06); + __m256 src116 = _mm256_set1_ps(*(src + 94)); + dst11 = _mm256_fmadd_ps(dst11, src116, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + __m256 src97 = _mm256_set1_ps(*(src + 79)); + dst9 = _mm256_fmadd_ps(dst9, src97, weight07); + __m256 src107 = _mm256_set1_ps(*(src + 87)); + dst10 = _mm256_fmadd_ps(dst10, src107, weight07); + __m256 src117 = _mm256_set1_ps(*(src + 95)); + dst11 = _mm256_fmadd_ps(dst11, src117, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); + _mm256_store_ps(dst + 0 * src_stride + 72, dst9); + _mm256_store_ps(dst + 0 * src_stride + 80, dst10); + _mm256_store_ps(dst + 0 * src_stride + 88, dst11); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..585957249bb --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,347 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "vmovups 288(%[dst]), %%ymm9\n" + "vmovups 320(%[dst]), %%ymm10\n" + "vmovups 352(%[dst]), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "vmovaps 0(%[bias]), %%ymm9\n" + "vmovaps 0(%[bias]), %%ymm10\n" + "vmovaps 0(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 288(%[src]), %%ymm14\n" + "vbroadcastss 320(%[src]), %%ymm13\n" + "vbroadcastss 352(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 289(%[src]), %%ymm14\n" + "vbroadcastss 321(%[src]), %%ymm13\n" + "vbroadcastss 353(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 290(%[src]), %%ymm14\n" + "vbroadcastss 322(%[src]), %%ymm13\n" + "vbroadcastss 354(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 291(%[src]), %%ymm14\n" + "vbroadcastss 323(%[src]), %%ymm13\n" + "vbroadcastss 355(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 292(%[src]), %%ymm14\n" + "vbroadcastss 324(%[src]), %%ymm13\n" + "vbroadcastss 356(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 293(%[src]), %%ymm14\n" + "vbroadcastss 325(%[src]), %%ymm13\n" + "vbroadcastss 357(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 294(%[src]), %%ymm14\n" + "vbroadcastss 326(%[src]), %%ymm13\n" + "vbroadcastss 358(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "vbroadcastss 295(%[src]), %%ymm14\n" + "vbroadcastss 327(%[src]), %%ymm13\n" + "vbroadcastss 359(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + "vmovups %%ymm9, 288(%[dst])\n" + "vmovups %%ymm10, 320(%[dst])\n" + "vmovups %%ymm11, 352(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..87a830fbc57 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,105 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src00, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src01, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src02, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src03, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src04, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src05, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src06, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src07, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 1 * src_stride + 0, dst1); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..6458f849f26 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,127 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 32(%[bias]), %%ymm1\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..548f143c0a0 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 2 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src00, weight10); + dst2 = _mm256_fmadd_ps(dst2, src00, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src01, weight11); + dst2 = _mm256_fmadd_ps(dst2, src01, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src02, weight12); + dst2 = _mm256_fmadd_ps(dst2, src02, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src03, weight13); + dst2 = _mm256_fmadd_ps(dst2, src03, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src04, weight14); + dst2 = _mm256_fmadd_ps(dst2, src04, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src05, weight15); + dst2 = _mm256_fmadd_ps(dst2, src05, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src06, weight16); + dst2 = _mm256_fmadd_ps(dst2, src06, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src07, weight17); + dst2 = _mm256_fmadd_ps(dst2, src07, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 1 * src_stride + 0, dst1); + _mm256_store_ps(dst + 2 * src_stride + 0, dst2); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..465fc9aeea1 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm2\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 32(%[bias]), %%ymm1\n" + "vmovaps 64(%[bias]), %%ymm2\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 2)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..c23a6ff6fa4 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c @@ -0,0 +1,153 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 3 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 16); + dst3 = _mm256_load_ps(bias + 24); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 src00 = _mm256_set1_ps(*(src + 0)); + __m256 weight00 = _mm256_load_ps(weight + 0); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 weight10 = _mm256_load_ps(weight + 8); + dst1 = _mm256_fmadd_ps(dst1, src00, weight10); + __m256 weight20 = _mm256_load_ps(weight + 16); + dst2 = _mm256_fmadd_ps(dst2, src00, weight20); + __m256 weight30 = _mm256_load_ps(weight + 24); + dst3 = _mm256_fmadd_ps(dst3, src00, weight30); + // bock1 + __m256 src01 = _mm256_set1_ps(*(src + 1)); + __m256 weight01 = _mm256_load_ps(weight + 32); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 weight11 = _mm256_load_ps(weight + 40); + dst1 = _mm256_fmadd_ps(dst1, src01, weight11); + __m256 weight21 = _mm256_load_ps(weight + 48); + dst2 = _mm256_fmadd_ps(dst2, src01, weight21); + __m256 weight31 = _mm256_load_ps(weight + 56); + dst3 = _mm256_fmadd_ps(dst3, src01, weight31); + // bock2 + __m256 src02 = _mm256_set1_ps(*(src + 2)); + __m256 weight02 = _mm256_load_ps(weight + 64); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 weight12 = _mm256_load_ps(weight + 72); + dst1 = _mm256_fmadd_ps(dst1, src02, weight12); + __m256 weight22 = _mm256_load_ps(weight + 80); + dst2 = _mm256_fmadd_ps(dst2, src02, weight22); + __m256 weight32 = _mm256_load_ps(weight + 88); + dst3 = _mm256_fmadd_ps(dst3, src02, weight32); + // bock3 + __m256 src03 = _mm256_set1_ps(*(src + 3)); + __m256 weight03 = _mm256_load_ps(weight + 96); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 weight13 = _mm256_load_ps(weight + 104); + dst1 = _mm256_fmadd_ps(dst1, src03, weight13); + __m256 weight23 = _mm256_load_ps(weight + 112); + dst2 = _mm256_fmadd_ps(dst2, src03, weight23); + __m256 weight33 = _mm256_load_ps(weight + 120); + dst3 = _mm256_fmadd_ps(dst3, src03, weight33); + // bock4 + __m256 src04 = _mm256_set1_ps(*(src + 4)); + __m256 weight04 = _mm256_load_ps(weight + 128); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 weight14 = _mm256_load_ps(weight + 136); + dst1 = _mm256_fmadd_ps(dst1, src04, weight14); + __m256 weight24 = _mm256_load_ps(weight + 144); + dst2 = _mm256_fmadd_ps(dst2, src04, weight24); + __m256 weight34 = _mm256_load_ps(weight + 152); + dst3 = _mm256_fmadd_ps(dst3, src04, weight34); + // bock5 + __m256 src05 = _mm256_set1_ps(*(src + 5)); + __m256 weight05 = _mm256_load_ps(weight + 160); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 weight15 = _mm256_load_ps(weight + 168); + dst1 = _mm256_fmadd_ps(dst1, src05, weight15); + __m256 weight25 = _mm256_load_ps(weight + 176); + dst2 = _mm256_fmadd_ps(dst2, src05, weight25); + __m256 weight35 = _mm256_load_ps(weight + 184); + dst3 = _mm256_fmadd_ps(dst3, src05, weight35); + // bock6 + __m256 src06 = _mm256_set1_ps(*(src + 6)); + __m256 weight06 = _mm256_load_ps(weight + 192); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 weight16 = _mm256_load_ps(weight + 200); + dst1 = _mm256_fmadd_ps(dst1, src06, weight16); + __m256 weight26 = _mm256_load_ps(weight + 208); + dst2 = _mm256_fmadd_ps(dst2, src06, weight26); + __m256 weight36 = _mm256_load_ps(weight + 216); + dst3 = _mm256_fmadd_ps(dst3, src06, weight36); + // bock7 + __m256 src07 = _mm256_set1_ps(*(src + 7)); + __m256 weight07 = _mm256_load_ps(weight + 224); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 weight17 = _mm256_load_ps(weight + 232); + dst1 = _mm256_fmadd_ps(dst1, src07, weight17); + __m256 weight27 = _mm256_load_ps(weight + 240); + dst2 = _mm256_fmadd_ps(dst2, src07, weight27); + __m256 weight37 = _mm256_load_ps(weight + 248); + dst3 = _mm256_fmadd_ps(dst3, src07, weight37); + src = src + src_stride; + weight += 1024; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 1 * src_stride + 0, dst1); + _mm256_store_ps(dst + 2 * src_stride + 0, dst2); + _mm256_store_ps(dst + 3 * src_stride + 0, dst3); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..5efc8dcfa81 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,174 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm2\n" + "vmovups 0(%[dst_4]), %%ymm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 32(%[bias]), %%ymm1\n" + "vmovaps 64(%[bias]), %%ymm2\n" + "vmovaps 96(%[bias]), %%ymm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_4 ] "r"(dst_4) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + asm volatile( + "0:\n" + // block 0 + "vbroadcastss 0(%[src]), %%ymm15\n" + "vmovaps 0(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 64(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 1 + "vbroadcastss 1(%[src]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 192(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 2 + "vbroadcastss 2(%[src]), %%ymm15\n" + "vmovaps 256(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 3 + "vbroadcastss 3(%[src]), %%ymm15\n" + "vmovaps 384(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 448(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 4 + "vbroadcastss 4(%[src]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 544(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 576(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 5 + "vbroadcastss 5(%[src]), %%ymm15\n" + "vmovaps 640(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 672(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 736(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 6 + "vbroadcastss 6(%[src]), %%ymm15\n" + "vmovaps 768(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 800(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 832(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 864(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 7 + "vbroadcastss 7(%[src]), %%ymm15\n" + "vmovaps 896(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vmovaps 928(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n" + "vmovaps 960(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n" + "vmovaps 992(%[weight]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 1024, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm3, 0(%[dst_4])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_4 ] "r"(dst_4) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..a795d14be0e --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..032876f77ae --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,105 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..55787a6905a --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,145 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst2; + __m256 dst1; + __m256 dst3; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst2 = _mm256_fmadd_ps(dst2, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst3 = _mm256_fmadd_ps(dst3, src10, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst2 = _mm256_fmadd_ps(dst2, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst3 = _mm256_fmadd_ps(dst3, src11, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst2 = _mm256_fmadd_ps(dst2, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst3 = _mm256_fmadd_ps(dst3, src12, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst2 = _mm256_fmadd_ps(dst2, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst3 = _mm256_fmadd_ps(dst3, src13, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst2 = _mm256_fmadd_ps(dst2, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst3 = _mm256_fmadd_ps(dst3, src14, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst2 = _mm256_fmadd_ps(dst2, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst3 = _mm256_fmadd_ps(dst3, src15, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst2 = _mm256_fmadd_ps(dst2, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst3 = _mm256_fmadd_ps(dst3, src16, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst2 = _mm256_fmadd_ps(dst2, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst3 = _mm256_fmadd_ps(dst3, src17, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 1 * src_stride + 0, dst2); + _mm256_store_ps(dst + 1 * src_stride + 8, dst3); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..c55abaab749 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,163 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 32(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..1de0370cb7c --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,185 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst2; + __m256 dst4; + __m256 dst1; + __m256 dst3; + __m256 dst5; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 2 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 16); + dst1 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst5 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst2 = _mm256_fmadd_ps(dst2, src00, weight10); + dst4 = _mm256_fmadd_ps(dst4, src00, weight20); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst3 = _mm256_fmadd_ps(dst3, src10, weight10); + dst5 = _mm256_fmadd_ps(dst5, src10, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst2 = _mm256_fmadd_ps(dst2, src01, weight11); + dst4 = _mm256_fmadd_ps(dst4, src01, weight21); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst3 = _mm256_fmadd_ps(dst3, src11, weight11); + dst5 = _mm256_fmadd_ps(dst5, src11, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst2 = _mm256_fmadd_ps(dst2, src02, weight12); + dst4 = _mm256_fmadd_ps(dst4, src02, weight22); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst3 = _mm256_fmadd_ps(dst3, src12, weight12); + dst5 = _mm256_fmadd_ps(dst5, src12, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst2 = _mm256_fmadd_ps(dst2, src03, weight13); + dst4 = _mm256_fmadd_ps(dst4, src03, weight23); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst3 = _mm256_fmadd_ps(dst3, src13, weight13); + dst5 = _mm256_fmadd_ps(dst5, src13, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst2 = _mm256_fmadd_ps(dst2, src04, weight14); + dst4 = _mm256_fmadd_ps(dst4, src04, weight24); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst3 = _mm256_fmadd_ps(dst3, src14, weight14); + dst5 = _mm256_fmadd_ps(dst5, src14, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst2 = _mm256_fmadd_ps(dst2, src05, weight15); + dst4 = _mm256_fmadd_ps(dst4, src05, weight25); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst3 = _mm256_fmadd_ps(dst3, src15, weight15); + dst5 = _mm256_fmadd_ps(dst5, src15, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst2 = _mm256_fmadd_ps(dst2, src06, weight16); + dst4 = _mm256_fmadd_ps(dst4, src06, weight26); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst3 = _mm256_fmadd_ps(dst3, src16, weight16); + dst5 = _mm256_fmadd_ps(dst5, src16, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst2 = _mm256_fmadd_ps(dst2, src07, weight17); + dst4 = _mm256_fmadd_ps(dst4, src07, weight27); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst3 = _mm256_fmadd_ps(dst3, src17, weight17); + dst5 = _mm256_fmadd_ps(dst5, src17, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 1 * src_stride + 0, dst2); + _mm256_store_ps(dst + 1 * src_stride + 8, dst3); + _mm256_store_ps(dst + 2 * src_stride + 0, dst4); + _mm256_store_ps(dst + 2 * src_stride + 8, dst5); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..c8e3309bd9b --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,199 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm5\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 32(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 64(%[bias]), %%ymm4\n" + "vmovaps 64(%[bias]), %%ymm5\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 2)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..0f9dbfa6ed9 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c @@ -0,0 +1,225 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst2; + __m256 dst4; + __m256 dst6; + __m256 dst1; + __m256 dst3; + __m256 dst5; + __m256 dst7; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 3 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 3 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 16); + dst6 = _mm256_load_ps(bias + 24); + dst1 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst5 = _mm256_load_ps(bias + 16); + dst7 = _mm256_load_ps(bias + 24); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 src00 = _mm256_set1_ps(*(src + 0)); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + __m256 weight00 = _mm256_load_ps(weight + 0); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 weight10 = _mm256_load_ps(weight + 8); + dst2 = _mm256_fmadd_ps(dst2, src00, weight10); + dst3 = _mm256_fmadd_ps(dst3, src10, weight10); + __m256 weight20 = _mm256_load_ps(weight + 16); + dst4 = _mm256_fmadd_ps(dst4, src00, weight20); + dst5 = _mm256_fmadd_ps(dst5, src10, weight20); + __m256 weight30 = _mm256_load_ps(weight + 24); + dst6 = _mm256_fmadd_ps(dst6, src00, weight30); + dst7 = _mm256_fmadd_ps(dst7, src10, weight30); + // bock1 + __m256 src01 = _mm256_set1_ps(*(src + 1)); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + __m256 weight01 = _mm256_load_ps(weight + 32); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 weight11 = _mm256_load_ps(weight + 40); + dst2 = _mm256_fmadd_ps(dst2, src01, weight11); + dst3 = _mm256_fmadd_ps(dst3, src11, weight11); + __m256 weight21 = _mm256_load_ps(weight + 48); + dst4 = _mm256_fmadd_ps(dst4, src01, weight21); + dst5 = _mm256_fmadd_ps(dst5, src11, weight21); + __m256 weight31 = _mm256_load_ps(weight + 56); + dst6 = _mm256_fmadd_ps(dst6, src01, weight31); + dst7 = _mm256_fmadd_ps(dst7, src11, weight31); + // bock2 + __m256 src02 = _mm256_set1_ps(*(src + 2)); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + __m256 weight02 = _mm256_load_ps(weight + 64); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 weight12 = _mm256_load_ps(weight + 72); + dst2 = _mm256_fmadd_ps(dst2, src02, weight12); + dst3 = _mm256_fmadd_ps(dst3, src12, weight12); + __m256 weight22 = _mm256_load_ps(weight + 80); + dst4 = _mm256_fmadd_ps(dst4, src02, weight22); + dst5 = _mm256_fmadd_ps(dst5, src12, weight22); + __m256 weight32 = _mm256_load_ps(weight + 88); + dst6 = _mm256_fmadd_ps(dst6, src02, weight32); + dst7 = _mm256_fmadd_ps(dst7, src12, weight32); + // bock3 + __m256 src03 = _mm256_set1_ps(*(src + 3)); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + __m256 weight03 = _mm256_load_ps(weight + 96); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 weight13 = _mm256_load_ps(weight + 104); + dst2 = _mm256_fmadd_ps(dst2, src03, weight13); + dst3 = _mm256_fmadd_ps(dst3, src13, weight13); + __m256 weight23 = _mm256_load_ps(weight + 112); + dst4 = _mm256_fmadd_ps(dst4, src03, weight23); + dst5 = _mm256_fmadd_ps(dst5, src13, weight23); + __m256 weight33 = _mm256_load_ps(weight + 120); + dst6 = _mm256_fmadd_ps(dst6, src03, weight33); + dst7 = _mm256_fmadd_ps(dst7, src13, weight33); + // bock4 + __m256 src04 = _mm256_set1_ps(*(src + 4)); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + __m256 weight04 = _mm256_load_ps(weight + 128); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 weight14 = _mm256_load_ps(weight + 136); + dst2 = _mm256_fmadd_ps(dst2, src04, weight14); + dst3 = _mm256_fmadd_ps(dst3, src14, weight14); + __m256 weight24 = _mm256_load_ps(weight + 144); + dst4 = _mm256_fmadd_ps(dst4, src04, weight24); + dst5 = _mm256_fmadd_ps(dst5, src14, weight24); + __m256 weight34 = _mm256_load_ps(weight + 152); + dst6 = _mm256_fmadd_ps(dst6, src04, weight34); + dst7 = _mm256_fmadd_ps(dst7, src14, weight34); + // bock5 + __m256 src05 = _mm256_set1_ps(*(src + 5)); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + __m256 weight05 = _mm256_load_ps(weight + 160); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 weight15 = _mm256_load_ps(weight + 168); + dst2 = _mm256_fmadd_ps(dst2, src05, weight15); + dst3 = _mm256_fmadd_ps(dst3, src15, weight15); + __m256 weight25 = _mm256_load_ps(weight + 176); + dst4 = _mm256_fmadd_ps(dst4, src05, weight25); + dst5 = _mm256_fmadd_ps(dst5, src15, weight25); + __m256 weight35 = _mm256_load_ps(weight + 184); + dst6 = _mm256_fmadd_ps(dst6, src05, weight35); + dst7 = _mm256_fmadd_ps(dst7, src15, weight35); + // bock6 + __m256 src06 = _mm256_set1_ps(*(src + 6)); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + __m256 weight06 = _mm256_load_ps(weight + 192); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 weight16 = _mm256_load_ps(weight + 200); + dst2 = _mm256_fmadd_ps(dst2, src06, weight16); + dst3 = _mm256_fmadd_ps(dst3, src16, weight16); + __m256 weight26 = _mm256_load_ps(weight + 208); + dst4 = _mm256_fmadd_ps(dst4, src06, weight26); + dst5 = _mm256_fmadd_ps(dst5, src16, weight26); + __m256 weight36 = _mm256_load_ps(weight + 216); + dst6 = _mm256_fmadd_ps(dst6, src06, weight36); + dst7 = _mm256_fmadd_ps(dst7, src16, weight36); + // bock7 + __m256 src07 = _mm256_set1_ps(*(src + 7)); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + __m256 weight07 = _mm256_load_ps(weight + 224); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 weight17 = _mm256_load_ps(weight + 232); + dst2 = _mm256_fmadd_ps(dst2, src07, weight17); + dst3 = _mm256_fmadd_ps(dst3, src17, weight17); + __m256 weight27 = _mm256_load_ps(weight + 240); + dst4 = _mm256_fmadd_ps(dst4, src07, weight27); + dst5 = _mm256_fmadd_ps(dst5, src17, weight27); + __m256 weight37 = _mm256_load_ps(weight + 248); + dst6 = _mm256_fmadd_ps(dst6, src07, weight37); + dst7 = _mm256_fmadd_ps(dst7, src17, weight37); + src = src + src_stride; + weight += 1024; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 1 * src_stride + 0, dst2); + _mm256_store_ps(dst + 1 * src_stride + 8, dst3); + _mm256_store_ps(dst + 2 * src_stride + 0, dst4); + _mm256_store_ps(dst + 2 * src_stride + 8, dst5); + _mm256_store_ps(dst + 3 * src_stride + 0, dst6); + _mm256_store_ps(dst + 3 * src_stride + 8, dst7); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..e8586a7c381 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,238 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm5\n" + "vmovups 0(%[dst_4]), %%ymm6\n" + "vmovups 32(%[dst_4]), %%ymm7\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 32(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 64(%[bias]), %%ymm4\n" + "vmovaps 64(%[bias]), %%ymm5\n" + "vmovaps 96(%[bias]), %%ymm6\n" + "vmovaps 96(%[bias]), %%ymm7\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_4 ] "r"(dst_4) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + asm volatile( + "0:\n" + // block 0 + "vbroadcastss 0(%[src]), %%ymm15\n" + "vbroadcastss 32(%[src]), %%ymm14\n" + "vmovaps 0(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 32(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 96(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 1 + "vbroadcastss 1(%[src]), %%ymm15\n" + "vbroadcastss 33(%[src]), %%ymm14\n" + "vmovaps 128(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 192(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 224(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 2 + "vbroadcastss 2(%[src]), %%ymm15\n" + "vbroadcastss 34(%[src]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 288(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 320(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 3 + "vbroadcastss 3(%[src]), %%ymm15\n" + "vbroadcastss 35(%[src]), %%ymm14\n" + "vmovaps 384(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 416(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 480(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 4 + "vbroadcastss 4(%[src]), %%ymm15\n" + "vbroadcastss 36(%[src]), %%ymm14\n" + "vmovaps 512(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 576(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 608(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 5 + "vbroadcastss 5(%[src]), %%ymm15\n" + "vbroadcastss 37(%[src]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 672(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 704(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 6 + "vbroadcastss 6(%[src]), %%ymm15\n" + "vbroadcastss 38(%[src]), %%ymm14\n" + "vmovaps 768(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 800(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 832(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 864(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + // block 7 + "vbroadcastss 7(%[src]), %%ymm15\n" + "vbroadcastss 39(%[src]), %%ymm14\n" + "vmovaps 896(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n" + "vmovaps 928(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vmovaps 960(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vmovaps 992(%[weight]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 1024, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm6, 0(%[dst_4])\n" + "vmovups %%ymm7, 32(%[dst_4])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_4 ] "r"(dst_4) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..6fcf7958a74 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,105 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..279370a0085 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,127 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..8cec925905b --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,185 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst3; + __m256 dst1; + __m256 dst4; + __m256 dst2; + __m256 dst5; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst3 = _mm256_fmadd_ps(dst3, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst4 = _mm256_fmadd_ps(dst4, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst5 = _mm256_fmadd_ps(dst5, src20, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst3 = _mm256_fmadd_ps(dst3, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst4 = _mm256_fmadd_ps(dst4, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst5 = _mm256_fmadd_ps(dst5, src21, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst3 = _mm256_fmadd_ps(dst3, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst4 = _mm256_fmadd_ps(dst4, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst5 = _mm256_fmadd_ps(dst5, src22, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst3 = _mm256_fmadd_ps(dst3, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst4 = _mm256_fmadd_ps(dst4, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst5 = _mm256_fmadd_ps(dst5, src23, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst3 = _mm256_fmadd_ps(dst3, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst4 = _mm256_fmadd_ps(dst4, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst5 = _mm256_fmadd_ps(dst5, src24, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst3 = _mm256_fmadd_ps(dst3, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst4 = _mm256_fmadd_ps(dst4, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst5 = _mm256_fmadd_ps(dst5, src25, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst3 = _mm256_fmadd_ps(dst3, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst4 = _mm256_fmadd_ps(dst4, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst5 = _mm256_fmadd_ps(dst5, src26, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst3 = _mm256_fmadd_ps(dst3, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst4 = _mm256_fmadd_ps(dst4, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst5 = _mm256_fmadd_ps(dst5, src27, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 1 * src_stride + 0, dst3); + _mm256_store_ps(dst + 1 * src_stride + 8, dst4); + _mm256_store_ps(dst + 1 * src_stride + 16, dst5); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..e2ddb336e60 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,199 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..db5d05d6f0c --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,241 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst3; + __m256 dst6; + __m256 dst1; + __m256 dst4; + __m256 dst7; + __m256 dst2; + __m256 dst5; + __m256 dst8; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst8 = _mm256_load_ps(dst + 2 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst6 = _mm256_load_ps(bias + 16); + dst1 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst7 = _mm256_load_ps(bias + 16); + dst2 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst8 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst3 = _mm256_fmadd_ps(dst3, src00, weight10); + dst6 = _mm256_fmadd_ps(dst6, src00, weight20); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst4 = _mm256_fmadd_ps(dst4, src10, weight10); + dst7 = _mm256_fmadd_ps(dst7, src10, weight20); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst5 = _mm256_fmadd_ps(dst5, src20, weight10); + dst8 = _mm256_fmadd_ps(dst8, src20, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst3 = _mm256_fmadd_ps(dst3, src01, weight11); + dst6 = _mm256_fmadd_ps(dst6, src01, weight21); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst4 = _mm256_fmadd_ps(dst4, src11, weight11); + dst7 = _mm256_fmadd_ps(dst7, src11, weight21); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst5 = _mm256_fmadd_ps(dst5, src21, weight11); + dst8 = _mm256_fmadd_ps(dst8, src21, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst3 = _mm256_fmadd_ps(dst3, src02, weight12); + dst6 = _mm256_fmadd_ps(dst6, src02, weight22); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst4 = _mm256_fmadd_ps(dst4, src12, weight12); + dst7 = _mm256_fmadd_ps(dst7, src12, weight22); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst5 = _mm256_fmadd_ps(dst5, src22, weight12); + dst8 = _mm256_fmadd_ps(dst8, src22, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst3 = _mm256_fmadd_ps(dst3, src03, weight13); + dst6 = _mm256_fmadd_ps(dst6, src03, weight23); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst4 = _mm256_fmadd_ps(dst4, src13, weight13); + dst7 = _mm256_fmadd_ps(dst7, src13, weight23); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst5 = _mm256_fmadd_ps(dst5, src23, weight13); + dst8 = _mm256_fmadd_ps(dst8, src23, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst3 = _mm256_fmadd_ps(dst3, src04, weight14); + dst6 = _mm256_fmadd_ps(dst6, src04, weight24); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst4 = _mm256_fmadd_ps(dst4, src14, weight14); + dst7 = _mm256_fmadd_ps(dst7, src14, weight24); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst5 = _mm256_fmadd_ps(dst5, src24, weight14); + dst8 = _mm256_fmadd_ps(dst8, src24, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst3 = _mm256_fmadd_ps(dst3, src05, weight15); + dst6 = _mm256_fmadd_ps(dst6, src05, weight25); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst4 = _mm256_fmadd_ps(dst4, src15, weight15); + dst7 = _mm256_fmadd_ps(dst7, src15, weight25); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst5 = _mm256_fmadd_ps(dst5, src25, weight15); + dst8 = _mm256_fmadd_ps(dst8, src25, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst3 = _mm256_fmadd_ps(dst3, src06, weight16); + dst6 = _mm256_fmadd_ps(dst6, src06, weight26); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst4 = _mm256_fmadd_ps(dst4, src16, weight16); + dst7 = _mm256_fmadd_ps(dst7, src16, weight26); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst5 = _mm256_fmadd_ps(dst5, src26, weight16); + dst8 = _mm256_fmadd_ps(dst8, src26, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst3 = _mm256_fmadd_ps(dst3, src07, weight17); + dst6 = _mm256_fmadd_ps(dst6, src07, weight27); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst4 = _mm256_fmadd_ps(dst4, src17, weight17); + dst7 = _mm256_fmadd_ps(dst7, src17, weight27); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst5 = _mm256_fmadd_ps(dst5, src27, weight17); + dst8 = _mm256_fmadd_ps(dst8, src27, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 1 * src_stride + 0, dst3); + _mm256_store_ps(dst + 1 * src_stride + 8, dst4); + _mm256_store_ps(dst + 1 * src_stride + 16, dst5); + _mm256_store_ps(dst + 2 * src_stride + 0, dst6); + _mm256_store_ps(dst + 2 * src_stride + 8, dst7); + _mm256_store_ps(dst + 2 * src_stride + 16, dst8); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..487634f1384 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,249 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm6\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm7\n" + "vmovups 64(%[dst], %[dst_stride], 2), %%ymm8\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 64(%[bias]), %%ymm6\n" + "vmovaps 64(%[bias]), %%ymm7\n" + "vmovaps 64(%[bias]), %%ymm8\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm7, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm8, 64(%[dst], %[dst_stride], 2)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..eaf7595f1bc --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c @@ -0,0 +1,297 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst3; + __m256 dst6; + __m256 dst9; + __m256 dst1; + __m256 dst4; + __m256 dst7; + __m256 dst10; + __m256 dst2; + __m256 dst5; + __m256 dst8; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst9 = _mm256_load_ps(dst + 3 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst10 = _mm256_load_ps(dst + 3 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst8 = _mm256_load_ps(dst + 2 * dst_stride + 16); + dst11 = _mm256_load_ps(dst + 3 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 8); + dst6 = _mm256_load_ps(bias + 16); + dst9 = _mm256_load_ps(bias + 24); + dst1 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst7 = _mm256_load_ps(bias + 16); + dst10 = _mm256_load_ps(bias + 24); + dst2 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst8 = _mm256_load_ps(bias + 16); + dst11 = _mm256_load_ps(bias + 24); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 src00 = _mm256_set1_ps(*(src + 0)); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + __m256 weight00 = _mm256_load_ps(weight + 0); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 weight10 = _mm256_load_ps(weight + 8); + dst3 = _mm256_fmadd_ps(dst3, src00, weight10); + dst4 = _mm256_fmadd_ps(dst4, src10, weight10); + dst5 = _mm256_fmadd_ps(dst5, src20, weight10); + __m256 weight20 = _mm256_load_ps(weight + 16); + dst6 = _mm256_fmadd_ps(dst6, src00, weight20); + dst7 = _mm256_fmadd_ps(dst7, src10, weight20); + dst8 = _mm256_fmadd_ps(dst8, src20, weight20); + __m256 weight30 = _mm256_load_ps(weight + 24); + dst9 = _mm256_fmadd_ps(dst9, src00, weight30); + dst10 = _mm256_fmadd_ps(dst10, src10, weight30); + dst11 = _mm256_fmadd_ps(dst11, src20, weight30); + // bock1 + __m256 src01 = _mm256_set1_ps(*(src + 1)); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + __m256 weight01 = _mm256_load_ps(weight + 32); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 weight11 = _mm256_load_ps(weight + 40); + dst3 = _mm256_fmadd_ps(dst3, src01, weight11); + dst4 = _mm256_fmadd_ps(dst4, src11, weight11); + dst5 = _mm256_fmadd_ps(dst5, src21, weight11); + __m256 weight21 = _mm256_load_ps(weight + 48); + dst6 = _mm256_fmadd_ps(dst6, src01, weight21); + dst7 = _mm256_fmadd_ps(dst7, src11, weight21); + dst8 = _mm256_fmadd_ps(dst8, src21, weight21); + __m256 weight31 = _mm256_load_ps(weight + 56); + dst9 = _mm256_fmadd_ps(dst9, src01, weight31); + dst10 = _mm256_fmadd_ps(dst10, src11, weight31); + dst11 = _mm256_fmadd_ps(dst11, src21, weight31); + // bock2 + __m256 src02 = _mm256_set1_ps(*(src + 2)); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + __m256 weight02 = _mm256_load_ps(weight + 64); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 weight12 = _mm256_load_ps(weight + 72); + dst3 = _mm256_fmadd_ps(dst3, src02, weight12); + dst4 = _mm256_fmadd_ps(dst4, src12, weight12); + dst5 = _mm256_fmadd_ps(dst5, src22, weight12); + __m256 weight22 = _mm256_load_ps(weight + 80); + dst6 = _mm256_fmadd_ps(dst6, src02, weight22); + dst7 = _mm256_fmadd_ps(dst7, src12, weight22); + dst8 = _mm256_fmadd_ps(dst8, src22, weight22); + __m256 weight32 = _mm256_load_ps(weight + 88); + dst9 = _mm256_fmadd_ps(dst9, src02, weight32); + dst10 = _mm256_fmadd_ps(dst10, src12, weight32); + dst11 = _mm256_fmadd_ps(dst11, src22, weight32); + // bock3 + __m256 src03 = _mm256_set1_ps(*(src + 3)); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + __m256 weight03 = _mm256_load_ps(weight + 96); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 weight13 = _mm256_load_ps(weight + 104); + dst3 = _mm256_fmadd_ps(dst3, src03, weight13); + dst4 = _mm256_fmadd_ps(dst4, src13, weight13); + dst5 = _mm256_fmadd_ps(dst5, src23, weight13); + __m256 weight23 = _mm256_load_ps(weight + 112); + dst6 = _mm256_fmadd_ps(dst6, src03, weight23); + dst7 = _mm256_fmadd_ps(dst7, src13, weight23); + dst8 = _mm256_fmadd_ps(dst8, src23, weight23); + __m256 weight33 = _mm256_load_ps(weight + 120); + dst9 = _mm256_fmadd_ps(dst9, src03, weight33); + dst10 = _mm256_fmadd_ps(dst10, src13, weight33); + dst11 = _mm256_fmadd_ps(dst11, src23, weight33); + // bock4 + __m256 src04 = _mm256_set1_ps(*(src + 4)); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + __m256 weight04 = _mm256_load_ps(weight + 128); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 weight14 = _mm256_load_ps(weight + 136); + dst3 = _mm256_fmadd_ps(dst3, src04, weight14); + dst4 = _mm256_fmadd_ps(dst4, src14, weight14); + dst5 = _mm256_fmadd_ps(dst5, src24, weight14); + __m256 weight24 = _mm256_load_ps(weight + 144); + dst6 = _mm256_fmadd_ps(dst6, src04, weight24); + dst7 = _mm256_fmadd_ps(dst7, src14, weight24); + dst8 = _mm256_fmadd_ps(dst8, src24, weight24); + __m256 weight34 = _mm256_load_ps(weight + 152); + dst9 = _mm256_fmadd_ps(dst9, src04, weight34); + dst10 = _mm256_fmadd_ps(dst10, src14, weight34); + dst11 = _mm256_fmadd_ps(dst11, src24, weight34); + // bock5 + __m256 src05 = _mm256_set1_ps(*(src + 5)); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + __m256 weight05 = _mm256_load_ps(weight + 160); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 weight15 = _mm256_load_ps(weight + 168); + dst3 = _mm256_fmadd_ps(dst3, src05, weight15); + dst4 = _mm256_fmadd_ps(dst4, src15, weight15); + dst5 = _mm256_fmadd_ps(dst5, src25, weight15); + __m256 weight25 = _mm256_load_ps(weight + 176); + dst6 = _mm256_fmadd_ps(dst6, src05, weight25); + dst7 = _mm256_fmadd_ps(dst7, src15, weight25); + dst8 = _mm256_fmadd_ps(dst8, src25, weight25); + __m256 weight35 = _mm256_load_ps(weight + 184); + dst9 = _mm256_fmadd_ps(dst9, src05, weight35); + dst10 = _mm256_fmadd_ps(dst10, src15, weight35); + dst11 = _mm256_fmadd_ps(dst11, src25, weight35); + // bock6 + __m256 src06 = _mm256_set1_ps(*(src + 6)); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + __m256 weight06 = _mm256_load_ps(weight + 192); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 weight16 = _mm256_load_ps(weight + 200); + dst3 = _mm256_fmadd_ps(dst3, src06, weight16); + dst4 = _mm256_fmadd_ps(dst4, src16, weight16); + dst5 = _mm256_fmadd_ps(dst5, src26, weight16); + __m256 weight26 = _mm256_load_ps(weight + 208); + dst6 = _mm256_fmadd_ps(dst6, src06, weight26); + dst7 = _mm256_fmadd_ps(dst7, src16, weight26); + dst8 = _mm256_fmadd_ps(dst8, src26, weight26); + __m256 weight36 = _mm256_load_ps(weight + 216); + dst9 = _mm256_fmadd_ps(dst9, src06, weight36); + dst10 = _mm256_fmadd_ps(dst10, src16, weight36); + dst11 = _mm256_fmadd_ps(dst11, src26, weight36); + // bock7 + __m256 src07 = _mm256_set1_ps(*(src + 7)); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + __m256 weight07 = _mm256_load_ps(weight + 224); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 weight17 = _mm256_load_ps(weight + 232); + dst3 = _mm256_fmadd_ps(dst3, src07, weight17); + dst4 = _mm256_fmadd_ps(dst4, src17, weight17); + dst5 = _mm256_fmadd_ps(dst5, src27, weight17); + __m256 weight27 = _mm256_load_ps(weight + 240); + dst6 = _mm256_fmadd_ps(dst6, src07, weight27); + dst7 = _mm256_fmadd_ps(dst7, src17, weight27); + dst8 = _mm256_fmadd_ps(dst8, src27, weight27); + __m256 weight37 = _mm256_load_ps(weight + 248); + dst9 = _mm256_fmadd_ps(dst9, src07, weight37); + dst10 = _mm256_fmadd_ps(dst10, src17, weight37); + dst11 = _mm256_fmadd_ps(dst11, src27, weight37); + src = src + src_stride; + weight += 1024; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 1 * src_stride + 0, dst3); + _mm256_store_ps(dst + 1 * src_stride + 8, dst4); + _mm256_store_ps(dst + 1 * src_stride + 16, dst5); + _mm256_store_ps(dst + 2 * src_stride + 0, dst6); + _mm256_store_ps(dst + 2 * src_stride + 8, dst7); + _mm256_store_ps(dst + 2 * src_stride + 16, dst8); + _mm256_store_ps(dst + 3 * src_stride + 0, dst9); + _mm256_store_ps(dst + 3 * src_stride + 8, dst10); + _mm256_store_ps(dst + 3 * src_stride + 16, dst11); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..b0f09911f29 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,302 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm6\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm7\n" + "vmovups 64(%[dst], %[dst_stride], 2), %%ymm8\n" + "vmovups 0(%[dst_4]), %%ymm9\n" + "vmovups 32(%[dst_4]), %%ymm10\n" + "vmovups 64(%[dst_4]), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 32(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 64(%[bias]), %%ymm6\n" + "vmovaps 64(%[bias]), %%ymm7\n" + "vmovaps 64(%[bias]), %%ymm8\n" + "vmovaps 96(%[bias]), %%ymm9\n" + "vmovaps 96(%[bias]), %%ymm10\n" + "vmovaps 96(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_4 ] "r"(dst_4) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vbroadcastss 0(%[src]), %%ymm15\n" + "vbroadcastss 32(%[src]), %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vmovaps 0(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 32(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 64(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 96(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 1 + "vbroadcastss 1(%[src]), %%ymm15\n" + "vbroadcastss 33(%[src]), %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vmovaps 128(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 160(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 192(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 224(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 2 + "vbroadcastss 2(%[src]), %%ymm15\n" + "vbroadcastss 34(%[src]), %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vmovaps 256(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 288(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 320(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 352(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 3 + "vbroadcastss 3(%[src]), %%ymm15\n" + "vbroadcastss 35(%[src]), %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vmovaps 384(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 416(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 448(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 480(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 4 + "vbroadcastss 4(%[src]), %%ymm15\n" + "vbroadcastss 36(%[src]), %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vmovaps 512(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 544(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 576(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 608(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 5 + "vbroadcastss 5(%[src]), %%ymm15\n" + "vbroadcastss 37(%[src]), %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vmovaps 640(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 672(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 704(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 736(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 6 + "vbroadcastss 6(%[src]), %%ymm15\n" + "vbroadcastss 38(%[src]), %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vmovaps 768(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 800(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 832(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 864(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 7 + "vbroadcastss 7(%[src]), %%ymm15\n" + "vbroadcastss 39(%[src]), %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vmovaps 896(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n" + "vmovaps 928(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n" + "vmovaps 960(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vmovaps 992(%[weight]), %%ymm12\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 1024, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm7, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm8, 64(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm9, 0(%[dst_4])\n" + "vmovups %%ymm10, 32(%[dst_4])\n" + "vmovups %%ymm11, 64(%[dst_4])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_4 ] "r"(dst_4) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..aff74f6f09f --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,129 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..8c51f806442 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..4b34d19bea6 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,225 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst4; + __m256 dst1; + __m256 dst5; + __m256 dst2; + __m256 dst6; + __m256 dst3; + __m256 dst7; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 24); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst3 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst4 = _mm256_fmadd_ps(dst4, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst5 = _mm256_fmadd_ps(dst5, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst6 = _mm256_fmadd_ps(dst6, src20, weight10); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst7 = _mm256_fmadd_ps(dst7, src30, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst4 = _mm256_fmadd_ps(dst4, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst5 = _mm256_fmadd_ps(dst5, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst6 = _mm256_fmadd_ps(dst6, src21, weight11); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst7 = _mm256_fmadd_ps(dst7, src31, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst4 = _mm256_fmadd_ps(dst4, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst5 = _mm256_fmadd_ps(dst5, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst6 = _mm256_fmadd_ps(dst6, src22, weight12); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst7 = _mm256_fmadd_ps(dst7, src32, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst4 = _mm256_fmadd_ps(dst4, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst5 = _mm256_fmadd_ps(dst5, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst6 = _mm256_fmadd_ps(dst6, src23, weight13); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst7 = _mm256_fmadd_ps(dst7, src33, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst4 = _mm256_fmadd_ps(dst4, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst5 = _mm256_fmadd_ps(dst5, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst6 = _mm256_fmadd_ps(dst6, src24, weight14); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst7 = _mm256_fmadd_ps(dst7, src34, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst4 = _mm256_fmadd_ps(dst4, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst5 = _mm256_fmadd_ps(dst5, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst6 = _mm256_fmadd_ps(dst6, src25, weight15); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst7 = _mm256_fmadd_ps(dst7, src35, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst4 = _mm256_fmadd_ps(dst4, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst5 = _mm256_fmadd_ps(dst5, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst6 = _mm256_fmadd_ps(dst6, src26, weight16); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst7 = _mm256_fmadd_ps(dst7, src36, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst4 = _mm256_fmadd_ps(dst4, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst5 = _mm256_fmadd_ps(dst5, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst6 = _mm256_fmadd_ps(dst6, src27, weight17); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst7 = _mm256_fmadd_ps(dst7, src37, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 1 * src_stride + 0, dst4); + _mm256_store_ps(dst + 1 * src_stride + 8, dst5); + _mm256_store_ps(dst + 1 * src_stride + 16, dst6); + _mm256_store_ps(dst + 1 * src_stride + 24, dst7); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..061f1d8cd49 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,235 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm7\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 96(%[dst], %[dst_stride], 1)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..8d987c6153a --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c @@ -0,0 +1,297 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst4; + __m256 dst8; + __m256 dst1; + __m256 dst5; + __m256 dst9; + __m256 dst2; + __m256 dst6; + __m256 dst10; + __m256 dst3; + __m256 dst7; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst4 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst8 = _mm256_load_ps(dst + 2 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst9 = _mm256_load_ps(dst + 2 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst10 = _mm256_load_ps(dst + 2 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 24); + dst11 = _mm256_load_ps(dst + 2 * dst_stride + 24); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 8); + dst8 = _mm256_load_ps(bias + 16); + dst1 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst9 = _mm256_load_ps(bias + 16); + dst2 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst10 = _mm256_load_ps(bias + 16); + dst3 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + dst11 = _mm256_load_ps(bias + 16); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 weight20 = _mm256_load_ps(weight + 16); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst4 = _mm256_fmadd_ps(dst4, src00, weight10); + dst8 = _mm256_fmadd_ps(dst8, src00, weight20); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst5 = _mm256_fmadd_ps(dst5, src10, weight10); + dst9 = _mm256_fmadd_ps(dst9, src10, weight20); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst6 = _mm256_fmadd_ps(dst6, src20, weight10); + dst10 = _mm256_fmadd_ps(dst10, src20, weight20); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst7 = _mm256_fmadd_ps(dst7, src30, weight10); + dst11 = _mm256_fmadd_ps(dst11, src30, weight20); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 24); + __m256 weight11 = _mm256_load_ps(weight + 32); + __m256 weight21 = _mm256_load_ps(weight + 40); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst4 = _mm256_fmadd_ps(dst4, src01, weight11); + dst8 = _mm256_fmadd_ps(dst8, src01, weight21); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst5 = _mm256_fmadd_ps(dst5, src11, weight11); + dst9 = _mm256_fmadd_ps(dst9, src11, weight21); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst6 = _mm256_fmadd_ps(dst6, src21, weight11); + dst10 = _mm256_fmadd_ps(dst10, src21, weight21); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst7 = _mm256_fmadd_ps(dst7, src31, weight11); + dst11 = _mm256_fmadd_ps(dst11, src31, weight21); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 48); + __m256 weight12 = _mm256_load_ps(weight + 56); + __m256 weight22 = _mm256_load_ps(weight + 64); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst4 = _mm256_fmadd_ps(dst4, src02, weight12); + dst8 = _mm256_fmadd_ps(dst8, src02, weight22); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst5 = _mm256_fmadd_ps(dst5, src12, weight12); + dst9 = _mm256_fmadd_ps(dst9, src12, weight22); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst6 = _mm256_fmadd_ps(dst6, src22, weight12); + dst10 = _mm256_fmadd_ps(dst10, src22, weight22); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst7 = _mm256_fmadd_ps(dst7, src32, weight12); + dst11 = _mm256_fmadd_ps(dst11, src32, weight22); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 72); + __m256 weight13 = _mm256_load_ps(weight + 80); + __m256 weight23 = _mm256_load_ps(weight + 88); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst4 = _mm256_fmadd_ps(dst4, src03, weight13); + dst8 = _mm256_fmadd_ps(dst8, src03, weight23); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst5 = _mm256_fmadd_ps(dst5, src13, weight13); + dst9 = _mm256_fmadd_ps(dst9, src13, weight23); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst6 = _mm256_fmadd_ps(dst6, src23, weight13); + dst10 = _mm256_fmadd_ps(dst10, src23, weight23); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst7 = _mm256_fmadd_ps(dst7, src33, weight13); + dst11 = _mm256_fmadd_ps(dst11, src33, weight23); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 96); + __m256 weight14 = _mm256_load_ps(weight + 104); + __m256 weight24 = _mm256_load_ps(weight + 112); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst4 = _mm256_fmadd_ps(dst4, src04, weight14); + dst8 = _mm256_fmadd_ps(dst8, src04, weight24); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst5 = _mm256_fmadd_ps(dst5, src14, weight14); + dst9 = _mm256_fmadd_ps(dst9, src14, weight24); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst6 = _mm256_fmadd_ps(dst6, src24, weight14); + dst10 = _mm256_fmadd_ps(dst10, src24, weight24); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst7 = _mm256_fmadd_ps(dst7, src34, weight14); + dst11 = _mm256_fmadd_ps(dst11, src34, weight24); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 120); + __m256 weight15 = _mm256_load_ps(weight + 128); + __m256 weight25 = _mm256_load_ps(weight + 136); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst4 = _mm256_fmadd_ps(dst4, src05, weight15); + dst8 = _mm256_fmadd_ps(dst8, src05, weight25); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst5 = _mm256_fmadd_ps(dst5, src15, weight15); + dst9 = _mm256_fmadd_ps(dst9, src15, weight25); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst6 = _mm256_fmadd_ps(dst6, src25, weight15); + dst10 = _mm256_fmadd_ps(dst10, src25, weight25); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst7 = _mm256_fmadd_ps(dst7, src35, weight15); + dst11 = _mm256_fmadd_ps(dst11, src35, weight25); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 144); + __m256 weight16 = _mm256_load_ps(weight + 152); + __m256 weight26 = _mm256_load_ps(weight + 160); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst4 = _mm256_fmadd_ps(dst4, src06, weight16); + dst8 = _mm256_fmadd_ps(dst8, src06, weight26); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst5 = _mm256_fmadd_ps(dst5, src16, weight16); + dst9 = _mm256_fmadd_ps(dst9, src16, weight26); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst6 = _mm256_fmadd_ps(dst6, src26, weight16); + dst10 = _mm256_fmadd_ps(dst10, src26, weight26); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst7 = _mm256_fmadd_ps(dst7, src36, weight16); + dst11 = _mm256_fmadd_ps(dst11, src36, weight26); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 168); + __m256 weight17 = _mm256_load_ps(weight + 176); + __m256 weight27 = _mm256_load_ps(weight + 184); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst4 = _mm256_fmadd_ps(dst4, src07, weight17); + dst8 = _mm256_fmadd_ps(dst8, src07, weight27); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst5 = _mm256_fmadd_ps(dst5, src17, weight17); + dst9 = _mm256_fmadd_ps(dst9, src17, weight27); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst6 = _mm256_fmadd_ps(dst6, src27, weight17); + dst10 = _mm256_fmadd_ps(dst10, src27, weight27); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst7 = _mm256_fmadd_ps(dst7, src37, weight17); + dst11 = _mm256_fmadd_ps(dst11, src37, weight27); + src = src + src_stride; + weight += 768; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 1 * src_stride + 0, dst4); + _mm256_store_ps(dst + 1 * src_stride + 8, dst5); + _mm256_store_ps(dst + 1 * src_stride + 16, dst6); + _mm256_store_ps(dst + 1 * src_stride + 24, dst7); + _mm256_store_ps(dst + 2 * src_stride + 0, dst8); + _mm256_store_ps(dst + 2 * src_stride + 8, dst9); + _mm256_store_ps(dst + 2 * src_stride + 16, dst10); + _mm256_store_ps(dst + 2 * src_stride + 24, dst11); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..d949566e075 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,299 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm4\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm7\n" + "vmovups 0(%[dst], %[dst_stride], 2), %%ymm8\n" + "vmovups 32(%[dst], %[dst_stride], 2), %%ymm9\n" + "vmovups 64(%[dst], %[dst_stride], 2), %%ymm10\n" + "vmovups 96(%[dst], %[dst_stride], 2), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 32(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "vmovaps 64(%[bias]), %%ymm8\n" + "vmovaps 64(%[bias]), %%ymm9\n" + "vmovaps 64(%[bias]), %%ymm10\n" + "vmovaps 64(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vmovaps 64(%[weight]), %%ymm13\n" + "vbroadcastss 0(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 1 + "vmovaps 96(%[weight]), %%ymm15\n" + "vmovaps 128(%[weight]), %%ymm14\n" + "vmovaps 160(%[weight]), %%ymm13\n" + "vbroadcastss 1(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 2 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vmovaps 256(%[weight]), %%ymm13\n" + "vbroadcastss 2(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 3 + "vmovaps 288(%[weight]), %%ymm15\n" + "vmovaps 320(%[weight]), %%ymm14\n" + "vmovaps 352(%[weight]), %%ymm13\n" + "vbroadcastss 3(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 4 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vmovaps 448(%[weight]), %%ymm13\n" + "vbroadcastss 4(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 5 + "vmovaps 480(%[weight]), %%ymm15\n" + "vmovaps 512(%[weight]), %%ymm14\n" + "vmovaps 544(%[weight]), %%ymm13\n" + "vbroadcastss 5(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 6 + "vmovaps 576(%[weight]), %%ymm15\n" + "vmovaps 608(%[weight]), %%ymm14\n" + "vmovaps 640(%[weight]), %%ymm13\n" + "vbroadcastss 6(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + // block 7 + "vmovaps 672(%[weight]), %%ymm15\n" + "vmovaps 704(%[weight]), %%ymm14\n" + "vmovaps 736(%[weight]), %%ymm13\n" + "vbroadcastss 7(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n" + "dec %[deep]\n" + "add 768, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm5, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 96(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm8, 0(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm9, 32(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm10, 64(%[dst], %[dst_stride], 2)\n" + "vmovups %%ymm11, 96(%[dst], %[dst_stride], 2)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..c972e8a846b --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,153 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..05b800dfc1e --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,171 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..5426fb627ad --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,265 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst5; + __m256 dst1; + __m256 dst6; + __m256 dst2; + __m256 dst7; + __m256 dst3; + __m256 dst8; + __m256 dst4; + __m256 dst9; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst5 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst8 = _mm256_load_ps(dst + 1 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst9 = _mm256_load_ps(dst + 1 * dst_stride + 32); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + dst3 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst5 = _mm256_fmadd_ps(dst5, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst6 = _mm256_fmadd_ps(dst6, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst7 = _mm256_fmadd_ps(dst7, src20, weight10); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst8 = _mm256_fmadd_ps(dst8, src30, weight10); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + dst9 = _mm256_fmadd_ps(dst9, src40, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst5 = _mm256_fmadd_ps(dst5, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst6 = _mm256_fmadd_ps(dst6, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst7 = _mm256_fmadd_ps(dst7, src21, weight11); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst8 = _mm256_fmadd_ps(dst8, src31, weight11); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + dst9 = _mm256_fmadd_ps(dst9, src41, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst5 = _mm256_fmadd_ps(dst5, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst6 = _mm256_fmadd_ps(dst6, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst7 = _mm256_fmadd_ps(dst7, src22, weight12); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst8 = _mm256_fmadd_ps(dst8, src32, weight12); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + dst9 = _mm256_fmadd_ps(dst9, src42, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst5 = _mm256_fmadd_ps(dst5, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst6 = _mm256_fmadd_ps(dst6, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst7 = _mm256_fmadd_ps(dst7, src23, weight13); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst8 = _mm256_fmadd_ps(dst8, src33, weight13); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + dst9 = _mm256_fmadd_ps(dst9, src43, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst5 = _mm256_fmadd_ps(dst5, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst6 = _mm256_fmadd_ps(dst6, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst7 = _mm256_fmadd_ps(dst7, src24, weight14); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst8 = _mm256_fmadd_ps(dst8, src34, weight14); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + dst9 = _mm256_fmadd_ps(dst9, src44, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst5 = _mm256_fmadd_ps(dst5, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst6 = _mm256_fmadd_ps(dst6, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst7 = _mm256_fmadd_ps(dst7, src25, weight15); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst8 = _mm256_fmadd_ps(dst8, src35, weight15); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + dst9 = _mm256_fmadd_ps(dst9, src45, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst5 = _mm256_fmadd_ps(dst5, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst6 = _mm256_fmadd_ps(dst6, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst7 = _mm256_fmadd_ps(dst7, src26, weight16); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst8 = _mm256_fmadd_ps(dst8, src36, weight16); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + dst9 = _mm256_fmadd_ps(dst9, src46, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst5 = _mm256_fmadd_ps(dst5, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst6 = _mm256_fmadd_ps(dst6, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst7 = _mm256_fmadd_ps(dst7, src27, weight17); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst8 = _mm256_fmadd_ps(dst8, src37, weight17); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + dst9 = _mm256_fmadd_ps(dst9, src47, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst9 = _mm256_max_ps(dst9, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 1 * src_stride + 0, dst5); + _mm256_store_ps(dst + 1 * src_stride + 8, dst6); + _mm256_store_ps(dst + 1 * src_stride + 16, dst7); + _mm256_store_ps(dst + 1 * src_stride + 24, dst8); + _mm256_store_ps(dst + 1 * src_stride + 32, dst9); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..56f6aed3c69 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,271 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm5\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm7\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm8\n" + "vmovups 128(%[dst], %[dst_stride], 1), %%ymm9\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 32(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "vmovaps 32(%[bias]), %%ymm8\n" + "vmovaps 32(%[bias]), %%ymm9\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm6, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm8, 96(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm9, 128(%[dst], %[dst_stride], 1)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..c3a6e1a6197 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,177 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..069f81ee1a8 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,193 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..5e074d19f9a --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c @@ -0,0 +1,305 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst6; + __m256 dst1; + __m256 dst7; + __m256 dst2; + __m256 dst8; + __m256 dst3; + __m256 dst9; + __m256 dst4; + __m256 dst10; + __m256 dst5; + __m256 dst11; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst6 = _mm256_load_ps(dst + 1 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst7 = _mm256_load_ps(dst + 1 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst8 = _mm256_load_ps(dst + 1 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst9 = _mm256_load_ps(dst + 1 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst10 = _mm256_load_ps(dst + 1 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst11 = _mm256_load_ps(dst + 1 * dst_stride + 40); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + dst9 = _mm256_setzero_ps(); + dst10 = _mm256_setzero_ps(); + dst11 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 8); + dst1 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 8); + dst2 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 8); + dst3 = _mm256_load_ps(bias + 0); + dst9 = _mm256_load_ps(bias + 8); + dst4 = _mm256_load_ps(bias + 0); + dst10 = _mm256_load_ps(bias + 8); + dst5 = _mm256_load_ps(bias + 0); + dst11 = _mm256_load_ps(bias + 8); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 weight10 = _mm256_load_ps(weight + 8); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + dst6 = _mm256_fmadd_ps(dst6, src00, weight10); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + dst7 = _mm256_fmadd_ps(dst7, src10, weight10); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + dst8 = _mm256_fmadd_ps(dst8, src20, weight10); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + dst9 = _mm256_fmadd_ps(dst9, src30, weight10); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + dst10 = _mm256_fmadd_ps(dst10, src40, weight10); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + dst11 = _mm256_fmadd_ps(dst11, src50, weight10); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 16); + __m256 weight11 = _mm256_load_ps(weight + 24); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + dst6 = _mm256_fmadd_ps(dst6, src01, weight11); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + dst7 = _mm256_fmadd_ps(dst7, src11, weight11); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + dst8 = _mm256_fmadd_ps(dst8, src21, weight11); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + dst9 = _mm256_fmadd_ps(dst9, src31, weight11); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + dst10 = _mm256_fmadd_ps(dst10, src41, weight11); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + dst11 = _mm256_fmadd_ps(dst11, src51, weight11); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 32); + __m256 weight12 = _mm256_load_ps(weight + 40); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + dst6 = _mm256_fmadd_ps(dst6, src02, weight12); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + dst7 = _mm256_fmadd_ps(dst7, src12, weight12); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + dst8 = _mm256_fmadd_ps(dst8, src22, weight12); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + dst9 = _mm256_fmadd_ps(dst9, src32, weight12); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + dst10 = _mm256_fmadd_ps(dst10, src42, weight12); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + dst11 = _mm256_fmadd_ps(dst11, src52, weight12); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 48); + __m256 weight13 = _mm256_load_ps(weight + 56); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + dst6 = _mm256_fmadd_ps(dst6, src03, weight13); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + dst7 = _mm256_fmadd_ps(dst7, src13, weight13); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + dst8 = _mm256_fmadd_ps(dst8, src23, weight13); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + dst9 = _mm256_fmadd_ps(dst9, src33, weight13); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + dst10 = _mm256_fmadd_ps(dst10, src43, weight13); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + dst11 = _mm256_fmadd_ps(dst11, src53, weight13); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 64); + __m256 weight14 = _mm256_load_ps(weight + 72); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + dst6 = _mm256_fmadd_ps(dst6, src04, weight14); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + dst7 = _mm256_fmadd_ps(dst7, src14, weight14); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + dst8 = _mm256_fmadd_ps(dst8, src24, weight14); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + dst9 = _mm256_fmadd_ps(dst9, src34, weight14); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + dst10 = _mm256_fmadd_ps(dst10, src44, weight14); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + dst11 = _mm256_fmadd_ps(dst11, src54, weight14); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 80); + __m256 weight15 = _mm256_load_ps(weight + 88); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + dst6 = _mm256_fmadd_ps(dst6, src05, weight15); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + dst7 = _mm256_fmadd_ps(dst7, src15, weight15); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + dst8 = _mm256_fmadd_ps(dst8, src25, weight15); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + dst9 = _mm256_fmadd_ps(dst9, src35, weight15); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + dst10 = _mm256_fmadd_ps(dst10, src45, weight15); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + dst11 = _mm256_fmadd_ps(dst11, src55, weight15); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 96); + __m256 weight16 = _mm256_load_ps(weight + 104); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + dst6 = _mm256_fmadd_ps(dst6, src06, weight16); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + dst7 = _mm256_fmadd_ps(dst7, src16, weight16); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + dst8 = _mm256_fmadd_ps(dst8, src26, weight16); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + dst9 = _mm256_fmadd_ps(dst9, src36, weight16); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + dst10 = _mm256_fmadd_ps(dst10, src46, weight16); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + dst11 = _mm256_fmadd_ps(dst11, src56, weight16); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 112); + __m256 weight17 = _mm256_load_ps(weight + 120); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + dst6 = _mm256_fmadd_ps(dst6, src07, weight17); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + dst7 = _mm256_fmadd_ps(dst7, src17, weight17); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + dst8 = _mm256_fmadd_ps(dst8, src27, weight17); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + dst9 = _mm256_fmadd_ps(dst9, src37, weight17); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + dst10 = _mm256_fmadd_ps(dst10, src47, weight17); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + dst11 = _mm256_fmadd_ps(dst11, src57, weight17); + src = src + src_stride; + weight += 512; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst9 = _mm256_min_ps(dst9, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst10 = _mm256_min_ps(dst10, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst11 = _mm256_min_ps(dst11, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst8 = _mm256_max_ps(dst8, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst9 = _mm256_max_ps(dst9, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst10 = _mm256_max_ps(dst10, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst11 = _mm256_max_ps(dst11, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 1 * src_stride + 0, dst6); + _mm256_store_ps(dst + 1 * src_stride + 8, dst7); + _mm256_store_ps(dst + 1 * src_stride + 16, dst8); + _mm256_store_ps(dst + 1 * src_stride + 24, dst9); + _mm256_store_ps(dst + 1 * src_stride + 32, dst10); + _mm256_store_ps(dst + 1 * src_stride + 40, dst11); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..bb47361dd31 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,307 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 0(%[dst], %[dst_stride], 1), %%ymm6\n" + "vmovups 32(%[dst], %[dst_stride], 1), %%ymm7\n" + "vmovups 64(%[dst], %[dst_stride], 1), %%ymm8\n" + "vmovups 96(%[dst], %[dst_stride], 1), %%ymm9\n" + "vmovups 128(%[dst], %[dst_stride], 1), %%ymm10\n" + "vmovups 160(%[dst], %[dst_stride], 1), %%ymm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 32(%[bias]), %%ymm6\n" + "vmovaps 32(%[bias]), %%ymm7\n" + "vmovaps 32(%[bias]), %%ymm8\n" + "vmovaps 32(%[bias]), %%ymm9\n" + "vmovaps 32(%[bias]), %%ymm10\n" + "vmovaps 32(%[bias]), %%ymm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "vxorps %%ymm9, %%ymm9, %%ymm9\n" + "vxorps %%ymm10, %%ymm10, %%ymm10\n" + "vxorps %%ymm11, %%ymm11, %%ymm11\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vmovaps 32(%[weight]), %%ymm14\n" + "vbroadcastss 0(%[src]), %%ymm13\n" + "vbroadcastss 32(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 64(%[src]), %%ymm13\n" + "vbroadcastss 96(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 1 + "vmovaps 64(%[weight]), %%ymm15\n" + "vmovaps 96(%[weight]), %%ymm14\n" + "vbroadcastss 1(%[src]), %%ymm13\n" + "vbroadcastss 33(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 65(%[src]), %%ymm13\n" + "vbroadcastss 97(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 2 + "vmovaps 128(%[weight]), %%ymm15\n" + "vmovaps 160(%[weight]), %%ymm14\n" + "vbroadcastss 2(%[src]), %%ymm13\n" + "vbroadcastss 34(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 66(%[src]), %%ymm13\n" + "vbroadcastss 98(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 3 + "vmovaps 192(%[weight]), %%ymm15\n" + "vmovaps 224(%[weight]), %%ymm14\n" + "vbroadcastss 3(%[src]), %%ymm13\n" + "vbroadcastss 35(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 67(%[src]), %%ymm13\n" + "vbroadcastss 99(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 4 + "vmovaps 256(%[weight]), %%ymm15\n" + "vmovaps 288(%[weight]), %%ymm14\n" + "vbroadcastss 4(%[src]), %%ymm13\n" + "vbroadcastss 36(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 68(%[src]), %%ymm13\n" + "vbroadcastss 100(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 5 + "vmovaps 320(%[weight]), %%ymm15\n" + "vmovaps 352(%[weight]), %%ymm14\n" + "vbroadcastss 5(%[src]), %%ymm13\n" + "vbroadcastss 37(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 69(%[src]), %%ymm13\n" + "vbroadcastss 101(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 6 + "vmovaps 384(%[weight]), %%ymm15\n" + "vmovaps 416(%[weight]), %%ymm14\n" + "vbroadcastss 6(%[src]), %%ymm13\n" + "vbroadcastss 38(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 70(%[src]), %%ymm13\n" + "vbroadcastss 102(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + // block 7 + "vmovaps 448(%[weight]), %%ymm15\n" + "vmovaps 480(%[weight]), %%ymm14\n" + "vbroadcastss 7(%[src]), %%ymm13\n" + "vbroadcastss 39(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n" + "vbroadcastss 71(%[src]), %%ymm13\n" + "vbroadcastss 103(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n" + "vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n" + "dec %[deep]\n" + "add 512, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "vmaxps %%ymm9, %%ymm15, %%ymm9\n" + "vmaxps %%ymm10, %%ymm15, %%ymm10\n" + "vmaxps %%ymm11, %%ymm15, %%ymm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "vminps %%ymm9, %%ymm14, %%ymm9\n" + "vminps %%ymm10, %%ymm14, %%ymm10\n" + "vminps %%ymm11, %%ymm14, %%ymm11\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 0(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm7, 32(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm8, 64(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm9, 96(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm10, 128(%[dst], %[dst_stride], 1)\n" + "vmovups %%ymm11, 160(%[dst], %[dst_stride], 1)\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..7efdcd8d0e0 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,201 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..7332ad3917a --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,215 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..e99e39da7eb --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,225 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..2d57354543c --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,237 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..e11e2af8412 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,249 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..97ad1c3bd04 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,259 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c new file mode 100644 index 00000000000..3dc30cff4e3 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c @@ -0,0 +1,273 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + __m256 dst0; + __m256 dst1; + __m256 dst2; + __m256 dst3; + __m256 dst4; + __m256 dst5; + __m256 dst6; + __m256 dst7; + __m256 dst8; + if (inc_flag) { + dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0); + dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8); + dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16); + dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24); + dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32); + dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40); + dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48); + dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56); + dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64); + } else if (bias == NULL) { + dst0 = _mm256_setzero_ps(); + dst1 = _mm256_setzero_ps(); + dst2 = _mm256_setzero_ps(); + dst3 = _mm256_setzero_ps(); + dst4 = _mm256_setzero_ps(); + dst5 = _mm256_setzero_ps(); + dst6 = _mm256_setzero_ps(); + dst7 = _mm256_setzero_ps(); + dst8 = _mm256_setzero_ps(); + } else { + dst0 = _mm256_load_ps(bias + 0); + dst1 = _mm256_load_ps(bias + 0); + dst2 = _mm256_load_ps(bias + 0); + dst3 = _mm256_load_ps(bias + 0); + dst4 = _mm256_load_ps(bias + 0); + dst5 = _mm256_load_ps(bias + 0); + dst6 = _mm256_load_ps(bias + 0); + dst7 = _mm256_load_ps(bias + 0); + dst8 = _mm256_load_ps(bias + 0); + } + for (int i = 0; i < (deep >> 3); ++i) { + // bock0 + __m256 weight00 = _mm256_load_ps(weight + 0); + __m256 src00 = _mm256_set1_ps(*(src + 0)); + dst0 = _mm256_fmadd_ps(dst0, src00, weight00); + __m256 src10 = _mm256_set1_ps(*(src + 8)); + dst1 = _mm256_fmadd_ps(dst1, src10, weight00); + __m256 src20 = _mm256_set1_ps(*(src + 16)); + dst2 = _mm256_fmadd_ps(dst2, src20, weight00); + __m256 src30 = _mm256_set1_ps(*(src + 24)); + dst3 = _mm256_fmadd_ps(dst3, src30, weight00); + __m256 src40 = _mm256_set1_ps(*(src + 32)); + dst4 = _mm256_fmadd_ps(dst4, src40, weight00); + __m256 src50 = _mm256_set1_ps(*(src + 40)); + dst5 = _mm256_fmadd_ps(dst5, src50, weight00); + __m256 src60 = _mm256_set1_ps(*(src + 48)); + dst6 = _mm256_fmadd_ps(dst6, src60, weight00); + __m256 src70 = _mm256_set1_ps(*(src + 56)); + dst7 = _mm256_fmadd_ps(dst7, src70, weight00); + __m256 src80 = _mm256_set1_ps(*(src + 64)); + dst8 = _mm256_fmadd_ps(dst8, src80, weight00); + // bock1 + __m256 weight01 = _mm256_load_ps(weight + 8); + __m256 src01 = _mm256_set1_ps(*(src + 1)); + dst0 = _mm256_fmadd_ps(dst0, src01, weight01); + __m256 src11 = _mm256_set1_ps(*(src + 9)); + dst1 = _mm256_fmadd_ps(dst1, src11, weight01); + __m256 src21 = _mm256_set1_ps(*(src + 17)); + dst2 = _mm256_fmadd_ps(dst2, src21, weight01); + __m256 src31 = _mm256_set1_ps(*(src + 25)); + dst3 = _mm256_fmadd_ps(dst3, src31, weight01); + __m256 src41 = _mm256_set1_ps(*(src + 33)); + dst4 = _mm256_fmadd_ps(dst4, src41, weight01); + __m256 src51 = _mm256_set1_ps(*(src + 41)); + dst5 = _mm256_fmadd_ps(dst5, src51, weight01); + __m256 src61 = _mm256_set1_ps(*(src + 49)); + dst6 = _mm256_fmadd_ps(dst6, src61, weight01); + __m256 src71 = _mm256_set1_ps(*(src + 57)); + dst7 = _mm256_fmadd_ps(dst7, src71, weight01); + __m256 src81 = _mm256_set1_ps(*(src + 65)); + dst8 = _mm256_fmadd_ps(dst8, src81, weight01); + // bock2 + __m256 weight02 = _mm256_load_ps(weight + 16); + __m256 src02 = _mm256_set1_ps(*(src + 2)); + dst0 = _mm256_fmadd_ps(dst0, src02, weight02); + __m256 src12 = _mm256_set1_ps(*(src + 10)); + dst1 = _mm256_fmadd_ps(dst1, src12, weight02); + __m256 src22 = _mm256_set1_ps(*(src + 18)); + dst2 = _mm256_fmadd_ps(dst2, src22, weight02); + __m256 src32 = _mm256_set1_ps(*(src + 26)); + dst3 = _mm256_fmadd_ps(dst3, src32, weight02); + __m256 src42 = _mm256_set1_ps(*(src + 34)); + dst4 = _mm256_fmadd_ps(dst4, src42, weight02); + __m256 src52 = _mm256_set1_ps(*(src + 42)); + dst5 = _mm256_fmadd_ps(dst5, src52, weight02); + __m256 src62 = _mm256_set1_ps(*(src + 50)); + dst6 = _mm256_fmadd_ps(dst6, src62, weight02); + __m256 src72 = _mm256_set1_ps(*(src + 58)); + dst7 = _mm256_fmadd_ps(dst7, src72, weight02); + __m256 src82 = _mm256_set1_ps(*(src + 66)); + dst8 = _mm256_fmadd_ps(dst8, src82, weight02); + // bock3 + __m256 weight03 = _mm256_load_ps(weight + 24); + __m256 src03 = _mm256_set1_ps(*(src + 3)); + dst0 = _mm256_fmadd_ps(dst0, src03, weight03); + __m256 src13 = _mm256_set1_ps(*(src + 11)); + dst1 = _mm256_fmadd_ps(dst1, src13, weight03); + __m256 src23 = _mm256_set1_ps(*(src + 19)); + dst2 = _mm256_fmadd_ps(dst2, src23, weight03); + __m256 src33 = _mm256_set1_ps(*(src + 27)); + dst3 = _mm256_fmadd_ps(dst3, src33, weight03); + __m256 src43 = _mm256_set1_ps(*(src + 35)); + dst4 = _mm256_fmadd_ps(dst4, src43, weight03); + __m256 src53 = _mm256_set1_ps(*(src + 43)); + dst5 = _mm256_fmadd_ps(dst5, src53, weight03); + __m256 src63 = _mm256_set1_ps(*(src + 51)); + dst6 = _mm256_fmadd_ps(dst6, src63, weight03); + __m256 src73 = _mm256_set1_ps(*(src + 59)); + dst7 = _mm256_fmadd_ps(dst7, src73, weight03); + __m256 src83 = _mm256_set1_ps(*(src + 67)); + dst8 = _mm256_fmadd_ps(dst8, src83, weight03); + // bock4 + __m256 weight04 = _mm256_load_ps(weight + 32); + __m256 src04 = _mm256_set1_ps(*(src + 4)); + dst0 = _mm256_fmadd_ps(dst0, src04, weight04); + __m256 src14 = _mm256_set1_ps(*(src + 12)); + dst1 = _mm256_fmadd_ps(dst1, src14, weight04); + __m256 src24 = _mm256_set1_ps(*(src + 20)); + dst2 = _mm256_fmadd_ps(dst2, src24, weight04); + __m256 src34 = _mm256_set1_ps(*(src + 28)); + dst3 = _mm256_fmadd_ps(dst3, src34, weight04); + __m256 src44 = _mm256_set1_ps(*(src + 36)); + dst4 = _mm256_fmadd_ps(dst4, src44, weight04); + __m256 src54 = _mm256_set1_ps(*(src + 44)); + dst5 = _mm256_fmadd_ps(dst5, src54, weight04); + __m256 src64 = _mm256_set1_ps(*(src + 52)); + dst6 = _mm256_fmadd_ps(dst6, src64, weight04); + __m256 src74 = _mm256_set1_ps(*(src + 60)); + dst7 = _mm256_fmadd_ps(dst7, src74, weight04); + __m256 src84 = _mm256_set1_ps(*(src + 68)); + dst8 = _mm256_fmadd_ps(dst8, src84, weight04); + // bock5 + __m256 weight05 = _mm256_load_ps(weight + 40); + __m256 src05 = _mm256_set1_ps(*(src + 5)); + dst0 = _mm256_fmadd_ps(dst0, src05, weight05); + __m256 src15 = _mm256_set1_ps(*(src + 13)); + dst1 = _mm256_fmadd_ps(dst1, src15, weight05); + __m256 src25 = _mm256_set1_ps(*(src + 21)); + dst2 = _mm256_fmadd_ps(dst2, src25, weight05); + __m256 src35 = _mm256_set1_ps(*(src + 29)); + dst3 = _mm256_fmadd_ps(dst3, src35, weight05); + __m256 src45 = _mm256_set1_ps(*(src + 37)); + dst4 = _mm256_fmadd_ps(dst4, src45, weight05); + __m256 src55 = _mm256_set1_ps(*(src + 45)); + dst5 = _mm256_fmadd_ps(dst5, src55, weight05); + __m256 src65 = _mm256_set1_ps(*(src + 53)); + dst6 = _mm256_fmadd_ps(dst6, src65, weight05); + __m256 src75 = _mm256_set1_ps(*(src + 61)); + dst7 = _mm256_fmadd_ps(dst7, src75, weight05); + __m256 src85 = _mm256_set1_ps(*(src + 69)); + dst8 = _mm256_fmadd_ps(dst8, src85, weight05); + // bock6 + __m256 weight06 = _mm256_load_ps(weight + 48); + __m256 src06 = _mm256_set1_ps(*(src + 6)); + dst0 = _mm256_fmadd_ps(dst0, src06, weight06); + __m256 src16 = _mm256_set1_ps(*(src + 14)); + dst1 = _mm256_fmadd_ps(dst1, src16, weight06); + __m256 src26 = _mm256_set1_ps(*(src + 22)); + dst2 = _mm256_fmadd_ps(dst2, src26, weight06); + __m256 src36 = _mm256_set1_ps(*(src + 30)); + dst3 = _mm256_fmadd_ps(dst3, src36, weight06); + __m256 src46 = _mm256_set1_ps(*(src + 38)); + dst4 = _mm256_fmadd_ps(dst4, src46, weight06); + __m256 src56 = _mm256_set1_ps(*(src + 46)); + dst5 = _mm256_fmadd_ps(dst5, src56, weight06); + __m256 src66 = _mm256_set1_ps(*(src + 54)); + dst6 = _mm256_fmadd_ps(dst6, src66, weight06); + __m256 src76 = _mm256_set1_ps(*(src + 62)); + dst7 = _mm256_fmadd_ps(dst7, src76, weight06); + __m256 src86 = _mm256_set1_ps(*(src + 70)); + dst8 = _mm256_fmadd_ps(dst8, src86, weight06); + // bock7 + __m256 weight07 = _mm256_load_ps(weight + 56); + __m256 src07 = _mm256_set1_ps(*(src + 7)); + dst0 = _mm256_fmadd_ps(dst0, src07, weight07); + __m256 src17 = _mm256_set1_ps(*(src + 15)); + dst1 = _mm256_fmadd_ps(dst1, src17, weight07); + __m256 src27 = _mm256_set1_ps(*(src + 23)); + dst2 = _mm256_fmadd_ps(dst2, src27, weight07); + __m256 src37 = _mm256_set1_ps(*(src + 31)); + dst3 = _mm256_fmadd_ps(dst3, src37, weight07); + __m256 src47 = _mm256_set1_ps(*(src + 39)); + dst4 = _mm256_fmadd_ps(dst4, src47, weight07); + __m256 src57 = _mm256_set1_ps(*(src + 47)); + dst5 = _mm256_fmadd_ps(dst5, src57, weight07); + __m256 src67 = _mm256_set1_ps(*(src + 55)); + dst6 = _mm256_fmadd_ps(dst6, src67, weight07); + __m256 src77 = _mm256_set1_ps(*(src + 63)); + dst7 = _mm256_fmadd_ps(dst7, src77, weight07); + __m256 src87 = _mm256_set1_ps(*(src + 71)); + dst8 = _mm256_fmadd_ps(dst8, src87, weight07); + src = src + src_stride; + weight += 256; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_min_ps(dst0, relu6); + dst1 = _mm256_min_ps(dst1, relu6); + dst2 = _mm256_min_ps(dst2, relu6); + dst3 = _mm256_min_ps(dst3, relu6); + dst4 = _mm256_min_ps(dst4, relu6); + dst5 = _mm256_min_ps(dst5, relu6); + dst6 = _mm256_min_ps(dst6, relu6); + dst7 = _mm256_min_ps(dst7, relu6); + dst8 = _mm256_min_ps(dst8, relu6); + // relu + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + dst0 = _mm256_max_ps(dst0, relu); + dst1 = _mm256_max_ps(dst1, relu); + dst2 = _mm256_max_ps(dst2, relu); + dst3 = _mm256_max_ps(dst3, relu); + dst4 = _mm256_max_ps(dst4, relu); + dst5 = _mm256_max_ps(dst5, relu); + dst6 = _mm256_max_ps(dst6, relu); + dst7 = _mm256_max_ps(dst7, relu); + dst8 = _mm256_max_ps(dst8, relu); + } + _mm256_store_ps(dst + 0 * src_stride + 0, dst0); + _mm256_store_ps(dst + 0 * src_stride + 8, dst1); + _mm256_store_ps(dst + 0 * src_stride + 16, dst2); + _mm256_store_ps(dst + 0 * src_stride + 24, dst3); + _mm256_store_ps(dst + 0 * src_stride + 32, dst4); + _mm256_store_ps(dst + 0 * src_stride + 40, dst5); + _mm256_store_ps(dst + 0 * src_stride + 48, dst6); + _mm256_store_ps(dst + 0 * src_stride + 56, dst7); + _mm256_store_ps(dst + 0 * src_stride + 64, dst8); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c new file mode 100644 index 00000000000..599aa518755 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c @@ -0,0 +1,281 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst]), %%ymm0\n" + "vmovups 32(%[dst]), %%ymm1\n" + "vmovups 64(%[dst]), %%ymm2\n" + "vmovups 96(%[dst]), %%ymm3\n" + "vmovups 128(%[dst]), %%ymm4\n" + "vmovups 160(%[dst]), %%ymm5\n" + "vmovups 192(%[dst]), %%ymm6\n" + "vmovups 224(%[dst]), %%ymm7\n" + "vmovups 256(%[dst]), %%ymm8\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%ymm0\n" + "vmovaps 0(%[bias]), %%ymm1\n" + "vmovaps 0(%[bias]), %%ymm2\n" + "vmovaps 0(%[bias]), %%ymm3\n" + "vmovaps 0(%[bias]), %%ymm4\n" + "vmovaps 0(%[bias]), %%ymm5\n" + "vmovaps 0(%[bias]), %%ymm6\n" + "vmovaps 0(%[bias]), %%ymm7\n" + "vmovaps 0(%[bias]), %%ymm8\n" + "jmp 2f\n" + "1:\n" + "vxorps %%ymm0, %%ymm0, %%ymm0\n" + "vxorps %%ymm1, %%ymm1, %%ymm1\n" + "vxorps %%ymm2, %%ymm2, %%ymm2\n" + "vxorps %%ymm3, %%ymm3, %%ymm3\n" + "vxorps %%ymm4, %%ymm4, %%ymm4\n" + "vxorps %%ymm5, %%ymm5, %%ymm5\n" + "vxorps %%ymm6, %%ymm6, %%ymm6\n" + "vxorps %%ymm7, %%ymm7, %%ymm7\n" + "vxorps %%ymm8, %%ymm8, %%ymm8\n" + "2:\n" + : + : [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8"); + asm volatile( + "0:\n" + // block 0 + "vmovaps 0(%[weight]), %%ymm15\n" + "vbroadcastss 0(%[src]), %%ymm14\n" + "vbroadcastss 32(%[src]), %%ymm13\n" + "vbroadcastss 64(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 96(%[src]), %%ymm14\n" + "vbroadcastss 128(%[src]), %%ymm13\n" + "vbroadcastss 160(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 192(%[src]), %%ymm14\n" + "vbroadcastss 224(%[src]), %%ymm13\n" + "vbroadcastss 256(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 1 + "vmovaps 32(%[weight]), %%ymm15\n" + "vbroadcastss 1(%[src]), %%ymm14\n" + "vbroadcastss 33(%[src]), %%ymm13\n" + "vbroadcastss 65(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 97(%[src]), %%ymm14\n" + "vbroadcastss 129(%[src]), %%ymm13\n" + "vbroadcastss 161(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 193(%[src]), %%ymm14\n" + "vbroadcastss 225(%[src]), %%ymm13\n" + "vbroadcastss 257(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 2 + "vmovaps 64(%[weight]), %%ymm15\n" + "vbroadcastss 2(%[src]), %%ymm14\n" + "vbroadcastss 34(%[src]), %%ymm13\n" + "vbroadcastss 66(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 98(%[src]), %%ymm14\n" + "vbroadcastss 130(%[src]), %%ymm13\n" + "vbroadcastss 162(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 194(%[src]), %%ymm14\n" + "vbroadcastss 226(%[src]), %%ymm13\n" + "vbroadcastss 258(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 3 + "vmovaps 96(%[weight]), %%ymm15\n" + "vbroadcastss 3(%[src]), %%ymm14\n" + "vbroadcastss 35(%[src]), %%ymm13\n" + "vbroadcastss 67(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 99(%[src]), %%ymm14\n" + "vbroadcastss 131(%[src]), %%ymm13\n" + "vbroadcastss 163(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 195(%[src]), %%ymm14\n" + "vbroadcastss 227(%[src]), %%ymm13\n" + "vbroadcastss 259(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 4 + "vmovaps 128(%[weight]), %%ymm15\n" + "vbroadcastss 4(%[src]), %%ymm14\n" + "vbroadcastss 36(%[src]), %%ymm13\n" + "vbroadcastss 68(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 100(%[src]), %%ymm14\n" + "vbroadcastss 132(%[src]), %%ymm13\n" + "vbroadcastss 164(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 196(%[src]), %%ymm14\n" + "vbroadcastss 228(%[src]), %%ymm13\n" + "vbroadcastss 260(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 5 + "vmovaps 160(%[weight]), %%ymm15\n" + "vbroadcastss 5(%[src]), %%ymm14\n" + "vbroadcastss 37(%[src]), %%ymm13\n" + "vbroadcastss 69(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 101(%[src]), %%ymm14\n" + "vbroadcastss 133(%[src]), %%ymm13\n" + "vbroadcastss 165(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 197(%[src]), %%ymm14\n" + "vbroadcastss 229(%[src]), %%ymm13\n" + "vbroadcastss 261(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 6 + "vmovaps 192(%[weight]), %%ymm15\n" + "vbroadcastss 6(%[src]), %%ymm14\n" + "vbroadcastss 38(%[src]), %%ymm13\n" + "vbroadcastss 70(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 102(%[src]), %%ymm14\n" + "vbroadcastss 134(%[src]), %%ymm13\n" + "vbroadcastss 166(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 198(%[src]), %%ymm14\n" + "vbroadcastss 230(%[src]), %%ymm13\n" + "vbroadcastss 262(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + // block 7 + "vmovaps 224(%[weight]), %%ymm15\n" + "vbroadcastss 7(%[src]), %%ymm14\n" + "vbroadcastss 39(%[src]), %%ymm13\n" + "vbroadcastss 71(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n" + "vbroadcastss 103(%[src]), %%ymm14\n" + "vbroadcastss 135(%[src]), %%ymm13\n" + "vbroadcastss 167(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n" + "vbroadcastss 199(%[src]), %%ymm14\n" + "vbroadcastss 231(%[src]), %%ymm13\n" + "vbroadcastss 263(%[src]), %%ymm12\n" + "vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n" + "vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n" + "vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n" + "dec %[deep]\n" + "add 256, %[weight]\n" + "add %[src_stride], %[src]\n" + "jg 0b\n" + + "movq %[inc_flag], %%rax\n" + "and $0x2, %%eax\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%ymm15, %%ymm15, %%ymm15\n" + "vmaxps %%ymm0, %%ymm15, %%ymm0\n" + "vmaxps %%ymm1, %%ymm15, %%ymm1\n" + "vmaxps %%ymm2, %%ymm15, %%ymm2\n" + "vmaxps %%ymm3, %%ymm15, %%ymm3\n" + "vmaxps %%ymm4, %%ymm15, %%ymm4\n" + "vmaxps %%ymm5, %%ymm15, %%ymm5\n" + "vmaxps %%ymm6, %%ymm15, %%ymm6\n" + "vmaxps %%ymm7, %%ymm15, %%ymm7\n" + "vmaxps %%ymm8, %%ymm15, %%ymm8\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm14\n" + "vpermps %%ymm14, %%ymm15, %%ymm14\n" + "vminps %%ymm0, %%ymm14, %%ymm0\n" + "vminps %%ymm1, %%ymm14, %%ymm1\n" + "vminps %%ymm2, %%ymm14, %%ymm2\n" + "vminps %%ymm3, %%ymm14, %%ymm3\n" + "vminps %%ymm4, %%ymm14, %%ymm4\n" + "vminps %%ymm5, %%ymm14, %%ymm5\n" + "vminps %%ymm6, %%ymm14, %%ymm6\n" + "vminps %%ymm7, %%ymm14, %%ymm7\n" + "vminps %%ymm8, %%ymm14, %%ymm8\n" + "3:\n" + "vmovups %%ymm0, 0(%[dst])\n" + "vmovups %%ymm1, 32(%[dst])\n" + "vmovups %%ymm2, 64(%[dst])\n" + "vmovups %%ymm3, 96(%[dst])\n" + "vmovups %%ymm4, 128(%[dst])\n" + "vmovups %%ymm5, 160(%[dst])\n" + "vmovups %%ymm6, 192(%[dst])\n" + "vmovups %%ymm7, 224(%[dst])\n" + "vmovups %%ymm8, 256(%[dst])\n" + : + : [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", + "%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15"); +} diff --git a/mindspore/lite/experiment/HPC-generator/generate_hpc.sh b/mindspore/lite/experiment/HPC-generator/generate_hpc.sh new file mode 100644 index 00000000000..84768b26839 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/generate_hpc.sh @@ -0,0 +1,87 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# generate gemm fma asm code +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c + +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c + +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c + +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c + +# generate gemm fma intrinics code +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c + +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c + +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c + +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c +python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c + +# generate gemm avx512 asm code +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=4 col_block=96 -O ./gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=3 col_block=96 -O ./gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=2 col_block=96 -O ./gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=1 col_block=96 -O ./gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c + +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=5 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=4 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=3 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=2 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=1 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c diff --git a/mindspore/lite/experiment/HPC-generator/generator.py b/mindspore/lite/experiment/HPC-generator/generator.py new file mode 100644 index 00000000000..7e7edb56f03 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/generator.py @@ -0,0 +1,148 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""HPC generator""" + +import sys +import os +import io +import argparse +from itertools import chain + +def key_value_pair(line): + """ + split key and value + :param line: + :return: + """ + key, value = line.split("=", 1) + try: + value = int(value) + except ValueError: + print("Error: you input value must be integer, but now is ", value) + return key, value + +def get_indent(line): + """ + get indent length + :param line: + :return: + """ + index = 0 + for i in line: + if i == " ": + index += 1 + else: + break + return index + +def print_line(line): + """ + Convert line to a python string + :param line: + :return: + """ + global python_indent + global generate_code_indent + if line.strip()[0] == "}" or line.strip()[0] == ")": + python_indent = -1 + split_str = line.split("@") + if line.strip()[0] != "@" and len(split_str) == 1: + if get_indent(line) == python_indent or python_indent == -1: + result = ["print(", line, ", file=OUT_STREAM)"] + python_indent = -1 + if "{" in line or "asm volatile(" in line: + generate_code_indent = get_indent(line) + if line.strip().startswith("}") and "{" not in line: + generate_code_indent -= 4 + if (len(line) == 1 and line[0] == "}"): + # modify next fun generate_code_indent + generate_code_indent = -4 + return "\"".join(result) + + if line.strip()[0] == '@': + # get python indent and first generate_code_indent + if python_indent == -1: + generate_code_indent = get_indent(line) - 4 + python_indent = get_indent(line) + result = split_str[0][python_indent:] + split_str[1] + return result + + index = get_indent(split_str[0]) + result = [split_str[0][python_indent:index] + "print("] + Prefix = " " * (generate_code_indent + 4) + split_str[0].lstrip() + + Suffix = " %(" + for str_tmp in split_str[1:]: + second = str_tmp.find("}") + Suffix += str_tmp[1:second] + ', ' + str_tmp = str_tmp.replace(str_tmp[0:second + 1], "%d") + Prefix += str_tmp + result.append(Prefix) + result.append(Suffix + "), file=OUT_STREAM)") + return "\"".join(result) + +def generate_code(template_file, exec_dict): + """ + generate hpc + :param template_file: template file path + :param exec_dict: dict + :return: hpc + """ + output_stream = io.StringIO() + with open(template_file, 'r') as f: + generate_code_lines = [] + for line in f: + line = line.replace("\n", "") + if line.strip() and line.strip()[0] != "@": + line = line.replace("\"", "\\\"") + if line.strip() and line.strip()[0] != "@": + line = line.replace("%", "%%") + if "print" in line: + line = line.replace("%%", "%") + if not line: + generate_code_lines.append("print(" + "\"" + line + "\"" + ", file=OUT_STREAM)") + else: + str = print_line(line) + if "%(" not in str: + str = str.replace("%%[", "%[") + generate_code_lines.append(str) + # print('\n'.join(generate_code_lines)) + c = compile('\n'.join(generate_code_lines), '', 'exec') + exec_dict["OUT_STREAM"] = output_stream + exec(c, exec_dict) + return output_stream.getvalue() + +generate_code_indent = -4 +python_indent = -1 + +parser = argparse.ArgumentParser(description="MSLite NNACL Code Generator") +parser.add_argument("-I", dest="Template_File", nargs=1, help="template file to generate code") +parser.add_argument("-A", dest="defines", metavar="KEY=VALUE", nargs="*", type=key_value_pair, action="append", + help="Custom Parameters") +parser.add_argument("-O", dest="Output_File", nargs=1, help="generate code output file path") + +if __name__ == "__main__": + parameters = parser.parse_args(sys.argv[1:]) + exec_globals = dict(chain(*parameters.defines)) + + generate_code_str = generate_code(parameters.Template_File[0], exec_globals) + if os.path.exists(parameters.Output_File[0]): + os.remove(parameters.Output_File[0]) + + saveDir = os.path.dirname(parameters.Output_File[0]) + if not os.path.exists(saveDir): + os.mkdir(saveDir) + with open(parameters.Output_File[0], "w", encoding='utf-8') as output_file: + output_file.write(generate_code_str) diff --git a/mindspore/lite/experiment/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in b/mindspore/lite/experiment/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in new file mode 100644 index 00000000000..4f5f1f885b0 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in @@ -0,0 +1,169 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, + const float *bias, const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t deep, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag) { + @asm_flag_list = [] + @row_split_number = [row for row in range(3, row_block, 3)] + @for row in row_split_number: + const float *dst_@{row} = dst + @{row} * dst_stride; + @asm_flag_list.append("[dst_" + str(row) + "] " + "\"r\"(dst_" + str(row) + ")"); + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + @col_split_num = col_block >> 4; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\\n" + "je 0f\\n" + @for row in range(0, row_block): + @tmp = int(row / 3) * 3 + @for col in range(0, col_split_num): + @if row % 3 == 0: + "vmovups @{col * 64}(%[dst_@{tmp}]), %%zmm@{row * col_split_num + col}\\n" + @else: + "vmovups @{col * 64}(%[dst_@{tmp}], %[dst_stride], @{row - tmp}), %%zmm@{row * col_split_num + col}\\n" + "jmp 2f\\n" + "0:\\n" + "cmpq $0, %[bias]\\n" + "je 1f\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vmovaps @{col * 64}(%[bias]), %%zmm@{row * col_split_num + col}\\n" + "jmp 2f\\n" + "1:\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vxorps %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}\\n" + "2:\\n" + : + @list = ["[dst_0] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)"] + @list.extend(asm_flag_list) + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, row_block * col_block >>4)]), file=OUT_STREAM) + ); + @for row in row_split_number: + const float *src_@{row} = src + @{row} * dst_stride; + @asm_flag_list.append("[src_" + str(row) + "] " + "\"r\"(src_" + str(row) + ")"); + asm volatile( + "0:\\n" + @loop_count = 8 + @for i in range(0, loop_count): + // block @{i} + @for col in range(0, col_split_num): + "vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n" + @if col_split_num == 6: + @if row_block == 4: + "vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num}\\n" + "vbroadcastss @{i * 4}(%[src_0], %[src_stride], 1), %%zmm@{31 - col_split_num - 1}\\n" + @for col in range(0, col_split_num): + "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{0 * col_split_num + col}\\n" + "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{1 * col_split_num + col}\\n" + "vbroadcastss @{i * 4}(%[src_0], %[src_stride], 2), %%zmm@{31 - col_split_num}\\n" + "vbroadcastss @{i * 4}(%[src_3]), %%zmm@{31 - col_split_num - 1}\\n" + @for col in range(0, col_split_num): + "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n" + "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{3 * col_split_num + col}\\n" + @else: + @for row in range(0, row_block): + @if row == 0: + "vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num - row}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_0], %[src_stride], @{row}), %%zmm@{31 - col_split_num - row}\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - row}, %%zmm@{row * col_split_num + col}\\n" + @elif col_split_num == 5: + @if row_block == 5: + "vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num}\\n" + "vbroadcastss @{i * 4}(%[src_0], %[src_stride], 1), %%zmm@{31 - col_split_num - 1}\\n" + @for col in range(0, col_split_num): + "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{0 * col_split_num + col}\\n" + "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{1 * col_split_num + col}\\n" + "vbroadcastss @{i * 4}(%[src_0], %[src_stride], 2), %%zmm@{31 - col_split_num}\\n" + "vbroadcastss @{i * 4}(%[src_3]), %%zmm@{31 - col_split_num - 1}\\n" + @for col in range(0, col_split_num): + "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n" + "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{3 * col_split_num + col}\\n" + "vbroadcastss @{i * 4}(%[src_3], %[src_stride], 1), %%zmm@{31 - col_split_num}\\n" + @for col in range(0, col_split_num): + "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n" + @else: + @for row in range(0, row_block): + @if row == 0: + "vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num - row}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_0], %[src_stride], @{row}), %%zmm@{31 - col_split_num - row}\\n" + @for row in range(0, row_block): + @for col in range(0, col_split_num): + "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - row}, %%zmm@{row * col_split_num + col}\\n" + @elif col_split_num == 2: + @for row in range(int(row_block / 6)): + @for j in range(0, 6): + @tmp = int(j / 3) * 3 + @if j % 3 == 0: + "vbroadcastss @{i * 4}(%[src_@{row * 6 + tmp}]), %%zmm@{31 - col_split_num - j}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{row * 6 + tmp}], %[src_stride], @{j - tmp}), %%zmm@{31 - col_split_num - j}\\n" + @for col in range(0, col_split_num): + @for j in range(0, 6): + "fmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - j}, %%zmm@{(row * 6 + j) * col_split_num + col}\\n" + + "dec %[deep]\\n" + "add $@{col_block * 4 * 8}, %[weight]\\n" + "add $@{loop_count * 4}, %[src_0]\\n" + @for row in row_split_number: + "add $@{loop_count * 4}, %[src_@{row}]\\n" + "jg 0b\\n" + + "and $0x2, %[inc_flag]\\n" + "je 3f\\n" + "movq %[act_flag], %rax\\n" + "and $0x3, %eax\\n" + "je 3f\\n" + // relu + "vxorps %zmm31, %zmm31, %zmm31\\n" + @for col in range(0, col_split_num): + @for row in range(0, row_block): + "vmaxps %%zmm@{row + col * row_block}, %%zmm31, %%zmm@{row + col * row_block}\\n" + "and $0x1, %eax\\n" + "je 3f\\n" + // relu6 + "mov $0x40C00000, %eax\\n" + "vmovd %eax, %xmm30\\n" + "vbroadcastss %xmm30, %zmm30\\n" + @for col in range(0, col_split_num): + @for row in range(0, row_block): + "vminps %%zmm@{row + col * row_block}, %%zmm30, %%zmm@{row + col * row_block}\\n" + "3:\\n" + @for row in range(0, row_block): + @tmp = int(row / 3) * 3 + @for col in range(0, col_split_num): + @if row % 3 == 0: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{tmp}])\\n" + @else: + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{tmp}], %[dst_stride], @{row - tmp}),\\n" + : + @list = ["[src_0] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[deep] \"r\"(deep_t)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst_0] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"] + @list.extend(asm_flag_list) + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : \"%rax\", " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, 32)]), file=OUT_STREAM) + ); +} diff --git a/mindspore/lite/experiment/HPC-generator/template_file/gemm_fma_nc8hw8.c.in b/mindspore/lite/experiment/HPC-generator/template_file/gemm_fma_nc8hw8.c.in new file mode 100644 index 00000000000..641b1857e78 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/template_file/gemm_fma_nc8hw8.c.in @@ -0,0 +1,85 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma intrinsic code +void nnacl_gemm_fma_@{row_block}x@{col_block}_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, + const float *bias, const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t deep, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag) { + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + __m256 dst@{j * row_block + i}; + if (inc_flag) { + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{j * row_block + i} = _mm256_load_ps(dst + @{j} * dst_stride + @{i * 8}); + } else if (bias == NULL) { + @for i in range(0, row_block * col_block >> 3): + dst@{i} = _mm256_setzero_ps(); + } else { + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{j * row_block + i} = _mm256_load_ps(bias + @{j * 8}); + } + for (int i = 0; i < (deep >> 3); ++i) { + @for i in range(0, 8): + // bock@{i} + @if col_block == 32: + @for row in range(0, row_block): + __m256 src@{row}@{i} = _mm256_set1_ps(*(src + @{row * 8 + i})); + @for col in range(0, col_block >> 3): + __m256 weight@{col}@{i} = _mm256_load_ps(weight + @{col * 8 + i * col_block}); + @for row in range(0, row_block): + dst@{row + col * row_block} = _mm256_fmadd_ps(dst@{row + col * row_block}, src@{row}@{i}, weight@{col}@{i}); + @else: + @for col in range(0, col_block >> 3): + __m256 weight@{col}@{i} = _mm256_load_ps(weight + @{col * 8 + i * col_block}); + @for row in range(0, row_block): + __m256 src@{row}@{i} = _mm256_set1_ps(*(src + @{row * 8 + i})); + @for col in range(0, col_block >> 3): + dst@{row + col * row_block} = _mm256_fmadd_ps(dst@{row + col * row_block}, src@{row}@{i}, weight@{col}@{i}); + src = src + src_stride; + weight += @{8 * col_block * 4}; + } + if (act_flag & 0x02) { + // relu6 + __m256 relu6 = _mm256_set1_ps(6.0f); + __m256 relu = _mm256_setzero_ps(); + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{i + j * row_block} = _mm256_min_ps(dst@{i + j * row_block}, relu6); + // relu + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{i + j * row_block} = _mm256_max_ps(dst@{i + j * row_block}, relu); + } + if (act_flag & 0x01) { + // relu + __m256 relu = _mm256_setzero_ps(); + @for i in range(0, row_block): + @for j in range(0, col_block >> 3): + dst@{i + j * row_block} = _mm256_max_ps(dst@{i + j * row_block}, relu); + } + @if col_block == 32: + @for j in range(0, col_block >> 3): + @for i in range(0, row_block): + _mm256_store_ps(dst + @{j} * src_stride + @{i * 8}, dst@{j * row_block + i}); + @else: + @for j in range(0, col_block >> 3): + @for i in range(0, row_block): + _mm256_store_ps(dst + @{j} * src_stride + @{i * 8}, dst@{j * row_block + i}); +} diff --git a/mindspore/lite/experiment/HPC-generator/template_file/gemm_fma_nc8hw8_asm.c.in b/mindspore/lite/experiment/HPC-generator/template_file/gemm_fma_nc8hw8_asm.c.in new file mode 100644 index 00000000000..70178cf583c --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/template_file/gemm_fma_nc8hw8_asm.c.in @@ -0,0 +1,149 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 fma asm code +void nnacl_gemm_fma_@{row_block}x@{col_block}_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, + const float *bias, const size_t act_flag, const size_t row_block, + const size_t col_block, const size_t deep, const size_t src_stride, + const size_t dst_stride, const size_t inc_flag) { + @if col_block == 32: + const float *dst_4 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + size_t src_stride_t = src_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\\n" + "je 0f\\n" + @for col in range(0, min((col_block >> 3), 3)): + @for row in range(0, row_block): + @if col == 0: + "vmovups @{row * 32}(%[dst]), %%ymm@{row + col * row_block}\\n" + @else: + "vmovups @{row * 32}(%[dst], %[dst_stride], @{col}), %%ymm@{row + col * row_block}\\n" + @if col_block == 32: + @for row in range(0, row_block): + "vmovups @{row * 32}(%[dst_4]), %%ymm@{row + (col + 1) * row_block}\\n" + "jmp 2f\\n" + "0:\\n" + "cmpq $0, %[bias]\\n" + "je 1f\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vmovaps @{col * 32}(%[bias]), %%ymm@{row + col * row_block}\\n" + "jmp 2f\\n" + "1:\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vxorps %%ymm@{row + col * row_block}, %%ymm@{row + col * row_block}, %%ymm@{row + col * row_block}\\n" + "2:\\n" + : + @list = ["[dst] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)"] + @if col_block == 32: + @list.append("[dst_4] \"r\"(dst_4)") + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : " + ", ".join(["\"%ymm" + str(i) + "\"" for i in range(0, row_block * col_block >> 3)]), file=OUT_STREAM) + ); + asm volatile( + "0:\\n" + @for i in range(0, 8): + // block @{i} + @if col_block == 32: + @for row in range(0, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{15 - row}\\n" + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - row_block}\\n" + @for row in range(0, row_block): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - row_block}, %%ymm@{15 - row}\\n" + @elif col_block == 24: + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n" + @for row in range(0, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{14 - col}\\n" + @for col in range(0, col_block >> 3): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3)}, %%ymm@{15 - col}\\n" + @elif col_block == 16: + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n" + @for row in range(0, row_block >> 1): + "vbroadcastss @{row * 64 + i}(%[src]), %%ymm@{14 - col}\\n" + "vbroadcastss @{row * 64 + 32 + i}(%[src]), %%ymm@{13 - col}\\n" + @for col in range(0, col_block >> 3): + @for j in range(0, 2): + "vfmadd231ps %%ymm@{row * 2 + j + col * row_block}, %%ymm@{15 - (col_block >> 3) - j}, %%ymm@{15 - col}\\n" + @for row in range(row_block >> 1 << 1, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{14 - col}\\n" + @for col in range(0, col_block >> 3): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3)}, %%ymm@{15 - col}\\n" + @else: + @for col in range(0, col_block >> 3): + "vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n" + @split_num = 3 + @for row in range(0, int(row_block / split_num)): + @for j in range(0, split_num): + "vbroadcastss @{row * 96 + j * 32 + i}(%[src]), %%ymm@{15 - (col_block >> 3) - j}\\n" + @for col in range(0, col_block >> 3): + @for j in range(0, split_num): + "vfmadd231ps %%ymm@{row * split_num + j + col * row_block}, %%ymm@{15 - (col_block >> 3) - j}, %%ymm@{15 - col}\\n" + @for row in range(int(row_block / split_num) * split_num, row_block): + "vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{15 - (col_block >> 3) - (row - int(row_block / split_num) * split_num)}\\n" + @for col in range(0, col_block >> 3): + @for row in range(int(row_block / split_num) * split_num, row_block): + "vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3) - (row - int(row_block / split_num) * split_num)}, %%ymm@{15 - col}\\n" + "dec %[deep]\\n" + "add @{col_block * 4 * 8}, %[weight]\\n" + "add %[src_stride], %[src]\\n" + "jg 0b\\n" + + "movq %[inc_flag], %rax\\n" + "and $0x2, %eax\\n" + "je 3f\\n" + "movq %[act_flag], %rax\\n" + "and $0x3, %eax\\n" + "je 3f\\n" + // relu + "vxorps %ymm15, %ymm15, %ymm15\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vmaxps %%ymm@{row + col * row_block}, %%ymm15, %%ymm@{row + col * row_block}\\n" + "and $0x1, %eax\\n" + "je 3f\\n" + // relu6 + "mov $0x40C00000, %eax\\n" + "vmovd %eax, %xmm14\\n" + "vpermps %ymm14, %ymm15, %ymm14\\n" + @for col in range(0, col_block >> 3): + @for row in range(0, row_block): + "vminps %%ymm@{row + col * row_block}, %%ymm14, %%ymm@{row + col * row_block}\\n" + "3:\\n" + @for col in range(0, min((col_block >> 3), 3)): + @for row in range(0, row_block): + @if col == 0: + "vmovups %%ymm@{row + col * row_block}, @{row * 32}(%[dst])\\n" + @else: + "vmovups %%ymm@{row + col * row_block}, @{row * 32}(%[dst], %[dst_stride], @{col})\\n" + @if col_block == 32: + @for row in range(0, row_block): + "vmovups %%ymm@{row + (col + 1) * row_block}, @{row * 32}(%[dst_4])\\n" + : + @list = ["[src] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[deep] \"r\"(deep_t)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"] + @if col_block == 32: + @list.append("[dst_4] \"r\"(dst_4)") + @print(" : " + ", ".join(list), file=OUT_STREAM) + @print(" : \"%rax\", " + ", ".join(["\"%ymm" + str(i) + "\"" for i in range(0, 16)]), file=OUT_STREAM) + ); +}