!6428 [MS][LITE]rewrite winograd output transform func

Merge pull request !6428 from fuzhiye/tmp
This commit is contained in:
mindspore-ci-bot 2020-09-22 10:05:34 +08:00 committed by Gitee
commit 37234c17f8
8 changed files with 2960 additions and 427 deletions

View File

@ -258,9 +258,9 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
}
// fp32 conv winograd
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, InputTransFunc in_func, OutputTransFunc out_func,
GEMM_FUNC_FP32 gemm_func) {
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data,
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, InputTransFunc in_func,
OutputTransFunc out_func) {
int thread_num = conv_param->thread_num_;
int input_unit = conv_param->input_unit_;
int in_batch = conv_param->input_batch_;
@ -277,13 +277,11 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
#endif
int output_tile_count = UP_DIV(output_count, tile_num);
int out_channel = conv_param->output_channel_;
int oc4 = UP_DIV(out_channel, C4NUM);
int oc8 = UP_DIV(out_channel, C8NUM);
int input_unit_square = input_unit * input_unit;
float *trans_input = buffer_list[0];
float *gemm_out = buffer_list[1];
float *tmp_out_data = buffer_list[2];
float *tmp_data = buffer_list[3];
float *col_buffer = buffer_list[4];
int trans_input_offset = tile_num * input_unit_square * ic4 * C4NUM;
@ -294,7 +292,7 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
// step 2 : input transform (online)
for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc4 * C4NUM;
int out_batch_offset = b * out_channel * conv_param->output_w_ * conv_param->output_h_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
int out_tile_index = thread_id * tile_num;
int cal_num = output_count - thread_id * tile_num;
@ -317,8 +315,9 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
}
// step 4 : output transform
WinogradOutputTransform(dst_ptr, tmp_out_data + tmp_out_batch_offset, bias_data, cal_num, out_tile_index,
out_w_block, conv_param, out_func);
float *output_ptr = output_data + out_batch_offset;
WinogradOutputTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
out_func);
}
}
}

View File

@ -53,9 +53,9 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
GEMM_FUNC_FP32 gemm_func);
// fp32 convolution winograd
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param, InputTransFunc in_func, OutputTransFunc out_func,
GEMM_FUNC_FP32 gemm_func);
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data,
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, InputTransFunc in_func,
OutputTransFunc out_func);
void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, int output_unit);

View File

@ -82,13 +82,11 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
} // cal_tile_num loop
}
void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const float *bias_data, int cal_num,
void WinogradOutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num,
int out_tile_index, int output_unit_num, ConvParameter *conv_param, OutputTransFunc func) {
int output_unit = conv_param->output_unit_;
int output_w = conv_param->output_w_;
int output_h = conv_param->output_h_;
int output_w_unit_block = UP_DIV(output_w, output_unit);
int output_h_unit_block = UP_DIV(output_h, output_unit);
int output_channel = conv_param->output_channel_;
int oc4 = UP_DIV(output_channel, C4NUM);
int oc8 = UP_DIV(output_channel, C8NUM);
@ -99,19 +97,29 @@ void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const f
for (int i = 0; i < cal_num; i++) {
int dst_x_s = out_tile_index % output_unit_num;
int dst_y_s = out_tile_index / output_unit_num;
int r_w = output_w - dst_x_s * output_unit;
r_w = r_w > output_unit ? output_unit : r_w;
int r_h = output_h - dst_y_s * output_unit;
r_h = r_h > output_unit ? output_unit : r_h;
int tmp_ix = dst_x_s * output_unit;
dst_x_s = tmp_ix > output_w ? output_w : tmp_ix;
int tmp_iy = dst_y_s * output_unit;
dst_y_s = tmp_iy > output_h ? output_h : tmp_iy;
int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit;
int dst_tile_offset = C4NUM * output_unit * (dst_x_s + dst_y_s * output_w_unit_block * output_unit);
int dst_tile_offset = output_channel * (dst_x_s + dst_y_s * output_w);
for (int j = 0; j < oc4; j++) {
int c8_block = j / 2;
int c8_res = j % 2;
int r_c = output_channel - j * C4NUM;
r_c = r_c > C4NUM ? C4NUM : r_c;
int src_oc4_offset = src_tile_offset + c8_block * input_unit * input_unit * C8NUM + c8_res * C4NUM;
int dst_oc4_offset =
dst_tile_offset + j * C4NUM * output_h_unit_block * output_w_unit_block * output_unit * output_unit;
int dst_oc4_offset = dst_tile_offset + j * C4NUM;
const float *src_ptr = gemm_out + src_oc4_offset;
const float *bias_ptr = bias_data + j * C4NUM;
float *dst_ptr = tmp_out_data + dst_oc4_offset;
func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w_unit_block * output_unit);
float *dst_ptr = out_data + dst_oc4_offset;
func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, output_channel, r_w, r_h, r_c);
// GeneralOutputTransformUnit(src_ptr, dst_ptr, bias_ptr, matrix_a, matrix_at, C8NUM,
// output_w_unit_block * output_unit, input_unit, output_unit);
}

