mutvecmal
This commit is contained in:
parent
f8f54091c3
commit
872e3acc9d
|
@ -25,7 +25,7 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
typedef void (*Row2ColMajorFuncPtr)(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
typedef void (*Row2ColMajorFuncPtr)(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
#ifdef ENABLE_ARM64
|
||||
typedef void (*MatmulFloatOptFuncPtr)(const float *a, const float *b, float *c, const float *bias, int act_type,
|
||||
int depth, int row, int col, size_t stride, size_t write_mode);
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
*/
|
||||
|
||||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
|
||||
#ifdef ENABLE_SSE
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
void RowMajor2ColMajor(const float *src_ptr, float *dst_ptr, int row, int col) {
|
||||
for (int r = 0; r < row; ++r) {
|
||||
for (int c = 0; c < col; ++c) {
|
||||
|
@ -114,6 +116,20 @@ void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col)
|
|||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Row32Major(const float *src_ptr, float *dst_ptr, int row, int col) {
|
||||
// Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met.
|
||||
int row_block_num = UP_DIV(row, C8NUM);
|
||||
int row_block = C4NUM;
|
||||
for (int i = 0; i < row_block_num; i += row_block) {
|
||||
row_block = MSMIN(C4NUM, row_block_num - i); // max_tile = 4
|
||||
int row_remainder = MSMIN(row_block * C8NUM, row - i * C8NUM);
|
||||
for (int oc = 0; oc < col; ++oc) {
|
||||
memcpy(dst_ptr, src_ptr + oc * row + i * C8NUM, row_remainder * sizeof(float));
|
||||
dst_ptr += row_block * C8NUM;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void RowMajor2Col12Major_arm64(const float *src_c, float *dst_c, size_t col) {
|
||||
size_t stride = col * sizeof(float);
|
||||
|
@ -237,7 +253,7 @@ void RowMajor2Col12Major_arm32(const float *src_c, float *dst_c, size_t col) {
|
|||
return;
|
||||
}
|
||||
#endif
|
||||
void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
|
||||
void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, int row, int col) {
|
||||
const float *src_r = src_ptr;
|
||||
float *dst_r = dst_ptr;
|
||||
size_t ri = 0;
|
||||
|
@ -508,7 +524,7 @@ void RowMajor2Col8Major_arm32(const float *src_c, float *dst_c, size_t col) {
|
|||
}
|
||||
#endif
|
||||
#endif
|
||||
void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
|
||||
void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, int row, int col) {
|
||||
size_t row8 = row / C8NUM * C8NUM;
|
||||
#ifdef ENABLE_ARM64
|
||||
size_t col_skip = col / C8NUM * C8NUM;
|
||||
|
@ -591,7 +607,7 @@ void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t
|
|||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
|
||||
void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, int row, int col) {
|
||||
size_t row16 = row / C16NUM * C16NUM;
|
||||
size_t col_skip = col / C4NUM * C4NUM;
|
||||
int skip_size = C4NUM;
|
||||
|
@ -638,7 +654,25 @@ void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_
|
|||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
|
||||
void RowMajor2Col32Major(const float *src_ptr, float *dst_ptr, int row, int col) {
|
||||
// Not exactly aligned to 32, but aligned to 24 or 16 or 8 If 32 is not met.
|
||||
int col_block_num = UP_DIV(col, C8NUM);
|
||||
int col_block = C4NUM;
|
||||
for (int i = 0; i < col_block_num; i += col_block) {
|
||||
col_block = MSMIN(C4NUM, col_block_num - i); // max_tile = 4
|
||||
int index = i * row * C8NUM;
|
||||
int col_remainder = MSMIN(C8NUM * col_block, col - i * C8NUM);
|
||||
for (int ir = 0; ir < row; ++ir) {
|
||||
for (int oc = 0; oc < col_remainder; ++oc) {
|
||||
int oc_index = oc * row + ir + index;
|
||||
dst_ptr[oc] = src_ptr[oc_index];
|
||||
}
|
||||
dst_ptr += col_block * C8NUM;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, int row, int col) {
|
||||
size_t totalRow = UP_ROUND(row, C6NUM);
|
||||
size_t row6 = row / C6NUM * C6NUM;
|
||||
size_t col8 = col / C8NUM * C8NUM;
|
||||
|
@ -737,7 +771,7 @@ void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t
|
|||
return;
|
||||
}
|
||||
|
||||
void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col) {
|
||||
void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, int row, int col) {
|
||||
size_t total_row = UP_ROUND(row, C4NUM);
|
||||
size_t row4 = row / C4NUM * C4NUM;
|
||||
size_t col4 = col / C4NUM * C4NUM;
|
||||
|
@ -845,7 +879,6 @@ void MatVecMulFp32(const float *a, const float *b, float *c, const float *bias,
|
|||
if (act_type == ActType_Relu || act_type == ActType_Relu6) value = MSMAX(0.0f, value);
|
||||
c[ci] = value;
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||
|
@ -908,7 +941,6 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
|
|||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
|
||||
|
@ -943,3 +975,456 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
|
|||
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef ENABLE_AVX
|
||||
void MatVecMulAvxFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col,
|
||||
int col_align) {
|
||||
// one time process 32 out_channel
|
||||
int col_block = C32NUM;
|
||||
int act_flag = 0;
|
||||
if (act_type == ActType_Relu6) {
|
||||
act_flag += 1;
|
||||
}
|
||||
if (act_type == ActType_Relu || act_type == ActType_Relu6) {
|
||||
act_flag += 2;
|
||||
}
|
||||
MatVecMulKernel kernel[4] = {MatVecMul1x8Kernel, MatVecMul1x16Kernel, MatVecMul1x24Kernel, MatVecMul1x32Kernel};
|
||||
const float *bias_data = bias;
|
||||
for (int col_index = 0; col_index < cur_col; col_index += col_block) {
|
||||
col_block = cur_col - col_index < col_block ? cur_col - col_index : col_block;
|
||||
kernel[(col_block >> 3) - 1](c + col_index, a, b + col_index * depth, bias_data, act_flag, 1, col_block >> 3,
|
||||
col_align, depth);
|
||||
if (bias_data != NULL) {
|
||||
bias_data += col_block;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MatVecMul1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
|
||||
size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
|
||||
asm volatile(
|
||||
"cmpq $0, %2\n"
|
||||
"je 0f\n"
|
||||
"vmovups (%2), %%ymm0\n"
|
||||
"vmovups 0x20(%2), %%ymm1\n"
|
||||
"vmovups 0x40(%2), %%ymm2\n"
|
||||
"vmovups 0x60(%2), %%ymm3\n"
|
||||
"jmp 1f\n"
|
||||
"0:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
"vxorps %%ymm3, %%ymm3, %%ymm3\n"
|
||||
"1:\n" // deep_c8
|
||||
"movq %3, %%rcx\n"
|
||||
"shr $3, %%ecx\n"
|
||||
"je 3f\n"
|
||||
"2:\n"
|
||||
"vbroadcastss (%0), %%ymm4\n"
|
||||
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
|
||||
"vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n"
|
||||
|
||||
"vbroadcastss 4(%0), %%ymm4\n"
|
||||
"vfmadd231ps 128(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 160(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 192(%1), %%ymm4, %%ymm2\n"
|
||||
"vfmadd231ps 224(%1), %%ymm4, %%ymm3\n"
|
||||
|
||||
"vbroadcastss 8(%0), %%ymm4\n"
|
||||
"vfmadd231ps 256(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 288(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 320(%1), %%ymm4, %%ymm2\n"
|
||||
"vfmadd231ps 352(%1), %%ymm4, %%ymm3\n"
|
||||
|
||||
"vbroadcastss 12(%0), %%ymm4\n"
|
||||
"vfmadd231ps 384(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 416(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 448(%1), %%ymm4, %%ymm2\n"
|
||||
"vfmadd231ps 480(%1), %%ymm4, %%ymm3\n"
|
||||
|
||||
"vbroadcastss 16(%0), %%ymm4\n"
|
||||
"vfmadd231ps 512(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 544(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 576(%1), %%ymm4, %%ymm2\n"
|
||||
"vfmadd231ps 608(%1), %%ymm4, %%ymm3\n"
|
||||
|
||||
"vbroadcastss 20(%0), %%ymm4\n"
|
||||
"vfmadd231ps 640(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 672(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 704(%1), %%ymm4, %%ymm2\n"
|
||||
"vfmadd231ps 736(%1), %%ymm4, %%ymm3\n"
|
||||
|
||||
"vbroadcastss 24(%0), %%ymm4\n"
|
||||
"vfmadd231ps 768(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 800(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 832(%1), %%ymm4, %%ymm2\n"
|
||||
"vfmadd231ps 864(%1), %%ymm4, %%ymm3\n"
|
||||
|
||||
"vbroadcastss 28(%0), %%ymm4\n"
|
||||
"vfmadd231ps 896(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 928(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 960(%1), %%ymm4, %%ymm2\n"
|
||||
"vfmadd231ps 992(%1), %%ymm4, %%ymm3\n"
|
||||
"addq $1024, %1\n"
|
||||
"addq $32, %0\n"
|
||||
"dec %%ecx\n"
|
||||
"jg 2b\n"
|
||||
|
||||
"3:\n"
|
||||
"and $7, %3\n" // deep_remainder
|
||||
"je 5f\n"
|
||||
"4:\n"
|
||||
"vbroadcastss (%0), %%ymm4\n"
|
||||
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
|
||||
"vfmadd231ps 0x60(%1), %%ymm4, %%ymm3\n"
|
||||
"addq $128, %1\n"
|
||||
"addq $4, %0\n"
|
||||
"dec %3\n"
|
||||
"jg 4b\n"
|
||||
|
||||
"5:\n"
|
||||
"and $0x3, %%eax\n" // act_type
|
||||
"je 6f\n"
|
||||
// Relu
|
||||
"vxorps %%ymm12, %%ymm12, %%ymm12\n"
|
||||
"vmaxps %%ymm12, %%ymm0, %%ymm0\n"
|
||||
"vmaxps %%ymm12, %%ymm1, %%ymm1\n"
|
||||
"vmaxps %%ymm12, %%ymm2, %%ymm2\n"
|
||||
"vmaxps %%ymm12, %%ymm3, %%ymm3\n"
|
||||
"and $0x1, %%eax\n"
|
||||
"je 6f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%ecx\n"
|
||||
"vmovd %%ecx, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm12, %%ymm14\n"
|
||||
"vminps %%ymm14, %%ymm0, %%ymm0\n"
|
||||
"vminps %%ymm14, %%ymm1, %%ymm1\n"
|
||||
"vminps %%ymm14, %%ymm2, %%ymm2\n"
|
||||
"vminps %%ymm14, %%ymm3, %%ymm3\n"
|
||||
"6:\n"
|
||||
"vmovups %%ymm0, (%5)\n" // dst_0
|
||||
"vmovups %%ymm1, 0x20(%5)\n"
|
||||
"vmovups %%ymm2, 0x40(%5)\n"
|
||||
"vmovups %%ymm3, 0x60(%5)\n"
|
||||
:
|
||||
: "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5
|
||||
: "%rcx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm12", "%ymm4", "%ymm14");
|
||||
}
|
||||
|
||||
void MatVecMul1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
|
||||
size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
|
||||
asm volatile(
|
||||
"cmpq $0, %2\n"
|
||||
"je 0f\n"
|
||||
"vmovups (%2), %%ymm0\n"
|
||||
"vmovups 0x20(%2), %%ymm1\n"
|
||||
"vmovups 0x40(%2), %%ymm2\n"
|
||||
"jmp 1f\n"
|
||||
"0:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"vxorps %%ymm2, %%ymm2, %%ymm2\n"
|
||||
|
||||
"1:\n" // deep
|
||||
"movq %3, %%rcx\n"
|
||||
"shr $3, %%ecx\n"
|
||||
"je 3f\n"
|
||||
"2:\n"
|
||||
"vbroadcastss (%0), %%ymm4\n"
|
||||
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
|
||||
|
||||
"vbroadcastss 4(%0), %%ymm4\n"
|
||||
"vfmadd231ps 96(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 128(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 160(%1), %%ymm4, %%ymm2\n"
|
||||
|
||||
"vbroadcastss 8(%0), %%ymm4\n"
|
||||
"vfmadd231ps 192(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 224(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 256(%1), %%ymm4, %%ymm2\n"
|
||||
|
||||
"vbroadcastss 12(%0), %%ymm4\n"
|
||||
"vfmadd231ps 288(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 320(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 352(%1), %%ymm4, %%ymm2\n"
|
||||
|
||||
"vbroadcastss 16(%0), %%ymm4\n"
|
||||
"vfmadd231ps 384(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 416(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 448(%1), %%ymm4, %%ymm2\n"
|
||||
|
||||
"vbroadcastss 20(%0), %%ymm4\n"
|
||||
"vfmadd231ps 480(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 512(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 544(%1), %%ymm4, %%ymm2\n"
|
||||
|
||||
"vbroadcastss 24(%0), %%ymm4\n"
|
||||
"vfmadd231ps 576(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 608(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 640(%1), %%ymm4, %%ymm2\n"
|
||||
|
||||
"vbroadcastss 28(%0), %%ymm4\n"
|
||||
"vfmadd231ps 672(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 704(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 736(%1), %%ymm4, %%ymm2\n"
|
||||
"addq $768, %1\n"
|
||||
"addq $32, %0\n"
|
||||
"dec %%ecx\n"
|
||||
"jg 2b\n"
|
||||
|
||||
"3:\n"
|
||||
"and $7, %3\n" // deep_remainder
|
||||
"je 5f\n"
|
||||
"4:\n"
|
||||
"vbroadcastss (%0), %%ymm4\n"
|
||||
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
|
||||
"vfmadd231ps 0x40(%1), %%ymm4, %%ymm2\n"
|
||||
"addq $96, %1\n"
|
||||
"addq $4, %0\n"
|
||||
"dec %3\n"
|
||||
"jg 4b\n"
|
||||
|
||||
"5:\n"
|
||||
"and $0x3, %%eax\n" // act_type
|
||||
"je 6f\n"
|
||||
// Relu
|
||||
"vxorps %%ymm12, %%ymm12, %%ymm12\n"
|
||||
"vmaxps %%ymm12, %%ymm0, %%ymm0\n"
|
||||
"vmaxps %%ymm12, %%ymm1, %%ymm1\n"
|
||||
"vmaxps %%ymm12, %%ymm2, %%ymm2\n"
|
||||
|
||||
"and $0x1, %%eax\n"
|
||||
"je 6f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%ecx\n"
|
||||
"vmovd %%ecx, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm12, %%ymm14\n"
|
||||
"vminps %%ymm14, %%ymm0, %%ymm0\n"
|
||||
"vminps %%ymm14, %%ymm1, %%ymm1\n"
|
||||
"vminps %%ymm14, %%ymm2, %%ymm2\n"
|
||||
|
||||
"6:\n"
|
||||
"vmovups %%ymm0, (%5)\n" // dst_0
|
||||
"vmovups %%ymm1, 0x20(%5)\n"
|
||||
"vmovups %%ymm2, 0x40(%5)\n"
|
||||
|
||||
:
|
||||
: "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5
|
||||
: "%rcx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm2", "%ymm12", "%ymm4", "%ymm14");
|
||||
}
|
||||
|
||||
void MatVecMul1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
|
||||
size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
|
||||
asm volatile(
|
||||
"cmpq $0, %2\n"
|
||||
"je 0f\n"
|
||||
"vmovups (%2), %%ymm0\n"
|
||||
"vmovups 0x20(%2), %%ymm1\n"
|
||||
"jmp 1f\n"
|
||||
"0:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"vxorps %%ymm1, %%ymm1, %%ymm1\n"
|
||||
"1:\n"
|
||||
"movq %3, %%rcx\n"
|
||||
"shr $3, %%ecx\n"
|
||||
"je 3f\n"
|
||||
"2:\n" // deep_c8
|
||||
"vbroadcastss (%0), %%ymm4\n"
|
||||
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
|
||||
|
||||
"vbroadcastss 4(%0), %%ymm4\n"
|
||||
"vfmadd231ps 64(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 96(%1), %%ymm4, %%ymm1\n"
|
||||
|
||||
"vbroadcastss 8(%0), %%ymm4\n"
|
||||
"vfmadd231ps 128(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 160(%1), %%ymm4, %%ymm1\n"
|
||||
|
||||
"vbroadcastss 12(%0), %%ymm4\n"
|
||||
"vfmadd231ps 192(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 224(%1), %%ymm4, %%ymm1\n"
|
||||
|
||||
"vbroadcastss 16(%0), %%ymm4\n"
|
||||
"vfmadd231ps 256(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 288(%1), %%ymm4, %%ymm1\n"
|
||||
|
||||
"vbroadcastss 20(%0), %%ymm4\n"
|
||||
"vfmadd231ps 320(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 352(%1), %%ymm4, %%ymm1\n"
|
||||
|
||||
"vbroadcastss 24(%0), %%ymm4\n"
|
||||
"vfmadd231ps 384(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 416(%1), %%ymm4, %%ymm1\n"
|
||||
|
||||
"vbroadcastss 28(%0), %%ymm4\n"
|
||||
"vfmadd231ps 448(%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 480(%1), %%ymm4, %%ymm1\n"
|
||||
"addq $512, %1\n"
|
||||
"addq $32, %0\n"
|
||||
"dec %%ecx\n"
|
||||
"jg 2b\n"
|
||||
|
||||
"3:\n"
|
||||
"and $7, %3\n"
|
||||
"je 5f\n"
|
||||
"4:\n"
|
||||
"vbroadcastss (%0), %%ymm4\n"
|
||||
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
|
||||
"vfmadd231ps 0x20(%1), %%ymm4, %%ymm1\n"
|
||||
"addq $64, %1\n"
|
||||
"addq $4, %0\n"
|
||||
"dec %3\n"
|
||||
"jg 4b\n"
|
||||
|
||||
"5:\n"
|
||||
"and $0x3, %%eax\n" // act_type
|
||||
"je 6f\n"
|
||||
// Relu
|
||||
"vxorps %%ymm12, %%ymm12, %%ymm12\n"
|
||||
"vmaxps %%ymm12, %%ymm0, %%ymm0\n"
|
||||
"vmaxps %%ymm12, %%ymm1, %%ymm1\n"
|
||||
|
||||
"and $0x1, %%eax\n"
|
||||
"je 6f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%ecx\n"
|
||||
"vmovd %%ecx, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm12, %%ymm14\n"
|
||||
"vminps %%ymm14, %%ymm0, %%ymm0\n"
|
||||
"vminps %%ymm14, %%ymm1, %%ymm1\n"
|
||||
|
||||
"6:\n"
|
||||
"vmovups %%ymm0, (%5)\n" // dst_0
|
||||
"vmovups %%ymm1, 0x20(%5)\n"
|
||||
|
||||
:
|
||||
: "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5
|
||||
: "%ecx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm12", "%ymm4", "%ymm14");
|
||||
}
|
||||
|
||||
void MatVecMul1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
|
||||
size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
|
||||
asm volatile(
|
||||
"cmpq $0, %2\n"
|
||||
"je 0f\n"
|
||||
"vmovups (%2), %%ymm0\n"
|
||||
"jmp 1f\n"
|
||||
"0:\n"
|
||||
"vxorps %%ymm0, %%ymm0, %%ymm0\n"
|
||||
"1:\n"
|
||||
"movq %3, %%rcx\n"
|
||||
"shr $3, %%ecx\n"
|
||||
"je 3f\n"
|
||||
"2:\n" // deep_c8
|
||||
"vbroadcastss (%0), %%ymm4\n"
|
||||
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
|
||||
"vbroadcastss 4(%0), %%ymm4\n"
|
||||
"vfmadd231ps 32(%1), %%ymm4, %%ymm0\n"
|
||||
"vbroadcastss 8(%0), %%ymm4\n"
|
||||
"vfmadd231ps 64(%1), %%ymm4, %%ymm0\n"
|
||||
"vbroadcastss 12(%0), %%ymm4\n"
|
||||
"vfmadd231ps 96(%1), %%ymm4, %%ymm0\n"
|
||||
"vbroadcastss 16(%0), %%ymm4\n"
|
||||
"vfmadd231ps 128(%1), %%ymm4, %%ymm0\n"
|
||||
"vbroadcastss 20(%0), %%ymm4\n"
|
||||
"vfmadd231ps 160(%1), %%ymm4, %%ymm0\n"
|
||||
"vbroadcastss 24(%0), %%ymm4\n"
|
||||
"vfmadd231ps 192(%1), %%ymm4, %%ymm0\n"
|
||||
"vbroadcastss 28(%0), %%ymm4\n"
|
||||
"vfmadd231ps 224(%1), %%ymm4, %%ymm0\n"
|
||||
"addq $256, %1\n"
|
||||
"addq $32, %0\n"
|
||||
"dec %%ecx\n"
|
||||
"jg 2b\n"
|
||||
|
||||
"3:\n"
|
||||
"and $7, %3\n"
|
||||
"je 5f\n"
|
||||
"4:\n"
|
||||
"vbroadcastss (%0), %%ymm4\n"
|
||||
"vfmadd231ps (%1), %%ymm4, %%ymm0\n"
|
||||
"addq $32, %1\n"
|
||||
"addq $4, %0\n"
|
||||
"dec %3\n"
|
||||
"jg 4b\n"
|
||||
|
||||
"5:\n"
|
||||
"and $0x3, %%eax\n" // act_type
|
||||
"je 6f\n"
|
||||
// Relu
|
||||
"vxorps %%ymm12, %%ymm12, %%ymm12\n"
|
||||
"vmaxps %%ymm12, %%ymm0, %%ymm0\n"
|
||||
|
||||
"and $0x1, %%eax\n"
|
||||
"je 6f\n"
|
||||
// relu6
|
||||
"mov $0x40C00000, %%ecx\n"
|
||||
"vmovd %%ecx, %%xmm14\n"
|
||||
"vpermps %%ymm14, %%ymm12, %%ymm14\n"
|
||||
"vminps %%ymm14, %%ymm0, %%ymm0\n"
|
||||
|
||||
"6:\n"
|
||||
"vmovups %%ymm0, (%5)\n" // dst_0
|
||||
|
||||
:
|
||||
: "r"(src), "r"(weight), "r"(bias), "r"(deep), "a"(act_flag), "r"(dst) // 5
|
||||
: "%ecx", "%rsi", "%r12", "%ymm0", "%ymm1", "%ymm12", "%ymm4", "%ymm14");
|
||||
}
|
||||
|
||||
#ifdef ENABLE_DEBUG
|
||||
void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
|
||||
size_t row_block, size_t col_block, size_t col_algin, size_t deep) {
|
||||
__m256 dst_data[12];
|
||||
const float *src_sw[12];
|
||||
__m256 weight_data[4];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
weight_data[i] = _mm256_set1_ps(0.0f);
|
||||
}
|
||||
for (int i = 0; i < row_block; ++i) {
|
||||
if (bias != NULL) {
|
||||
for (int j = 0; j < col_block; ++j) {
|
||||
dst_data[i * col_block + j] = _mm256_loadu_ps(bias + j * 8);
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < col_block; ++j) {
|
||||
dst_data[i * col_block + j] = _mm256_set1_ps(0.0f);
|
||||
}
|
||||
}
|
||||
src_sw[i] = src + i * deep;
|
||||
}
|
||||
const float *weight_kernel = weight;
|
||||
for (int ic = 0; ic < deep; ++ic) {
|
||||
for (int j = 0; j < col_block; ++j) {
|
||||
weight_data[j] = _mm256_loadu_ps(weight_kernel + j * C8NUM);
|
||||
}
|
||||
for (int i = 0; i < row_block; ++i) {
|
||||
for (int j = 0; j < col_block; ++j) {
|
||||
dst_data[i * col_block + j] =
|
||||
_mm256_fmadd_ps(_mm256_set1_ps(src_sw[i][ic]), weight_data[j], dst_data[i * col_block + j]);
|
||||
}
|
||||
}
|
||||
weight_kernel += C8NUM * col_block;
|
||||
} // ic loop
|
||||
// add bias and relu
|
||||
for (int i = 0; i < row_block; ++i) {
|
||||
for (int j = 0; j < col_block; ++j) {
|
||||
if (0x1 & act_flag) { // relu6
|
||||
dst_data[i * col_block + j] = _mm256_min_ps(dst_data[i * col_block + j], _mm256_set1_ps(6.0f));
|
||||
}
|
||||
if (0x2 & act_flag) { // relu
|
||||
dst_data[i * col_block + j] = _mm256_max_ps(dst_data[i * col_block + j], _mm256_set1_ps(0.0f));
|
||||
}
|
||||
_mm256_storeu_ps(dst + i * col_algin + j * C8NUM, dst_data[i * col_block + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
|
|
@ -46,11 +46,13 @@ void RowMajor2Row6Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
|||
void RowMajor2Row8Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Row12Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Row16Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, size_t row, size_t col);
|
||||
void RowMajor2Row32Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Col4Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Col6Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Col8Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Col12Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Col16Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
void RowMajor2Col32Major(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||
|
@ -78,6 +80,22 @@ void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *
|
|||
#ifdef ENABLE_AVX
|
||||
void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, size_t act_type, size_t depth,
|
||||
size_t row, size_t col, size_t stride, size_t write_mode);
|
||||
typedef void (*MatVecMulKernel)(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
|
||||
size_t row_block, size_t col_block, size_t col_algin, size_t deep);
|
||||
void MatVecMulAvxFp32(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int cur_col,
|
||||
int col_align);
|
||||
void MatVecMul1x32Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
|
||||
size_t row_block, size_t col_block, size_t col_algin, size_t deep);
|
||||
void MatVecMul1x24Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
|
||||
size_t row_block, size_t col_block, size_t col_algin, size_t deep);
|
||||
void MatVecMul1x16Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
|
||||
size_t row_block, size_t col_block, size_t col_algin, size_t deep);
|
||||
void MatVecMul1x8Kernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
|
||||
size_t row_block, size_t col_block, size_t col_algin, size_t deep);
|
||||
#ifdef ENABLE_DEBUG
|
||||
void MatVecMulRowxColKernel(float *dst, const float *src, const float *weight, const float *bias, size_t act_flag,
|
||||
size_t row_block, size_t col_block, size_t col_algin, size_t deep);
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
|
||||
|
@ -86,5 +104,4 @@ void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, A
|
|||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_NNACL_FP32_MATMUL_H_
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#define C8NUM 8
|
||||
#define C12NUM 12
|
||||
#define C16NUM 16
|
||||
#define C32NUM 32
|
||||
#define TILE_NUM 8
|
||||
|
||||
#define MSMIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
|
|
|
@ -79,6 +79,10 @@ class MS_API Allocator {
|
|||
///
|
||||
/// \return Pointer of ready memory.
|
||||
virtual void *Prepare(void *ptr) { return ptr; }
|
||||
|
||||
protected:
|
||||
// memory aligned bytes
|
||||
size_t aligned_size_ = 32;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_ALLOCATOR_H_
|
||||
|
|
|
@ -23,7 +23,7 @@ std::shared_ptr<Allocator> Allocator::Create() {
|
|||
return std::shared_ptr<Allocator>(new (std::nothrow) DefaultAllocator());
|
||||
}
|
||||
|
||||
DefaultAllocator::DefaultAllocator() = default;
|
||||
DefaultAllocator::DefaultAllocator(size_t aligned_size) { aligned_size_ = aligned_size; }
|
||||
|
||||
DefaultAllocator::~DefaultAllocator() { Clear(); }
|
||||
|
||||
|
@ -69,7 +69,7 @@ void *DefaultAllocator::Malloc(size_t size) {
|
|||
return membuf->buf;
|
||||
}
|
||||
|
||||
std::unique_ptr<MemBuf> membuf(reinterpret_cast<MemBuf *>(malloc(sizeof(MemBuf) + size)));
|
||||
std::unique_ptr<MemBuf> membuf(reinterpret_cast<MemBuf *>(malloc(sizeof(MemBuf) + size + aligned_size_)));
|
||||
if (membuf == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc membuf return nullptr";
|
||||
UnLock();
|
||||
|
@ -78,7 +78,10 @@ void *DefaultAllocator::Malloc(size_t size) {
|
|||
this->total_size_ += size;
|
||||
membuf->ref_count_ = 0;
|
||||
membuf->size = size;
|
||||
membuf->buf = reinterpret_cast<char *>(membuf.get()) + sizeof(MemBuf);
|
||||
auto aligned_bytes =
|
||||
reinterpret_cast<size_t>((reinterpret_cast<char *>(membuf.get()) + sizeof(MemBuf))) % aligned_size_;
|
||||
aligned_bytes = aligned_bytes == 0 ? 0 : aligned_size_ - aligned_bytes;
|
||||
membuf->buf = reinterpret_cast<char *>(membuf.get()) + sizeof(MemBuf) + aligned_bytes;
|
||||
auto bufPtr = membuf->buf;
|
||||
allocatedList_[bufPtr] = membuf.release();
|
||||
UnLock();
|
||||
|
|
|
@ -35,7 +35,7 @@ struct AllocatorContext {
|
|||
|
||||
class DefaultAllocator : public Allocator {
|
||||
public:
|
||||
DefaultAllocator();
|
||||
explicit DefaultAllocator(size_t aligned_size = 32);
|
||||
~DefaultAllocator() override;
|
||||
void SetContext(const AllocatorContext &ctx);
|
||||
void *Malloc(size_t size) override;
|
||||
|
|
|
@ -27,13 +27,13 @@ namespace mindspore::kernel {
|
|||
int FullconnectionCPUKernel::Init() {
|
||||
MatmulFp32BaseCPUKernel::InitParameter();
|
||||
|
||||
if (params_->a_const_ == true) {
|
||||
if (params_->a_const_) {
|
||||
auto a_shape = in_tensors_.at(0)->shape();
|
||||
params_->row_ = a_shape[0];
|
||||
params_->deep_ = a_shape[1];
|
||||
}
|
||||
|
||||
if (params_->b_const_ == true) {
|
||||
if (params_->b_const_) {
|
||||
auto b_shape = in_tensors_.at(1)->shape();
|
||||
params_->col_ = b_shape[0];
|
||||
params_->deep_ = b_shape[1];
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "src/runtime/kernel/arm/fp32/matmul_fp32_base.h"
|
||||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
#include "nnacl/fp32/pack_fp32.h"
|
||||
|
||||
using mindspore::lite::RET_NULL_PTR;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int MatmulBaseFloatRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
|
@ -32,7 +35,6 @@ MatmulFp32BaseCPUKernel::~MatmulFp32BaseCPUKernel() {
|
|||
FreeResizeBufA();
|
||||
FreeResizeBufB();
|
||||
FreeBiasBuf();
|
||||
return;
|
||||
}
|
||||
|
||||
void MatmulFp32BaseCPUKernel::InitParameter() {
|
||||
|
@ -43,29 +45,25 @@ void MatmulFp32BaseCPUKernel::InitParameter() {
|
|||
params_->a_const_ = false;
|
||||
params_->b_const_ = false;
|
||||
}
|
||||
#ifdef ENABLE_AVX
|
||||
row_tile_ = C6NUM;
|
||||
col_tile_ = C16NUM;
|
||||
#elif defined(ENABLE_ARM32)
|
||||
row_tile_ = C12NUM;
|
||||
col_tile_ = C4NUM;
|
||||
#elif defined(ENABLE_SSE)
|
||||
row_tile_ = C4NUM;
|
||||
col_tile_ = C8NUM;
|
||||
#else
|
||||
row_tile_ = C12NUM;
|
||||
col_tile_ = C8NUM;
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
void MatmulFp32BaseCPUKernel::ResizeParameter() {
|
||||
if (params_->row_ == 1) {
|
||||
vec_matmul_ = true;
|
||||
#ifdef ENABLE_AVX
|
||||
// vector matmul col is aligned to C8NUM in avx
|
||||
col_tile_ = C8NUM;
|
||||
#endif
|
||||
row_tile_ = 1;
|
||||
}
|
||||
params_->row_align_ = vec_matmul_ ? 1 : UP_ROUND(params_->row_, row_tile_);
|
||||
params_->row_align_ = UP_ROUND(params_->row_, row_tile_);
|
||||
#ifdef ENABLE_AVX
|
||||
// avx is aligned to col_tile_
|
||||
params_->col_align_ = UP_ROUND(params_->col_, col_tile_);
|
||||
#else
|
||||
params_->col_align_ = vec_matmul_ ? params_->col_ : UP_ROUND(params_->col_, col_tile_);
|
||||
return;
|
||||
#endif
|
||||
oc_res_ = params_->col_ % col_tile_;
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::InitBufferA() {
|
||||
|
@ -102,7 +100,7 @@ int MatmulFp32BaseCPUKernel::InitBufferB() {
|
|||
|
||||
int MatmulFp32BaseCPUKernel::CalBroadCastBiasDataElements() {
|
||||
lite::Tensor *bias_tensor = in_tensors_.at(2);
|
||||
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), C16NUM);
|
||||
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), col_tile_);
|
||||
if (!params_->b_const_) {
|
||||
MS_LOG(WARNING) << "matmul do not support broadcast bias data";
|
||||
} else {
|
||||
|
@ -112,9 +110,9 @@ int MatmulFp32BaseCPUKernel::CalBroadCastBiasDataElements() {
|
|||
return max_bias_data;
|
||||
}
|
||||
if (params_->b_transpose_) {
|
||||
max_bias_data = UP_ROUND(const_tensor->shape()[shape_size - kBiasIndex], C16NUM);
|
||||
max_bias_data = UP_ROUND(const_tensor->shape()[shape_size - kBiasIndex], col_tile_);
|
||||
} else {
|
||||
max_bias_data = UP_ROUND(const_tensor->shape()[shape_size - kWeightIndex], C16NUM);
|
||||
max_bias_data = UP_ROUND(const_tensor->shape()[shape_size - kWeightIndex], col_tile_);
|
||||
}
|
||||
}
|
||||
return max_bias_data;
|
||||
|
@ -123,26 +121,22 @@ int MatmulFp32BaseCPUKernel::CalBroadCastBiasDataElements() {
|
|||
int MatmulFp32BaseCPUKernel::InitBiasData() {
|
||||
if (in_tensors_.size() == 3) {
|
||||
auto bias_tensor = in_tensors_[2];
|
||||
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), C16NUM);
|
||||
int max_bias_data = UP_ROUND(bias_tensor->ElementsNum(), col_tile_);
|
||||
// malloc addr need to aligned to 32 bytes
|
||||
bias_ptr_ = reinterpret_cast<float *>(malloc(max_bias_data * sizeof(float)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// whether to broadcast bias data
|
||||
if (bias_tensor->ElementsNum() == 1) {
|
||||
max_bias_data = CalBroadCastBiasDataElements();
|
||||
bias_ptr_ = reinterpret_cast<float *>(malloc(max_bias_data * sizeof(float)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
float broadcast_data = (reinterpret_cast<float *>(bias_tensor->data_c()))[0];
|
||||
// broadcast bias data
|
||||
for (int i = 0; i < max_bias_data; ++i) {
|
||||
bias_ptr_[i] = broadcast_data;
|
||||
}
|
||||
} else {
|
||||
bias_ptr_ = reinterpret_cast<float *>(malloc(max_bias_data * sizeof(float)));
|
||||
if (bias_ptr_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc bias_ptr_ failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(bias_ptr_, 0, max_bias_data * sizeof(float));
|
||||
memcpy(bias_ptr_, bias_tensor->data_c(), bias_tensor->ElementsNum() * sizeof(float));
|
||||
}
|
||||
|
@ -159,38 +153,32 @@ int MatmulFp32BaseCPUKernel::InitMatrixA(const float *src_ptr) {
|
|||
for (int i = 0; i < params_->batch; i++) {
|
||||
const float *src = src_ptr + i * params_->deep_ * params_->row_;
|
||||
float *dst = a_pack_ptr_ + i * params_->deep_ * params_->row_align_;
|
||||
#ifdef ENABLE_AVX
|
||||
if (params_->a_transpose_) {
|
||||
RowMajor2Row6Major(src, dst, params_->deep_, params_->row_);
|
||||
matrix_a_pack_fun_(src, dst, params_->deep_, params_->row_);
|
||||
} else {
|
||||
RowMajor2Col6Major(src, dst, params_->row_, params_->deep_);
|
||||
matrix_a_pack_fun_(src, dst, params_->row_, params_->deep_);
|
||||
}
|
||||
#elif defined(ENABLE_SSE)
|
||||
if (params_->a_transpose_) {
|
||||
RowMajor2Row4Major(src, dst, params_->deep_, params_->row_);
|
||||
} else {
|
||||
RowMajor2Col4Major(src, dst, params_->row_, params_->deep_);
|
||||
}
|
||||
#else
|
||||
if (params_->a_transpose_) {
|
||||
RowMajor2Row12Major(src, dst, params_->deep_, params_->row_);
|
||||
} else {
|
||||
RowMajor2Col12Major(src, dst, params_->row_, params_->deep_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::InitMatrixB(const float *src_ptr) {
|
||||
if (vec_matmul_) {
|
||||
if (params_->b_transpose_) {
|
||||
memcpy(b_pack_ptr_, src_ptr, params_->batch * params_->col_ * params_->deep_ * sizeof(float));
|
||||
} else {
|
||||
for (int i = 0; i < params_->batch; i++) {
|
||||
const float *src_data = src_ptr + i * params_->deep_ * params_->col_;
|
||||
float *dst_data = b_pack_ptr_ + i * params_->deep_ * params_->col_;
|
||||
RowMajor2ColMajor(src_data, dst_data, params_->deep_, params_->col_);
|
||||
for (int i = 0; i < params_->batch; i++) {
|
||||
const float *src_data = src_ptr + i * params_->deep_ * params_->col_;
|
||||
float *dst = b_pack_ptr_ + i * params_->deep_ * params_->col_align_;
|
||||
if (params_->b_transpose_) {
|
||||
#ifdef ENABLE_AVX
|
||||
RowMajor2Col32Major(src_data, dst, params_->deep_, params_->col_);
|
||||
#else
|
||||
memcpy(dst, src_data, params_->col_ * params_->deep_ * sizeof(float));
|
||||
#endif
|
||||
} else {
|
||||
#ifdef ENABLE_AVX
|
||||
RowMajor2Row32Major(src_data, dst, params_->col_, params_->deep_);
|
||||
#else
|
||||
RowMajor2ColMajor(src_data, dst, params_->deep_, params_->col_);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -199,25 +187,11 @@ int MatmulFp32BaseCPUKernel::InitMatrixB(const float *src_ptr) {
|
|||
for (int i = 0; i < params_->batch; i++) {
|
||||
const float *src = src_ptr + i * params_->deep_ * params_->col_;
|
||||
float *dst = b_pack_ptr_ + i * params_->deep_ * params_->col_align_;
|
||||
#ifdef ENABLE_AVX
|
||||
if (params_->b_transpose_) {
|
||||
RowMajor2Col16Major(src, dst, params_->col_, params_->deep_);
|
||||
matrix_b_pack_fun_(src, dst, params_->col_, params_->deep_);
|
||||
} else {
|
||||
RowMajor2Row16Major(src, dst, params_->deep_, params_->col_);
|
||||
matrix_b_pack_fun_(src, dst, params_->deep_, params_->col_);
|
||||
}
|
||||
#elif defined(ENABLE_ARM32)
|
||||
if (params_->b_transpose_) {
|
||||
RowMajor2Col4Major(src, dst, params_->col_, params_->deep_);
|
||||
} else {
|
||||
RowMajor2Row4Major(src, dst, params_->deep_, params_->col_);
|
||||
}
|
||||
#else
|
||||
if (params_->b_transpose_) {
|
||||
RowMajor2Col8Major(src, dst, params_->col_, params_->deep_);
|
||||
} else {
|
||||
RowMajor2Row8Major(src, dst, params_->deep_, params_->col_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -227,7 +201,6 @@ void MatmulFp32BaseCPUKernel::FreeBiasBuf() {
|
|||
free(bias_ptr_);
|
||||
bias_ptr_ = nullptr;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void MatmulFp32BaseCPUKernel::FreeResizeBufA() {
|
||||
|
@ -239,7 +212,6 @@ void MatmulFp32BaseCPUKernel::FreeResizeBufA() {
|
|||
} else {
|
||||
a_pack_ptr_ = nullptr;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void MatmulFp32BaseCPUKernel::FreeResizeBufB() {
|
||||
|
@ -251,22 +223,34 @@ void MatmulFp32BaseCPUKernel::FreeResizeBufB() {
|
|||
} else {
|
||||
b_pack_ptr_ = nullptr;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::FloatRun(int task_id) {
|
||||
int current_stride_oc = thread_stride_ * col_tile_;
|
||||
int current_rest_oc = params_->col_ - task_id * thread_stride_ * col_tile_;
|
||||
int cur_oc = MSMIN(current_stride_oc, current_rest_oc);
|
||||
int current_start_oc = task_id * thread_stride_ * col_tile_;
|
||||
int current_rest_oc = 0;
|
||||
#if defined(ENABLE_AVX)
|
||||
if (vec_matmul_) {
|
||||
current_rest_oc = params_->col_align_ - current_start_oc;
|
||||
} else {
|
||||
current_rest_oc = params_->col_ - current_start_oc;
|
||||
}
|
||||
#else
|
||||
current_rest_oc = params_->col_ - current_start_oc;
|
||||
#endif
|
||||
int cur_oc = MSMIN(thread_stride_ * col_tile_, current_rest_oc);
|
||||
if (cur_oc <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
auto b = batch_b_ptr_ + task_id * thread_stride_ * col_tile_ * params_->deep_;
|
||||
auto c = batch_c_ptr_ + task_id * thread_stride_ * col_tile_;
|
||||
auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + task_id * thread_stride_ * col_tile_;
|
||||
auto b = batch_b_ptr_ + current_start_oc * params_->deep_;
|
||||
auto c = batch_c_ptr_ + current_start_oc;
|
||||
auto bias = (bias_ptr_ == nullptr) ? nullptr : bias_ptr_ + current_start_oc;
|
||||
if (vec_matmul_) {
|
||||
#ifdef ENABLE_AVX
|
||||
MatVecMulAvxFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc, params_->col_align_);
|
||||
#else
|
||||
MatVecMulFp32(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, cur_oc);
|
||||
#endif
|
||||
} else {
|
||||
MatMulOpt(batch_a_ptr_, b, c, bias, params_->act_type_, params_->deep_, params_->row_, cur_oc, params_->col_,
|
||||
OutType_Nhwc);
|
||||
|
@ -275,81 +259,141 @@ int MatmulFp32BaseCPUKernel::FloatRun(int task_id) {
|
|||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::Init() {
|
||||
ResizeParameter();
|
||||
#ifdef ENABLE_AVX
|
||||
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2Row6Major : RowMajor2Col6Major;
|
||||
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col16Major : RowMajor2Row16Major;
|
||||
row_tile_ = C6NUM;
|
||||
col_tile_ = C16NUM;
|
||||
#elif defined(ENABLE_ARM32)
|
||||
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2Row12Major : RowMajor2Col12Major;
|
||||
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col4Major : RowMajor2Row4Major;
|
||||
row_tile_ = C12NUM;
|
||||
col_tile_ = C4NUM;
|
||||
#elif defined(ENABLE_SSE)
|
||||
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2Row4Major : RowMajor2Col4Major;
|
||||
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col8Major : RowMajor2Row8Major;
|
||||
row_tile_ = C4NUM;
|
||||
col_tile_ = C8NUM;
|
||||
#else
|
||||
matrix_a_pack_fun_ = params_->a_transpose_ ? RowMajor2Row12Major : RowMajor2Col12Major;
|
||||
matrix_b_pack_fun_ = params_->b_transpose_ ? RowMajor2Col8Major : RowMajor2Row8Major;
|
||||
row_tile_ = C12NUM;
|
||||
col_tile_ = C8NUM;
|
||||
#endif
|
||||
params_->row_align_ = UP_ROUND(params_->row_, row_tile_);
|
||||
matrix_a_pack_size_ = params_->batch * params_->row_align_ * params_->deep_;
|
||||
matrix_b_pack_size_ = params_->batch * params_->col_align_ * params_->deep_;
|
||||
if ((matrix_a_pack_size_ + matrix_b_pack_size_) < 0) {
|
||||
if (matrix_a_pack_size_ < 0) {
|
||||
MS_LOG(ERROR) << "Matrix pack size is negative "
|
||||
<< "matrix_a_pack_size=" << matrix_a_pack_size_ << "matrix_b_pack_size" << matrix_b_pack_size_;
|
||||
<< "matrix_a_pack_size=" << matrix_a_pack_size_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto ret = InitBiasData();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitBiasData failed";
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (params_->a_const_ == true) {
|
||||
if (params_->a_const_) {
|
||||
if (RET_OK != InitBufferA()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->data_c()));
|
||||
ret = InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->data_c()));
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InitMatrixA failed!";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (params_->b_const_ == true) {
|
||||
/* copy origin b data, pack in resize
|
||||
* pack after a infershape done */
|
||||
if (params_->b_const_) {
|
||||
// only copy weight data
|
||||
// resize or run to pack
|
||||
auto b_tensor = in_tensors_[1];
|
||||
src_b_ = reinterpret_cast<float *>(malloc(params_->batch * params_->col_ * params_->deep_ * sizeof(float)));
|
||||
src_b_ = reinterpret_cast<float *>(malloc(params_->batch * params_->deep_ * params_->col_ * sizeof(float)));
|
||||
if (src_b_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Matmul fp16 malloc src_b_ failed";
|
||||
MS_LOG(ERROR) << "matmul fp16 src_b_ is failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(src_b_, b_tensor->data_c(), params_->batch * params_->col_ * params_->deep_ * sizeof(float));
|
||||
memcpy(src_b_, b_tensor->data_c(), params_->batch * params_->deep_ * params_->col_ * sizeof(float));
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void MatmulFp32BaseCPUKernel::FreeBuffSrcB() {
|
||||
if (src_b_ != nullptr) {
|
||||
free(src_b_);
|
||||
src_b_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::ReSize() {
|
||||
ResizeParameter();
|
||||
matrix_a_pack_size_ = params_->batch * params_->row_align_ * params_->deep_;
|
||||
matrix_b_pack_size_ = params_->batch * params_->col_align_ * params_->deep_;
|
||||
if ((matrix_a_pack_size_ + matrix_b_pack_size_) < 0) {
|
||||
if (matrix_a_pack_size_ < 0 || matrix_b_pack_size_ < 0) {
|
||||
MS_LOG(ERROR) << "Matrix pack size is negative "
|
||||
<< "matrix_a_pack_size=" << matrix_a_pack_size_ << "matrix_b_pack_size" << matrix_b_pack_size_;
|
||||
<< "matrix_a_pack_size=" << matrix_a_pack_size_ << "matrix_b_pack_size=" << matrix_b_pack_size_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (op_parameter_->is_train_session_) {
|
||||
set_workspace_size((matrix_a_pack_size_ + matrix_b_pack_size_) * sizeof(float));
|
||||
}
|
||||
|
||||
if (params_->b_const_ == true && src_b_ != nullptr) {
|
||||
if (RET_OK != InitBufferB()) {
|
||||
if (params_->b_const_ && src_b_ != nullptr) {
|
||||
if (InitBufferB() != RET_OK) {
|
||||
FreeBuffSrcB();
|
||||
return RET_ERROR;
|
||||
}
|
||||
InitMatrixB(src_b_);
|
||||
free(src_b_);
|
||||
src_b_ = nullptr;
|
||||
if (InitMatrixB(src_b_) != RET_OK) {
|
||||
FreeBuffSrcB();
|
||||
MS_LOG(ERROR) << "InitMatrixB failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
FreeBuffSrcB();
|
||||
}
|
||||
|
||||
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_align_, col_tile_));
|
||||
#if defined(ENABLE_AVX)
|
||||
if (vec_matmul_) {
|
||||
thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_ * C4NUM), thread_count_) * C4NUM;
|
||||
} else {
|
||||
thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_);
|
||||
}
|
||||
#else
|
||||
thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_);
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::InitTmpOutBuffer() {
|
||||
auto out_data = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
|
||||
MS_ASSERT(out_data != nullptr);
|
||||
#ifdef ENABLE_AVX
|
||||
if (oc_res_ != 0 && vec_matmul_) { // vec matmul need to malloc dst
|
||||
int out_channel = params_->col_;
|
||||
int oc_block_num = UP_DIV(out_channel, col_tile_);
|
||||
MS_ASSERT(context_->allocator != nullptr);
|
||||
output_data_ = reinterpret_cast<float *>(
|
||||
context_->allocator->Malloc(params_->batch * params_->row_ * oc_block_num * col_tile_ * sizeof(float)));
|
||||
if (output_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc tmp output data failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
} else { // need to malloc dst to algin block
|
||||
output_data_ = out_data;
|
||||
}
|
||||
#else
|
||||
output_data_ = out_data;
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulFp32BaseCPUKernel::Run() {
|
||||
auto a_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
|
||||
auto b_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->data_c());
|
||||
auto c_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
|
||||
|
||||
if (params_->a_const_ == false) {
|
||||
if (!params_->a_const_) {
|
||||
auto a_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
|
||||
if (RET_OK != InitBufferA()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
InitMatrixA(a_ptr);
|
||||
}
|
||||
if (params_->b_const_ == false) {
|
||||
if (!params_->b_const_) {
|
||||
auto b_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->data_c());
|
||||
if (RET_OK != InitBufferB()) {
|
||||
FreeResizeBufA();
|
||||
return RET_ERROR;
|
||||
|
@ -357,31 +401,45 @@ int MatmulFp32BaseCPUKernel::Run() {
|
|||
InitMatrixB(b_ptr);
|
||||
}
|
||||
|
||||
auto ret = InitTmpOutBuffer();
|
||||
if (ret != RET_OK) {
|
||||
FreeResizeBufA();
|
||||
FreeResizeBufB();
|
||||
MS_LOG(ERROR) << "InitTmpOutBuffer error!";
|
||||
return ret;
|
||||
}
|
||||
|
||||
for (int i = 0; i < params_->batch; ++i) {
|
||||
batch_a_ptr_ = a_pack_ptr_ + i * params_->row_align_ * params_->deep_;
|
||||
batch_b_ptr_ = b_pack_ptr_ + i * params_->deep_ * params_->col_align_;
|
||||
if (vec_matmul_) {
|
||||
batch_a_ptr_ = a_pack_ptr_ + i * params_->deep_;
|
||||
batch_b_ptr_ = b_pack_ptr_ + i * params_->deep_ * params_->col_;
|
||||
batch_c_ptr_ = c_ptr + i * params_->row_ * params_->col_;
|
||||
batch_c_ptr_ = output_data_ + i * params_->row_ * params_->col_align_;
|
||||
} else {
|
||||
batch_a_ptr_ = a_pack_ptr_ + i * params_->row_align_ * params_->deep_;
|
||||
batch_b_ptr_ = b_pack_ptr_ + i * params_->deep_ * params_->col_align_;
|
||||
batch_c_ptr_ = c_ptr + i * params_->row_ * params_->col_;
|
||||
// need not aligned
|
||||
batch_c_ptr_ = output_data_ + i * params_->row_ * params_->col_;
|
||||
}
|
||||
auto ret = static_cast<const lite::InnerContext *>(this->context_)
|
||||
->thread_pool_->ParallelLaunch(MatmulBaseFloatRun, this, thread_count_);
|
||||
ret = static_cast<const lite::InnerContext *>(this->context_)
|
||||
->thread_pool_->ParallelLaunch(MatmulBaseFloatRun, this, thread_count_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "MatmulBaseFloatRun failed";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
if (params_->a_const_ == false) {
|
||||
#ifdef ENABLE_AVX
|
||||
if (oc_res_ != 0 && vec_matmul_) {
|
||||
auto out_data = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
|
||||
PackNHWCXToNHWCFp32(output_data_, out_data, params_->batch, params_->row_, params_->col_, col_tile_);
|
||||
context_->allocator->Free(output_data_);
|
||||
output_data_ = nullptr;
|
||||
}
|
||||
#endif
|
||||
if (!params_->a_const_) {
|
||||
FreeResizeBufA();
|
||||
}
|
||||
|
||||
if (params_->b_const_ == false) {
|
||||
if (!params_->b_const_) {
|
||||
FreeResizeBufB();
|
||||
}
|
||||
return RET_OK;
|
||||
return ret;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -27,6 +27,7 @@ using mindspore::lite::RET_MEMORY_FAILED;
|
|||
using mindspore::lite::RET_OK;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
using MatrixPackFun = void (*)(const float *src_ptr, float *dst_ptr, int row, int col);
|
||||
class MatmulFp32BaseCPUKernel : public InnerKernel {
|
||||
public:
|
||||
MatmulFp32BaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
|
@ -35,7 +36,7 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
params_ = reinterpret_cast<MatMulParameter *>(op_parameter_);
|
||||
vec_matmul_ = false;
|
||||
}
|
||||
~MatmulFp32BaseCPUKernel();
|
||||
~MatmulFp32BaseCPUKernel() override;
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
@ -56,7 +57,9 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
void ResizeParameter();
|
||||
void FreeResizeBufA();
|
||||
void FreeResizeBufB();
|
||||
void FreeBuffSrcB();
|
||||
int CalBroadCastBiasDataElements();
|
||||
int InitTmpOutBuffer();
|
||||
|
||||
protected:
|
||||
MatMulParameter *params_ = nullptr;
|
||||
|
@ -66,16 +69,20 @@ class MatmulFp32BaseCPUKernel : public InnerKernel {
|
|||
private:
|
||||
int col_tile_ = 0;
|
||||
int row_tile_ = 0;
|
||||
int oc_res_ = 0;
|
||||
int thread_stride_ = 0;
|
||||
int thread_count_ = 0;
|
||||
bool vec_matmul_ = false;
|
||||
float *src_b_ = nullptr;
|
||||
float *bias_ptr_ = nullptr;
|
||||
float *batch_a_ptr_ = nullptr;
|
||||
float *batch_b_ptr_ = nullptr;
|
||||
float *batch_c_ptr_ = nullptr;
|
||||
float *output_data_ = nullptr;
|
||||
int matrix_a_pack_size_ = -1;
|
||||
int matrix_b_pack_size_ = -1;
|
||||
float *src_b_ = nullptr;
|
||||
MatrixPackFun matrix_a_pack_fun_ = nullptr;
|
||||
MatrixPackFun matrix_b_pack_fun_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_MATMUL_FP32_BASE_H_
|
||||
|
|
Loading…
Reference in New Issue