forked from mindspore-Ecosystem/mindspore
!23292 [MS][LITE][develop] optimize x86 init
Merge pull request !23292 from sunsuodong/optimize_x86_init
This commit is contained in:
commit
0589d797e5
|
@ -406,21 +406,79 @@ void PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void *src, void *dst, const i
|
|||
}
|
||||
|
||||
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) {
|
||||
int channel_up8 = UP_ROUND(channel, C8NUM);
|
||||
for (int n = 0; n < batch; n++) {
|
||||
for (int hw = 0; hw < plane; hw++) {
|
||||
for (int c = 0; c < channel; c++) {
|
||||
int c = 0;
|
||||
for (; c < channel; c++) {
|
||||
int c8div = c / C8NUM;
|
||||
int c8mod = c % C8NUM;
|
||||
int src_index = n * plane * channel + hw * channel + c;
|
||||
int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
|
||||
((float *)dst)[dst_index] = ((float *)src)[src_index];
|
||||
}
|
||||
for (; c < channel_up8; c++) {
|
||||
int c8div = c / C8NUM;
|
||||
int c8mod = c % C8NUM;
|
||||
int dst_index = c8div * batch * plane * C8NUM + hw * batch * C8NUM + n * C8NUM + c8mod;
|
||||
((float *)dst)[dst_index] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel) {
|
||||
// pack weight NHWC to C24HWN24 (Priority 24)=>C16HWN16 (Not satisfied 24)=>C8HWN8 (Not satisfied 16)
|
||||
#ifdef ENABLE_AVX
|
||||
int oc_block_num = UP_DIV(channel, C8NUM);
|
||||
int plane16 = plane / C16NUM * C16NUM;
|
||||
for (int i = 0, oc_block = 0; i < oc_block_num; i += oc_block) {
|
||||
oc_block = MSMIN(C3NUM, oc_block_num - i);
|
||||
int oc_remainder = MSMIN(C8NUM * oc_block, channel - i * C8NUM);
|
||||
int oc_remainder_c8 = oc_remainder / C8NUM * C8NUM;
|
||||
int p = 0;
|
||||
for (; p < plane16; p += C16NUM) {
|
||||
int index_plane = i * C8NUM + p * channel;
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
int index_batch = index_plane + b * plane * channel;
|
||||
int oc = 0;
|
||||
int stride = oc_block * C8NUM * batch;
|
||||
for (; oc < oc_remainder_c8; oc += C8NUM) {
|
||||
const float *cur_src = src + index_batch + oc;
|
||||
float *cur_dst = dst + oc;
|
||||
LOAD256X16_F32(r, cur_src, channel);
|
||||
STORE256X16_F32(cur_dst, stride, r);
|
||||
}
|
||||
for (; oc < oc_remainder; ++oc) {
|
||||
for (int k = 0; k < C16NUM; ++k) {
|
||||
dst[oc + stride * k] = src[index_batch + oc + channel * k];
|
||||
}
|
||||
}
|
||||
for (; oc < C8NUM; ++oc) {
|
||||
for (int k = 0; k < C16NUM; ++k) {
|
||||
dst[oc + stride * k] = 0;
|
||||
}
|
||||
}
|
||||
dst += oc_block * C8NUM;
|
||||
}
|
||||
dst += (C16NUM - 1) * oc_block * C8NUM * batch;
|
||||
}
|
||||
for (; p < plane; ++p) {
|
||||
int index_plane = i * C8NUM + p * channel;
|
||||
for (int b = 0; b < batch; ++b) {
|
||||
int index_batch = index_plane + b * plane * channel;
|
||||
int oc = 0;
|
||||
for (; oc < oc_remainder; ++oc) {
|
||||
dst[oc] = src[index_batch + oc];
|
||||
}
|
||||
for (; oc < C8NUM; ++oc) {
|
||||
dst[oc] = 0;
|
||||
}
|
||||
dst += oc_block * C8NUM;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
int oc_block = 0;
|
||||
int oc_block_num = UP_DIV(channel, C8NUM);
|
||||
for (int i = 0; i < oc_block_num; i += oc_block) {
|
||||
|
@ -437,6 +495,7 @@ void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, in
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel) {
|
||||
|
|
|
@ -235,6 +235,24 @@ static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) {
|
|||
MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
|
||||
MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num);
|
||||
|
||||
#define LOAD256X16_F32(src, input_ptr, num) \
|
||||
MS_FLOAT32X8 src##1 = MS_LD256_F32(input_ptr + 0 * num); \
|
||||
MS_FLOAT32X8 src##2 = MS_LD256_F32(input_ptr + 1 * num); \
|
||||
MS_FLOAT32X8 src##3 = MS_LD256_F32(input_ptr + 2 * num); \
|
||||
MS_FLOAT32X8 src##4 = MS_LD256_F32(input_ptr + 3 * num); \
|
||||
MS_FLOAT32X8 src##5 = MS_LD256_F32(input_ptr + 4 * num); \
|
||||
MS_FLOAT32X8 src##6 = MS_LD256_F32(input_ptr + 5 * num); \
|
||||
MS_FLOAT32X8 src##7 = MS_LD256_F32(input_ptr + 6 * num); \
|
||||
MS_FLOAT32X8 src##8 = MS_LD256_F32(input_ptr + 7 * num); \
|
||||
MS_FLOAT32X8 src##9 = MS_LD256_F32(input_ptr + 8 * num); \
|
||||
MS_FLOAT32X8 src##10 = MS_LD256_F32(input_ptr + 9 * num); \
|
||||
MS_FLOAT32X8 src##11 = MS_LD256_F32(input_ptr + 10 * num); \
|
||||
MS_FLOAT32X8 src##12 = MS_LD256_F32(input_ptr + 11 * num); \
|
||||
MS_FLOAT32X8 src##13 = MS_LD256_F32(input_ptr + 12 * num); \
|
||||
MS_FLOAT32X8 src##14 = MS_LD256_F32(input_ptr + 13 * num); \
|
||||
MS_FLOAT32X8 src##15 = MS_LD256_F32(input_ptr + 14 * num); \
|
||||
MS_FLOAT32X8 src##16 = MS_LD256_F32(input_ptr + 15 * num);
|
||||
|
||||
#define STORE256X8_F32(output_ptr, num, dst) \
|
||||
MS_ST256_F32(output_ptr + 0 * num, dst##1); \
|
||||
MS_ST256_F32(output_ptr + 1 * num, dst##2); \
|
||||
|
@ -245,6 +263,24 @@ static inline MS_FLOAT32X8 MS_SQRTFX8_F32(MS_FLOAT32X8 src) {
|
|||
MS_ST256_F32(output_ptr + 6 * num, dst##7); \
|
||||
MS_ST256_F32(output_ptr + 7 * num, dst##8);
|
||||
|
||||
#define STORE256X16_F32(output_ptr, num, dst) \
|
||||
MS_ST256_F32(output_ptr + 0 * num, dst##1); \
|
||||
MS_ST256_F32(output_ptr + 1 * num, dst##2); \
|
||||
MS_ST256_F32(output_ptr + 2 * num, dst##3); \
|
||||
MS_ST256_F32(output_ptr + 3 * num, dst##4); \
|
||||
MS_ST256_F32(output_ptr + 4 * num, dst##5); \
|
||||
MS_ST256_F32(output_ptr + 5 * num, dst##6); \
|
||||
MS_ST256_F32(output_ptr + 6 * num, dst##7); \
|
||||
MS_ST256_F32(output_ptr + 7 * num, dst##8); \
|
||||
MS_ST256_F32(output_ptr + 8 * num, dst##9); \
|
||||
MS_ST256_F32(output_ptr + 9 * num, dst##10); \
|
||||
MS_ST256_F32(output_ptr + 10 * num, dst##11); \
|
||||
MS_ST256_F32(output_ptr + 11 * num, dst##12); \
|
||||
MS_ST256_F32(output_ptr + 12 * num, dst##13); \
|
||||
MS_ST256_F32(output_ptr + 13 * num, dst##14); \
|
||||
MS_ST256_F32(output_ptr + 14 * num, dst##15); \
|
||||
MS_ST256_F32(output_ptr + 15 * num, dst##16);
|
||||
|
||||
static inline MS_FLOAT32X8 MS_TANHX8_F32(MS_FLOAT32X8 src) {
|
||||
static const MS_FLOAT32X8 data0 = {378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f, 378.0f};
|
||||
static const MS_FLOAT32X8 data1 = {17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f, 17325.0f};
|
||||
|
|
|
@ -67,7 +67,6 @@ int DeConvolutionCPUKernel::MallocWeightBiasData() {
|
|||
MS_LOG(ERROR) << "deconv malloc packed_weight_ error!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(packed_weight_, 0, pack_weight_size);
|
||||
}
|
||||
|
||||
bias_data_ = MallocAlignedData(C32NUM, output_aligned_size * sizeof(float));
|
||||
|
|
Loading…
Reference in New Issue