forked from mindspore-Ecosystem/mindspore
!6702 [MS][LITE]optimize conv preprocess
Merge pull request !6702 from fuzhiye/tmp
This commit is contained in:
commit
87859d2ef9
|
@ -221,7 +221,6 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
|
|||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int kernel_plane = kernel_h * kernel_w;
|
||||
int unit_size = kernel_plane * ic4 * C4NUM;
|
||||
int packed_input_size = output_tile_count * TILE_NUM * unit_size;
|
||||
bool relu = conv_param->act_type_ == ActType_Relu;
|
||||
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
||||
|
||||
|
@ -232,13 +231,14 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
|
|||
size_t output_offset = out_channel * sizeof(float);
|
||||
|
||||
for (int b = 0; b < in_batch; b++) {
|
||||
int in_batch_offset = b * ic4 * C4NUM * in_h * in_w;
|
||||
int in_batch_offset = b * in_channel * in_h * in_w;
|
||||
int out_batch_offset = b * out_channel * out_h * out_w;
|
||||
int gemm_in_batch_offset = b * packed_input_size;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
|
||||
int start_index = thread_id * TILE_NUM;
|
||||
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
|
||||
float *gemm_input = packed_input + thread_id * unit_size * TILE_NUM + gemm_in_batch_offset;
|
||||
float *gemm_input = packed_input + task_id * unit_size * TILE_NUM;
|
||||
size_t packed_input_size = unit_size * TILE_NUM * sizeof(float);
|
||||
memset(gemm_input, 0, packed_input_size);
|
||||
Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index);
|
||||
|
||||
int out_offset = thread_id * TILE_NUM * out_channel + out_batch_offset;
|
||||
|
@ -291,7 +291,7 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
// step 1 : filter transform (pre-processed offline)
|
||||
// step 2 : input transform (online)
|
||||
for (int b = 0; b < in_batch; b++) {
|
||||
int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
|
||||
int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_;
|
||||
int out_batch_offset = b * out_channel * conv_param->output_w_ * conv_param->output_h_;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
|
||||
int out_tile_index = thread_id * tile_num;
|
||||
|
@ -322,144 +322,9 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
}
|
||||
}
|
||||
|
||||
void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel,
|
||||
int output_unit) {
|
||||
int out_h_block_num = UP_DIV(height, output_unit);
|
||||
int out_w_block_num = UP_DIV(width, output_unit);
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
int c4_block = C4NUM * out_h_block_num * output_unit * out_w_block_num * output_unit;
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_batch_offset = b * c4 * c4_block;
|
||||
int dst_batch_offset = b * height * width * channel;
|
||||
for (int h = 0; h < height; h++) {
|
||||
int src_h_offset = src_batch_offset + C4NUM * (h * out_w_block_num * output_unit);
|
||||
int dst_h_offset = dst_batch_offset + h * width * channel;
|
||||
for (int w = 0; w < width; w++) {
|
||||
int src_w_offset = src_h_offset + w * C4NUM;
|
||||
int dst_w_offset = dst_h_offset + w * channel;
|
||||
for (int c = 0; c < c4 - 1; c++) {
|
||||
int src_c4_offset = src_w_offset + c * c4_block;
|
||||
int dst_c4_offset = dst_w_offset + c * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
vst1q_f32(dst + dst_c4_offset, vld1q_f32(src + src_c4_offset));
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
dst[dst_c4_offset + i] = src[src_c4_offset + i];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
int c_res = channel - (c4 - 1) * C4NUM;
|
||||
int src_c_res_offset = (c4 - 1) * c4_block;
|
||||
int dst_c_res_offset = (c4 - 1) * C4NUM;
|
||||
for (int c = 0; c < c_res; c++) {
|
||||
int src_c4_res_offset = src_w_offset + src_c_res_offset + c;
|
||||
int dst_c4_res_offset = dst_w_offset + dst_c_res_offset + c;
|
||||
dst[dst_c4_res_offset] = src[src_c4_res_offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UnPackWinogradReluOutput(const float *src, float *dst, int batch, int height, int width, int channel,
|
||||
int output_unit) {
|
||||
int out_h_block_num = UP_DIV(height, output_unit);
|
||||
int out_w_block_num = UP_DIV(width, output_unit);
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
int c4_block = C4NUM * out_h_block_num * output_unit * out_w_block_num * output_unit;
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_batch_offset = b * c4 * c4_block;
|
||||
int dst_batch_offset = b * height * width * channel;
|
||||
for (int h = 0; h < height; h++) {
|
||||
int src_h_offset = src_batch_offset + C4NUM * (h * out_w_block_num * output_unit);
|
||||
int dst_h_offset = dst_batch_offset + h * width * channel;
|
||||
for (int w = 0; w < width; w++) {
|
||||
int src_w_offset = src_h_offset + w * C4NUM;
|
||||
int dst_w_offset = dst_h_offset + w * channel;
|
||||
for (int c = 0; c < c4 - 1; c++) {
|
||||
int src_c4_offset = src_w_offset + c * c4_block;
|
||||
int dst_c4_offset = dst_w_offset + c * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t input_ptr = vld1q_f32(src + src_c4_offset);
|
||||
float32x4_t zero = vdupq_n_f32(0);
|
||||
input_ptr = vmaxq_f32(zero, input_ptr);
|
||||
vst1q_f32(dst + dst_c4_offset, input_ptr);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
float input_data = src[src_c4_offset + i];
|
||||
input_data = input_data < 0 ? 0 : input_data;
|
||||
dst[dst_c4_offset + i] = input_data;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
int c_res = channel - (c4 - 1) * C4NUM;
|
||||
int src_c_res_offset = (c4 - 1) * c4_block;
|
||||
int dst_c_res_offset = (c4 - 1) * C4NUM;
|
||||
for (int c = 0; c < c_res; c++) {
|
||||
int src_c4_res_offset = src_w_offset + src_c_res_offset + c;
|
||||
int dst_c4_res_offset = dst_w_offset + dst_c_res_offset + c;
|
||||
float input_data = src[src_c4_res_offset];
|
||||
input_data = input_data < 0 ? 0 : input_data;
|
||||
dst[dst_c4_res_offset] = input_data;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UnPackWinogradRelu6Output(const float *src, float *dst, int batch, int height, int width, int channel,
|
||||
int output_unit) {
|
||||
int out_h_block_num = UP_DIV(height, output_unit);
|
||||
int out_w_block_num = UP_DIV(width, output_unit);
|
||||
int c4 = UP_DIV(channel, C4NUM);
|
||||
int c4_block = C4NUM * out_h_block_num * output_unit * out_w_block_num * output_unit;
|
||||
for (int b = 0; b < batch; b++) {
|
||||
int src_batch_offset = b * c4 * c4_block;
|
||||
int dst_batch_offset = b * height * width * channel;
|
||||
for (int h = 0; h < height; h++) {
|
||||
int src_h_offset = src_batch_offset + C4NUM * (h * out_w_block_num * output_unit);
|
||||
int dst_h_offset = dst_batch_offset + h * width * channel;
|
||||
for (int w = 0; w < width; w++) {
|
||||
int src_w_offset = src_h_offset + w * C4NUM;
|
||||
int dst_w_offset = dst_h_offset + w * channel;
|
||||
for (int c = 0; c < c4 - 1; c++) {
|
||||
int src_c4_offset = src_w_offset + c * c4_block;
|
||||
int dst_c4_offset = dst_w_offset + c * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
float32x4_t input_ptr = vld1q_f32(src + src_c4_offset);
|
||||
float32x4_t zero = vdupq_n_f32(0);
|
||||
float32x4_t six = vdupq_n_f32(6);
|
||||
input_ptr = vmaxq_f32(zero, input_ptr);
|
||||
input_ptr = vminq_f32(six, input_ptr);
|
||||
vst1q_f32(dst + dst_c4_offset, input_ptr);
|
||||
#else
|
||||
for (int i = 0; i < C4NUM; ++i) {
|
||||
float input_data = src[src_c4_offset + i];
|
||||
input_data = input_data < 0 ? 0 : input_data;
|
||||
input_data = input_data > 6 ? 6 : input_data;
|
||||
dst[dst_c4_offset + i] = input_data;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
int c_res = channel - (c4 - 1) * C4NUM;
|
||||
int src_c_res_offset = (c4 - 1) * c4_block;
|
||||
int dst_c_res_offset = (c4 - 1) * C4NUM;
|
||||
for (int c = 0; c < c_res; c++) {
|
||||
int src_c4_res_offset = src_w_offset + src_c_res_offset + c;
|
||||
int dst_c4_res_offset = dst_w_offset + dst_c_res_offset + c;
|
||||
float input_data = src[src_c4_res_offset];
|
||||
input_data = input_data < 0 ? 0 : input_data;
|
||||
input_data = input_data > 6 ? 6 : input_data;
|
||||
dst[dst_c4_res_offset] = input_data;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fp32 conv3x3
|
||||
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, TmpBufferAddress *buffer_list,
|
||||
int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func) {
|
||||
int task_id, ConvParameter *conv_param) {
|
||||
int thread_count = conv_param->thread_num_;
|
||||
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
|
||||
int output_channel = conv_param->output_channel_;
|
||||
|
@ -488,7 +353,7 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
|
|||
|
||||
int input_batch = conv_param->input_batch_;
|
||||
for (int batch = 0; batch < input_batch; batch++) {
|
||||
int in_batch_offset = batch * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
|
||||
int in_batch_offset = batch * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
|
||||
int nc4hw4_buffer_offset = batch * oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_;
|
||||
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
|
||||
|
|
|
@ -67,7 +67,7 @@ void UnPackWinogradRelu6Output(const float *src, float *dst, int batch, int heig
|
|||
|
||||
// fp32 conv3x3
|
||||
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, TmpBufferAddress *buffer_list,
|
||||
int task_id, ConvParameter *conv_param, GEMM_FUNC_FP32 gemm_func);
|
||||
int task_id, ConvParameter *conv_param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -273,7 +273,6 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
|
|||
int kernel_plane = kernel_h * kernel_w;
|
||||
int plane_block = UP_DIV(kernel_plane, C4NUM);
|
||||
int unit_size = plane_block * C4NUM * ic4 * C4NUM;
|
||||
int packed_input_size = output_tile_count * tile_n * unit_size;
|
||||
int input_sum_offset;
|
||||
if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) {
|
||||
input_sum_offset = tile_n * oc4 * C4NUM;
|
||||
|
@ -284,12 +283,11 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
|
|||
for (int b = 0; b < in_batch; b++) {
|
||||
int in_batch_offset = b * ic4 * C4NUM * in_h * in_w;
|
||||
int out_batch_offset = b * out_channel * out_h * out_w;
|
||||
int gemm_in_batch_offset = b * packed_input_size;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
|
||||
int start_index = thread_id * tile_n;
|
||||
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
|
||||
int32_t *tmp_input_sum = input_sum + task_id * input_sum_offset;
|
||||
int8_t *gemm_input = packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset;
|
||||
int8_t *gemm_input = packed_input + task_id * unit_size * tile_n;
|
||||
// clear tmp buffer before compute
|
||||
memset(gemm_input, (int8_t)input_zp, unit_size * tile_n);
|
||||
int out_offset = thread_id * tile_n * out_channel + out_batch_offset;
|
||||
|
@ -336,7 +334,6 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
|
|||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int kernel_plane = kernel_h * kernel_w;
|
||||
int unit_size = kernel_plane * ic4 * C4NUM;
|
||||
int packed_input_size = output_tile_count * tile_n * unit_size;
|
||||
int input_sum_offset;
|
||||
if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) {
|
||||
input_sum_offset = tile_n * oc4 * C4NUM;
|
||||
|
@ -347,12 +344,11 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
|
|||
for (int b = 0; b < in_batch; b++) {
|
||||
int in_batch_offset = b * ic4 * C4NUM * in_h * in_w;
|
||||
int out_batch_offset = b * out_channel * out_h * out_w;
|
||||
int gemm_in_batch_offset = b * packed_input_size;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
|
||||
int start_index = thread_id * tile_n;
|
||||
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
|
||||
int32_t *tmp_input_sum = input_sum + task_id * input_sum_offset;
|
||||
int8_t *gemm_input = packed_input + thread_id * unit_size * tile_n + gemm_in_batch_offset;
|
||||
int8_t *gemm_input = packed_input + task_id * unit_size * tile_n;
|
||||
// clear tmp buffer before compute
|
||||
memset(gemm_input, (int8_t)input_zp, unit_size * tile_n);
|
||||
int out_offset = thread_id * tile_n * out_channel + out_batch_offset;
|
||||
|
|
|
@ -297,24 +297,24 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
|
|||
int in_h = conv_param->input_h_;
|
||||
int in_w = conv_param->input_w_;
|
||||
int out_w = conv_param->output_w_;
|
||||
int ic4_minus = in_channel / C4NUM;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
memset(packed_input, 0, kernel_h * kernel_w * ic4 * C4NUM * TILE_NUM * sizeof(float));
|
||||
|
||||
for (int i = 0; i < real_cal_num; i++) {
|
||||
int block_start = block_index + i;
|
||||
int input_h = block_start / out_w * stride_h - pad_h;
|
||||
int input_w = block_start % out_w * stride_w - pad_w;
|
||||
int input_stride = input_h * in_w * ic4 * C4NUM + input_w * ic4 * C4NUM;
|
||||
int input_stride = (input_h * in_w + input_w) * in_channel;
|
||||
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
|
||||
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
|
||||
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
|
||||
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
|
||||
for (int j = kh_s; j < kh_e; j++) {
|
||||
int input_y_stride = j * dilation_h * in_w * ic4 * C4NUM + input_stride;
|
||||
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
|
||||
for (int n = kw_s; n < kw_e; n++) {
|
||||
int input_x_stride = input_y_stride + n * dilation_w * ic4 * C4NUM;
|
||||
int input_x_stride = input_y_stride + n * dilation_w * in_channel;
|
||||
int input_plane_offset = (j * kernel_w + n) * C8NUM * C4NUM * ic4 + i * C4NUM;
|
||||
for (int m = 0; m < ic4; m++) {
|
||||
for (int m = 0; m < ic4_minus; m++) {
|
||||
int channel_block_stride = input_x_stride + m * C4NUM;
|
||||
int channel_block_offset = input_plane_offset + m * C8NUM * C4NUM;
|
||||
#ifdef ENABLE_NEON
|
||||
|
@ -325,9 +325,15 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
|
|||
}
|
||||
#endif
|
||||
} // channel_block loop
|
||||
} // kernel_w loop
|
||||
} // kernel_h loop
|
||||
} // tile num loop
|
||||
int ic_res = conv_param->input_channel_ - ic4_minus * C4NUM;
|
||||
for (int l = 0; l < ic_res; ++l) {
|
||||
int channel_block_stride = input_x_stride + ic4_minus * C4NUM + l;
|
||||
int channel_block_offset = input_plane_offset + ic4_minus * C8NUM * C4NUM + l;
|
||||
packed_input[channel_block_offset] = input_data[channel_block_stride];
|
||||
}
|
||||
} // kernel_w loop
|
||||
} // kernel_h loop
|
||||
} // tile num loop
|
||||
}
|
||||
|
||||
void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index,
|
||||
|
@ -346,6 +352,7 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
|
|||
int in_channel = conv_param->input_channel_;
|
||||
int in_h = conv_param->input_h_;
|
||||
int in_w = conv_param->input_w_;
|
||||
int ic4_minus = in_channel / C4NUM;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int oc4 = UP_DIV(conv_param->output_channel_, C4NUM);
|
||||
int out_w = conv_param->output_w_;
|
||||
|
@ -362,19 +369,19 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
|
|||
input_accumulator += ic4 * C4NUM * conv_param->conv_quant_arg_.input_quant_args_[0].zp_ * kernel_w;
|
||||
continue;
|
||||
}
|
||||
int input_y_stride = input_y * in_w * ic4 * C4NUM;
|
||||
int input_y_stride = input_y * in_w * in_channel;
|
||||
for (int n = 0; n < kernel_w; n++) {
|
||||
int input_x = input_w + n * dilation_w;
|
||||
if (input_x < 0 || input_x >= in_w) {
|
||||
input_accumulator += ic4 * C4NUM * conv_param->conv_quant_arg_.input_quant_args_[0].zp_;
|
||||
continue;
|
||||
}
|
||||
int input_x_stride = input_y_stride + input_x * ic4 * C4NUM;
|
||||
int input_x_stride = input_y_stride + input_x * in_channel;
|
||||
int plane_c4_block = (j * kernel_w + n) / C4NUM;
|
||||
int plane_c4_res = (j * kernel_w + n) % C4NUM;
|
||||
int input_plane_offset =
|
||||
plane_c4_block * tile_num * C4NUM * C4NUM * ic4 + plane_c4_res * C4NUM + input_cal_num_offset;
|
||||
for (int m = 0; m < ic4; m++) {
|
||||
for (int m = 0; m < ic4_minus; m++) {
|
||||
int channel_block_stride = input_x_stride + m * C4NUM;
|
||||
int channel_block_offset = input_plane_offset + m * tile_num * C4NUM * C4NUM;
|
||||
(packed_input + channel_block_offset)[0] = (input_data + channel_block_stride)[0];
|
||||
|
@ -386,8 +393,15 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
|
|||
input_accumulator += (packed_input + channel_block_offset)[2];
|
||||
input_accumulator += (packed_input + channel_block_offset)[3];
|
||||
} // channel_block loop
|
||||
} // kernel_w loop
|
||||
} // kernel_h loop
|
||||
int ic_res = conv_param->input_channel_ - ic4_minus * C4NUM;
|
||||
for (int l = 0; l < ic_res; ++l) {
|
||||
int channel_block_stride = input_x_stride + ic4_minus * C4NUM + l;
|
||||
int channel_block_offset = input_plane_offset + ic4_minus * tile_num * C4NUM + l;
|
||||
packed_input[channel_block_offset] = input_data[channel_block_stride];
|
||||
input_accumulator += (packed_input + channel_block_offset)[0];
|
||||
}
|
||||
} // kernel_w loop
|
||||
} // kernel_h loop
|
||||
if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC)) {
|
||||
continue;
|
||||
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
|
||||
|
@ -419,6 +433,7 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
|
|||
int in_channel = conv_param->input_channel_;
|
||||
int in_h = conv_param->input_h_;
|
||||
int in_w = conv_param->input_w_;
|
||||
int ic4_minus = in_channel / C4NUM;
|
||||
int ic4 = UP_DIV(in_channel, C4NUM);
|
||||
int oc4 = UP_DIV(conv_param->output_channel_, C4NUM);
|
||||
int out_w = conv_param->output_w_;
|
||||
|
@ -428,26 +443,29 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
|
|||
int block_start = block_index + i;
|
||||
int input_h = block_start / out_w * stride_h - pad_h;
|
||||
int input_w = block_start % out_w * stride_w - pad_w;
|
||||
for (int j = 0; j < kernel_h; j++) {
|
||||
int input_y = input_h + j * dilation_h;
|
||||
if (input_y < 0 || input_y >= in_h) {
|
||||
continue;
|
||||
}
|
||||
int input_y_stride = input_y * in_w * ic4 * C4NUM;
|
||||
for (int n = 0; n < kernel_w; n++) {
|
||||
int input_x = input_w + n * dilation_w;
|
||||
if (input_x < 0 || input_x >= in_w) {
|
||||
continue;
|
||||
}
|
||||
int input_x_stride = input_y_stride + input_x * ic4 * C4NUM;
|
||||
int input_stride = input_h * in_w * in_channel + input_w * in_channel;
|
||||
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
|
||||
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
|
||||
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
|
||||
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
|
||||
for (int j = kh_s; j < kh_e; j++) {
|
||||
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
|
||||
for (int n = kw_s; n < kw_e; n++) {
|
||||
int input_x_stride = input_y_stride + n * dilation_w * in_channel;
|
||||
int input_plane_offset = (j * kernel_w + n) * tile_num * C4NUM * ic4 + i * C4NUM;
|
||||
for (int m = 0; m < ic4; m++) {
|
||||
for (int m = 0; m < ic4_minus; m++) {
|
||||
int channel_block_stride = input_x_stride + m * C4NUM;
|
||||
int channel_block_offset = input_plane_offset + m * tile_num * C4NUM;
|
||||
memcpy(packed_input + channel_block_offset, input_data + channel_block_stride, 4);
|
||||
} // channel_block loop
|
||||
} // kernel_w loop
|
||||
} // kernel_h loop
|
||||
int ic_res = conv_param->input_channel_ - ic4_minus * C4NUM;
|
||||
for (int l = 0; l < ic_res; ++l) {
|
||||
int channel_block_stride = input_x_stride + ic4_minus * C4NUM + l;
|
||||
int channel_block_offset = input_plane_offset + ic4_minus * tile_num * C4NUM + l;
|
||||
packed_input[channel_block_offset] = input_data[channel_block_stride];
|
||||
}
|
||||
} // kernel_w loop
|
||||
} // kernel_h loop
|
||||
int32_t input_accumulator = 0;
|
||||
for (int j = 0; j < block_size; j++) {
|
||||
int block_offset = j * tile_num * ic4 * C4NUM + i * C4NUM;
|
||||
|
|
|
@ -41,30 +41,48 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
|
|||
int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s);
|
||||
int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s);
|
||||
|
||||
int src_plane_offset = ic4 * C4NUM * (src_y_s * input_w + src_x_s);
|
||||
int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s);
|
||||
int dst_plane_offset = c * C4NUM * ic4;
|
||||
for (int ic = 0; ic < ic4; ic++) {
|
||||
// clear tmp buffer
|
||||
memset(tmp_data, 0, input_unit * input_unit * C4NUM * sizeof(float));
|
||||
|
||||
// get real input block with padding
|
||||
int real_c = in_channel - ic * C4NUM;
|
||||
real_c = real_c > C4NUM ? C4NUM : real_c;
|
||||
int src_ic4_offset = src_plane_offset + ic * C4NUM;
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * ic4 * C4NUM;
|
||||
int dst_y_offset = interval * input_unit * C4NUM + interval_x_s * C4NUM;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * ic4 * C4NUM;
|
||||
int dst_x_offset = dst_y_offset + j * C4NUM;
|
||||
float *src_addr = (float *)(input_data) + src_x_offset;
|
||||
float *dst_addr = tmp_data + dst_x_offset;
|
||||
// get real input block with padding
|
||||
if (real_c == C4NUM) {
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * in_channel;
|
||||
int dst_y_offset = interval * input_unit * C4NUM + interval_x_s * C4NUM;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * in_channel;
|
||||
int dst_x_offset = dst_y_offset + j * C4NUM;
|
||||
float *src_addr = (float *)(input_data) + src_x_offset;
|
||||
float *dst_addr = tmp_data + dst_x_offset;
|
||||
#ifdef ENABLE_NEON
|
||||
vst1q_f32(dst_addr, vld1q_f32(src_addr));
|
||||
vst1q_f32(dst_addr, vld1q_f32(src_addr));
|
||||
#else
|
||||
for (int k = 0; k < C4NUM; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
for (int k = 0; k < C4NUM; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
} // interval x loop
|
||||
} // interval y loop
|
||||
} else {
|
||||
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
|
||||
int src_y_offset = src_ic4_offset + (interval * input_w + interval_x_s) * in_channel;
|
||||
int dst_y_offset = interval * input_unit * C4NUM + interval_x_s * C4NUM;
|
||||
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
|
||||
int src_x_offset = src_y_offset + j * in_channel;
|
||||
int dst_x_offset = dst_y_offset + j * C4NUM;
|
||||
float *src_addr = (float *)(input_data) + src_x_offset;
|
||||
float *dst_addr = tmp_data + dst_x_offset;
|
||||
for (int k = 0; k < real_c; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
} // interval x loop
|
||||
} // interval y loop
|
||||
}
|
||||
// input transform
|
||||
#ifdef ENABLE_ARM32
|
||||
|
@ -314,29 +332,47 @@ void Conv3x3Fp32InputTransform(const float *input_data, float *trans_input, floa
|
|||
int real_y_start = origin_y > 0 ? 0 : -origin_y;
|
||||
int real_y_end = (origin_y + input_unit) < input_height ? input_unit : (input_height - origin_y);
|
||||
|
||||
int src_plane_offset = ic4 * C4NUM * (origin_y * input_width + origin_x);
|
||||
int src_plane_offset = input_channel * (origin_y * input_width + origin_x);
|
||||
int dst_plane_offset = cal_id * C4NUM * ic4;
|
||||
for (int ic = 0; ic < ic4; ic++) {
|
||||
// clear tmp buffer
|
||||
memset(tmp_data, 0, input_unit * input_unit * C4NUM * sizeof(float));
|
||||
int real_c = input_channel - ic * C4NUM;
|
||||
real_c = real_c > C4NUM ? C4NUM : real_c;
|
||||
|
||||
// get real input block with padding
|
||||
int src_ic4_offset = src_plane_offset + ic * C4NUM;
|
||||
for (int interval = real_y_start; interval < real_y_end; interval++) {
|
||||
int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * ic4 * C4NUM;
|
||||
int dst_y_offset = interval * input_unit * C4NUM + real_x_start * C4NUM;
|
||||
for (int j = 0; j < (real_x_end - real_x_start); j++) {
|
||||
int src_x_offset = src_y_offset + j * ic4 * C4NUM;
|
||||
int dst_x_offset = dst_y_offset + j * C4NUM;
|
||||
float *src_addr = (float *)(input_data) + src_x_offset;
|
||||
float *dst_addr = tmp_data + dst_x_offset;
|
||||
if (real_c == C4NUM) {
|
||||
for (int interval = real_y_start; interval < real_y_end; interval++) {
|
||||
int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * input_channel;
|
||||
int dst_y_offset = interval * input_unit * C4NUM + real_x_start * C4NUM;
|
||||
for (int j = 0; j < (real_x_end - real_x_start); j++) {
|
||||
int src_x_offset = src_y_offset + j * input_channel;
|
||||
int dst_x_offset = dst_y_offset + j * C4NUM;
|
||||
float *src_addr = (float *)(input_data) + src_x_offset;
|
||||
float *dst_addr = tmp_data + dst_x_offset;
|
||||
#ifdef ENABLE_NEON
|
||||
vst1q_f32(dst_addr, vld1q_f32(src_addr));
|
||||
vst1q_f32(dst_addr, vld1q_f32(src_addr));
|
||||
#else
|
||||
for (int k = 0; k < C4NUM; k++) {
|
||||
(dst_addr + k)[0] = (src_addr + k)[0];
|
||||
}
|
||||
for (int k = 0; k < C4NUM; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int interval = real_y_start; interval < real_y_end; interval++) {
|
||||
int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * input_channel;
|
||||
int dst_y_offset = interval * input_unit * C4NUM + real_x_start * C4NUM;
|
||||
for (int j = 0; j < (real_x_end - real_x_start); j++) {
|
||||
int src_x_offset = src_y_offset + j * input_channel;
|
||||
int dst_x_offset = dst_y_offset + j * C4NUM;
|
||||
float *src_addr = (float *)(input_data) + src_x_offset;
|
||||
float *dst_addr = tmp_data + dst_x_offset;
|
||||
for (int k = 0; k < real_c; k++) {
|
||||
dst_addr[k] = src_addr[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -84,21 +84,8 @@ int ConvolutionCPUKernel::InitTmpBuffer() {
|
|||
MS_ASSERT(ctx_->allocator != nullptr);
|
||||
|
||||
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
|
||||
size_t nhwc4_input_size =
|
||||
ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float);
|
||||
MS_ASSERT(nullptr != ctx_->allocator);
|
||||
nhwc4_input_ = ctx_->allocator->Malloc(nhwc4_input_size);
|
||||
if (nhwc4_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc nhwc4 input failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
int output_count = conv_param_->output_h_ * conv_param_->output_w_;
|
||||
int output_tile_count = UP_DIV(output_count, TILE_NUM);
|
||||
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * ic4 * C4NUM;
|
||||
int packed_input_size = output_tile_count * TILE_NUM * unit_size;
|
||||
packed_input_ =
|
||||
reinterpret_cast<float *>(ctx_->allocator->Malloc(conv_param_->input_batch_ * packed_input_size * sizeof(float)));
|
||||
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * ic4 * C4NUM * TILE_NUM * thread_count_;
|
||||
packed_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(unit_size * sizeof(float)));
|
||||
if (packed_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc packed input failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -158,9 +145,11 @@ int ConvolutionCPUKernel::RunImpl(int task_id) {
|
|||
MS_LOG(ERROR) << "gemm_func is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->MutableData());
|
||||
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
|
||||
ConvFp32(reinterpret_cast<float *>(nhwc4_input_), packed_input_, packed_weight_,
|
||||
reinterpret_cast<float *>(bias_data_), tmp_output_block_, output_addr, task_id, conv_param_, gemm_func_);
|
||||
ConvFp32(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<float *>(bias_data_), tmp_output_block_,
|
||||
output_addr, task_id, conv_param_, gemm_func_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -187,11 +176,6 @@ int ConvolutionCPUKernel::Run() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto ori_input_data = input_tensor->MutableData();
|
||||
PackNHWCToNHWC4Fp32(ori_input_data, nhwc4_input_, conv_param_->input_batch_,
|
||||
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
|
||||
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionImpl, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "conv error error_code[" << error_code << "]";
|
||||
|
|
|
@ -51,10 +51,6 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
|
|||
ctx_->allocator->Free(tmp_output_block_);
|
||||
tmp_output_block_ = nullptr;
|
||||
}
|
||||
if (nhwc4_input_ != nullptr) {
|
||||
ctx_->allocator->Free(nhwc4_input_);
|
||||
nhwc4_input_ = nullptr;
|
||||
}
|
||||
if (packed_input_ != nullptr) {
|
||||
ctx_->allocator->Free(packed_input_);
|
||||
packed_input_ = nullptr;
|
||||
|
|
|
@ -100,13 +100,6 @@ int Convolution3x3CPUKernel::InitTmpBuffer() {
|
|||
#else
|
||||
const int tile_num = 12;
|
||||
#endif
|
||||
size_t nhwc4_input_size =
|
||||
ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float);
|
||||
nhwc4_input_ = ctx_->allocator->Malloc(nhwc4_input_size);
|
||||
if (nhwc4_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc nhwc4_input_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
size_t tile_buffer_size = thread_count_ * tile_num * C16NUM * ic4 * C4NUM * sizeof(float);
|
||||
tile_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size));
|
||||
|
@ -152,16 +145,6 @@ int Convolution3x3CPUKernel::InitTmpBuffer() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void Convolution3x3CPUKernel::ConfigInputOutput() {
|
||||
auto output_tensor = out_tensors_.at(kOutputIndex);
|
||||
output_tensor->SetFormat(schema::Format::Format_NHWC);
|
||||
// #ifdef ENABLE_ARM32
|
||||
// gemm_func_ = IndirectGemmFp32_8x4;
|
||||
// #else
|
||||
gemm_func_ = IndirectGemmFp32_8x8;
|
||||
// #endif
|
||||
}
|
||||
|
||||
int Convolution3x3CPUKernel::Init() {
|
||||
auto ret = InitWeightBias();
|
||||
if (ret != RET_OK) {
|
||||
|
@ -171,7 +154,6 @@ int Convolution3x3CPUKernel::Init() {
|
|||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
ConfigInputOutput();
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
|
@ -191,12 +173,10 @@ int Convolution3x3CPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int Convolution3x3CPUKernel::RunImpl(int task_id) {
|
||||
if (gemm_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "gemm_func is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
Conv3x3Fp32(reinterpret_cast<float *>(nhwc4_input_), transformed_filter_addr_, reinterpret_cast<float *>(bias_data_),
|
||||
tmp_buffer_address_list_, task_id, conv_param_, gemm_func_);
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->MutableData());
|
||||
Conv3x3Fp32(ori_input_data, transformed_filter_addr_, reinterpret_cast<float *>(bias_data_), tmp_buffer_address_list_,
|
||||
task_id, conv_param_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -245,10 +225,6 @@ int Convolution3x3CPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "Init tmp buffer failed.ret: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto ori_input_data = input_tensor->MutableData();
|
||||
PackNHWCToNHWC4Fp32(ori_input_data, nhwc4_input_, conv_param_->input_batch_,
|
||||
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
|
||||
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, Convolution3x3Impl, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
|
@ -260,6 +236,7 @@ int Convolution3x3CPUKernel::Run() {
|
|||
ret = PostProcess();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Post process failed.";
|
||||
FreeTmpBuffer();
|
||||
return ret;
|
||||
}
|
||||
FreeTmpBuffer();
|
||||
|
|
|
@ -40,15 +40,10 @@ class Convolution3x3CPUKernel : public ConvolutionBaseCPUKernel {
|
|||
int RunImpl(int task_id);
|
||||
int InitWeightBias();
|
||||
int InitTmpBuffer();
|
||||
void ConfigInputOutput();
|
||||
int PostProcess();
|
||||
|
||||
private:
|
||||
void FreeTmpBuffer() {
|
||||
if (nhwc4_input_ != nullptr) {
|
||||
ctx_->allocator->Free(nhwc4_input_);
|
||||
nhwc4_input_ = nullptr;
|
||||
}
|
||||
if (tile_buffer_ != nullptr) {
|
||||
ctx_->allocator->Free(tile_buffer_);
|
||||
tile_buffer_ = nullptr;
|
||||
|
@ -78,7 +73,6 @@ class Convolution3x3CPUKernel : public ConvolutionBaseCPUKernel {
|
|||
float *col_buffer_ = nullptr;
|
||||
float *nc4hw4_out_ = nullptr;
|
||||
TmpBufferAddress tmp_buffer_address_list_[5];
|
||||
GEMM_FUNC_FP32 gemm_func_ = nullptr;
|
||||
};
|
||||
void ProcessFilter(float *origin_weight, float *dst_weight, ConvParameter *conv_param, int oc_block, int oc_block_num);
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -184,14 +184,6 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
|
|||
#endif
|
||||
MS_ASSERT(ctx_->allocator != nullptr);
|
||||
|
||||
size_t nhwc4_input_size =
|
||||
ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float);
|
||||
nhwc4_input_ = ctx_->allocator->Malloc(nhwc4_input_size);
|
||||
if (nhwc4_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc nhwc4_input_ failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
|
||||
size_t tile_buffer_size = thread_count_ * tile_num * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float);
|
||||
trans_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size));
|
||||
if (trans_input_ == nullptr) {
|
||||
|
@ -298,9 +290,11 @@ int ConvolutionWinogradCPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int ConvolutionWinogradCPUKernel::RunImpl(int task_id) {
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto ori_input_data = reinterpret_cast<float *>(input_tensor->MutableData());
|
||||
auto output_data = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
|
||||
ConvWinogardFp32(reinterpret_cast<float *>(nhwc4_input_), trans_weight_, reinterpret_cast<const float *>(bias_data_),
|
||||
output_data, tmp_buffer_address_list_, task_id, conv_param_, in_func_, out_func_);
|
||||
ConvWinogardFp32(ori_input_data, trans_weight_, reinterpret_cast<const float *>(bias_data_), output_data,
|
||||
tmp_buffer_address_list_, task_id, conv_param_, in_func_, out_func_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -314,30 +308,6 @@ int ConvolutionWinogradImpl(void *cdata, int task_id) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradCPUKernel::PostProcess() {
|
||||
auto out_tensor = out_tensors_.front();
|
||||
auto out_data = reinterpret_cast<float *>(out_tensor->MutableData());
|
||||
auto act_type = conv_param_->act_type_;
|
||||
switch (act_type) {
|
||||
case ActType_No:
|
||||
UnPackWinogradOutput(tmp_out_data_, out_data, conv_param_->output_batch_, conv_param_->output_h_,
|
||||
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
|
||||
break;
|
||||
case ActType_Relu:
|
||||
UnPackWinogradReluOutput(tmp_out_data_, out_data, conv_param_->output_batch_, conv_param_->output_h_,
|
||||
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
|
||||
break;
|
||||
case ActType_Relu6:
|
||||
UnPackWinogradRelu6Output(tmp_out_data_, out_data, conv_param_->output_batch_, conv_param_->output_h_,
|
||||
conv_param_->output_w_, conv_param_->output_channel_, output_unit_);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupport activation type.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConvolutionWinogradCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
|
@ -351,11 +321,6 @@ int ConvolutionWinogradCPUKernel::Run() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto ori_input_data = input_tensor->MutableData();
|
||||
PackNHWCToNHWC4Fp32(ori_input_data, nhwc4_input_, conv_param_->input_batch_,
|
||||
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
|
||||
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionWinogradImpl, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "conv winograd error error_code[" << error_code << "]";
|
||||
|
|
|
@ -45,15 +45,10 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
|
|||
int InitWeightBias();
|
||||
int InitTmpBuffer();
|
||||
int ConfigInputOutput();
|
||||
int PostProcess();
|
||||
int WinogradFilterTransform(const float *weight_data, float *matrix_g, float *matrix_gt, int oc_block);
|
||||
|
||||
private:
|
||||
void FreeTmpBuffer() {
|
||||
if (nhwc4_input_ != nullptr) {
|
||||
ctx_->allocator->Free(nhwc4_input_);
|
||||
nhwc4_input_ = nullptr;
|
||||
}
|
||||
if (trans_input_ != nullptr) {
|
||||
ctx_->allocator->Free(trans_input_);
|
||||
trans_input_ = nullptr;
|
||||
|
|
|
@ -137,25 +137,15 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() {
|
|||
MS_ASSERT(ctx_->allocator != nullptr);
|
||||
|
||||
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
|
||||
int output_count = conv_param_->output_h_ * conv_param_->output_w_;
|
||||
int output_tile_count = UP_DIV(output_count, tile_num_);
|
||||
int kernel_plane = conv_param_->kernel_h_ * conv_param_->kernel_w_;
|
||||
int plane_c4 = UP_DIV(kernel_plane, C4NUM);
|
||||
int unit_size = plane_c4 * C4NUM * ic4 * C4NUM;
|
||||
int packed_input_size = output_tile_count * tile_num_ * unit_size;
|
||||
packed_input_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(conv_param_->input_batch_ * packed_input_size));
|
||||
packed_input_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(unit_size * thread_count_ * tile_num_));
|
||||
if (packed_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc packed_input_ failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
size_t nhwc4_input_size = ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_;
|
||||
nhwc4_input_ = ctx_->allocator->Malloc(nhwc4_input_size);
|
||||
if (nhwc4_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc nhwc4 input failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
size_t tmp_dst_size = thread_count_ * tile_num_ * conv_param_->output_channel_ * sizeof(int32_t);
|
||||
tmp_dst_ = reinterpret_cast<int32_t *>(ctx_->allocator->Malloc(tmp_dst_size));
|
||||
if (tmp_dst_ == nullptr) {
|
||||
|
@ -322,15 +312,15 @@ int ConvolutionInt8CPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int ConvolutionInt8CPUKernel::RunImpl(int task_id) {
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto ori_input_data = reinterpret_cast<int8_t *>(input_tensor->MutableData());
|
||||
auto output_addr = reinterpret_cast<int8_t *>(out_tensors_.at(kOutputIndex)->MutableData());
|
||||
if (support_optimize_) {
|
||||
ConvInt8Opt(reinterpret_cast<int8_t *>(nhwc4_input_), packed_input_, packed_weight_,
|
||||
reinterpret_cast<int32_t *>(bias_data_), tmp_dst_, tmp_out_, output_addr, input_sum_, task_id,
|
||||
conv_param_, gemm_func_);
|
||||
ConvInt8Opt(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), tmp_dst_,
|
||||
tmp_out_, output_addr, input_sum_, task_id, conv_param_, gemm_func_);
|
||||
} else {
|
||||
ConvInt8(reinterpret_cast<int8_t *>(nhwc4_input_), packed_input_, packed_weight_,
|
||||
reinterpret_cast<int32_t *>(bias_data_), tmp_dst_, tmp_out_, output_addr, input_sum_, task_id,
|
||||
conv_param_);
|
||||
ConvInt8(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), tmp_dst_, tmp_out_,
|
||||
output_addr, input_sum_, task_id, conv_param_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -358,11 +348,6 @@ int ConvolutionInt8CPUKernel::Run() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto input_tensor = in_tensors_.at(kInputIndex);
|
||||
auto ori_input_data = input_tensor->MutableData();
|
||||
PackNHWCToNHWC4Int8(ori_input_data, nhwc4_input_, conv_param_->input_batch_,
|
||||
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
|
||||
|
||||
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionInt8Impl, this, thread_count_);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "conv int8 error error_code[" << error_code << "]";
|
||||
|
|
|
@ -55,10 +55,6 @@ class ConvolutionInt8CPUKernel : public ConvolutionBaseCPUKernel {
|
|||
|
||||
private:
|
||||
void FreeTmpBuffer() {
|
||||
if (nhwc4_input_ != nullptr) {
|
||||
ctx_->allocator->Free(nhwc4_input_);
|
||||
nhwc4_input_ = nullptr;
|
||||
}
|
||||
if (packed_input_ != nullptr) {
|
||||
ctx_->allocator->Free(packed_input_);
|
||||
packed_input_ = nullptr;
|
||||
|
|
Loading…
Reference in New Issue