!27822 [MS][LITE][CPU] code generator avx

Merge pull request !27822 from liuzhongkai/code_generate1
This commit is contained in:
i-robot 2021-12-17 11:01:35 +00:00 committed by Gitee
commit 7b1b36daeb
29 changed files with 5033 additions and 408 deletions

View File

@ -121,4 +121,5 @@
#MindSpore Lite
"mindspore/mindspore/lite/experiment/HPC-generator/generator.py" "redefined-builtin"
"mindspore/mindspore/lite/experiment/HPC-generator/generator.py" "exec-used"
"mindspore/mindspore/lite/experiment/HPC-generator/generator.py" "exec-used"
"mindspore/mindspore/lite/experiment/HPC-generator/generator.py" "global-variable-undefined"

View File

@ -149,3 +149,17 @@ mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x80_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x80_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x80_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x64_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x64_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x32_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x32_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x64_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_12x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_12x32_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_6x32_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_7x32_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_8x32_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x64_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x64_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x32_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x64_kernel_nhwc_fp32
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x32_kernel_nhwc_fp32

View File

@ -0,0 +1,537 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_12x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
const float *dst_3 = dst + 3 * dst_stride;
const float *dst_6 = dst + 6 * dst_stride;
const float *dst_9 = dst + 9 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n"
"vmovups 0(%[dst_3]), %%zmm6\n"
"vmovups 64(%[dst_3]), %%zmm7\n"
"vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n"
"vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n"
"vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n"
"vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n"
"vmovups 0(%[dst_6]), %%zmm12\n"
"vmovups 64(%[dst_6]), %%zmm13\n"
"vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n"
"vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n"
"vmovups 0(%[dst_6], %[dst_stride], 2), %%zmm16\n"
"vmovups 64(%[dst_6], %[dst_stride], 2), %%zmm17\n"
"vmovups 0(%[dst_9]), %%zmm18\n"
"vmovups 64(%[dst_9]), %%zmm19\n"
"vmovups 0(%[dst_9], %[dst_stride], 1), %%zmm20\n"
"vmovups 64(%[dst_9], %[dst_stride], 1), %%zmm21\n"
"vmovups 0(%[dst_9], %[dst_stride], 2), %%zmm22\n"
"vmovups 64(%[dst_9], %[dst_stride], 2), %%zmm23\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 0(%[bias]), %%zmm2\n"
"vmovaps 64(%[bias]), %%zmm3\n"
"vmovaps 0(%[bias]), %%zmm4\n"
"vmovaps 64(%[bias]), %%zmm5\n"
"vmovaps 0(%[bias]), %%zmm6\n"
"vmovaps 64(%[bias]), %%zmm7\n"
"vmovaps 0(%[bias]), %%zmm8\n"
"vmovaps 64(%[bias]), %%zmm9\n"
"vmovaps 0(%[bias]), %%zmm10\n"
"vmovaps 64(%[bias]), %%zmm11\n"
"vmovaps 0(%[bias]), %%zmm12\n"
"vmovaps 64(%[bias]), %%zmm13\n"
"vmovaps 0(%[bias]), %%zmm14\n"
"vmovaps 64(%[bias]), %%zmm15\n"
"vmovaps 0(%[bias]), %%zmm16\n"
"vmovaps 64(%[bias]), %%zmm17\n"
"vmovaps 0(%[bias]), %%zmm18\n"
"vmovaps 64(%[bias]), %%zmm19\n"
"vmovaps 0(%[bias]), %%zmm20\n"
"vmovaps 64(%[bias]), %%zmm21\n"
"vmovaps 0(%[bias]), %%zmm22\n"
"vmovaps 64(%[bias]), %%zmm23\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
"vxorps %%zmm14, %%zmm14, %%zmm14\n"
"vxorps %%zmm15, %%zmm15, %%zmm15\n"
"vxorps %%zmm16, %%zmm16, %%zmm16\n"
"vxorps %%zmm17, %%zmm17, %%zmm17\n"
"vxorps %%zmm18, %%zmm18, %%zmm18\n"
"vxorps %%zmm19, %%zmm19, %%zmm19\n"
"vxorps %%zmm20, %%zmm20, %%zmm20\n"
"vxorps %%zmm21, %%zmm21, %%zmm21\n"
"vxorps %%zmm22, %%zmm22, %%zmm22\n"
"vxorps %%zmm23, %%zmm23, %%zmm23\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
[ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22",
"%zmm23");
const float *src_3 = src + 3 * src_stride;
const float *src_6 = src + 6 * src_stride;
const float *src_9 = src + 9 * src_stride;
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vbroadcastss 0(%[src_0]), %%zmm29\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 0(%[src_3]), %%zmm26\n"
"vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vbroadcastss 0(%[src_6]), %%zmm29\n"
"vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 0(%[src_6], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 0(%[src_9]), %%zmm26\n"
"vbroadcastss 0(%[src_9], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 0(%[src_9], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm12\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm14\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm18\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm20\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm22\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm13\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm15\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm19\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm21\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm23\n"
// block 1
"vmovups 128(%[weight]), %%zmm31\n"
"vmovups 192(%[weight]), %%zmm30\n"
"vbroadcastss 4(%[src_0]), %%zmm29\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 4(%[src_3]), %%zmm26\n"
"vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vbroadcastss 4(%[src_6]), %%zmm29\n"
"vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 4(%[src_6], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 4(%[src_9]), %%zmm26\n"
"vbroadcastss 4(%[src_9], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 4(%[src_9], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm12\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm14\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm18\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm20\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm22\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm13\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm15\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm19\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm21\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm23\n"
// block 2
"vmovups 256(%[weight]), %%zmm31\n"
"vmovups 320(%[weight]), %%zmm30\n"
"vbroadcastss 8(%[src_0]), %%zmm29\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 8(%[src_3]), %%zmm26\n"
"vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vbroadcastss 8(%[src_6]), %%zmm29\n"
"vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 8(%[src_6], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 8(%[src_9]), %%zmm26\n"
"vbroadcastss 8(%[src_9], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 8(%[src_9], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm12\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm14\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm18\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm20\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm22\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm13\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm15\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm19\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm21\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm23\n"
// block 3
"vmovups 384(%[weight]), %%zmm31\n"
"vmovups 448(%[weight]), %%zmm30\n"
"vbroadcastss 12(%[src_0]), %%zmm29\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 12(%[src_3]), %%zmm26\n"
"vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vbroadcastss 12(%[src_6]), %%zmm29\n"
"vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 12(%[src_6], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 12(%[src_9]), %%zmm26\n"
"vbroadcastss 12(%[src_9], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 12(%[src_9], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm12\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm14\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm18\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm20\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm22\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm13\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm15\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm19\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm21\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm23\n"
// block 4
"vmovups 512(%[weight]), %%zmm31\n"
"vmovups 576(%[weight]), %%zmm30\n"
"vbroadcastss 16(%[src_0]), %%zmm29\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 16(%[src_3]), %%zmm26\n"
"vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vbroadcastss 16(%[src_6]), %%zmm29\n"
"vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 16(%[src_6], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 16(%[src_9]), %%zmm26\n"
"vbroadcastss 16(%[src_9], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 16(%[src_9], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm12\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm14\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm18\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm20\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm22\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm13\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm15\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm19\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm21\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm23\n"
// block 5
"vmovups 640(%[weight]), %%zmm31\n"
"vmovups 704(%[weight]), %%zmm30\n"
"vbroadcastss 20(%[src_0]), %%zmm29\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 20(%[src_3]), %%zmm26\n"
"vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vbroadcastss 20(%[src_6]), %%zmm29\n"
"vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 20(%[src_6], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 20(%[src_9]), %%zmm26\n"
"vbroadcastss 20(%[src_9], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 20(%[src_9], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm12\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm14\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm18\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm20\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm22\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm13\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm15\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm19\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm21\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm23\n"
// block 6
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vbroadcastss 24(%[src_0]), %%zmm29\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 24(%[src_3]), %%zmm26\n"
"vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vbroadcastss 24(%[src_6]), %%zmm29\n"
"vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 24(%[src_6], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 24(%[src_9]), %%zmm26\n"
"vbroadcastss 24(%[src_9], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 24(%[src_9], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm12\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm14\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm18\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm20\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm22\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm13\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm15\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm19\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm21\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm23\n"
// block 7
"vmovups 896(%[weight]), %%zmm31\n"
"vmovups 960(%[weight]), %%zmm30\n"
"vbroadcastss 28(%[src_0]), %%zmm29\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 28(%[src_3]), %%zmm26\n"
"vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vbroadcastss 28(%[src_6]), %%zmm29\n"
"vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 28(%[src_6], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 28(%[src_9]), %%zmm26\n"
"vbroadcastss 28(%[src_9], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 28(%[src_9], %[src_stride], 2), %%zmm24\n"
"fmadd231ps %%zmm31, %%zmm29, %%zmm12\n"
"fmadd231ps %%zmm31, %%zmm28, %%zmm14\n"
"fmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"fmadd231ps %%zmm31, %%zmm26, %%zmm18\n"
"fmadd231ps %%zmm31, %%zmm25, %%zmm20\n"
"fmadd231ps %%zmm31, %%zmm24, %%zmm22\n"
"fmadd231ps %%zmm30, %%zmm29, %%zmm13\n"
"fmadd231ps %%zmm30, %%zmm28, %%zmm15\n"
"fmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"fmadd231ps %%zmm30, %%zmm26, %%zmm19\n"
"fmadd231ps %%zmm30, %%zmm25, %%zmm21\n"
"fmadd231ps %%zmm30, %%zmm24, %%zmm23\n"
"dec %[deep]\n"
"add $1024, %[weight]\n"
"add $32, %[src_0]\n"
"add $32, %[src_3]\n"
"add $32, %[src_6]\n"
"add $32, %[src_9]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
"vmaxps %%zmm14, %%zmm31, %%zmm14\n"
"vmaxps %%zmm15, %%zmm31, %%zmm15\n"
"vmaxps %%zmm16, %%zmm31, %%zmm16\n"
"vmaxps %%zmm17, %%zmm31, %%zmm17\n"
"vmaxps %%zmm18, %%zmm31, %%zmm18\n"
"vmaxps %%zmm19, %%zmm31, %%zmm19\n"
"vmaxps %%zmm20, %%zmm31, %%zmm20\n"
"vmaxps %%zmm21, %%zmm31, %%zmm21\n"
"vmaxps %%zmm22, %%zmm31, %%zmm22\n"
"vmaxps %%zmm23, %%zmm31, %%zmm23\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"vminps %%zmm10, %%zmm30, %%zmm10\n"
"vminps %%zmm11, %%zmm30, %%zmm11\n"
"vminps %%zmm12, %%zmm30, %%zmm12\n"
"vminps %%zmm13, %%zmm30, %%zmm13\n"
"vminps %%zmm14, %%zmm30, %%zmm14\n"
"vminps %%zmm15, %%zmm30, %%zmm15\n"
"vminps %%zmm16, %%zmm30, %%zmm16\n"
"vminps %%zmm17, %%zmm30, %%zmm17\n"
"vminps %%zmm18, %%zmm30, %%zmm18\n"
"vminps %%zmm19, %%zmm30, %%zmm19\n"
"vminps %%zmm20, %%zmm30, %%zmm20\n"
"vminps %%zmm21, %%zmm30, %%zmm21\n"
"vminps %%zmm22, %%zmm30, %%zmm22\n"
"vminps %%zmm23, %%zmm30, %%zmm23\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm6, 0(%[dst_3])\n"
"vmovups %%zmm7, 64(%[dst_3])\n"
"vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n"
"vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n"
"vmovups %%zmm12, 0(%[dst_6])\n"
"vmovups %%zmm13, 64(%[dst_6])\n"
"vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n"
"vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n"
"vmovups %%zmm16, 0(%[dst_6], %[dst_stride], 2)\n"
"vmovups %%zmm17, 64(%[dst_6], %[dst_stride], 2)\n"
"vmovups %%zmm18, 0(%[dst_9])\n"
"vmovups %%zmm19, 64(%[dst_9])\n"
"vmovups %%zmm20, 0(%[dst_9], %[dst_stride], 1)\n"
"vmovups %%zmm21, 64(%[dst_9], %[dst_stride], 1)\n"
"vmovups %%zmm22, 0(%[dst_9], %[dst_stride], 2)\n"
"vmovups %%zmm23, 64(%[dst_9], %[dst_stride], 2)\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
[ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ dst_9 ] "r"(dst_9), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6),
[ src_9 ] "r"(src_9)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,127 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_1x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1");
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vbroadcastss 0(%[src_0]), %%zmm29\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
// block 1
"vmovups 128(%[weight]), %%zmm31\n"
"vmovups 192(%[weight]), %%zmm30\n"
"vbroadcastss 4(%[src_0]), %%zmm29\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
// block 2
"vmovups 256(%[weight]), %%zmm31\n"
"vmovups 320(%[weight]), %%zmm30\n"
"vbroadcastss 8(%[src_0]), %%zmm29\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
// block 3
"vmovups 384(%[weight]), %%zmm31\n"
"vmovups 448(%[weight]), %%zmm30\n"
"vbroadcastss 12(%[src_0]), %%zmm29\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
// block 4
"vmovups 512(%[weight]), %%zmm31\n"
"vmovups 576(%[weight]), %%zmm30\n"
"vbroadcastss 16(%[src_0]), %%zmm29\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
// block 5
"vmovups 640(%[weight]), %%zmm31\n"
"vmovups 704(%[weight]), %%zmm30\n"
"vbroadcastss 20(%[src_0]), %%zmm29\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
// block 6
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vbroadcastss 24(%[src_0]), %%zmm29\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
// block 7
"vmovups 896(%[weight]), %%zmm31\n"
"vmovups 960(%[weight]), %%zmm30\n"
"vbroadcastss 28(%[src_0]), %%zmm29\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"dec %[deep]\n"
"add $1024, %[weight]\n"
"add $32, %[src_0]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,171 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_1x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 128(%[dst_0]), %%zmm2\n"
"vmovups 192(%[dst_0]), %%zmm3\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 128(%[bias]), %%zmm2\n"
"vmovaps 192(%[bias]), %%zmm3\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3");
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vmovups 128(%[weight]), %%zmm29\n"
"vmovups 192(%[weight]), %%zmm28\n"
"vbroadcastss 0(%[src_0]), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
// block 1
"vmovups 256(%[weight]), %%zmm31\n"
"vmovups 320(%[weight]), %%zmm30\n"
"vmovups 384(%[weight]), %%zmm29\n"
"vmovups 448(%[weight]), %%zmm28\n"
"vbroadcastss 4(%[src_0]), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
// block 2
"vmovups 512(%[weight]), %%zmm31\n"
"vmovups 576(%[weight]), %%zmm30\n"
"vmovups 640(%[weight]), %%zmm29\n"
"vmovups 704(%[weight]), %%zmm28\n"
"vbroadcastss 8(%[src_0]), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
// block 3
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vmovups 896(%[weight]), %%zmm29\n"
"vmovups 960(%[weight]), %%zmm28\n"
"vbroadcastss 12(%[src_0]), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
// block 4
"vmovups 1024(%[weight]), %%zmm31\n"
"vmovups 1088(%[weight]), %%zmm30\n"
"vmovups 1152(%[weight]), %%zmm29\n"
"vmovups 1216(%[weight]), %%zmm28\n"
"vbroadcastss 16(%[src_0]), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
// block 5
"vmovups 1280(%[weight]), %%zmm31\n"
"vmovups 1344(%[weight]), %%zmm30\n"
"vmovups 1408(%[weight]), %%zmm29\n"
"vmovups 1472(%[weight]), %%zmm28\n"
"vbroadcastss 20(%[src_0]), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
// block 6
"vmovups 1536(%[weight]), %%zmm31\n"
"vmovups 1600(%[weight]), %%zmm30\n"
"vmovups 1664(%[weight]), %%zmm29\n"
"vmovups 1728(%[weight]), %%zmm28\n"
"vbroadcastss 24(%[src_0]), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
// block 7
"vmovups 1792(%[weight]), %%zmm31\n"
"vmovups 1856(%[weight]), %%zmm30\n"
"vmovups 1920(%[weight]), %%zmm29\n"
"vmovups 1984(%[weight]), %%zmm28\n"
"vbroadcastss 28(%[src_0]), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"dec %[deep]\n"
"add $2048, %[weight]\n"
"add $32, %[src_0]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -22,7 +22,6 @@ void nnacl_gemm_avx512_1x80_kernel_nhwc_fp32(float *dst, const float *src, const
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
@ -52,6 +51,7 @@ void nnacl_gemm_avx512_1x80_kernel_nhwc_fp32(float *dst, const float *src, const
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4");
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
@ -150,7 +150,6 @@ void nnacl_gemm_avx512_1x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"dec %[deep]\n"
"add $2560, %[weight]\n"
"add $32, %[src_0]\n"

View File

@ -22,7 +22,6 @@ void nnacl_gemm_avx512_1x96_kernel_nhwc_fp32(float *dst, const float *src, const
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
@ -55,6 +54,7 @@ void nnacl_gemm_avx512_1x96_kernel_nhwc_fp32(float *dst, const float *src, const
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5");
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
@ -169,7 +169,6 @@ void nnacl_gemm_avx512_1x96_kernel_nhwc_fp32(float *dst, const float *src, const
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"dec %[deep]\n"
"add $3072, %[weight]\n"
"add $32, %[src_0]\n"

View File

@ -0,0 +1,163 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_2x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 0(%[bias]), %%zmm2\n"
"vmovaps 64(%[bias]), %%zmm3\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3");
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vbroadcastss 0(%[src_0]), %%zmm29\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
// block 1
"vmovups 128(%[weight]), %%zmm31\n"
"vmovups 192(%[weight]), %%zmm30\n"
"vbroadcastss 4(%[src_0]), %%zmm29\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
// block 2
"vmovups 256(%[weight]), %%zmm31\n"
"vmovups 320(%[weight]), %%zmm30\n"
"vbroadcastss 8(%[src_0]), %%zmm29\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
// block 3
"vmovups 384(%[weight]), %%zmm31\n"
"vmovups 448(%[weight]), %%zmm30\n"
"vbroadcastss 12(%[src_0]), %%zmm29\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
// block 4
"vmovups 512(%[weight]), %%zmm31\n"
"vmovups 576(%[weight]), %%zmm30\n"
"vbroadcastss 16(%[src_0]), %%zmm29\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
// block 5
"vmovups 640(%[weight]), %%zmm31\n"
"vmovups 704(%[weight]), %%zmm30\n"
"vbroadcastss 20(%[src_0]), %%zmm29\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
// block 6
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vbroadcastss 24(%[src_0]), %%zmm29\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
// block 7
"vmovups 896(%[weight]), %%zmm31\n"
"vmovups 960(%[weight]), %%zmm30\n"
"vbroadcastss 28(%[src_0]), %%zmm29\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"dec %[deep]\n"
"add $1024, %[weight]\n"
"add $32, %[src_0]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,235 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_2x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 128(%[dst_0]), %%zmm2\n"
"vmovups 192(%[dst_0]), %%zmm3\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n"
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n"
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 128(%[bias]), %%zmm2\n"
"vmovaps 192(%[bias]), %%zmm3\n"
"vmovaps 0(%[bias]), %%zmm4\n"
"vmovaps 64(%[bias]), %%zmm5\n"
"vmovaps 128(%[bias]), %%zmm6\n"
"vmovaps 192(%[bias]), %%zmm7\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7");
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vmovups 128(%[weight]), %%zmm29\n"
"vmovups 192(%[weight]), %%zmm28\n"
"vbroadcastss 0(%[src_0]), %%zmm27\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
// block 1
"vmovups 256(%[weight]), %%zmm31\n"
"vmovups 320(%[weight]), %%zmm30\n"
"vmovups 384(%[weight]), %%zmm29\n"
"vmovups 448(%[weight]), %%zmm28\n"
"vbroadcastss 4(%[src_0]), %%zmm27\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
// block 2
"vmovups 512(%[weight]), %%zmm31\n"
"vmovups 576(%[weight]), %%zmm30\n"
"vmovups 640(%[weight]), %%zmm29\n"
"vmovups 704(%[weight]), %%zmm28\n"
"vbroadcastss 8(%[src_0]), %%zmm27\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
// block 3
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vmovups 896(%[weight]), %%zmm29\n"
"vmovups 960(%[weight]), %%zmm28\n"
"vbroadcastss 12(%[src_0]), %%zmm27\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
// block 4
"vmovups 1024(%[weight]), %%zmm31\n"
"vmovups 1088(%[weight]), %%zmm30\n"
"vmovups 1152(%[weight]), %%zmm29\n"
"vmovups 1216(%[weight]), %%zmm28\n"
"vbroadcastss 16(%[src_0]), %%zmm27\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
// block 5
"vmovups 1280(%[weight]), %%zmm31\n"
"vmovups 1344(%[weight]), %%zmm30\n"
"vmovups 1408(%[weight]), %%zmm29\n"
"vmovups 1472(%[weight]), %%zmm28\n"
"vbroadcastss 20(%[src_0]), %%zmm27\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
// block 6
"vmovups 1536(%[weight]), %%zmm31\n"
"vmovups 1600(%[weight]), %%zmm30\n"
"vmovups 1664(%[weight]), %%zmm29\n"
"vmovups 1728(%[weight]), %%zmm28\n"
"vbroadcastss 24(%[src_0]), %%zmm27\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
// block 7
"vmovups 1792(%[weight]), %%zmm31\n"
"vmovups 1856(%[weight]), %%zmm30\n"
"vmovups 1920(%[weight]), %%zmm29\n"
"vmovups 1984(%[weight]), %%zmm28\n"
"vbroadcastss 28(%[src_0]), %%zmm27\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"dec %[deep]\n"
"add $2048, %[weight]\n"
"add $32, %[src_0]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -22,7 +22,6 @@ void nnacl_gemm_avx512_2x80_kernel_nhwc_fp32(float *dst, const float *src, const
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
@ -67,6 +66,7 @@ void nnacl_gemm_avx512_2x80_kernel_nhwc_fp32(float *dst, const float *src, const
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9");
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
@ -213,7 +213,6 @@ void nnacl_gemm_avx512_2x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"dec %[deep]\n"
"add $2560, %[weight]\n"
"add $32, %[src_0]\n"
@ -258,11 +257,11 @@ void nnacl_gemm_avx512_2x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)

View File

@ -22,7 +22,6 @@ void nnacl_gemm_avx512_2x96_kernel_nhwc_fp32(float *dst, const float *src, const
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
@ -73,6 +72,7 @@ void nnacl_gemm_avx512_2x96_kernel_nhwc_fp32(float *dst, const float *src, const
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11");
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
@ -243,7 +243,6 @@ void nnacl_gemm_avx512_2x96_kernel_nhwc_fp32(float *dst, const float *src, const
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"dec %[deep]\n"
"add $3072, %[weight]\n"
"add $32, %[src_0]\n"
@ -293,12 +292,12 @@ void nnacl_gemm_avx512_2x96_kernel_nhwc_fp32(float *dst, const float *src, const
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
"vmovups %%zmm5, 320(%[dst_0])\n"
"vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1)\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)

View File

@ -0,0 +1,199 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_3x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 0(%[bias]), %%zmm2\n"
"vmovaps 64(%[bias]), %%zmm3\n"
"vmovaps 0(%[bias]), %%zmm4\n"
"vmovaps 64(%[bias]), %%zmm5\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5");
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vbroadcastss 0(%[src_0]), %%zmm29\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
// block 1
"vmovups 128(%[weight]), %%zmm31\n"
"vmovups 192(%[weight]), %%zmm30\n"
"vbroadcastss 4(%[src_0]), %%zmm29\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
// block 2
"vmovups 256(%[weight]), %%zmm31\n"
"vmovups 320(%[weight]), %%zmm30\n"
"vbroadcastss 8(%[src_0]), %%zmm29\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
// block 3
"vmovups 384(%[weight]), %%zmm31\n"
"vmovups 448(%[weight]), %%zmm30\n"
"vbroadcastss 12(%[src_0]), %%zmm29\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
// block 4
"vmovups 512(%[weight]), %%zmm31\n"
"vmovups 576(%[weight]), %%zmm30\n"
"vbroadcastss 16(%[src_0]), %%zmm29\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
// block 5
"vmovups 640(%[weight]), %%zmm31\n"
"vmovups 704(%[weight]), %%zmm30\n"
"vbroadcastss 20(%[src_0]), %%zmm29\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
// block 6
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vbroadcastss 24(%[src_0]), %%zmm29\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
// block 7
"vmovups 896(%[weight]), %%zmm31\n"
"vmovups 960(%[weight]), %%zmm30\n"
"vbroadcastss 28(%[src_0]), %%zmm29\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"dec %[deep]\n"
"add $1024, %[weight]\n"
"add $32, %[src_0]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,299 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_3x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 128(%[dst_0]), %%zmm2\n"
"vmovups 192(%[dst_0]), %%zmm3\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n"
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n"
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n"
"vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n"
"vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 128(%[bias]), %%zmm2\n"
"vmovaps 192(%[bias]), %%zmm3\n"
"vmovaps 0(%[bias]), %%zmm4\n"
"vmovaps 64(%[bias]), %%zmm5\n"
"vmovaps 128(%[bias]), %%zmm6\n"
"vmovaps 192(%[bias]), %%zmm7\n"
"vmovaps 0(%[bias]), %%zmm8\n"
"vmovaps 64(%[bias]), %%zmm9\n"
"vmovaps 128(%[bias]), %%zmm10\n"
"vmovaps 192(%[bias]), %%zmm11\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11");
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vmovups 128(%[weight]), %%zmm29\n"
"vmovups 192(%[weight]), %%zmm28\n"
"vbroadcastss 0(%[src_0]), %%zmm27\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
// block 1
"vmovups 256(%[weight]), %%zmm31\n"
"vmovups 320(%[weight]), %%zmm30\n"
"vmovups 384(%[weight]), %%zmm29\n"
"vmovups 448(%[weight]), %%zmm28\n"
"vbroadcastss 4(%[src_0]), %%zmm27\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
// block 2
"vmovups 512(%[weight]), %%zmm31\n"
"vmovups 576(%[weight]), %%zmm30\n"
"vmovups 640(%[weight]), %%zmm29\n"
"vmovups 704(%[weight]), %%zmm28\n"
"vbroadcastss 8(%[src_0]), %%zmm27\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
// block 3
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vmovups 896(%[weight]), %%zmm29\n"
"vmovups 960(%[weight]), %%zmm28\n"
"vbroadcastss 12(%[src_0]), %%zmm27\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
// block 4
"vmovups 1024(%[weight]), %%zmm31\n"
"vmovups 1088(%[weight]), %%zmm30\n"
"vmovups 1152(%[weight]), %%zmm29\n"
"vmovups 1216(%[weight]), %%zmm28\n"
"vbroadcastss 16(%[src_0]), %%zmm27\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
// block 5
"vmovups 1280(%[weight]), %%zmm31\n"
"vmovups 1344(%[weight]), %%zmm30\n"
"vmovups 1408(%[weight]), %%zmm29\n"
"vmovups 1472(%[weight]), %%zmm28\n"
"vbroadcastss 20(%[src_0]), %%zmm27\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
// block 6
"vmovups 1536(%[weight]), %%zmm31\n"
"vmovups 1600(%[weight]), %%zmm30\n"
"vmovups 1664(%[weight]), %%zmm29\n"
"vmovups 1728(%[weight]), %%zmm28\n"
"vbroadcastss 24(%[src_0]), %%zmm27\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
// block 7
"vmovups 1792(%[weight]), %%zmm31\n"
"vmovups 1856(%[weight]), %%zmm30\n"
"vmovups 1920(%[weight]), %%zmm29\n"
"vmovups 1984(%[weight]), %%zmm28\n"
"vbroadcastss 28(%[src_0]), %%zmm27\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"dec %[deep]\n"
"add $2048, %[weight]\n"
"add $32, %[src_0]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"vminps %%zmm10, %%zmm30, %%zmm10\n"
"vminps %%zmm11, %%zmm30, %%zmm11\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -22,7 +22,6 @@ void nnacl_gemm_avx512_3x80_kernel_nhwc_fp32(float *dst, const float *src, const
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
@ -83,6 +82,7 @@ void nnacl_gemm_avx512_3x80_kernel_nhwc_fp32(float *dst, const float *src, const
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13", "%zmm14");
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
@ -277,7 +277,6 @@ void nnacl_gemm_avx512_3x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
"dec %[deep]\n"
"add $2560, %[weight]\n"
"add $32, %[src_0]\n"
@ -332,16 +331,16 @@ void nnacl_gemm_avx512_3x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2)\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)

View File

@ -22,7 +22,6 @@ void nnacl_gemm_avx512_3x96_kernel_nhwc_fp32(float *dst, const float *src, const
const size_t inc_flag) {
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
@ -92,6 +91,7 @@ void nnacl_gemm_avx512_3x96_kernel_nhwc_fp32(float *dst, const float *src, const
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17");
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
@ -318,7 +318,6 @@ void nnacl_gemm_avx512_3x96_kernel_nhwc_fp32(float *dst, const float *src, const
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
"dec %[deep]\n"
"add $3072, %[weight]\n"
"add $32, %[src_0]\n"
@ -380,18 +379,18 @@ void nnacl_gemm_avx512_3x96_kernel_nhwc_fp32(float *dst, const float *src, const
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
"vmovups %%zmm5, 320(%[dst_0])\n"
"vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2)\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)

View File

@ -0,0 +1,240 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_4x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
const float *dst_3 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n"
"vmovups 0(%[dst_3]), %%zmm6\n"
"vmovups 64(%[dst_3]), %%zmm7\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 0(%[bias]), %%zmm2\n"
"vmovaps 64(%[bias]), %%zmm3\n"
"vmovaps 0(%[bias]), %%zmm4\n"
"vmovaps 64(%[bias]), %%zmm5\n"
"vmovaps 0(%[bias]), %%zmm6\n"
"vmovaps 64(%[bias]), %%zmm7\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
[ dst_3 ] "r"(dst_3)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7");
const float *src_3 = src + 3 * src_stride;
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vbroadcastss 0(%[src_0]), %%zmm29\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 0(%[src_3]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
// block 1
"vmovups 128(%[weight]), %%zmm31\n"
"vmovups 192(%[weight]), %%zmm30\n"
"vbroadcastss 4(%[src_0]), %%zmm29\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 4(%[src_3]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
// block 2
"vmovups 256(%[weight]), %%zmm31\n"
"vmovups 320(%[weight]), %%zmm30\n"
"vbroadcastss 8(%[src_0]), %%zmm29\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 8(%[src_3]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
// block 3
"vmovups 384(%[weight]), %%zmm31\n"
"vmovups 448(%[weight]), %%zmm30\n"
"vbroadcastss 12(%[src_0]), %%zmm29\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 12(%[src_3]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
// block 4
"vmovups 512(%[weight]), %%zmm31\n"
"vmovups 576(%[weight]), %%zmm30\n"
"vbroadcastss 16(%[src_0]), %%zmm29\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 16(%[src_3]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
// block 5
"vmovups 640(%[weight]), %%zmm31\n"
"vmovups 704(%[weight]), %%zmm30\n"
"vbroadcastss 20(%[src_0]), %%zmm29\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 20(%[src_3]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
// block 6
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vbroadcastss 24(%[src_0]), %%zmm29\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 24(%[src_3]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
// block 7
"vmovups 896(%[weight]), %%zmm31\n"
"vmovups 960(%[weight]), %%zmm30\n"
"vbroadcastss 28(%[src_0]), %%zmm29\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 28(%[src_3]), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"dec %[deep]\n"
"add $1024, %[weight]\n"
"add $32, %[src_0]\n"
"add $32, %[src_3]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm6, 0(%[dst_3])\n"
"vmovups %%zmm7, 64(%[dst_3])\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
[ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,369 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_4x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
const float *dst_3 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 128(%[dst_0]), %%zmm2\n"
"vmovups 192(%[dst_0]), %%zmm3\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n"
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n"
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n"
"vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n"
"vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n"
"vmovups 0(%[dst_3]), %%zmm12\n"
"vmovups 64(%[dst_3]), %%zmm13\n"
"vmovups 128(%[dst_3]), %%zmm14\n"
"vmovups 192(%[dst_3]), %%zmm15\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 128(%[bias]), %%zmm2\n"
"vmovaps 192(%[bias]), %%zmm3\n"
"vmovaps 0(%[bias]), %%zmm4\n"
"vmovaps 64(%[bias]), %%zmm5\n"
"vmovaps 128(%[bias]), %%zmm6\n"
"vmovaps 192(%[bias]), %%zmm7\n"
"vmovaps 0(%[bias]), %%zmm8\n"
"vmovaps 64(%[bias]), %%zmm9\n"
"vmovaps 128(%[bias]), %%zmm10\n"
"vmovaps 192(%[bias]), %%zmm11\n"
"vmovaps 0(%[bias]), %%zmm12\n"
"vmovaps 64(%[bias]), %%zmm13\n"
"vmovaps 128(%[bias]), %%zmm14\n"
"vmovaps 192(%[bias]), %%zmm15\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
"vxorps %%zmm14, %%zmm14, %%zmm14\n"
"vxorps %%zmm15, %%zmm15, %%zmm15\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
[ dst_3 ] "r"(dst_3)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13", "%zmm14", "%zmm15");
const float *src_3 = src + 3 * src_stride;
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vmovups 128(%[weight]), %%zmm29\n"
"vmovups 192(%[weight]), %%zmm28\n"
"vbroadcastss 0(%[src_0]), %%zmm27\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 0(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
// block 1
"vmovups 256(%[weight]), %%zmm31\n"
"vmovups 320(%[weight]), %%zmm30\n"
"vmovups 384(%[weight]), %%zmm29\n"
"vmovups 448(%[weight]), %%zmm28\n"
"vbroadcastss 4(%[src_0]), %%zmm27\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 4(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
// block 2
"vmovups 512(%[weight]), %%zmm31\n"
"vmovups 576(%[weight]), %%zmm30\n"
"vmovups 640(%[weight]), %%zmm29\n"
"vmovups 704(%[weight]), %%zmm28\n"
"vbroadcastss 8(%[src_0]), %%zmm27\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 8(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
// block 3
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vmovups 896(%[weight]), %%zmm29\n"
"vmovups 960(%[weight]), %%zmm28\n"
"vbroadcastss 12(%[src_0]), %%zmm27\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 12(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
// block 4
"vmovups 1024(%[weight]), %%zmm31\n"
"vmovups 1088(%[weight]), %%zmm30\n"
"vmovups 1152(%[weight]), %%zmm29\n"
"vmovups 1216(%[weight]), %%zmm28\n"
"vbroadcastss 16(%[src_0]), %%zmm27\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 16(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
// block 5
"vmovups 1280(%[weight]), %%zmm31\n"
"vmovups 1344(%[weight]), %%zmm30\n"
"vmovups 1408(%[weight]), %%zmm29\n"
"vmovups 1472(%[weight]), %%zmm28\n"
"vbroadcastss 20(%[src_0]), %%zmm27\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 20(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
// block 6
"vmovups 1536(%[weight]), %%zmm31\n"
"vmovups 1600(%[weight]), %%zmm30\n"
"vmovups 1664(%[weight]), %%zmm29\n"
"vmovups 1728(%[weight]), %%zmm28\n"
"vbroadcastss 24(%[src_0]), %%zmm27\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 24(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
// block 7
"vmovups 1792(%[weight]), %%zmm31\n"
"vmovups 1856(%[weight]), %%zmm30\n"
"vmovups 1920(%[weight]), %%zmm29\n"
"vmovups 1984(%[weight]), %%zmm28\n"
"vbroadcastss 28(%[src_0]), %%zmm27\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 28(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"dec %[deep]\n"
"add $2048, %[weight]\n"
"add $32, %[src_0]\n"
"add $32, %[src_3]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
"vmaxps %%zmm14, %%zmm31, %%zmm14\n"
"vmaxps %%zmm15, %%zmm31, %%zmm15\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"vminps %%zmm10, %%zmm30, %%zmm10\n"
"vminps %%zmm11, %%zmm30, %%zmm11\n"
"vminps %%zmm12, %%zmm30, %%zmm12\n"
"vminps %%zmm13, %%zmm30, %%zmm13\n"
"vminps %%zmm14, %%zmm30, %%zmm14\n"
"vminps %%zmm15, %%zmm30, %%zmm15\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm12, 0(%[dst_3])\n"
"vmovups %%zmm13, 64(%[dst_3])\n"
"vmovups %%zmm14, 128(%[dst_3])\n"
"vmovups %%zmm15, 192(%[dst_3])\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
[ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -23,7 +23,6 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const
const float *dst_3 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
@ -100,7 +99,8 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const
[ dst_3 ] "r"(dst_3)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19");
const float *src_3 = src + 3 * dst_stride;
const float *src_3 = src + 3 * src_stride;
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
@ -112,7 +112,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 0(%[src_0]), %%zmm26\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 0(%[src_0], %[src_stride], 3), %%zmm23\n"
"vbroadcastss 0(%[src_3]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
@ -142,7 +142,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 4(%[src_0]), %%zmm26\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 4(%[src_0], %[src_stride], 3), %%zmm23\n"
"vbroadcastss 4(%[src_3]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
@ -172,7 +172,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 8(%[src_0]), %%zmm26\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 8(%[src_0], %[src_stride], 3), %%zmm23\n"
"vbroadcastss 8(%[src_3]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
@ -202,7 +202,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 12(%[src_0]), %%zmm26\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 12(%[src_0], %[src_stride], 3), %%zmm23\n"
"vbroadcastss 12(%[src_3]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
@ -232,7 +232,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 16(%[src_0]), %%zmm26\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 16(%[src_0], %[src_stride], 3), %%zmm23\n"
"vbroadcastss 16(%[src_3]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
@ -262,7 +262,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 20(%[src_0]), %%zmm26\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 20(%[src_0], %[src_stride], 3), %%zmm23\n"
"vbroadcastss 20(%[src_3]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
@ -292,7 +292,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 24(%[src_0]), %%zmm26\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 24(%[src_0], %[src_stride], 3), %%zmm23\n"
"vbroadcastss 24(%[src_3]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
@ -322,7 +322,7 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 28(%[src_0]), %%zmm26\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 28(%[src_0], %[src_stride], 3), %%zmm23\n"
"vbroadcastss 28(%[src_3]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
@ -343,7 +343,6 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
"dec %[deep]\n"
"add $2560, %[weight]\n"
"add $32, %[src_0]\n"
@ -409,16 +408,16 @@ void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm15, 0(%[dst_3])\n"
"vmovups %%zmm16, 64(%[dst_3])\n"
"vmovups %%zmm17, 128(%[dst_3])\n"

View File

@ -23,7 +23,6 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const
const float *dst_3 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
@ -113,7 +112,8 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22",
"%zmm23");
const float *src_3 = src + 3 * dst_stride;
const float *src_3 = src + 3 * src_stride;
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
@ -126,30 +126,30 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 0(%[src_0]), %%zmm25\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 0(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
// block 1
"vmovups 384(%[weight]), %%zmm31\n"
@ -161,30 +161,30 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 4(%[src_0]), %%zmm25\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 4(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
// block 2
"vmovups 768(%[weight]), %%zmm31\n"
@ -196,30 +196,30 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 8(%[src_0]), %%zmm25\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 8(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
// block 3
"vmovups 1152(%[weight]), %%zmm31\n"
@ -231,30 +231,30 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 12(%[src_0]), %%zmm25\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 12(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
// block 4
"vmovups 1536(%[weight]), %%zmm31\n"
@ -266,30 +266,30 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 16(%[src_0]), %%zmm25\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 16(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
// block 5
"vmovups 1920(%[weight]), %%zmm31\n"
@ -301,30 +301,30 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 20(%[src_0]), %%zmm25\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 20(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
// block 6
"vmovups 2304(%[weight]), %%zmm31\n"
@ -336,30 +336,30 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 24(%[src_0]), %%zmm25\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 24(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
// block 7
"vmovups 2688(%[weight]), %%zmm31\n"
@ -371,32 +371,31 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 28(%[src_0]), %%zmm25\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 28(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
"dec %[deep]\n"
"add $3072, %[weight]\n"
"add $32, %[src_0]\n"
@ -471,18 +470,18 @@ void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
"vmovups %%zmm5, 320(%[dst_0])\n"
"vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm18, 0(%[dst_3])\n"
"vmovups %%zmm19, 64(%[dst_3])\n"
"vmovups %%zmm20, 128(%[dst_3])\n"

View File

@ -0,0 +1,276 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_5x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
const float *dst_3 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n"
"vmovups 0(%[dst_3]), %%zmm6\n"
"vmovups 64(%[dst_3]), %%zmm7\n"
"vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n"
"vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 0(%[bias]), %%zmm2\n"
"vmovaps 64(%[bias]), %%zmm3\n"
"vmovaps 0(%[bias]), %%zmm4\n"
"vmovaps 64(%[bias]), %%zmm5\n"
"vmovaps 0(%[bias]), %%zmm6\n"
"vmovaps 64(%[bias]), %%zmm7\n"
"vmovaps 0(%[bias]), %%zmm8\n"
"vmovaps 64(%[bias]), %%zmm9\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
[ dst_3 ] "r"(dst_3)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9");
const float *src_3 = src + 3 * src_stride;
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vbroadcastss 0(%[src_0]), %%zmm29\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 0(%[src_3]), %%zmm26\n"
"vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
// block 1
"vmovups 128(%[weight]), %%zmm31\n"
"vmovups 192(%[weight]), %%zmm30\n"
"vbroadcastss 4(%[src_0]), %%zmm29\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 4(%[src_3]), %%zmm26\n"
"vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
// block 2
"vmovups 256(%[weight]), %%zmm31\n"
"vmovups 320(%[weight]), %%zmm30\n"
"vbroadcastss 8(%[src_0]), %%zmm29\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 8(%[src_3]), %%zmm26\n"
"vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
// block 3
"vmovups 384(%[weight]), %%zmm31\n"
"vmovups 448(%[weight]), %%zmm30\n"
"vbroadcastss 12(%[src_0]), %%zmm29\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 12(%[src_3]), %%zmm26\n"
"vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
// block 4
"vmovups 512(%[weight]), %%zmm31\n"
"vmovups 576(%[weight]), %%zmm30\n"
"vbroadcastss 16(%[src_0]), %%zmm29\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 16(%[src_3]), %%zmm26\n"
"vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
// block 5
"vmovups 640(%[weight]), %%zmm31\n"
"vmovups 704(%[weight]), %%zmm30\n"
"vbroadcastss 20(%[src_0]), %%zmm29\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 20(%[src_3]), %%zmm26\n"
"vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
// block 6
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vbroadcastss 24(%[src_0]), %%zmm29\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 24(%[src_3]), %%zmm26\n"
"vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
// block 7
"vmovups 896(%[weight]), %%zmm31\n"
"vmovups 960(%[weight]), %%zmm30\n"
"vbroadcastss 28(%[src_0]), %%zmm29\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 28(%[src_3]), %%zmm26\n"
"vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"dec %[deep]\n"
"add $1024, %[weight]\n"
"add $32, %[src_0]\n"
"add $32, %[src_3]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm6, 0(%[dst_3])\n"
"vmovups %%zmm7, 64(%[dst_3])\n"
"vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
[ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,433 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_5x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
const float *dst_3 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 128(%[dst_0]), %%zmm2\n"
"vmovups 192(%[dst_0]), %%zmm3\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n"
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n"
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n"
"vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n"
"vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n"
"vmovups 0(%[dst_3]), %%zmm12\n"
"vmovups 64(%[dst_3]), %%zmm13\n"
"vmovups 128(%[dst_3]), %%zmm14\n"
"vmovups 192(%[dst_3]), %%zmm15\n"
"vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n"
"vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n"
"vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n"
"vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 128(%[bias]), %%zmm2\n"
"vmovaps 192(%[bias]), %%zmm3\n"
"vmovaps 0(%[bias]), %%zmm4\n"
"vmovaps 64(%[bias]), %%zmm5\n"
"vmovaps 128(%[bias]), %%zmm6\n"
"vmovaps 192(%[bias]), %%zmm7\n"
"vmovaps 0(%[bias]), %%zmm8\n"
"vmovaps 64(%[bias]), %%zmm9\n"
"vmovaps 128(%[bias]), %%zmm10\n"
"vmovaps 192(%[bias]), %%zmm11\n"
"vmovaps 0(%[bias]), %%zmm12\n"
"vmovaps 64(%[bias]), %%zmm13\n"
"vmovaps 128(%[bias]), %%zmm14\n"
"vmovaps 192(%[bias]), %%zmm15\n"
"vmovaps 0(%[bias]), %%zmm16\n"
"vmovaps 64(%[bias]), %%zmm17\n"
"vmovaps 128(%[bias]), %%zmm18\n"
"vmovaps 192(%[bias]), %%zmm19\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
"vxorps %%zmm14, %%zmm14, %%zmm14\n"
"vxorps %%zmm15, %%zmm15, %%zmm15\n"
"vxorps %%zmm16, %%zmm16, %%zmm16\n"
"vxorps %%zmm17, %%zmm17, %%zmm17\n"
"vxorps %%zmm18, %%zmm18, %%zmm18\n"
"vxorps %%zmm19, %%zmm19, %%zmm19\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
[ dst_3 ] "r"(dst_3)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19");
const float *src_3 = src + 3 * src_stride;
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vmovups 128(%[weight]), %%zmm29\n"
"vmovups 192(%[weight]), %%zmm28\n"
"vbroadcastss 0(%[src_0]), %%zmm27\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 0(%[src_3]), %%zmm24\n"
"vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n"
// block 1
"vmovups 256(%[weight]), %%zmm31\n"
"vmovups 320(%[weight]), %%zmm30\n"
"vmovups 384(%[weight]), %%zmm29\n"
"vmovups 448(%[weight]), %%zmm28\n"
"vbroadcastss 4(%[src_0]), %%zmm27\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 4(%[src_3]), %%zmm24\n"
"vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n"
// block 2
"vmovups 512(%[weight]), %%zmm31\n"
"vmovups 576(%[weight]), %%zmm30\n"
"vmovups 640(%[weight]), %%zmm29\n"
"vmovups 704(%[weight]), %%zmm28\n"
"vbroadcastss 8(%[src_0]), %%zmm27\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 8(%[src_3]), %%zmm24\n"
"vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n"
// block 3
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vmovups 896(%[weight]), %%zmm29\n"
"vmovups 960(%[weight]), %%zmm28\n"
"vbroadcastss 12(%[src_0]), %%zmm27\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 12(%[src_3]), %%zmm24\n"
"vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n"
// block 4
"vmovups 1024(%[weight]), %%zmm31\n"
"vmovups 1088(%[weight]), %%zmm30\n"
"vmovups 1152(%[weight]), %%zmm29\n"
"vmovups 1216(%[weight]), %%zmm28\n"
"vbroadcastss 16(%[src_0]), %%zmm27\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 16(%[src_3]), %%zmm24\n"
"vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n"
// block 5
"vmovups 1280(%[weight]), %%zmm31\n"
"vmovups 1344(%[weight]), %%zmm30\n"
"vmovups 1408(%[weight]), %%zmm29\n"
"vmovups 1472(%[weight]), %%zmm28\n"
"vbroadcastss 20(%[src_0]), %%zmm27\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 20(%[src_3]), %%zmm24\n"
"vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n"
// block 6
"vmovups 1536(%[weight]), %%zmm31\n"
"vmovups 1600(%[weight]), %%zmm30\n"
"vmovups 1664(%[weight]), %%zmm29\n"
"vmovups 1728(%[weight]), %%zmm28\n"
"vbroadcastss 24(%[src_0]), %%zmm27\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 24(%[src_3]), %%zmm24\n"
"vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n"
// block 7
"vmovups 1792(%[weight]), %%zmm31\n"
"vmovups 1856(%[weight]), %%zmm30\n"
"vmovups 1920(%[weight]), %%zmm29\n"
"vmovups 1984(%[weight]), %%zmm28\n"
"vbroadcastss 28(%[src_0]), %%zmm27\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 28(%[src_3]), %%zmm24\n"
"vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm23, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm23, %%zmm19\n"
"dec %[deep]\n"
"add $2048, %[weight]\n"
"add $32, %[src_0]\n"
"add $32, %[src_3]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
"vmaxps %%zmm14, %%zmm31, %%zmm14\n"
"vmaxps %%zmm15, %%zmm31, %%zmm15\n"
"vmaxps %%zmm16, %%zmm31, %%zmm16\n"
"vmaxps %%zmm17, %%zmm31, %%zmm17\n"
"vmaxps %%zmm18, %%zmm31, %%zmm18\n"
"vmaxps %%zmm19, %%zmm31, %%zmm19\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"vminps %%zmm10, %%zmm30, %%zmm10\n"
"vminps %%zmm11, %%zmm30, %%zmm11\n"
"vminps %%zmm12, %%zmm30, %%zmm12\n"
"vminps %%zmm13, %%zmm30, %%zmm13\n"
"vminps %%zmm14, %%zmm30, %%zmm14\n"
"vminps %%zmm15, %%zmm30, %%zmm15\n"
"vminps %%zmm16, %%zmm30, %%zmm16\n"
"vminps %%zmm17, %%zmm30, %%zmm17\n"
"vminps %%zmm18, %%zmm30, %%zmm18\n"
"vminps %%zmm19, %%zmm30, %%zmm19\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm12, 0(%[dst_3])\n"
"vmovups %%zmm13, 64(%[dst_3])\n"
"vmovups %%zmm14, 128(%[dst_3])\n"
"vmovups %%zmm15, 192(%[dst_3])\n"
"vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1)\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
[ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -23,7 +23,6 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const
const float *dst_3 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
@ -116,7 +115,8 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22",
"%zmm23", "%zmm24");
const float *src_3 = src + 3 * dst_stride;
const float *src_3 = src + 3 * src_stride;
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
@ -128,33 +128,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 0(%[src_0]), %%zmm26\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 0(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n"
// block 1
"vmovups 320(%[weight]), %%zmm31\n"
"vmovups 384(%[weight]), %%zmm30\n"
@ -164,33 +164,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 4(%[src_0]), %%zmm26\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 4(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n"
// block 2
"vmovups 640(%[weight]), %%zmm31\n"
"vmovups 704(%[weight]), %%zmm30\n"
@ -200,33 +200,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 8(%[src_0]), %%zmm26\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 8(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n"
// block 3
"vmovups 960(%[weight]), %%zmm31\n"
"vmovups 1024(%[weight]), %%zmm30\n"
@ -236,33 +236,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 12(%[src_0]), %%zmm26\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 12(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n"
// block 4
"vmovups 1280(%[weight]), %%zmm31\n"
"vmovups 1344(%[weight]), %%zmm30\n"
@ -272,33 +272,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 16(%[src_0]), %%zmm26\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 16(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n"
// block 5
"vmovups 1600(%[weight]), %%zmm31\n"
"vmovups 1664(%[weight]), %%zmm30\n"
@ -308,33 +308,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 20(%[src_0]), %%zmm26\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 20(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n"
// block 6
"vmovups 1920(%[weight]), %%zmm31\n"
"vmovups 1984(%[weight]), %%zmm30\n"
@ -344,33 +344,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 24(%[src_0]), %%zmm26\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 24(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n"
// block 7
"vmovups 2240(%[weight]), %%zmm31\n"
"vmovups 2304(%[weight]), %%zmm30\n"
@ -380,34 +380,33 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vbroadcastss 28(%[src_0]), %%zmm26\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n"
"vbroadcastss 28(%[src_3]), %%zmm25\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
"vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
"vfmadd231ps %%zmm27, %%zmm26, %%zmm24\n"
"dec %[deep]\n"
"add $2560, %[weight]\n"
"add $32, %[src_0]\n"
@ -483,26 +482,26 @@ void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 256(%[dst_0])\n"
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n"
"vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2),\n"
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm15, 0(%[dst_3])\n"
"vmovups %%zmm16, 64(%[dst_3])\n"
"vmovups %%zmm17, 128(%[dst_3])\n"
"vmovups %%zmm18, 192(%[dst_3])\n"
"vmovups %%zmm19, 256(%[dst_3])\n"
"vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 1),\n"
"vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 1),\n"
"vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 1),\n"
"vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 1),\n"
"vmovups %%zmm24, 256(%[dst_3], %[dst_stride], 1),\n"
"vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm24, 256(%[dst_3], %[dst_stride], 1)\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),

View File

@ -0,0 +1,312 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_6x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
const float *dst_3 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n"
"vmovups 0(%[dst_3]), %%zmm6\n"
"vmovups 64(%[dst_3]), %%zmm7\n"
"vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n"
"vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n"
"vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n"
"vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 0(%[bias]), %%zmm2\n"
"vmovaps 64(%[bias]), %%zmm3\n"
"vmovaps 0(%[bias]), %%zmm4\n"
"vmovaps 64(%[bias]), %%zmm5\n"
"vmovaps 0(%[bias]), %%zmm6\n"
"vmovaps 64(%[bias]), %%zmm7\n"
"vmovaps 0(%[bias]), %%zmm8\n"
"vmovaps 64(%[bias]), %%zmm9\n"
"vmovaps 0(%[bias]), %%zmm10\n"
"vmovaps 64(%[bias]), %%zmm11\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
[ dst_3 ] "r"(dst_3)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11");
const float *src_3 = src + 3 * src_stride;
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vbroadcastss 0(%[src_0]), %%zmm29\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 0(%[src_3]), %%zmm26\n"
"vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
// block 1
"vmovups 128(%[weight]), %%zmm31\n"
"vmovups 192(%[weight]), %%zmm30\n"
"vbroadcastss 4(%[src_0]), %%zmm29\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 4(%[src_3]), %%zmm26\n"
"vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
// block 2
"vmovups 256(%[weight]), %%zmm31\n"
"vmovups 320(%[weight]), %%zmm30\n"
"vbroadcastss 8(%[src_0]), %%zmm29\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 8(%[src_3]), %%zmm26\n"
"vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
// block 3
"vmovups 384(%[weight]), %%zmm31\n"
"vmovups 448(%[weight]), %%zmm30\n"
"vbroadcastss 12(%[src_0]), %%zmm29\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 12(%[src_3]), %%zmm26\n"
"vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
// block 4
"vmovups 512(%[weight]), %%zmm31\n"
"vmovups 576(%[weight]), %%zmm30\n"
"vbroadcastss 16(%[src_0]), %%zmm29\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 16(%[src_3]), %%zmm26\n"
"vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
// block 5
"vmovups 640(%[weight]), %%zmm31\n"
"vmovups 704(%[weight]), %%zmm30\n"
"vbroadcastss 20(%[src_0]), %%zmm29\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 20(%[src_3]), %%zmm26\n"
"vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
// block 6
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vbroadcastss 24(%[src_0]), %%zmm29\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 24(%[src_3]), %%zmm26\n"
"vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
// block 7
"vmovups 896(%[weight]), %%zmm31\n"
"vmovups 960(%[weight]), %%zmm30\n"
"vbroadcastss 28(%[src_0]), %%zmm29\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 28(%[src_3]), %%zmm26\n"
"vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"dec %[deep]\n"
"add $1024, %[weight]\n"
"add $32, %[src_0]\n"
"add $32, %[src_3]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"vminps %%zmm10, %%zmm30, %%zmm10\n"
"vminps %%zmm11, %%zmm30, %%zmm11\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm6, 0(%[dst_3])\n"
"vmovups %%zmm7, 64(%[dst_3])\n"
"vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n"
"vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
[ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,498 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_6x64_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
const float *dst_3 = dst + 3 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 128(%[dst_0]), %%zmm2\n"
"vmovups 192(%[dst_0]), %%zmm3\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm4\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm5\n"
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm6\n"
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm7\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm8\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm9\n"
"vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm10\n"
"vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm11\n"
"vmovups 0(%[dst_3]), %%zmm12\n"
"vmovups 64(%[dst_3]), %%zmm13\n"
"vmovups 128(%[dst_3]), %%zmm14\n"
"vmovups 192(%[dst_3]), %%zmm15\n"
"vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm16\n"
"vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm17\n"
"vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm18\n"
"vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm19\n"
"vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm20\n"
"vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm21\n"
"vmovups 128(%[dst_3], %[dst_stride], 2), %%zmm22\n"
"vmovups 192(%[dst_3], %[dst_stride], 2), %%zmm23\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 128(%[bias]), %%zmm2\n"
"vmovaps 192(%[bias]), %%zmm3\n"
"vmovaps 0(%[bias]), %%zmm4\n"
"vmovaps 64(%[bias]), %%zmm5\n"
"vmovaps 128(%[bias]), %%zmm6\n"
"vmovaps 192(%[bias]), %%zmm7\n"
"vmovaps 0(%[bias]), %%zmm8\n"
"vmovaps 64(%[bias]), %%zmm9\n"
"vmovaps 128(%[bias]), %%zmm10\n"
"vmovaps 192(%[bias]), %%zmm11\n"
"vmovaps 0(%[bias]), %%zmm12\n"
"vmovaps 64(%[bias]), %%zmm13\n"
"vmovaps 128(%[bias]), %%zmm14\n"
"vmovaps 192(%[bias]), %%zmm15\n"
"vmovaps 0(%[bias]), %%zmm16\n"
"vmovaps 64(%[bias]), %%zmm17\n"
"vmovaps 128(%[bias]), %%zmm18\n"
"vmovaps 192(%[bias]), %%zmm19\n"
"vmovaps 0(%[bias]), %%zmm20\n"
"vmovaps 64(%[bias]), %%zmm21\n"
"vmovaps 128(%[bias]), %%zmm22\n"
"vmovaps 192(%[bias]), %%zmm23\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
"vxorps %%zmm14, %%zmm14, %%zmm14\n"
"vxorps %%zmm15, %%zmm15, %%zmm15\n"
"vxorps %%zmm16, %%zmm16, %%zmm16\n"
"vxorps %%zmm17, %%zmm17, %%zmm17\n"
"vxorps %%zmm18, %%zmm18, %%zmm18\n"
"vxorps %%zmm19, %%zmm19, %%zmm19\n"
"vxorps %%zmm20, %%zmm20, %%zmm20\n"
"vxorps %%zmm21, %%zmm21, %%zmm21\n"
"vxorps %%zmm22, %%zmm22, %%zmm22\n"
"vxorps %%zmm23, %%zmm23, %%zmm23\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
[ dst_3 ] "r"(dst_3)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22",
"%zmm23");
const float *src_3 = src + 3 * src_stride;
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vmovups 128(%[weight]), %%zmm29\n"
"vmovups 192(%[weight]), %%zmm28\n"
"vbroadcastss 0(%[src_0]), %%zmm27\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 0(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm27\n"
"vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
// block 1
"vmovups 256(%[weight]), %%zmm31\n"
"vmovups 320(%[weight]), %%zmm30\n"
"vmovups 384(%[weight]), %%zmm29\n"
"vmovups 448(%[weight]), %%zmm28\n"
"vbroadcastss 4(%[src_0]), %%zmm27\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 4(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm27\n"
"vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
// block 2
"vmovups 512(%[weight]), %%zmm31\n"
"vmovups 576(%[weight]), %%zmm30\n"
"vmovups 640(%[weight]), %%zmm29\n"
"vmovups 704(%[weight]), %%zmm28\n"
"vbroadcastss 8(%[src_0]), %%zmm27\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 8(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm27\n"
"vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
// block 3
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vmovups 896(%[weight]), %%zmm29\n"
"vmovups 960(%[weight]), %%zmm28\n"
"vbroadcastss 12(%[src_0]), %%zmm27\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 12(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm27\n"
"vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
// block 4
"vmovups 1024(%[weight]), %%zmm31\n"
"vmovups 1088(%[weight]), %%zmm30\n"
"vmovups 1152(%[weight]), %%zmm29\n"
"vmovups 1216(%[weight]), %%zmm28\n"
"vbroadcastss 16(%[src_0]), %%zmm27\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 16(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm27\n"
"vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
// block 5
"vmovups 1280(%[weight]), %%zmm31\n"
"vmovups 1344(%[weight]), %%zmm30\n"
"vmovups 1408(%[weight]), %%zmm29\n"
"vmovups 1472(%[weight]), %%zmm28\n"
"vbroadcastss 20(%[src_0]), %%zmm27\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 20(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm27\n"
"vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
// block 6
"vmovups 1536(%[weight]), %%zmm31\n"
"vmovups 1600(%[weight]), %%zmm30\n"
"vmovups 1664(%[weight]), %%zmm29\n"
"vmovups 1728(%[weight]), %%zmm28\n"
"vbroadcastss 24(%[src_0]), %%zmm27\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 24(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm27\n"
"vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
// block 7
"vmovups 1792(%[weight]), %%zmm31\n"
"vmovups 1856(%[weight]), %%zmm30\n"
"vmovups 1920(%[weight]), %%zmm29\n"
"vmovups 1984(%[weight]), %%zmm28\n"
"vbroadcastss 28(%[src_0]), %%zmm27\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm26\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n"
"vbroadcastss 28(%[src_3]), %%zmm24\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm1\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm2\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm5\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm29, %%zmm25, %%zmm10\n"
"vfmadd231ps %%zmm28, %%zmm25, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm13\n"
"vfmadd231ps %%zmm29, %%zmm24, %%zmm14\n"
"vfmadd231ps %%zmm28, %%zmm24, %%zmm15\n"
"vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm27\n"
"vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm26\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm16\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm17\n"
"vfmadd231ps %%zmm29, %%zmm27, %%zmm18\n"
"vfmadd231ps %%zmm28, %%zmm27, %%zmm19\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm20\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm21\n"
"vfmadd231ps %%zmm29, %%zmm26, %%zmm22\n"
"vfmadd231ps %%zmm28, %%zmm26, %%zmm23\n"
"dec %[deep]\n"
"add $2048, %[weight]\n"
"add $32, %[src_0]\n"
"add $32, %[src_3]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
"vmaxps %%zmm14, %%zmm31, %%zmm14\n"
"vmaxps %%zmm15, %%zmm31, %%zmm15\n"
"vmaxps %%zmm16, %%zmm31, %%zmm16\n"
"vmaxps %%zmm17, %%zmm31, %%zmm17\n"
"vmaxps %%zmm18, %%zmm31, %%zmm18\n"
"vmaxps %%zmm19, %%zmm31, %%zmm19\n"
"vmaxps %%zmm20, %%zmm31, %%zmm20\n"
"vmaxps %%zmm21, %%zmm31, %%zmm21\n"
"vmaxps %%zmm22, %%zmm31, %%zmm22\n"
"vmaxps %%zmm23, %%zmm31, %%zmm23\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"vminps %%zmm10, %%zmm30, %%zmm10\n"
"vminps %%zmm11, %%zmm30, %%zmm11\n"
"vminps %%zmm12, %%zmm30, %%zmm12\n"
"vminps %%zmm13, %%zmm30, %%zmm13\n"
"vminps %%zmm14, %%zmm30, %%zmm14\n"
"vminps %%zmm15, %%zmm30, %%zmm15\n"
"vminps %%zmm16, %%zmm30, %%zmm16\n"
"vminps %%zmm17, %%zmm30, %%zmm17\n"
"vminps %%zmm18, %%zmm30, %%zmm18\n"
"vminps %%zmm19, %%zmm30, %%zmm19\n"
"vminps %%zmm20, %%zmm30, %%zmm20\n"
"vminps %%zmm21, %%zmm30, %%zmm21\n"
"vminps %%zmm22, %%zmm30, %%zmm22\n"
"vminps %%zmm23, %%zmm30, %%zmm23\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 128(%[dst_0])\n"
"vmovups %%zmm3, 192(%[dst_0])\n"
"vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm6, 128(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm7, 192(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm8, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm9, 64(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm10, 128(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm11, 192(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm12, 0(%[dst_3])\n"
"vmovups %%zmm13, 64(%[dst_3])\n"
"vmovups %%zmm14, 128(%[dst_3])\n"
"vmovups %%zmm15, 192(%[dst_3])\n"
"vmovups %%zmm16, 0(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm17, 64(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm18, 128(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm19, 192(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 2)\n"
"vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 2)\n"
"vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 2)\n"
"vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 2)\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
[ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,352 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_7x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
const float *dst_3 = dst + 3 * dst_stride;
const float *dst_6 = dst + 6 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n"
"vmovups 0(%[dst_3]), %%zmm6\n"
"vmovups 64(%[dst_3]), %%zmm7\n"
"vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n"
"vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n"
"vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n"
"vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n"
"vmovups 0(%[dst_6]), %%zmm12\n"
"vmovups 64(%[dst_6]), %%zmm13\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 0(%[bias]), %%zmm2\n"
"vmovaps 64(%[bias]), %%zmm3\n"
"vmovaps 0(%[bias]), %%zmm4\n"
"vmovaps 64(%[bias]), %%zmm5\n"
"vmovaps 0(%[bias]), %%zmm6\n"
"vmovaps 64(%[bias]), %%zmm7\n"
"vmovaps 0(%[bias]), %%zmm8\n"
"vmovaps 64(%[bias]), %%zmm9\n"
"vmovaps 0(%[bias]), %%zmm10\n"
"vmovaps 64(%[bias]), %%zmm11\n"
"vmovaps 0(%[bias]), %%zmm12\n"
"vmovaps 64(%[bias]), %%zmm13\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
[ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13");
const float *src_3 = src + 3 * src_stride;
const float *src_6 = src + 6 * src_stride;
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vbroadcastss 0(%[src_0]), %%zmm29\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 0(%[src_3]), %%zmm26\n"
"vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 0(%[src_6]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
// block 1
"vmovups 128(%[weight]), %%zmm31\n"
"vmovups 192(%[weight]), %%zmm30\n"
"vbroadcastss 4(%[src_0]), %%zmm29\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 4(%[src_3]), %%zmm26\n"
"vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 4(%[src_6]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
// block 2
"vmovups 256(%[weight]), %%zmm31\n"
"vmovups 320(%[weight]), %%zmm30\n"
"vbroadcastss 8(%[src_0]), %%zmm29\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 8(%[src_3]), %%zmm26\n"
"vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 8(%[src_6]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
// block 3
"vmovups 384(%[weight]), %%zmm31\n"
"vmovups 448(%[weight]), %%zmm30\n"
"vbroadcastss 12(%[src_0]), %%zmm29\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 12(%[src_3]), %%zmm26\n"
"vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 12(%[src_6]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
// block 4
"vmovups 512(%[weight]), %%zmm31\n"
"vmovups 576(%[weight]), %%zmm30\n"
"vbroadcastss 16(%[src_0]), %%zmm29\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 16(%[src_3]), %%zmm26\n"
"vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 16(%[src_6]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
// block 5
"vmovups 640(%[weight]), %%zmm31\n"
"vmovups 704(%[weight]), %%zmm30\n"
"vbroadcastss 20(%[src_0]), %%zmm29\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 20(%[src_3]), %%zmm26\n"
"vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 20(%[src_6]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
// block 6
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vbroadcastss 24(%[src_0]), %%zmm29\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 24(%[src_3]), %%zmm26\n"
"vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 24(%[src_6]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
// block 7
"vmovups 896(%[weight]), %%zmm31\n"
"vmovups 960(%[weight]), %%zmm30\n"
"vbroadcastss 28(%[src_0]), %%zmm29\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 28(%[src_3]), %%zmm26\n"
"vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 28(%[src_6]), %%zmm23\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"dec %[deep]\n"
"add $1024, %[weight]\n"
"add $32, %[src_0]\n"
"add $32, %[src_3]\n"
"add $32, %[src_6]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"vminps %%zmm10, %%zmm30, %%zmm10\n"
"vminps %%zmm11, %%zmm30, %%zmm11\n"
"vminps %%zmm12, %%zmm30, %%zmm12\n"
"vminps %%zmm13, %%zmm30, %%zmm13\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm6, 0(%[dst_3])\n"
"vmovups %%zmm7, 64(%[dst_3])\n"
"vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n"
"vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n"
"vmovups %%zmm12, 0(%[dst_6])\n"
"vmovups %%zmm13, 64(%[dst_6])\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
[ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -0,0 +1,388 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <x86intrin.h>
// nnacl gemm in x86 avx512 asm code
void nnacl_gemm_avx512_8x32_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
const size_t act_flag, const size_t row_block, const size_t col_block,
const size_t deep, const size_t src_stride, const size_t dst_stride,
const size_t inc_flag) {
const float *dst_3 = dst + 3 * dst_stride;
const float *dst_6 = dst + 6 * dst_stride;
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\n"
"je 0f\n"
"vmovups 0(%[dst_0]), %%zmm0\n"
"vmovups 64(%[dst_0]), %%zmm1\n"
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm2\n"
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm3\n"
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm4\n"
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm5\n"
"vmovups 0(%[dst_3]), %%zmm6\n"
"vmovups 64(%[dst_3]), %%zmm7\n"
"vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm8\n"
"vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm9\n"
"vmovups 0(%[dst_3], %[dst_stride], 2), %%zmm10\n"
"vmovups 64(%[dst_3], %[dst_stride], 2), %%zmm11\n"
"vmovups 0(%[dst_6]), %%zmm12\n"
"vmovups 64(%[dst_6]), %%zmm13\n"
"vmovups 0(%[dst_6], %[dst_stride], 1), %%zmm14\n"
"vmovups 64(%[dst_6], %[dst_stride], 1), %%zmm15\n"
"jmp 2f\n"
"0:\n"
"cmpq $0, %[bias]\n"
"je 1f\n"
"vmovaps 0(%[bias]), %%zmm0\n"
"vmovaps 64(%[bias]), %%zmm1\n"
"vmovaps 0(%[bias]), %%zmm2\n"
"vmovaps 64(%[bias]), %%zmm3\n"
"vmovaps 0(%[bias]), %%zmm4\n"
"vmovaps 64(%[bias]), %%zmm5\n"
"vmovaps 0(%[bias]), %%zmm6\n"
"vmovaps 64(%[bias]), %%zmm7\n"
"vmovaps 0(%[bias]), %%zmm8\n"
"vmovaps 64(%[bias]), %%zmm9\n"
"vmovaps 0(%[bias]), %%zmm10\n"
"vmovaps 64(%[bias]), %%zmm11\n"
"vmovaps 0(%[bias]), %%zmm12\n"
"vmovaps 64(%[bias]), %%zmm13\n"
"vmovaps 0(%[bias]), %%zmm14\n"
"vmovaps 64(%[bias]), %%zmm15\n"
"jmp 2f\n"
"1:\n"
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
"vxorps %%zmm14, %%zmm14, %%zmm14\n"
"vxorps %%zmm15, %%zmm15, %%zmm15\n"
"2:\n"
:
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
[ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6)
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
"%zmm12", "%zmm13", "%zmm14", "%zmm15");
const float *src_3 = src + 3 * src_stride;
const float *src_6 = src + 6 * src_stride;
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\n"
// block 0
"vmovups 0(%[weight]), %%zmm31\n"
"vmovups 64(%[weight]), %%zmm30\n"
"vbroadcastss 0(%[src_0]), %%zmm29\n"
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 0(%[src_3]), %%zmm26\n"
"vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 0(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 0(%[src_6]), %%zmm23\n"
"vbroadcastss 0(%[src_6], %[src_stride], 1), %%zmm22\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n"
"vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n"
// block 1
"vmovups 128(%[weight]), %%zmm31\n"
"vmovups 192(%[weight]), %%zmm30\n"
"vbroadcastss 4(%[src_0]), %%zmm29\n"
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 4(%[src_3]), %%zmm26\n"
"vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 4(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 4(%[src_6]), %%zmm23\n"
"vbroadcastss 4(%[src_6], %[src_stride], 1), %%zmm22\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n"
"vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n"
// block 2
"vmovups 256(%[weight]), %%zmm31\n"
"vmovups 320(%[weight]), %%zmm30\n"
"vbroadcastss 8(%[src_0]), %%zmm29\n"
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 8(%[src_3]), %%zmm26\n"
"vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 8(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 8(%[src_6]), %%zmm23\n"
"vbroadcastss 8(%[src_6], %[src_stride], 1), %%zmm22\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n"
"vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n"
// block 3
"vmovups 384(%[weight]), %%zmm31\n"
"vmovups 448(%[weight]), %%zmm30\n"
"vbroadcastss 12(%[src_0]), %%zmm29\n"
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 12(%[src_3]), %%zmm26\n"
"vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 12(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 12(%[src_6]), %%zmm23\n"
"vbroadcastss 12(%[src_6], %[src_stride], 1), %%zmm22\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n"
"vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n"
// block 4
"vmovups 512(%[weight]), %%zmm31\n"
"vmovups 576(%[weight]), %%zmm30\n"
"vbroadcastss 16(%[src_0]), %%zmm29\n"
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 16(%[src_3]), %%zmm26\n"
"vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 16(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 16(%[src_6]), %%zmm23\n"
"vbroadcastss 16(%[src_6], %[src_stride], 1), %%zmm22\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n"
"vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n"
// block 5
"vmovups 640(%[weight]), %%zmm31\n"
"vmovups 704(%[weight]), %%zmm30\n"
"vbroadcastss 20(%[src_0]), %%zmm29\n"
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 20(%[src_3]), %%zmm26\n"
"vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 20(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 20(%[src_6]), %%zmm23\n"
"vbroadcastss 20(%[src_6], %[src_stride], 1), %%zmm22\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n"
"vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n"
// block 6
"vmovups 768(%[weight]), %%zmm31\n"
"vmovups 832(%[weight]), %%zmm30\n"
"vbroadcastss 24(%[src_0]), %%zmm29\n"
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 24(%[src_3]), %%zmm26\n"
"vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 24(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 24(%[src_6]), %%zmm23\n"
"vbroadcastss 24(%[src_6], %[src_stride], 1), %%zmm22\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n"
"vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n"
// block 7
"vmovups 896(%[weight]), %%zmm31\n"
"vmovups 960(%[weight]), %%zmm30\n"
"vbroadcastss 28(%[src_0]), %%zmm29\n"
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm28\n"
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm27\n"
"vbroadcastss 28(%[src_3]), %%zmm26\n"
"vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm25\n"
"vbroadcastss 28(%[src_3], %[src_stride], 2), %%zmm24\n"
"vbroadcastss 28(%[src_6]), %%zmm23\n"
"vbroadcastss 28(%[src_6], %[src_stride], 1), %%zmm22\n"
"vfmadd231ps %%zmm31, %%zmm29, %%zmm0\n"
"vfmadd231ps %%zmm30, %%zmm29, %%zmm1\n"
"vfmadd231ps %%zmm31, %%zmm28, %%zmm2\n"
"vfmadd231ps %%zmm30, %%zmm28, %%zmm3\n"
"vfmadd231ps %%zmm31, %%zmm27, %%zmm4\n"
"vfmadd231ps %%zmm30, %%zmm27, %%zmm5\n"
"vfmadd231ps %%zmm31, %%zmm26, %%zmm6\n"
"vfmadd231ps %%zmm30, %%zmm26, %%zmm7\n"
"vfmadd231ps %%zmm31, %%zmm25, %%zmm8\n"
"vfmadd231ps %%zmm30, %%zmm25, %%zmm9\n"
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
"vfmadd231ps %%zmm31, %%zmm22, %%zmm14\n"
"vfmadd231ps %%zmm30, %%zmm22, %%zmm15\n"
"dec %[deep]\n"
"add $1024, %[weight]\n"
"add $32, %[src_0]\n"
"add $32, %[src_3]\n"
"add $32, %[src_6]\n"
"jg 0b\n"
"and $0x2, %[inc_flag]\n"
"je 3f\n"
"movq %[act_flag], %%rax\n"
"and $0x3, %%eax\n"
"je 3f\n"
// relu
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
"vmaxps %%zmm14, %%zmm31, %%zmm14\n"
"vmaxps %%zmm15, %%zmm31, %%zmm15\n"
"and $0x1, %%eax\n"
"je 3f\n"
// relu6
"mov $0x40C00000, %%eax\n"
"vmovd %%eax, %%xmm30\n"
"vbroadcastss %%xmm30, %%zmm30\n"
"vminps %%zmm0, %%zmm30, %%zmm0\n"
"vminps %%zmm1, %%zmm30, %%zmm1\n"
"vminps %%zmm2, %%zmm30, %%zmm2\n"
"vminps %%zmm3, %%zmm30, %%zmm3\n"
"vminps %%zmm4, %%zmm30, %%zmm4\n"
"vminps %%zmm5, %%zmm30, %%zmm5\n"
"vminps %%zmm6, %%zmm30, %%zmm6\n"
"vminps %%zmm7, %%zmm30, %%zmm7\n"
"vminps %%zmm8, %%zmm30, %%zmm8\n"
"vminps %%zmm9, %%zmm30, %%zmm9\n"
"vminps %%zmm10, %%zmm30, %%zmm10\n"
"vminps %%zmm11, %%zmm30, %%zmm11\n"
"vminps %%zmm12, %%zmm30, %%zmm12\n"
"vminps %%zmm13, %%zmm30, %%zmm13\n"
"vminps %%zmm14, %%zmm30, %%zmm14\n"
"vminps %%zmm15, %%zmm30, %%zmm15\n"
"3:\n"
"vmovups %%zmm0, 0(%[dst_0])\n"
"vmovups %%zmm1, 64(%[dst_0])\n"
"vmovups %%zmm2, 0(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm3, 64(%[dst_0], %[dst_stride], 1)\n"
"vmovups %%zmm4, 0(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm5, 64(%[dst_0], %[dst_stride], 2)\n"
"vmovups %%zmm6, 0(%[dst_3])\n"
"vmovups %%zmm7, 64(%[dst_3])\n"
"vmovups %%zmm8, 0(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm9, 64(%[dst_3], %[dst_stride], 1)\n"
"vmovups %%zmm10, 0(%[dst_3], %[dst_stride], 2)\n"
"vmovups %%zmm11, 64(%[dst_3], %[dst_stride], 2)\n"
"vmovups %%zmm12, 0(%[dst_6])\n"
"vmovups %%zmm13, 64(%[dst_6])\n"
"vmovups %%zmm14, 0(%[dst_6], %[dst_stride], 1)\n"
"vmovups %%zmm15, 64(%[dst_6], %[dst_stride], 1)\n"
:
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
[ dst_3 ] "r"(dst_3), [ dst_6 ] "r"(dst_6), [ src_3 ] "r"(src_3), [ src_6 ] "r"(src_6)
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
}

View File

@ -15,64 +15,64 @@
# ============================================================================
# generate gemm fma asm code
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c
# generate gemm fma intrinics code
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c
#
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c
#
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c
#
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c
#
## generate gemm fma intrinics code
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c
#
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c
#
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c
#
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c
#python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c
# generate gemm avx512 asm code
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=4 col_block=96 -O ./gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c
@ -85,3 +85,19 @@ python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=4
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=3 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=2 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=1 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=6 col_block=64 -O ./gemm_avx512/nnacl_gemm_avx512_6x64_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=5 col_block=64 -O ./gemm_avx512/nnacl_gemm_avx512_5x64_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=4 col_block=64 -O ./gemm_avx512/nnacl_gemm_avx512_4x64_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=3 col_block=64 -O ./gemm_avx512/nnacl_gemm_avx512_3x64_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=2 col_block=64 -O ./gemm_avx512/nnacl_gemm_avx512_2x64_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=1 col_block=64 -O ./gemm_avx512/nnacl_gemm_avx512_1x64_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=8 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_8x32_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=7 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_7x32_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=6 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_6x32_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=5 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_5x32_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=4 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_4x32_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=3 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_3x32_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=2 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_2x32_kernel_nhwc_fp32.c
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=1 col_block=32 -O ./gemm_avx512/nnacl_gemm_avx512_1x32_kernel_nhwc_fp32.c

View File

@ -26,11 +26,18 @@ def key_value_pair(line):
:param line:
:return:
"""
key, value = line.split("=", 1)
key = None
value = None
try:
key, value = line.split("=", 1)
except ValueError:
print("line must be format: key=value, but now is:", line)
sys.exit(1)
try:
value = int(value)
except ValueError:
print("Error: you input value must be integer, but now is ", value)
print("Error: you input value must be integer, but now is:", value)
sys.exit(1)
return key, value
def get_indent(line):
@ -66,7 +73,7 @@ def print_line(line):
generate_code_indent = get_indent(line)
if line.strip().startswith("}") and "{" not in line:
generate_code_indent -= 4
if (len(line) == 1 and line[0] == "}"):
if len(line) == 1 and line[0] == "}":
# modify next fun generate_code_indent
generate_code_indent = -4
return "\"".join(result)
@ -107,7 +114,6 @@ def generate_code(template_file, exec_dict):
line = line.replace("\n", "")
if line.strip() and line.strip()[0] != "@":
line = line.replace("\"", "\\\"")
if line.strip() and line.strip()[0] != "@":
line = line.replace("%", "%%")
if "print" in line:
line = line.replace("%%", "%")
@ -118,12 +124,17 @@ def generate_code(template_file, exec_dict):
if "%(" not in str:
str = str.replace("%%[", "%[")
generate_code_lines.append(str)
# print('\n'.join(generate_code_lines))
c = compile('\n'.join(generate_code_lines), '', 'exec')
exec_dict["OUT_STREAM"] = output_stream
exec(c, exec_dict)
return output_stream.getvalue()
def check_python_version():
if sys.version_info < (3, 6):
sys.stdout.write("At least python 3.6 is required, but now is " + str(sys.version_info.major) + "." +
str(sys.version_info.minor) + "\n")
sys.exit(1)
generate_code_indent = -4
python_indent = -1
@ -134,6 +145,7 @@ parser.add_argument("-A", dest="defines", metavar="KEY=VALUE", nargs="*", type=k
parser.add_argument("-O", dest="Output_File", nargs=1, help="generate code output file path")
if __name__ == "__main__":
check_python_version()
parameters = parser.parse_args(sys.argv[1:])
exec_globals = dict(chain(*parameters.defines))

View File

@ -20,6 +20,9 @@ void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, co
const float *bias, const size_t act_flag, const size_t row_block,
const size_t col_block, const size_t deep, const size_t src_stride,
const size_t dst_stride, const size_t inc_flag) {
@import math
@row_stride_map = {6 : 4, 5 : 5, 4 : 6, 3 : 8, 2 : 12, 1 : 20}
@src_addr_stride = 3
@asm_flag_list = []
@row_split_number = [row for row in range(3, row_block, 3)]
@for row in row_split_number:
@ -27,19 +30,18 @@ void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, co
@asm_flag_list.append("[dst_" + str(row) + "] " + "\"r\"(dst_" + str(row) + ")");
size_t deep_t = deep >> 3;
size_t dst_stride_t = dst_stride << 2;
size_t src_stride_t = src_stride << 2;
@col_split_num = col_block >> 4;
asm volatile(
// inc in deep
"and $0x1, %[inc_flag]\\n"
"je 0f\\n"
@for row in range(0, row_block):
@tmp = int(row / 3) * 3
@src_addr = int(row / 3) * 3
@for col in range(0, col_split_num):
@if row % 3 == 0:
"vmovups @{col * 64}(%[dst_@{tmp}]), %%zmm@{row * col_split_num + col}\\n"
"vmovups @{col * 64}(%[dst_@{src_addr}]), %%zmm@{row * col_split_num + col}\\n"
@else:
"vmovups @{col * 64}(%[dst_@{tmp}], %[dst_stride], @{row - tmp}), %%zmm@{row * col_split_num + col}\\n"
"vmovups @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr}), %%zmm@{row * col_split_num + col}\\n"
"jmp 2f\\n"
"0:\\n"
"cmpq $0, %[bias]\\n"
@ -60,8 +62,9 @@ void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, co
@print(" : " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, row_block * col_block >>4)]), file=OUT_STREAM)
);
@for row in row_split_number:
const float *src_@{row} = src + @{row} * dst_stride;
const float *src_@{row} = src + @{row} * src_stride;
@asm_flag_list.append("[src_" + str(row) + "] " + "\"r\"(src_" + str(row) + ")");
size_t src_stride_t = src_stride << 2;
asm volatile(
"0:\\n"
@loop_count = 8
@ -69,63 +72,52 @@ void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, co
// block @{i}
@for col in range(0, col_split_num):
"vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n"
@if col_split_num == 6:
@if row_block == 4:
"vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num}\\n"
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], 1), %%zmm@{31 - col_split_num - 1}\\n"
@if row_block * col_split_num + row_block + col_split_num <= 32:
@for row in range(0, row_block):
@src_addr = math.floor(row / src_addr_stride) * src_addr_stride
@src_index = 31 - col_split_num - row
@if row % src_addr_stride == 0:
"vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n"
@else:
"vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n"
@for row in range(0, row_block):
@src_index = 31 - col_split_num - row
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{0 * col_split_num + col}\\n"
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{1 * col_split_num + col}\\n"
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], 2), %%zmm@{31 - col_split_num}\\n"
"vbroadcastss @{i * 4}(%[src_3]), %%zmm@{31 - col_split_num - 1}\\n"
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n"
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{3 * col_split_num + col}\\n"
@else:
@for row in range(0, row_block):
@if row == 0:
"vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num - row}\\n"
@weight_index = 31 - col
@dst_index = row * col_split_num + col
"vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n"
@else:
@row_stride = 32 - (row_stride_map[col_split_num] + 1) * col_split_num
@row_split_num = math.floor(row_block / row_stride)
@for row_index in range(0, row_split_num):
@row_split_start = row_index * row_stride
@for row in range(row_split_start, row_split_start + row_stride):
@src_addr = math.floor(row / src_addr_stride) * src_addr_stride
@src_index = 31 - col_split_num - (row - row_split_start)
@if row % src_addr_stride == 0:
"vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n"
@else:
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], @{row}), %%zmm@{31 - col_split_num - row}\\n"
@for row in range(0, row_block):
"vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n"
@for row in range(0, row_stride):
@src_index = 31 - col_split_num - row
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - row}, %%zmm@{row * col_split_num + col}\\n"
@elif col_split_num == 5:
@if row_block == 5:
"vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num}\\n"
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], 1), %%zmm@{31 - col_split_num - 1}\\n"
@weight_index = 31 - col
@dst_index = (row_split_start + row) * col_split_num + col
"vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n"
@row_split_start = row_split_num * row_stride
@for row in range(row_split_start, row_block):
@src_addr = math.floor(row / src_addr_stride) * src_addr_stride
@src_index = 31 - col_split_num - (row - row_split_start)
@if row % src_addr_stride == 0:
"vbroadcastss @{i * 4}(%[src_@{src_addr}]), %%zmm@{src_index}\\n"
@else:
"vbroadcastss @{i * 4}(%[src_@{src_addr}], %[src_stride], @{row - src_addr}), %%zmm@{src_index}\\n"
@for row in range(row_split_start, row_block):
@src_index = 31 - col_split_num - (row - row_split_start)
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{0 * col_split_num + col}\\n"
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{1 * col_split_num + col}\\n"
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], 2), %%zmm@{31 - col_split_num}\\n"
"vbroadcastss @{i * 4}(%[src_3]), %%zmm@{31 - col_split_num - 1}\\n"
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n"
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{3 * col_split_num + col}\\n"
"vbroadcastss @{i * 4}(%[src_3], %[src_stride], 1), %%zmm@{31 - col_split_num}\\n"
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n"
@else:
@for row in range(0, row_block):
@if row == 0:
"vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num - row}\\n"
@else:
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], @{row}), %%zmm@{31 - col_split_num - row}\\n"
@for row in range(0, row_block):
@for col in range(0, col_split_num):
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - row}, %%zmm@{row * col_split_num + col}\\n"
@elif col_split_num == 2:
@for row in range(int(row_block / 6)):
@for j in range(0, 6):
@tmp = int(j / 3) * 3
@if j % 3 == 0:
"vbroadcastss @{i * 4}(%[src_@{row * 6 + tmp}]), %%zmm@{31 - col_split_num - j}\\n"
@else:
"vbroadcastss @{i * 4}(%[src_@{row * 6 + tmp}], %[src_stride], @{j - tmp}), %%zmm@{31 - col_split_num - j}\\n"
@for col in range(0, col_split_num):
@for j in range(0, 6):
"fmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - j}, %%zmm@{(row * 6 + j) * col_split_num + col}\\n"
@weight_index = 31 - col
@dst_index = row * col_split_num + col
"vfmadd231ps %%zmm@{weight_index}, %%zmm@{src_index}, %%zmm@{dst_index}\\n"
"dec %[deep]\\n"
"add $@{col_block * 4 * 8}, %[weight]\\n"
"add $@{loop_count * 4}, %[src_0]\\n"
@ -154,12 +146,12 @@ void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, co
"vminps %%zmm@{row + col * row_block}, %%zmm30, %%zmm@{row + col * row_block}\\n"
"3:\\n"
@for row in range(0, row_block):
@tmp = int(row / 3) * 3
@src_addr = int(row / 3) * 3
@for col in range(0, col_split_num):
@if row % 3 == 0:
"vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{tmp}])\\n"
"vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}])\\n"
@else:
"vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{tmp}], %[dst_stride], @{row - tmp}),\\n"
"vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{src_addr}], %[dst_stride], @{row - src_addr})\\n"
:
@list = ["[src_0] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[deep] \"r\"(deep_t)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst_0] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"]
@list.extend(asm_flag_list)