View File

@ -35,7 +35,7 @@ extern "C" {
void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num,
int out_tile_index, int out_w_block_num, ConvParameter *conv_param, InputTransFunc func);
void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const float *bias_data, int cal_num,
void WinogradOutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num,
int out_tile_index, int output_unit_num, ConvParameter *conv_param, OutputTransFunc func);
// for fp32 convolution 3x3 filter/input/output transform

File diff suppressed because it is too large Load Diff

View File

@ -31,7 +31,7 @@ extern "C" {
typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step);
typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step);
int dst_step, int out_c, int r_w, int r_h, int r_c);
void GeneralInputTransformUnit(const float *src_data, float *dst_data, float *matrix_b, float *matrix_bt, int src_step,
int dst_step, int in_unit);
@ -169,84 +169,144 @@ void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step,
void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step);
OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit);
OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type);
#define Store4Data \
vst1q_f32(dst_data, m[0]); \
vst1q_f32(dst_data + C4NUM, m[1]); \
vst1q_f32(dst_data + dst_step * C4NUM, m[2]); \
vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[3]);
vst1q_f32(dst_data + out_c, m[1]); \
vst1q_f32(dst_data + dst_step * out_c, m[2]); \
vst1q_f32(dst_data + dst_step * out_c + out_c, m[3]);
#define Store9Data \
vst1q_f32(dst_data, m[0]); \
vst1q_f32(dst_data + C4NUM, m[1]); \
vst1q_f32(dst_data + 2 * C4NUM, m[2]); \
vst1q_f32(dst_data + dst_step * C4NUM, m[3]); \
vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[4]); \
vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, m[5]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM, m[6]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, m[7]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, m[8]);
vst1q_f32(dst_data + out_c, m[1]); \
vst1q_f32(dst_data + 2 * out_c, m[2]); \
vst1q_f32(dst_data + dst_step * out_c, m[3]); \
vst1q_f32(dst_data + dst_step * out_c + out_c, m[4]); \
vst1q_f32(dst_data + dst_step * out_c + 2 * out_c, m[5]); \
vst1q_f32(dst_data + 2 * dst_step * out_c, m[6]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + out_c, m[7]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]);
#define Store16Data \
vst1q_f32(dst_data, m[0]); \
vst1q_f32(dst_data + C4NUM, m[1]); \
vst1q_f32(dst_data + 2 * C4NUM, m[2]); \
vst1q_f32(dst_data + 3 * C4NUM, m[3]); \
vst1q_f32(dst_data + dst_step * C4NUM, m[4]); \
vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[5]); \
vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, m[6]); \
vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, m[7]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM, m[8]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, m[9]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, m[10]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, m[11]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM, m[12]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, m[13]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, m[14]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, m[15]);
vst1q_f32(dst_data + out_c, m[1]); \
vst1q_f32(dst_data + 2 * out_c, m[2]); \
vst1q_f32(dst_data + 3 * out_c, m[3]); \
vst1q_f32(dst_data + dst_step * out_c, m[4]); \
vst1q_f32(dst_data + dst_step * out_c + out_c, m[5]); \
vst1q_f32(dst_data + dst_step * out_c + 2 * out_c, m[6]); \
vst1q_f32(dst_data + dst_step * out_c + 3 * out_c, m[7]); \
vst1q_f32(dst_data + 2 * dst_step * out_c, m[8]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + out_c, m[9]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \
vst1q_f32(dst_data + 3 * dst_step * out_c, m[12]); \
vst1q_f32(dst_data + 3 * dst_step * out_c + out_c, m[13]); \
vst1q_f32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \
vst1q_f32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]);
#define Store25Data \
vst1q_f32(dst_data, m[0]); \
vst1q_f32(dst_data + C4NUM, m[1]); \
vst1q_f32(dst_data + 2 * C4NUM, m[2]); \
vst1q_f32(dst_data + 3 * C4NUM, m[3]); \
vst1q_f32(dst_data + 4 * C4NUM, m[4]); \
vst1q_f32(dst_data + dst_step * C4NUM, m[5]); \
vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[6]); \
vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, m[7]); \
vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, m[8]); \
vst1q_f32(dst_data + dst_step * C4NUM + 4 * C4NUM, m[9]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM, m[10]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, m[11]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, m[12]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, m[13]); \
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 4 * C4NUM, m[14]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM, m[15]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, m[16]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, m[17]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, m[18]); \
vst1q_f32(dst_data + 3 * dst_step * C4NUM + 4 * C4NUM, m[19]); \
vst1q_f32(dst_data + 4 * dst_step * C4NUM, m[20]); \
vst1q_f32(dst_data + 4 * dst_step * C4NUM + C4NUM, m[21]); \
vst1q_f32(dst_data + 4 * dst_step * C4NUM + 2 * C4NUM, m[22]); \
vst1q_f32(dst_data + 4 * dst_step * C4NUM + 3 * C4NUM, m[23]); \
vst1q_f32(dst_data + 4 * dst_step * C4NUM + 4 * C4NUM, m[24]);
vst1q_f32(dst_data + out_c, m[1]); \
vst1q_f32(dst_data + 2 * out_c, m[2]); \
vst1q_f32(dst_data + 3 * out_c, m[3]); \
vst1q_f32(dst_data + 4 * out_c, m[4]); \
vst1q_f32(dst_data + dst_step * out_c, m[5]); \
vst1q_f32(dst_data + dst_step * out_c + out_c, m[6]); \
vst1q_f32(dst_data + dst_step * out_c + 2 * out_c, m[7]); \
vst1q_f32(dst_data + dst_step * out_c + 3 * out_c, m[8]); \
vst1q_f32(dst_data + dst_step * out_c + 4 * out_c, m[9]); \
vst1q_f32(dst_data + 2 * dst_step * out_c, m[10]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + out_c, m[11]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \
vst1q_f32(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \
vst1q_f32(dst_data + 3 * dst_step * out_c, m[15]); \
vst1q_f32(dst_data + 3 * dst_step * out_c + out_c, m[16]); \
vst1q_f32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \
vst1q_f32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \
vst1q_f32(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \
vst1q_f32(dst_data + 4 * dst_step * out_c, m[20]); \
vst1q_f32(dst_data + 4 * dst_step * out_c + out_c, m[21]); \
vst1q_f32(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \
vst1q_f32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \
vst1q_f32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]);
void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
int out_c, int r_w, int r_h, int r_c);
void OutputTransform4x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform4x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
int out_c, int r_w, int r_h, int r_c);
void OutputTransform4x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform4x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
int out_c, int r_w, int r_h, int r_c);
void OutputTransform6x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform6x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
int out_c, int r_w, int r_h, int r_c);
void OutputTransform6x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform6x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
int out_c, int r_w, int r_h, int r_c);
void OutputTransform6x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform6x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
int out_c, int r_w, int r_h, int r_c);
void OutputTransform6x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform6x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
int dst_step, int out_c, int r_w, int r_h, int r_c);
int SelectOutputUnit(ConvParameter *conv_param);

