!23292 [MS][LITE][develop] optimize x86 init

Merge pull request !23292 from sunsuodong/optimize_x86_init
This commit is contained in:
i-robot 2021-09-15 06:11:53 +00:00 committed by Gitee
commit 0589d797e5
3 changed files with 96 additions and 2 deletions

View File

@ -406,21 +406,79 @@ void PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void *src, void *dst, const i
}
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
int channel_up8 = UP_ROUND(channel, C8NUM);
for (int n = 0; n < batch; n++) {
for (int hw = 0; hw < plane; hw++) {
for (int c = 0; c < channel; c++) {
int c = 0;
for (; c < channel; c++) {
int c8div = c / C8NUM;
int c8mod = c % C8NUM;
int src_index = n * plane * channel + hw * channel + c;
int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
((float *)dst)[dst_index] = ((float *)src)[src_index];
}
for (; c < channel_up8; c++) {
int c8div = c / C8NUM;
int c8mod = c % C8NUM;
int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
((float *)dst)[dst_index] = 0;
}
}
}
}
void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel) {
// pack weight NHWC to C24HWN24 (Priority 24)=>C16HWN16 (Not satisfied 24)=>C8HWN8 (Not satisfied 16)
#ifdef ENABLE_AVX
int oc_block_num = UP_DIV(channel, C8NUM);
int plane16 = plane / C16NUM * C16NUM;
for (int i = 0, oc_block = 0; i < oc_block_num; i += oc_block) {
oc_block = MSMIN(C3NUM, oc_block_num - i);
int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM);
int oc_remainder_c8 = oc_remainder / C8NUM * C8NUM;
int p = 0;
for (; p < plane16; p += C16NUM) {
int index_plane = i * C8NUM + p * channel;
for (int b = 0; b < batch; ++b) {
int index_batch = index_plane + b * plane * channel;
int oc = 0;
int stride = oc_block * C8NUM * batch;
for (; oc < oc_remainder_c8; oc += C8NUM) {
const float *cur_src = src + index_batch + oc;
float *cur_dst = dst + oc;
LOAD256X16_F32(r, cur_src, channel);
STORE256X16_F32(cur_dst, stride, r);
}
for (; oc < oc_remainder; ++oc) {
for (int k = 0; k < C16NUM; ++k) {
dst[oc + stride * k] = src[index_batch + oc + channel * k];
}
}
for (; oc < C8NUM; ++oc) {
for (int k = 0; k < C16NUM; ++k) {
dst[oc + stride * k] = 0;
}
}
dst += oc_block * C8NUM;
}
dst += (C16NUM - 1) * oc_block * C8NUM * batch;
}
for (; p < plane; ++p) {
int index_plane = i * C8NUM + p * channel;
for (int b = 0; b < batch; ++b) {
int index_batch = index_plane + b * plane * channel;
int oc = 0;
for (; oc < oc_remainder; ++oc) {
dst[oc] = src[index_batch + oc];
}
for (; oc < C8NUM; ++oc) {
dst[oc] = 0;
}
dst += oc_block * C8NUM;
}
}
}
#else
int oc_block = 0;
int oc_block_num = UP_DIV(channel, C8NUM);
for (int i = 0; i < oc_block_num; i += oc_block) {
@ -437,6 +495,7 @@ void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, in
}
}
}
#endif
}
void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) {

View File

@ -235,6 +235,24 @@ static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) {
MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num);
#define LOAD256X16_F32(src, input_ptr, num) \
MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \
MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \
MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \
MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num); \
MS_FLOAT32X8 src##9 = MS_LD256_F32(input_ptr + 8 * num); \
MS_FLOAT32X8 src##10 = MS_LD256_F32(input_ptr + 9 * num); \
MS_FLOAT32X8 src##11 = MS_LD256_F32(input_ptr + 10 * num); \
MS_FLOAT32X8 src##12 = MS_LD256_F32(input_ptr + 11 * num); \
MS_FLOAT32X8 src##13 = MS_LD256_F32(input_ptr + 12 * num); \
MS_FLOAT32X8 src##14 = MS_LD256_F32(input_ptr + 13 * num); \
MS_FLOAT32X8 src##15 = MS_LD256_F32(input_ptr + 14 * num); \
MS_FLOAT32X8 src##16 = MS_LD256_F32(input_ptr + 15 * num);
#define STORE256X8_F32(output_ptr, num, dst) \
MS_ST256_F32(output_ptr + 0 * num, dst##1); \
MS_ST256_F32(output_ptr + 1 * num, dst##2); \
@ -245,6 +263,24 @@ static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) {
MS_ST256_F32(output_ptr + 6 * num, dst##7); \
MS_ST256_F32(output_ptr + 7 * num, dst##8);
#define STORE256X16_F32(output_ptr, num, dst) \
MS_ST256_F32(output_ptr + 0 * num, dst##1); \
MS_ST256_F32(output_ptr + 1 * num, dst##2); \
MS_ST256_F32(output_ptr + 2 * num, dst##3); \
MS_ST256_F32(output_ptr + 3 * num, dst##4); \
MS_ST256_F32(output_ptr + 4 * num, dst##5); \
MS_ST256_F32(output_ptr + 5 * num, dst##6); \
MS_ST256_F32(output_ptr + 6 * num, dst##7); \
MS_ST256_F32(output_ptr + 7 * num, dst##8); \
MS_ST256_F32(output_ptr + 8 * num, dst##9); \
MS_ST256_F32(output_ptr + 9 * num, dst##10); \
MS_ST256_F32(output_ptr + 10 * num, dst##11); \
MS_ST256_F32(output_ptr + 11 * num, dst##12); \
MS_ST256_F32(output_ptr + 12 * num, dst##13); \
MS_ST256_F32(output_ptr + 13 * num, dst##14); \
MS_ST256_F32(output_ptr + 14 * num, dst##15); \
MS_ST256_F32(output_ptr + 15 * num, dst##16);
static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) {
static const MS_FLOAT32X8 data0 = {378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f};
static const MS_FLOAT32X8 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f};

View File

@ -67,7 +67,6 @@ int DeConvolutionCPUKernel::MallocWeightBiasData() {
MS_LOG(ERROR) << "deconv malloc packed_weight_ error!";
return RET_ERROR;
}
memset(packed_weight_, 0, pack_weight_size);
}
bias_data_ = MallocAlignedData(C32NUM, output_aligned_size * sizeof(float));