mutvecmal

This commit is contained in:
lzk 2021-06-09 20:39:55 -07:00
parent f8f54091c3
commit 872e3acc9d
10 changed files with 718 additions and 143 deletions

View File

@ -25,7 +25,7 @@
#ifdef __cplusplus
extern "C" {
#endif
typedef void (*Row2ColMajorFuncPtr)(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
typedef void (*Row2ColMajorFuncPtr)(const float *src_ptr, float *dst_ptr, int row, int col);
#ifdef ENABLE_ARM64
typedef void (*MatmulFloatOptFuncPtr)(const float *a, const float *b, float *c, const float *bias, int act_type,
int depth, int row, int col, size_t stride, size_t write_mode);

View File

@ -15,7 +15,9 @@
*/
#include "nnacl/fp32/matmul_fp32.h"
#ifdef ENABLE_SSE
#include <x86intrin.h>
#endif
void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col) {
for (int r = 0; r < row; ++r) {
for (int c = 0; c < col; ++c) {
@ -114,6 +116,20 @@ void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col)
return;
}
void RowMajor2Row32Major(const float *src_ptr, float *dst_ptr, int row, int col) {
// Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met.
int row_block_num = UP_DIV(row, C8NUM);
int row_block = C4NUM;
for (int i = 0; i < row_block_num; i += row_block) {
row_block = MSMIN(C4NUM, row_block_num - i); // max_tile = 4
int row_remainder = MSMIN(row_block * C8NUM, row - i * C8NUM);
for (int oc = 0; oc < col; ++oc) {
memcpy(dst_ptr, src_ptr + oc * row + i * C8NUM, row_remainder * sizeof(float));
dst_ptr += row_block * C8NUM;
}
}
}
#ifdef ENABLE_ARM64
void RowMajor2Col12Major_arm64(const float *src_c, float *dst_c, size_t col) {
size_t stride = col * sizeof(float);
@ -237,7 +253,7 @@ void RowMajor2Col12Major_arm32(const float *src_c, float *dst_c, size_t col) {
return;
}
#endif
void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, int row, int col) {
const float *src_r = src_ptr;
float *dst_r = dst_ptr;
size_t ri = 0;
@ -508,7 +524,7 @@ void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) {
}
#endif
#endif
void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, int row, int col) {
size_t row8 = row / C8NUM * C8NUM;
#ifdef ENABLE_ARM64
size_t col_skip = col / C8NUM * C8NUM;
@ -591,7 +607,7 @@ void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t
return;
}
void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, int row, int col) {
size_t row16 = row / C16NUM * C16NUM;
size_t col_skip = col / C4NUM * C4NUM;
int skip_size = C4NUM;
@ -638,7 +654,25 @@ void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_
return;
}
void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
void RowMajor2Col32Major(const float *src_ptr, float *dst_ptr, int row, int col) {
// Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met.
int col_block_num = UP_DIV(col, C8NUM);
int col_block = C4NUM;
for (int i = 0; i < col_block_num; i += col_block) {
col_block = MSMIN(C4NUM, col_block_num - i); // max_tile = 4
int index = i * row * C8NUM;
int col_remainder = MSMIN(C8NUM * col_block, col - i * C8NUM);
for (int ir = 0; ir < row; ++ir) {
for (int oc = 0; oc < col_remainder; ++oc) {
int oc_index = oc * row + ir + index;
dst_ptr[oc] = src_ptr[oc_index];
}
dst_ptr += col_block * C8NUM;
}
}
}
void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, int row, int col) {
size_t totalRow = UP_ROUND(row, C6NUM);
size_t row6 = row / C6NUM * C6NUM;
size_t col8 = col / C8NUM * C8NUM;
@ -737,7 +771,7 @@ void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t
return;
}
void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, int row, int col) {
size_t total_row = UP_ROUND(row, C4NUM);
size_t row4 = row / C4NUM * C4NUM;
size_t col4 = col / C4NUM * C4NUM;
@ -845,7 +879,6 @@ void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias,
if (act_type == ActType_Relu || act_type == ActType_Relu6) value = MSMAX(0.0f, value);
c[ci] = value;
}
return;
}
#endif
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
@ -908,7 +941,6 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
}
}
}
return;
}
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
@ -943,3 +975,456 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
#endif
}
#ifdef ENABLE_AVX
void MatVecMulAvxFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col,
int col_align) {
// one time process 32 out_channel
int col_block = C32NUM;
int act_flag = 0;
if (act_type == ActType_Relu6) {
act_flag += 1;
}
if (act_type == ActType_Relu || act_type == ActType_Relu6) {
act_flag += 2;
}
MatVecMulKernel kernel[4] = {MatVecMul1x8Kernel, MatVecMul1x16Kernel, MatVecMul1x24Kernel, MatVecMul1x32Kernel};
const float *bias_data = bias;
for (int col_index = 0; col_index < cur_col; col_index += col_block) {
col_block = cur_col - col_index < col_block ? cur_col - col_index : col_block;
kernel[(col_block >> 3) - 1](c + col_index, a, b + col_index * depth, bias_data, act_flag, 1, col_block >> 3,
col_align, depth);
if (bias_data != NULL) {
bias_data += col_block;
}
}
}
void MatVecMul1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups 0x20(%2), %%ymm1\n"
"vmovups 0x40(%2), %%ymm2\n"
"vmovups 0x60(%2), %%ymm3\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
"1:\n" // deep_c8
"movq %3, %%rcx\n"
"shr $3, %%ecx\n"
"je 3f\n"
"2:\n"
"vbroadcastss (%0), %%ymm4\n"
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
"vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n"
"vbroadcastss 4(%0), %%ymm4\n"
"vfmadd231ps 128(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 160(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 192(%1), %%ymm4, %%ymm2\n"
"vfmadd231ps 224(%1), %%ymm4, %%ymm3\n"
"vbroadcastss 8(%0), %%ymm4\n"
"vfmadd231ps 256(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 288(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 320(%1), %%ymm4, %%ymm2\n"
"vfmadd231ps 352(%1), %%ymm4, %%ymm3\n"
"vbroadcastss 12(%0), %%ymm4\n"
"vfmadd231ps 384(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 416(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 448(%1), %%ymm4, %%ymm2\n"
"vfmadd231ps 480(%1), %%ymm4, %%ymm3\n"
"vbroadcastss 16(%0), %%ymm4\n"
"vfmadd231ps 512(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 544(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 576(%1), %%ymm4, %%ymm2\n"
"vfmadd231ps 608(%1), %%ymm4, %%ymm3\n"
"vbroadcastss 20(%0), %%ymm4\n"
"vfmadd231ps 640(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 672(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 704(%1), %%ymm4, %%ymm2\n"
"vfmadd231ps 736(%1), %%ymm4, %%ymm3\n"
"vbroadcastss 24(%0), %%ymm4\n"
"vfmadd231ps 768(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 800(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 832(%1), %%ymm4, %%ymm2\n"
"vfmadd231ps 864(%1), %%ymm4, %%ymm3\n"
"vbroadcastss 28(%0), %%ymm4\n"
"vfmadd231ps 896(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 928(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 960(%1), %%ymm4, %%ymm2\n"
"vfmadd231ps 992(%1), %%ymm4, %%ymm3\n"
"addq $1024, %1\n"
"addq $32, %0\n"
"dec %%ecx\n"
"jg 2b\n"
"3:\n"
"and $7, %3\n" // deep_remainder
"je 5f\n"
"4:\n"
"vbroadcastss (%0), %%ymm4\n"
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
"vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n"
"addq $128, %1\n"
"addq $4, %0\n"
"dec %3\n"
"jg 4b\n"
"5:\n"
"and $0x3, %%eax\n" // act_type
"je 6f\n"
// Relu
"vxorps %%ymm12, %%ymm12, %%ymm12\n"
"vmaxps %%ymm12, %%ymm0, %%ymm0\n"
"vmaxps %%ymm12, %%ymm1, %%ymm1\n"
"vmaxps %%ymm12, %%ymm2, %%ymm2\n"
"vmaxps %%ymm12, %%ymm3, %%ymm3\n"
"and $0x1, %%eax\n"
"je 6f\n"
// relu6
"mov $0x40C00000, %%ecx\n"
"vmovd %%ecx, %%xmm14\n"
"vpermps %%ymm14, %%ymm12, %%ymm14\n"
"vminps %%ymm14, %%ymm0, %%ymm0\n"
"vminps %%ymm14, %%ymm1, %%ymm1\n"
"vminps %%ymm14, %%ymm2, %%ymm2\n"
"vminps %%ymm14, %%ymm3, %%ymm3\n"
"6:\n"
"vmovups %%ymm0, (%5)\n" // dst_0
"vmovups %%ymm1, 0x20(%5)\n"
"vmovups %%ymm2, 0x40(%5)\n"
"vmovups %%ymm3, 0x60(%5)\n"
:
: "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5
: "%rcx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm4", "%ymm14");
}
void MatVecMul1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups 0x20(%2), %%ymm1\n"
"vmovups 0x40(%2), %%ymm2\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
"1:\n" // deep
"movq %3, %%rcx\n"
"shr $3, %%ecx\n"
"je 3f\n"
"2:\n"
"vbroadcastss (%0), %%ymm4\n"
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
"vbroadcastss 4(%0), %%ymm4\n"
"vfmadd231ps 96(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 128(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 160(%1), %%ymm4, %%ymm2\n"
"vbroadcastss 8(%0), %%ymm4\n"
"vfmadd231ps 192(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 224(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 256(%1), %%ymm4, %%ymm2\n"
"vbroadcastss 12(%0), %%ymm4\n"
"vfmadd231ps 288(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 320(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 352(%1), %%ymm4, %%ymm2\n"
"vbroadcastss 16(%0), %%ymm4\n"
"vfmadd231ps 384(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 416(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 448(%1), %%ymm4, %%ymm2\n"
"vbroadcastss 20(%0), %%ymm4\n"
"vfmadd231ps 480(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 512(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 544(%1), %%ymm4, %%ymm2\n"
"vbroadcastss 24(%0), %%ymm4\n"
"vfmadd231ps 576(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 608(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 640(%1), %%ymm4, %%ymm2\n"
"vbroadcastss 28(%0), %%ymm4\n"
"vfmadd231ps 672(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 704(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 736(%1), %%ymm4, %%ymm2\n"
"addq $768, %1\n"
"addq $32, %0\n"
"dec %%ecx\n"
"jg 2b\n"
"3:\n"
"and $7, %3\n" // deep_remainder
"je 5f\n"
"4:\n"
"vbroadcastss (%0), %%ymm4\n"
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
"vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
"addq $96, %1\n"
"addq $4, %0\n"
"dec %3\n"
"jg 4b\n"
"5:\n"
"and $0x3, %%eax\n" // act_type
"je 6f\n"
// Relu
"vxorps %%ymm12, %%ymm12, %%ymm12\n"
"vmaxps %%ymm12, %%ymm0, %%ymm0\n"
"vmaxps %%ymm12, %%ymm1, %%ymm1\n"
"vmaxps %%ymm12, %%ymm2, %%ymm2\n"
"and $0x1, %%eax\n"
"je 6f\n"
// relu6
"mov $0x40C00000, %%ecx\n"
"vmovd %%ecx, %%xmm14\n"
"vpermps %%ymm14, %%ymm12, %%ymm14\n"
"vminps %%ymm14, %%ymm0, %%ymm0\n"
"vminps %%ymm14, %%ymm1, %%ymm1\n"
"vminps %%ymm14, %%ymm2, %%ymm2\n"
"6:\n"
"vmovups %%ymm0, (%5)\n" // dst_0
"vmovups %%ymm1, 0x20(%5)\n"
"vmovups %%ymm2, 0x40(%5)\n"
:
: "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5
: "%rcx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm4", "%ymm14");
}
void MatVecMul1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"vmovups 0x20(%2), %%ymm1\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
"1:\n"
"movq %3, %%rcx\n"
"shr $3, %%ecx\n"
"je 3f\n"
"2:\n" // deep_c8
"vbroadcastss (%0), %%ymm4\n"
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
"vbroadcastss 4(%0), %%ymm4\n"
"vfmadd231ps 64(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 96(%1), %%ymm4, %%ymm1\n"
"vbroadcastss 8(%0), %%ymm4\n"
"vfmadd231ps 128(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 160(%1), %%ymm4, %%ymm1\n"
"vbroadcastss 12(%0), %%ymm4\n"
"vfmadd231ps 192(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 224(%1), %%ymm4, %%ymm1\n"
"vbroadcastss 16(%0), %%ymm4\n"
"vfmadd231ps 256(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 288(%1), %%ymm4, %%ymm1\n"
"vbroadcastss 20(%0), %%ymm4\n"
"vfmadd231ps 320(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 352(%1), %%ymm4, %%ymm1\n"
"vbroadcastss 24(%0), %%ymm4\n"
"vfmadd231ps 384(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 416(%1), %%ymm4, %%ymm1\n"
"vbroadcastss 28(%0), %%ymm4\n"
"vfmadd231ps 448(%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 480(%1), %%ymm4, %%ymm1\n"
"addq $512, %1\n"
"addq $32, %0\n"
"dec %%ecx\n"
"jg 2b\n"
"3:\n"
"and $7, %3\n"
"je 5f\n"
"4:\n"
"vbroadcastss (%0), %%ymm4\n"
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
"vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
"addq $64, %1\n"
"addq $4, %0\n"
"dec %3\n"
"jg 4b\n"
"5:\n"
"and $0x3, %%eax\n" // act_type
"je 6f\n"
// Relu
"vxorps %%ymm12, %%ymm12, %%ymm12\n"
"vmaxps %%ymm12, %%ymm0, %%ymm0\n"
"vmaxps %%ymm12, %%ymm1, %%ymm1\n"
"and $0x1, %%eax\n"
"je 6f\n"
// relu6
"mov $0x40C00000, %%ecx\n"
"vmovd %%ecx, %%xmm14\n"
"vpermps %%ymm14, %%ymm12, %%ymm14\n"
"vminps %%ymm14, %%ymm0, %%ymm0\n"
"vminps %%ymm14, %%ymm1, %%ymm1\n"
"6:\n"
"vmovups %%ymm0, (%5)\n" // dst_0
"vmovups %%ymm1, 0x20(%5)\n"
:
: "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5
: "%ecx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm12", "%ymm4", "%ymm14");
}
void MatVecMul1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
asm volatile(
"cmpq $0, %2\n"
"je 0f\n"
"vmovups (%2), %%ymm0\n"
"jmp 1f\n"
"0:\n"
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
"1:\n"
"movq %3, %%rcx\n"
"shr $3, %%ecx\n"
"je 3f\n"
"2:\n" // deep_c8
"vbroadcastss (%0), %%ymm4\n"
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
"vbroadcastss 4(%0), %%ymm4\n"
"vfmadd231ps 32(%1), %%ymm4, %%ymm0\n"
"vbroadcastss 8(%0), %%ymm4\n"
"vfmadd231ps 64(%1), %%ymm4, %%ymm0\n"
"vbroadcastss 12(%0), %%ymm4\n"
"vfmadd231ps 96(%1), %%ymm4, %%ymm0\n"
"vbroadcastss 16(%0), %%ymm4\n"
"vfmadd231ps 128(%1), %%ymm4, %%ymm0\n"
"vbroadcastss 20(%0), %%ymm4\n"
"vfmadd231ps 160(%1), %%ymm4, %%ymm0\n"
"vbroadcastss 24(%0), %%ymm4\n"
"vfmadd231ps 192(%1), %%ymm4, %%ymm0\n"
"vbroadcastss 28(%0), %%ymm4\n"
"vfmadd231ps 224(%1), %%ymm4, %%ymm0\n"
"addq $256, %1\n"
"addq $32, %0\n"
"dec %%ecx\n"
"jg 2b\n"
"3:\n"
"and $7, %3\n"
"je 5f\n"
"4:\n"
"vbroadcastss (%0), %%ymm4\n"
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
"addq $32, %1\n"
"addq $4, %0\n"
"dec %3\n"
"jg 4b\n"
"5:\n"
"and $0x3, %%eax\n" // act_type
"je 6f\n"
// Relu
"vxorps %%ymm12, %%ymm12, %%ymm12\n"
"vmaxps %%ymm12, %%ymm0, %%ymm0\n"
"and $0x1, %%eax\n"
"je 6f\n"
// relu6
"mov $0x40C00000, %%ecx\n"
"vmovd %%ecx, %%xmm14\n"
"vpermps %%ymm14, %%ymm12, %%ymm14\n"
"vminps %%ymm14, %%ymm0, %%ymm0\n"
"6:\n"
"vmovups %%ymm0, (%5)\n" // dst_0
:
: "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5
: "%ecx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm12", "%ymm4", "%ymm14");
}
#ifdef ENABLE_DEBUG
void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
__m256 dst_data[12];
const float *src_sw[12];
__m256 weight_data[4];
for (int i = 0; i < 4; ++i) {
weight_data[i] = _mm256_set1_ps(0.0f);
}
for (int i = 0; i < row_block; ++i) {
if (bias != NULL) {
for (int j = 0; j < col_block; ++j) {
dst_data[i * col_block + j] = _mm256_loadu_ps(bias + j * 8);
}
} else {
for (int j = 0; j < col_block; ++j) {
dst_data[i * col_block + j] = _mm256_set1_ps(0.0f);
}
}
src_sw[i] = src + i * deep;
}
const float *weight_kernel = weight;
for (int ic = 0; ic < deep; ++ic) {
for (int j = 0; j < col_block; ++j) {
weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM);
}
for (int i = 0; i < row_block; ++i) {
for (int j = 0; j < col_block; ++j) {
dst_data[i * col_block + j] =
_mm256_fmadd_ps(_mm256_set1_ps(src_sw[i][ic]), weight_data[j], dst_data[i * col_block + j]);
}
}
weight_kernel += C8NUM * col_block;
} // ic loop
// add bias and relu
for (int i = 0; i < row_block; ++i) {
for (int j = 0; j < col_block; ++j) {
if (0x1 & act_flag) { // relu6
dst_data[i * col_block + j] = _mm256_min_ps(dst_data[i * col_block + j], _mm256_set1_ps(6.0f));
}
if (0x2 & act_flag) { // relu
dst_data[i * col_block + j] = _mm256_max_ps(dst_data[i * col_block + j], _mm256_set1_ps(0.0f));
}
_mm256_storeu_ps(dst + i * col_algin + j * C8NUM, dst_data[i * col_block + j]);
}
}
}
#endif
#endif

View File

@ -46,11 +46,13 @@ void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Row32Major(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col32Major(const float *src_ptr, float *dst_ptr, int row, int col);
#ifdef ENABLE_ARM64
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
@ -78,6 +80,22 @@ void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *
#ifdef ENABLE_AVX
void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, size_t act_type, size_t depth,
size_t row, size_t col, size_t stride, size_t write_mode);
typedef void (*MatVecMulKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t row_block, size_t col_block, size_t col_algin, size_t deep);
void MatVecMulAvxFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col,
int col_align);
void MatVecMul1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t row_block, size_t col_block, size_t col_algin, size_t deep);
void MatVecMul1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t row_block, size_t col_block, size_t col_algin, size_t deep);
void MatVecMul1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t row_block, size_t col_block, size_t col_algin, size_t deep);
void MatVecMul1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t row_block, size_t col_block, size_t col_algin, size_t deep);
#ifdef ENABLE_DEBUG
void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
size_t row_block, size_t col_block, size_t col_algin, size_t deep);
#endif
#endif
#endif
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
@ -86,5 +104,4 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_FP32_MATMUL_H_

View File

@ -31,6 +31,7 @@
#define C8NUM 8
#define C12NUM 12
#define C16NUM 16
#define C32NUM 32
#define TILE_NUM 8
#define MSMIN(x, y) ((x) < (y) ? (x) : (y))