View File

@ -136,8 +136,12 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
float matrix_at[64];
float matrix_b[64];
float matrix_bt[64];
float coef = 1.0f;
if (input_unit_ == 8) {
coef = 0.5f;
}
auto ret =
CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 1.0f, output_unit_, kernel_unit_);
CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, coef, output_unit_, kernel_unit_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "get matrix g from CookToomFilter failed.";
return ret;
@ -243,17 +247,11 @@ int ConvolutionWinogradCPUKernel::ConfigInputOutput() {
MS_LOG(ERROR) << "in_func_ is null.";
return RET_ERROR;
}
out_func_ = GetOutputTransFunc(input_unit_, output_unit_);
out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_);
if (out_func_ == nullptr) {
MS_LOG(ERROR) << "out_func_ is null.";
return RET_ERROR;
}
// #ifdef ENABLE_ARM32
// gemm_func_ = IndirectGemmFp32_8x4;
// #else
gemm_func_ = IndirectGemmFp32_8x8;
// #endif
return RET_OK;
}
@ -300,12 +298,9 @@ int ConvolutionWinogradCPUKernel::ReSize() {
}
int ConvolutionWinogradCPUKernel::RunImpl(int task_id) {
if (gemm_func_ == nullptr) {
MS_LOG(ERROR) << "gemm_func is nullptr.";
return RET_ERROR;
}
auto output_data = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
ConvWinogardFp32(reinterpret_cast<float *>(nhwc4_input_), trans_weight_, reinterpret_cast<const float *>(bias_data_),
tmp_buffer_address_list_, task_id, conv_param_, in_func_, out_func_, gemm_func_);
output_data, tmp_buffer_address_list_, task_id, conv_param_, in_func_, out_func_);
return RET_OK;
}
@ -368,12 +363,6 @@ int ConvolutionWinogradCPUKernel::Run() {
return RET_ERROR;
}
ret = PostProcess();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Post process failed.";
FreeTmpBuffer();
return ret;
}
FreeTmpBuffer();
return RET_OK;
}

View File

@ -87,7 +87,6 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
TmpBufferAddress tmp_buffer_address_list_[5];
InputTransFunc in_func_;
OutputTransFunc out_func_;
GEMM_FUNC_FP32 gemm_func_ = nullptr;
};
} // namespace mindspore::kernel