forked from mindspore-Ecosystem/mindspore
!10227 [ms][lite][cpu] win x86 avx optimize
From: @lzkcode Reviewed-by: Signed-off-by:
This commit is contained in:
commit
3f0aeaa8fc
|
@ -1,15 +1,16 @@
|
||||||
#ifdef ENABLE_AVX
|
#ifdef ENABLE_AVX
|
||||||
#ifndef WIN32
|
|
||||||
|
|
||||||
.text
|
.text
|
||||||
.align 4
|
.align 4
|
||||||
.global ConvDwFp32Avx3x3
|
.global ConvDwFp32Avx3x3
|
||||||
#ifndef __APPLE__
|
#ifndef __APPLE__
|
||||||
|
#ifndef WIN32
|
||||||
.type ConvDwFp32Avx3x3, %function
|
.type ConvDwFp32Avx3x3, %function
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
// void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels, int output_width,
|
// void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels, size_t output_width,
|
||||||
// size_t input_stride, size_t relu)
|
// size_t input_stride, size_t relum, szie_t relu6)
|
||||||
|
// in linux x64 platfrom:
|
||||||
// rdi: output
|
// rdi: output
|
||||||
// rsi: input
|
// rsi: input
|
||||||
// rdx: weights
|
// rdx: weights
|
||||||
|
@ -20,6 +21,16 @@
|
||||||
// 16: relu
|
// 16: relu
|
||||||
// 24: relu6
|
// 24: relu6
|
||||||
|
|
||||||
|
// in win x64 platfrom: "shadow space" needs to be opened up for first four parameters ==> 32 bites
|
||||||
|
// rcx: output
|
||||||
|
// rdx: input
|
||||||
|
// r8: weights
|
||||||
|
// r9: bias
|
||||||
|
// 40: channels
|
||||||
|
// 48: output_width
|
||||||
|
// 56: input_stride
|
||||||
|
// 64: relu
|
||||||
|
// 72: relu6
|
||||||
ConvDwFp32Avx3x3:
|
ConvDwFp32Avx3x3:
|
||||||
pushq %r15
|
pushq %r15
|
||||||
pushq %r14
|
pushq %r14
|
||||||
|
@ -27,14 +38,34 @@ ConvDwFp32Avx3x3:
|
||||||
pushq %r12
|
pushq %r12
|
||||||
pushq %rbx
|
pushq %rbx
|
||||||
pushq %rbp
|
pushq %rbp
|
||||||
pushq %r9
|
pushq %r9 // -56
|
||||||
pushq %r8
|
pushq %r8 // -64
|
||||||
pushq %rcx
|
pushq %rcx // -72
|
||||||
pushq %rdx
|
pushq %rdx // -80
|
||||||
pushq %rsi
|
pushq %rsi // -88
|
||||||
pushq %rdi
|
pushq %rdi // -96
|
||||||
addq $96, %rsp
|
addq $96, %rsp
|
||||||
|
|
||||||
|
#ifdef WIN32
|
||||||
|
movq %rcx, %rdi
|
||||||
|
movq %rdx, %rsi
|
||||||
|
movq %r8, %rdx
|
||||||
|
movq %r9, %rcx
|
||||||
|
movq 40(%rsp), %r8 // channels
|
||||||
|
movq 48(%rsp), %r9 // output_width
|
||||||
|
|
||||||
|
mov %rdx, -80(%rsp)
|
||||||
|
mov %rcx, -72(%rsp)
|
||||||
|
mov %r9, -56(%rsp)
|
||||||
|
mov %r8, -64(%rsp)
|
||||||
|
movq 56(%rsp), %rbp // input_stride
|
||||||
|
movq %rbp, 8(%rsp)
|
||||||
|
movq 64(%rsp), %rbp // relu
|
||||||
|
movq %rbp, 16(%rsp)
|
||||||
|
movq 72(%rsp), %rbp // relu6
|
||||||
|
movq %rbp, 24(%rsp)
|
||||||
|
#endif
|
||||||
|
|
||||||
movq $6, %rax
|
movq $6, %rax
|
||||||
vcvtsi2ss %rax, %xmm15, %xmm15
|
vcvtsi2ss %rax, %xmm15, %xmm15
|
||||||
vshufps $0, %xmm15, %xmm15, %xmm15
|
vshufps $0, %xmm15, %xmm15, %xmm15
|
||||||
|
@ -270,4 +301,3 @@ End:
|
||||||
popq %r15
|
popq %r15
|
||||||
retq
|
retq
|
||||||
#endif
|
#endif
|
||||||
#endif
|
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
#ifdef ENABLE_AVX
|
#ifdef ENABLE_AVX
|
||||||
#ifndef WIN32
|
|
||||||
.text
|
.text
|
||||||
.align 4
|
.align 4
|
||||||
.global MatmulFloatAvxOpt
|
.global MatmulFloatAvxOpt
|
||||||
#ifndef __APPLE__
|
#ifndef __APPLE__
|
||||||
|
#ifndef WIN32
|
||||||
.type MatmulFloatAvxOpt, %function
|
.type MatmulFloatAvxOpt, %function
|
||||||
#endif
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
// void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
|
// void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth
|
||||||
// int row, int col, size_t stride, size_t writeMode)
|
// int row, int col, size_t stride, size_t writeMode)
|
||||||
|
// parameters pass in Linux x86 platform:
|
||||||
// rdi: a
|
// rdi: a
|
||||||
// rsi: b
|
// rsi: b
|
||||||
// rdx: c
|
// rdx: c
|
||||||
|
@ -20,6 +22,18 @@
|
||||||
// 24: stride
|
// 24: stride
|
||||||
// 32: writeNhwc/writeWino
|
// 32: writeNhwc/writeWino
|
||||||
|
|
||||||
|
// parameters pass in win x64 platfrom: "shadow space" needs to be opened up for first four parameters ==> 32 bites
|
||||||
|
// rcx: a
|
||||||
|
// rdx: b
|
||||||
|
// r8: c
|
||||||
|
// r9: bias
|
||||||
|
// 40: act_type
|
||||||
|
// 48: depth
|
||||||
|
// 56: row
|
||||||
|
// 64: col
|
||||||
|
// 72: stride
|
||||||
|
// 80: writeMode
|
||||||
|
|
||||||
MatmulFloatAvxOpt:
|
MatmulFloatAvxOpt:
|
||||||
// rbx, rsp, rbp, r12-r15 must be saved according to x86 calling convention
|
// rbx, rsp, rbp, r12-r15 must be saved according to x86 calling convention
|
||||||
pushq %r15
|
pushq %r15
|
||||||
|
@ -28,14 +42,37 @@ MatmulFloatAvxOpt:
|
||||||
pushq %r12
|
pushq %r12
|
||||||
pushq %rbx
|
pushq %rbx
|
||||||
pushq %rbp
|
pushq %rbp
|
||||||
pushq %r9
|
pushq %r9 // -56
|
||||||
pushq %r8
|
pushq %r8 // -64
|
||||||
pushq %rcx
|
pushq %rcx // -72
|
||||||
pushq %rdx
|
pushq %rdx // -80
|
||||||
pushq %rsi
|
pushq %rsi // -88
|
||||||
pushq %rdi
|
pushq %rdi // -96
|
||||||
addq $96, %rsp
|
pushq %rsi // -104 rsi
|
||||||
|
pushq %rdi // -112 rdi
|
||||||
|
addq $112, %rsp
|
||||||
|
#ifdef WIN32
|
||||||
|
movq %rcx, %rdi
|
||||||
|
movq %rdx, %rsi
|
||||||
|
movq %r8, %rdx
|
||||||
|
movq %r9, %rcx
|
||||||
|
movq 40(%rsp), %r8 // act_type
|
||||||
|
movq 48(%rsp), %r9 // depth
|
||||||
|
movq %r9, -56(%rsp) // r9
|
||||||
|
movq %rcx, -72(%rsp) // rcx
|
||||||
|
movq %rdx, -80(%rsp) // rdx
|
||||||
|
movq %rsi, -88(%rsp) // rsi
|
||||||
|
movq %rdi, -96(%rsp) // rdi
|
||||||
|
|
||||||
|
movq 56(%rsp), %rbp // row
|
||||||
|
movq %rbp, 8(%rsp)
|
||||||
|
movq 64(%rsp), %rbp // col
|
||||||
|
movq %rbp, 16(%rsp)
|
||||||
|
movq 72(%rsp), %rbp // stride
|
||||||
|
movq %rbp, 24(%rsp)
|
||||||
|
movq 80(%rsp), %rbp // weiteMode
|
||||||
|
movq %rbp, 32(%rsp)
|
||||||
|
#endif
|
||||||
movq 8(%rsp), %rbp
|
movq 8(%rsp), %rbp
|
||||||
movq 16(%rsp), %rbx
|
movq 16(%rsp), %rbx
|
||||||
movq 24(%rsp), %r10
|
movq 24(%rsp), %r10
|
||||||
|
@ -926,10 +963,12 @@ LoopRow:
|
||||||
jmp LoopRow
|
jmp LoopRow
|
||||||
|
|
||||||
LoopRowEnd:
|
LoopRowEnd:
|
||||||
subq $96, %rsp
|
subq $112, %rsp
|
||||||
popq %rdi
|
popq %rdi
|
||||||
popq %rsi
|
popq %rsi
|
||||||
popq %rdx
|
popq %rdx
|
||||||
|
popq %rdx
|
||||||
|
popq %rdx
|
||||||
popq %rcx
|
popq %rcx
|
||||||
popq %r8
|
popq %r8
|
||||||
popq %r9
|
popq %r9
|
||||||
|
@ -941,4 +980,3 @@ LoopRowEnd:
|
||||||
popq %r15
|
popq %r15
|
||||||
retq
|
retq
|
||||||
#endif
|
#endif
|
||||||
#endif
|
|
||||||
|
|
|
@ -681,47 +681,6 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef ENABLE_AVX
|
#ifdef ENABLE_AVX
|
||||||
#ifdef WIN32
|
|
||||||
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
|
|
||||||
int output_width, int input_stride, bool relu, bool relu6, int kernel) {
|
|
||||||
do {
|
|
||||||
float *in[kernel];
|
|
||||||
for (int k = 0; k < kernel; k++) {
|
|
||||||
in[k] = input[k];
|
|
||||||
}
|
|
||||||
input = input + input_stride;
|
|
||||||
|
|
||||||
size_t c = channels;
|
|
||||||
const float *w = weights;
|
|
||||||
float *out = output;
|
|
||||||
memcpy(out, bias, channels * sizeof(float));
|
|
||||||
for (; c >= C8NUM; c -= C8NUM) {
|
|
||||||
for (int i = 0; i < C8NUM; i++) {
|
|
||||||
for (int k = 0; k < kernel; k++) {
|
|
||||||
out[i] += in[k][i] * w[i + k * C8NUM];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
w += kernel * C8NUM;
|
|
||||||
out += C8NUM;
|
|
||||||
for (int k = 0; k < kernel; k++) {
|
|
||||||
in[k] += C8NUM;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int i = 0; i < c; i++) {
|
|
||||||
for (int k = 0; k < kernel; k++) {
|
|
||||||
out[i] += in[k][i] * w[i + k * C8NUM];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (relu) {
|
|
||||||
ReluFp32C8(output, output, channels);
|
|
||||||
}
|
|
||||||
if (relu6) {
|
|
||||||
Relu6Fp32C8(output, output, channels);
|
|
||||||
}
|
|
||||||
output += channels;
|
|
||||||
} while (--output_width != 0);
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
|
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
|
||||||
int output_width, int input_stride, bool relu, bool relu6, int kernel) {
|
int output_width, int input_stride, bool relu, bool relu6, int kernel) {
|
||||||
if (kernel == 9) {
|
if (kernel == 9) {
|
||||||
|
@ -729,7 +688,6 @@ void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#endif
|
|
||||||
|
|
||||||
void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data,
|
void ConvDwIndirection(float *output_data, float **indirect_buffer, const float *weight_data, const float *bias_data,
|
||||||
float *zero_ptr, const ConvParameter *conv_param, int task_id) {
|
float *zero_ptr, const ConvParameter *conv_param, int task_id) {
|
||||||
|
|
|
@ -67,10 +67,8 @@ void ConvDwFp32Indirect5x5(float *output, float **input, const float *weights, c
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef ENABLE_AVX
|
#ifdef ENABLE_AVX
|
||||||
#ifndef WIN32
|
void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, size_t channels,
|
||||||
void ConvDwFp32Avx3x3(float *output, float **input, const float *weights, const float *bias, int channels,
|
size_t output_width, size_t input_stride, size_t relu, size_t relu6);
|
||||||
int output_width, size_t input_stride, size_t relu, size_t relu6);
|
|
||||||
#endif
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
|
void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels,
|
||||||
|
|
|
@ -883,11 +883,7 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
|
||||||
if (out_type == OutType_C8) {
|
if (out_type == OutType_C8) {
|
||||||
MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
|
MatmulFloatSse64(a, b, c, bias, (int)act_type, deep, row, col, stride, 0, 0);
|
||||||
} else {
|
} else {
|
||||||
#ifdef WIN32
|
MatmulFloatAvxOpt(a, b, c, bias, (size_t)act_type, deep, row, col, stride, (size_t)(out_type));
|
||||||
MatMul6x16(a, b, c, bias, act_type, deep, row, col, stride, out_type);
|
|
||||||
#else
|
|
||||||
MatmulFloatAvxOpt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
#elif ENABLE_SSE
|
#elif ENABLE_SSE
|
||||||
if (out_type == OutType_C8) {
|
if (out_type == OutType_C8) {
|
||||||
|
|
|
@ -62,8 +62,8 @@ void MatmulFloatSse64(const float *a, const float *b, float *c, const float *bia
|
||||||
void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
void MatmulFloatSse64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
||||||
int col, int stride, int write_mode);
|
int col, int stride, int write_mode);
|
||||||
#ifdef ENABLE_AVX
|
#ifdef ENABLE_AVX
|
||||||
void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
|
void MatmulFloatAvxOpt(const float *a, const float *b, float *c, const float *bias, size_t act_type, size_t depth,
|
||||||
int col, int stride, int write_mode);
|
size_t row, size_t col, size_t stride, size_t write_mode);
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue