forked from mindspore-Ecosystem/mindspore
code generator
This commit is contained in:
parent
2496637ead
commit
dfbe38aa5b
|
@ -118,3 +118,7 @@
|
|||
"mindspore/tests/st/ops/ascend/test_aicpu_ops/test_strided_slice_grad.py" "redefined-outer-name"
|
||||
"mindspore/tests/st/pynative/parser/test_parser_construct.py" "bad-super-call"
|
||||
"mindspore/tests/ut/python/optimizer/test_auto_grad.py" "broad-except"
|
||||
|
||||
#MindSpore Lite
|
||||
"mindspore/mindspore/lite/experiment/HPC-generator/generator.py" "redefined-builtin"
|
||||
"mindspore/mindspore/lite/experiment/HPC-generator/generator.py" "exec-used"
|
|
@ -96,3 +96,54 @@ mindspore/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/intrinsics/avx/Tiled
|
|||
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetNodeOutputType
|
||||
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetValueToProto
|
||||
mindspore/mindspore/ccsrc/debug/dump_proto.cc:mindspore::ProtoExporter::SetScalarToProto
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c:nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c:nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x96_kernel_nhwc_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x80_kernel_nhwc_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x96_kernel_nhwc_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x96_kernel_nhwc_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_2x80_kernel_nhwc_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c:nnacl_gemm_avx512_3x96_kernel_nhwc_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_5x80_kernel_nhwc_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_1x80_kernel_nhwc_fp32
|
||||
mindspore/mindspore/lite/experiment/HPC-generator/gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c:nnacl_gemm_avx512_4x80_kernel_nhwc_fp32
|
||||
|
|
|
@ -0,0 +1,194 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 avx512 asm code
|
||||
void nnacl_gemm_avx512_1x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst_0]), %%zmm0\n"
|
||||
"vmovups 64(%[dst_0]), %%zmm1\n"
|
||||
"vmovups 128(%[dst_0]), %%zmm2\n"
|
||||
"vmovups 192(%[dst_0]), %%zmm3\n"
|
||||
"vmovups 256(%[dst_0]), %%zmm4\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%zmm0\n"
|
||||
"vmovaps 64(%[bias]), %%zmm1\n"
|
||||
"vmovaps 128(%[bias]), %%zmm2\n"
|
||||
"vmovaps 192(%[bias]), %%zmm3\n"
|
||||
"vmovaps 256(%[bias]), %%zmm4\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
|
||||
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
|
||||
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
|
||||
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
|
||||
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovups 0(%[weight]), %%zmm31\n"
|
||||
"vmovups 64(%[weight]), %%zmm30\n"
|
||||
"vmovups 128(%[weight]), %%zmm29\n"
|
||||
"vmovups 192(%[weight]), %%zmm28\n"
|
||||
"vmovups 256(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 0(%[src_0]), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
// block 1
|
||||
"vmovups 320(%[weight]), %%zmm31\n"
|
||||
"vmovups 384(%[weight]), %%zmm30\n"
|
||||
"vmovups 448(%[weight]), %%zmm29\n"
|
||||
"vmovups 512(%[weight]), %%zmm28\n"
|
||||
"vmovups 576(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 4(%[src_0]), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
// block 2
|
||||
"vmovups 640(%[weight]), %%zmm31\n"
|
||||
"vmovups 704(%[weight]), %%zmm30\n"
|
||||
"vmovups 768(%[weight]), %%zmm29\n"
|
||||
"vmovups 832(%[weight]), %%zmm28\n"
|
||||
"vmovups 896(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 8(%[src_0]), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
// block 3
|
||||
"vmovups 960(%[weight]), %%zmm31\n"
|
||||
"vmovups 1024(%[weight]), %%zmm30\n"
|
||||
"vmovups 1088(%[weight]), %%zmm29\n"
|
||||
"vmovups 1152(%[weight]), %%zmm28\n"
|
||||
"vmovups 1216(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 12(%[src_0]), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
// block 4
|
||||
"vmovups 1280(%[weight]), %%zmm31\n"
|
||||
"vmovups 1344(%[weight]), %%zmm30\n"
|
||||
"vmovups 1408(%[weight]), %%zmm29\n"
|
||||
"vmovups 1472(%[weight]), %%zmm28\n"
|
||||
"vmovups 1536(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 16(%[src_0]), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
// block 5
|
||||
"vmovups 1600(%[weight]), %%zmm31\n"
|
||||
"vmovups 1664(%[weight]), %%zmm30\n"
|
||||
"vmovups 1728(%[weight]), %%zmm29\n"
|
||||
"vmovups 1792(%[weight]), %%zmm28\n"
|
||||
"vmovups 1856(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 20(%[src_0]), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
// block 6
|
||||
"vmovups 1920(%[weight]), %%zmm31\n"
|
||||
"vmovups 1984(%[weight]), %%zmm30\n"
|
||||
"vmovups 2048(%[weight]), %%zmm29\n"
|
||||
"vmovups 2112(%[weight]), %%zmm28\n"
|
||||
"vmovups 2176(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 24(%[src_0]), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
// block 7
|
||||
"vmovups 2240(%[weight]), %%zmm31\n"
|
||||
"vmovups 2304(%[weight]), %%zmm30\n"
|
||||
"vmovups 2368(%[weight]), %%zmm29\n"
|
||||
"vmovups 2432(%[weight]), %%zmm28\n"
|
||||
"vmovups 2496(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 28(%[src_0]), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
|
||||
"dec %[deep]\n"
|
||||
"add $2560, %[weight]\n"
|
||||
"add $32, %[src_0]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"and $0x2, %[inc_flag]\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
|
||||
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
|
||||
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
|
||||
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
|
||||
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
|
||||
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm30\n"
|
||||
"vbroadcastss %%xmm30, %%zmm30\n"
|
||||
"vminps %%zmm0, %%zmm30, %%zmm0\n"
|
||||
"vminps %%zmm1, %%zmm30, %%zmm1\n"
|
||||
"vminps %%zmm2, %%zmm30, %%zmm2\n"
|
||||
"vminps %%zmm3, %%zmm30, %%zmm3\n"
|
||||
"vminps %%zmm4, %%zmm30, %%zmm4\n"
|
||||
"3:\n"
|
||||
"vmovups %%zmm0, 0(%[dst_0])\n"
|
||||
"vmovups %%zmm1, 64(%[dst_0])\n"
|
||||
"vmovups %%zmm2, 128(%[dst_0])\n"
|
||||
"vmovups %%zmm3, 192(%[dst_0])\n"
|
||||
"vmovups %%zmm4, 256(%[dst_0])\n"
|
||||
:
|
||||
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
|
||||
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
|
||||
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
|
||||
}
|
|
@ -0,0 +1,216 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 avx512 asm code
|
||||
void nnacl_gemm_avx512_1x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst_0]), %%zmm0\n"
|
||||
"vmovups 64(%[dst_0]), %%zmm1\n"
|
||||
"vmovups 128(%[dst_0]), %%zmm2\n"
|
||||
"vmovups 192(%[dst_0]), %%zmm3\n"
|
||||
"vmovups 256(%[dst_0]), %%zmm4\n"
|
||||
"vmovups 320(%[dst_0]), %%zmm5\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%zmm0\n"
|
||||
"vmovaps 64(%[bias]), %%zmm1\n"
|
||||
"vmovaps 128(%[bias]), %%zmm2\n"
|
||||
"vmovaps 192(%[bias]), %%zmm3\n"
|
||||
"vmovaps 256(%[bias]), %%zmm4\n"
|
||||
"vmovaps 320(%[bias]), %%zmm5\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
|
||||
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
|
||||
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
|
||||
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
|
||||
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
|
||||
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovups 0(%[weight]), %%zmm31\n"
|
||||
"vmovups 64(%[weight]), %%zmm30\n"
|
||||
"vmovups 128(%[weight]), %%zmm29\n"
|
||||
"vmovups 192(%[weight]), %%zmm28\n"
|
||||
"vmovups 256(%[weight]), %%zmm27\n"
|
||||
"vmovups 320(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 0(%[src_0]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
// block 1
|
||||
"vmovups 384(%[weight]), %%zmm31\n"
|
||||
"vmovups 448(%[weight]), %%zmm30\n"
|
||||
"vmovups 512(%[weight]), %%zmm29\n"
|
||||
"vmovups 576(%[weight]), %%zmm28\n"
|
||||
"vmovups 640(%[weight]), %%zmm27\n"
|
||||
"vmovups 704(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 4(%[src_0]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
// block 2
|
||||
"vmovups 768(%[weight]), %%zmm31\n"
|
||||
"vmovups 832(%[weight]), %%zmm30\n"
|
||||
"vmovups 896(%[weight]), %%zmm29\n"
|
||||
"vmovups 960(%[weight]), %%zmm28\n"
|
||||
"vmovups 1024(%[weight]), %%zmm27\n"
|
||||
"vmovups 1088(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 8(%[src_0]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
// block 3
|
||||
"vmovups 1152(%[weight]), %%zmm31\n"
|
||||
"vmovups 1216(%[weight]), %%zmm30\n"
|
||||
"vmovups 1280(%[weight]), %%zmm29\n"
|
||||
"vmovups 1344(%[weight]), %%zmm28\n"
|
||||
"vmovups 1408(%[weight]), %%zmm27\n"
|
||||
"vmovups 1472(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 12(%[src_0]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
// block 4
|
||||
"vmovups 1536(%[weight]), %%zmm31\n"
|
||||
"vmovups 1600(%[weight]), %%zmm30\n"
|
||||
"vmovups 1664(%[weight]), %%zmm29\n"
|
||||
"vmovups 1728(%[weight]), %%zmm28\n"
|
||||
"vmovups 1792(%[weight]), %%zmm27\n"
|
||||
"vmovups 1856(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 16(%[src_0]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
// block 5
|
||||
"vmovups 1920(%[weight]), %%zmm31\n"
|
||||
"vmovups 1984(%[weight]), %%zmm30\n"
|
||||
"vmovups 2048(%[weight]), %%zmm29\n"
|
||||
"vmovups 2112(%[weight]), %%zmm28\n"
|
||||
"vmovups 2176(%[weight]), %%zmm27\n"
|
||||
"vmovups 2240(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 20(%[src_0]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
// block 6
|
||||
"vmovups 2304(%[weight]), %%zmm31\n"
|
||||
"vmovups 2368(%[weight]), %%zmm30\n"
|
||||
"vmovups 2432(%[weight]), %%zmm29\n"
|
||||
"vmovups 2496(%[weight]), %%zmm28\n"
|
||||
"vmovups 2560(%[weight]), %%zmm27\n"
|
||||
"vmovups 2624(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 24(%[src_0]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
// block 7
|
||||
"vmovups 2688(%[weight]), %%zmm31\n"
|
||||
"vmovups 2752(%[weight]), %%zmm30\n"
|
||||
"vmovups 2816(%[weight]), %%zmm29\n"
|
||||
"vmovups 2880(%[weight]), %%zmm28\n"
|
||||
"vmovups 2944(%[weight]), %%zmm27\n"
|
||||
"vmovups 3008(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 28(%[src_0]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
|
||||
"dec %[deep]\n"
|
||||
"add $3072, %[weight]\n"
|
||||
"add $32, %[src_0]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"and $0x2, %[inc_flag]\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
|
||||
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
|
||||
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
|
||||
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
|
||||
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
|
||||
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
|
||||
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm30\n"
|
||||
"vbroadcastss %%xmm30, %%zmm30\n"
|
||||
"vminps %%zmm0, %%zmm30, %%zmm0\n"
|
||||
"vminps %%zmm1, %%zmm30, %%zmm1\n"
|
||||
"vminps %%zmm2, %%zmm30, %%zmm2\n"
|
||||
"vminps %%zmm3, %%zmm30, %%zmm3\n"
|
||||
"vminps %%zmm4, %%zmm30, %%zmm4\n"
|
||||
"vminps %%zmm5, %%zmm30, %%zmm5\n"
|
||||
"3:\n"
|
||||
"vmovups %%zmm0, 0(%[dst_0])\n"
|
||||
"vmovups %%zmm1, 64(%[dst_0])\n"
|
||||
"vmovups %%zmm2, 128(%[dst_0])\n"
|
||||
"vmovups %%zmm3, 192(%[dst_0])\n"
|
||||
"vmovups %%zmm4, 256(%[dst_0])\n"
|
||||
"vmovups %%zmm5, 320(%[dst_0])\n"
|
||||
:
|
||||
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
|
||||
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
|
||||
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
|
||||
}
|
|
@ -0,0 +1,272 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 avx512 asm code
|
||||
void nnacl_gemm_avx512_2x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst_0]), %%zmm0\n"
|
||||
"vmovups 64(%[dst_0]), %%zmm1\n"
|
||||
"vmovups 128(%[dst_0]), %%zmm2\n"
|
||||
"vmovups 192(%[dst_0]), %%zmm3\n"
|
||||
"vmovups 256(%[dst_0]), %%zmm4\n"
|
||||
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n"
|
||||
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n"
|
||||
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n"
|
||||
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n"
|
||||
"vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%zmm0\n"
|
||||
"vmovaps 64(%[bias]), %%zmm1\n"
|
||||
"vmovaps 128(%[bias]), %%zmm2\n"
|
||||
"vmovaps 192(%[bias]), %%zmm3\n"
|
||||
"vmovaps 256(%[bias]), %%zmm4\n"
|
||||
"vmovaps 0(%[bias]), %%zmm5\n"
|
||||
"vmovaps 64(%[bias]), %%zmm6\n"
|
||||
"vmovaps 128(%[bias]), %%zmm7\n"
|
||||
"vmovaps 192(%[bias]), %%zmm8\n"
|
||||
"vmovaps 256(%[bias]), %%zmm9\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
|
||||
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
|
||||
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
|
||||
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
|
||||
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
|
||||
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
|
||||
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
|
||||
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
|
||||
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
|
||||
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovups 0(%[weight]), %%zmm31\n"
|
||||
"vmovups 64(%[weight]), %%zmm30\n"
|
||||
"vmovups 128(%[weight]), %%zmm29\n"
|
||||
"vmovups 192(%[weight]), %%zmm28\n"
|
||||
"vmovups 256(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 0(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
// block 1
|
||||
"vmovups 320(%[weight]), %%zmm31\n"
|
||||
"vmovups 384(%[weight]), %%zmm30\n"
|
||||
"vmovups 448(%[weight]), %%zmm29\n"
|
||||
"vmovups 512(%[weight]), %%zmm28\n"
|
||||
"vmovups 576(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 4(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
// block 2
|
||||
"vmovups 640(%[weight]), %%zmm31\n"
|
||||
"vmovups 704(%[weight]), %%zmm30\n"
|
||||
"vmovups 768(%[weight]), %%zmm29\n"
|
||||
"vmovups 832(%[weight]), %%zmm28\n"
|
||||
"vmovups 896(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 8(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
// block 3
|
||||
"vmovups 960(%[weight]), %%zmm31\n"
|
||||
"vmovups 1024(%[weight]), %%zmm30\n"
|
||||
"vmovups 1088(%[weight]), %%zmm29\n"
|
||||
"vmovups 1152(%[weight]), %%zmm28\n"
|
||||
"vmovups 1216(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 12(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
// block 4
|
||||
"vmovups 1280(%[weight]), %%zmm31\n"
|
||||
"vmovups 1344(%[weight]), %%zmm30\n"
|
||||
"vmovups 1408(%[weight]), %%zmm29\n"
|
||||
"vmovups 1472(%[weight]), %%zmm28\n"
|
||||
"vmovups 1536(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 16(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
// block 5
|
||||
"vmovups 1600(%[weight]), %%zmm31\n"
|
||||
"vmovups 1664(%[weight]), %%zmm30\n"
|
||||
"vmovups 1728(%[weight]), %%zmm29\n"
|
||||
"vmovups 1792(%[weight]), %%zmm28\n"
|
||||
"vmovups 1856(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 20(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
// block 6
|
||||
"vmovups 1920(%[weight]), %%zmm31\n"
|
||||
"vmovups 1984(%[weight]), %%zmm30\n"
|
||||
"vmovups 2048(%[weight]), %%zmm29\n"
|
||||
"vmovups 2112(%[weight]), %%zmm28\n"
|
||||
"vmovups 2176(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 24(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
// block 7
|
||||
"vmovups 2240(%[weight]), %%zmm31\n"
|
||||
"vmovups 2304(%[weight]), %%zmm30\n"
|
||||
"vmovups 2368(%[weight]), %%zmm29\n"
|
||||
"vmovups 2432(%[weight]), %%zmm28\n"
|
||||
"vmovups 2496(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 28(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
|
||||
"dec %[deep]\n"
|
||||
"add $2560, %[weight]\n"
|
||||
"add $32, %[src_0]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"and $0x2, %[inc_flag]\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
|
||||
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
|
||||
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
|
||||
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
|
||||
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
|
||||
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
|
||||
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
|
||||
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
|
||||
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
|
||||
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
|
||||
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm30\n"
|
||||
"vbroadcastss %%xmm30, %%zmm30\n"
|
||||
"vminps %%zmm0, %%zmm30, %%zmm0\n"
|
||||
"vminps %%zmm1, %%zmm30, %%zmm1\n"
|
||||
"vminps %%zmm2, %%zmm30, %%zmm2\n"
|
||||
"vminps %%zmm3, %%zmm30, %%zmm3\n"
|
||||
"vminps %%zmm4, %%zmm30, %%zmm4\n"
|
||||
"vminps %%zmm5, %%zmm30, %%zmm5\n"
|
||||
"vminps %%zmm6, %%zmm30, %%zmm6\n"
|
||||
"vminps %%zmm7, %%zmm30, %%zmm7\n"
|
||||
"vminps %%zmm8, %%zmm30, %%zmm8\n"
|
||||
"vminps %%zmm9, %%zmm30, %%zmm9\n"
|
||||
"3:\n"
|
||||
"vmovups %%zmm0, 0(%[dst_0])\n"
|
||||
"vmovups %%zmm1, 64(%[dst_0])\n"
|
||||
"vmovups %%zmm2, 128(%[dst_0])\n"
|
||||
"vmovups %%zmm3, 192(%[dst_0])\n"
|
||||
"vmovups %%zmm4, 256(%[dst_0])\n"
|
||||
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n"
|
||||
:
|
||||
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
|
||||
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
|
||||
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
|
||||
}
|
|
@ -0,0 +1,308 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 avx512 asm code
|
||||
void nnacl_gemm_avx512_2x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst_0]), %%zmm0\n"
|
||||
"vmovups 64(%[dst_0]), %%zmm1\n"
|
||||
"vmovups 128(%[dst_0]), %%zmm2\n"
|
||||
"vmovups 192(%[dst_0]), %%zmm3\n"
|
||||
"vmovups 256(%[dst_0]), %%zmm4\n"
|
||||
"vmovups 320(%[dst_0]), %%zmm5\n"
|
||||
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n"
|
||||
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n"
|
||||
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n"
|
||||
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n"
|
||||
"vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n"
|
||||
"vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%zmm0\n"
|
||||
"vmovaps 64(%[bias]), %%zmm1\n"
|
||||
"vmovaps 128(%[bias]), %%zmm2\n"
|
||||
"vmovaps 192(%[bias]), %%zmm3\n"
|
||||
"vmovaps 256(%[bias]), %%zmm4\n"
|
||||
"vmovaps 320(%[bias]), %%zmm5\n"
|
||||
"vmovaps 0(%[bias]), %%zmm6\n"
|
||||
"vmovaps 64(%[bias]), %%zmm7\n"
|
||||
"vmovaps 128(%[bias]), %%zmm8\n"
|
||||
"vmovaps 192(%[bias]), %%zmm9\n"
|
||||
"vmovaps 256(%[bias]), %%zmm10\n"
|
||||
"vmovaps 320(%[bias]), %%zmm11\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
|
||||
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
|
||||
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
|
||||
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
|
||||
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
|
||||
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
|
||||
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
|
||||
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
|
||||
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
|
||||
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
|
||||
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
|
||||
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovups 0(%[weight]), %%zmm31\n"
|
||||
"vmovups 64(%[weight]), %%zmm30\n"
|
||||
"vmovups 128(%[weight]), %%zmm29\n"
|
||||
"vmovups 192(%[weight]), %%zmm28\n"
|
||||
"vmovups 256(%[weight]), %%zmm27\n"
|
||||
"vmovups 320(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 0(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
// block 1
|
||||
"vmovups 384(%[weight]), %%zmm31\n"
|
||||
"vmovups 448(%[weight]), %%zmm30\n"
|
||||
"vmovups 512(%[weight]), %%zmm29\n"
|
||||
"vmovups 576(%[weight]), %%zmm28\n"
|
||||
"vmovups 640(%[weight]), %%zmm27\n"
|
||||
"vmovups 704(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 4(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
// block 2
|
||||
"vmovups 768(%[weight]), %%zmm31\n"
|
||||
"vmovups 832(%[weight]), %%zmm30\n"
|
||||
"vmovups 896(%[weight]), %%zmm29\n"
|
||||
"vmovups 960(%[weight]), %%zmm28\n"
|
||||
"vmovups 1024(%[weight]), %%zmm27\n"
|
||||
"vmovups 1088(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 8(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
// block 3
|
||||
"vmovups 1152(%[weight]), %%zmm31\n"
|
||||
"vmovups 1216(%[weight]), %%zmm30\n"
|
||||
"vmovups 1280(%[weight]), %%zmm29\n"
|
||||
"vmovups 1344(%[weight]), %%zmm28\n"
|
||||
"vmovups 1408(%[weight]), %%zmm27\n"
|
||||
"vmovups 1472(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 12(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
// block 4
|
||||
"vmovups 1536(%[weight]), %%zmm31\n"
|
||||
"vmovups 1600(%[weight]), %%zmm30\n"
|
||||
"vmovups 1664(%[weight]), %%zmm29\n"
|
||||
"vmovups 1728(%[weight]), %%zmm28\n"
|
||||
"vmovups 1792(%[weight]), %%zmm27\n"
|
||||
"vmovups 1856(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 16(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
// block 5
|
||||
"vmovups 1920(%[weight]), %%zmm31\n"
|
||||
"vmovups 1984(%[weight]), %%zmm30\n"
|
||||
"vmovups 2048(%[weight]), %%zmm29\n"
|
||||
"vmovups 2112(%[weight]), %%zmm28\n"
|
||||
"vmovups 2176(%[weight]), %%zmm27\n"
|
||||
"vmovups 2240(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 20(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
// block 6
|
||||
"vmovups 2304(%[weight]), %%zmm31\n"
|
||||
"vmovups 2368(%[weight]), %%zmm30\n"
|
||||
"vmovups 2432(%[weight]), %%zmm29\n"
|
||||
"vmovups 2496(%[weight]), %%zmm28\n"
|
||||
"vmovups 2560(%[weight]), %%zmm27\n"
|
||||
"vmovups 2624(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 24(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
// block 7
|
||||
"vmovups 2688(%[weight]), %%zmm31\n"
|
||||
"vmovups 2752(%[weight]), %%zmm30\n"
|
||||
"vmovups 2816(%[weight]), %%zmm29\n"
|
||||
"vmovups 2880(%[weight]), %%zmm28\n"
|
||||
"vmovups 2944(%[weight]), %%zmm27\n"
|
||||
"vmovups 3008(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 28(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
|
||||
"dec %[deep]\n"
|
||||
"add $3072, %[weight]\n"
|
||||
"add $32, %[src_0]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"and $0x2, %[inc_flag]\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
|
||||
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
|
||||
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
|
||||
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
|
||||
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
|
||||
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
|
||||
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
|
||||
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
|
||||
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
|
||||
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
|
||||
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
|
||||
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
|
||||
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm30\n"
|
||||
"vbroadcastss %%xmm30, %%zmm30\n"
|
||||
"vminps %%zmm0, %%zmm30, %%zmm0\n"
|
||||
"vminps %%zmm1, %%zmm30, %%zmm1\n"
|
||||
"vminps %%zmm2, %%zmm30, %%zmm2\n"
|
||||
"vminps %%zmm3, %%zmm30, %%zmm3\n"
|
||||
"vminps %%zmm4, %%zmm30, %%zmm4\n"
|
||||
"vminps %%zmm5, %%zmm30, %%zmm5\n"
|
||||
"vminps %%zmm6, %%zmm30, %%zmm6\n"
|
||||
"vminps %%zmm7, %%zmm30, %%zmm7\n"
|
||||
"vminps %%zmm8, %%zmm30, %%zmm8\n"
|
||||
"vminps %%zmm9, %%zmm30, %%zmm9\n"
|
||||
"vminps %%zmm10, %%zmm30, %%zmm10\n"
|
||||
"vminps %%zmm11, %%zmm30, %%zmm11\n"
|
||||
"3:\n"
|
||||
"vmovups %%zmm0, 0(%[dst_0])\n"
|
||||
"vmovups %%zmm1, 64(%[dst_0])\n"
|
||||
"vmovups %%zmm2, 128(%[dst_0])\n"
|
||||
"vmovups %%zmm3, 192(%[dst_0])\n"
|
||||
"vmovups %%zmm4, 256(%[dst_0])\n"
|
||||
"vmovups %%zmm5, 320(%[dst_0])\n"
|
||||
"vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1),\n"
|
||||
:
|
||||
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
|
||||
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
|
||||
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
|
||||
}
|
|
@ -0,0 +1,351 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 avx512 asm code
|
||||
void nnacl_gemm_avx512_3x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst_0]), %%zmm0\n"
|
||||
"vmovups 64(%[dst_0]), %%zmm1\n"
|
||||
"vmovups 128(%[dst_0]), %%zmm2\n"
|
||||
"vmovups 192(%[dst_0]), %%zmm3\n"
|
||||
"vmovups 256(%[dst_0]), %%zmm4\n"
|
||||
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n"
|
||||
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n"
|
||||
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n"
|
||||
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n"
|
||||
"vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n"
|
||||
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n"
|
||||
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n"
|
||||
"vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n"
|
||||
"vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n"
|
||||
"vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%zmm0\n"
|
||||
"vmovaps 64(%[bias]), %%zmm1\n"
|
||||
"vmovaps 128(%[bias]), %%zmm2\n"
|
||||
"vmovaps 192(%[bias]), %%zmm3\n"
|
||||
"vmovaps 256(%[bias]), %%zmm4\n"
|
||||
"vmovaps 0(%[bias]), %%zmm5\n"
|
||||
"vmovaps 64(%[bias]), %%zmm6\n"
|
||||
"vmovaps 128(%[bias]), %%zmm7\n"
|
||||
"vmovaps 192(%[bias]), %%zmm8\n"
|
||||
"vmovaps 256(%[bias]), %%zmm9\n"
|
||||
"vmovaps 0(%[bias]), %%zmm10\n"
|
||||
"vmovaps 64(%[bias]), %%zmm11\n"
|
||||
"vmovaps 128(%[bias]), %%zmm12\n"
|
||||
"vmovaps 192(%[bias]), %%zmm13\n"
|
||||
"vmovaps 256(%[bias]), %%zmm14\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
|
||||
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
|
||||
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
|
||||
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
|
||||
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
|
||||
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
|
||||
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
|
||||
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
|
||||
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
|
||||
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
|
||||
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
|
||||
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
|
||||
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
|
||||
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
|
||||
"vxorps %%zmm14, %%zmm14, %%zmm14\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
|
||||
"%zmm12", "%zmm13", "%zmm14");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovups 0(%[weight]), %%zmm31\n"
|
||||
"vmovups 64(%[weight]), %%zmm30\n"
|
||||
"vmovups 128(%[weight]), %%zmm29\n"
|
||||
"vmovups 192(%[weight]), %%zmm28\n"
|
||||
"vmovups 256(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 0(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
// block 1
|
||||
"vmovups 320(%[weight]), %%zmm31\n"
|
||||
"vmovups 384(%[weight]), %%zmm30\n"
|
||||
"vmovups 448(%[weight]), %%zmm29\n"
|
||||
"vmovups 512(%[weight]), %%zmm28\n"
|
||||
"vmovups 576(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 4(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
// block 2
|
||||
"vmovups 640(%[weight]), %%zmm31\n"
|
||||
"vmovups 704(%[weight]), %%zmm30\n"
|
||||
"vmovups 768(%[weight]), %%zmm29\n"
|
||||
"vmovups 832(%[weight]), %%zmm28\n"
|
||||
"vmovups 896(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 8(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
// block 3
|
||||
"vmovups 960(%[weight]), %%zmm31\n"
|
||||
"vmovups 1024(%[weight]), %%zmm30\n"
|
||||
"vmovups 1088(%[weight]), %%zmm29\n"
|
||||
"vmovups 1152(%[weight]), %%zmm28\n"
|
||||
"vmovups 1216(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 12(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
// block 4
|
||||
"vmovups 1280(%[weight]), %%zmm31\n"
|
||||
"vmovups 1344(%[weight]), %%zmm30\n"
|
||||
"vmovups 1408(%[weight]), %%zmm29\n"
|
||||
"vmovups 1472(%[weight]), %%zmm28\n"
|
||||
"vmovups 1536(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 16(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
// block 5
|
||||
"vmovups 1600(%[weight]), %%zmm31\n"
|
||||
"vmovups 1664(%[weight]), %%zmm30\n"
|
||||
"vmovups 1728(%[weight]), %%zmm29\n"
|
||||
"vmovups 1792(%[weight]), %%zmm28\n"
|
||||
"vmovups 1856(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 20(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
// block 6
|
||||
"vmovups 1920(%[weight]), %%zmm31\n"
|
||||
"vmovups 1984(%[weight]), %%zmm30\n"
|
||||
"vmovups 2048(%[weight]), %%zmm29\n"
|
||||
"vmovups 2112(%[weight]), %%zmm28\n"
|
||||
"vmovups 2176(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 24(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
// block 7
|
||||
"vmovups 2240(%[weight]), %%zmm31\n"
|
||||
"vmovups 2304(%[weight]), %%zmm30\n"
|
||||
"vmovups 2368(%[weight]), %%zmm29\n"
|
||||
"vmovups 2432(%[weight]), %%zmm28\n"
|
||||
"vmovups 2496(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 28(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
|
||||
"dec %[deep]\n"
|
||||
"add $2560, %[weight]\n"
|
||||
"add $32, %[src_0]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"and $0x2, %[inc_flag]\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
|
||||
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
|
||||
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
|
||||
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
|
||||
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
|
||||
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
|
||||
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
|
||||
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
|
||||
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
|
||||
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
|
||||
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
|
||||
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
|
||||
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
|
||||
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
|
||||
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
|
||||
"vmaxps %%zmm14, %%zmm31, %%zmm14\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm30\n"
|
||||
"vbroadcastss %%xmm30, %%zmm30\n"
|
||||
"vminps %%zmm0, %%zmm30, %%zmm0\n"
|
||||
"vminps %%zmm1, %%zmm30, %%zmm1\n"
|
||||
"vminps %%zmm2, %%zmm30, %%zmm2\n"
|
||||
"vminps %%zmm3, %%zmm30, %%zmm3\n"
|
||||
"vminps %%zmm4, %%zmm30, %%zmm4\n"
|
||||
"vminps %%zmm5, %%zmm30, %%zmm5\n"
|
||||
"vminps %%zmm6, %%zmm30, %%zmm6\n"
|
||||
"vminps %%zmm7, %%zmm30, %%zmm7\n"
|
||||
"vminps %%zmm8, %%zmm30, %%zmm8\n"
|
||||
"vminps %%zmm9, %%zmm30, %%zmm9\n"
|
||||
"vminps %%zmm10, %%zmm30, %%zmm10\n"
|
||||
"vminps %%zmm11, %%zmm30, %%zmm11\n"
|
||||
"vminps %%zmm12, %%zmm30, %%zmm12\n"
|
||||
"vminps %%zmm13, %%zmm30, %%zmm13\n"
|
||||
"vminps %%zmm14, %%zmm30, %%zmm14\n"
|
||||
"3:\n"
|
||||
"vmovups %%zmm0, 0(%[dst_0])\n"
|
||||
"vmovups %%zmm1, 64(%[dst_0])\n"
|
||||
"vmovups %%zmm2, 128(%[dst_0])\n"
|
||||
"vmovups %%zmm3, 192(%[dst_0])\n"
|
||||
"vmovups %%zmm4, 256(%[dst_0])\n"
|
||||
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2),\n"
|
||||
:
|
||||
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
|
||||
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
|
||||
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
|
||||
}
|
|
@ -0,0 +1,401 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 avx512 asm code
|
||||
void nnacl_gemm_avx512_3x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst_0]), %%zmm0\n"
|
||||
"vmovups 64(%[dst_0]), %%zmm1\n"
|
||||
"vmovups 128(%[dst_0]), %%zmm2\n"
|
||||
"vmovups 192(%[dst_0]), %%zmm3\n"
|
||||
"vmovups 256(%[dst_0]), %%zmm4\n"
|
||||
"vmovups 320(%[dst_0]), %%zmm5\n"
|
||||
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n"
|
||||
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n"
|
||||
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n"
|
||||
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n"
|
||||
"vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n"
|
||||
"vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n"
|
||||
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n"
|
||||
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n"
|
||||
"vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n"
|
||||
"vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n"
|
||||
"vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n"
|
||||
"vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%zmm0\n"
|
||||
"vmovaps 64(%[bias]), %%zmm1\n"
|
||||
"vmovaps 128(%[bias]), %%zmm2\n"
|
||||
"vmovaps 192(%[bias]), %%zmm3\n"
|
||||
"vmovaps 256(%[bias]), %%zmm4\n"
|
||||
"vmovaps 320(%[bias]), %%zmm5\n"
|
||||
"vmovaps 0(%[bias]), %%zmm6\n"
|
||||
"vmovaps 64(%[bias]), %%zmm7\n"
|
||||
"vmovaps 128(%[bias]), %%zmm8\n"
|
||||
"vmovaps 192(%[bias]), %%zmm9\n"
|
||||
"vmovaps 256(%[bias]), %%zmm10\n"
|
||||
"vmovaps 320(%[bias]), %%zmm11\n"
|
||||
"vmovaps 0(%[bias]), %%zmm12\n"
|
||||
"vmovaps 64(%[bias]), %%zmm13\n"
|
||||
"vmovaps 128(%[bias]), %%zmm14\n"
|
||||
"vmovaps 192(%[bias]), %%zmm15\n"
|
||||
"vmovaps 256(%[bias]), %%zmm16\n"
|
||||
"vmovaps 320(%[bias]), %%zmm17\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
|
||||
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
|
||||
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
|
||||
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
|
||||
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
|
||||
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
|
||||
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
|
||||
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
|
||||
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
|
||||
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
|
||||
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
|
||||
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
|
||||
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
|
||||
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
|
||||
"vxorps %%zmm14, %%zmm14, %%zmm14\n"
|
||||
"vxorps %%zmm15, %%zmm15, %%zmm15\n"
|
||||
"vxorps %%zmm16, %%zmm16, %%zmm16\n"
|
||||
"vxorps %%zmm17, %%zmm17, %%zmm17\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
|
||||
"%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovups 0(%[weight]), %%zmm31\n"
|
||||
"vmovups 64(%[weight]), %%zmm30\n"
|
||||
"vmovups 128(%[weight]), %%zmm29\n"
|
||||
"vmovups 192(%[weight]), %%zmm28\n"
|
||||
"vmovups 256(%[weight]), %%zmm27\n"
|
||||
"vmovups 320(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 0(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
|
||||
// block 1
|
||||
"vmovups 384(%[weight]), %%zmm31\n"
|
||||
"vmovups 448(%[weight]), %%zmm30\n"
|
||||
"vmovups 512(%[weight]), %%zmm29\n"
|
||||
"vmovups 576(%[weight]), %%zmm28\n"
|
||||
"vmovups 640(%[weight]), %%zmm27\n"
|
||||
"vmovups 704(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 4(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
|
||||
// block 2
|
||||
"vmovups 768(%[weight]), %%zmm31\n"
|
||||
"vmovups 832(%[weight]), %%zmm30\n"
|
||||
"vmovups 896(%[weight]), %%zmm29\n"
|
||||
"vmovups 960(%[weight]), %%zmm28\n"
|
||||
"vmovups 1024(%[weight]), %%zmm27\n"
|
||||
"vmovups 1088(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 8(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
|
||||
// block 3
|
||||
"vmovups 1152(%[weight]), %%zmm31\n"
|
||||
"vmovups 1216(%[weight]), %%zmm30\n"
|
||||
"vmovups 1280(%[weight]), %%zmm29\n"
|
||||
"vmovups 1344(%[weight]), %%zmm28\n"
|
||||
"vmovups 1408(%[weight]), %%zmm27\n"
|
||||
"vmovups 1472(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 12(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
|
||||
// block 4
|
||||
"vmovups 1536(%[weight]), %%zmm31\n"
|
||||
"vmovups 1600(%[weight]), %%zmm30\n"
|
||||
"vmovups 1664(%[weight]), %%zmm29\n"
|
||||
"vmovups 1728(%[weight]), %%zmm28\n"
|
||||
"vmovups 1792(%[weight]), %%zmm27\n"
|
||||
"vmovups 1856(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 16(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
|
||||
// block 5
|
||||
"vmovups 1920(%[weight]), %%zmm31\n"
|
||||
"vmovups 1984(%[weight]), %%zmm30\n"
|
||||
"vmovups 2048(%[weight]), %%zmm29\n"
|
||||
"vmovups 2112(%[weight]), %%zmm28\n"
|
||||
"vmovups 2176(%[weight]), %%zmm27\n"
|
||||
"vmovups 2240(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 20(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
|
||||
// block 6
|
||||
"vmovups 2304(%[weight]), %%zmm31\n"
|
||||
"vmovups 2368(%[weight]), %%zmm30\n"
|
||||
"vmovups 2432(%[weight]), %%zmm29\n"
|
||||
"vmovups 2496(%[weight]), %%zmm28\n"
|
||||
"vmovups 2560(%[weight]), %%zmm27\n"
|
||||
"vmovups 2624(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 24(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
|
||||
// block 7
|
||||
"vmovups 2688(%[weight]), %%zmm31\n"
|
||||
"vmovups 2752(%[weight]), %%zmm30\n"
|
||||
"vmovups 2816(%[weight]), %%zmm29\n"
|
||||
"vmovups 2880(%[weight]), %%zmm28\n"
|
||||
"vmovups 2944(%[weight]), %%zmm27\n"
|
||||
"vmovups 3008(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 28(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm23, %%zmm17\n"
|
||||
|
||||
"dec %[deep]\n"
|
||||
"add $3072, %[weight]\n"
|
||||
"add $32, %[src_0]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"and $0x2, %[inc_flag]\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
|
||||
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
|
||||
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
|
||||
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
|
||||
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
|
||||
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
|
||||
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
|
||||
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
|
||||
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
|
||||
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
|
||||
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
|
||||
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
|
||||
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
|
||||
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
|
||||
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
|
||||
"vmaxps %%zmm14, %%zmm31, %%zmm14\n"
|
||||
"vmaxps %%zmm15, %%zmm31, %%zmm15\n"
|
||||
"vmaxps %%zmm16, %%zmm31, %%zmm16\n"
|
||||
"vmaxps %%zmm17, %%zmm31, %%zmm17\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm30\n"
|
||||
"vbroadcastss %%xmm30, %%zmm30\n"
|
||||
"vminps %%zmm0, %%zmm30, %%zmm0\n"
|
||||
"vminps %%zmm1, %%zmm30, %%zmm1\n"
|
||||
"vminps %%zmm2, %%zmm30, %%zmm2\n"
|
||||
"vminps %%zmm3, %%zmm30, %%zmm3\n"
|
||||
"vminps %%zmm4, %%zmm30, %%zmm4\n"
|
||||
"vminps %%zmm5, %%zmm30, %%zmm5\n"
|
||||
"vminps %%zmm6, %%zmm30, %%zmm6\n"
|
||||
"vminps %%zmm7, %%zmm30, %%zmm7\n"
|
||||
"vminps %%zmm8, %%zmm30, %%zmm8\n"
|
||||
"vminps %%zmm9, %%zmm30, %%zmm9\n"
|
||||
"vminps %%zmm10, %%zmm30, %%zmm10\n"
|
||||
"vminps %%zmm11, %%zmm30, %%zmm11\n"
|
||||
"vminps %%zmm12, %%zmm30, %%zmm12\n"
|
||||
"vminps %%zmm13, %%zmm30, %%zmm13\n"
|
||||
"vminps %%zmm14, %%zmm30, %%zmm14\n"
|
||||
"vminps %%zmm15, %%zmm30, %%zmm15\n"
|
||||
"vminps %%zmm16, %%zmm30, %%zmm16\n"
|
||||
"vminps %%zmm17, %%zmm30, %%zmm17\n"
|
||||
"3:\n"
|
||||
"vmovups %%zmm0, 0(%[dst_0])\n"
|
||||
"vmovups %%zmm1, 64(%[dst_0])\n"
|
||||
"vmovups %%zmm2, 128(%[dst_0])\n"
|
||||
"vmovups %%zmm3, 192(%[dst_0])\n"
|
||||
"vmovups %%zmm4, 256(%[dst_0])\n"
|
||||
"vmovups %%zmm5, 320(%[dst_0])\n"
|
||||
"vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2),\n"
|
||||
:
|
||||
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
|
||||
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
|
||||
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
|
||||
}
|
|
@ -0,0 +1,434 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 avx512 asm code
|
||||
void nnacl_gemm_avx512_4x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
const float *dst_3 = dst + 3 * dst_stride;
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst_0]), %%zmm0\n"
|
||||
"vmovups 64(%[dst_0]), %%zmm1\n"
|
||||
"vmovups 128(%[dst_0]), %%zmm2\n"
|
||||
"vmovups 192(%[dst_0]), %%zmm3\n"
|
||||
"vmovups 256(%[dst_0]), %%zmm4\n"
|
||||
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n"
|
||||
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n"
|
||||
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n"
|
||||
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n"
|
||||
"vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n"
|
||||
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n"
|
||||
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n"
|
||||
"vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n"
|
||||
"vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n"
|
||||
"vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n"
|
||||
"vmovups 0(%[dst_3]), %%zmm15\n"
|
||||
"vmovups 64(%[dst_3]), %%zmm16\n"
|
||||
"vmovups 128(%[dst_3]), %%zmm17\n"
|
||||
"vmovups 192(%[dst_3]), %%zmm18\n"
|
||||
"vmovups 256(%[dst_3]), %%zmm19\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%zmm0\n"
|
||||
"vmovaps 64(%[bias]), %%zmm1\n"
|
||||
"vmovaps 128(%[bias]), %%zmm2\n"
|
||||
"vmovaps 192(%[bias]), %%zmm3\n"
|
||||
"vmovaps 256(%[bias]), %%zmm4\n"
|
||||
"vmovaps 0(%[bias]), %%zmm5\n"
|
||||
"vmovaps 64(%[bias]), %%zmm6\n"
|
||||
"vmovaps 128(%[bias]), %%zmm7\n"
|
||||
"vmovaps 192(%[bias]), %%zmm8\n"
|
||||
"vmovaps 256(%[bias]), %%zmm9\n"
|
||||
"vmovaps 0(%[bias]), %%zmm10\n"
|
||||
"vmovaps 64(%[bias]), %%zmm11\n"
|
||||
"vmovaps 128(%[bias]), %%zmm12\n"
|
||||
"vmovaps 192(%[bias]), %%zmm13\n"
|
||||
"vmovaps 256(%[bias]), %%zmm14\n"
|
||||
"vmovaps 0(%[bias]), %%zmm15\n"
|
||||
"vmovaps 64(%[bias]), %%zmm16\n"
|
||||
"vmovaps 128(%[bias]), %%zmm17\n"
|
||||
"vmovaps 192(%[bias]), %%zmm18\n"
|
||||
"vmovaps 256(%[bias]), %%zmm19\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
|
||||
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
|
||||
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
|
||||
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
|
||||
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
|
||||
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
|
||||
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
|
||||
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
|
||||
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
|
||||
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
|
||||
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
|
||||
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
|
||||
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
|
||||
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
|
||||
"vxorps %%zmm14, %%zmm14, %%zmm14\n"
|
||||
"vxorps %%zmm15, %%zmm15, %%zmm15\n"
|
||||
"vxorps %%zmm16, %%zmm16, %%zmm16\n"
|
||||
"vxorps %%zmm17, %%zmm17, %%zmm17\n"
|
||||
"vxorps %%zmm18, %%zmm18, %%zmm18\n"
|
||||
"vxorps %%zmm19, %%zmm19, %%zmm19\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
|
||||
[ dst_3 ] "r"(dst_3)
|
||||
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
|
||||
"%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19");
|
||||
const float *src_3 = src + 3 * dst_stride;
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovups 0(%[weight]), %%zmm31\n"
|
||||
"vmovups 64(%[weight]), %%zmm30\n"
|
||||
"vmovups 128(%[weight]), %%zmm29\n"
|
||||
"vmovups 192(%[weight]), %%zmm28\n"
|
||||
"vmovups 256(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 0(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vbroadcastss 0(%[src_0], %[src_stride], 3), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
|
||||
// block 1
|
||||
"vmovups 320(%[weight]), %%zmm31\n"
|
||||
"vmovups 384(%[weight]), %%zmm30\n"
|
||||
"vmovups 448(%[weight]), %%zmm29\n"
|
||||
"vmovups 512(%[weight]), %%zmm28\n"
|
||||
"vmovups 576(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 4(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vbroadcastss 4(%[src_0], %[src_stride], 3), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
|
||||
// block 2
|
||||
"vmovups 640(%[weight]), %%zmm31\n"
|
||||
"vmovups 704(%[weight]), %%zmm30\n"
|
||||
"vmovups 768(%[weight]), %%zmm29\n"
|
||||
"vmovups 832(%[weight]), %%zmm28\n"
|
||||
"vmovups 896(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 8(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vbroadcastss 8(%[src_0], %[src_stride], 3), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
|
||||
// block 3
|
||||
"vmovups 960(%[weight]), %%zmm31\n"
|
||||
"vmovups 1024(%[weight]), %%zmm30\n"
|
||||
"vmovups 1088(%[weight]), %%zmm29\n"
|
||||
"vmovups 1152(%[weight]), %%zmm28\n"
|
||||
"vmovups 1216(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 12(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vbroadcastss 12(%[src_0], %[src_stride], 3), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
|
||||
// block 4
|
||||
"vmovups 1280(%[weight]), %%zmm31\n"
|
||||
"vmovups 1344(%[weight]), %%zmm30\n"
|
||||
"vmovups 1408(%[weight]), %%zmm29\n"
|
||||
"vmovups 1472(%[weight]), %%zmm28\n"
|
||||
"vmovups 1536(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 16(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vbroadcastss 16(%[src_0], %[src_stride], 3), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
|
||||
// block 5
|
||||
"vmovups 1600(%[weight]), %%zmm31\n"
|
||||
"vmovups 1664(%[weight]), %%zmm30\n"
|
||||
"vmovups 1728(%[weight]), %%zmm29\n"
|
||||
"vmovups 1792(%[weight]), %%zmm28\n"
|
||||
"vmovups 1856(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 20(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vbroadcastss 20(%[src_0], %[src_stride], 3), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
|
||||
// block 6
|
||||
"vmovups 1920(%[weight]), %%zmm31\n"
|
||||
"vmovups 1984(%[weight]), %%zmm30\n"
|
||||
"vmovups 2048(%[weight]), %%zmm29\n"
|
||||
"vmovups 2112(%[weight]), %%zmm28\n"
|
||||
"vmovups 2176(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 24(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vbroadcastss 24(%[src_0], %[src_stride], 3), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
|
||||
// block 7
|
||||
"vmovups 2240(%[weight]), %%zmm31\n"
|
||||
"vmovups 2304(%[weight]), %%zmm30\n"
|
||||
"vmovups 2368(%[weight]), %%zmm29\n"
|
||||
"vmovups 2432(%[weight]), %%zmm28\n"
|
||||
"vmovups 2496(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 28(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm24\n"
|
||||
"vbroadcastss 28(%[src_0], %[src_stride], 3), %%zmm23\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm23, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm23, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm23, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm23, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm23, %%zmm19\n"
|
||||
|
||||
"dec %[deep]\n"
|
||||
"add $2560, %[weight]\n"
|
||||
"add $32, %[src_0]\n"
|
||||
"add $32, %[src_3]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"and $0x2, %[inc_flag]\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
|
||||
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
|
||||
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
|
||||
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
|
||||
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
|
||||
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
|
||||
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
|
||||
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
|
||||
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
|
||||
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
|
||||
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
|
||||
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
|
||||
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
|
||||
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
|
||||
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
|
||||
"vmaxps %%zmm14, %%zmm31, %%zmm14\n"
|
||||
"vmaxps %%zmm15, %%zmm31, %%zmm15\n"
|
||||
"vmaxps %%zmm16, %%zmm31, %%zmm16\n"
|
||||
"vmaxps %%zmm17, %%zmm31, %%zmm17\n"
|
||||
"vmaxps %%zmm18, %%zmm31, %%zmm18\n"
|
||||
"vmaxps %%zmm19, %%zmm31, %%zmm19\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm30\n"
|
||||
"vbroadcastss %%xmm30, %%zmm30\n"
|
||||
"vminps %%zmm0, %%zmm30, %%zmm0\n"
|
||||
"vminps %%zmm1, %%zmm30, %%zmm1\n"
|
||||
"vminps %%zmm2, %%zmm30, %%zmm2\n"
|
||||
"vminps %%zmm3, %%zmm30, %%zmm3\n"
|
||||
"vminps %%zmm4, %%zmm30, %%zmm4\n"
|
||||
"vminps %%zmm5, %%zmm30, %%zmm5\n"
|
||||
"vminps %%zmm6, %%zmm30, %%zmm6\n"
|
||||
"vminps %%zmm7, %%zmm30, %%zmm7\n"
|
||||
"vminps %%zmm8, %%zmm30, %%zmm8\n"
|
||||
"vminps %%zmm9, %%zmm30, %%zmm9\n"
|
||||
"vminps %%zmm10, %%zmm30, %%zmm10\n"
|
||||
"vminps %%zmm11, %%zmm30, %%zmm11\n"
|
||||
"vminps %%zmm12, %%zmm30, %%zmm12\n"
|
||||
"vminps %%zmm13, %%zmm30, %%zmm13\n"
|
||||
"vminps %%zmm14, %%zmm30, %%zmm14\n"
|
||||
"vminps %%zmm15, %%zmm30, %%zmm15\n"
|
||||
"vminps %%zmm16, %%zmm30, %%zmm16\n"
|
||||
"vminps %%zmm17, %%zmm30, %%zmm17\n"
|
||||
"vminps %%zmm18, %%zmm30, %%zmm18\n"
|
||||
"vminps %%zmm19, %%zmm30, %%zmm19\n"
|
||||
"3:\n"
|
||||
"vmovups %%zmm0, 0(%[dst_0])\n"
|
||||
"vmovups %%zmm1, 64(%[dst_0])\n"
|
||||
"vmovups %%zmm2, 128(%[dst_0])\n"
|
||||
"vmovups %%zmm3, 192(%[dst_0])\n"
|
||||
"vmovups %%zmm4, 256(%[dst_0])\n"
|
||||
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm15, 0(%[dst_3])\n"
|
||||
"vmovups %%zmm16, 64(%[dst_3])\n"
|
||||
"vmovups %%zmm17, 128(%[dst_3])\n"
|
||||
"vmovups %%zmm18, 192(%[dst_3])\n"
|
||||
"vmovups %%zmm19, 256(%[dst_3])\n"
|
||||
:
|
||||
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
|
||||
[ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3)
|
||||
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
|
||||
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
|
||||
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
|
||||
}
|
|
@ -0,0 +1,499 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 avx512 asm code
|
||||
void nnacl_gemm_avx512_4x96_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
const float *dst_3 = dst + 3 * dst_stride;
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst_0]), %%zmm0\n"
|
||||
"vmovups 64(%[dst_0]), %%zmm1\n"
|
||||
"vmovups 128(%[dst_0]), %%zmm2\n"
|
||||
"vmovups 192(%[dst_0]), %%zmm3\n"
|
||||
"vmovups 256(%[dst_0]), %%zmm4\n"
|
||||
"vmovups 320(%[dst_0]), %%zmm5\n"
|
||||
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm6\n"
|
||||
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm7\n"
|
||||
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm8\n"
|
||||
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm9\n"
|
||||
"vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm10\n"
|
||||
"vmovups 320(%[dst_0], %[dst_stride], 1), %%zmm11\n"
|
||||
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm12\n"
|
||||
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm13\n"
|
||||
"vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm14\n"
|
||||
"vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm15\n"
|
||||
"vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm16\n"
|
||||
"vmovups 320(%[dst_0], %[dst_stride], 2), %%zmm17\n"
|
||||
"vmovups 0(%[dst_3]), %%zmm18\n"
|
||||
"vmovups 64(%[dst_3]), %%zmm19\n"
|
||||
"vmovups 128(%[dst_3]), %%zmm20\n"
|
||||
"vmovups 192(%[dst_3]), %%zmm21\n"
|
||||
"vmovups 256(%[dst_3]), %%zmm22\n"
|
||||
"vmovups 320(%[dst_3]), %%zmm23\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%zmm0\n"
|
||||
"vmovaps 64(%[bias]), %%zmm1\n"
|
||||
"vmovaps 128(%[bias]), %%zmm2\n"
|
||||
"vmovaps 192(%[bias]), %%zmm3\n"
|
||||
"vmovaps 256(%[bias]), %%zmm4\n"
|
||||
"vmovaps 320(%[bias]), %%zmm5\n"
|
||||
"vmovaps 0(%[bias]), %%zmm6\n"
|
||||
"vmovaps 64(%[bias]), %%zmm7\n"
|
||||
"vmovaps 128(%[bias]), %%zmm8\n"
|
||||
"vmovaps 192(%[bias]), %%zmm9\n"
|
||||
"vmovaps 256(%[bias]), %%zmm10\n"
|
||||
"vmovaps 320(%[bias]), %%zmm11\n"
|
||||
"vmovaps 0(%[bias]), %%zmm12\n"
|
||||
"vmovaps 64(%[bias]), %%zmm13\n"
|
||||
"vmovaps 128(%[bias]), %%zmm14\n"
|
||||
"vmovaps 192(%[bias]), %%zmm15\n"
|
||||
"vmovaps 256(%[bias]), %%zmm16\n"
|
||||
"vmovaps 320(%[bias]), %%zmm17\n"
|
||||
"vmovaps 0(%[bias]), %%zmm18\n"
|
||||
"vmovaps 64(%[bias]), %%zmm19\n"
|
||||
"vmovaps 128(%[bias]), %%zmm20\n"
|
||||
"vmovaps 192(%[bias]), %%zmm21\n"
|
||||
"vmovaps 256(%[bias]), %%zmm22\n"
|
||||
"vmovaps 320(%[bias]), %%zmm23\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
|
||||
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
|
||||
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
|
||||
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
|
||||
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
|
||||
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
|
||||
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
|
||||
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
|
||||
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
|
||||
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
|
||||
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
|
||||
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
|
||||
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
|
||||
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
|
||||
"vxorps %%zmm14, %%zmm14, %%zmm14\n"
|
||||
"vxorps %%zmm15, %%zmm15, %%zmm15\n"
|
||||
"vxorps %%zmm16, %%zmm16, %%zmm16\n"
|
||||
"vxorps %%zmm17, %%zmm17, %%zmm17\n"
|
||||
"vxorps %%zmm18, %%zmm18, %%zmm18\n"
|
||||
"vxorps %%zmm19, %%zmm19, %%zmm19\n"
|
||||
"vxorps %%zmm20, %%zmm20, %%zmm20\n"
|
||||
"vxorps %%zmm21, %%zmm21, %%zmm21\n"
|
||||
"vxorps %%zmm22, %%zmm22, %%zmm22\n"
|
||||
"vxorps %%zmm23, %%zmm23, %%zmm23\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
|
||||
[ dst_3 ] "r"(dst_3)
|
||||
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
|
||||
"%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22",
|
||||
"%zmm23");
|
||||
const float *src_3 = src + 3 * dst_stride;
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovups 0(%[weight]), %%zmm31\n"
|
||||
"vmovups 64(%[weight]), %%zmm30\n"
|
||||
"vmovups 128(%[weight]), %%zmm29\n"
|
||||
"vmovups 192(%[weight]), %%zmm28\n"
|
||||
"vmovups 256(%[weight]), %%zmm27\n"
|
||||
"vmovups 320(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 0(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm25\n"
|
||||
"vbroadcastss 0(%[src_3]), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
|
||||
// block 1
|
||||
"vmovups 384(%[weight]), %%zmm31\n"
|
||||
"vmovups 448(%[weight]), %%zmm30\n"
|
||||
"vmovups 512(%[weight]), %%zmm29\n"
|
||||
"vmovups 576(%[weight]), %%zmm28\n"
|
||||
"vmovups 640(%[weight]), %%zmm27\n"
|
||||
"vmovups 704(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 4(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm25\n"
|
||||
"vbroadcastss 4(%[src_3]), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
|
||||
// block 2
|
||||
"vmovups 768(%[weight]), %%zmm31\n"
|
||||
"vmovups 832(%[weight]), %%zmm30\n"
|
||||
"vmovups 896(%[weight]), %%zmm29\n"
|
||||
"vmovups 960(%[weight]), %%zmm28\n"
|
||||
"vmovups 1024(%[weight]), %%zmm27\n"
|
||||
"vmovups 1088(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 8(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm25\n"
|
||||
"vbroadcastss 8(%[src_3]), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
|
||||
// block 3
|
||||
"vmovups 1152(%[weight]), %%zmm31\n"
|
||||
"vmovups 1216(%[weight]), %%zmm30\n"
|
||||
"vmovups 1280(%[weight]), %%zmm29\n"
|
||||
"vmovups 1344(%[weight]), %%zmm28\n"
|
||||
"vmovups 1408(%[weight]), %%zmm27\n"
|
||||
"vmovups 1472(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 12(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm25\n"
|
||||
"vbroadcastss 12(%[src_3]), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
|
||||
// block 4
|
||||
"vmovups 1536(%[weight]), %%zmm31\n"
|
||||
"vmovups 1600(%[weight]), %%zmm30\n"
|
||||
"vmovups 1664(%[weight]), %%zmm29\n"
|
||||
"vmovups 1728(%[weight]), %%zmm28\n"
|
||||
"vmovups 1792(%[weight]), %%zmm27\n"
|
||||
"vmovups 1856(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 16(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm25\n"
|
||||
"vbroadcastss 16(%[src_3]), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
|
||||
// block 5
|
||||
"vmovups 1920(%[weight]), %%zmm31\n"
|
||||
"vmovups 1984(%[weight]), %%zmm30\n"
|
||||
"vmovups 2048(%[weight]), %%zmm29\n"
|
||||
"vmovups 2112(%[weight]), %%zmm28\n"
|
||||
"vmovups 2176(%[weight]), %%zmm27\n"
|
||||
"vmovups 2240(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 20(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm25\n"
|
||||
"vbroadcastss 20(%[src_3]), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
|
||||
// block 6
|
||||
"vmovups 2304(%[weight]), %%zmm31\n"
|
||||
"vmovups 2368(%[weight]), %%zmm30\n"
|
||||
"vmovups 2432(%[weight]), %%zmm29\n"
|
||||
"vmovups 2496(%[weight]), %%zmm28\n"
|
||||
"vmovups 2560(%[weight]), %%zmm27\n"
|
||||
"vmovups 2624(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 24(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm25\n"
|
||||
"vbroadcastss 24(%[src_3]), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
|
||||
// block 7
|
||||
"vmovups 2688(%[weight]), %%zmm31\n"
|
||||
"vmovups 2752(%[weight]), %%zmm30\n"
|
||||
"vmovups 2816(%[weight]), %%zmm29\n"
|
||||
"vmovups 2880(%[weight]), %%zmm28\n"
|
||||
"vmovups 2944(%[weight]), %%zmm27\n"
|
||||
"vmovups 3008(%[weight]), %%zmm26\n"
|
||||
"vbroadcastss 28(%[src_0]), %%zmm25\n"
|
||||
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm9\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm11\n"
|
||||
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm25\n"
|
||||
"vbroadcastss 28(%[src_3]), %%zmm24\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm24, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm24, %%zmm19\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm24, %%zmm20\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm24, %%zmm21\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm24, %%zmm22\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm26, %%zmm24, %%zmm23\n"
|
||||
|
||||
"dec %[deep]\n"
|
||||
"add $3072, %[weight]\n"
|
||||
"add $32, %[src_0]\n"
|
||||
"add $32, %[src_3]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"and $0x2, %[inc_flag]\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
|
||||
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
|
||||
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
|
||||
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
|
||||
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
|
||||
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
|
||||
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
|
||||
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
|
||||
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
|
||||
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
|
||||
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
|
||||
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
|
||||
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
|
||||
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
|
||||
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
|
||||
"vmaxps %%zmm14, %%zmm31, %%zmm14\n"
|
||||
"vmaxps %%zmm15, %%zmm31, %%zmm15\n"
|
||||
"vmaxps %%zmm16, %%zmm31, %%zmm16\n"
|
||||
"vmaxps %%zmm17, %%zmm31, %%zmm17\n"
|
||||
"vmaxps %%zmm18, %%zmm31, %%zmm18\n"
|
||||
"vmaxps %%zmm19, %%zmm31, %%zmm19\n"
|
||||
"vmaxps %%zmm20, %%zmm31, %%zmm20\n"
|
||||
"vmaxps %%zmm21, %%zmm31, %%zmm21\n"
|
||||
"vmaxps %%zmm22, %%zmm31, %%zmm22\n"
|
||||
"vmaxps %%zmm23, %%zmm31, %%zmm23\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm30\n"
|
||||
"vbroadcastss %%xmm30, %%zmm30\n"
|
||||
"vminps %%zmm0, %%zmm30, %%zmm0\n"
|
||||
"vminps %%zmm1, %%zmm30, %%zmm1\n"
|
||||
"vminps %%zmm2, %%zmm30, %%zmm2\n"
|
||||
"vminps %%zmm3, %%zmm30, %%zmm3\n"
|
||||
"vminps %%zmm4, %%zmm30, %%zmm4\n"
|
||||
"vminps %%zmm5, %%zmm30, %%zmm5\n"
|
||||
"vminps %%zmm6, %%zmm30, %%zmm6\n"
|
||||
"vminps %%zmm7, %%zmm30, %%zmm7\n"
|
||||
"vminps %%zmm8, %%zmm30, %%zmm8\n"
|
||||
"vminps %%zmm9, %%zmm30, %%zmm9\n"
|
||||
"vminps %%zmm10, %%zmm30, %%zmm10\n"
|
||||
"vminps %%zmm11, %%zmm30, %%zmm11\n"
|
||||
"vminps %%zmm12, %%zmm30, %%zmm12\n"
|
||||
"vminps %%zmm13, %%zmm30, %%zmm13\n"
|
||||
"vminps %%zmm14, %%zmm30, %%zmm14\n"
|
||||
"vminps %%zmm15, %%zmm30, %%zmm15\n"
|
||||
"vminps %%zmm16, %%zmm30, %%zmm16\n"
|
||||
"vminps %%zmm17, %%zmm30, %%zmm17\n"
|
||||
"vminps %%zmm18, %%zmm30, %%zmm18\n"
|
||||
"vminps %%zmm19, %%zmm30, %%zmm19\n"
|
||||
"vminps %%zmm20, %%zmm30, %%zmm20\n"
|
||||
"vminps %%zmm21, %%zmm30, %%zmm21\n"
|
||||
"vminps %%zmm22, %%zmm30, %%zmm22\n"
|
||||
"vminps %%zmm23, %%zmm30, %%zmm23\n"
|
||||
"3:\n"
|
||||
"vmovups %%zmm0, 0(%[dst_0])\n"
|
||||
"vmovups %%zmm1, 64(%[dst_0])\n"
|
||||
"vmovups %%zmm2, 128(%[dst_0])\n"
|
||||
"vmovups %%zmm3, 192(%[dst_0])\n"
|
||||
"vmovups %%zmm4, 256(%[dst_0])\n"
|
||||
"vmovups %%zmm5, 320(%[dst_0])\n"
|
||||
"vmovups %%zmm6, 0(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm7, 64(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm8, 128(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm9, 192(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm10, 256(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm11, 320(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm12, 0(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm13, 64(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm14, 128(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm15, 192(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm16, 256(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm17, 320(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm18, 0(%[dst_3])\n"
|
||||
"vmovups %%zmm19, 64(%[dst_3])\n"
|
||||
"vmovups %%zmm20, 128(%[dst_3])\n"
|
||||
"vmovups %%zmm21, 192(%[dst_3])\n"
|
||||
"vmovups %%zmm22, 256(%[dst_3])\n"
|
||||
"vmovups %%zmm23, 320(%[dst_3])\n"
|
||||
:
|
||||
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
|
||||
[ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3)
|
||||
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
|
||||
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
|
||||
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
|
||||
}
|
|
@ -0,0 +1,513 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 avx512 asm code
|
||||
void nnacl_gemm_avx512_5x80_kernel_nhwc_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
const float *dst_3 = dst + 3 * dst_stride;
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst_0]), %%zmm0\n"
|
||||
"vmovups 64(%[dst_0]), %%zmm1\n"
|
||||
"vmovups 128(%[dst_0]), %%zmm2\n"
|
||||
"vmovups 192(%[dst_0]), %%zmm3\n"
|
||||
"vmovups 256(%[dst_0]), %%zmm4\n"
|
||||
"vmovups 0(%[dst_0], %[dst_stride], 1), %%zmm5\n"
|
||||
"vmovups 64(%[dst_0], %[dst_stride], 1), %%zmm6\n"
|
||||
"vmovups 128(%[dst_0], %[dst_stride], 1), %%zmm7\n"
|
||||
"vmovups 192(%[dst_0], %[dst_stride], 1), %%zmm8\n"
|
||||
"vmovups 256(%[dst_0], %[dst_stride], 1), %%zmm9\n"
|
||||
"vmovups 0(%[dst_0], %[dst_stride], 2), %%zmm10\n"
|
||||
"vmovups 64(%[dst_0], %[dst_stride], 2), %%zmm11\n"
|
||||
"vmovups 128(%[dst_0], %[dst_stride], 2), %%zmm12\n"
|
||||
"vmovups 192(%[dst_0], %[dst_stride], 2), %%zmm13\n"
|
||||
"vmovups 256(%[dst_0], %[dst_stride], 2), %%zmm14\n"
|
||||
"vmovups 0(%[dst_3]), %%zmm15\n"
|
||||
"vmovups 64(%[dst_3]), %%zmm16\n"
|
||||
"vmovups 128(%[dst_3]), %%zmm17\n"
|
||||
"vmovups 192(%[dst_3]), %%zmm18\n"
|
||||
"vmovups 256(%[dst_3]), %%zmm19\n"
|
||||
"vmovups 0(%[dst_3], %[dst_stride], 1), %%zmm20\n"
|
||||
"vmovups 64(%[dst_3], %[dst_stride], 1), %%zmm21\n"
|
||||
"vmovups 128(%[dst_3], %[dst_stride], 1), %%zmm22\n"
|
||||
"vmovups 192(%[dst_3], %[dst_stride], 1), %%zmm23\n"
|
||||
"vmovups 256(%[dst_3], %[dst_stride], 1), %%zmm24\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%zmm0\n"
|
||||
"vmovaps 64(%[bias]), %%zmm1\n"
|
||||
"vmovaps 128(%[bias]), %%zmm2\n"
|
||||
"vmovaps 192(%[bias]), %%zmm3\n"
|
||||
"vmovaps 256(%[bias]), %%zmm4\n"
|
||||
"vmovaps 0(%[bias]), %%zmm5\n"
|
||||
"vmovaps 64(%[bias]), %%zmm6\n"
|
||||
"vmovaps 128(%[bias]), %%zmm7\n"
|
||||
"vmovaps 192(%[bias]), %%zmm8\n"
|
||||
"vmovaps 256(%[bias]), %%zmm9\n"
|
||||
"vmovaps 0(%[bias]), %%zmm10\n"
|
||||
"vmovaps 64(%[bias]), %%zmm11\n"
|
||||
"vmovaps 128(%[bias]), %%zmm12\n"
|
||||
"vmovaps 192(%[bias]), %%zmm13\n"
|
||||
"vmovaps 256(%[bias]), %%zmm14\n"
|
||||
"vmovaps 0(%[bias]), %%zmm15\n"
|
||||
"vmovaps 64(%[bias]), %%zmm16\n"
|
||||
"vmovaps 128(%[bias]), %%zmm17\n"
|
||||
"vmovaps 192(%[bias]), %%zmm18\n"
|
||||
"vmovaps 256(%[bias]), %%zmm19\n"
|
||||
"vmovaps 0(%[bias]), %%zmm20\n"
|
||||
"vmovaps 64(%[bias]), %%zmm21\n"
|
||||
"vmovaps 128(%[bias]), %%zmm22\n"
|
||||
"vmovaps 192(%[bias]), %%zmm23\n"
|
||||
"vmovaps 256(%[bias]), %%zmm24\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%zmm0, %%zmm0, %%zmm0\n"
|
||||
"vxorps %%zmm1, %%zmm1, %%zmm1\n"
|
||||
"vxorps %%zmm2, %%zmm2, %%zmm2\n"
|
||||
"vxorps %%zmm3, %%zmm3, %%zmm3\n"
|
||||
"vxorps %%zmm4, %%zmm4, %%zmm4\n"
|
||||
"vxorps %%zmm5, %%zmm5, %%zmm5\n"
|
||||
"vxorps %%zmm6, %%zmm6, %%zmm6\n"
|
||||
"vxorps %%zmm7, %%zmm7, %%zmm7\n"
|
||||
"vxorps %%zmm8, %%zmm8, %%zmm8\n"
|
||||
"vxorps %%zmm9, %%zmm9, %%zmm9\n"
|
||||
"vxorps %%zmm10, %%zmm10, %%zmm10\n"
|
||||
"vxorps %%zmm11, %%zmm11, %%zmm11\n"
|
||||
"vxorps %%zmm12, %%zmm12, %%zmm12\n"
|
||||
"vxorps %%zmm13, %%zmm13, %%zmm13\n"
|
||||
"vxorps %%zmm14, %%zmm14, %%zmm14\n"
|
||||
"vxorps %%zmm15, %%zmm15, %%zmm15\n"
|
||||
"vxorps %%zmm16, %%zmm16, %%zmm16\n"
|
||||
"vxorps %%zmm17, %%zmm17, %%zmm17\n"
|
||||
"vxorps %%zmm18, %%zmm18, %%zmm18\n"
|
||||
"vxorps %%zmm19, %%zmm19, %%zmm19\n"
|
||||
"vxorps %%zmm20, %%zmm20, %%zmm20\n"
|
||||
"vxorps %%zmm21, %%zmm21, %%zmm21\n"
|
||||
"vxorps %%zmm22, %%zmm22, %%zmm22\n"
|
||||
"vxorps %%zmm23, %%zmm23, %%zmm23\n"
|
||||
"vxorps %%zmm24, %%zmm24, %%zmm24\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst_0 ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
|
||||
[ dst_3 ] "r"(dst_3)
|
||||
: "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10", "%zmm11",
|
||||
"%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21", "%zmm22",
|
||||
"%zmm23", "%zmm24");
|
||||
const float *src_3 = src + 3 * dst_stride;
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovups 0(%[weight]), %%zmm31\n"
|
||||
"vmovups 64(%[weight]), %%zmm30\n"
|
||||
"vmovups 128(%[weight]), %%zmm29\n"
|
||||
"vmovups 192(%[weight]), %%zmm28\n"
|
||||
"vmovups 256(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 0(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 0(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vbroadcastss 0(%[src_0], %[src_stride], 2), %%zmm26\n"
|
||||
"vbroadcastss 0(%[src_3]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
|
||||
"vbroadcastss 0(%[src_3], %[src_stride], 1), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
// block 1
|
||||
"vmovups 320(%[weight]), %%zmm31\n"
|
||||
"vmovups 384(%[weight]), %%zmm30\n"
|
||||
"vmovups 448(%[weight]), %%zmm29\n"
|
||||
"vmovups 512(%[weight]), %%zmm28\n"
|
||||
"vmovups 576(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 4(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 4(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vbroadcastss 4(%[src_0], %[src_stride], 2), %%zmm26\n"
|
||||
"vbroadcastss 4(%[src_3]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
|
||||
"vbroadcastss 4(%[src_3], %[src_stride], 1), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
// block 2
|
||||
"vmovups 640(%[weight]), %%zmm31\n"
|
||||
"vmovups 704(%[weight]), %%zmm30\n"
|
||||
"vmovups 768(%[weight]), %%zmm29\n"
|
||||
"vmovups 832(%[weight]), %%zmm28\n"
|
||||
"vmovups 896(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 8(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 8(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vbroadcastss 8(%[src_0], %[src_stride], 2), %%zmm26\n"
|
||||
"vbroadcastss 8(%[src_3]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
|
||||
"vbroadcastss 8(%[src_3], %[src_stride], 1), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
// block 3
|
||||
"vmovups 960(%[weight]), %%zmm31\n"
|
||||
"vmovups 1024(%[weight]), %%zmm30\n"
|
||||
"vmovups 1088(%[weight]), %%zmm29\n"
|
||||
"vmovups 1152(%[weight]), %%zmm28\n"
|
||||
"vmovups 1216(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 12(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 12(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vbroadcastss 12(%[src_0], %[src_stride], 2), %%zmm26\n"
|
||||
"vbroadcastss 12(%[src_3]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
|
||||
"vbroadcastss 12(%[src_3], %[src_stride], 1), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
// block 4
|
||||
"vmovups 1280(%[weight]), %%zmm31\n"
|
||||
"vmovups 1344(%[weight]), %%zmm30\n"
|
||||
"vmovups 1408(%[weight]), %%zmm29\n"
|
||||
"vmovups 1472(%[weight]), %%zmm28\n"
|
||||
"vmovups 1536(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 16(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 16(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vbroadcastss 16(%[src_0], %[src_stride], 2), %%zmm26\n"
|
||||
"vbroadcastss 16(%[src_3]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
|
||||
"vbroadcastss 16(%[src_3], %[src_stride], 1), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
// block 5
|
||||
"vmovups 1600(%[weight]), %%zmm31\n"
|
||||
"vmovups 1664(%[weight]), %%zmm30\n"
|
||||
"vmovups 1728(%[weight]), %%zmm29\n"
|
||||
"vmovups 1792(%[weight]), %%zmm28\n"
|
||||
"vmovups 1856(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 20(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 20(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vbroadcastss 20(%[src_0], %[src_stride], 2), %%zmm26\n"
|
||||
"vbroadcastss 20(%[src_3]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
|
||||
"vbroadcastss 20(%[src_3], %[src_stride], 1), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
// block 6
|
||||
"vmovups 1920(%[weight]), %%zmm31\n"
|
||||
"vmovups 1984(%[weight]), %%zmm30\n"
|
||||
"vmovups 2048(%[weight]), %%zmm29\n"
|
||||
"vmovups 2112(%[weight]), %%zmm28\n"
|
||||
"vmovups 2176(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 24(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 24(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vbroadcastss 24(%[src_0], %[src_stride], 2), %%zmm26\n"
|
||||
"vbroadcastss 24(%[src_3]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
|
||||
"vbroadcastss 24(%[src_3], %[src_stride], 1), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
// block 7
|
||||
"vmovups 2240(%[weight]), %%zmm31\n"
|
||||
"vmovups 2304(%[weight]), %%zmm30\n"
|
||||
"vmovups 2368(%[weight]), %%zmm29\n"
|
||||
"vmovups 2432(%[weight]), %%zmm28\n"
|
||||
"vmovups 2496(%[weight]), %%zmm27\n"
|
||||
"vbroadcastss 28(%[src_0]), %%zmm26\n"
|
||||
"vbroadcastss 28(%[src_0], %[src_stride], 1), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm0\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm5\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm1\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm6\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm2\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm7\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm3\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm8\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm4\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm9\n"
|
||||
"vbroadcastss 28(%[src_0], %[src_stride], 2), %%zmm26\n"
|
||||
"vbroadcastss 28(%[src_3]), %%zmm25\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm25, %%zmm15\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm25, %%zmm16\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm25, %%zmm17\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm25, %%zmm18\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm25, %%zmm19\n"
|
||||
"vbroadcastss 28(%[src_3], %[src_stride], 1), %%zmm26\n"
|
||||
"vfmadd231ps %%zmm31, %%zmm26, %%zmm10\n"
|
||||
"vfmadd231ps %%zmm30, %%zmm26, %%zmm11\n"
|
||||
"vfmadd231ps %%zmm29, %%zmm26, %%zmm12\n"
|
||||
"vfmadd231ps %%zmm28, %%zmm26, %%zmm13\n"
|
||||
"vfmadd231ps %%zmm27, %%zmm26, %%zmm14\n"
|
||||
|
||||
"dec %[deep]\n"
|
||||
"add $2560, %[weight]\n"
|
||||
"add $32, %[src_0]\n"
|
||||
"add $32, %[src_3]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"and $0x2, %[inc_flag]\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%zmm31, %%zmm31, %%zmm31\n"
|
||||
"vmaxps %%zmm0, %%zmm31, %%zmm0\n"
|
||||
"vmaxps %%zmm1, %%zmm31, %%zmm1\n"
|
||||
"vmaxps %%zmm2, %%zmm31, %%zmm2\n"
|
||||
"vmaxps %%zmm3, %%zmm31, %%zmm3\n"
|
||||
"vmaxps %%zmm4, %%zmm31, %%zmm4\n"
|
||||
"vmaxps %%zmm5, %%zmm31, %%zmm5\n"
|
||||
"vmaxps %%zmm6, %%zmm31, %%zmm6\n"
|
||||
"vmaxps %%zmm7, %%zmm31, %%zmm7\n"
|
||||
"vmaxps %%zmm8, %%zmm31, %%zmm8\n"
|
||||
"vmaxps %%zmm9, %%zmm31, %%zmm9\n"
|
||||
"vmaxps %%zmm10, %%zmm31, %%zmm10\n"
|
||||
"vmaxps %%zmm11, %%zmm31, %%zmm11\n"
|
||||
"vmaxps %%zmm12, %%zmm31, %%zmm12\n"
|
||||
"vmaxps %%zmm13, %%zmm31, %%zmm13\n"
|
||||
"vmaxps %%zmm14, %%zmm31, %%zmm14\n"
|
||||
"vmaxps %%zmm15, %%zmm31, %%zmm15\n"
|
||||
"vmaxps %%zmm16, %%zmm31, %%zmm16\n"
|
||||
"vmaxps %%zmm17, %%zmm31, %%zmm17\n"
|
||||
"vmaxps %%zmm18, %%zmm31, %%zmm18\n"
|
||||
"vmaxps %%zmm19, %%zmm31, %%zmm19\n"
|
||||
"vmaxps %%zmm20, %%zmm31, %%zmm20\n"
|
||||
"vmaxps %%zmm21, %%zmm31, %%zmm21\n"
|
||||
"vmaxps %%zmm22, %%zmm31, %%zmm22\n"
|
||||
"vmaxps %%zmm23, %%zmm31, %%zmm23\n"
|
||||
"vmaxps %%zmm24, %%zmm31, %%zmm24\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm30\n"
|
||||
"vbroadcastss %%xmm30, %%zmm30\n"
|
||||
"vminps %%zmm0, %%zmm30, %%zmm0\n"
|
||||
"vminps %%zmm1, %%zmm30, %%zmm1\n"
|
||||
"vminps %%zmm2, %%zmm30, %%zmm2\n"
|
||||
"vminps %%zmm3, %%zmm30, %%zmm3\n"
|
||||
"vminps %%zmm4, %%zmm30, %%zmm4\n"
|
||||
"vminps %%zmm5, %%zmm30, %%zmm5\n"
|
||||
"vminps %%zmm6, %%zmm30, %%zmm6\n"
|
||||
"vminps %%zmm7, %%zmm30, %%zmm7\n"
|
||||
"vminps %%zmm8, %%zmm30, %%zmm8\n"
|
||||
"vminps %%zmm9, %%zmm30, %%zmm9\n"
|
||||
"vminps %%zmm10, %%zmm30, %%zmm10\n"
|
||||
"vminps %%zmm11, %%zmm30, %%zmm11\n"
|
||||
"vminps %%zmm12, %%zmm30, %%zmm12\n"
|
||||
"vminps %%zmm13, %%zmm30, %%zmm13\n"
|
||||
"vminps %%zmm14, %%zmm30, %%zmm14\n"
|
||||
"vminps %%zmm15, %%zmm30, %%zmm15\n"
|
||||
"vminps %%zmm16, %%zmm30, %%zmm16\n"
|
||||
"vminps %%zmm17, %%zmm30, %%zmm17\n"
|
||||
"vminps %%zmm18, %%zmm30, %%zmm18\n"
|
||||
"vminps %%zmm19, %%zmm30, %%zmm19\n"
|
||||
"vminps %%zmm20, %%zmm30, %%zmm20\n"
|
||||
"vminps %%zmm21, %%zmm30, %%zmm21\n"
|
||||
"vminps %%zmm22, %%zmm30, %%zmm22\n"
|
||||
"vminps %%zmm23, %%zmm30, %%zmm23\n"
|
||||
"vminps %%zmm24, %%zmm30, %%zmm24\n"
|
||||
"3:\n"
|
||||
"vmovups %%zmm0, 0(%[dst_0])\n"
|
||||
"vmovups %%zmm1, 64(%[dst_0])\n"
|
||||
"vmovups %%zmm2, 128(%[dst_0])\n"
|
||||
"vmovups %%zmm3, 192(%[dst_0])\n"
|
||||
"vmovups %%zmm4, 256(%[dst_0])\n"
|
||||
"vmovups %%zmm5, 0(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm6, 64(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm7, 128(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm8, 192(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm9, 256(%[dst_0], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm10, 0(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm11, 64(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm12, 128(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm13, 192(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm14, 256(%[dst_0], %[dst_stride], 2),\n"
|
||||
"vmovups %%zmm15, 0(%[dst_3])\n"
|
||||
"vmovups %%zmm16, 64(%[dst_3])\n"
|
||||
"vmovups %%zmm17, 128(%[dst_3])\n"
|
||||
"vmovups %%zmm18, 192(%[dst_3])\n"
|
||||
"vmovups %%zmm19, 256(%[dst_3])\n"
|
||||
"vmovups %%zmm20, 0(%[dst_3], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm21, 64(%[dst_3], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm22, 128(%[dst_3], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm23, 192(%[dst_3], %[dst_stride], 1),\n"
|
||||
"vmovups %%zmm24, 256(%[dst_3], %[dst_stride], 1),\n"
|
||||
:
|
||||
: [ src_0 ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst_0 ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
|
||||
[ dst_3 ] "r"(dst_3), [ src_3 ] "r"(src_3)
|
||||
: "%rax", "%zmm0", "%zmm1", "%zmm2", "%zmm3", "%zmm4", "%zmm5", "%zmm6", "%zmm7", "%zmm8", "%zmm9", "%zmm10",
|
||||
"%zmm11", "%zmm12", "%zmm13", "%zmm14", "%zmm15", "%zmm16", "%zmm17", "%zmm18", "%zmm19", "%zmm20", "%zmm21",
|
||||
"%zmm22", "%zmm23", "%zmm24", "%zmm25", "%zmm26", "%zmm27", "%zmm28", "%zmm29", "%zmm30", "%zmm31");
|
||||
}
|
|
@ -0,0 +1,297 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst1;
|
||||
__m256 dst2;
|
||||
__m256 dst3;
|
||||
__m256 dst4;
|
||||
__m256 dst5;
|
||||
__m256 dst6;
|
||||
__m256 dst7;
|
||||
__m256 dst8;
|
||||
__m256 dst9;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
|
||||
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
|
||||
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
|
||||
dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48);
|
||||
dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56);
|
||||
dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64);
|
||||
dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
dst6 = _mm256_setzero_ps();
|
||||
dst7 = _mm256_setzero_ps();
|
||||
dst8 = _mm256_setzero_ps();
|
||||
dst9 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst3 = _mm256_load_ps(bias + 0);
|
||||
dst4 = _mm256_load_ps(bias + 0);
|
||||
dst5 = _mm256_load_ps(bias + 0);
|
||||
dst6 = _mm256_load_ps(bias + 0);
|
||||
dst7 = _mm256_load_ps(bias + 0);
|
||||
dst8 = _mm256_load_ps(bias + 0);
|
||||
dst9 = _mm256_load_ps(bias + 0);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
__m256 src30 = _mm256_set1_ps(*(src + 24));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
|
||||
__m256 src40 = _mm256_set1_ps(*(src + 32));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
|
||||
__m256 src50 = _mm256_set1_ps(*(src + 40));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
|
||||
__m256 src60 = _mm256_set1_ps(*(src + 48));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src60, weight00);
|
||||
__m256 src70 = _mm256_set1_ps(*(src + 56));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src70, weight00);
|
||||
__m256 src80 = _mm256_set1_ps(*(src + 64));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src80, weight00);
|
||||
__m256 src90 = _mm256_set1_ps(*(src + 72));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src90, weight00);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 8);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
__m256 src31 = _mm256_set1_ps(*(src + 25));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
|
||||
__m256 src41 = _mm256_set1_ps(*(src + 33));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
|
||||
__m256 src51 = _mm256_set1_ps(*(src + 41));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
|
||||
__m256 src61 = _mm256_set1_ps(*(src + 49));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src61, weight01);
|
||||
__m256 src71 = _mm256_set1_ps(*(src + 57));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src71, weight01);
|
||||
__m256 src81 = _mm256_set1_ps(*(src + 65));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src81, weight01);
|
||||
__m256 src91 = _mm256_set1_ps(*(src + 73));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src91, weight01);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 16);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
__m256 src32 = _mm256_set1_ps(*(src + 26));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
|
||||
__m256 src42 = _mm256_set1_ps(*(src + 34));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
|
||||
__m256 src52 = _mm256_set1_ps(*(src + 42));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
|
||||
__m256 src62 = _mm256_set1_ps(*(src + 50));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src62, weight02);
|
||||
__m256 src72 = _mm256_set1_ps(*(src + 58));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src72, weight02);
|
||||
__m256 src82 = _mm256_set1_ps(*(src + 66));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src82, weight02);
|
||||
__m256 src92 = _mm256_set1_ps(*(src + 74));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src92, weight02);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 24);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
__m256 src33 = _mm256_set1_ps(*(src + 27));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
|
||||
__m256 src43 = _mm256_set1_ps(*(src + 35));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
|
||||
__m256 src53 = _mm256_set1_ps(*(src + 43));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
|
||||
__m256 src63 = _mm256_set1_ps(*(src + 51));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src63, weight03);
|
||||
__m256 src73 = _mm256_set1_ps(*(src + 59));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src73, weight03);
|
||||
__m256 src83 = _mm256_set1_ps(*(src + 67));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src83, weight03);
|
||||
__m256 src93 = _mm256_set1_ps(*(src + 75));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src93, weight03);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 32);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
__m256 src34 = _mm256_set1_ps(*(src + 28));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
|
||||
__m256 src44 = _mm256_set1_ps(*(src + 36));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
|
||||
__m256 src54 = _mm256_set1_ps(*(src + 44));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
|
||||
__m256 src64 = _mm256_set1_ps(*(src + 52));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src64, weight04);
|
||||
__m256 src74 = _mm256_set1_ps(*(src + 60));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src74, weight04);
|
||||
__m256 src84 = _mm256_set1_ps(*(src + 68));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src84, weight04);
|
||||
__m256 src94 = _mm256_set1_ps(*(src + 76));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src94, weight04);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 40);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
__m256 src35 = _mm256_set1_ps(*(src + 29));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
|
||||
__m256 src45 = _mm256_set1_ps(*(src + 37));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
|
||||
__m256 src55 = _mm256_set1_ps(*(src + 45));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
|
||||
__m256 src65 = _mm256_set1_ps(*(src + 53));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src65, weight05);
|
||||
__m256 src75 = _mm256_set1_ps(*(src + 61));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src75, weight05);
|
||||
__m256 src85 = _mm256_set1_ps(*(src + 69));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src85, weight05);
|
||||
__m256 src95 = _mm256_set1_ps(*(src + 77));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src95, weight05);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 48);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
__m256 src36 = _mm256_set1_ps(*(src + 30));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
|
||||
__m256 src46 = _mm256_set1_ps(*(src + 38));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
|
||||
__m256 src56 = _mm256_set1_ps(*(src + 46));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
|
||||
__m256 src66 = _mm256_set1_ps(*(src + 54));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src66, weight06);
|
||||
__m256 src76 = _mm256_set1_ps(*(src + 62));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src76, weight06);
|
||||
__m256 src86 = _mm256_set1_ps(*(src + 70));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src86, weight06);
|
||||
__m256 src96 = _mm256_set1_ps(*(src + 78));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src96, weight06);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 56);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
__m256 src37 = _mm256_set1_ps(*(src + 31));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
|
||||
__m256 src47 = _mm256_set1_ps(*(src + 39));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
|
||||
__m256 src57 = _mm256_set1_ps(*(src + 47));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
|
||||
__m256 src67 = _mm256_set1_ps(*(src + 55));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src67, weight07);
|
||||
__m256 src77 = _mm256_set1_ps(*(src + 63));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src77, weight07);
|
||||
__m256 src87 = _mm256_set1_ps(*(src + 71));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src87, weight07);
|
||||
__m256 src97 = _mm256_set1_ps(*(src + 79));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src97, weight07);
|
||||
src = src + src_stride;
|
||||
weight += 256;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
dst6 = _mm256_min_ps(dst6, relu6);
|
||||
dst7 = _mm256_min_ps(dst7, relu6);
|
||||
dst8 = _mm256_min_ps(dst8, relu6);
|
||||
dst9 = _mm256_min_ps(dst9, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
dst9 = _mm256_max_ps(dst9, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
dst9 = _mm256_max_ps(dst9, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 48, dst6);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 56, dst7);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 64, dst8);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 72, dst9);
|
||||
}
|
|
@ -0,0 +1,303 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 96(%[dst]), %%ymm3\n"
|
||||
"vmovups 128(%[dst]), %%ymm4\n"
|
||||
"vmovups 160(%[dst]), %%ymm5\n"
|
||||
"vmovups 192(%[dst]), %%ymm6\n"
|
||||
"vmovups 224(%[dst]), %%ymm7\n"
|
||||
"vmovups 256(%[dst]), %%ymm8\n"
|
||||
"vmovups 288(%[dst]), %%ymm9\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 0(%[bias]), %%ymm3\n"
|
||||
"vmovaps 0(%[bias]), %%ymm4\n"
|
||||
"vmovaps 0(%[bias]), %%ymm5\n"
|
||||
"vmovaps 0(%[bias]), %%ymm6\n"
|
||||
"vmovaps 0(%[bias]), %%ymm7\n"
|
||||
"vmovaps 0(%[bias]), %%ymm8\n"
|
||||
"vmovaps 0(%[bias]), %%ymm9\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
|
||||
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
|
||||
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
|
||||
"vxorps %%ymm9, %%ymm9, %%ymm9\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 96(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 128(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 160(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 192(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 224(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 256(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 288(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
// block 1
|
||||
"vmovaps 32(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 97(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 129(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 161(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 193(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 225(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 257(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 289(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
// block 2
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 98(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 130(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 162(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 194(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 226(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 258(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 290(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
// block 3
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 99(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 131(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 163(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 195(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 227(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 259(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 291(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
// block 4
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 100(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 132(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 164(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 196(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 228(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 260(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 292(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
// block 5
|
||||
"vmovaps 160(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 101(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 133(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 165(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 197(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 229(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 261(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 293(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
// block 6
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 102(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 134(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 166(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 198(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 230(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 262(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 294(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
// block 7
|
||||
"vmovaps 224(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 103(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 135(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 167(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 199(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 231(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 263(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 295(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"dec %[deep]\n"
|
||||
"add 256, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
|
||||
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
|
||||
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
|
||||
"vmaxps %%ymm9, %%ymm15, %%ymm9\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"vminps %%ymm6, %%ymm14, %%ymm6\n"
|
||||
"vminps %%ymm7, %%ymm14, %%ymm7\n"
|
||||
"vminps %%ymm8, %%ymm14, %%ymm8\n"
|
||||
"vminps %%ymm9, %%ymm14, %%ymm9\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 96(%[dst])\n"
|
||||
"vmovups %%ymm4, 128(%[dst])\n"
|
||||
"vmovups %%ymm5, 160(%[dst])\n"
|
||||
"vmovups %%ymm6, 192(%[dst])\n"
|
||||
"vmovups %%ymm7, 224(%[dst])\n"
|
||||
"vmovups %%ymm8, 256(%[dst])\n"
|
||||
"vmovups %%ymm9, 288(%[dst])\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,321 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst1;
|
||||
__m256 dst2;
|
||||
__m256 dst3;
|
||||
__m256 dst4;
|
||||
__m256 dst5;
|
||||
__m256 dst6;
|
||||
__m256 dst7;
|
||||
__m256 dst8;
|
||||
__m256 dst9;
|
||||
__m256 dst10;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
|
||||
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
|
||||
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
|
||||
dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48);
|
||||
dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56);
|
||||
dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64);
|
||||
dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72);
|
||||
dst10 = _mm256_load_ps(dst + 0 * dst_stride + 80);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
dst6 = _mm256_setzero_ps();
|
||||
dst7 = _mm256_setzero_ps();
|
||||
dst8 = _mm256_setzero_ps();
|
||||
dst9 = _mm256_setzero_ps();
|
||||
dst10 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst3 = _mm256_load_ps(bias + 0);
|
||||
dst4 = _mm256_load_ps(bias + 0);
|
||||
dst5 = _mm256_load_ps(bias + 0);
|
||||
dst6 = _mm256_load_ps(bias + 0);
|
||||
dst7 = _mm256_load_ps(bias + 0);
|
||||
dst8 = _mm256_load_ps(bias + 0);
|
||||
dst9 = _mm256_load_ps(bias + 0);
|
||||
dst10 = _mm256_load_ps(bias + 0);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
__m256 src30 = _mm256_set1_ps(*(src + 24));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
|
||||
__m256 src40 = _mm256_set1_ps(*(src + 32));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
|
||||
__m256 src50 = _mm256_set1_ps(*(src + 40));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
|
||||
__m256 src60 = _mm256_set1_ps(*(src + 48));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src60, weight00);
|
||||
__m256 src70 = _mm256_set1_ps(*(src + 56));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src70, weight00);
|
||||
__m256 src80 = _mm256_set1_ps(*(src + 64));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src80, weight00);
|
||||
__m256 src90 = _mm256_set1_ps(*(src + 72));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src90, weight00);
|
||||
__m256 src100 = _mm256_set1_ps(*(src + 80));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src100, weight00);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 8);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
__m256 src31 = _mm256_set1_ps(*(src + 25));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
|
||||
__m256 src41 = _mm256_set1_ps(*(src + 33));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
|
||||
__m256 src51 = _mm256_set1_ps(*(src + 41));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
|
||||
__m256 src61 = _mm256_set1_ps(*(src + 49));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src61, weight01);
|
||||
__m256 src71 = _mm256_set1_ps(*(src + 57));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src71, weight01);
|
||||
__m256 src81 = _mm256_set1_ps(*(src + 65));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src81, weight01);
|
||||
__m256 src91 = _mm256_set1_ps(*(src + 73));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src91, weight01);
|
||||
__m256 src101 = _mm256_set1_ps(*(src + 81));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src101, weight01);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 16);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
__m256 src32 = _mm256_set1_ps(*(src + 26));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
|
||||
__m256 src42 = _mm256_set1_ps(*(src + 34));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
|
||||
__m256 src52 = _mm256_set1_ps(*(src + 42));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
|
||||
__m256 src62 = _mm256_set1_ps(*(src + 50));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src62, weight02);
|
||||
__m256 src72 = _mm256_set1_ps(*(src + 58));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src72, weight02);
|
||||
__m256 src82 = _mm256_set1_ps(*(src + 66));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src82, weight02);
|
||||
__m256 src92 = _mm256_set1_ps(*(src + 74));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src92, weight02);
|
||||
__m256 src102 = _mm256_set1_ps(*(src + 82));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src102, weight02);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 24);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
__m256 src33 = _mm256_set1_ps(*(src + 27));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
|
||||
__m256 src43 = _mm256_set1_ps(*(src + 35));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
|
||||
__m256 src53 = _mm256_set1_ps(*(src + 43));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
|
||||
__m256 src63 = _mm256_set1_ps(*(src + 51));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src63, weight03);
|
||||
__m256 src73 = _mm256_set1_ps(*(src + 59));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src73, weight03);
|
||||
__m256 src83 = _mm256_set1_ps(*(src + 67));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src83, weight03);
|
||||
__m256 src93 = _mm256_set1_ps(*(src + 75));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src93, weight03);
|
||||
__m256 src103 = _mm256_set1_ps(*(src + 83));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src103, weight03);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 32);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
__m256 src34 = _mm256_set1_ps(*(src + 28));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
|
||||
__m256 src44 = _mm256_set1_ps(*(src + 36));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
|
||||
__m256 src54 = _mm256_set1_ps(*(src + 44));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
|
||||
__m256 src64 = _mm256_set1_ps(*(src + 52));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src64, weight04);
|
||||
__m256 src74 = _mm256_set1_ps(*(src + 60));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src74, weight04);
|
||||
__m256 src84 = _mm256_set1_ps(*(src + 68));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src84, weight04);
|
||||
__m256 src94 = _mm256_set1_ps(*(src + 76));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src94, weight04);
|
||||
__m256 src104 = _mm256_set1_ps(*(src + 84));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src104, weight04);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 40);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
__m256 src35 = _mm256_set1_ps(*(src + 29));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
|
||||
__m256 src45 = _mm256_set1_ps(*(src + 37));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
|
||||
__m256 src55 = _mm256_set1_ps(*(src + 45));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
|
||||
__m256 src65 = _mm256_set1_ps(*(src + 53));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src65, weight05);
|
||||
__m256 src75 = _mm256_set1_ps(*(src + 61));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src75, weight05);
|
||||
__m256 src85 = _mm256_set1_ps(*(src + 69));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src85, weight05);
|
||||
__m256 src95 = _mm256_set1_ps(*(src + 77));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src95, weight05);
|
||||
__m256 src105 = _mm256_set1_ps(*(src + 85));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src105, weight05);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 48);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
__m256 src36 = _mm256_set1_ps(*(src + 30));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
|
||||
__m256 src46 = _mm256_set1_ps(*(src + 38));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
|
||||
__m256 src56 = _mm256_set1_ps(*(src + 46));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
|
||||
__m256 src66 = _mm256_set1_ps(*(src + 54));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src66, weight06);
|
||||
__m256 src76 = _mm256_set1_ps(*(src + 62));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src76, weight06);
|
||||
__m256 src86 = _mm256_set1_ps(*(src + 70));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src86, weight06);
|
||||
__m256 src96 = _mm256_set1_ps(*(src + 78));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src96, weight06);
|
||||
__m256 src106 = _mm256_set1_ps(*(src + 86));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src106, weight06);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 56);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
__m256 src37 = _mm256_set1_ps(*(src + 31));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
|
||||
__m256 src47 = _mm256_set1_ps(*(src + 39));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
|
||||
__m256 src57 = _mm256_set1_ps(*(src + 47));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
|
||||
__m256 src67 = _mm256_set1_ps(*(src + 55));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src67, weight07);
|
||||
__m256 src77 = _mm256_set1_ps(*(src + 63));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src77, weight07);
|
||||
__m256 src87 = _mm256_set1_ps(*(src + 71));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src87, weight07);
|
||||
__m256 src97 = _mm256_set1_ps(*(src + 79));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src97, weight07);
|
||||
__m256 src107 = _mm256_set1_ps(*(src + 87));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src107, weight07);
|
||||
src = src + src_stride;
|
||||
weight += 256;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
dst6 = _mm256_min_ps(dst6, relu6);
|
||||
dst7 = _mm256_min_ps(dst7, relu6);
|
||||
dst8 = _mm256_min_ps(dst8, relu6);
|
||||
dst9 = _mm256_min_ps(dst9, relu6);
|
||||
dst10 = _mm256_min_ps(dst10, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
dst9 = _mm256_max_ps(dst9, relu);
|
||||
dst10 = _mm256_max_ps(dst10, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
dst9 = _mm256_max_ps(dst9, relu);
|
||||
dst10 = _mm256_max_ps(dst10, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 48, dst6);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 56, dst7);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 64, dst8);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 72, dst9);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 80, dst10);
|
||||
}
|
|
@ -0,0 +1,325 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 96(%[dst]), %%ymm3\n"
|
||||
"vmovups 128(%[dst]), %%ymm4\n"
|
||||
"vmovups 160(%[dst]), %%ymm5\n"
|
||||
"vmovups 192(%[dst]), %%ymm6\n"
|
||||
"vmovups 224(%[dst]), %%ymm7\n"
|
||||
"vmovups 256(%[dst]), %%ymm8\n"
|
||||
"vmovups 288(%[dst]), %%ymm9\n"
|
||||
"vmovups 320(%[dst]), %%ymm10\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 0(%[bias]), %%ymm3\n"
|
||||
"vmovaps 0(%[bias]), %%ymm4\n"
|
||||
"vmovaps 0(%[bias]), %%ymm5\n"
|
||||
"vmovaps 0(%[bias]), %%ymm6\n"
|
||||
"vmovaps 0(%[bias]), %%ymm7\n"
|
||||
"vmovaps 0(%[bias]), %%ymm8\n"
|
||||
"vmovaps 0(%[bias]), %%ymm9\n"
|
||||
"vmovaps 0(%[bias]), %%ymm10\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
|
||||
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
|
||||
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
|
||||
"vxorps %%ymm9, %%ymm9, %%ymm9\n"
|
||||
"vxorps %%ymm10, %%ymm10, %%ymm10\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 96(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 128(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 160(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 192(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 224(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 256(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 288(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 320(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
// block 1
|
||||
"vmovaps 32(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 97(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 129(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 161(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 193(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 225(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 257(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 289(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 321(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
// block 2
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 98(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 130(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 162(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 194(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 226(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 258(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 290(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 322(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
// block 3
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 99(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 131(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 163(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 195(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 227(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 259(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 291(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 323(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
// block 4
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 100(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 132(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 164(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 196(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 228(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 260(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 292(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 324(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
// block 5
|
||||
"vmovaps 160(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 101(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 133(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 165(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 197(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 229(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 261(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 293(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 325(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
// block 6
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 102(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 134(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 166(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 198(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 230(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 262(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 294(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 326(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
// block 7
|
||||
"vmovaps 224(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 103(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 135(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 167(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 199(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 231(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 263(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 295(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 327(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
"dec %[deep]\n"
|
||||
"add 256, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
|
||||
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
|
||||
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
|
||||
"vmaxps %%ymm9, %%ymm15, %%ymm9\n"
|
||||
"vmaxps %%ymm10, %%ymm15, %%ymm10\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"vminps %%ymm6, %%ymm14, %%ymm6\n"
|
||||
"vminps %%ymm7, %%ymm14, %%ymm7\n"
|
||||
"vminps %%ymm8, %%ymm14, %%ymm8\n"
|
||||
"vminps %%ymm9, %%ymm14, %%ymm9\n"
|
||||
"vminps %%ymm10, %%ymm14, %%ymm10\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 96(%[dst])\n"
|
||||
"vmovups %%ymm4, 128(%[dst])\n"
|
||||
"vmovups %%ymm5, 160(%[dst])\n"
|
||||
"vmovups %%ymm6, 192(%[dst])\n"
|
||||
"vmovups %%ymm7, 224(%[dst])\n"
|
||||
"vmovups %%ymm8, 256(%[dst])\n"
|
||||
"vmovups %%ymm9, 288(%[dst])\n"
|
||||
"vmovups %%ymm10, 320(%[dst])\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,345 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst1;
|
||||
__m256 dst2;
|
||||
__m256 dst3;
|
||||
__m256 dst4;
|
||||
__m256 dst5;
|
||||
__m256 dst6;
|
||||
__m256 dst7;
|
||||
__m256 dst8;
|
||||
__m256 dst9;
|
||||
__m256 dst10;
|
||||
__m256 dst11;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
|
||||
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
|
||||
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
|
||||
dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48);
|
||||
dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56);
|
||||
dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64);
|
||||
dst9 = _mm256_load_ps(dst + 0 * dst_stride + 72);
|
||||
dst10 = _mm256_load_ps(dst + 0 * dst_stride + 80);
|
||||
dst11 = _mm256_load_ps(dst + 0 * dst_stride + 88);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
dst6 = _mm256_setzero_ps();
|
||||
dst7 = _mm256_setzero_ps();
|
||||
dst8 = _mm256_setzero_ps();
|
||||
dst9 = _mm256_setzero_ps();
|
||||
dst10 = _mm256_setzero_ps();
|
||||
dst11 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst3 = _mm256_load_ps(bias + 0);
|
||||
dst4 = _mm256_load_ps(bias + 0);
|
||||
dst5 = _mm256_load_ps(bias + 0);
|
||||
dst6 = _mm256_load_ps(bias + 0);
|
||||
dst7 = _mm256_load_ps(bias + 0);
|
||||
dst8 = _mm256_load_ps(bias + 0);
|
||||
dst9 = _mm256_load_ps(bias + 0);
|
||||
dst10 = _mm256_load_ps(bias + 0);
|
||||
dst11 = _mm256_load_ps(bias + 0);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
__m256 src30 = _mm256_set1_ps(*(src + 24));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
|
||||
__m256 src40 = _mm256_set1_ps(*(src + 32));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
|
||||
__m256 src50 = _mm256_set1_ps(*(src + 40));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
|
||||
__m256 src60 = _mm256_set1_ps(*(src + 48));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src60, weight00);
|
||||
__m256 src70 = _mm256_set1_ps(*(src + 56));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src70, weight00);
|
||||
__m256 src80 = _mm256_set1_ps(*(src + 64));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src80, weight00);
|
||||
__m256 src90 = _mm256_set1_ps(*(src + 72));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src90, weight00);
|
||||
__m256 src100 = _mm256_set1_ps(*(src + 80));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src100, weight00);
|
||||
__m256 src110 = _mm256_set1_ps(*(src + 88));
|
||||
dst11 = _mm256_fmadd_ps(dst11, src110, weight00);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 8);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
__m256 src31 = _mm256_set1_ps(*(src + 25));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
|
||||
__m256 src41 = _mm256_set1_ps(*(src + 33));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
|
||||
__m256 src51 = _mm256_set1_ps(*(src + 41));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
|
||||
__m256 src61 = _mm256_set1_ps(*(src + 49));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src61, weight01);
|
||||
__m256 src71 = _mm256_set1_ps(*(src + 57));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src71, weight01);
|
||||
__m256 src81 = _mm256_set1_ps(*(src + 65));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src81, weight01);
|
||||
__m256 src91 = _mm256_set1_ps(*(src + 73));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src91, weight01);
|
||||
__m256 src101 = _mm256_set1_ps(*(src + 81));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src101, weight01);
|
||||
__m256 src111 = _mm256_set1_ps(*(src + 89));
|
||||
dst11 = _mm256_fmadd_ps(dst11, src111, weight01);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 16);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
__m256 src32 = _mm256_set1_ps(*(src + 26));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
|
||||
__m256 src42 = _mm256_set1_ps(*(src + 34));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
|
||||
__m256 src52 = _mm256_set1_ps(*(src + 42));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
|
||||
__m256 src62 = _mm256_set1_ps(*(src + 50));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src62, weight02);
|
||||
__m256 src72 = _mm256_set1_ps(*(src + 58));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src72, weight02);
|
||||
__m256 src82 = _mm256_set1_ps(*(src + 66));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src82, weight02);
|
||||
__m256 src92 = _mm256_set1_ps(*(src + 74));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src92, weight02);
|
||||
__m256 src102 = _mm256_set1_ps(*(src + 82));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src102, weight02);
|
||||
__m256 src112 = _mm256_set1_ps(*(src + 90));
|
||||
dst11 = _mm256_fmadd_ps(dst11, src112, weight02);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 24);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
__m256 src33 = _mm256_set1_ps(*(src + 27));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
|
||||
__m256 src43 = _mm256_set1_ps(*(src + 35));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
|
||||
__m256 src53 = _mm256_set1_ps(*(src + 43));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
|
||||
__m256 src63 = _mm256_set1_ps(*(src + 51));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src63, weight03);
|
||||
__m256 src73 = _mm256_set1_ps(*(src + 59));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src73, weight03);
|
||||
__m256 src83 = _mm256_set1_ps(*(src + 67));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src83, weight03);
|
||||
__m256 src93 = _mm256_set1_ps(*(src + 75));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src93, weight03);
|
||||
__m256 src103 = _mm256_set1_ps(*(src + 83));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src103, weight03);
|
||||
__m256 src113 = _mm256_set1_ps(*(src + 91));
|
||||
dst11 = _mm256_fmadd_ps(dst11, src113, weight03);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 32);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
__m256 src34 = _mm256_set1_ps(*(src + 28));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
|
||||
__m256 src44 = _mm256_set1_ps(*(src + 36));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
|
||||
__m256 src54 = _mm256_set1_ps(*(src + 44));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
|
||||
__m256 src64 = _mm256_set1_ps(*(src + 52));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src64, weight04);
|
||||
__m256 src74 = _mm256_set1_ps(*(src + 60));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src74, weight04);
|
||||
__m256 src84 = _mm256_set1_ps(*(src + 68));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src84, weight04);
|
||||
__m256 src94 = _mm256_set1_ps(*(src + 76));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src94, weight04);
|
||||
__m256 src104 = _mm256_set1_ps(*(src + 84));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src104, weight04);
|
||||
__m256 src114 = _mm256_set1_ps(*(src + 92));
|
||||
dst11 = _mm256_fmadd_ps(dst11, src114, weight04);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 40);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
__m256 src35 = _mm256_set1_ps(*(src + 29));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
|
||||
__m256 src45 = _mm256_set1_ps(*(src + 37));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
|
||||
__m256 src55 = _mm256_set1_ps(*(src + 45));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
|
||||
__m256 src65 = _mm256_set1_ps(*(src + 53));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src65, weight05);
|
||||
__m256 src75 = _mm256_set1_ps(*(src + 61));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src75, weight05);
|
||||
__m256 src85 = _mm256_set1_ps(*(src + 69));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src85, weight05);
|
||||
__m256 src95 = _mm256_set1_ps(*(src + 77));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src95, weight05);
|
||||
__m256 src105 = _mm256_set1_ps(*(src + 85));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src105, weight05);
|
||||
__m256 src115 = _mm256_set1_ps(*(src + 93));
|
||||
dst11 = _mm256_fmadd_ps(dst11, src115, weight05);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 48);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
__m256 src36 = _mm256_set1_ps(*(src + 30));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
|
||||
__m256 src46 = _mm256_set1_ps(*(src + 38));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
|
||||
__m256 src56 = _mm256_set1_ps(*(src + 46));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
|
||||
__m256 src66 = _mm256_set1_ps(*(src + 54));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src66, weight06);
|
||||
__m256 src76 = _mm256_set1_ps(*(src + 62));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src76, weight06);
|
||||
__m256 src86 = _mm256_set1_ps(*(src + 70));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src86, weight06);
|
||||
__m256 src96 = _mm256_set1_ps(*(src + 78));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src96, weight06);
|
||||
__m256 src106 = _mm256_set1_ps(*(src + 86));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src106, weight06);
|
||||
__m256 src116 = _mm256_set1_ps(*(src + 94));
|
||||
dst11 = _mm256_fmadd_ps(dst11, src116, weight06);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 56);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
__m256 src37 = _mm256_set1_ps(*(src + 31));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
|
||||
__m256 src47 = _mm256_set1_ps(*(src + 39));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
|
||||
__m256 src57 = _mm256_set1_ps(*(src + 47));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
|
||||
__m256 src67 = _mm256_set1_ps(*(src + 55));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src67, weight07);
|
||||
__m256 src77 = _mm256_set1_ps(*(src + 63));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src77, weight07);
|
||||
__m256 src87 = _mm256_set1_ps(*(src + 71));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src87, weight07);
|
||||
__m256 src97 = _mm256_set1_ps(*(src + 79));
|
||||
dst9 = _mm256_fmadd_ps(dst9, src97, weight07);
|
||||
__m256 src107 = _mm256_set1_ps(*(src + 87));
|
||||
dst10 = _mm256_fmadd_ps(dst10, src107, weight07);
|
||||
__m256 src117 = _mm256_set1_ps(*(src + 95));
|
||||
dst11 = _mm256_fmadd_ps(dst11, src117, weight07);
|
||||
src = src + src_stride;
|
||||
weight += 256;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
dst6 = _mm256_min_ps(dst6, relu6);
|
||||
dst7 = _mm256_min_ps(dst7, relu6);
|
||||
dst8 = _mm256_min_ps(dst8, relu6);
|
||||
dst9 = _mm256_min_ps(dst9, relu6);
|
||||
dst10 = _mm256_min_ps(dst10, relu6);
|
||||
dst11 = _mm256_min_ps(dst11, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
dst9 = _mm256_max_ps(dst9, relu);
|
||||
dst10 = _mm256_max_ps(dst10, relu);
|
||||
dst11 = _mm256_max_ps(dst11, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
dst9 = _mm256_max_ps(dst9, relu);
|
||||
dst10 = _mm256_max_ps(dst10, relu);
|
||||
dst11 = _mm256_max_ps(dst11, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 48, dst6);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 56, dst7);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 64, dst8);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 72, dst9);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 80, dst10);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 88, dst11);
|
||||
}
|
|
@ -0,0 +1,347 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 96(%[dst]), %%ymm3\n"
|
||||
"vmovups 128(%[dst]), %%ymm4\n"
|
||||
"vmovups 160(%[dst]), %%ymm5\n"
|
||||
"vmovups 192(%[dst]), %%ymm6\n"
|
||||
"vmovups 224(%[dst]), %%ymm7\n"
|
||||
"vmovups 256(%[dst]), %%ymm8\n"
|
||||
"vmovups 288(%[dst]), %%ymm9\n"
|
||||
"vmovups 320(%[dst]), %%ymm10\n"
|
||||
"vmovups 352(%[dst]), %%ymm11\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 0(%[bias]), %%ymm3\n"
|
||||
"vmovaps 0(%[bias]), %%ymm4\n"
|
||||
"vmovaps 0(%[bias]), %%ymm5\n"
|
||||
"vmovaps 0(%[bias]), %%ymm6\n"
|
||||
"vmovaps 0(%[bias]), %%ymm7\n"
|
||||
"vmovaps 0(%[bias]), %%ymm8\n"
|
||||
"vmovaps 0(%[bias]), %%ymm9\n"
|
||||
"vmovaps 0(%[bias]), %%ymm10\n"
|
||||
"vmovaps 0(%[bias]), %%ymm11\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
|
||||
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
|
||||
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
|
||||
"vxorps %%ymm9, %%ymm9, %%ymm9\n"
|
||||
"vxorps %%ymm10, %%ymm10, %%ymm10\n"
|
||||
"vxorps %%ymm11, %%ymm11, %%ymm11\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 96(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 128(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 160(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 192(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 224(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 256(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 288(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 320(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 352(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
|
||||
// block 1
|
||||
"vmovaps 32(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 97(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 129(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 161(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 193(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 225(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 257(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 289(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 321(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 353(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
|
||||
// block 2
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 98(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 130(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 162(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 194(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 226(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 258(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 290(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 322(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 354(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
|
||||
// block 3
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 99(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 131(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 163(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 195(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 227(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 259(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 291(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 323(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 355(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
|
||||
// block 4
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 100(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 132(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 164(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 196(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 228(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 260(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 292(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 324(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 356(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
|
||||
// block 5
|
||||
"vmovaps 160(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 101(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 133(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 165(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 197(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 229(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 261(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 293(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 325(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 357(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
|
||||
// block 6
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 102(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 134(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 166(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 198(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 230(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 262(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 294(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 326(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 358(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
|
||||
// block 7
|
||||
"vmovaps 224(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 103(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 135(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 167(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 199(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 231(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 263(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 295(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 327(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 359(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm15\n"
|
||||
"dec %[deep]\n"
|
||||
"add 256, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
|
||||
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
|
||||
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
|
||||
"vmaxps %%ymm9, %%ymm15, %%ymm9\n"
|
||||
"vmaxps %%ymm10, %%ymm15, %%ymm10\n"
|
||||
"vmaxps %%ymm11, %%ymm15, %%ymm11\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"vminps %%ymm6, %%ymm14, %%ymm6\n"
|
||||
"vminps %%ymm7, %%ymm14, %%ymm7\n"
|
||||
"vminps %%ymm8, %%ymm14, %%ymm8\n"
|
||||
"vminps %%ymm9, %%ymm14, %%ymm9\n"
|
||||
"vminps %%ymm10, %%ymm14, %%ymm10\n"
|
||||
"vminps %%ymm11, %%ymm14, %%ymm11\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 96(%[dst])\n"
|
||||
"vmovups %%ymm4, 128(%[dst])\n"
|
||||
"vmovups %%ymm5, 160(%[dst])\n"
|
||||
"vmovups %%ymm6, 192(%[dst])\n"
|
||||
"vmovups %%ymm7, 224(%[dst])\n"
|
||||
"vmovups %%ymm8, 256(%[dst])\n"
|
||||
"vmovups %%ymm9, 288(%[dst])\n"
|
||||
"vmovups %%ymm10, 320(%[dst])\n"
|
||||
"vmovups %%ymm11, 352(%[dst])\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst1;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst1 = _mm256_load_ps(bias + 8);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 weight10 = _mm256_load_ps(weight + 8);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src00, weight10);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 16);
|
||||
__m256 weight11 = _mm256_load_ps(weight + 24);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src01, weight11);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 32);
|
||||
__m256 weight12 = _mm256_load_ps(weight + 40);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src02, weight12);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 48);
|
||||
__m256 weight13 = _mm256_load_ps(weight + 56);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src03, weight13);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 64);
|
||||
__m256 weight14 = _mm256_load_ps(weight + 72);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src04, weight14);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 80);
|
||||
__m256 weight15 = _mm256_load_ps(weight + 88);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src05, weight15);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 96);
|
||||
__m256 weight16 = _mm256_load_ps(weight + 104);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src06, weight16);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 112);
|
||||
__m256 weight17 = _mm256_load_ps(weight + 120);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src07, weight17);
|
||||
src = src + src_stride;
|
||||
weight += 512;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 0, dst1);
|
||||
}
|
|
@ -0,0 +1,127 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 32(%[bias]), %%ymm1\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vmovaps 32(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
// block 1
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vmovaps 96(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
// block 2
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vmovaps 160(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
// block 3
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vmovaps 224(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
// block 4
|
||||
"vmovaps 256(%[weight]), %%ymm15\n"
|
||||
"vmovaps 288(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
// block 5
|
||||
"vmovaps 320(%[weight]), %%ymm15\n"
|
||||
"vmovaps 352(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
// block 6
|
||||
"vmovaps 384(%[weight]), %%ymm15\n"
|
||||
"vmovaps 416(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
// block 7
|
||||
"vmovaps 448(%[weight]), %%ymm15\n"
|
||||
"vmovaps 480(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
"dec %[deep]\n"
|
||||
"add 512, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst1;
|
||||
__m256 dst2;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0);
|
||||
dst2 = _mm256_load_ps(dst + 2 * dst_stride + 0);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst1 = _mm256_load_ps(bias + 8);
|
||||
dst2 = _mm256_load_ps(bias + 16);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 weight10 = _mm256_load_ps(weight + 8);
|
||||
__m256 weight20 = _mm256_load_ps(weight + 16);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src00, weight10);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src00, weight20);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 24);
|
||||
__m256 weight11 = _mm256_load_ps(weight + 32);
|
||||
__m256 weight21 = _mm256_load_ps(weight + 40);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src01, weight11);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src01, weight21);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 48);
|
||||
__m256 weight12 = _mm256_load_ps(weight + 56);
|
||||
__m256 weight22 = _mm256_load_ps(weight + 64);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src02, weight12);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src02, weight22);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 72);
|
||||
__m256 weight13 = _mm256_load_ps(weight + 80);
|
||||
__m256 weight23 = _mm256_load_ps(weight + 88);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src03, weight13);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src03, weight23);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 96);
|
||||
__m256 weight14 = _mm256_load_ps(weight + 104);
|
||||
__m256 weight24 = _mm256_load_ps(weight + 112);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src04, weight14);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src04, weight24);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 120);
|
||||
__m256 weight15 = _mm256_load_ps(weight + 128);
|
||||
__m256 weight25 = _mm256_load_ps(weight + 136);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src05, weight15);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src05, weight25);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 144);
|
||||
__m256 weight16 = _mm256_load_ps(weight + 152);
|
||||
__m256 weight26 = _mm256_load_ps(weight + 160);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src06, weight16);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src06, weight26);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 168);
|
||||
__m256 weight17 = _mm256_load_ps(weight + 176);
|
||||
__m256 weight27 = _mm256_load_ps(weight + 184);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src07, weight17);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src07, weight27);
|
||||
src = src + src_stride;
|
||||
weight += 768;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 0, dst1);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 0, dst2);
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 2), %%ymm2\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 32(%[bias]), %%ymm1\n"
|
||||
"vmovaps 64(%[bias]), %%ymm2\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vmovaps 32(%[weight]), %%ymm14\n"
|
||||
"vmovaps 64(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
// block 1
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vmovaps 128(%[weight]), %%ymm14\n"
|
||||
"vmovaps 160(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
// block 2
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vmovaps 224(%[weight]), %%ymm14\n"
|
||||
"vmovaps 256(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
// block 3
|
||||
"vmovaps 288(%[weight]), %%ymm15\n"
|
||||
"vmovaps 320(%[weight]), %%ymm14\n"
|
||||
"vmovaps 352(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
// block 4
|
||||
"vmovaps 384(%[weight]), %%ymm15\n"
|
||||
"vmovaps 416(%[weight]), %%ymm14\n"
|
||||
"vmovaps 448(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
// block 5
|
||||
"vmovaps 480(%[weight]), %%ymm15\n"
|
||||
"vmovaps 512(%[weight]), %%ymm14\n"
|
||||
"vmovaps 544(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
// block 6
|
||||
"vmovaps 576(%[weight]), %%ymm15\n"
|
||||
"vmovaps 608(%[weight]), %%ymm14\n"
|
||||
"vmovaps 640(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
// block 7
|
||||
"vmovaps 672(%[weight]), %%ymm15\n"
|
||||
"vmovaps 704(%[weight]), %%ymm14\n"
|
||||
"vmovaps 736(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
"dec %[deep]\n"
|
||||
"add 768, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm2, 0(%[dst], %[dst_stride], 2)\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,153 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst1;
|
||||
__m256 dst2;
|
||||
__m256 dst3;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 1 * dst_stride + 0);
|
||||
dst2 = _mm256_load_ps(dst + 2 * dst_stride + 0);
|
||||
dst3 = _mm256_load_ps(dst + 3 * dst_stride + 0);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst1 = _mm256_load_ps(bias + 8);
|
||||
dst2 = _mm256_load_ps(bias + 16);
|
||||
dst3 = _mm256_load_ps(bias + 24);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
__m256 weight10 = _mm256_load_ps(weight + 8);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src00, weight10);
|
||||
__m256 weight20 = _mm256_load_ps(weight + 16);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src00, weight20);
|
||||
__m256 weight30 = _mm256_load_ps(weight + 24);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src00, weight30);
|
||||
// bock1
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
__m256 weight01 = _mm256_load_ps(weight + 32);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
__m256 weight11 = _mm256_load_ps(weight + 40);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src01, weight11);
|
||||
__m256 weight21 = _mm256_load_ps(weight + 48);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src01, weight21);
|
||||
__m256 weight31 = _mm256_load_ps(weight + 56);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src01, weight31);
|
||||
// bock2
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
__m256 weight02 = _mm256_load_ps(weight + 64);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
__m256 weight12 = _mm256_load_ps(weight + 72);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src02, weight12);
|
||||
__m256 weight22 = _mm256_load_ps(weight + 80);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src02, weight22);
|
||||
__m256 weight32 = _mm256_load_ps(weight + 88);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src02, weight32);
|
||||
// bock3
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
__m256 weight03 = _mm256_load_ps(weight + 96);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
__m256 weight13 = _mm256_load_ps(weight + 104);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src03, weight13);
|
||||
__m256 weight23 = _mm256_load_ps(weight + 112);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src03, weight23);
|
||||
__m256 weight33 = _mm256_load_ps(weight + 120);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src03, weight33);
|
||||
// bock4
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
__m256 weight04 = _mm256_load_ps(weight + 128);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
__m256 weight14 = _mm256_load_ps(weight + 136);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src04, weight14);
|
||||
__m256 weight24 = _mm256_load_ps(weight + 144);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src04, weight24);
|
||||
__m256 weight34 = _mm256_load_ps(weight + 152);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src04, weight34);
|
||||
// bock5
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
__m256 weight05 = _mm256_load_ps(weight + 160);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
__m256 weight15 = _mm256_load_ps(weight + 168);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src05, weight15);
|
||||
__m256 weight25 = _mm256_load_ps(weight + 176);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src05, weight25);
|
||||
__m256 weight35 = _mm256_load_ps(weight + 184);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src05, weight35);
|
||||
// bock6
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
__m256 weight06 = _mm256_load_ps(weight + 192);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
__m256 weight16 = _mm256_load_ps(weight + 200);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src06, weight16);
|
||||
__m256 weight26 = _mm256_load_ps(weight + 208);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src06, weight26);
|
||||
__m256 weight36 = _mm256_load_ps(weight + 216);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src06, weight36);
|
||||
// bock7
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
__m256 weight07 = _mm256_load_ps(weight + 224);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
__m256 weight17 = _mm256_load_ps(weight + 232);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src07, weight17);
|
||||
__m256 weight27 = _mm256_load_ps(weight + 240);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src07, weight27);
|
||||
__m256 weight37 = _mm256_load_ps(weight + 248);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src07, weight37);
|
||||
src = src + src_stride;
|
||||
weight += 1024;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 0, dst1);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 0, dst2);
|
||||
_mm256_store_ps(dst + 3 * src_stride + 0, dst3);
|
||||
}
|
|
@ -0,0 +1,174 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
const float *dst_4 = dst + 3 * dst_stride;
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm1\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 2), %%ymm2\n"
|
||||
"vmovups 0(%[dst_4]), %%ymm3\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 32(%[bias]), %%ymm1\n"
|
||||
"vmovaps 64(%[bias]), %%ymm2\n"
|
||||
"vmovaps 96(%[bias]), %%ymm3\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
|
||||
[ dst_4 ] "r"(dst_4)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vbroadcastss 0(%[src]), %%ymm15\n"
|
||||
"vmovaps 0(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 32(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 64(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 96(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
// block 1
|
||||
"vbroadcastss 1(%[src]), %%ymm15\n"
|
||||
"vmovaps 128(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 160(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 192(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 224(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
// block 2
|
||||
"vbroadcastss 2(%[src]), %%ymm15\n"
|
||||
"vmovaps 256(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 288(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 320(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 352(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
// block 3
|
||||
"vbroadcastss 3(%[src]), %%ymm15\n"
|
||||
"vmovaps 384(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 416(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 448(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 480(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
// block 4
|
||||
"vbroadcastss 4(%[src]), %%ymm15\n"
|
||||
"vmovaps 512(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 544(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 576(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 608(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
// block 5
|
||||
"vbroadcastss 5(%[src]), %%ymm15\n"
|
||||
"vmovaps 640(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 672(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 704(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 736(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
// block 6
|
||||
"vbroadcastss 6(%[src]), %%ymm15\n"
|
||||
"vmovaps 768(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 800(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 832(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 864(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
// block 7
|
||||
"vbroadcastss 7(%[src]), %%ymm15\n"
|
||||
"vmovaps 896(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 928(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 960(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm14, %%ymm15\n"
|
||||
"vmovaps 992(%[weight]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"dec %[deep]\n"
|
||||
"add 1024, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 0(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm2, 0(%[dst], %[dst_stride], 2)\n"
|
||||
"vmovups %%ymm3, 0(%[dst_4])\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
|
||||
[ dst_4 ] "r"(dst_4)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 8);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 16);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 24);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 32);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 40);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 48);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 56);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
src = src + src_stride;
|
||||
weight += 256;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
// block 1
|
||||
"vmovaps 32(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
// block 2
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
// block 3
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
// block 4
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
// block 5
|
||||
"vmovaps 160(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
// block 6
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
// block 7
|
||||
"vmovaps 224(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"dec %[deep]\n"
|
||||
"add 256, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,145 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst2;
|
||||
__m256 dst1;
|
||||
__m256 dst3;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst2 = _mm256_load_ps(bias + 8);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst3 = _mm256_load_ps(bias + 8);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 weight10 = _mm256_load_ps(weight + 8);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src00, weight10);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src10, weight10);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 16);
|
||||
__m256 weight11 = _mm256_load_ps(weight + 24);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src01, weight11);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src11, weight11);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 32);
|
||||
__m256 weight12 = _mm256_load_ps(weight + 40);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src02, weight12);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src12, weight12);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 48);
|
||||
__m256 weight13 = _mm256_load_ps(weight + 56);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src03, weight13);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src13, weight13);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 64);
|
||||
__m256 weight14 = _mm256_load_ps(weight + 72);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src04, weight14);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src14, weight14);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 80);
|
||||
__m256 weight15 = _mm256_load_ps(weight + 88);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src05, weight15);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src15, weight15);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 96);
|
||||
__m256 weight16 = _mm256_load_ps(weight + 104);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src06, weight16);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src16, weight16);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 112);
|
||||
__m256 weight17 = _mm256_load_ps(weight + 120);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src07, weight17);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src17, weight17);
|
||||
src = src + src_stride;
|
||||
weight += 512;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 0, dst2);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 8, dst3);
|
||||
}
|
|
@ -0,0 +1,163 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n"
|
||||
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 32(%[bias]), %%ymm2\n"
|
||||
"vmovaps 32(%[bias]), %%ymm3\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vmovaps 32(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
// block 1
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vmovaps 96(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
// block 2
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vmovaps 160(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
// block 3
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vmovaps 224(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
// block 4
|
||||
"vmovaps 256(%[weight]), %%ymm15\n"
|
||||
"vmovaps 288(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
// block 5
|
||||
"vmovaps 320(%[weight]), %%ymm15\n"
|
||||
"vmovaps 352(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
// block 6
|
||||
"vmovaps 384(%[weight]), %%ymm15\n"
|
||||
"vmovaps 416(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
// block 7
|
||||
"vmovaps 448(%[weight]), %%ymm15\n"
|
||||
"vmovaps 480(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"dec %[deep]\n"
|
||||
"add 512, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,185 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst2;
|
||||
__m256 dst4;
|
||||
__m256 dst1;
|
||||
__m256 dst3;
|
||||
__m256 dst5;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0);
|
||||
dst4 = _mm256_load_ps(dst + 2 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8);
|
||||
dst5 = _mm256_load_ps(dst + 2 * dst_stride + 8);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst2 = _mm256_load_ps(bias + 8);
|
||||
dst4 = _mm256_load_ps(bias + 16);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst3 = _mm256_load_ps(bias + 8);
|
||||
dst5 = _mm256_load_ps(bias + 16);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 weight10 = _mm256_load_ps(weight + 8);
|
||||
__m256 weight20 = _mm256_load_ps(weight + 16);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src00, weight10);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src00, weight20);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src10, weight10);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src10, weight20);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 24);
|
||||
__m256 weight11 = _mm256_load_ps(weight + 32);
|
||||
__m256 weight21 = _mm256_load_ps(weight + 40);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src01, weight11);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src01, weight21);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src11, weight11);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src11, weight21);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 48);
|
||||
__m256 weight12 = _mm256_load_ps(weight + 56);
|
||||
__m256 weight22 = _mm256_load_ps(weight + 64);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src02, weight12);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src02, weight22);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src12, weight12);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src12, weight22);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 72);
|
||||
__m256 weight13 = _mm256_load_ps(weight + 80);
|
||||
__m256 weight23 = _mm256_load_ps(weight + 88);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src03, weight13);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src03, weight23);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src13, weight13);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src13, weight23);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 96);
|
||||
__m256 weight14 = _mm256_load_ps(weight + 104);
|
||||
__m256 weight24 = _mm256_load_ps(weight + 112);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src04, weight14);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src04, weight24);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src14, weight14);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src14, weight24);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 120);
|
||||
__m256 weight15 = _mm256_load_ps(weight + 128);
|
||||
__m256 weight25 = _mm256_load_ps(weight + 136);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src05, weight15);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src05, weight25);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src15, weight15);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src15, weight25);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 144);
|
||||
__m256 weight16 = _mm256_load_ps(weight + 152);
|
||||
__m256 weight26 = _mm256_load_ps(weight + 160);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src06, weight16);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src06, weight26);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src16, weight16);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src16, weight26);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 168);
|
||||
__m256 weight17 = _mm256_load_ps(weight + 176);
|
||||
__m256 weight27 = _mm256_load_ps(weight + 184);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src07, weight17);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src07, weight27);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src17, weight17);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src17, weight27);
|
||||
src = src + src_stride;
|
||||
weight += 768;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 0, dst2);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 8, dst3);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 0, dst4);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 8, dst5);
|
||||
}
|
|
@ -0,0 +1,199 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n"
|
||||
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 2), %%ymm4\n"
|
||||
"vmovups 32(%[dst], %[dst_stride], 2), %%ymm5\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 32(%[bias]), %%ymm2\n"
|
||||
"vmovaps 32(%[bias]), %%ymm3\n"
|
||||
"vmovaps 64(%[bias]), %%ymm4\n"
|
||||
"vmovaps 64(%[bias]), %%ymm5\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vmovaps 32(%[weight]), %%ymm14\n"
|
||||
"vmovaps 64(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
// block 1
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vmovaps 128(%[weight]), %%ymm14\n"
|
||||
"vmovaps 160(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
// block 2
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vmovaps 224(%[weight]), %%ymm14\n"
|
||||
"vmovaps 256(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
// block 3
|
||||
"vmovaps 288(%[weight]), %%ymm15\n"
|
||||
"vmovaps 320(%[weight]), %%ymm14\n"
|
||||
"vmovaps 352(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
// block 4
|
||||
"vmovaps 384(%[weight]), %%ymm15\n"
|
||||
"vmovaps 416(%[weight]), %%ymm14\n"
|
||||
"vmovaps 448(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
// block 5
|
||||
"vmovaps 480(%[weight]), %%ymm15\n"
|
||||
"vmovaps 512(%[weight]), %%ymm14\n"
|
||||
"vmovaps 544(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
// block 6
|
||||
"vmovaps 576(%[weight]), %%ymm15\n"
|
||||
"vmovaps 608(%[weight]), %%ymm14\n"
|
||||
"vmovaps 640(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
// block 7
|
||||
"vmovaps 672(%[weight]), %%ymm15\n"
|
||||
"vmovaps 704(%[weight]), %%ymm14\n"
|
||||
"vmovaps 736(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
"dec %[deep]\n"
|
||||
"add 768, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm4, 0(%[dst], %[dst_stride], 2)\n"
|
||||
"vmovups %%ymm5, 32(%[dst], %[dst_stride], 2)\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,225 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst2;
|
||||
__m256 dst4;
|
||||
__m256 dst6;
|
||||
__m256 dst1;
|
||||
__m256 dst3;
|
||||
__m256 dst5;
|
||||
__m256 dst7;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst2 = _mm256_load_ps(dst + 1 * dst_stride + 0);
|
||||
dst4 = _mm256_load_ps(dst + 2 * dst_stride + 0);
|
||||
dst6 = _mm256_load_ps(dst + 3 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst3 = _mm256_load_ps(dst + 1 * dst_stride + 8);
|
||||
dst5 = _mm256_load_ps(dst + 2 * dst_stride + 8);
|
||||
dst7 = _mm256_load_ps(dst + 3 * dst_stride + 8);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
dst6 = _mm256_setzero_ps();
|
||||
dst7 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst2 = _mm256_load_ps(bias + 8);
|
||||
dst4 = _mm256_load_ps(bias + 16);
|
||||
dst6 = _mm256_load_ps(bias + 24);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst3 = _mm256_load_ps(bias + 8);
|
||||
dst5 = _mm256_load_ps(bias + 16);
|
||||
dst7 = _mm256_load_ps(bias + 24);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
__m256 weight10 = _mm256_load_ps(weight + 8);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src00, weight10);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src10, weight10);
|
||||
__m256 weight20 = _mm256_load_ps(weight + 16);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src00, weight20);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src10, weight20);
|
||||
__m256 weight30 = _mm256_load_ps(weight + 24);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src00, weight30);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src10, weight30);
|
||||
// bock1
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
__m256 weight01 = _mm256_load_ps(weight + 32);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
__m256 weight11 = _mm256_load_ps(weight + 40);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src01, weight11);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src11, weight11);
|
||||
__m256 weight21 = _mm256_load_ps(weight + 48);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src01, weight21);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src11, weight21);
|
||||
__m256 weight31 = _mm256_load_ps(weight + 56);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src01, weight31);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src11, weight31);
|
||||
// bock2
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
__m256 weight02 = _mm256_load_ps(weight + 64);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
__m256 weight12 = _mm256_load_ps(weight + 72);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src02, weight12);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src12, weight12);
|
||||
__m256 weight22 = _mm256_load_ps(weight + 80);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src02, weight22);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src12, weight22);
|
||||
__m256 weight32 = _mm256_load_ps(weight + 88);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src02, weight32);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src12, weight32);
|
||||
// bock3
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
__m256 weight03 = _mm256_load_ps(weight + 96);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
__m256 weight13 = _mm256_load_ps(weight + 104);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src03, weight13);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src13, weight13);
|
||||
__m256 weight23 = _mm256_load_ps(weight + 112);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src03, weight23);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src13, weight23);
|
||||
__m256 weight33 = _mm256_load_ps(weight + 120);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src03, weight33);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src13, weight33);
|
||||
// bock4
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
__m256 weight04 = _mm256_load_ps(weight + 128);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
__m256 weight14 = _mm256_load_ps(weight + 136);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src04, weight14);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src14, weight14);
|
||||
__m256 weight24 = _mm256_load_ps(weight + 144);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src04, weight24);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src14, weight24);
|
||||
__m256 weight34 = _mm256_load_ps(weight + 152);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src04, weight34);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src14, weight34);
|
||||
// bock5
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
__m256 weight05 = _mm256_load_ps(weight + 160);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
__m256 weight15 = _mm256_load_ps(weight + 168);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src05, weight15);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src15, weight15);
|
||||
__m256 weight25 = _mm256_load_ps(weight + 176);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src05, weight25);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src15, weight25);
|
||||
__m256 weight35 = _mm256_load_ps(weight + 184);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src05, weight35);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src15, weight35);
|
||||
// bock6
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
__m256 weight06 = _mm256_load_ps(weight + 192);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
__m256 weight16 = _mm256_load_ps(weight + 200);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src06, weight16);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src16, weight16);
|
||||
__m256 weight26 = _mm256_load_ps(weight + 208);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src06, weight26);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src16, weight26);
|
||||
__m256 weight36 = _mm256_load_ps(weight + 216);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src06, weight36);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src16, weight36);
|
||||
// bock7
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
__m256 weight07 = _mm256_load_ps(weight + 224);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
__m256 weight17 = _mm256_load_ps(weight + 232);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src07, weight17);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src17, weight17);
|
||||
__m256 weight27 = _mm256_load_ps(weight + 240);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src07, weight27);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src17, weight27);
|
||||
__m256 weight37 = _mm256_load_ps(weight + 248);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src07, weight37);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src17, weight37);
|
||||
src = src + src_stride;
|
||||
weight += 1024;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst6 = _mm256_min_ps(dst6, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
dst7 = _mm256_min_ps(dst7, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 0, dst2);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 8, dst3);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 0, dst4);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 8, dst5);
|
||||
_mm256_store_ps(dst + 3 * src_stride + 0, dst6);
|
||||
_mm256_store_ps(dst + 3 * src_stride + 8, dst7);
|
||||
}
|
|
@ -0,0 +1,238 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
const float *dst_4 = dst + 3 * dst_stride;
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm2\n"
|
||||
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm3\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 2), %%ymm4\n"
|
||||
"vmovups 32(%[dst], %[dst_stride], 2), %%ymm5\n"
|
||||
"vmovups 0(%[dst_4]), %%ymm6\n"
|
||||
"vmovups 32(%[dst_4]), %%ymm7\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 32(%[bias]), %%ymm2\n"
|
||||
"vmovaps 32(%[bias]), %%ymm3\n"
|
||||
"vmovaps 64(%[bias]), %%ymm4\n"
|
||||
"vmovaps 64(%[bias]), %%ymm5\n"
|
||||
"vmovaps 96(%[bias]), %%ymm6\n"
|
||||
"vmovaps 96(%[bias]), %%ymm7\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
|
||||
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
|
||||
[ dst_4 ] "r"(dst_4)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vbroadcastss 0(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm14\n"
|
||||
"vmovaps 0(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 32(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 64(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 96(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
// block 1
|
||||
"vbroadcastss 1(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm14\n"
|
||||
"vmovaps 128(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 160(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 192(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 224(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
// block 2
|
||||
"vbroadcastss 2(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm14\n"
|
||||
"vmovaps 256(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 288(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 320(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 352(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
// block 3
|
||||
"vbroadcastss 3(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm14\n"
|
||||
"vmovaps 384(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 416(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 448(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 480(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
// block 4
|
||||
"vbroadcastss 4(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm14\n"
|
||||
"vmovaps 512(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 544(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 576(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 608(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
// block 5
|
||||
"vbroadcastss 5(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm14\n"
|
||||
"vmovaps 640(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 672(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 704(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 736(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
// block 6
|
||||
"vbroadcastss 6(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm14\n"
|
||||
"vmovaps 768(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 800(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 832(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 864(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
// block 7
|
||||
"vbroadcastss 7(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm14\n"
|
||||
"vmovaps 896(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 928(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 960(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vmovaps 992(%[weight]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
"dec %[deep]\n"
|
||||
"add 1024, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
|
||||
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"vminps %%ymm6, %%ymm14, %%ymm6\n"
|
||||
"vminps %%ymm7, %%ymm14, %%ymm7\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 0(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm3, 32(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm4, 0(%[dst], %[dst_stride], 2)\n"
|
||||
"vmovups %%ymm5, 32(%[dst], %[dst_stride], 2)\n"
|
||||
"vmovups %%ymm6, 0(%[dst_4])\n"
|
||||
"vmovups %%ymm7, 32(%[dst_4])\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
|
||||
[ dst_4 ] "r"(dst_4)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst1;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 8);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 16);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 24);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 32);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 40);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 48);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 56);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
src = src + src_stride;
|
||||
weight += 256;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
}
|
|
@ -0,0 +1,127 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
// block 1
|
||||
"vmovaps 32(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
// block 2
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
// block 3
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
// block 4
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
// block 5
|
||||
"vmovaps 160(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
// block 6
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
// block 7
|
||||
"vmovaps 224(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"dec %[deep]\n"
|
||||
"add 256, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,185 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst3;
|
||||
__m256 dst1;
|
||||
__m256 dst4;
|
||||
__m256 dst2;
|
||||
__m256 dst5;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst3 = _mm256_load_ps(bias + 8);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst4 = _mm256_load_ps(bias + 8);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst5 = _mm256_load_ps(bias + 8);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 weight10 = _mm256_load_ps(weight + 8);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src00, weight10);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src10, weight10);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src20, weight10);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 16);
|
||||
__m256 weight11 = _mm256_load_ps(weight + 24);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src01, weight11);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src11, weight11);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src21, weight11);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 32);
|
||||
__m256 weight12 = _mm256_load_ps(weight + 40);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src02, weight12);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src12, weight12);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src22, weight12);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 48);
|
||||
__m256 weight13 = _mm256_load_ps(weight + 56);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src03, weight13);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src13, weight13);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src23, weight13);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 64);
|
||||
__m256 weight14 = _mm256_load_ps(weight + 72);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src04, weight14);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src14, weight14);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src24, weight14);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 80);
|
||||
__m256 weight15 = _mm256_load_ps(weight + 88);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src05, weight15);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src15, weight15);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src25, weight15);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 96);
|
||||
__m256 weight16 = _mm256_load_ps(weight + 104);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src06, weight16);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src16, weight16);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src26, weight16);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 112);
|
||||
__m256 weight17 = _mm256_load_ps(weight + 120);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src07, weight17);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src17, weight17);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src27, weight17);
|
||||
src = src + src_stride;
|
||||
weight += 512;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 0, dst3);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 8, dst4);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 16, dst5);
|
||||
}
|
|
@ -0,0 +1,199 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n"
|
||||
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n"
|
||||
"vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 32(%[bias]), %%ymm3\n"
|
||||
"vmovaps 32(%[bias]), %%ymm4\n"
|
||||
"vmovaps 32(%[bias]), %%ymm5\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vmovaps 32(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
// block 1
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vmovaps 96(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
// block 2
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vmovaps 160(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
// block 3
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vmovaps 224(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
// block 4
|
||||
"vmovaps 256(%[weight]), %%ymm15\n"
|
||||
"vmovaps 288(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
// block 5
|
||||
"vmovaps 320(%[weight]), %%ymm15\n"
|
||||
"vmovaps 352(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
// block 6
|
||||
"vmovaps 384(%[weight]), %%ymm15\n"
|
||||
"vmovaps 416(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
// block 7
|
||||
"vmovaps 448(%[weight]), %%ymm15\n"
|
||||
"vmovaps 480(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"dec %[deep]\n"
|
||||
"add 512, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,241 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst3;
|
||||
__m256 dst6;
|
||||
__m256 dst1;
|
||||
__m256 dst4;
|
||||
__m256 dst7;
|
||||
__m256 dst2;
|
||||
__m256 dst5;
|
||||
__m256 dst8;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0);
|
||||
dst6 = _mm256_load_ps(dst + 2 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8);
|
||||
dst7 = _mm256_load_ps(dst + 2 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16);
|
||||
dst8 = _mm256_load_ps(dst + 2 * dst_stride + 16);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
dst6 = _mm256_setzero_ps();
|
||||
dst7 = _mm256_setzero_ps();
|
||||
dst8 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst3 = _mm256_load_ps(bias + 8);
|
||||
dst6 = _mm256_load_ps(bias + 16);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst4 = _mm256_load_ps(bias + 8);
|
||||
dst7 = _mm256_load_ps(bias + 16);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst5 = _mm256_load_ps(bias + 8);
|
||||
dst8 = _mm256_load_ps(bias + 16);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 weight10 = _mm256_load_ps(weight + 8);
|
||||
__m256 weight20 = _mm256_load_ps(weight + 16);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src00, weight10);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src00, weight20);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src10, weight10);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src10, weight20);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src20, weight10);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src20, weight20);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 24);
|
||||
__m256 weight11 = _mm256_load_ps(weight + 32);
|
||||
__m256 weight21 = _mm256_load_ps(weight + 40);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src01, weight11);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src01, weight21);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src11, weight11);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src11, weight21);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src21, weight11);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src21, weight21);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 48);
|
||||
__m256 weight12 = _mm256_load_ps(weight + 56);
|
||||
__m256 weight22 = _mm256_load_ps(weight + 64);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src02, weight12);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src02, weight22);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src12, weight12);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src12, weight22);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src22, weight12);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src22, weight22);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 72);
|
||||
__m256 weight13 = _mm256_load_ps(weight + 80);
|
||||
__m256 weight23 = _mm256_load_ps(weight + 88);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src03, weight13);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src03, weight23);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src13, weight13);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src13, weight23);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src23, weight13);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src23, weight23);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 96);
|
||||
__m256 weight14 = _mm256_load_ps(weight + 104);
|
||||
__m256 weight24 = _mm256_load_ps(weight + 112);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src04, weight14);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src04, weight24);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src14, weight14);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src14, weight24);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src24, weight14);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src24, weight24);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 120);
|
||||
__m256 weight15 = _mm256_load_ps(weight + 128);
|
||||
__m256 weight25 = _mm256_load_ps(weight + 136);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src05, weight15);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src05, weight25);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src15, weight15);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src15, weight25);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src25, weight15);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src25, weight25);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 144);
|
||||
__m256 weight16 = _mm256_load_ps(weight + 152);
|
||||
__m256 weight26 = _mm256_load_ps(weight + 160);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src06, weight16);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src06, weight26);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src16, weight16);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src16, weight26);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src26, weight16);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src26, weight26);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 168);
|
||||
__m256 weight17 = _mm256_load_ps(weight + 176);
|
||||
__m256 weight27 = _mm256_load_ps(weight + 184);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src07, weight17);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src07, weight27);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src17, weight17);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src17, weight27);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src27, weight17);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src27, weight27);
|
||||
src = src + src_stride;
|
||||
weight += 768;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst6 = _mm256_min_ps(dst6, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst7 = _mm256_min_ps(dst7, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
dst8 = _mm256_min_ps(dst8, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 0, dst3);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 8, dst4);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 16, dst5);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 0, dst6);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 8, dst7);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 16, dst8);
|
||||
}
|
|
@ -0,0 +1,249 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n"
|
||||
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n"
|
||||
"vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 2), %%ymm6\n"
|
||||
"vmovups 32(%[dst], %[dst_stride], 2), %%ymm7\n"
|
||||
"vmovups 64(%[dst], %[dst_stride], 2), %%ymm8\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 32(%[bias]), %%ymm3\n"
|
||||
"vmovaps 32(%[bias]), %%ymm4\n"
|
||||
"vmovaps 32(%[bias]), %%ymm5\n"
|
||||
"vmovaps 64(%[bias]), %%ymm6\n"
|
||||
"vmovaps 64(%[bias]), %%ymm7\n"
|
||||
"vmovaps 64(%[bias]), %%ymm8\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
|
||||
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
|
||||
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vmovaps 32(%[weight]), %%ymm14\n"
|
||||
"vmovaps 64(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
// block 1
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vmovaps 128(%[weight]), %%ymm14\n"
|
||||
"vmovaps 160(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
// block 2
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vmovaps 224(%[weight]), %%ymm14\n"
|
||||
"vmovaps 256(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
// block 3
|
||||
"vmovaps 288(%[weight]), %%ymm15\n"
|
||||
"vmovaps 320(%[weight]), %%ymm14\n"
|
||||
"vmovaps 352(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
// block 4
|
||||
"vmovaps 384(%[weight]), %%ymm15\n"
|
||||
"vmovaps 416(%[weight]), %%ymm14\n"
|
||||
"vmovaps 448(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
// block 5
|
||||
"vmovaps 480(%[weight]), %%ymm15\n"
|
||||
"vmovaps 512(%[weight]), %%ymm14\n"
|
||||
"vmovaps 544(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
// block 6
|
||||
"vmovaps 576(%[weight]), %%ymm15\n"
|
||||
"vmovaps 608(%[weight]), %%ymm14\n"
|
||||
"vmovaps 640(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
// block 7
|
||||
"vmovaps 672(%[weight]), %%ymm15\n"
|
||||
"vmovaps 704(%[weight]), %%ymm14\n"
|
||||
"vmovaps 736(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"dec %[deep]\n"
|
||||
"add 768, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
|
||||
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
|
||||
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"vminps %%ymm6, %%ymm14, %%ymm6\n"
|
||||
"vminps %%ymm7, %%ymm14, %%ymm7\n"
|
||||
"vminps %%ymm8, %%ymm14, %%ymm8\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm6, 0(%[dst], %[dst_stride], 2)\n"
|
||||
"vmovups %%ymm7, 32(%[dst], %[dst_stride], 2)\n"
|
||||
"vmovups %%ymm8, 64(%[dst], %[dst_stride], 2)\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,297 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst3;
|
||||
__m256 dst6;
|
||||
__m256 dst9;
|
||||
__m256 dst1;
|
||||
__m256 dst4;
|
||||
__m256 dst7;
|
||||
__m256 dst10;
|
||||
__m256 dst2;
|
||||
__m256 dst5;
|
||||
__m256 dst8;
|
||||
__m256 dst11;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst3 = _mm256_load_ps(dst + 1 * dst_stride + 0);
|
||||
dst6 = _mm256_load_ps(dst + 2 * dst_stride + 0);
|
||||
dst9 = _mm256_load_ps(dst + 3 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst4 = _mm256_load_ps(dst + 1 * dst_stride + 8);
|
||||
dst7 = _mm256_load_ps(dst + 2 * dst_stride + 8);
|
||||
dst10 = _mm256_load_ps(dst + 3 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst5 = _mm256_load_ps(dst + 1 * dst_stride + 16);
|
||||
dst8 = _mm256_load_ps(dst + 2 * dst_stride + 16);
|
||||
dst11 = _mm256_load_ps(dst + 3 * dst_stride + 16);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
dst6 = _mm256_setzero_ps();
|
||||
dst7 = _mm256_setzero_ps();
|
||||
dst8 = _mm256_setzero_ps();
|
||||
dst9 = _mm256_setzero_ps();
|
||||
dst10 = _mm256_setzero_ps();
|
||||
dst11 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst3 = _mm256_load_ps(bias + 8);
|
||||
dst6 = _mm256_load_ps(bias + 16);
|
||||
dst9 = _mm256_load_ps(bias + 24);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst4 = _mm256_load_ps(bias + 8);
|
||||
dst7 = _mm256_load_ps(bias + 16);
|
||||
dst10 = _mm256_load_ps(bias + 24);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst5 = _mm256_load_ps(bias + 8);
|
||||
dst8 = _mm256_load_ps(bias + 16);
|
||||
dst11 = _mm256_load_ps(bias + 24);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
__m256 weight10 = _mm256_load_ps(weight + 8);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src00, weight10);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src10, weight10);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src20, weight10);
|
||||
__m256 weight20 = _mm256_load_ps(weight + 16);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src00, weight20);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src10, weight20);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src20, weight20);
|
||||
__m256 weight30 = _mm256_load_ps(weight + 24);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src00, weight30);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src10, weight30);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src20, weight30);
|
||||
// bock1
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
__m256 weight01 = _mm256_load_ps(weight + 32);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
__m256 weight11 = _mm256_load_ps(weight + 40);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src01, weight11);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src11, weight11);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src21, weight11);
|
||||
__m256 weight21 = _mm256_load_ps(weight + 48);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src01, weight21);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src11, weight21);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src21, weight21);
|
||||
__m256 weight31 = _mm256_load_ps(weight + 56);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src01, weight31);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src11, weight31);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src21, weight31);
|
||||
// bock2
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
__m256 weight02 = _mm256_load_ps(weight + 64);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
__m256 weight12 = _mm256_load_ps(weight + 72);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src02, weight12);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src12, weight12);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src22, weight12);
|
||||
__m256 weight22 = _mm256_load_ps(weight + 80);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src02, weight22);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src12, weight22);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src22, weight22);
|
||||
__m256 weight32 = _mm256_load_ps(weight + 88);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src02, weight32);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src12, weight32);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src22, weight32);
|
||||
// bock3
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
__m256 weight03 = _mm256_load_ps(weight + 96);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
__m256 weight13 = _mm256_load_ps(weight + 104);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src03, weight13);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src13, weight13);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src23, weight13);
|
||||
__m256 weight23 = _mm256_load_ps(weight + 112);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src03, weight23);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src13, weight23);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src23, weight23);
|
||||
__m256 weight33 = _mm256_load_ps(weight + 120);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src03, weight33);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src13, weight33);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src23, weight33);
|
||||
// bock4
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
__m256 weight04 = _mm256_load_ps(weight + 128);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
__m256 weight14 = _mm256_load_ps(weight + 136);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src04, weight14);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src14, weight14);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src24, weight14);
|
||||
__m256 weight24 = _mm256_load_ps(weight + 144);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src04, weight24);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src14, weight24);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src24, weight24);
|
||||
__m256 weight34 = _mm256_load_ps(weight + 152);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src04, weight34);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src14, weight34);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src24, weight34);
|
||||
// bock5
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
__m256 weight05 = _mm256_load_ps(weight + 160);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
__m256 weight15 = _mm256_load_ps(weight + 168);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src05, weight15);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src15, weight15);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src25, weight15);
|
||||
__m256 weight25 = _mm256_load_ps(weight + 176);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src05, weight25);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src15, weight25);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src25, weight25);
|
||||
__m256 weight35 = _mm256_load_ps(weight + 184);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src05, weight35);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src15, weight35);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src25, weight35);
|
||||
// bock6
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
__m256 weight06 = _mm256_load_ps(weight + 192);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
__m256 weight16 = _mm256_load_ps(weight + 200);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src06, weight16);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src16, weight16);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src26, weight16);
|
||||
__m256 weight26 = _mm256_load_ps(weight + 208);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src06, weight26);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src16, weight26);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src26, weight26);
|
||||
__m256 weight36 = _mm256_load_ps(weight + 216);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src06, weight36);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src16, weight36);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src26, weight36);
|
||||
// bock7
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
__m256 weight07 = _mm256_load_ps(weight + 224);
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
__m256 weight17 = _mm256_load_ps(weight + 232);
|
||||
dst3 = _mm256_fmadd_ps(dst3, src07, weight17);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src17, weight17);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src27, weight17);
|
||||
__m256 weight27 = _mm256_load_ps(weight + 240);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src07, weight27);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src17, weight27);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src27, weight27);
|
||||
__m256 weight37 = _mm256_load_ps(weight + 248);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src07, weight37);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src17, weight37);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src27, weight37);
|
||||
src = src + src_stride;
|
||||
weight += 1024;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst6 = _mm256_min_ps(dst6, relu6);
|
||||
dst9 = _mm256_min_ps(dst9, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst7 = _mm256_min_ps(dst7, relu6);
|
||||
dst10 = _mm256_min_ps(dst10, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
dst8 = _mm256_min_ps(dst8, relu6);
|
||||
dst11 = _mm256_min_ps(dst11, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst9 = _mm256_max_ps(dst9, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst10 = _mm256_max_ps(dst10, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
dst11 = _mm256_max_ps(dst11, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst9 = _mm256_max_ps(dst9, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst10 = _mm256_max_ps(dst10, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
dst11 = _mm256_max_ps(dst11, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 0, dst3);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 8, dst4);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 16, dst5);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 0, dst6);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 8, dst7);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 16, dst8);
|
||||
_mm256_store_ps(dst + 3 * src_stride + 0, dst9);
|
||||
_mm256_store_ps(dst + 3 * src_stride + 8, dst10);
|
||||
_mm256_store_ps(dst + 3 * src_stride + 16, dst11);
|
||||
}
|
|
@ -0,0 +1,302 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
const float *dst_4 = dst + 3 * dst_stride;
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm3\n"
|
||||
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm4\n"
|
||||
"vmovups 64(%[dst], %[dst_stride], 1), %%ymm5\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 2), %%ymm6\n"
|
||||
"vmovups 32(%[dst], %[dst_stride], 2), %%ymm7\n"
|
||||
"vmovups 64(%[dst], %[dst_stride], 2), %%ymm8\n"
|
||||
"vmovups 0(%[dst_4]), %%ymm9\n"
|
||||
"vmovups 32(%[dst_4]), %%ymm10\n"
|
||||
"vmovups 64(%[dst_4]), %%ymm11\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 32(%[bias]), %%ymm3\n"
|
||||
"vmovaps 32(%[bias]), %%ymm4\n"
|
||||
"vmovaps 32(%[bias]), %%ymm5\n"
|
||||
"vmovaps 64(%[bias]), %%ymm6\n"
|
||||
"vmovaps 64(%[bias]), %%ymm7\n"
|
||||
"vmovaps 64(%[bias]), %%ymm8\n"
|
||||
"vmovaps 96(%[bias]), %%ymm9\n"
|
||||
"vmovaps 96(%[bias]), %%ymm10\n"
|
||||
"vmovaps 96(%[bias]), %%ymm11\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
|
||||
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
|
||||
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
|
||||
"vxorps %%ymm9, %%ymm9, %%ymm9\n"
|
||||
"vxorps %%ymm10, %%ymm10, %%ymm10\n"
|
||||
"vxorps %%ymm11, %%ymm11, %%ymm11\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag),
|
||||
[ dst_4 ] "r"(dst_4)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vbroadcastss 0(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm13\n"
|
||||
"vmovaps 0(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 32(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 64(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 96(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
// block 1
|
||||
"vbroadcastss 1(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm13\n"
|
||||
"vmovaps 128(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 160(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 192(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 224(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
// block 2
|
||||
"vbroadcastss 2(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm13\n"
|
||||
"vmovaps 256(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 288(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 320(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 352(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
// block 3
|
||||
"vbroadcastss 3(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm13\n"
|
||||
"vmovaps 384(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 416(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 448(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 480(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
// block 4
|
||||
"vbroadcastss 4(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm13\n"
|
||||
"vmovaps 512(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 544(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 576(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 608(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
// block 5
|
||||
"vbroadcastss 5(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm13\n"
|
||||
"vmovaps 640(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 672(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 704(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 736(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
// block 6
|
||||
"vbroadcastss 6(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm13\n"
|
||||
"vmovaps 768(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 800(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 832(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 864(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
// block 7
|
||||
"vbroadcastss 7(%[src]), %%ymm15\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm13\n"
|
||||
"vmovaps 896(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 928(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 960(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vmovaps 992(%[weight]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
"dec %[deep]\n"
|
||||
"add 1024, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
|
||||
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
|
||||
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
|
||||
"vmaxps %%ymm9, %%ymm15, %%ymm9\n"
|
||||
"vmaxps %%ymm10, %%ymm15, %%ymm10\n"
|
||||
"vmaxps %%ymm11, %%ymm15, %%ymm11\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"vminps %%ymm6, %%ymm14, %%ymm6\n"
|
||||
"vminps %%ymm7, %%ymm14, %%ymm7\n"
|
||||
"vminps %%ymm8, %%ymm14, %%ymm8\n"
|
||||
"vminps %%ymm9, %%ymm14, %%ymm9\n"
|
||||
"vminps %%ymm10, %%ymm14, %%ymm10\n"
|
||||
"vminps %%ymm11, %%ymm14, %%ymm11\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 0(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm4, 32(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm5, 64(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm6, 0(%[dst], %[dst_stride], 2)\n"
|
||||
"vmovups %%ymm7, 32(%[dst], %[dst_stride], 2)\n"
|
||||
"vmovups %%ymm8, 64(%[dst], %[dst_stride], 2)\n"
|
||||
"vmovups %%ymm9, 0(%[dst_4])\n"
|
||||
"vmovups %%ymm10, 32(%[dst_4])\n"
|
||||
"vmovups %%ymm11, 64(%[dst_4])\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t),
|
||||
[ dst_4 ] "r"(dst_4)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst1;
|
||||
__m256 dst2;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 8);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 16);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 24);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 32);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 40);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 48);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 56);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
src = src + src_stride;
|
||||
weight += 256;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
// block 1
|
||||
"vmovaps 32(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
// block 2
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
// block 3
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
// block 4
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
// block 5
|
||||
"vmovaps 160(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
// block 6
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
// block 7
|
||||
"vmovaps 224(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"dec %[deep]\n"
|
||||
"add 256, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,225 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst4;
|
||||
__m256 dst1;
|
||||
__m256 dst5;
|
||||
__m256 dst2;
|
||||
__m256 dst6;
|
||||
__m256 dst3;
|
||||
__m256 dst7;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst4 = _mm256_load_ps(dst + 1 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst5 = _mm256_load_ps(dst + 1 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst6 = _mm256_load_ps(dst + 1 * dst_stride + 16);
|
||||
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
|
||||
dst7 = _mm256_load_ps(dst + 1 * dst_stride + 24);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
dst6 = _mm256_setzero_ps();
|
||||
dst7 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst4 = _mm256_load_ps(bias + 8);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst5 = _mm256_load_ps(bias + 8);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst6 = _mm256_load_ps(bias + 8);
|
||||
dst3 = _mm256_load_ps(bias + 0);
|
||||
dst7 = _mm256_load_ps(bias + 8);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 weight10 = _mm256_load_ps(weight + 8);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src00, weight10);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src10, weight10);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src20, weight10);
|
||||
__m256 src30 = _mm256_set1_ps(*(src + 24));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src30, weight10);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 16);
|
||||
__m256 weight11 = _mm256_load_ps(weight + 24);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src01, weight11);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src11, weight11);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src21, weight11);
|
||||
__m256 src31 = _mm256_set1_ps(*(src + 25));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src31, weight11);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 32);
|
||||
__m256 weight12 = _mm256_load_ps(weight + 40);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src02, weight12);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src12, weight12);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src22, weight12);
|
||||
__m256 src32 = _mm256_set1_ps(*(src + 26));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src32, weight12);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 48);
|
||||
__m256 weight13 = _mm256_load_ps(weight + 56);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src03, weight13);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src13, weight13);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src23, weight13);
|
||||
__m256 src33 = _mm256_set1_ps(*(src + 27));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src33, weight13);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 64);
|
||||
__m256 weight14 = _mm256_load_ps(weight + 72);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src04, weight14);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src14, weight14);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src24, weight14);
|
||||
__m256 src34 = _mm256_set1_ps(*(src + 28));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src34, weight14);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 80);
|
||||
__m256 weight15 = _mm256_load_ps(weight + 88);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src05, weight15);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src15, weight15);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src25, weight15);
|
||||
__m256 src35 = _mm256_set1_ps(*(src + 29));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src35, weight15);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 96);
|
||||
__m256 weight16 = _mm256_load_ps(weight + 104);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src06, weight16);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src16, weight16);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src26, weight16);
|
||||
__m256 src36 = _mm256_set1_ps(*(src + 30));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src36, weight16);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 112);
|
||||
__m256 weight17 = _mm256_load_ps(weight + 120);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src07, weight17);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src17, weight17);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src27, weight17);
|
||||
__m256 src37 = _mm256_set1_ps(*(src + 31));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src37, weight17);
|
||||
src = src + src_stride;
|
||||
weight += 512;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst6 = _mm256_min_ps(dst6, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst7 = _mm256_min_ps(dst7, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 0, dst4);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 8, dst5);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 16, dst6);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 24, dst7);
|
||||
}
|
|
@ -0,0 +1,235 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 96(%[dst]), %%ymm3\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm4\n"
|
||||
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm5\n"
|
||||
"vmovups 64(%[dst], %[dst_stride], 1), %%ymm6\n"
|
||||
"vmovups 96(%[dst], %[dst_stride], 1), %%ymm7\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 0(%[bias]), %%ymm3\n"
|
||||
"vmovaps 32(%[bias]), %%ymm4\n"
|
||||
"vmovaps 32(%[bias]), %%ymm5\n"
|
||||
"vmovaps 32(%[bias]), %%ymm6\n"
|
||||
"vmovaps 32(%[bias]), %%ymm7\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
|
||||
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vmovaps 32(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 96(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
// block 1
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vmovaps 96(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 97(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
// block 2
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vmovaps 160(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 98(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
// block 3
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vmovaps 224(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 99(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
// block 4
|
||||
"vmovaps 256(%[weight]), %%ymm15\n"
|
||||
"vmovaps 288(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 100(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
// block 5
|
||||
"vmovaps 320(%[weight]), %%ymm15\n"
|
||||
"vmovaps 352(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 101(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
// block 6
|
||||
"vmovaps 384(%[weight]), %%ymm15\n"
|
||||
"vmovaps 416(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 102(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
// block 7
|
||||
"vmovaps 448(%[weight]), %%ymm15\n"
|
||||
"vmovaps 480(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 103(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"dec %[deep]\n"
|
||||
"add 512, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
|
||||
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"vminps %%ymm6, %%ymm14, %%ymm6\n"
|
||||
"vminps %%ymm7, %%ymm14, %%ymm7\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 96(%[dst])\n"
|
||||
"vmovups %%ymm4, 0(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm5, 32(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm6, 64(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm7, 96(%[dst], %[dst_stride], 1)\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,297 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst4;
|
||||
__m256 dst8;
|
||||
__m256 dst1;
|
||||
__m256 dst5;
|
||||
__m256 dst9;
|
||||
__m256 dst2;
|
||||
__m256 dst6;
|
||||
__m256 dst10;
|
||||
__m256 dst3;
|
||||
__m256 dst7;
|
||||
__m256 dst11;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst4 = _mm256_load_ps(dst + 1 * dst_stride + 0);
|
||||
dst8 = _mm256_load_ps(dst + 2 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst5 = _mm256_load_ps(dst + 1 * dst_stride + 8);
|
||||
dst9 = _mm256_load_ps(dst + 2 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst6 = _mm256_load_ps(dst + 1 * dst_stride + 16);
|
||||
dst10 = _mm256_load_ps(dst + 2 * dst_stride + 16);
|
||||
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
|
||||
dst7 = _mm256_load_ps(dst + 1 * dst_stride + 24);
|
||||
dst11 = _mm256_load_ps(dst + 2 * dst_stride + 24);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
dst6 = _mm256_setzero_ps();
|
||||
dst7 = _mm256_setzero_ps();
|
||||
dst8 = _mm256_setzero_ps();
|
||||
dst9 = _mm256_setzero_ps();
|
||||
dst10 = _mm256_setzero_ps();
|
||||
dst11 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst4 = _mm256_load_ps(bias + 8);
|
||||
dst8 = _mm256_load_ps(bias + 16);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst5 = _mm256_load_ps(bias + 8);
|
||||
dst9 = _mm256_load_ps(bias + 16);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst6 = _mm256_load_ps(bias + 8);
|
||||
dst10 = _mm256_load_ps(bias + 16);
|
||||
dst3 = _mm256_load_ps(bias + 0);
|
||||
dst7 = _mm256_load_ps(bias + 8);
|
||||
dst11 = _mm256_load_ps(bias + 16);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 weight10 = _mm256_load_ps(weight + 8);
|
||||
__m256 weight20 = _mm256_load_ps(weight + 16);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src00, weight10);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src00, weight20);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src10, weight10);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src10, weight20);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src20, weight10);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src20, weight20);
|
||||
__m256 src30 = _mm256_set1_ps(*(src + 24));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src30, weight10);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src30, weight20);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 24);
|
||||
__m256 weight11 = _mm256_load_ps(weight + 32);
|
||||
__m256 weight21 = _mm256_load_ps(weight + 40);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src01, weight11);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src01, weight21);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src11, weight11);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src11, weight21);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src21, weight11);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src21, weight21);
|
||||
__m256 src31 = _mm256_set1_ps(*(src + 25));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src31, weight11);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src31, weight21);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 48);
|
||||
__m256 weight12 = _mm256_load_ps(weight + 56);
|
||||
__m256 weight22 = _mm256_load_ps(weight + 64);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src02, weight12);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src02, weight22);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src12, weight12);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src12, weight22);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src22, weight12);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src22, weight22);
|
||||
__m256 src32 = _mm256_set1_ps(*(src + 26));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src32, weight12);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src32, weight22);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 72);
|
||||
__m256 weight13 = _mm256_load_ps(weight + 80);
|
||||
__m256 weight23 = _mm256_load_ps(weight + 88);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src03, weight13);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src03, weight23);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src13, weight13);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src13, weight23);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src23, weight13);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src23, weight23);
|
||||
__m256 src33 = _mm256_set1_ps(*(src + 27));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src33, weight13);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src33, weight23);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 96);
|
||||
__m256 weight14 = _mm256_load_ps(weight + 104);
|
||||
__m256 weight24 = _mm256_load_ps(weight + 112);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src04, weight14);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src04, weight24);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src14, weight14);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src14, weight24);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src24, weight14);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src24, weight24);
|
||||
__m256 src34 = _mm256_set1_ps(*(src + 28));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src34, weight14);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src34, weight24);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 120);
|
||||
__m256 weight15 = _mm256_load_ps(weight + 128);
|
||||
__m256 weight25 = _mm256_load_ps(weight + 136);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src05, weight15);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src05, weight25);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src15, weight15);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src15, weight25);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src25, weight15);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src25, weight25);
|
||||
__m256 src35 = _mm256_set1_ps(*(src + 29));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src35, weight15);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src35, weight25);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 144);
|
||||
__m256 weight16 = _mm256_load_ps(weight + 152);
|
||||
__m256 weight26 = _mm256_load_ps(weight + 160);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src06, weight16);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src06, weight26);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src16, weight16);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src16, weight26);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src26, weight16);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src26, weight26);
|
||||
__m256 src36 = _mm256_set1_ps(*(src + 30));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src36, weight16);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src36, weight26);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 168);
|
||||
__m256 weight17 = _mm256_load_ps(weight + 176);
|
||||
__m256 weight27 = _mm256_load_ps(weight + 184);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
dst4 = _mm256_fmadd_ps(dst4, src07, weight17);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src07, weight27);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src17, weight17);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src17, weight27);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src27, weight17);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src27, weight27);
|
||||
__m256 src37 = _mm256_set1_ps(*(src + 31));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src37, weight17);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src37, weight27);
|
||||
src = src + src_stride;
|
||||
weight += 768;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst8 = _mm256_min_ps(dst8, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
dst9 = _mm256_min_ps(dst9, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst6 = _mm256_min_ps(dst6, relu6);
|
||||
dst10 = _mm256_min_ps(dst10, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst7 = _mm256_min_ps(dst7, relu6);
|
||||
dst11 = _mm256_min_ps(dst11, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst9 = _mm256_max_ps(dst9, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst10 = _mm256_max_ps(dst10, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst11 = _mm256_max_ps(dst11, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst9 = _mm256_max_ps(dst9, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst10 = _mm256_max_ps(dst10, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst11 = _mm256_max_ps(dst11, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 0, dst4);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 8, dst5);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 16, dst6);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 24, dst7);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 0, dst8);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 8, dst9);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 16, dst10);
|
||||
_mm256_store_ps(dst + 2 * src_stride + 24, dst11);
|
||||
}
|
|
@ -0,0 +1,299 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 96(%[dst]), %%ymm3\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm4\n"
|
||||
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm5\n"
|
||||
"vmovups 64(%[dst], %[dst_stride], 1), %%ymm6\n"
|
||||
"vmovups 96(%[dst], %[dst_stride], 1), %%ymm7\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 2), %%ymm8\n"
|
||||
"vmovups 32(%[dst], %[dst_stride], 2), %%ymm9\n"
|
||||
"vmovups 64(%[dst], %[dst_stride], 2), %%ymm10\n"
|
||||
"vmovups 96(%[dst], %[dst_stride], 2), %%ymm11\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 0(%[bias]), %%ymm3\n"
|
||||
"vmovaps 32(%[bias]), %%ymm4\n"
|
||||
"vmovaps 32(%[bias]), %%ymm5\n"
|
||||
"vmovaps 32(%[bias]), %%ymm6\n"
|
||||
"vmovaps 32(%[bias]), %%ymm7\n"
|
||||
"vmovaps 64(%[bias]), %%ymm8\n"
|
||||
"vmovaps 64(%[bias]), %%ymm9\n"
|
||||
"vmovaps 64(%[bias]), %%ymm10\n"
|
||||
"vmovaps 64(%[bias]), %%ymm11\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
|
||||
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
|
||||
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
|
||||
"vxorps %%ymm9, %%ymm9, %%ymm9\n"
|
||||
"vxorps %%ymm10, %%ymm10, %%ymm10\n"
|
||||
"vxorps %%ymm11, %%ymm11, %%ymm11\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vmovaps 32(%[weight]), %%ymm14\n"
|
||||
"vmovaps 64(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 96(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
// block 1
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vmovaps 128(%[weight]), %%ymm14\n"
|
||||
"vmovaps 160(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 97(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
// block 2
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vmovaps 224(%[weight]), %%ymm14\n"
|
||||
"vmovaps 256(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 98(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
// block 3
|
||||
"vmovaps 288(%[weight]), %%ymm15\n"
|
||||
"vmovaps 320(%[weight]), %%ymm14\n"
|
||||
"vmovaps 352(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 99(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
// block 4
|
||||
"vmovaps 384(%[weight]), %%ymm15\n"
|
||||
"vmovaps 416(%[weight]), %%ymm14\n"
|
||||
"vmovaps 448(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 100(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
// block 5
|
||||
"vmovaps 480(%[weight]), %%ymm15\n"
|
||||
"vmovaps 512(%[weight]), %%ymm14\n"
|
||||
"vmovaps 544(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 101(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
// block 6
|
||||
"vmovaps 576(%[weight]), %%ymm15\n"
|
||||
"vmovaps 608(%[weight]), %%ymm14\n"
|
||||
"vmovaps 640(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 102(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
// block 7
|
||||
"vmovaps 672(%[weight]), %%ymm15\n"
|
||||
"vmovaps 704(%[weight]), %%ymm14\n"
|
||||
"vmovaps 736(%[weight]), %%ymm13\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm12, %%ymm13\n"
|
||||
"vbroadcastss 103(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm13\n"
|
||||
"dec %[deep]\n"
|
||||
"add 768, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
|
||||
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
|
||||
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
|
||||
"vmaxps %%ymm9, %%ymm15, %%ymm9\n"
|
||||
"vmaxps %%ymm10, %%ymm15, %%ymm10\n"
|
||||
"vmaxps %%ymm11, %%ymm15, %%ymm11\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"vminps %%ymm6, %%ymm14, %%ymm6\n"
|
||||
"vminps %%ymm7, %%ymm14, %%ymm7\n"
|
||||
"vminps %%ymm8, %%ymm14, %%ymm8\n"
|
||||
"vminps %%ymm9, %%ymm14, %%ymm9\n"
|
||||
"vminps %%ymm10, %%ymm14, %%ymm10\n"
|
||||
"vminps %%ymm11, %%ymm14, %%ymm11\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 96(%[dst])\n"
|
||||
"vmovups %%ymm4, 0(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm5, 32(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm6, 64(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm7, 96(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm8, 0(%[dst], %[dst_stride], 2)\n"
|
||||
"vmovups %%ymm9, 32(%[dst], %[dst_stride], 2)\n"
|
||||
"vmovups %%ymm10, 64(%[dst], %[dst_stride], 2)\n"
|
||||
"vmovups %%ymm11, 96(%[dst], %[dst_stride], 2)\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,153 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst1;
|
||||
__m256 dst2;
|
||||
__m256 dst3;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst3 = _mm256_load_ps(bias + 0);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
__m256 src30 = _mm256_set1_ps(*(src + 24));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 8);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
__m256 src31 = _mm256_set1_ps(*(src + 25));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 16);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
__m256 src32 = _mm256_set1_ps(*(src + 26));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 24);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
__m256 src33 = _mm256_set1_ps(*(src + 27));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 32);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
__m256 src34 = _mm256_set1_ps(*(src + 28));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 40);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
__m256 src35 = _mm256_set1_ps(*(src + 29));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 48);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
__m256 src36 = _mm256_set1_ps(*(src + 30));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 56);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
__m256 src37 = _mm256_set1_ps(*(src + 31));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
|
||||
src = src + src_stride;
|
||||
weight += 256;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 96(%[dst]), %%ymm3\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 0(%[bias]), %%ymm3\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 96(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
// block 1
|
||||
"vmovaps 32(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 97(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
// block 2
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 98(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
// block 3
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 99(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
// block 4
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 100(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
// block 5
|
||||
"vmovaps 160(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 101(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
// block 6
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 102(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
// block 7
|
||||
"vmovaps 224(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 103(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"dec %[deep]\n"
|
||||
"add 256, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 96(%[dst])\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,265 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst5;
|
||||
__m256 dst1;
|
||||
__m256 dst6;
|
||||
__m256 dst2;
|
||||
__m256 dst7;
|
||||
__m256 dst3;
|
||||
__m256 dst8;
|
||||
__m256 dst4;
|
||||
__m256 dst9;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst5 = _mm256_load_ps(dst + 1 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst6 = _mm256_load_ps(dst + 1 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst7 = _mm256_load_ps(dst + 1 * dst_stride + 16);
|
||||
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
|
||||
dst8 = _mm256_load_ps(dst + 1 * dst_stride + 24);
|
||||
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
|
||||
dst9 = _mm256_load_ps(dst + 1 * dst_stride + 32);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
dst6 = _mm256_setzero_ps();
|
||||
dst7 = _mm256_setzero_ps();
|
||||
dst8 = _mm256_setzero_ps();
|
||||
dst9 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst5 = _mm256_load_ps(bias + 8);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst6 = _mm256_load_ps(bias + 8);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst7 = _mm256_load_ps(bias + 8);
|
||||
dst3 = _mm256_load_ps(bias + 0);
|
||||
dst8 = _mm256_load_ps(bias + 8);
|
||||
dst4 = _mm256_load_ps(bias + 0);
|
||||
dst9 = _mm256_load_ps(bias + 8);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 weight10 = _mm256_load_ps(weight + 8);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src00, weight10);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src10, weight10);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src20, weight10);
|
||||
__m256 src30 = _mm256_set1_ps(*(src + 24));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src30, weight10);
|
||||
__m256 src40 = _mm256_set1_ps(*(src + 32));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src40, weight10);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 16);
|
||||
__m256 weight11 = _mm256_load_ps(weight + 24);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src01, weight11);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src11, weight11);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src21, weight11);
|
||||
__m256 src31 = _mm256_set1_ps(*(src + 25));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src31, weight11);
|
||||
__m256 src41 = _mm256_set1_ps(*(src + 33));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src41, weight11);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 32);
|
||||
__m256 weight12 = _mm256_load_ps(weight + 40);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src02, weight12);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src12, weight12);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src22, weight12);
|
||||
__m256 src32 = _mm256_set1_ps(*(src + 26));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src32, weight12);
|
||||
__m256 src42 = _mm256_set1_ps(*(src + 34));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src42, weight12);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 48);
|
||||
__m256 weight13 = _mm256_load_ps(weight + 56);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src03, weight13);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src13, weight13);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src23, weight13);
|
||||
__m256 src33 = _mm256_set1_ps(*(src + 27));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src33, weight13);
|
||||
__m256 src43 = _mm256_set1_ps(*(src + 35));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src43, weight13);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 64);
|
||||
__m256 weight14 = _mm256_load_ps(weight + 72);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src04, weight14);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src14, weight14);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src24, weight14);
|
||||
__m256 src34 = _mm256_set1_ps(*(src + 28));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src34, weight14);
|
||||
__m256 src44 = _mm256_set1_ps(*(src + 36));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src44, weight14);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 80);
|
||||
__m256 weight15 = _mm256_load_ps(weight + 88);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src05, weight15);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src15, weight15);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src25, weight15);
|
||||
__m256 src35 = _mm256_set1_ps(*(src + 29));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src35, weight15);
|
||||
__m256 src45 = _mm256_set1_ps(*(src + 37));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src45, weight15);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 96);
|
||||
__m256 weight16 = _mm256_load_ps(weight + 104);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src06, weight16);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src16, weight16);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src26, weight16);
|
||||
__m256 src36 = _mm256_set1_ps(*(src + 30));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src36, weight16);
|
||||
__m256 src46 = _mm256_set1_ps(*(src + 38));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src46, weight16);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 112);
|
||||
__m256 weight17 = _mm256_load_ps(weight + 120);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
dst5 = _mm256_fmadd_ps(dst5, src07, weight17);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src17, weight17);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src27, weight17);
|
||||
__m256 src37 = _mm256_set1_ps(*(src + 31));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src37, weight17);
|
||||
__m256 src47 = _mm256_set1_ps(*(src + 39));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src47, weight17);
|
||||
src = src + src_stride;
|
||||
weight += 512;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst6 = _mm256_min_ps(dst6, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst7 = _mm256_min_ps(dst7, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst8 = _mm256_min_ps(dst8, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst9 = _mm256_min_ps(dst9, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst9 = _mm256_max_ps(dst9, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst9 = _mm256_max_ps(dst9, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 0, dst5);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 8, dst6);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 16, dst7);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 24, dst8);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 32, dst9);
|
||||
}
|
|
@ -0,0 +1,271 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 96(%[dst]), %%ymm3\n"
|
||||
"vmovups 128(%[dst]), %%ymm4\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm5\n"
|
||||
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm6\n"
|
||||
"vmovups 64(%[dst], %[dst_stride], 1), %%ymm7\n"
|
||||
"vmovups 96(%[dst], %[dst_stride], 1), %%ymm8\n"
|
||||
"vmovups 128(%[dst], %[dst_stride], 1), %%ymm9\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 0(%[bias]), %%ymm3\n"
|
||||
"vmovaps 0(%[bias]), %%ymm4\n"
|
||||
"vmovaps 32(%[bias]), %%ymm5\n"
|
||||
"vmovaps 32(%[bias]), %%ymm6\n"
|
||||
"vmovaps 32(%[bias]), %%ymm7\n"
|
||||
"vmovaps 32(%[bias]), %%ymm8\n"
|
||||
"vmovaps 32(%[bias]), %%ymm9\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
|
||||
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
|
||||
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
|
||||
"vxorps %%ymm9, %%ymm9, %%ymm9\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vmovaps 32(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 96(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 128(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
|
||||
// block 1
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vmovaps 96(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 97(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 129(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
|
||||
// block 2
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vmovaps 160(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 98(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 130(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
|
||||
// block 3
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vmovaps 224(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 99(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 131(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
|
||||
// block 4
|
||||
"vmovaps 256(%[weight]), %%ymm15\n"
|
||||
"vmovaps 288(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 100(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 132(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
|
||||
// block 5
|
||||
"vmovaps 320(%[weight]), %%ymm15\n"
|
||||
"vmovaps 352(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 101(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 133(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
|
||||
// block 6
|
||||
"vmovaps 384(%[weight]), %%ymm15\n"
|
||||
"vmovaps 416(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 102(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 134(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
|
||||
// block 7
|
||||
"vmovaps 448(%[weight]), %%ymm15\n"
|
||||
"vmovaps 480(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 103(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 135(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm13, %%ymm14\n"
|
||||
"dec %[deep]\n"
|
||||
"add 512, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
|
||||
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
|
||||
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
|
||||
"vmaxps %%ymm9, %%ymm15, %%ymm9\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"vminps %%ymm6, %%ymm14, %%ymm6\n"
|
||||
"vminps %%ymm7, %%ymm14, %%ymm7\n"
|
||||
"vminps %%ymm8, %%ymm14, %%ymm8\n"
|
||||
"vminps %%ymm9, %%ymm14, %%ymm9\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 96(%[dst])\n"
|
||||
"vmovups %%ymm4, 128(%[dst])\n"
|
||||
"vmovups %%ymm5, 0(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm6, 32(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm7, 64(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm8, 96(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm9, 128(%[dst], %[dst_stride], 1)\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,177 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst1;
|
||||
__m256 dst2;
|
||||
__m256 dst3;
|
||||
__m256 dst4;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
|
||||
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst3 = _mm256_load_ps(bias + 0);
|
||||
dst4 = _mm256_load_ps(bias + 0);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
__m256 src30 = _mm256_set1_ps(*(src + 24));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
|
||||
__m256 src40 = _mm256_set1_ps(*(src + 32));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 8);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
__m256 src31 = _mm256_set1_ps(*(src + 25));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
|
||||
__m256 src41 = _mm256_set1_ps(*(src + 33));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 16);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
__m256 src32 = _mm256_set1_ps(*(src + 26));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
|
||||
__m256 src42 = _mm256_set1_ps(*(src + 34));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 24);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
__m256 src33 = _mm256_set1_ps(*(src + 27));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
|
||||
__m256 src43 = _mm256_set1_ps(*(src + 35));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 32);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
__m256 src34 = _mm256_set1_ps(*(src + 28));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
|
||||
__m256 src44 = _mm256_set1_ps(*(src + 36));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 40);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
__m256 src35 = _mm256_set1_ps(*(src + 29));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
|
||||
__m256 src45 = _mm256_set1_ps(*(src + 37));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 48);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
__m256 src36 = _mm256_set1_ps(*(src + 30));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
|
||||
__m256 src46 = _mm256_set1_ps(*(src + 38));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 56);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
__m256 src37 = _mm256_set1_ps(*(src + 31));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
|
||||
__m256 src47 = _mm256_set1_ps(*(src + 39));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
|
||||
src = src + src_stride;
|
||||
weight += 256;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
|
||||
}
|
|
@ -0,0 +1,193 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 96(%[dst]), %%ymm3\n"
|
||||
"vmovups 128(%[dst]), %%ymm4\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 0(%[bias]), %%ymm3\n"
|
||||
"vmovaps 0(%[bias]), %%ymm4\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 96(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 128(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
// block 1
|
||||
"vmovaps 32(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 97(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 129(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
// block 2
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 98(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 130(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
// block 3
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 99(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 131(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
// block 4
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 100(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 132(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
// block 5
|
||||
"vmovaps 160(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 101(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 133(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
// block 6
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 102(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 134(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
// block 7
|
||||
"vmovaps 224(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 103(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 135(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"dec %[deep]\n"
|
||||
"add 256, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 96(%[dst])\n"
|
||||
"vmovups %%ymm4, 128(%[dst])\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,305 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst6;
|
||||
__m256 dst1;
|
||||
__m256 dst7;
|
||||
__m256 dst2;
|
||||
__m256 dst8;
|
||||
__m256 dst3;
|
||||
__m256 dst9;
|
||||
__m256 dst4;
|
||||
__m256 dst10;
|
||||
__m256 dst5;
|
||||
__m256 dst11;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst6 = _mm256_load_ps(dst + 1 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst7 = _mm256_load_ps(dst + 1 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst8 = _mm256_load_ps(dst + 1 * dst_stride + 16);
|
||||
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
|
||||
dst9 = _mm256_load_ps(dst + 1 * dst_stride + 24);
|
||||
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
|
||||
dst10 = _mm256_load_ps(dst + 1 * dst_stride + 32);
|
||||
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
|
||||
dst11 = _mm256_load_ps(dst + 1 * dst_stride + 40);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
dst6 = _mm256_setzero_ps();
|
||||
dst7 = _mm256_setzero_ps();
|
||||
dst8 = _mm256_setzero_ps();
|
||||
dst9 = _mm256_setzero_ps();
|
||||
dst10 = _mm256_setzero_ps();
|
||||
dst11 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst6 = _mm256_load_ps(bias + 8);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst7 = _mm256_load_ps(bias + 8);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst8 = _mm256_load_ps(bias + 8);
|
||||
dst3 = _mm256_load_ps(bias + 0);
|
||||
dst9 = _mm256_load_ps(bias + 8);
|
||||
dst4 = _mm256_load_ps(bias + 0);
|
||||
dst10 = _mm256_load_ps(bias + 8);
|
||||
dst5 = _mm256_load_ps(bias + 0);
|
||||
dst11 = _mm256_load_ps(bias + 8);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 weight10 = _mm256_load_ps(weight + 8);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src00, weight10);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src10, weight10);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src20, weight10);
|
||||
__m256 src30 = _mm256_set1_ps(*(src + 24));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src30, weight10);
|
||||
__m256 src40 = _mm256_set1_ps(*(src + 32));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src40, weight10);
|
||||
__m256 src50 = _mm256_set1_ps(*(src + 40));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src50, weight10);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 16);
|
||||
__m256 weight11 = _mm256_load_ps(weight + 24);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src01, weight11);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src11, weight11);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src21, weight11);
|
||||
__m256 src31 = _mm256_set1_ps(*(src + 25));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src31, weight11);
|
||||
__m256 src41 = _mm256_set1_ps(*(src + 33));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src41, weight11);
|
||||
__m256 src51 = _mm256_set1_ps(*(src + 41));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src51, weight11);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 32);
|
||||
__m256 weight12 = _mm256_load_ps(weight + 40);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src02, weight12);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src12, weight12);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src22, weight12);
|
||||
__m256 src32 = _mm256_set1_ps(*(src + 26));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src32, weight12);
|
||||
__m256 src42 = _mm256_set1_ps(*(src + 34));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src42, weight12);
|
||||
__m256 src52 = _mm256_set1_ps(*(src + 42));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src52, weight12);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 48);
|
||||
__m256 weight13 = _mm256_load_ps(weight + 56);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src03, weight13);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src13, weight13);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src23, weight13);
|
||||
__m256 src33 = _mm256_set1_ps(*(src + 27));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src33, weight13);
|
||||
__m256 src43 = _mm256_set1_ps(*(src + 35));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src43, weight13);
|
||||
__m256 src53 = _mm256_set1_ps(*(src + 43));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src53, weight13);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 64);
|
||||
__m256 weight14 = _mm256_load_ps(weight + 72);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src04, weight14);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src14, weight14);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src24, weight14);
|
||||
__m256 src34 = _mm256_set1_ps(*(src + 28));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src34, weight14);
|
||||
__m256 src44 = _mm256_set1_ps(*(src + 36));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src44, weight14);
|
||||
__m256 src54 = _mm256_set1_ps(*(src + 44));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src54, weight14);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 80);
|
||||
__m256 weight15 = _mm256_load_ps(weight + 88);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src05, weight15);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src15, weight15);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src25, weight15);
|
||||
__m256 src35 = _mm256_set1_ps(*(src + 29));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src35, weight15);
|
||||
__m256 src45 = _mm256_set1_ps(*(src + 37));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src45, weight15);
|
||||
__m256 src55 = _mm256_set1_ps(*(src + 45));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src55, weight15);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 96);
|
||||
__m256 weight16 = _mm256_load_ps(weight + 104);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src06, weight16);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src16, weight16);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src26, weight16);
|
||||
__m256 src36 = _mm256_set1_ps(*(src + 30));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src36, weight16);
|
||||
__m256 src46 = _mm256_set1_ps(*(src + 38));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src46, weight16);
|
||||
__m256 src56 = _mm256_set1_ps(*(src + 46));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src56, weight16);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 112);
|
||||
__m256 weight17 = _mm256_load_ps(weight + 120);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
dst6 = _mm256_fmadd_ps(dst6, src07, weight17);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
dst7 = _mm256_fmadd_ps(dst7, src17, weight17);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
dst8 = _mm256_fmadd_ps(dst8, src27, weight17);
|
||||
__m256 src37 = _mm256_set1_ps(*(src + 31));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
|
||||
dst9 = _mm256_fmadd_ps(dst9, src37, weight17);
|
||||
__m256 src47 = _mm256_set1_ps(*(src + 39));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
|
||||
dst10 = _mm256_fmadd_ps(dst10, src47, weight17);
|
||||
__m256 src57 = _mm256_set1_ps(*(src + 47));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
|
||||
dst11 = _mm256_fmadd_ps(dst11, src57, weight17);
|
||||
src = src + src_stride;
|
||||
weight += 512;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst6 = _mm256_min_ps(dst6, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst7 = _mm256_min_ps(dst7, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst8 = _mm256_min_ps(dst8, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst9 = _mm256_min_ps(dst9, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst10 = _mm256_min_ps(dst10, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
dst11 = _mm256_min_ps(dst11, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst9 = _mm256_max_ps(dst9, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst10 = _mm256_max_ps(dst10, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst11 = _mm256_max_ps(dst11, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst9 = _mm256_max_ps(dst9, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst10 = _mm256_max_ps(dst10, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst11 = _mm256_max_ps(dst11, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 0, dst6);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 8, dst7);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 16, dst8);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 24, dst9);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 32, dst10);
|
||||
_mm256_store_ps(dst + 1 * src_stride + 40, dst11);
|
||||
}
|
|
@ -0,0 +1,307 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 96(%[dst]), %%ymm3\n"
|
||||
"vmovups 128(%[dst]), %%ymm4\n"
|
||||
"vmovups 160(%[dst]), %%ymm5\n"
|
||||
"vmovups 0(%[dst], %[dst_stride], 1), %%ymm6\n"
|
||||
"vmovups 32(%[dst], %[dst_stride], 1), %%ymm7\n"
|
||||
"vmovups 64(%[dst], %[dst_stride], 1), %%ymm8\n"
|
||||
"vmovups 96(%[dst], %[dst_stride], 1), %%ymm9\n"
|
||||
"vmovups 128(%[dst], %[dst_stride], 1), %%ymm10\n"
|
||||
"vmovups 160(%[dst], %[dst_stride], 1), %%ymm11\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 0(%[bias]), %%ymm3\n"
|
||||
"vmovaps 0(%[bias]), %%ymm4\n"
|
||||
"vmovaps 0(%[bias]), %%ymm5\n"
|
||||
"vmovaps 32(%[bias]), %%ymm6\n"
|
||||
"vmovaps 32(%[bias]), %%ymm7\n"
|
||||
"vmovaps 32(%[bias]), %%ymm8\n"
|
||||
"vmovaps 32(%[bias]), %%ymm9\n"
|
||||
"vmovaps 32(%[bias]), %%ymm10\n"
|
||||
"vmovaps 32(%[bias]), %%ymm11\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
|
||||
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
|
||||
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
|
||||
"vxorps %%ymm9, %%ymm9, %%ymm9\n"
|
||||
"vxorps %%ymm10, %%ymm10, %%ymm10\n"
|
||||
"vxorps %%ymm11, %%ymm11, %%ymm11\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10", "%ymm11");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vmovaps 32(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 96(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 128(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 160(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
|
||||
// block 1
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vmovaps 96(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 97(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 129(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 161(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
|
||||
// block 2
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vmovaps 160(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 98(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 130(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 162(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
|
||||
// block 3
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vmovaps 224(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 99(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 131(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 163(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
|
||||
// block 4
|
||||
"vmovaps 256(%[weight]), %%ymm15\n"
|
||||
"vmovaps 288(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 100(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 132(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 164(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
|
||||
// block 5
|
||||
"vmovaps 320(%[weight]), %%ymm15\n"
|
||||
"vmovaps 352(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 101(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 133(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 165(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
|
||||
// block 6
|
||||
"vmovaps 384(%[weight]), %%ymm15\n"
|
||||
"vmovaps 416(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 102(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 134(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 166(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
|
||||
// block 7
|
||||
"vmovaps 448(%[weight]), %%ymm15\n"
|
||||
"vmovaps 480(%[weight]), %%ymm14\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 103(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm9, %%ymm12, %%ymm14\n"
|
||||
"vbroadcastss 135(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 167(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm10, %%ymm13, %%ymm14\n"
|
||||
"vfmadd231ps %%ymm11, %%ymm12, %%ymm14\n"
|
||||
"dec %[deep]\n"
|
||||
"add 512, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
|
||||
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
|
||||
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
|
||||
"vmaxps %%ymm9, %%ymm15, %%ymm9\n"
|
||||
"vmaxps %%ymm10, %%ymm15, %%ymm10\n"
|
||||
"vmaxps %%ymm11, %%ymm15, %%ymm11\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"vminps %%ymm6, %%ymm14, %%ymm6\n"
|
||||
"vminps %%ymm7, %%ymm14, %%ymm7\n"
|
||||
"vminps %%ymm8, %%ymm14, %%ymm8\n"
|
||||
"vminps %%ymm9, %%ymm14, %%ymm9\n"
|
||||
"vminps %%ymm10, %%ymm14, %%ymm10\n"
|
||||
"vminps %%ymm11, %%ymm14, %%ymm11\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 96(%[dst])\n"
|
||||
"vmovups %%ymm4, 128(%[dst])\n"
|
||||
"vmovups %%ymm5, 160(%[dst])\n"
|
||||
"vmovups %%ymm6, 0(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm7, 32(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm8, 64(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm9, 96(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm10, 128(%[dst], %[dst_stride], 1)\n"
|
||||
"vmovups %%ymm11, 160(%[dst], %[dst_stride], 1)\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,201 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst1;
|
||||
__m256 dst2;
|
||||
__m256 dst3;
|
||||
__m256 dst4;
|
||||
__m256 dst5;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
|
||||
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
|
||||
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst3 = _mm256_load_ps(bias + 0);
|
||||
dst4 = _mm256_load_ps(bias + 0);
|
||||
dst5 = _mm256_load_ps(bias + 0);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
__m256 src30 = _mm256_set1_ps(*(src + 24));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
|
||||
__m256 src40 = _mm256_set1_ps(*(src + 32));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
|
||||
__m256 src50 = _mm256_set1_ps(*(src + 40));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 8);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
__m256 src31 = _mm256_set1_ps(*(src + 25));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
|
||||
__m256 src41 = _mm256_set1_ps(*(src + 33));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
|
||||
__m256 src51 = _mm256_set1_ps(*(src + 41));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 16);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
__m256 src32 = _mm256_set1_ps(*(src + 26));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
|
||||
__m256 src42 = _mm256_set1_ps(*(src + 34));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
|
||||
__m256 src52 = _mm256_set1_ps(*(src + 42));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 24);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
__m256 src33 = _mm256_set1_ps(*(src + 27));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
|
||||
__m256 src43 = _mm256_set1_ps(*(src + 35));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
|
||||
__m256 src53 = _mm256_set1_ps(*(src + 43));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 32);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
__m256 src34 = _mm256_set1_ps(*(src + 28));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
|
||||
__m256 src44 = _mm256_set1_ps(*(src + 36));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
|
||||
__m256 src54 = _mm256_set1_ps(*(src + 44));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 40);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
__m256 src35 = _mm256_set1_ps(*(src + 29));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
|
||||
__m256 src45 = _mm256_set1_ps(*(src + 37));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
|
||||
__m256 src55 = _mm256_set1_ps(*(src + 45));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 48);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
__m256 src36 = _mm256_set1_ps(*(src + 30));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
|
||||
__m256 src46 = _mm256_set1_ps(*(src + 38));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
|
||||
__m256 src56 = _mm256_set1_ps(*(src + 46));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 56);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
__m256 src37 = _mm256_set1_ps(*(src + 31));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
|
||||
__m256 src47 = _mm256_set1_ps(*(src + 39));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
|
||||
__m256 src57 = _mm256_set1_ps(*(src + 47));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
|
||||
src = src + src_stride;
|
||||
weight += 256;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
|
||||
}
|
|
@ -0,0 +1,215 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 96(%[dst]), %%ymm3\n"
|
||||
"vmovups 128(%[dst]), %%ymm4\n"
|
||||
"vmovups 160(%[dst]), %%ymm5\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 0(%[bias]), %%ymm3\n"
|
||||
"vmovaps 0(%[bias]), %%ymm4\n"
|
||||
"vmovaps 0(%[bias]), %%ymm5\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 96(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 128(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 160(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
// block 1
|
||||
"vmovaps 32(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 97(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 129(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 161(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
// block 2
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 98(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 130(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 162(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
// block 3
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 99(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 131(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 163(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
// block 4
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 100(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 132(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 164(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
// block 5
|
||||
"vmovaps 160(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 101(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 133(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 165(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
// block 6
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 102(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 134(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 166(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
// block 7
|
||||
"vmovaps 224(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 103(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 135(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 167(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"dec %[deep]\n"
|
||||
"add 256, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 96(%[dst])\n"
|
||||
"vmovups %%ymm4, 128(%[dst])\n"
|
||||
"vmovups %%ymm5, 160(%[dst])\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,225 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst1;
|
||||
__m256 dst2;
|
||||
__m256 dst3;
|
||||
__m256 dst4;
|
||||
__m256 dst5;
|
||||
__m256 dst6;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
|
||||
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
|
||||
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
|
||||
dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
dst6 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst3 = _mm256_load_ps(bias + 0);
|
||||
dst4 = _mm256_load_ps(bias + 0);
|
||||
dst5 = _mm256_load_ps(bias + 0);
|
||||
dst6 = _mm256_load_ps(bias + 0);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
__m256 src30 = _mm256_set1_ps(*(src + 24));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
|
||||
__m256 src40 = _mm256_set1_ps(*(src + 32));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
|
||||
__m256 src50 = _mm256_set1_ps(*(src + 40));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
|
||||
__m256 src60 = _mm256_set1_ps(*(src + 48));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src60, weight00);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 8);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
__m256 src31 = _mm256_set1_ps(*(src + 25));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
|
||||
__m256 src41 = _mm256_set1_ps(*(src + 33));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
|
||||
__m256 src51 = _mm256_set1_ps(*(src + 41));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
|
||||
__m256 src61 = _mm256_set1_ps(*(src + 49));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src61, weight01);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 16);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
__m256 src32 = _mm256_set1_ps(*(src + 26));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
|
||||
__m256 src42 = _mm256_set1_ps(*(src + 34));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
|
||||
__m256 src52 = _mm256_set1_ps(*(src + 42));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
|
||||
__m256 src62 = _mm256_set1_ps(*(src + 50));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src62, weight02);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 24);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
__m256 src33 = _mm256_set1_ps(*(src + 27));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
|
||||
__m256 src43 = _mm256_set1_ps(*(src + 35));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
|
||||
__m256 src53 = _mm256_set1_ps(*(src + 43));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
|
||||
__m256 src63 = _mm256_set1_ps(*(src + 51));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src63, weight03);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 32);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
__m256 src34 = _mm256_set1_ps(*(src + 28));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
|
||||
__m256 src44 = _mm256_set1_ps(*(src + 36));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
|
||||
__m256 src54 = _mm256_set1_ps(*(src + 44));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
|
||||
__m256 src64 = _mm256_set1_ps(*(src + 52));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src64, weight04);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 40);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
__m256 src35 = _mm256_set1_ps(*(src + 29));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
|
||||
__m256 src45 = _mm256_set1_ps(*(src + 37));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
|
||||
__m256 src55 = _mm256_set1_ps(*(src + 45));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
|
||||
__m256 src65 = _mm256_set1_ps(*(src + 53));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src65, weight05);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 48);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
__m256 src36 = _mm256_set1_ps(*(src + 30));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
|
||||
__m256 src46 = _mm256_set1_ps(*(src + 38));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
|
||||
__m256 src56 = _mm256_set1_ps(*(src + 46));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
|
||||
__m256 src66 = _mm256_set1_ps(*(src + 54));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src66, weight06);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 56);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
__m256 src37 = _mm256_set1_ps(*(src + 31));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
|
||||
__m256 src47 = _mm256_set1_ps(*(src + 39));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
|
||||
__m256 src57 = _mm256_set1_ps(*(src + 47));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
|
||||
__m256 src67 = _mm256_set1_ps(*(src + 55));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src67, weight07);
|
||||
src = src + src_stride;
|
||||
weight += 256;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
dst6 = _mm256_min_ps(dst6, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 48, dst6);
|
||||
}
|
|
@ -0,0 +1,237 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 96(%[dst]), %%ymm3\n"
|
||||
"vmovups 128(%[dst]), %%ymm4\n"
|
||||
"vmovups 160(%[dst]), %%ymm5\n"
|
||||
"vmovups 192(%[dst]), %%ymm6\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 0(%[bias]), %%ymm3\n"
|
||||
"vmovaps 0(%[bias]), %%ymm4\n"
|
||||
"vmovaps 0(%[bias]), %%ymm5\n"
|
||||
"vmovaps 0(%[bias]), %%ymm6\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 96(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 128(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 160(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 192(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
// block 1
|
||||
"vmovaps 32(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 97(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 129(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 161(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 193(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
// block 2
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 98(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 130(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 162(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 194(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
// block 3
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 99(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 131(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 163(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 195(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
// block 4
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 100(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 132(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 164(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 196(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
// block 5
|
||||
"vmovaps 160(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 101(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 133(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 165(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 197(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
// block 6
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 102(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 134(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 166(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 198(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
// block 7
|
||||
"vmovaps 224(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 103(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 135(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 167(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 199(%[src]), %%ymm14\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"dec %[deep]\n"
|
||||
"add 256, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"vminps %%ymm6, %%ymm14, %%ymm6\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 96(%[dst])\n"
|
||||
"vmovups %%ymm4, 128(%[dst])\n"
|
||||
"vmovups %%ymm5, 160(%[dst])\n"
|
||||
"vmovups %%ymm6, 192(%[dst])\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,249 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst1;
|
||||
__m256 dst2;
|
||||
__m256 dst3;
|
||||
__m256 dst4;
|
||||
__m256 dst5;
|
||||
__m256 dst6;
|
||||
__m256 dst7;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
|
||||
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
|
||||
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
|
||||
dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48);
|
||||
dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
dst6 = _mm256_setzero_ps();
|
||||
dst7 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst3 = _mm256_load_ps(bias + 0);
|
||||
dst4 = _mm256_load_ps(bias + 0);
|
||||
dst5 = _mm256_load_ps(bias + 0);
|
||||
dst6 = _mm256_load_ps(bias + 0);
|
||||
dst7 = _mm256_load_ps(bias + 0);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
__m256 src30 = _mm256_set1_ps(*(src + 24));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
|
||||
__m256 src40 = _mm256_set1_ps(*(src + 32));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
|
||||
__m256 src50 = _mm256_set1_ps(*(src + 40));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
|
||||
__m256 src60 = _mm256_set1_ps(*(src + 48));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src60, weight00);
|
||||
__m256 src70 = _mm256_set1_ps(*(src + 56));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src70, weight00);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 8);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
__m256 src31 = _mm256_set1_ps(*(src + 25));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
|
||||
__m256 src41 = _mm256_set1_ps(*(src + 33));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
|
||||
__m256 src51 = _mm256_set1_ps(*(src + 41));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
|
||||
__m256 src61 = _mm256_set1_ps(*(src + 49));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src61, weight01);
|
||||
__m256 src71 = _mm256_set1_ps(*(src + 57));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src71, weight01);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 16);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
__m256 src32 = _mm256_set1_ps(*(src + 26));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
|
||||
__m256 src42 = _mm256_set1_ps(*(src + 34));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
|
||||
__m256 src52 = _mm256_set1_ps(*(src + 42));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
|
||||
__m256 src62 = _mm256_set1_ps(*(src + 50));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src62, weight02);
|
||||
__m256 src72 = _mm256_set1_ps(*(src + 58));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src72, weight02);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 24);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
__m256 src33 = _mm256_set1_ps(*(src + 27));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
|
||||
__m256 src43 = _mm256_set1_ps(*(src + 35));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
|
||||
__m256 src53 = _mm256_set1_ps(*(src + 43));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
|
||||
__m256 src63 = _mm256_set1_ps(*(src + 51));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src63, weight03);
|
||||
__m256 src73 = _mm256_set1_ps(*(src + 59));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src73, weight03);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 32);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
__m256 src34 = _mm256_set1_ps(*(src + 28));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
|
||||
__m256 src44 = _mm256_set1_ps(*(src + 36));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
|
||||
__m256 src54 = _mm256_set1_ps(*(src + 44));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
|
||||
__m256 src64 = _mm256_set1_ps(*(src + 52));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src64, weight04);
|
||||
__m256 src74 = _mm256_set1_ps(*(src + 60));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src74, weight04);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 40);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
__m256 src35 = _mm256_set1_ps(*(src + 29));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
|
||||
__m256 src45 = _mm256_set1_ps(*(src + 37));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
|
||||
__m256 src55 = _mm256_set1_ps(*(src + 45));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
|
||||
__m256 src65 = _mm256_set1_ps(*(src + 53));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src65, weight05);
|
||||
__m256 src75 = _mm256_set1_ps(*(src + 61));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src75, weight05);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 48);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
__m256 src36 = _mm256_set1_ps(*(src + 30));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
|
||||
__m256 src46 = _mm256_set1_ps(*(src + 38));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
|
||||
__m256 src56 = _mm256_set1_ps(*(src + 46));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
|
||||
__m256 src66 = _mm256_set1_ps(*(src + 54));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src66, weight06);
|
||||
__m256 src76 = _mm256_set1_ps(*(src + 62));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src76, weight06);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 56);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
__m256 src37 = _mm256_set1_ps(*(src + 31));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
|
||||
__m256 src47 = _mm256_set1_ps(*(src + 39));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
|
||||
__m256 src57 = _mm256_set1_ps(*(src + 47));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
|
||||
__m256 src67 = _mm256_set1_ps(*(src + 55));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src67, weight07);
|
||||
__m256 src77 = _mm256_set1_ps(*(src + 63));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src77, weight07);
|
||||
src = src + src_stride;
|
||||
weight += 256;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
dst6 = _mm256_min_ps(dst6, relu6);
|
||||
dst7 = _mm256_min_ps(dst7, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 48, dst6);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 56, dst7);
|
||||
}
|
|
@ -0,0 +1,259 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 96(%[dst]), %%ymm3\n"
|
||||
"vmovups 128(%[dst]), %%ymm4\n"
|
||||
"vmovups 160(%[dst]), %%ymm5\n"
|
||||
"vmovups 192(%[dst]), %%ymm6\n"
|
||||
"vmovups 224(%[dst]), %%ymm7\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 0(%[bias]), %%ymm3\n"
|
||||
"vmovaps 0(%[bias]), %%ymm4\n"
|
||||
"vmovaps 0(%[bias]), %%ymm5\n"
|
||||
"vmovaps 0(%[bias]), %%ymm6\n"
|
||||
"vmovaps 0(%[bias]), %%ymm7\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
|
||||
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 96(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 128(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 160(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 192(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 224(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
// block 1
|
||||
"vmovaps 32(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 97(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 129(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 161(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 193(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 225(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
// block 2
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 98(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 130(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 162(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 194(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 226(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
// block 3
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 99(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 131(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 163(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 195(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 227(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
// block 4
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 100(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 132(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 164(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 196(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 228(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
// block 5
|
||||
"vmovaps 160(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 101(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 133(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 165(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 197(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 229(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
// block 6
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 102(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 134(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 166(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 198(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 230(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
// block 7
|
||||
"vmovaps 224(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 103(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 135(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 167(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 199(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 231(%[src]), %%ymm13\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"dec %[deep]\n"
|
||||
"add 256, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
|
||||
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"vminps %%ymm6, %%ymm14, %%ymm6\n"
|
||||
"vminps %%ymm7, %%ymm14, %%ymm7\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 96(%[dst])\n"
|
||||
"vmovups %%ymm4, 128(%[dst])\n"
|
||||
"vmovups %%ymm5, 160(%[dst])\n"
|
||||
"vmovups %%ymm6, 192(%[dst])\n"
|
||||
"vmovups %%ymm7, 224(%[dst])\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,273 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
__m256 dst0;
|
||||
__m256 dst1;
|
||||
__m256 dst2;
|
||||
__m256 dst3;
|
||||
__m256 dst4;
|
||||
__m256 dst5;
|
||||
__m256 dst6;
|
||||
__m256 dst7;
|
||||
__m256 dst8;
|
||||
if (inc_flag) {
|
||||
dst0 = _mm256_load_ps(dst + 0 * dst_stride + 0);
|
||||
dst1 = _mm256_load_ps(dst + 0 * dst_stride + 8);
|
||||
dst2 = _mm256_load_ps(dst + 0 * dst_stride + 16);
|
||||
dst3 = _mm256_load_ps(dst + 0 * dst_stride + 24);
|
||||
dst4 = _mm256_load_ps(dst + 0 * dst_stride + 32);
|
||||
dst5 = _mm256_load_ps(dst + 0 * dst_stride + 40);
|
||||
dst6 = _mm256_load_ps(dst + 0 * dst_stride + 48);
|
||||
dst7 = _mm256_load_ps(dst + 0 * dst_stride + 56);
|
||||
dst8 = _mm256_load_ps(dst + 0 * dst_stride + 64);
|
||||
} else if (bias == NULL) {
|
||||
dst0 = _mm256_setzero_ps();
|
||||
dst1 = _mm256_setzero_ps();
|
||||
dst2 = _mm256_setzero_ps();
|
||||
dst3 = _mm256_setzero_ps();
|
||||
dst4 = _mm256_setzero_ps();
|
||||
dst5 = _mm256_setzero_ps();
|
||||
dst6 = _mm256_setzero_ps();
|
||||
dst7 = _mm256_setzero_ps();
|
||||
dst8 = _mm256_setzero_ps();
|
||||
} else {
|
||||
dst0 = _mm256_load_ps(bias + 0);
|
||||
dst1 = _mm256_load_ps(bias + 0);
|
||||
dst2 = _mm256_load_ps(bias + 0);
|
||||
dst3 = _mm256_load_ps(bias + 0);
|
||||
dst4 = _mm256_load_ps(bias + 0);
|
||||
dst5 = _mm256_load_ps(bias + 0);
|
||||
dst6 = _mm256_load_ps(bias + 0);
|
||||
dst7 = _mm256_load_ps(bias + 0);
|
||||
dst8 = _mm256_load_ps(bias + 0);
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
// bock0
|
||||
__m256 weight00 = _mm256_load_ps(weight + 0);
|
||||
__m256 src00 = _mm256_set1_ps(*(src + 0));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src00, weight00);
|
||||
__m256 src10 = _mm256_set1_ps(*(src + 8));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src10, weight00);
|
||||
__m256 src20 = _mm256_set1_ps(*(src + 16));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src20, weight00);
|
||||
__m256 src30 = _mm256_set1_ps(*(src + 24));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src30, weight00);
|
||||
__m256 src40 = _mm256_set1_ps(*(src + 32));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src40, weight00);
|
||||
__m256 src50 = _mm256_set1_ps(*(src + 40));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src50, weight00);
|
||||
__m256 src60 = _mm256_set1_ps(*(src + 48));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src60, weight00);
|
||||
__m256 src70 = _mm256_set1_ps(*(src + 56));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src70, weight00);
|
||||
__m256 src80 = _mm256_set1_ps(*(src + 64));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src80, weight00);
|
||||
// bock1
|
||||
__m256 weight01 = _mm256_load_ps(weight + 8);
|
||||
__m256 src01 = _mm256_set1_ps(*(src + 1));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src01, weight01);
|
||||
__m256 src11 = _mm256_set1_ps(*(src + 9));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src11, weight01);
|
||||
__m256 src21 = _mm256_set1_ps(*(src + 17));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src21, weight01);
|
||||
__m256 src31 = _mm256_set1_ps(*(src + 25));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src31, weight01);
|
||||
__m256 src41 = _mm256_set1_ps(*(src + 33));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src41, weight01);
|
||||
__m256 src51 = _mm256_set1_ps(*(src + 41));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src51, weight01);
|
||||
__m256 src61 = _mm256_set1_ps(*(src + 49));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src61, weight01);
|
||||
__m256 src71 = _mm256_set1_ps(*(src + 57));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src71, weight01);
|
||||
__m256 src81 = _mm256_set1_ps(*(src + 65));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src81, weight01);
|
||||
// bock2
|
||||
__m256 weight02 = _mm256_load_ps(weight + 16);
|
||||
__m256 src02 = _mm256_set1_ps(*(src + 2));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src02, weight02);
|
||||
__m256 src12 = _mm256_set1_ps(*(src + 10));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src12, weight02);
|
||||
__m256 src22 = _mm256_set1_ps(*(src + 18));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src22, weight02);
|
||||
__m256 src32 = _mm256_set1_ps(*(src + 26));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src32, weight02);
|
||||
__m256 src42 = _mm256_set1_ps(*(src + 34));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src42, weight02);
|
||||
__m256 src52 = _mm256_set1_ps(*(src + 42));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src52, weight02);
|
||||
__m256 src62 = _mm256_set1_ps(*(src + 50));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src62, weight02);
|
||||
__m256 src72 = _mm256_set1_ps(*(src + 58));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src72, weight02);
|
||||
__m256 src82 = _mm256_set1_ps(*(src + 66));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src82, weight02);
|
||||
// bock3
|
||||
__m256 weight03 = _mm256_load_ps(weight + 24);
|
||||
__m256 src03 = _mm256_set1_ps(*(src + 3));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src03, weight03);
|
||||
__m256 src13 = _mm256_set1_ps(*(src + 11));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src13, weight03);
|
||||
__m256 src23 = _mm256_set1_ps(*(src + 19));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src23, weight03);
|
||||
__m256 src33 = _mm256_set1_ps(*(src + 27));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src33, weight03);
|
||||
__m256 src43 = _mm256_set1_ps(*(src + 35));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src43, weight03);
|
||||
__m256 src53 = _mm256_set1_ps(*(src + 43));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src53, weight03);
|
||||
__m256 src63 = _mm256_set1_ps(*(src + 51));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src63, weight03);
|
||||
__m256 src73 = _mm256_set1_ps(*(src + 59));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src73, weight03);
|
||||
__m256 src83 = _mm256_set1_ps(*(src + 67));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src83, weight03);
|
||||
// bock4
|
||||
__m256 weight04 = _mm256_load_ps(weight + 32);
|
||||
__m256 src04 = _mm256_set1_ps(*(src + 4));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src04, weight04);
|
||||
__m256 src14 = _mm256_set1_ps(*(src + 12));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src14, weight04);
|
||||
__m256 src24 = _mm256_set1_ps(*(src + 20));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src24, weight04);
|
||||
__m256 src34 = _mm256_set1_ps(*(src + 28));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src34, weight04);
|
||||
__m256 src44 = _mm256_set1_ps(*(src + 36));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src44, weight04);
|
||||
__m256 src54 = _mm256_set1_ps(*(src + 44));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src54, weight04);
|
||||
__m256 src64 = _mm256_set1_ps(*(src + 52));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src64, weight04);
|
||||
__m256 src74 = _mm256_set1_ps(*(src + 60));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src74, weight04);
|
||||
__m256 src84 = _mm256_set1_ps(*(src + 68));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src84, weight04);
|
||||
// bock5
|
||||
__m256 weight05 = _mm256_load_ps(weight + 40);
|
||||
__m256 src05 = _mm256_set1_ps(*(src + 5));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src05, weight05);
|
||||
__m256 src15 = _mm256_set1_ps(*(src + 13));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src15, weight05);
|
||||
__m256 src25 = _mm256_set1_ps(*(src + 21));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src25, weight05);
|
||||
__m256 src35 = _mm256_set1_ps(*(src + 29));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src35, weight05);
|
||||
__m256 src45 = _mm256_set1_ps(*(src + 37));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src45, weight05);
|
||||
__m256 src55 = _mm256_set1_ps(*(src + 45));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src55, weight05);
|
||||
__m256 src65 = _mm256_set1_ps(*(src + 53));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src65, weight05);
|
||||
__m256 src75 = _mm256_set1_ps(*(src + 61));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src75, weight05);
|
||||
__m256 src85 = _mm256_set1_ps(*(src + 69));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src85, weight05);
|
||||
// bock6
|
||||
__m256 weight06 = _mm256_load_ps(weight + 48);
|
||||
__m256 src06 = _mm256_set1_ps(*(src + 6));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src06, weight06);
|
||||
__m256 src16 = _mm256_set1_ps(*(src + 14));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src16, weight06);
|
||||
__m256 src26 = _mm256_set1_ps(*(src + 22));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src26, weight06);
|
||||
__m256 src36 = _mm256_set1_ps(*(src + 30));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src36, weight06);
|
||||
__m256 src46 = _mm256_set1_ps(*(src + 38));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src46, weight06);
|
||||
__m256 src56 = _mm256_set1_ps(*(src + 46));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src56, weight06);
|
||||
__m256 src66 = _mm256_set1_ps(*(src + 54));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src66, weight06);
|
||||
__m256 src76 = _mm256_set1_ps(*(src + 62));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src76, weight06);
|
||||
__m256 src86 = _mm256_set1_ps(*(src + 70));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src86, weight06);
|
||||
// bock7
|
||||
__m256 weight07 = _mm256_load_ps(weight + 56);
|
||||
__m256 src07 = _mm256_set1_ps(*(src + 7));
|
||||
dst0 = _mm256_fmadd_ps(dst0, src07, weight07);
|
||||
__m256 src17 = _mm256_set1_ps(*(src + 15));
|
||||
dst1 = _mm256_fmadd_ps(dst1, src17, weight07);
|
||||
__m256 src27 = _mm256_set1_ps(*(src + 23));
|
||||
dst2 = _mm256_fmadd_ps(dst2, src27, weight07);
|
||||
__m256 src37 = _mm256_set1_ps(*(src + 31));
|
||||
dst3 = _mm256_fmadd_ps(dst3, src37, weight07);
|
||||
__m256 src47 = _mm256_set1_ps(*(src + 39));
|
||||
dst4 = _mm256_fmadd_ps(dst4, src47, weight07);
|
||||
__m256 src57 = _mm256_set1_ps(*(src + 47));
|
||||
dst5 = _mm256_fmadd_ps(dst5, src57, weight07);
|
||||
__m256 src67 = _mm256_set1_ps(*(src + 55));
|
||||
dst6 = _mm256_fmadd_ps(dst6, src67, weight07);
|
||||
__m256 src77 = _mm256_set1_ps(*(src + 63));
|
||||
dst7 = _mm256_fmadd_ps(dst7, src77, weight07);
|
||||
__m256 src87 = _mm256_set1_ps(*(src + 71));
|
||||
dst8 = _mm256_fmadd_ps(dst8, src87, weight07);
|
||||
src = src + src_stride;
|
||||
weight += 256;
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_min_ps(dst0, relu6);
|
||||
dst1 = _mm256_min_ps(dst1, relu6);
|
||||
dst2 = _mm256_min_ps(dst2, relu6);
|
||||
dst3 = _mm256_min_ps(dst3, relu6);
|
||||
dst4 = _mm256_min_ps(dst4, relu6);
|
||||
dst5 = _mm256_min_ps(dst5, relu6);
|
||||
dst6 = _mm256_min_ps(dst6, relu6);
|
||||
dst7 = _mm256_min_ps(dst7, relu6);
|
||||
dst8 = _mm256_min_ps(dst8, relu6);
|
||||
// relu
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
dst0 = _mm256_max_ps(dst0, relu);
|
||||
dst1 = _mm256_max_ps(dst1, relu);
|
||||
dst2 = _mm256_max_ps(dst2, relu);
|
||||
dst3 = _mm256_max_ps(dst3, relu);
|
||||
dst4 = _mm256_max_ps(dst4, relu);
|
||||
dst5 = _mm256_max_ps(dst5, relu);
|
||||
dst6 = _mm256_max_ps(dst6, relu);
|
||||
dst7 = _mm256_max_ps(dst7, relu);
|
||||
dst8 = _mm256_max_ps(dst8, relu);
|
||||
}
|
||||
_mm256_store_ps(dst + 0 * src_stride + 0, dst0);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 8, dst1);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 16, dst2);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 24, dst3);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 32, dst4);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 40, dst5);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 48, dst6);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 56, dst7);
|
||||
_mm256_store_ps(dst + 0 * src_stride + 64, dst8);
|
||||
}
|
|
@ -0,0 +1,281 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight, const float *bias,
|
||||
const size_t act_flag, const size_t row_block, const size_t col_block,
|
||||
const size_t deep, const size_t src_stride, const size_t dst_stride,
|
||||
const size_t inc_flag) {
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\n"
|
||||
"je 0f\n"
|
||||
"vmovups 0(%[dst]), %%ymm0\n"
|
||||
"vmovups 32(%[dst]), %%ymm1\n"
|
||||
"vmovups 64(%[dst]), %%ymm2\n"
|
||||
"vmovups 96(%[dst]), %%ymm3\n"
|
||||
"vmovups 128(%[dst]), %%ymm4\n"
|
||||
"vmovups 160(%[dst]), %%ymm5\n"
|
||||
"vmovups 192(%[dst]), %%ymm6\n"
|
||||
"vmovups 224(%[dst]), %%ymm7\n"
|
||||
"vmovups 256(%[dst]), %%ymm8\n"
|
||||
"jmp 2f\n"
|
||||
"0:\n"
|
||||
"cmpq $0, %[bias]\n"
|
||||
"je 1f\n"
|
||||
"vmovaps 0(%[bias]), %%ymm0\n"
|
||||
"vmovaps 0(%[bias]), %%ymm1\n"
|
||||
"vmovaps 0(%[bias]), %%ymm2\n"
|
||||
"vmovaps 0(%[bias]), %%ymm3\n"
|
||||
"vmovaps 0(%[bias]), %%ymm4\n"
|
||||
"vmovaps 0(%[bias]), %%ymm5\n"
|
||||
"vmovaps 0(%[bias]), %%ymm6\n"
|
||||
"vmovaps 0(%[bias]), %%ymm7\n"
|
||||
"vmovaps 0(%[bias]), %%ymm8\n"
|
||||
"jmp 2f\n"
|
||||
"1:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"vxorps %%ymm4, %%ymm4, %%ymm4\n"
|
||||
"vxorps %%ymm5, %%ymm5, %%ymm5\n"
|
||||
"vxorps %%ymm6, %%ymm6, %%ymm6\n"
|
||||
"vxorps %%ymm7, %%ymm7, %%ymm7\n"
|
||||
"vxorps %%ymm8, %%ymm8, %%ymm8\n"
|
||||
"2:\n"
|
||||
:
|
||||
: [ dst ] "r"(dst), [ bias ] "r"(bias), [ dst_stride ] "r"(dst_stride_t), [ inc_flag ] "r"(inc_flag)
|
||||
: "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8");
|
||||
asm volatile(
|
||||
"0:\n"
|
||||
// block 0
|
||||
"vmovaps 0(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 0(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 32(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 64(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 96(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 128(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 160(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 192(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 224(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 256(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
// block 1
|
||||
"vmovaps 32(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 1(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 33(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 65(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 97(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 129(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 161(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 193(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 225(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 257(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
// block 2
|
||||
"vmovaps 64(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 2(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 34(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 66(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 98(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 130(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 162(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 194(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 226(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 258(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
// block 3
|
||||
"vmovaps 96(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 3(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 35(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 67(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 99(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 131(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 163(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 195(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 227(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 259(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
// block 4
|
||||
"vmovaps 128(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 4(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 36(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 68(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 100(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 132(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 164(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 196(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 228(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 260(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
// block 5
|
||||
"vmovaps 160(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 5(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 37(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 69(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 101(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 133(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 165(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 197(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 229(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 261(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
// block 6
|
||||
"vmovaps 192(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 6(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 38(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 70(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 102(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 134(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 166(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 198(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 230(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 262(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
// block 7
|
||||
"vmovaps 224(%[weight]), %%ymm15\n"
|
||||
"vbroadcastss 7(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 39(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 71(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm0, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm1, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm2, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 103(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 135(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 167(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm3, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm4, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm5, %%ymm12, %%ymm15\n"
|
||||
"vbroadcastss 199(%[src]), %%ymm14\n"
|
||||
"vbroadcastss 231(%[src]), %%ymm13\n"
|
||||
"vbroadcastss 263(%[src]), %%ymm12\n"
|
||||
"vfmadd231ps %%ymm6, %%ymm14, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm7, %%ymm13, %%ymm15\n"
|
||||
"vfmadd231ps %%ymm8, %%ymm12, %%ymm15\n"
|
||||
"dec %[deep]\n"
|
||||
"add 256, %[weight]\n"
|
||||
"add %[src_stride], %[src]\n"
|
||||
"jg 0b\n"
|
||||
|
||||
"movq %[inc_flag], %%rax\n"
|
||||
"and $0x2, %%eax\n"
|
||||
"je 3f\n"
|
||||
"movq %[act_flag], %%rax\n"
|
||||
"and $0x3, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu
|
||||
"vxorps %%ymm15, %%ymm15, %%ymm15\n"
|
||||
"vmaxps %%ymm0, %%ymm15, %%ymm0\n"
|
||||
"vmaxps %%ymm1, %%ymm15, %%ymm1\n"
|
||||
"vmaxps %%ymm2, %%ymm15, %%ymm2\n"
|
||||
"vmaxps %%ymm3, %%ymm15, %%ymm3\n"
|
||||
"vmaxps %%ymm4, %%ymm15, %%ymm4\n"
|
||||
"vmaxps %%ymm5, %%ymm15, %%ymm5\n"
|
||||
"vmaxps %%ymm6, %%ymm15, %%ymm6\n"
|
||||
"vmaxps %%ymm7, %%ymm15, %%ymm7\n"
|
||||
"vmaxps %%ymm8, %%ymm15, %%ymm8\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 3f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%eax\n"
|
||||
"vmovd %%eax, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm15, %%ymm14\n"
|
||||
"vminps %%ymm0, %%ymm14, %%ymm0\n"
|
||||
"vminps %%ymm1, %%ymm14, %%ymm1\n"
|
||||
"vminps %%ymm2, %%ymm14, %%ymm2\n"
|
||||
"vminps %%ymm3, %%ymm14, %%ymm3\n"
|
||||
"vminps %%ymm4, %%ymm14, %%ymm4\n"
|
||||
"vminps %%ymm5, %%ymm14, %%ymm5\n"
|
||||
"vminps %%ymm6, %%ymm14, %%ymm6\n"
|
||||
"vminps %%ymm7, %%ymm14, %%ymm7\n"
|
||||
"vminps %%ymm8, %%ymm14, %%ymm8\n"
|
||||
"3:\n"
|
||||
"vmovups %%ymm0, 0(%[dst])\n"
|
||||
"vmovups %%ymm1, 32(%[dst])\n"
|
||||
"vmovups %%ymm2, 64(%[dst])\n"
|
||||
"vmovups %%ymm3, 96(%[dst])\n"
|
||||
"vmovups %%ymm4, 128(%[dst])\n"
|
||||
"vmovups %%ymm5, 160(%[dst])\n"
|
||||
"vmovups %%ymm6, 192(%[dst])\n"
|
||||
"vmovups %%ymm7, 224(%[dst])\n"
|
||||
"vmovups %%ymm8, 256(%[dst])\n"
|
||||
:
|
||||
: [ src ] "r"(src), [ src_stride ] "r"(src_stride_t), [ weight ] "r"(weight), [ deep ] "r"(deep_t),
|
||||
[ inc_flag ] "r"(inc_flag), [ act_flag ] "r"(act_flag), [ dst ] "r"(dst), [ dst_stride ] "r"(dst_stride_t)
|
||||
: "%rax", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4", "%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
|
||||
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
# generate gemm fma asm code
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32_asm.c
|
||||
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32_asm.c
|
||||
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32_asm.c
|
||||
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32_asm.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8_asm.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32_asm.c
|
||||
|
||||
# generate gemm fma intrinics code
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=12 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_12x8_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=11 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_11x8_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=10 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_10x8_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=9 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_9x8_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=8 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_8x8_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=7 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_7x8_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_6x8_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_5x8_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_4x8_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_3x8_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_2x8_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=8 -O ./gemm_fma/nnacl_gemm_fma_1x8_kernel_nc8hw8_fp32.c
|
||||
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=6 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_6x16_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=5 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_5x16_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_4x16_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_3x16_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_2x16_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=16 -O ./gemm_fma/nnacl_gemm_fma_1x16_kernel_nc8hw8_fp32.c
|
||||
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=4 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_4x24_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_3x24_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_2x24_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=24 -O ./gemm_fma/nnacl_gemm_fma_1x24_kernel_nc8hw8_fp32.c
|
||||
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=3 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_3x32_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=2 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_2x32_kernel_nc8hw8_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_fma_nc8hw8.c.in -A row_block=1 col_block=32 -O ./gemm_fma/nnacl_gemm_fma_1x32_kernel_nc8hw8_fp32.c
|
||||
|
||||
# generate gemm avx512 asm code
|
||||
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=4 col_block=96 -O ./gemm_avx512/nnacl_gemm_avx512_4x96_kernel_nhwc_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=3 col_block=96 -O ./gemm_avx512/nnacl_gemm_avx512_3x96_kernel_nhwc_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=2 col_block=96 -O ./gemm_avx512/nnacl_gemm_avx512_2x96_kernel_nhwc_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=1 col_block=96 -O ./gemm_avx512/nnacl_gemm_avx512_1x96_kernel_nhwc_fp32.c
|
||||
|
||||
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=5 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_5x80_kernel_nhwc_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=4 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_4x80_kernel_nhwc_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=3 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_3x80_kernel_nhwc_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=2 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_2x80_kernel_nhwc_fp32.c
|
||||
python3 generator.py -I ./template_file/gemm_avx512_nhwc_asm.c.in -A row_block=1 col_block=80 -O ./gemm_avx512/nnacl_gemm_avx512_1x80_kernel_nhwc_fp32.c
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""HPC generator"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import io
|
||||
import argparse
|
||||
from itertools import chain
|
||||
|
||||
def key_value_pair(line):
|
||||
"""
|
||||
split key and value
|
||||
:param line:
|
||||
:return:
|
||||
"""
|
||||
key, value = line.split("=", 1)
|
||||
try:
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
print("Error: you input value must be integer, but now is ", value)
|
||||
return key, value
|
||||
|
||||
def get_indent(line):
|
||||
"""
|
||||
get indent length
|
||||
:param line:
|
||||
:return:
|
||||
"""
|
||||
index = 0
|
||||
for i in line:
|
||||
if i == " ":
|
||||
index += 1
|
||||
else:
|
||||
break
|
||||
return index
|
||||
|
||||
def print_line(line):
|
||||
"""
|
||||
Convert line to a python string
|
||||
:param line:
|
||||
:return:
|
||||
"""
|
||||
global python_indent
|
||||
global generate_code_indent
|
||||
if line.strip()[0] == "}" or line.strip()[0] == ")":
|
||||
python_indent = -1
|
||||
split_str = line.split("@")
|
||||
if line.strip()[0] != "@" and len(split_str) == 1:
|
||||
if get_indent(line) == python_indent or python_indent == -1:
|
||||
result = ["print(", line, ", file=OUT_STREAM)"]
|
||||
python_indent = -1
|
||||
if "{" in line or "asm volatile(" in line:
|
||||
generate_code_indent = get_indent(line)
|
||||
if line.strip().startswith("}") and "{" not in line:
|
||||
generate_code_indent -= 4
|
||||
if (len(line) == 1 and line[0] == "}"):
|
||||
# modify next fun generate_code_indent
|
||||
generate_code_indent = -4
|
||||
return "\"".join(result)
|
||||
|
||||
if line.strip()[0] == '@':
|
||||
# get python indent and first generate_code_indent
|
||||
if python_indent == -1:
|
||||
generate_code_indent = get_indent(line) - 4
|
||||
python_indent = get_indent(line)
|
||||
result = split_str[0][python_indent:] + split_str[1]
|
||||
return result
|
||||
|
||||
index = get_indent(split_str[0])
|
||||
result = [split_str[0][python_indent:index] + "print("]
|
||||
Prefix = " " * (generate_code_indent + 4) + split_str[0].lstrip()
|
||||
|
||||
Suffix = " %("
|
||||
for str_tmp in split_str[1:]:
|
||||
second = str_tmp.find("}")
|
||||
Suffix += str_tmp[1:second] + ', '
|
||||
str_tmp = str_tmp.replace(str_tmp[0:second + 1], "%d")
|
||||
Prefix += str_tmp
|
||||
result.append(Prefix)
|
||||
result.append(Suffix + "), file=OUT_STREAM)")
|
||||
return "\"".join(result)
|
||||
|
||||
def generate_code(template_file, exec_dict):
|
||||
"""
|
||||
generate hpc
|
||||
:param template_file: template file path
|
||||
:param exec_dict: dict
|
||||
:return: hpc
|
||||
"""
|
||||
output_stream = io.StringIO()
|
||||
with open(template_file, 'r') as f:
|
||||
generate_code_lines = []
|
||||
for line in f:
|
||||
line = line.replace("\n", "")
|
||||
if line.strip() and line.strip()[0] != "@":
|
||||
line = line.replace("\"", "\\\"")
|
||||
if line.strip() and line.strip()[0] != "@":
|
||||
line = line.replace("%", "%%")
|
||||
if "print" in line:
|
||||
line = line.replace("%%", "%")
|
||||
if not line:
|
||||
generate_code_lines.append("print(" + "\"" + line + "\"" + ", file=OUT_STREAM)")
|
||||
else:
|
||||
str = print_line(line)
|
||||
if "%(" not in str:
|
||||
str = str.replace("%%[", "%[")
|
||||
generate_code_lines.append(str)
|
||||
# print('\n'.join(generate_code_lines))
|
||||
c = compile('\n'.join(generate_code_lines), '', 'exec')
|
||||
exec_dict["OUT_STREAM"] = output_stream
|
||||
exec(c, exec_dict)
|
||||
return output_stream.getvalue()
|
||||
|
||||
generate_code_indent = -4
|
||||
python_indent = -1
|
||||
|
||||
parser = argparse.ArgumentParser(description="MSLite NNACL Code Generator")
|
||||
parser.add_argument("-I", dest="Template_File", nargs=1, help="template file to generate code")
|
||||
parser.add_argument("-A", dest="defines", metavar="KEY=VALUE", nargs="*", type=key_value_pair, action="append",
|
||||
help="Custom Parameters")
|
||||
parser.add_argument("-O", dest="Output_File", nargs=1, help="generate code output file path")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parameters = parser.parse_args(sys.argv[1:])
|
||||
exec_globals = dict(chain(*parameters.defines))
|
||||
|
||||
generate_code_str = generate_code(parameters.Template_File[0], exec_globals)
|
||||
if os.path.exists(parameters.Output_File[0]):
|
||||
os.remove(parameters.Output_File[0])
|
||||
|
||||
saveDir = os.path.dirname(parameters.Output_File[0])
|
||||
if not os.path.exists(saveDir):
|
||||
os.mkdir(saveDir)
|
||||
with open(parameters.Output_File[0], "w", encoding='utf-8') as output_file:
|
||||
output_file.write(generate_code_str)
|
|
@ -0,0 +1,169 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 avx512 asm code
|
||||
void nnacl_gemm_avx512_@{row_block}x@{col_block}_kernel_nhwc_fp32(float *dst, const float *src, const float *weight,
|
||||
const float *bias, const size_t act_flag, const size_t row_block,
|
||||
const size_t col_block, const size_t deep, const size_t src_stride,
|
||||
const size_t dst_stride, const size_t inc_flag) {
|
||||
@asm_flag_list = []
|
||||
@row_split_number = [row for row in range(3, row_block, 3)]
|
||||
@for row in row_split_number:
|
||||
const float *dst_@{row} = dst + @{row} * dst_stride;
|
||||
@asm_flag_list.append("[dst_" + str(row) + "] " + "\"r\"(dst_" + str(row) + ")");
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
@col_split_num = col_block >> 4;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\\n"
|
||||
"je 0f\\n"
|
||||
@for row in range(0, row_block):
|
||||
@tmp = int(row / 3) * 3
|
||||
@for col in range(0, col_split_num):
|
||||
@if row % 3 == 0:
|
||||
"vmovups @{col * 64}(%[dst_@{tmp}]), %%zmm@{row * col_split_num + col}\\n"
|
||||
@else:
|
||||
"vmovups @{col * 64}(%[dst_@{tmp}], %[dst_stride], @{row - tmp}), %%zmm@{row * col_split_num + col}\\n"
|
||||
"jmp 2f\\n"
|
||||
"0:\\n"
|
||||
"cmpq $0, %[bias]\\n"
|
||||
"je 1f\\n"
|
||||
@for row in range(0, row_block):
|
||||
@for col in range(0, col_split_num):
|
||||
"vmovaps @{col * 64}(%[bias]), %%zmm@{row * col_split_num + col}\\n"
|
||||
"jmp 2f\\n"
|
||||
"1:\\n"
|
||||
@for row in range(0, row_block):
|
||||
@for col in range(0, col_split_num):
|
||||
"vxorps %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}, %%zmm@{row * col_split_num + col}\\n"
|
||||
"2:\\n"
|
||||
:
|
||||
@list = ["[dst_0] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)"]
|
||||
@list.extend(asm_flag_list)
|
||||
@print(" : " + ", ".join(list), file=OUT_STREAM)
|
||||
@print(" : " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, row_block * col_block >>4)]), file=OUT_STREAM)
|
||||
);
|
||||
@for row in row_split_number:
|
||||
const float *src_@{row} = src + @{row} * dst_stride;
|
||||
@asm_flag_list.append("[src_" + str(row) + "] " + "\"r\"(src_" + str(row) + ")");
|
||||
asm volatile(
|
||||
"0:\\n"
|
||||
@loop_count = 8
|
||||
@for i in range(0, loop_count):
|
||||
// block @{i}
|
||||
@for col in range(0, col_split_num):
|
||||
"vmovups @{col * 64 + i * col_block * 4}(%[weight]), %%zmm@{31 - col}\\n"
|
||||
@if col_split_num == 6:
|
||||
@if row_block == 4:
|
||||
"vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num}\\n"
|
||||
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], 1), %%zmm@{31 - col_split_num - 1}\\n"
|
||||
@for col in range(0, col_split_num):
|
||||
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{0 * col_split_num + col}\\n"
|
||||
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{1 * col_split_num + col}\\n"
|
||||
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], 2), %%zmm@{31 - col_split_num}\\n"
|
||||
"vbroadcastss @{i * 4}(%[src_3]), %%zmm@{31 - col_split_num - 1}\\n"
|
||||
@for col in range(0, col_split_num):
|
||||
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n"
|
||||
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{3 * col_split_num + col}\\n"
|
||||
@else:
|
||||
@for row in range(0, row_block):
|
||||
@if row == 0:
|
||||
"vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num - row}\\n"
|
||||
@else:
|
||||
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], @{row}), %%zmm@{31 - col_split_num - row}\\n"
|
||||
@for row in range(0, row_block):
|
||||
@for col in range(0, col_split_num):
|
||||
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - row}, %%zmm@{row * col_split_num + col}\\n"
|
||||
@elif col_split_num == 5:
|
||||
@if row_block == 5:
|
||||
"vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num}\\n"
|
||||
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], 1), %%zmm@{31 - col_split_num - 1}\\n"
|
||||
@for col in range(0, col_split_num):
|
||||
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{0 * col_split_num + col}\\n"
|
||||
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{1 * col_split_num + col}\\n"
|
||||
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], 2), %%zmm@{31 - col_split_num}\\n"
|
||||
"vbroadcastss @{i * 4}(%[src_3]), %%zmm@{31 - col_split_num - 1}\\n"
|
||||
@for col in range(0, col_split_num):
|
||||
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n"
|
||||
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - 1}, %%zmm@{3 * col_split_num + col}\\n"
|
||||
"vbroadcastss @{i * 4}(%[src_3], %[src_stride], 1), %%zmm@{31 - col_split_num}\\n"
|
||||
@for col in range(0, col_split_num):
|
||||
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num}, %%zmm@{2 * col_split_num + col}\\n"
|
||||
@else:
|
||||
@for row in range(0, row_block):
|
||||
@if row == 0:
|
||||
"vbroadcastss @{i * 4}(%[src_0]), %%zmm@{31 - col_split_num - row}\\n"
|
||||
@else:
|
||||
"vbroadcastss @{i * 4}(%[src_0], %[src_stride], @{row}), %%zmm@{31 - col_split_num - row}\\n"
|
||||
@for row in range(0, row_block):
|
||||
@for col in range(0, col_split_num):
|
||||
"vfmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - row}, %%zmm@{row * col_split_num + col}\\n"
|
||||
@elif col_split_num == 2:
|
||||
@for row in range(int(row_block / 6)):
|
||||
@for j in range(0, 6):
|
||||
@tmp = int(j / 3) * 3
|
||||
@if j % 3 == 0:
|
||||
"vbroadcastss @{i * 4}(%[src_@{row * 6 + tmp}]), %%zmm@{31 - col_split_num - j}\\n"
|
||||
@else:
|
||||
"vbroadcastss @{i * 4}(%[src_@{row * 6 + tmp}], %[src_stride], @{j - tmp}), %%zmm@{31 - col_split_num - j}\\n"
|
||||
@for col in range(0, col_split_num):
|
||||
@for j in range(0, 6):
|
||||
"fmadd231ps %%zmm@{31 - col}, %%zmm@{31 - col_split_num - j}, %%zmm@{(row * 6 + j) * col_split_num + col}\\n"
|
||||
|
||||
"dec %[deep]\\n"
|
||||
"add $@{col_block * 4 * 8}, %[weight]\\n"
|
||||
"add $@{loop_count * 4}, %[src_0]\\n"
|
||||
@for row in row_split_number:
|
||||
"add $@{loop_count * 4}, %[src_@{row}]\\n"
|
||||
"jg 0b\\n"
|
||||
|
||||
"and $0x2, %[inc_flag]\\n"
|
||||
"je 3f\\n"
|
||||
"movq %[act_flag], %rax\\n"
|
||||
"and $0x3, %eax\\n"
|
||||
"je 3f\\n"
|
||||
// relu
|
||||
"vxorps %zmm31, %zmm31, %zmm31\\n"
|
||||
@for col in range(0, col_split_num):
|
||||
@for row in range(0, row_block):
|
||||
"vmaxps %%zmm@{row + col * row_block}, %%zmm31, %%zmm@{row + col * row_block}\\n"
|
||||
"and $0x1, %eax\\n"
|
||||
"je 3f\\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %eax\\n"
|
||||
"vmovd %eax, %xmm30\\n"
|
||||
"vbroadcastss %xmm30, %zmm30\\n"
|
||||
@for col in range(0, col_split_num):
|
||||
@for row in range(0, row_block):
|
||||
"vminps %%zmm@{row + col * row_block}, %%zmm30, %%zmm@{row + col * row_block}\\n"
|
||||
"3:\\n"
|
||||
@for row in range(0, row_block):
|
||||
@tmp = int(row / 3) * 3
|
||||
@for col in range(0, col_split_num):
|
||||
@if row % 3 == 0:
|
||||
"vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{tmp}])\\n"
|
||||
@else:
|
||||
"vmovups %%zmm@{row * col_split_num + col}, @{col * 64}(%[dst_@{tmp}], %[dst_stride], @{row - tmp}),\\n"
|
||||
:
|
||||
@list = ["[src_0] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[deep] \"r\"(deep_t)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst_0] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"]
|
||||
@list.extend(asm_flag_list)
|
||||
@print(" : " + ", ".join(list), file=OUT_STREAM)
|
||||
@print(" : \"%rax\", " + ", ".join(["\"%zmm" + str(i) + "\"" for i in range(0, 32)]), file=OUT_STREAM)
|
||||
);
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma intrinsic code
|
||||
void nnacl_gemm_fma_@{row_block}x@{col_block}_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight,
|
||||
const float *bias, const size_t act_flag, const size_t row_block,
|
||||
const size_t col_block, const size_t deep, const size_t src_stride,
|
||||
const size_t dst_stride, const size_t inc_flag) {
|
||||
@for i in range(0, row_block):
|
||||
@for j in range(0, col_block >> 3):
|
||||
__m256 dst@{j * row_block + i};
|
||||
if (inc_flag) {
|
||||
@for i in range(0, row_block):
|
||||
@for j in range(0, col_block >> 3):
|
||||
dst@{j * row_block + i} = _mm256_load_ps(dst + @{j} * dst_stride + @{i * 8});
|
||||
} else if (bias == NULL) {
|
||||
@for i in range(0, row_block * col_block >> 3):
|
||||
dst@{i} = _mm256_setzero_ps();
|
||||
} else {
|
||||
@for i in range(0, row_block):
|
||||
@for j in range(0, col_block >> 3):
|
||||
dst@{j * row_block + i} = _mm256_load_ps(bias + @{j * 8});
|
||||
}
|
||||
for (int i = 0; i < (deep >> 3); ++i) {
|
||||
@for i in range(0, 8):
|
||||
// bock@{i}
|
||||
@if col_block == 32:
|
||||
@for row in range(0, row_block):
|
||||
__m256 src@{row}@{i} = _mm256_set1_ps(*(src + @{row * 8 + i}));
|
||||
@for col in range(0, col_block >> 3):
|
||||
__m256 weight@{col}@{i} = _mm256_load_ps(weight + @{col * 8 + i * col_block});
|
||||
@for row in range(0, row_block):
|
||||
dst@{row + col * row_block} = _mm256_fmadd_ps(dst@{row + col * row_block}, src@{row}@{i}, weight@{col}@{i});
|
||||
@else:
|
||||
@for col in range(0, col_block >> 3):
|
||||
__m256 weight@{col}@{i} = _mm256_load_ps(weight + @{col * 8 + i * col_block});
|
||||
@for row in range(0, row_block):
|
||||
__m256 src@{row}@{i} = _mm256_set1_ps(*(src + @{row * 8 + i}));
|
||||
@for col in range(0, col_block >> 3):
|
||||
dst@{row + col * row_block} = _mm256_fmadd_ps(dst@{row + col * row_block}, src@{row}@{i}, weight@{col}@{i});
|
||||
src = src + src_stride;
|
||||
weight += @{8 * col_block * 4};
|
||||
}
|
||||
if (act_flag & 0x02) {
|
||||
// relu6
|
||||
__m256 relu6 = _mm256_set1_ps(6.0f);
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
@for i in range(0, row_block):
|
||||
@for j in range(0, col_block >> 3):
|
||||
dst@{i + j * row_block} = _mm256_min_ps(dst@{i + j * row_block}, relu6);
|
||||
// relu
|
||||
@for i in range(0, row_block):
|
||||
@for j in range(0, col_block >> 3):
|
||||
dst@{i + j * row_block} = _mm256_max_ps(dst@{i + j * row_block}, relu);
|
||||
}
|
||||
if (act_flag & 0x01) {
|
||||
// relu
|
||||
__m256 relu = _mm256_setzero_ps();
|
||||
@for i in range(0, row_block):
|
||||
@for j in range(0, col_block >> 3):
|
||||
dst@{i + j * row_block} = _mm256_max_ps(dst@{i + j * row_block}, relu);
|
||||
}
|
||||
@if col_block == 32:
|
||||
@for j in range(0, col_block >> 3):
|
||||
@for i in range(0, row_block):
|
||||
_mm256_store_ps(dst + @{j} * src_stride + @{i * 8}, dst@{j * row_block + i});
|
||||
@else:
|
||||
@for j in range(0, col_block >> 3):
|
||||
@for i in range(0, row_block):
|
||||
_mm256_store_ps(dst + @{j} * src_stride + @{i * 8}, dst@{j * row_block + i});
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <x86intrin.h>
|
||||
|
||||
// nnacl gemm in x86 fma asm code
|
||||
void nnacl_gemm_fma_@{row_block}x@{col_block}_kernel_nc8hw8_fp32(float *dst, const float *src, const float *weight,
|
||||
const float *bias, const size_t act_flag, const size_t row_block,
|
||||
const size_t col_block, const size_t deep, const size_t src_stride,
|
||||
const size_t dst_stride, const size_t inc_flag) {
|
||||
@if col_block == 32:
|
||||
const float *dst_4 = dst + 3 * dst_stride;
|
||||
size_t deep_t = deep >> 3;
|
||||
size_t dst_stride_t = dst_stride << 2;
|
||||
size_t src_stride_t = src_stride << 2;
|
||||
asm volatile(
|
||||
// inc in deep
|
||||
"and $0x1, %[inc_flag]\\n"
|
||||
"je 0f\\n"
|
||||
@for col in range(0, min((col_block >> 3), 3)):
|
||||
@for row in range(0, row_block):
|
||||
@if col == 0:
|
||||
"vmovups @{row * 32}(%[dst]), %%ymm@{row + col * row_block}\\n"
|
||||
@else:
|
||||
"vmovups @{row * 32}(%[dst], %[dst_stride], @{col}), %%ymm@{row + col * row_block}\\n"
|
||||
@if col_block == 32:
|
||||
@for row in range(0, row_block):
|
||||
"vmovups @{row * 32}(%[dst_4]), %%ymm@{row + (col + 1) * row_block}\\n"
|
||||
"jmp 2f\\n"
|
||||
"0:\\n"
|
||||
"cmpq $0, %[bias]\\n"
|
||||
"je 1f\\n"
|
||||
@for col in range(0, col_block >> 3):
|
||||
@for row in range(0, row_block):
|
||||
"vmovaps @{col * 32}(%[bias]), %%ymm@{row + col * row_block}\\n"
|
||||
"jmp 2f\\n"
|
||||
"1:\\n"
|
||||
@for col in range(0, col_block >> 3):
|
||||
@for row in range(0, row_block):
|
||||
"vxorps %%ymm@{row + col * row_block}, %%ymm@{row + col * row_block}, %%ymm@{row + col * row_block}\\n"
|
||||
"2:\\n"
|
||||
:
|
||||
@list = ["[dst] \"r\"(dst)", "[bias] \"r\"(bias)", "[dst_stride] \"r\"(dst_stride_t)", "[inc_flag] \"r\"(inc_flag)"]
|
||||
@if col_block == 32:
|
||||
@list.append("[dst_4] \"r\"(dst_4)")
|
||||
@print(" : " + ", ".join(list), file=OUT_STREAM)
|
||||
@print(" : " + ", ".join(["\"%ymm" + str(i) + "\"" for i in range(0, row_block * col_block >> 3)]), file=OUT_STREAM)
|
||||
);
|
||||
asm volatile(
|
||||
"0:\\n"
|
||||
@for i in range(0, 8):
|
||||
// block @{i}
|
||||
@if col_block == 32:
|
||||
@for row in range(0, row_block):
|
||||
"vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{15 - row}\\n"
|
||||
@for col in range(0, col_block >> 3):
|
||||
"vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - row_block}\\n"
|
||||
@for row in range(0, row_block):
|
||||
"vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - row_block}, %%ymm@{15 - row}\\n"
|
||||
@elif col_block == 24:
|
||||
@for col in range(0, col_block >> 3):
|
||||
"vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n"
|
||||
@for row in range(0, row_block):
|
||||
"vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{14 - col}\\n"
|
||||
@for col in range(0, col_block >> 3):
|
||||
"vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3)}, %%ymm@{15 - col}\\n"
|
||||
@elif col_block == 16:
|
||||
@for col in range(0, col_block >> 3):
|
||||
"vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n"
|
||||
@for row in range(0, row_block >> 1):
|
||||
"vbroadcastss @{row * 64 + i}(%[src]), %%ymm@{14 - col}\\n"
|
||||
"vbroadcastss @{row * 64 + 32 + i}(%[src]), %%ymm@{13 - col}\\n"
|
||||
@for col in range(0, col_block >> 3):
|
||||
@for j in range(0, 2):
|
||||
"vfmadd231ps %%ymm@{row * 2 + j + col * row_block}, %%ymm@{15 - (col_block >> 3) - j}, %%ymm@{15 - col}\\n"
|
||||
@for row in range(row_block >> 1 << 1, row_block):
|
||||
"vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{14 - col}\\n"
|
||||
@for col in range(0, col_block >> 3):
|
||||
"vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3)}, %%ymm@{15 - col}\\n"
|
||||
@else:
|
||||
@for col in range(0, col_block >> 3):
|
||||
"vmovaps @{col * 32 + i * col_block * 4}(%[weight]), %%ymm@{15 - col}\\n"
|
||||
@split_num = 3
|
||||
@for row in range(0, int(row_block / split_num)):
|
||||
@for j in range(0, split_num):
|
||||
"vbroadcastss @{row * 96 + j * 32 + i}(%[src]), %%ymm@{15 - (col_block >> 3) - j}\\n"
|
||||
@for col in range(0, col_block >> 3):
|
||||
@for j in range(0, split_num):
|
||||
"vfmadd231ps %%ymm@{row * split_num + j + col * row_block}, %%ymm@{15 - (col_block >> 3) - j}, %%ymm@{15 - col}\\n"
|
||||
@for row in range(int(row_block / split_num) * split_num, row_block):
|
||||
"vbroadcastss @{row * 32 + i}(%[src]), %%ymm@{15 - (col_block >> 3) - (row - int(row_block / split_num) * split_num)}\\n"
|
||||
@for col in range(0, col_block >> 3):
|
||||
@for row in range(int(row_block / split_num) * split_num, row_block):
|
||||
"vfmadd231ps %%ymm@{row + col * row_block}, %%ymm@{15 - (col_block >> 3) - (row - int(row_block / split_num) * split_num)}, %%ymm@{15 - col}\\n"
|
||||
"dec %[deep]\\n"
|
||||
"add @{col_block * 4 * 8}, %[weight]\\n"
|
||||
"add %[src_stride], %[src]\\n"
|
||||
"jg 0b\\n"
|
||||
|
||||
"movq %[inc_flag], %rax\\n"
|
||||
"and $0x2, %eax\\n"
|
||||
"je 3f\\n"
|
||||
"movq %[act_flag], %rax\\n"
|
||||
"and $0x3, %eax\\n"
|
||||
"je 3f\\n"
|
||||
// relu
|
||||
"vxorps %ymm15, %ymm15, %ymm15\\n"
|
||||
@for col in range(0, col_block >> 3):
|
||||
@for row in range(0, row_block):
|
||||
"vmaxps %%ymm@{row + col * row_block}, %%ymm15, %%ymm@{row + col * row_block}\\n"
|
||||
"and $0x1, %eax\\n"
|
||||
"je 3f\\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %eax\\n"
|
||||
"vmovd %eax, %xmm14\\n"
|
||||
"vpermps %ymm14, %ymm15, %ymm14\\n"
|
||||
@for col in range(0, col_block >> 3):
|
||||
@for row in range(0, row_block):
|
||||
"vminps %%ymm@{row + col * row_block}, %%ymm14, %%ymm@{row + col * row_block}\\n"
|
||||
"3:\\n"
|
||||
@for col in range(0, min((col_block >> 3), 3)):
|
||||
@for row in range(0, row_block):
|
||||
@if col == 0:
|
||||
"vmovups %%ymm@{row + col * row_block}, @{row * 32}(%[dst])\\n"
|
||||
@else:
|
||||
"vmovups %%ymm@{row + col * row_block}, @{row * 32}(%[dst], %[dst_stride], @{col})\\n"
|
||||
@if col_block == 32:
|
||||
@for row in range(0, row_block):
|
||||
"vmovups %%ymm@{row + (col + 1) * row_block}, @{row * 32}(%[dst_4])\\n"
|
||||
:
|
||||
@list = ["[src] \"r\"(src)", "[src_stride] \"r\"(src_stride_t)", "[weight] \"r\"(weight)", "[deep] \"r\"(deep_t)", "[inc_flag] \"r\"(inc_flag)", "[act_flag] \"r\"(act_flag)", "[dst] \"r\"(dst)", "[dst_stride] \"r\"(dst_stride_t)"]
|
||||
@if col_block == 32:
|
||||
@list.append("[dst_4] \"r\"(dst_4)")
|
||||
@print(" : " + ", ".join(list), file=OUT_STREAM)
|
||||
@print(" : \"%rax\", " + ", ".join(["\"%ymm" + str(i) + "\"" for i in range(0, 16)]), file=OUT_STREAM)
|
||||
);
|
||||
}
|
Loading…
Reference in New Issue