diff --git a/.jenkins/check/config/filter_pylint.txt b/.jenkins/check/config/filter_pylint.txt index 0049451f750..f507535726f 100644 --- a/.jenkins/check/config/filter_pylint.txt +++ b/.jenkins/check/config/filter_pylint.txt @@ -121,4 +121,5 @@ #MindSpore Lite "mindspore/mindspore/lite/experiment/HPC-generator/generator.py" "redefined-builtin" -"mindspore/mindspore/lite/experiment/HPC-generator/generator.py" "exec-used" \ No newline at end of file +"mindspore/mindspore/lite/experiment/HPC-generator/generator.py" "exec-used" +"mindspore/mindspore/lite/experiment/HPC-generator/generator.py" "global-variable-undefined" \ No newline at end of file diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index c5c4fe17373..20fb8320baf 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -149,3 +149,17 @@ mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_ mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x80_kernel_nhwc_fp32 mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x80_kernel_nhwc_fp32 mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x80_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x64_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x64_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x32_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x32_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x64_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_12x32_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x32_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x32_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x32_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x64_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x64_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x32_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x64_kernel_nhwc_fp32 +mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x32_kernel_nhwc_fp32 diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..7be01fb1884 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c @@ -0,0 +1,537 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_12x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + const float *dst_9 = dst + 9 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n" + "vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n" + "vmovups 0(%[dst_9]), %%zmm18\n" + "vmovups 64(%[dst_9]), %%zmm19\n" + "vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm20\n" + "vmovups 64(%[dst_9], %[dst_stride], 1), %%zmm21\n" + "vmovups 0(%[dst_9], %[dst_stride], 2), %%zmm22\n" + "vmovups 64(%[dst_9], %[dst_stride], 2), %%zmm23\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 0(%[bias]), %%zmm2\n" + "vmovaps 64(%[bias]), %%zmm3\n" + "vmovaps 0(%[bias]), %%zmm4\n" + "vmovaps 64(%[bias]), %%zmm5\n" + "vmovaps 0(%[bias]), %%zmm6\n" + "vmovaps 64(%[bias]), %%zmm7\n" + "vmovaps 0(%[bias]), %%zmm8\n" + "vmovaps 64(%[bias]), %%zmm9\n" + "vmovaps 0(%[bias]), %%zmm10\n" + "vmovaps 64(%[bias]), %%zmm11\n" + "vmovaps 0(%[bias]), %%zmm12\n" + "vmovaps 64(%[bias]), %%zmm13\n" + "vmovaps 0(%[bias]), %%zmm14\n" + "vmovaps 64(%[bias]), %%zmm15\n" + "vmovaps 0(%[bias]), %%zmm16\n" + "vmovaps 64(%[bias]), %%zmm17\n" + "vmovaps 0(%[bias]), %%zmm18\n" + "vmovaps 64(%[bias]), %%zmm19\n" + "vmovaps 0(%[bias]), %%zmm20\n" + "vmovaps 64(%[bias]), %%zmm21\n" + "vmovaps 0(%[bias]), %%zmm22\n" + "vmovaps 64(%[bias]), %%zmm23\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + const float *src_9 = src + 9 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 0(%[src_6]), %%zmm29\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_9]), %%zmm26\n" + "vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 4(%[src_6]), %%zmm29\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_9]), %%zmm26\n" + "vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_9], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 8(%[src_6]), %%zmm29\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_9]), %%zmm26\n" + "vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_9], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 12(%[src_6]), %%zmm29\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_9]), %%zmm26\n" + "vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_9], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 16(%[src_6]), %%zmm29\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_9]), %%zmm26\n" + "vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_9], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 20(%[src_6]), %%zmm29\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_9]), %%zmm26\n" + "vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_9], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 24(%[src_6]), %%zmm29\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_9]), %%zmm26\n" + "vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_9], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vbroadcastss 28(%[src_6]), %%zmm29\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_9]), %%zmm26\n" + "vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_9], %[src_stride], 2), %%zmm24\n" + "fmadd231ps %%zmm31, %%zmm29, %%zmm12\n" + "fmadd231ps %%zmm31, %%zmm28, %%zmm14\n" + "fmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "fmadd231ps %%zmm31, %%zmm26, %%zmm18\n" + "fmadd231ps %%zmm31, %%zmm25, %%zmm20\n" + "fmadd231ps %%zmm31, %%zmm24, %%zmm22\n" + "fmadd231ps %%zmm30, %%zmm29, %%zmm13\n" + "fmadd231ps %%zmm30, %%zmm28, %%zmm15\n" + "fmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "fmadd231ps %%zmm30, %%zmm26, %%zmm19\n" + "fmadd231ps %%zmm30, %%zmm25, %%zmm21\n" + "fmadd231ps %%zmm30, %%zmm24, %%zmm23\n" + "dec %[deep]\n" + "add $1024, %[weight]\n" + "add $32, %[src_0]\n" + "add $32, %[src_3]\n" + "add $32, %[src_6]\n" + "add $32, %[src_9]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2)\n" + "vmovups %%zmm18, 0(%[dst_9])\n" + "vmovups %%zmm19, 64(%[dst_9])\n" + "vmovups %%zmm20, 0(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_9], %[dst_stride], 1)\n" + "vmovups %%zmm22, 0(%[dst_9], %[dst_stride], 2)\n" + "vmovups %%zmm23, 64(%[dst_9], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6), + [ src_9 ] "r"(src_9) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..12e606ef09c --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c @@ -0,0 +1,127 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "dec %[deep]\n" + "add $1024, %[weight]\n" + "add $32, %[src_0]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..9eddd4f5299 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c @@ -0,0 +1,171 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_1x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 128(%[bias]), %%zmm2\n" + "vmovaps 192(%[bias]), %%zmm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "dec %[deep]\n" + "add $2048, %[weight]\n" + "add $32, %[src_0]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c index fead3fbf07f..873f0d85767 100644 --- a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c @@ -22,7 +22,6 @@ void nnacl_gemm_avx512_1x80_kernel_nhwc_fp32(float *dst, const float *src, const const size_t inc_flag) { size_t deep_t = deep >> 3; size_t dst_stride_t = dst_stride << 2; - size_t src_stride_t = src_stride << 2; asm volatile( // inc in deep "and $0x1, %[inc_flag]\n" @@ -52,6 +51,7 @@ void nnacl_gemm_avx512_1x80_kernel_nhwc_fp32(float *dst, const float *src, const : : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4"); + size_t src_stride_t = src_stride << 2; asm volatile( "0:\n" // block 0 @@ -150,7 +150,6 @@ void nnacl_gemm_avx512_1x80_kernel_nhwc_fp32(float *dst, const float *src, const "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" - "dec %[deep]\n" "add $2560, %[weight]\n" "add $32, %[src_0]\n" diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c index dfa732aedea..324a3a2ce7b 100644 --- a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c @@ -22,7 +22,6 @@ void nnacl_gemm_avx512_1x96_kernel_nhwc_fp32(float *dst, const float *src, const const size_t inc_flag) { size_t deep_t = deep >> 3; size_t dst_stride_t = dst_stride << 2; - size_t src_stride_t = src_stride << 2; asm volatile( // inc in deep "and $0x1, %[inc_flag]\n" @@ -55,6 +54,7 @@ void nnacl_gemm_avx512_1x96_kernel_nhwc_fp32(float *dst, const float *src, const : : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; asm volatile( "0:\n" // block 0 @@ -169,7 +169,6 @@ void nnacl_gemm_avx512_1x96_kernel_nhwc_fp32(float *dst, const float *src, const "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" - "dec %[deep]\n" "add $3072, %[weight]\n" "add $32, %[src_0]\n" diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..955f5eacf1c --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c @@ -0,0 +1,163 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 0(%[bias]), %%zmm2\n" + "vmovaps 64(%[bias]), %%zmm3\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "dec %[deep]\n" + "add $1024, %[weight]\n" + "add $32, %[src_0]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..d4b5eaf8140 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c @@ -0,0 +1,235 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_2x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 128(%[bias]), %%zmm2\n" + "vmovaps 192(%[bias]), %%zmm3\n" + "vmovaps 0(%[bias]), %%zmm4\n" + "vmovaps 64(%[bias]), %%zmm5\n" + "vmovaps 128(%[bias]), %%zmm6\n" + "vmovaps 192(%[bias]), %%zmm7\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "dec %[deep]\n" + "add $2048, %[weight]\n" + "add $32, %[src_0]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c index 49ee1fc9540..847f531ae1e 100644 --- a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c @@ -22,7 +22,6 @@ void nnacl_gemm_avx512_2x80_kernel_nhwc_fp32(float *dst, const float *src, const const size_t inc_flag) { size_t deep_t = deep >> 3; size_t dst_stride_t = dst_stride << 2; - size_t src_stride_t = src_stride << 2; asm volatile( // inc in deep "and $0x1, %[inc_flag]\n" @@ -67,6 +66,7 @@ void nnacl_gemm_avx512_2x80_kernel_nhwc_fp32(float *dst, const float *src, const : : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + size_t src_stride_t = src_stride << 2; asm volatile( "0:\n" // block 0 @@ -213,7 +213,6 @@ void nnacl_gemm_avx512_2x80_kernel_nhwc_fp32(float *dst, const float *src, const "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" - "dec %[deep]\n" "add $2560, %[weight]\n" "add $32, %[src_0]\n" @@ -258,11 +257,11 @@ void nnacl_gemm_avx512_2x80_kernel_nhwc_fp32(float *dst, const float *src, const "vmovups %%zmm2, 128(%[dst_0])\n" "vmovups %%zmm3, 192(%[dst_0])\n" "vmovups %%zmm4, 256(%[dst_0])\n" - "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" : : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c index a36d357f577..1ffddc52517 100644 --- a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c @@ -22,7 +22,6 @@ void nnacl_gemm_avx512_2x96_kernel_nhwc_fp32(float *dst, const float *src, const const size_t inc_flag) { size_t deep_t = deep >> 3; size_t dst_stride_t = dst_stride << 2; - size_t src_stride_t = src_stride << 2; asm volatile( // inc in deep "and $0x1, %[inc_flag]\n" @@ -73,6 +72,7 @@ void nnacl_gemm_avx512_2x96_kernel_nhwc_fp32(float *dst, const float *src, const : : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + size_t src_stride_t = src_stride << 2; asm volatile( "0:\n" // block 0 @@ -243,7 +243,6 @@ void nnacl_gemm_avx512_2x96_kernel_nhwc_fp32(float *dst, const float *src, const "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" - "dec %[deep]\n" "add $3072, %[weight]\n" "add $32, %[src_0]\n" @@ -293,12 +292,12 @@ void nnacl_gemm_avx512_2x96_kernel_nhwc_fp32(float *dst, const float *src, const "vmovups %%zmm3, 192(%[dst_0])\n" "vmovups %%zmm4, 256(%[dst_0])\n" "vmovups %%zmm5, 320(%[dst_0])\n" - "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1),\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1)\n" : : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..73acd676707 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c @@ -0,0 +1,199 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 0(%[bias]), %%zmm2\n" + "vmovaps 64(%[bias]), %%zmm3\n" + "vmovaps 0(%[bias]), %%zmm4\n" + "vmovaps 64(%[bias]), %%zmm5\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "dec %[deep]\n" + "add $1024, %[weight]\n" + "add $32, %[src_0]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..bdd588c5219 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c @@ -0,0 +1,299 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_3x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 128(%[bias]), %%zmm2\n" + "vmovaps 192(%[bias]), %%zmm3\n" + "vmovaps 0(%[bias]), %%zmm4\n" + "vmovaps 64(%[bias]), %%zmm5\n" + "vmovaps 128(%[bias]), %%zmm6\n" + "vmovaps 192(%[bias]), %%zmm7\n" + "vmovaps 0(%[bias]), %%zmm8\n" + "vmovaps 64(%[bias]), %%zmm9\n" + "vmovaps 128(%[bias]), %%zmm10\n" + "vmovaps 192(%[bias]), %%zmm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + size_t src_stride_t = src_stride << 2; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "dec %[deep]\n" + "add $2048, %[weight]\n" + "add $32, %[src_0]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c index 5dbb4874098..1979b37929b 100644 --- a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c @@ -22,7 +22,6 @@ void nnacl_gemm_avx512_3x80_kernel_nhwc_fp32(float *dst, const float *src, const const size_t inc_flag) { size_t deep_t = deep >> 3; size_t dst_stride_t = dst_stride << 2; - size_t src_stride_t = src_stride << 2; asm volatile( // inc in deep "and $0x1, %[inc_flag]\n" @@ -83,6 +82,7 @@ void nnacl_gemm_avx512_3x80_kernel_nhwc_fp32(float *dst, const float *src, const : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", "%zmm12", "%zmm13", "%zmm14"); + size_t src_stride_t = src_stride << 2; asm volatile( "0:\n" // block 0 @@ -277,7 +277,6 @@ void nnacl_gemm_avx512_3x80_kernel_nhwc_fp32(float *dst, const float *src, const "vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n" "vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n" "vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n" - "dec %[deep]\n" "add $2560, %[weight]\n" "add $32, %[src_0]\n" @@ -332,16 +331,16 @@ void nnacl_gemm_avx512_3x80_kernel_nhwc_fp32(float *dst, const float *src, const "vmovups %%zmm2, 128(%[dst_0])\n" "vmovups %%zmm3, 192(%[dst_0])\n" "vmovups %%zmm4, 256(%[dst_0])\n" - "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2)\n" : : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c index d84aa83254b..ac02c714e28 100644 --- a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c @@ -22,7 +22,6 @@ void nnacl_gemm_avx512_3x96_kernel_nhwc_fp32(float *dst, const float *src, const const size_t inc_flag) { size_t deep_t = deep >> 3; size_t dst_stride_t = dst_stride << 2; - size_t src_stride_t = src_stride << 2; asm volatile( // inc in deep "and $0x1, %[inc_flag]\n" @@ -92,6 +91,7 @@ void nnacl_gemm_avx512_3x96_kernel_nhwc_fp32(float *dst, const float *src, const : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag) : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17"); + size_t src_stride_t = src_stride << 2; asm volatile( "0:\n" // block 0 @@ -318,7 +318,6 @@ void nnacl_gemm_avx512_3x96_kernel_nhwc_fp32(float *dst, const float *src, const "vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n" "vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n" "vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n" - "dec %[deep]\n" "add $3072, %[weight]\n" "add $32, %[src_0]\n" @@ -380,18 +379,18 @@ void nnacl_gemm_avx512_3x96_kernel_nhwc_fp32(float *dst, const float *src, const "vmovups %%zmm3, 192(%[dst_0])\n" "vmovups %%zmm4, 256(%[dst_0])\n" "vmovups %%zmm5, 320(%[dst_0])\n" - "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2)\n" : : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t) diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..f5f2607233e --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c @@ -0,0 +1,240 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 0(%[bias]), %%zmm2\n" + "vmovaps 64(%[bias]), %%zmm3\n" + "vmovaps 0(%[bias]), %%zmm4\n" + "vmovaps 64(%[bias]), %%zmm5\n" + "vmovaps 0(%[bias]), %%zmm6\n" + "vmovaps 64(%[bias]), %%zmm7\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "dec %[deep]\n" + "add $1024, %[weight]\n" + "add $32, %[src_0]\n" + "add $32, %[src_3]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..9dba8c541d5 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c @@ -0,0 +1,369 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_4x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 128(%[bias]), %%zmm2\n" + "vmovaps 192(%[bias]), %%zmm3\n" + "vmovaps 0(%[bias]), %%zmm4\n" + "vmovaps 64(%[bias]), %%zmm5\n" + "vmovaps 128(%[bias]), %%zmm6\n" + "vmovaps 192(%[bias]), %%zmm7\n" + "vmovaps 0(%[bias]), %%zmm8\n" + "vmovaps 64(%[bias]), %%zmm9\n" + "vmovaps 128(%[bias]), %%zmm10\n" + "vmovaps 192(%[bias]), %%zmm11\n" + "vmovaps 0(%[bias]), %%zmm12\n" + "vmovaps 64(%[bias]), %%zmm13\n" + "vmovaps 128(%[bias]), %%zmm14\n" + "vmovaps 192(%[bias]), %%zmm15\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "dec %[deep]\n" + "add $2048, %[weight]\n" + "add $32, %[src_0]\n" + "add $32, %[src_3]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c index d3e8bf83423..c1a99afdc8a 100644 --- a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c @@ -23,7 +23,6 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const const float *dst_3 = dst + 3 * dst_stride; size_t deep_t = deep >> 3; size_t dst_stride_t = dst_stride << 2; - size_t src_stride_t = src_stride << 2; asm volatile( // inc in deep "and $0x1, %[inc_flag]\n" @@ -100,7 +99,8 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const [ dst_3 ] "r"(dst_3) : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); - const float *src_3 = src + 3 * dst_stride; + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; asm volatile( "0:\n" // block 0 @@ -112,7 +112,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 0(%[src_0]), %%zmm26\n" "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n" - "vbroadcastss 0(%[src_0], %[src_stride], 3), %%zmm23\n" + "vbroadcastss 0(%[src_3]), %%zmm23\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" @@ -142,7 +142,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 4(%[src_0]), %%zmm26\n" "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n" - "vbroadcastss 4(%[src_0], %[src_stride], 3), %%zmm23\n" + "vbroadcastss 4(%[src_3]), %%zmm23\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" @@ -172,7 +172,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 8(%[src_0]), %%zmm26\n" "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n" - "vbroadcastss 8(%[src_0], %[src_stride], 3), %%zmm23\n" + "vbroadcastss 8(%[src_3]), %%zmm23\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" @@ -202,7 +202,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 12(%[src_0]), %%zmm26\n" "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n" - "vbroadcastss 12(%[src_0], %[src_stride], 3), %%zmm23\n" + "vbroadcastss 12(%[src_3]), %%zmm23\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" @@ -232,7 +232,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 16(%[src_0]), %%zmm26\n" "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n" - "vbroadcastss 16(%[src_0], %[src_stride], 3), %%zmm23\n" + "vbroadcastss 16(%[src_3]), %%zmm23\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" @@ -262,7 +262,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 20(%[src_0]), %%zmm26\n" "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n" - "vbroadcastss 20(%[src_0], %[src_stride], 3), %%zmm23\n" + "vbroadcastss 20(%[src_3]), %%zmm23\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" @@ -292,7 +292,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 24(%[src_0]), %%zmm26\n" "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n" - "vbroadcastss 24(%[src_0], %[src_stride], 3), %%zmm23\n" + "vbroadcastss 24(%[src_3]), %%zmm23\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" @@ -322,7 +322,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 28(%[src_0]), %%zmm26\n" "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n" - "vbroadcastss 28(%[src_0], %[src_stride], 3), %%zmm23\n" + "vbroadcastss 28(%[src_3]), %%zmm23\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" @@ -343,7 +343,6 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const "vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n" "vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n" - "dec %[deep]\n" "add $2560, %[weight]\n" "add $32, %[src_0]\n" @@ -409,16 +408,16 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const "vmovups %%zmm2, 128(%[dst_0])\n" "vmovups %%zmm3, 192(%[dst_0])\n" "vmovups %%zmm4, 256(%[dst_0])\n" - "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2)\n" "vmovups %%zmm15, 0(%[dst_3])\n" "vmovups %%zmm16, 64(%[dst_3])\n" "vmovups %%zmm17, 128(%[dst_3])\n" diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c index d945400880a..8426a36bd28 100644 --- a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c @@ -23,7 +23,6 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const const float *dst_3 = dst + 3 * dst_stride; size_t deep_t = deep >> 3; size_t dst_stride_t = dst_stride << 2; - size_t src_stride_t = src_stride << 2; asm volatile( // inc in deep "and $0x1, %[inc_flag]\n" @@ -113,7 +112,8 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", "%zmm23"); - const float *src_3 = src + 3 * dst_stride; + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; asm volatile( "0:\n" // block 0 @@ -126,30 +126,30 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 0(%[src_0]), %%zmm25\n" "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" "vbroadcastss 0(%[src_3]), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" // block 1 "vmovups 384(%[weight]), %%zmm31\n" @@ -161,30 +161,30 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 4(%[src_0]), %%zmm25\n" "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" "vbroadcastss 4(%[src_3]), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" // block 2 "vmovups 768(%[weight]), %%zmm31\n" @@ -196,30 +196,30 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 8(%[src_0]), %%zmm25\n" "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" "vbroadcastss 8(%[src_3]), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" // block 3 "vmovups 1152(%[weight]), %%zmm31\n" @@ -231,30 +231,30 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 12(%[src_0]), %%zmm25\n" "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" "vbroadcastss 12(%[src_3]), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" // block 4 "vmovups 1536(%[weight]), %%zmm31\n" @@ -266,30 +266,30 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 16(%[src_0]), %%zmm25\n" "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" "vbroadcastss 16(%[src_3]), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" // block 5 "vmovups 1920(%[weight]), %%zmm31\n" @@ -301,30 +301,30 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 20(%[src_0]), %%zmm25\n" "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" "vbroadcastss 20(%[src_3]), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" // block 6 "vmovups 2304(%[weight]), %%zmm31\n" @@ -336,30 +336,30 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 24(%[src_0]), %%zmm25\n" "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" "vbroadcastss 24(%[src_3]), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" // block 7 "vmovups 2688(%[weight]), %%zmm31\n" @@ -371,32 +371,31 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 28(%[src_0]), %%zmm25\n" "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n" "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" "vbroadcastss 28(%[src_3]), %%zmm24\n" "vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n" - "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" "vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n" - "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" "vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n" - "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" "vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n" - "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n" - "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n" + "vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n" "vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n" - "dec %[deep]\n" "add $3072, %[weight]\n" "add $32, %[src_0]\n" @@ -471,18 +470,18 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const "vmovups %%zmm3, 192(%[dst_0])\n" "vmovups %%zmm4, 256(%[dst_0])\n" "vmovups %%zmm5, 320(%[dst_0])\n" - "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2)\n" "vmovups %%zmm18, 0(%[dst_3])\n" "vmovups %%zmm19, 64(%[dst_3])\n" "vmovups %%zmm20, 128(%[dst_3])\n" diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..e9220c42065 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c @@ -0,0 +1,276 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 0(%[bias]), %%zmm2\n" + "vmovaps 64(%[bias]), %%zmm3\n" + "vmovaps 0(%[bias]), %%zmm4\n" + "vmovaps 64(%[bias]), %%zmm5\n" + "vmovaps 0(%[bias]), %%zmm6\n" + "vmovaps 64(%[bias]), %%zmm7\n" + "vmovaps 0(%[bias]), %%zmm8\n" + "vmovaps 64(%[bias]), %%zmm9\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "dec %[deep]\n" + "add $1024, %[weight]\n" + "add $32, %[src_0]\n" + "add $32, %[src_3]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..6638b7a23a7 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c @@ -0,0 +1,433 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_5x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 128(%[bias]), %%zmm2\n" + "vmovaps 192(%[bias]), %%zmm3\n" + "vmovaps 0(%[bias]), %%zmm4\n" + "vmovaps 64(%[bias]), %%zmm5\n" + "vmovaps 128(%[bias]), %%zmm6\n" + "vmovaps 192(%[bias]), %%zmm7\n" + "vmovaps 0(%[bias]), %%zmm8\n" + "vmovaps 64(%[bias]), %%zmm9\n" + "vmovaps 128(%[bias]), %%zmm10\n" + "vmovaps 192(%[bias]), %%zmm11\n" + "vmovaps 0(%[bias]), %%zmm12\n" + "vmovaps 64(%[bias]), %%zmm13\n" + "vmovaps 128(%[bias]), %%zmm14\n" + "vmovaps 192(%[bias]), %%zmm15\n" + "vmovaps 0(%[bias]), %%zmm16\n" + "vmovaps 64(%[bias]), %%zmm17\n" + "vmovaps 128(%[bias]), %%zmm18\n" + "vmovaps 192(%[bias]), %%zmm19\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n" + "dec %[deep]\n" + "add $2048, %[weight]\n" + "add $32, %[src_0]\n" + "add $32, %[src_3]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3])\n" + "vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c index 9a0c58579b4..9ea67417bba 100644 --- a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c @@ -23,7 +23,6 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const const float *dst_3 = dst + 3 * dst_stride; size_t deep_t = deep >> 3; size_t dst_stride_t = dst_stride << 2; - size_t src_stride_t = src_stride << 2; asm volatile( // inc in deep "and $0x1, %[inc_flag]\n" @@ -116,7 +115,8 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", "%zmm23", "%zmm24"); - const float *src_3 = src + 3 * dst_stride; + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; asm volatile( "0:\n" // block 0 @@ -128,33 +128,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 0(%[src_0]), %%zmm26\n" "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n" "vbroadcastss 0(%[src_3]), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n" - "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" // block 1 "vmovups 320(%[weight]), %%zmm31\n" "vmovups 384(%[weight]), %%zmm30\n" @@ -164,33 +164,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 4(%[src_0]), %%zmm26\n" "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n" "vbroadcastss 4(%[src_3]), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n" - "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" // block 2 "vmovups 640(%[weight]), %%zmm31\n" "vmovups 704(%[weight]), %%zmm30\n" @@ -200,33 +200,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 8(%[src_0]), %%zmm26\n" "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n" "vbroadcastss 8(%[src_3]), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n" - "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" // block 3 "vmovups 960(%[weight]), %%zmm31\n" "vmovups 1024(%[weight]), %%zmm30\n" @@ -236,33 +236,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 12(%[src_0]), %%zmm26\n" "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n" "vbroadcastss 12(%[src_3]), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n" - "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" // block 4 "vmovups 1280(%[weight]), %%zmm31\n" "vmovups 1344(%[weight]), %%zmm30\n" @@ -272,33 +272,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 16(%[src_0]), %%zmm26\n" "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n" "vbroadcastss 16(%[src_3]), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n" - "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" // block 5 "vmovups 1600(%[weight]), %%zmm31\n" "vmovups 1664(%[weight]), %%zmm30\n" @@ -308,33 +308,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 20(%[src_0]), %%zmm26\n" "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n" "vbroadcastss 20(%[src_3]), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n" - "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" // block 6 "vmovups 1920(%[weight]), %%zmm31\n" "vmovups 1984(%[weight]), %%zmm30\n" @@ -344,33 +344,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 24(%[src_0]), %%zmm26\n" "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n" "vbroadcastss 24(%[src_3]), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n" - "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" // block 7 "vmovups 2240(%[weight]), %%zmm31\n" "vmovups 2304(%[weight]), %%zmm30\n" @@ -380,34 +380,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const "vbroadcastss 28(%[src_0]), %%zmm26\n" "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n" "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n" "vbroadcastss 28(%[src_3]), %%zmm25\n" "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n" "vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n" "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n" - "vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n" - "vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n" - "vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n" - "vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n" - "vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n" - + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n" "dec %[deep]\n" "add $2560, %[weight]\n" "add $32, %[src_0]\n" @@ -483,26 +482,26 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const "vmovups %%zmm2, 128(%[dst_0])\n" "vmovups %%zmm3, 192(%[dst_0])\n" "vmovups %%zmm4, 256(%[dst_0])\n" - "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n" - "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2),\n" - "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2),\n" + "vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2)\n" "vmovups %%zmm15, 0(%[dst_3])\n" "vmovups %%zmm16, 64(%[dst_3])\n" "vmovups %%zmm17, 128(%[dst_3])\n" "vmovups %%zmm18, 192(%[dst_3])\n" "vmovups %%zmm19, 256(%[dst_3])\n" - "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 1),\n" - "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 1),\n" - "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 1),\n" - "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 1),\n" - "vmovups %%zmm24, 256(%[dst_3], %[dst_stride], 1),\n" + "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm24, 256(%[dst_3], %[dst_stride], 1)\n" : : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..a23d69ee6d1 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c @@ -0,0 +1,312 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 0(%[bias]), %%zmm2\n" + "vmovaps 64(%[bias]), %%zmm3\n" + "vmovaps 0(%[bias]), %%zmm4\n" + "vmovaps 64(%[bias]), %%zmm5\n" + "vmovaps 0(%[bias]), %%zmm6\n" + "vmovaps 64(%[bias]), %%zmm7\n" + "vmovaps 0(%[bias]), %%zmm8\n" + "vmovaps 64(%[bias]), %%zmm9\n" + "vmovaps 0(%[bias]), %%zmm10\n" + "vmovaps 64(%[bias]), %%zmm11\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "dec %[deep]\n" + "add $1024, %[weight]\n" + "add $32, %[src_0]\n" + "add $32, %[src_3]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..a59d6a8585c --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c @@ -0,0 +1,498 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_6x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 128(%[dst_0]), %%zmm2\n" + "vmovups 192(%[dst_0]), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n" + "vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n" + "vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n" + "vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n" + "vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_3]), %%zmm12\n" + "vmovups 64(%[dst_3]), %%zmm13\n" + "vmovups 128(%[dst_3]), %%zmm14\n" + "vmovups 192(%[dst_3]), %%zmm15\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n" + "vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n" + "vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm20\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm21\n" + "vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm22\n" + "vmovups 192(%[dst_3], %[dst_stride], 2), %%zmm23\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 128(%[bias]), %%zmm2\n" + "vmovaps 192(%[bias]), %%zmm3\n" + "vmovaps 0(%[bias]), %%zmm4\n" + "vmovaps 64(%[bias]), %%zmm5\n" + "vmovaps 128(%[bias]), %%zmm6\n" + "vmovaps 192(%[bias]), %%zmm7\n" + "vmovaps 0(%[bias]), %%zmm8\n" + "vmovaps 64(%[bias]), %%zmm9\n" + "vmovaps 128(%[bias]), %%zmm10\n" + "vmovaps 192(%[bias]), %%zmm11\n" + "vmovaps 0(%[bias]), %%zmm12\n" + "vmovaps 64(%[bias]), %%zmm13\n" + "vmovaps 128(%[bias]), %%zmm14\n" + "vmovaps 192(%[bias]), %%zmm15\n" + "vmovaps 0(%[bias]), %%zmm16\n" + "vmovaps 64(%[bias]), %%zmm17\n" + "vmovaps 128(%[bias]), %%zmm18\n" + "vmovaps 192(%[bias]), %%zmm19\n" + "vmovaps 0(%[bias]), %%zmm20\n" + "vmovaps 64(%[bias]), %%zmm21\n" + "vmovaps 128(%[bias]), %%zmm22\n" + "vmovaps 192(%[bias]), %%zmm23\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "vxorps %%zmm16, %%zmm16, %%zmm16\n" + "vxorps %%zmm17, %%zmm17, %%zmm17\n" + "vxorps %%zmm18, %%zmm18, %%zmm18\n" + "vxorps %%zmm19, %%zmm19, %%zmm19\n" + "vxorps %%zmm20, %%zmm20, %%zmm20\n" + "vxorps %%zmm21, %%zmm21, %%zmm21\n" + "vxorps %%zmm22, %%zmm22, %%zmm22\n" + "vxorps %%zmm23, %%zmm23, %%zmm23\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22", + "%zmm23"); + const float *src_3 = src + 3 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vmovups 128(%[weight]), %%zmm29\n" + "vmovups 192(%[weight]), %%zmm28\n" + "vbroadcastss 0(%[src_0]), %%zmm27\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 0(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 1 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vmovups 384(%[weight]), %%zmm29\n" + "vmovups 448(%[weight]), %%zmm28\n" + "vbroadcastss 4(%[src_0]), %%zmm27\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 4(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 2 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vmovups 640(%[weight]), %%zmm29\n" + "vmovups 704(%[weight]), %%zmm28\n" + "vbroadcastss 8(%[src_0]), %%zmm27\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 8(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 3 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vmovups 896(%[weight]), %%zmm29\n" + "vmovups 960(%[weight]), %%zmm28\n" + "vbroadcastss 12(%[src_0]), %%zmm27\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 12(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 4 + "vmovups 1024(%[weight]), %%zmm31\n" + "vmovups 1088(%[weight]), %%zmm30\n" + "vmovups 1152(%[weight]), %%zmm29\n" + "vmovups 1216(%[weight]), %%zmm28\n" + "vbroadcastss 16(%[src_0]), %%zmm27\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 16(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 5 + "vmovups 1280(%[weight]), %%zmm31\n" + "vmovups 1344(%[weight]), %%zmm30\n" + "vmovups 1408(%[weight]), %%zmm29\n" + "vmovups 1472(%[weight]), %%zmm28\n" + "vbroadcastss 20(%[src_0]), %%zmm27\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 20(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 6 + "vmovups 1536(%[weight]), %%zmm31\n" + "vmovups 1600(%[weight]), %%zmm30\n" + "vmovups 1664(%[weight]), %%zmm29\n" + "vmovups 1728(%[weight]), %%zmm28\n" + "vbroadcastss 24(%[src_0]), %%zmm27\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 24(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + // block 7 + "vmovups 1792(%[weight]), %%zmm31\n" + "vmovups 1856(%[weight]), %%zmm30\n" + "vmovups 1920(%[weight]), %%zmm29\n" + "vmovups 1984(%[weight]), %%zmm28\n" + "vbroadcastss 28(%[src_0]), %%zmm27\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n" + "vbroadcastss 28(%[src_3]), %%zmm24\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n" + "vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n" + "vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n" + "vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm27\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm26\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n" + "vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n" + "vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n" + "vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n" + "vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n" + "dec %[deep]\n" + "add $2048, %[weight]\n" + "add $32, %[src_0]\n" + "add $32, %[src_3]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "vmaxps %%zmm16, %%zmm31, %%zmm16\n" + "vmaxps %%zmm17, %%zmm31, %%zmm17\n" + "vmaxps %%zmm18, %%zmm31, %%zmm18\n" + "vmaxps %%zmm19, %%zmm31, %%zmm19\n" + "vmaxps %%zmm20, %%zmm31, %%zmm20\n" + "vmaxps %%zmm21, %%zmm31, %%zmm21\n" + "vmaxps %%zmm22, %%zmm31, %%zmm22\n" + "vmaxps %%zmm23, %%zmm31, %%zmm23\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "vminps %%zmm16, %%zmm30, %%zmm16\n" + "vminps %%zmm17, %%zmm30, %%zmm17\n" + "vminps %%zmm18, %%zmm30, %%zmm18\n" + "vminps %%zmm19, %%zmm30, %%zmm19\n" + "vminps %%zmm20, %%zmm30, %%zmm20\n" + "vminps %%zmm21, %%zmm30, %%zmm21\n" + "vminps %%zmm22, %%zmm30, %%zmm22\n" + "vminps %%zmm23, %%zmm30, %%zmm23\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 128(%[dst_0])\n" + "vmovups %%zmm3, 192(%[dst_0])\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_3])\n" + "vmovups %%zmm13, 64(%[dst_3])\n" + "vmovups %%zmm14, 128(%[dst_3])\n" + "vmovups %%zmm15, 192(%[dst_3])\n" + "vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 2)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..fc094fd0f31 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c @@ -0,0 +1,352 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_7x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 0(%[bias]), %%zmm2\n" + "vmovaps 64(%[bias]), %%zmm3\n" + "vmovaps 0(%[bias]), %%zmm4\n" + "vmovaps 64(%[bias]), %%zmm5\n" + "vmovaps 0(%[bias]), %%zmm6\n" + "vmovaps 64(%[bias]), %%zmm7\n" + "vmovaps 0(%[bias]), %%zmm8\n" + "vmovaps 64(%[bias]), %%zmm9\n" + "vmovaps 0(%[bias]), %%zmm10\n" + "vmovaps 64(%[bias]), %%zmm11\n" + "vmovaps 0(%[bias]), %%zmm12\n" + "vmovaps 64(%[bias]), %%zmm13\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "dec %[deep]\n" + "add $1024, %[weight]\n" + "add $32, %[src_0]\n" + "add $32, %[src_3]\n" + "add $32, %[src_6]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c new file mode 100644 index 00000000000..5ef728e90d1 --- /dev/null +++ b/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c @@ -0,0 +1,388 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +// nnacl gemm in x86 avx512 asm code +void nnacl_gemm_avx512_8x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias, + const size_t act_flag, const size_t row_block, const size_t col_block, + const size_t deep, const size_t src_stride, const size_t dst_stride, + const size_t inc_flag) { + const float *dst_3 = dst + 3 * dst_stride; + const float *dst_6 = dst + 6 * dst_stride; + size_t deep_t = deep >> 3; + size_t dst_stride_t = dst_stride << 2; + asm volatile( + // inc in deep + "and $0x1, %[inc_flag]\n" + "je 0f\n" + "vmovups 0(%[dst_0]), %%zmm0\n" + "vmovups 64(%[dst_0]), %%zmm1\n" + "vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n" + "vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n" + "vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n" + "vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n" + "vmovups 0(%[dst_3]), %%zmm6\n" + "vmovups 64(%[dst_3]), %%zmm7\n" + "vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n" + "vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n" + "vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n" + "vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n" + "vmovups 0(%[dst_6]), %%zmm12\n" + "vmovups 64(%[dst_6]), %%zmm13\n" + "vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n" + "vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n" + "jmp 2f\n" + "0:\n" + "cmpq $0, %[bias]\n" + "je 1f\n" + "vmovaps 0(%[bias]), %%zmm0\n" + "vmovaps 64(%[bias]), %%zmm1\n" + "vmovaps 0(%[bias]), %%zmm2\n" + "vmovaps 64(%[bias]), %%zmm3\n" + "vmovaps 0(%[bias]), %%zmm4\n" + "vmovaps 64(%[bias]), %%zmm5\n" + "vmovaps 0(%[bias]), %%zmm6\n" + "vmovaps 64(%[bias]), %%zmm7\n" + "vmovaps 0(%[bias]), %%zmm8\n" + "vmovaps 64(%[bias]), %%zmm9\n" + "vmovaps 0(%[bias]), %%zmm10\n" + "vmovaps 64(%[bias]), %%zmm11\n" + "vmovaps 0(%[bias]), %%zmm12\n" + "vmovaps 64(%[bias]), %%zmm13\n" + "vmovaps 0(%[bias]), %%zmm14\n" + "vmovaps 64(%[bias]), %%zmm15\n" + "jmp 2f\n" + "1:\n" + "vxorps %%zmm0, %%zmm0, %%zmm0\n" + "vxorps %%zmm1, %%zmm1, %%zmm1\n" + "vxorps %%zmm2, %%zmm2, %%zmm2\n" + "vxorps %%zmm3, %%zmm3, %%zmm3\n" + "vxorps %%zmm4, %%zmm4, %%zmm4\n" + "vxorps %%zmm5, %%zmm5, %%zmm5\n" + "vxorps %%zmm6, %%zmm6, %%zmm6\n" + "vxorps %%zmm7, %%zmm7, %%zmm7\n" + "vxorps %%zmm8, %%zmm8, %%zmm8\n" + "vxorps %%zmm9, %%zmm9, %%zmm9\n" + "vxorps %%zmm10, %%zmm10, %%zmm10\n" + "vxorps %%zmm11, %%zmm11, %%zmm11\n" + "vxorps %%zmm12, %%zmm12, %%zmm12\n" + "vxorps %%zmm13, %%zmm13, %%zmm13\n" + "vxorps %%zmm14, %%zmm14, %%zmm14\n" + "vxorps %%zmm15, %%zmm15, %%zmm15\n" + "2:\n" + : + : [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6) + : "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11", + "%zmm12", "%zmm13", "%zmm14", "%zmm15"); + const float *src_3 = src + 3 * src_stride; + const float *src_6 = src + 6 * src_stride; + size_t src_stride_t = src_stride << 2; + asm volatile( + "0:\n" + // block 0 + "vmovups 0(%[weight]), %%zmm31\n" + "vmovups 64(%[weight]), %%zmm30\n" + "vbroadcastss 0(%[src_0]), %%zmm29\n" + "vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 0(%[src_3]), %%zmm26\n" + "vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 0(%[src_6]), %%zmm23\n" + "vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 1 + "vmovups 128(%[weight]), %%zmm31\n" + "vmovups 192(%[weight]), %%zmm30\n" + "vbroadcastss 4(%[src_0]), %%zmm29\n" + "vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 4(%[src_3]), %%zmm26\n" + "vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 4(%[src_6]), %%zmm23\n" + "vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 2 + "vmovups 256(%[weight]), %%zmm31\n" + "vmovups 320(%[weight]), %%zmm30\n" + "vbroadcastss 8(%[src_0]), %%zmm29\n" + "vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 8(%[src_3]), %%zmm26\n" + "vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 8(%[src_6]), %%zmm23\n" + "vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 3 + "vmovups 384(%[weight]), %%zmm31\n" + "vmovups 448(%[weight]), %%zmm30\n" + "vbroadcastss 12(%[src_0]), %%zmm29\n" + "vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 12(%[src_3]), %%zmm26\n" + "vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 12(%[src_6]), %%zmm23\n" + "vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 4 + "vmovups 512(%[weight]), %%zmm31\n" + "vmovups 576(%[weight]), %%zmm30\n" + "vbroadcastss 16(%[src_0]), %%zmm29\n" + "vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 16(%[src_3]), %%zmm26\n" + "vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 16(%[src_6]), %%zmm23\n" + "vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 5 + "vmovups 640(%[weight]), %%zmm31\n" + "vmovups 704(%[weight]), %%zmm30\n" + "vbroadcastss 20(%[src_0]), %%zmm29\n" + "vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 20(%[src_3]), %%zmm26\n" + "vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 20(%[src_6]), %%zmm23\n" + "vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 6 + "vmovups 768(%[weight]), %%zmm31\n" + "vmovups 832(%[weight]), %%zmm30\n" + "vbroadcastss 24(%[src_0]), %%zmm29\n" + "vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 24(%[src_3]), %%zmm26\n" + "vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 24(%[src_6]), %%zmm23\n" + "vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + // block 7 + "vmovups 896(%[weight]), %%zmm31\n" + "vmovups 960(%[weight]), %%zmm30\n" + "vbroadcastss 28(%[src_0]), %%zmm29\n" + "vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n" + "vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n" + "vbroadcastss 28(%[src_3]), %%zmm26\n" + "vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n" + "vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n" + "vbroadcastss 28(%[src_6]), %%zmm23\n" + "vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n" + "vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n" + "vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n" + "vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n" + "vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n" + "vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n" + "vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n" + "vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n" + "vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n" + "vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n" + "vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n" + "vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n" + "vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n" + "vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n" + "vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n" + "vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n" + "vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n" + "dec %[deep]\n" + "add $1024, %[weight]\n" + "add $32, %[src_0]\n" + "add $32, %[src_3]\n" + "add $32, %[src_6]\n" + "jg 0b\n" + + "and $0x2, %[inc_flag]\n" + "je 3f\n" + "movq %[act_flag], %%rax\n" + "and $0x3, %%eax\n" + "je 3f\n" + // relu + "vxorps %%zmm31, %%zmm31, %%zmm31\n" + "vmaxps %%zmm0, %%zmm31, %%zmm0\n" + "vmaxps %%zmm1, %%zmm31, %%zmm1\n" + "vmaxps %%zmm2, %%zmm31, %%zmm2\n" + "vmaxps %%zmm3, %%zmm31, %%zmm3\n" + "vmaxps %%zmm4, %%zmm31, %%zmm4\n" + "vmaxps %%zmm5, %%zmm31, %%zmm5\n" + "vmaxps %%zmm6, %%zmm31, %%zmm6\n" + "vmaxps %%zmm7, %%zmm31, %%zmm7\n" + "vmaxps %%zmm8, %%zmm31, %%zmm8\n" + "vmaxps %%zmm9, %%zmm31, %%zmm9\n" + "vmaxps %%zmm10, %%zmm31, %%zmm10\n" + "vmaxps %%zmm11, %%zmm31, %%zmm11\n" + "vmaxps %%zmm12, %%zmm31, %%zmm12\n" + "vmaxps %%zmm13, %%zmm31, %%zmm13\n" + "vmaxps %%zmm14, %%zmm31, %%zmm14\n" + "vmaxps %%zmm15, %%zmm31, %%zmm15\n" + "and $0x1, %%eax\n" + "je 3f\n" + // relu6 + "mov $0x40C00000, %%eax\n" + "vmovd %%eax, %%xmm30\n" + "vbroadcastss %%xmm30, %%zmm30\n" + "vminps %%zmm0, %%zmm30, %%zmm0\n" + "vminps %%zmm1, %%zmm30, %%zmm1\n" + "vminps %%zmm2, %%zmm30, %%zmm2\n" + "vminps %%zmm3, %%zmm30, %%zmm3\n" + "vminps %%zmm4, %%zmm30, %%zmm4\n" + "vminps %%zmm5, %%zmm30, %%zmm5\n" + "vminps %%zmm6, %%zmm30, %%zmm6\n" + "vminps %%zmm7, %%zmm30, %%zmm7\n" + "vminps %%zmm8, %%zmm30, %%zmm8\n" + "vminps %%zmm9, %%zmm30, %%zmm9\n" + "vminps %%zmm10, %%zmm30, %%zmm10\n" + "vminps %%zmm11, %%zmm30, %%zmm11\n" + "vminps %%zmm12, %%zmm30, %%zmm12\n" + "vminps %%zmm13, %%zmm30, %%zmm13\n" + "vminps %%zmm14, %%zmm30, %%zmm14\n" + "vminps %%zmm15, %%zmm30, %%zmm15\n" + "3:\n" + "vmovups %%zmm0, 0(%[dst_0])\n" + "vmovups %%zmm1, 64(%[dst_0])\n" + "vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n" + "vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n" + "vmovups %%zmm6, 0(%[dst_3])\n" + "vmovups %%zmm7, 64(%[dst_3])\n" + "vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n" + "vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n" + "vmovups %%zmm12, 0(%[dst_6])\n" + "vmovups %%zmm13, 64(%[dst_6])\n" + "vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n" + "vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n" + : + : [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t), + [ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t), + [ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6) + : "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", + "%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", + "%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31"); +} diff --git a/mindspore/lite/experiment/HPC-generator/generate_hpc.sh b/mindspore/lite/experiment/HPC-generator/generate_hpc.sh index 84768b26839..dc44faafb54 100644 --- a/mindspore/lite/experiment/HPC-generator/generate_hpc.sh +++ b/mindspore/lite/experiment/HPC-generator/generate_hpc.sh @@ -15,64 +15,64 @@ # ============================================================================ # generate gemm fma asm code -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c - -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c - -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c - -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c - -# generate gemm fma intrinics code -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c - -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c - -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c - -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c -python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c +# +## generate gemm fma intrinics code +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c +# +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c +#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c # generate gemm avx512 asm code python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=4 col_block=96 -O ./gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c @@ -85,3 +85,19 @@ python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=4 python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=3 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=2 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=1 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c + +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=6 col_block=64 -O ./gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=5 col_block=64 -O ./gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=4 col_block=64 -O ./gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=3 col_block=64 -O ./gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=2 col_block=64 -O ./gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=1 col_block=64 -O ./gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c + +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=8 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=7 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=6 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=5 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=4 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=3 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=2 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c +python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=1 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c diff --git a/mindspore/lite/experiment/HPC-generator/generator.py b/mindspore/lite/experiment/HPC-generator/generator.py index 7e7edb56f03..f03b6de97ce 100644 --- a/mindspore/lite/experiment/HPC-generator/generator.py +++ b/mindspore/lite/experiment/HPC-generator/generator.py @@ -26,11 +26,18 @@ def key_value_pair(line): :param line: :return: """ - key, value = line.split("=", 1) + key = None + value = None + try: + key, value = line.split("=", 1) + except ValueError: + print("line must be format: key=value, but now is:", line) + sys.exit(1) try: value = int(value) except ValueError: - print("Error: you input value must be integer, but now is ", value) + print("Error: you input value must be integer, but now is:", value) + sys.exit(1) return key, value def get_indent(line): @@ -66,7 +73,7 @@ def print_line(line): generate_code_indent = get_indent(line) if line.strip().startswith("}") and "{" not in line: generate_code_indent -= 4 - if (len(line) == 1 and line[0] == "}"): + if len(line) == 1 and line[0] == "}": # modify next fun generate_code_indent generate_code_indent = -4 return "\"".join(result) @@ -107,7 +114,6 @@ def generate_code(template_file, exec_dict): line = line.replace("\n", "") if line.strip() and line.strip()[0] != "@": line = line.replace("\"", "\\\"") - if line.strip() and line.strip()[0] != "@": line = line.replace("%", "%%") if "print" in line: line = line.replace("%%", "%") @@ -118,12 +124,17 @@ def generate_code(template_file, exec_dict): if "%(" not in str: str = str.replace("%%[", "%[") generate_code_lines.append(str) - # print('\n'.join(generate_code_lines)) c = compile('\n'.join(generate_code_lines), '', 'exec') exec_dict["OUT_STREAM"] = output_stream exec(c, exec_dict) return output_stream.getvalue() +def check_python_version(): + if sys.version_info < (3, 6): + sys.stdout.write("At least python 3.6 is required, but now is " + str(sys.version_info.major) + "." + + str(sys.version_info.minor) + "\n") + sys.exit(1) + generate_code_indent = -4 python_indent = -1 @@ -134,6 +145,7 @@ parser.add_argument("-A", dest="defines", metavar="KEY=VALUE", nargs="*", type=k parser.add_argument("-O", dest="Output_File", nargs=1, help="generate code output file path") if __name__ == "__main__": + check_python_version() parameters = parser.parse_args(sys.argv[1:]) exec_globals = dict(chain(*parameters.defines)) diff --git a/mindspore/lite/experiment/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in b/mindspore/lite/experiment/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in index 4f5f1f885b0..367b36659aa 100644 --- a/mindspore/lite/experiment/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in +++ b/mindspore/lite/experiment/HPC-generator/template_file/gemm_avx512_nhwc_asm.c.in @@ -20,6 +20,9 @@ void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, co const float *bias, const size_t act_flag, const size_t row_block, const size_t col_block, const size_t deep, const size_t src_stride, const size_t dst_stride, const size_t inc_flag) { + @import math + @row_stride_map = {6 : 4, 5 : 5, 4 : 6, 3 : 8, 2 : 12, 1 : 20} + @src_addr_stride = 3 @asm_flag_list = [] @row_split_number = [row for row in range(3, row_block, 3)] @for row in row_split_number: @@ -27,19 +30,18 @@ void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, co @asm_flag_list.append("[dst_" + str(row) + "] " + "\"r\"(dst_" + str(row) + ")"); size_t deep_t = deep >> 3; size_t dst_stride_t = dst_stride << 2; - size_t src_stride_t = src_stride << 2; @col_split_num = col_block >> 4; asm volatile( // inc in deep "and $0x1, %[inc_flag]\\n" "je 0f\\n" @for row in range(0, row_block): - @tmp = int(row / 3) * 3 + @src_addr = int(row / 3) * 3 @for col in range(0, col_split_num): @if row % 3 == 0: - "vmovups @{col * 64}(%[dst_@{tmp}]), %%zmm@{row * col_split_num + col}\\n" + "vmovups @{col * 64}(%[dst_@{src_addr}]), %%zmm@{row * col_split_num + col}\\n" @else: - "vmovups @{col * 64}(%[dst_@{tmp}], %[dst_stride], @{row - tmp}), %%zmm@{row * col_split_num + col}\\n" + "vmovups @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr}), %%zmm@{row * col_split_num + col}\\n" "jmp 2f\\n" "0:\\n" "cmpq $0, %[bias]\\n" @@ -60,8 +62,9 @@ void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, co @print(" : " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, row_block * col_block >>4)]), file=OUT_STREAM) ); @for row in row_split_number: - const float *src_@{row} = src + @{row} * dst_stride; + const float *src_@{row} = src + @{row} * src_stride; @asm_flag_list.append("[src_" + str(row) + "] " + "\"r\"(src_" + str(row) + ")"); + size_t src_stride_t = src_stride << 2; asm volatile( "0:\\n" @loop_count = 8 @@ -69,63 +72,52 @@ void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, co // block @{i} @for col in range(0, col_split_num): "vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n" - @if col_split_num == 6: - @if row_block == 4: - "vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num}\\n" - "vbroadcastss @{i * 4}(%[src_0], %[src_stride], 1), %%zmm@{31 - col_split_num - 1}\\n" + @if row_block * col_split_num + row_block + col_split_num <= 32: + @for row in range(0, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - row + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_block): + @src_index = 31 - col_split_num - row @for col in range(0, col_split_num): - "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{0 * col_split_num + col}\\n" - "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{1 * col_split_num + col}\\n" - "vbroadcastss @{i * 4}(%[src_0], %[src_stride], 2), %%zmm@{31 - col_split_num}\\n" - "vbroadcastss @{i * 4}(%[src_3]), %%zmm@{31 - col_split_num - 1}\\n" - @for col in range(0, col_split_num): - "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n" - "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{3 * col_split_num + col}\\n" - @else: - @for row in range(0, row_block): - @if row == 0: - "vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num - row}\\n" + @weight_index = 31 - col + @dst_index = row * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @else: + @row_stride = 32 - (row_stride_map[col_split_num] + 1) * col_split_num + @row_split_num = math.floor(row_block / row_stride) + @for row_index in range(0, row_split_num): + @row_split_start = row_index * row_stride + @for row in range(row_split_start, row_split_start + row_stride): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" @else: - "vbroadcastss @{i * 4}(%[src_0], %[src_stride], @{row}), %%zmm@{31 - col_split_num - row}\\n" - @for row in range(0, row_block): + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(0, row_stride): + @src_index = 31 - col_split_num - row @for col in range(0, col_split_num): - "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - row}, %%zmm@{row * col_split_num + col}\\n" - @elif col_split_num == 5: - @if row_block == 5: - "vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num}\\n" - "vbroadcastss @{i * 4}(%[src_0], %[src_stride], 1), %%zmm@{31 - col_split_num - 1}\\n" + @weight_index = 31 - col + @dst_index = (row_split_start + row) * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" + @row_split_start = row_split_num * row_stride + @for row in range(row_split_start, row_block): + @src_addr = math.floor(row / src_addr_stride) * src_addr_stride + @src_index = 31 - col_split_num - (row - row_split_start) + @if row % src_addr_stride == 0: + "vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n" + @else: + "vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n" + @for row in range(row_split_start, row_block): + @src_index = 31 - col_split_num - (row - row_split_start) @for col in range(0, col_split_num): - "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{0 * col_split_num + col}\\n" - "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{1 * col_split_num + col}\\n" - "vbroadcastss @{i * 4}(%[src_0], %[src_stride], 2), %%zmm@{31 - col_split_num}\\n" - "vbroadcastss @{i * 4}(%[src_3]), %%zmm@{31 - col_split_num - 1}\\n" - @for col in range(0, col_split_num): - "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n" - "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{3 * col_split_num + col}\\n" - "vbroadcastss @{i * 4}(%[src_3], %[src_stride], 1), %%zmm@{31 - col_split_num}\\n" - @for col in range(0, col_split_num): - "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n" - @else: - @for row in range(0, row_block): - @if row == 0: - "vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num - row}\\n" - @else: - "vbroadcastss @{i * 4}(%[src_0], %[src_stride], @{row}), %%zmm@{31 - col_split_num - row}\\n" - @for row in range(0, row_block): - @for col in range(0, col_split_num): - "vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - row}, %%zmm@{row * col_split_num + col}\\n" - @elif col_split_num == 2: - @for row in range(int(row_block / 6)): - @for j in range(0, 6): - @tmp = int(j / 3) * 3 - @if j % 3 == 0: - "vbroadcastss @{i * 4}(%[src_@{row * 6 + tmp}]), %%zmm@{31 - col_split_num - j}\\n" - @else: - "vbroadcastss @{i * 4}(%[src_@{row * 6 + tmp}], %[src_stride], @{j - tmp}), %%zmm@{31 - col_split_num - j}\\n" - @for col in range(0, col_split_num): - @for j in range(0, 6): - "fmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - j}, %%zmm@{(row * 6 + j) * col_split_num + col}\\n" - + @weight_index = 31 - col + @dst_index = row * col_split_num + col + "vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n" "dec %[deep]\\n" "add $@{col_block * 4 * 8}, %[weight]\\n" "add $@{loop_count * 4}, %[src_0]\\n" @@ -154,12 +146,12 @@ void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, co "vminps %%zmm@{row + col * row_block}, %%zmm30, %%zmm@{row + col * row_block}\\n" "3:\\n" @for row in range(0, row_block): - @tmp = int(row / 3) * 3 + @src_addr = int(row / 3) * 3 @for col in range(0, col_split_num): @if row % 3 == 0: - "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{tmp}])\\n" + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}])\\n" @else: - "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{tmp}], %[dst_stride], @{row - tmp}),\\n" + "vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr})\\n" : @list = ["[src_0] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[deep] \"r\"(deep_t)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst_0] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"] @list.extend(asm_flag_list)