View File

@ -79,6 +79,10 @@ class MS_API Allocator {
///
/// \return Pointer of ready memory.
virtual void *Prepare(void *ptr) { return ptr; }
protected:
// memory aligned bytes
size_t aligned_size_ = 32;
};
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_ALLOCATOR_H_

View File

@ -23,7 +23,7 @@ std::shared_ptr<Allocator> Allocator::Create() {
return std::shared_ptr<Allocator>(new (std::nothrow) DefaultAllocator());
}
DefaultAllocator::DefaultAllocator() = default;
DefaultAllocator::DefaultAllocator(size_t aligned_size) { aligned_size_ = aligned_size; }
DefaultAllocator::~DefaultAllocator() { Clear(); }
@ -69,7 +69,7 @@ void *DefaultAllocator::Malloc(size_t size) {
return membuf->buf;
}
std::unique_ptr<MemBuf> membuf(reinterpret_cast<MemBuf *>(malloc(sizeof(MemBuf) + size)));
std::unique_ptr<MemBuf> membuf(reinterpret_cast<MemBuf *>(malloc(sizeof(MemBuf) + size + aligned_size_)));
if (membuf == nullptr) {
MS_LOG(ERROR) << "malloc membuf return nullptr";
UnLock();
@ -78,7 +78,10 @@ void *DefaultAllocator::Malloc(size_t size) {
this->total_size_ += size;
membuf->ref_count_ = 0;
membuf->size = size;
membuf->buf = reinterpret_cast<char *>(membuf.get()) + sizeof(MemBuf);
auto aligned_bytes =
reinterpret_cast<size_t>((reinterpret_cast<char *>(membuf.get()) + sizeof(MemBuf))) % aligned_size_;
aligned_bytes = aligned_bytes == 0 ? 0 : aligned_size_ - aligned_bytes;
membuf->buf = reinterpret_cast<char *>(membuf.get()) + sizeof(MemBuf) + aligned_bytes;
auto bufPtr = membuf->buf;
allocatedList_[bufPtr] = membuf.release();
UnLock();

View File

@ -35,7 +35,7 @@ struct AllocatorContext {
class DefaultAllocator : public Allocator {
public:
DefaultAllocator();
explicit DefaultAllocator(size_t aligned_size = 32);
~DefaultAllocator() override;
void SetContext(const AllocatorContext &ctx);
void *Malloc(size_t size) override;

View File

@ -27,13 +27,13 @@ namespace mindspore::kernel {
int FullconnectionCPUKernel::Init() {
MatmulFp32BaseCPUKernel::InitParameter();
if (params_->a_const_ == true) {
if (params_->a_const_) {
auto a_shape = in_tensors_.at(0)->shape();
params_->row_ = a_shape[0];
params_->deep_ = a_shape[1];
}
if (params_->b_const_ == true) {
if (params_->b_const_) {
auto b_shape = in_tensors_.at(1)->shape();
params_->col_ = b_shape[0];
params_->deep_ = b_shape[1];

View File

@ -16,6 +16,9 @@
#include "src/runtime/kernel/arm/fp32/matmul_fp32_base.h"
#include "nnacl/fp32/matmul_fp32.h"
#include "nnacl/fp32/pack_fp32.h"
using mindspore::lite::RET_NULL_PTR;
namespace mindspore::kernel {
int MatmulBaseFloatRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
@ -32,7 +35,6 @@ MatmulFp32BaseCPUKernel::~MatmulFp32BaseCPUKernel() {
FreeResizeBufA();
FreeResizeBufB();
FreeBiasBuf();
return;
}
void MatmulFp32BaseCPUKernel::InitParameter() {
@ -43,29 +45,25 @@ void MatmulFp32BaseCPUKernel::InitParameter() {
params_->a_const_ = false;
params_->b_const_ = false;
}
#ifdef ENABLE_AVX
row_tile_ = C6NUM;
col_tile_ = C16NUM;
#elif defined(ENABLE_ARM32)
row_tile_ = C12NUM;
col_tile_ = C4NUM;
#elif defined(ENABLE_SSE)
row_tile_ = C4NUM;
col_tile_ = C8NUM;
#else
row_tile_ = C12NUM;
col_tile_ = C8NUM;
#endif
return;
}
void MatmulFp32BaseCPUKernel::ResizeParameter() {
if (params_->row_ == 1) {
vec_matmul_ = true;
#ifdef ENABLE_AVX
// vector matmul col is aligned to C8NUM in avx
col_tile_ = C8NUM;
#endif
row_tile_ = 1;
}
params_->row_align_ = vec_matmul_ ? 1 : UP_ROUND(params_->row_, row_tile_);
params_->row_align_ = UP_ROUND(params_->row_, row_tile_);
#ifdef ENABLE_AVX
// avx is aligned to col_tile_
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
#else
params_->col_align_ = vec_matmul_ ? params_->col_ : UP_ROUND(params_->col_, col_tile_);
return;
#endif
oc_res_ = params_->col_ % col_tile_;
}
int MatmulFp32BaseCPUKernel::InitBufferA() {
@ -102,7 +100,7 @@ int MatmulFp32BaseCPUKernel::InitBufferB() {
int MatmulFp32BaseCPUKernel::CalBroadCastBiasDataElements() {
lite::Tensor *bias_tensor = in_tensors_.at(2);
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), C16NUM);
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), col_tile_);
if (!params_->b_const_) {
MS_LOG(WARNING) << "matmul do not support broadcast bias data";
} else {
@ -112,9 +110,9 @@ int MatmulFp32BaseCPUKernel::CalBroadCastBiasDataElements() {
return max_bias_data;
}
if (params_->b_transpose_) {
max_bias_data = UP_ROUND(const_tensor->shape()[shape_size - kBiasIndex], C16NUM);
max_bias_data = UP_ROUND(const_tensor->shape()[shape_size - kBiasIndex], col_tile_);
} else {
max_bias_data = UP_ROUND(const_tensor->shape()[shape_size - kWeightIndex], C16NUM);
max_bias_data = UP_ROUND(const_tensor->shape()[shape_size - kWeightIndex], col_tile_);
}
}
return max_bias_data;
@ -123,26 +121,22 @@ int MatmulFp32BaseCPUKernel::CalBroadCastBiasDataElements() {
int MatmulFp32BaseCPUKernel::InitBiasData() {
if (in_tensors_.size() == 3) {
auto bias_tensor = in_tensors_[2];
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), C16NUM);
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), col_tile_);
// malloc addr need to aligned to 32 bytes
bias_ptr_ = reinterpret_cast<float *>(malloc(max_bias_data * sizeof(float)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
return RET_ERROR;
}
// whether to broadcast bias data
if (bias_tensor->ElementsNum() == 1) {
max_bias_data = CalBroadCastBiasDataElements();
bias_ptr_ = reinterpret_cast<float *>(malloc(max_bias_data * sizeof(float)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
return RET_ERROR;
}
float broadcast_data = (reinterpret_cast<float *>(bias_tensor->data_c()))[0];
// broadcast bias data
for (int i = 0; i < max_bias_data; ++i) {
bias_ptr_[i] = broadcast_data;
}
} else {
bias_ptr_ = reinterpret_cast<float *>(malloc(max_bias_data * sizeof(float)));
if (bias_ptr_ == nullptr) {
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
return RET_ERROR;
}
memset(bias_ptr_, 0, max_bias_data * sizeof(float));
memcpy(bias_ptr_, bias_tensor->data_c(), bias_tensor->ElementsNum() * sizeof(float));
}
@ -159,38 +153,32 @@ int MatmulFp32BaseCPUKernel::InitMatrixA(const float *src_ptr) {
for (int i = 0; i < params_->batch; i++) {
const float *src = src_ptr + i * params_->deep_ * params_->row_;
float *dst = a_pack_ptr_ + i * params_->deep_ * params_->row_align_;
#ifdef ENABLE_AVX
if (params_->a_transpose_) {
RowMajor2Row6Major(src, dst, params_->deep_, params_->row_);
matrix_a_pack_fun_(src, dst, params_->deep_, params_->row_);
} else {
RowMajor2Col6Major(src, dst, params_->row_, params_->deep_);
matrix_a_pack_fun_(src, dst, params_->row_, params_->deep_);
}
#elif defined(ENABLE_SSE)
if (params_->a_transpose_) {
RowMajor2Row4Major(src, dst, params_->deep_, params_->row_);
} else {
RowMajor2Col4Major(src, dst, params_->row_, params_->deep_);
}
#else
if (params_->a_transpose_) {
RowMajor2Row12Major(src, dst, params_->deep_, params_->row_);
} else {
RowMajor2Col12Major(src, dst, params_->row_, params_->deep_);
}
#endif
}
return RET_OK;
}
int MatmulFp32BaseCPUKernel::InitMatrixB(const float *src_ptr) {
if (vec_matmul_) {
if (params_->b_transpose_) {
memcpy(b_pack_ptr_, src_ptr, params_->batch * params_->col_ * params_->deep_ * sizeof(float));
} else {
for (int i = 0; i < params_->batch; i++) {
const float *src_data = src_ptr + i * params_->deep_ * params_->col_;
float *dst_data = b_pack_ptr_ + i * params_->deep_ * params_->col_;
RowMajor2ColMajor(src_data, dst_data, params_->deep_, params_->col_);
for (int i = 0; i < params_->batch; i++) {
const float *src_data = src_ptr + i * params_->deep_ * params_->col_;
float *dst = b_pack_ptr_ + i * params_->deep_ * params_->col_align_;
if (params_->b_transpose_) {
#ifdef ENABLE_AVX
RowMajor2Col32Major(src_data, dst, params_->deep_, params_->col_);
#else
memcpy(dst, src_data, params_->col_ * params_->deep_ * sizeof(float));
#endif
} else {
#ifdef ENABLE_AVX
RowMajor2Row32Major(src_data, dst, params_->col_, params_->deep_);
#else
RowMajor2ColMajor(src_data, dst, params_->deep_, params_->col_);
#endif
}
}
return RET_OK;
@ -199,25 +187,11 @@ int MatmulFp32BaseCPUKernel::InitMatrixB(const float *src_ptr) {
for (int i = 0; i < params_->batch; i++) {
const float *src = src_ptr + i * params_->deep_ * params_->col_;
float *dst = b_pack_ptr_ + i * params_->deep_ * params_->col_align_;
#ifdef ENABLE_AVX
if (params_->b_transpose_) {
RowMajor2Col16Major(src, dst, params_->col_, params_->deep_);
matrix_b_pack_fun_(src, dst, params_->col_, params_->deep_);
} else {
RowMajor2Row16Major(src, dst, params_->deep_, params_->col_);
matrix_b_pack_fun_(src, dst, params_->deep_, params_->col_);
}
#elif defined(ENABLE_ARM32)
if (params_->b_transpose_) {
RowMajor2Col4Major(src, dst, params_->col_, params_->deep_);
} else {
RowMajor2Row4Major(src, dst, params_->deep_, params_->col_);
}
#else
if (params_->b_transpose_) {
RowMajor2Col8Major(src, dst, params_->col_, params_->deep_);
} else {
RowMajor2Row8Major(src, dst, params_->deep_, params_->col_);
}
#endif
}
return RET_OK;
}
@ -227,7 +201,6 @@ void MatmulFp32BaseCPUKernel::FreeBiasBuf() {
free(bias_ptr_);
bias_ptr_ = nullptr;
}
return;
}
void MatmulFp32BaseCPUKernel::FreeResizeBufA() {
@ -239,7 +212,6 @@ void MatmulFp32BaseCPUKernel::FreeResizeBufA() {
} else {
a_pack_ptr_ = nullptr;
}
return;
}
void MatmulFp32BaseCPUKernel::FreeResizeBufB() {
@ -251,22 +223,34 @@ void MatmulFp32BaseCPUKernel::FreeResizeBufB() {
} else {
b_pack_ptr_ = nullptr;
}
return;
}
int MatmulFp32BaseCPUKernel::FloatRun(int task_id) {
int current_stride_oc = thread_stride_ * col_tile_;
int current_rest_oc = params_->col_ - task_id * thread_stride_ * col_tile_;
int cur_oc = MSMIN(current_stride_oc, current_rest_oc);
int current_start_oc = task_id * thread_stride_ * col_tile_;
int current_rest_oc = 0;
#if defined(ENABLE_AVX)
if (vec_matmul_) {
current_rest_oc = params_->col_align_ - current_start_oc;
} else {
current_rest_oc = params_->col_ - current_start_oc;
}
#else
current_rest_oc = params_->col_ - current_start_oc;
#endif
int cur_oc = MSMIN(thread_stride_ * col_tile_, current_rest_oc);
if (cur_oc <= 0) {
return RET_OK;
}
auto b = batch_b_ptr_ + task_id * thread_stride_ * col_tile_ * params_->deep_;
auto c = batch_c_ptr_ + task_id * thread_stride_ * col_tile_;
auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + task_id * thread_stride_ * col_tile_;
auto b = batch_b_ptr_ + current_start_oc * params_->deep_;
auto c = batch_c_ptr_ + current_start_oc;
auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + current_start_oc;
if (vec_matmul_) {
#ifdef ENABLE_AVX
MatVecMulAvxFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc, params_->col_align_);
#else
MatVecMulFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc);
#endif
} else {
MatMulOpt(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, params_->row_, cur_oc, params_->col_,
OutType_Nhwc);
@ -275,81 +259,141 @@ int MatmulFp32BaseCPUKernel::FloatRun(int task_id) {
}
int MatmulFp32BaseCPUKernel::Init() {
ResizeParameter();
#ifdef ENABLE_AVX
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2Row6Major : RowMajor2Col6Major;
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col16Major : RowMajor2Row16Major;
row_tile_ = C6NUM;
col_tile_ = C16NUM;
#elif defined(ENABLE_ARM32)
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2Row12Major : RowMajor2Col12Major;
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col4Major : RowMajor2Row4Major;
row_tile_ = C12NUM;
col_tile_ = C4NUM;
#elif defined(ENABLE_SSE)
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2Row4Major : RowMajor2Col4Major;
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col8Major : RowMajor2Row8Major;
row_tile_ = C4NUM;
col_tile_ = C8NUM;
#else
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2Row12Major : RowMajor2Col12Major;
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col8Major : RowMajor2Row8Major;
row_tile_ = C12NUM;
col_tile_ = C8NUM;
#endif
params_->row_align_ = UP_ROUND(params_->row_, row_tile_);
matrix_a_pack_size_ = params_->batch * params_->row_align_ * params_->deep_;
matrix_b_pack_size_ = params_->batch * params_->col_align_ * params_->deep_;
if ((matrix_a_pack_size_ + matrix_b_pack_size_) < 0) {
if (matrix_a_pack_size_ < 0) {
MS_LOG(ERROR) << "Matrix pack size is negative "
<< "matrix_a_pack_size=" << matrix_a_pack_size_ << "matrix_b_pack_size" << matrix_b_pack_size_;
<< "matrix_a_pack_size=" << matrix_a_pack_size_;
return RET_ERROR;
}
auto ret = InitBiasData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitBiasData failed";
return ret;
}
if (params_->a_const_ == true) {
if (params_->a_const_) {
if (RET_OK != InitBufferA()) {
return RET_ERROR;
}
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->data_c()));
ret = InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->data_c()));
if (ret != RET_OK) {
MS_LOG(ERROR) << "InitMatrixA failed!";
return ret;
}
}
if (params_->b_const_ == true) {
/* copy origin b data, pack in resize
* pack after a infershape done */
if (params_->b_const_) {
// only copy weight data
// resize or run to pack
auto b_tensor = in_tensors_[1];
src_b_ = reinterpret_cast<float *>(malloc(params_->batch * params_->col_ * params_->deep_ * sizeof(float)));
src_b_ = reinterpret_cast<float *>(malloc(params_->batch * params_->deep_ * params_->col_ * sizeof(float)));
if (src_b_ == nullptr) {
MS_LOG(ERROR) << "Matmul fp16 malloc src_b_ failed";
MS_LOG(ERROR) << "matmul fp16 src_b_ is failed!";
return RET_ERROR;
}
memcpy(src_b_, b_tensor->data_c(), params_->batch * params_->col_ * params_->deep_ * sizeof(float));
memcpy(src_b_, b_tensor->data_c(), params_->batch * params_->deep_ * params_->col_ * sizeof(float));
}
return RET_OK;
}
void MatmulFp32BaseCPUKernel::FreeBuffSrcB() {
if (src_b_ != nullptr) {
free(src_b_);
src_b_ = nullptr;
}
}
int MatmulFp32BaseCPUKernel::ReSize() {
ResizeParameter();
matrix_a_pack_size_ = params_->batch * params_->row_align_ * params_->deep_;
matrix_b_pack_size_ = params_->batch * params_->col_align_ * params_->deep_;
if ((matrix_a_pack_size_ + matrix_b_pack_size_) < 0) {
if (matrix_a_pack_size_ < 0 || matrix_b_pack_size_ < 0) {
MS_LOG(ERROR) << "Matrix pack size is negative "
<< "matrix_a_pack_size=" << matrix_a_pack_size_ << "matrix_b_pack_size" << matrix_b_pack_size_;
<< "matrix_a_pack_size=" << matrix_a_pack_size_ << "matrix_b_pack_size=" << matrix_b_pack_size_;
return RET_ERROR;
}
if (op_parameter_->is_train_session_) {
set_workspace_size((matrix_a_pack_size_ + matrix_b_pack_size_) * sizeof(float));
}
if (params_->b_const_ == true && src_b_ != nullptr) {
if (RET_OK != InitBufferB()) {
if (params_->b_const_ && src_b_ != nullptr) {
if (InitBufferB() != RET_OK) {
FreeBuffSrcB();
return RET_ERROR;
}
InitMatrixB(src_b_);
free(src_b_);
src_b_ = nullptr;
if (InitMatrixB(src_b_) != RET_OK) {
FreeBuffSrcB();
MS_LOG(ERROR) << "InitMatrixB failed!";
return RET_ERROR;
}
FreeBuffSrcB();
}
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_align_, col_tile_));
#if defined(ENABLE_AVX)
if (vec_matmul_) {
thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_ * C4NUM), thread_count_) * C4NUM;
} else {
thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_);
}
#else
thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_);
#endif
return RET_OK;
}
int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() {
auto out_data = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
MS_ASSERT(out_data != nullptr);
#ifdef ENABLE_AVX
if (oc_res_ != 0 && vec_matmul_) { // vec matmul need to malloc dst
int out_channel = params_->col_;
int oc_block_num = UP_DIV(out_channel, col_tile_);
MS_ASSERT(context_->allocator != nullptr);
output_data_ = reinterpret_cast<float *>(
context_->allocator->Malloc(params_->batch * params_->row_ * oc_block_num * col_tile_ * sizeof(float)));
if (output_data_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp output data failed.";
return RET_NULL_PTR;
}
} else { // need to malloc dst to algin block
output_data_ = out_data;
}
#else
output_data_ = out_data;
#endif
return RET_OK;
}
int MatmulFp32BaseCPUKernel::Run() {
auto a_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
auto b_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->data_c());
auto c_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
if (params_->a_const_ == false) {
if (!params_->a_const_) {
auto a_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
if (RET_OK != InitBufferA()) {
return RET_ERROR;
}
InitMatrixA(a_ptr);
}
if (params_->b_const_ == false) {
if (!params_->b_const_) {
auto b_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->data_c());
if (RET_OK != InitBufferB()) {
FreeResizeBufA();
return RET_ERROR;
@ -357,31 +401,45 @@ int MatmulFp32BaseCPUKernel::Run() {
InitMatrixB(b_ptr);
}
auto ret = InitTmpOutBuffer();
if (ret != RET_OK) {
FreeResizeBufA();
FreeResizeBufB();
MS_LOG(ERROR) << "InitTmpOutBuffer error!";
return ret;
}
for (int i = 0; i < params_->batch; ++i) {
batch_a_ptr_ = a_pack_ptr_ + i * params_->row_align_ * params_->deep_;
batch_b_ptr_ = b_pack_ptr_ + i * params_->deep_ * params_->col_align_;
if (vec_matmul_) {
batch_a_ptr_ = a_pack_ptr_ + i * params_->deep_;
batch_b_ptr_ = b_pack_ptr_ + i * params_->deep_ * params_->col_;
batch_c_ptr_ = c_ptr + i * params_->row_ * params_->col_;
batch_c_ptr_ = output_data_ + i * params_->row_ * params_->col_align_;
} else {
batch_a_ptr_ = a_pack_ptr_ + i * params_->row_align_ * params_->deep_;
batch_b_ptr_ = b_pack_ptr_ + i * params_->deep_ * params_->col_align_;
batch_c_ptr_ = c_ptr + i * params_->row_ * params_->col_;
// need not aligned
batch_c_ptr_ = output_data_ + i * params_->row_ * params_->col_;
}
auto ret = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(MatmulBaseFloatRun, this, thread_count_);
ret = static_cast<const lite::InnerContext *>(this->context_)
->thread_pool_->ParallelLaunch(MatmulBaseFloatRun, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulBaseFloatRun failed";
return ret;
}
}
if (params_->a_const_ == false) {
#ifdef ENABLE_AVX
if (oc_res_ != 0 && vec_matmul_) {
auto out_data = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
PackNHWCXToNHWCFp32(output_data_, out_data, params_->batch, params_->row_, params_->col_, col_tile_);
context_->allocator->Free(output_data_);
output_data_ = nullptr;
}
#endif
if (!params_->a_const_) {
FreeResizeBufA();
}
if (params_->b_const_ == false) {
if (!params_->b_const_) {
FreeResizeBufB();
}
return RET_OK;
return ret;
}
} // namespace mindspore::kernel

View File

@ -27,6 +27,7 @@ using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
namespace mindspore::kernel {
using MatrixPackFun = void (*)(const float *src_ptr, float *dst_ptr, int row, int col);
class MatmulFp32BaseCPUKernel : public InnerKernel {
public:
MatmulFp32BaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
@ -35,7 +36,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
params_ = reinterpret_cast<MatMulParameter *>(op_parameter_);
vec_matmul_ = false;
}
~MatmulFp32BaseCPUKernel();
~MatmulFp32BaseCPUKernel() override;
int Init() override;
int ReSize() override;
int Run() override;
@ -56,7 +57,9 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
void ResizeParameter();
void FreeResizeBufA();
void FreeResizeBufB();
void FreeBuffSrcB();
int CalBroadCastBiasDataElements();
int InitTmpOutBuffer();
protected:
MatMulParameter *params_ = nullptr;
@ -66,16 +69,20 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
private:
int col_tile_ = 0;
int row_tile_ = 0;
int oc_res_ = 0;
int thread_stride_ = 0;
int thread_count_ = 0;
bool vec_matmul_ = false;
float *src_b_ = nullptr;
float *bias_ptr_ = nullptr;
float *batch_a_ptr_ = nullptr;
float *batch_b_ptr_ = nullptr;
float *batch_c_ptr_ = nullptr;
float *output_data_ = nullptr;
int matrix_a_pack_size_ = -1;
int matrix_b_pack_size_ = -1;
float *src_b_ = nullptr;
MatrixPackFun matrix_a_pack_fun_ = nullptr;
MatrixPackFun matrix_b_pack_fun_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_FP32_BASE_H_