forked from OSSInnovation/mindspore
!6428 [MS][LITE]rewrite winograd output transform func
Merge pull request !6428 from fuzhiye/tmp
This commit is contained in:
commit
37234c17f8
|
@ -258,9 +258,9 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
|
|||
}
|
||||
|
||||
// fp32 conv winograd
|
||||
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
|
||||
int task_id, ConvParameter *conv_param, InputTransFunc in_func, OutputTransFunc out_func,
|
||||
GEMM_FUNC_FP32 gemm_func) {
|
||||
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data,
|
||||
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, InputTransFunc in_func,
|
||||
OutputTransFunc out_func) {
|
||||
int thread_num = conv_param->thread_num_;
|
||||
int input_unit = conv_param->input_unit_;
|
||||
int in_batch = conv_param->input_batch_;
|
||||
|
@ -277,13 +277,11 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
#endif
|
||||
int output_tile_count = UP_DIV(output_count, tile_num);
|
||||
int out_channel = conv_param->output_channel_;
|
||||
int oc4 = UP_DIV(out_channel, C4NUM);
|
||||
int oc8 = UP_DIV(out_channel, C8NUM);
|
||||
int input_unit_square = input_unit * input_unit;
|
||||
|
||||
float *trans_input = buffer_list[0];
|
||||
float *gemm_out = buffer_list[1];
|
||||
float *tmp_out_data = buffer_list[2];
|
||||
float *tmp_data = buffer_list[3];
|
||||
float *col_buffer = buffer_list[4];
|
||||
int trans_input_offset = tile_num * input_unit_square * ic4 * C4NUM;
|
||||
|
@ -294,7 +292,7 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
// step 2 : input transform (online)
|
||||
for (int b = 0; b < in_batch; b++) {
|
||||
int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
|
||||
int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc4 * C4NUM;
|
||||
int out_batch_offset = b * out_channel * conv_param->output_w_ * conv_param->output_h_;
|
||||
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
|
||||
int out_tile_index = thread_id * tile_num;
|
||||
int cal_num = output_count - thread_id * tile_num;
|
||||
|
@ -317,8 +315,9 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|||
}
|
||||
|
||||
// step 4 : output transform
|
||||
WinogradOutputTransform(dst_ptr, tmp_out_data + tmp_out_batch_offset, bias_data, cal_num, out_tile_index,
|
||||
out_w_block, conv_param, out_func);
|
||||
float *output_ptr = output_data + out_batch_offset;
|
||||
WinogradOutputTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
out_func);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,9 +53,9 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
|
|||
GEMM_FUNC_FP32 gemm_func);
|
||||
|
||||
// fp32 convolution winograd
|
||||
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list,
|
||||
int task_id, ConvParameter *conv_param, InputTransFunc in_func, OutputTransFunc out_func,
|
||||
GEMM_FUNC_FP32 gemm_func);
|
||||
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data,
|
||||
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, InputTransFunc in_func,
|
||||
OutputTransFunc out_func);
|
||||
|
||||
void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, int output_unit);
|
||||
|
||||
|
|
|
@ -82,13 +82,11 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
|
|||
} // cal_tile_num loop
|
||||
}
|
||||
|
||||
void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const float *bias_data, int cal_num,
|
||||
void WinogradOutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num,
|
||||
int out_tile_index, int output_unit_num, ConvParameter *conv_param, OutputTransFunc func) {
|
||||
int output_unit = conv_param->output_unit_;
|
||||
int output_w = conv_param->output_w_;
|
||||
int output_h = conv_param->output_h_;
|
||||
int output_w_unit_block = UP_DIV(output_w, output_unit);
|
||||
int output_h_unit_block = UP_DIV(output_h, output_unit);
|
||||
int output_channel = conv_param->output_channel_;
|
||||
int oc4 = UP_DIV(output_channel, C4NUM);
|
||||
int oc8 = UP_DIV(output_channel, C8NUM);
|
||||
|
@ -99,19 +97,29 @@ void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const f
|
|||
for (int i = 0; i < cal_num; i++) {
|
||||
int dst_x_s = out_tile_index % output_unit_num;
|
||||
int dst_y_s = out_tile_index / output_unit_num;
|
||||
int r_w = output_w - dst_x_s * output_unit;
|
||||
r_w = r_w > output_unit ? output_unit : r_w;
|
||||
int r_h = output_h - dst_y_s * output_unit;
|
||||
r_h = r_h > output_unit ? output_unit : r_h;
|
||||
int tmp_ix = dst_x_s * output_unit;
|
||||
dst_x_s = tmp_ix > output_w ? output_w : tmp_ix;
|
||||
int tmp_iy = dst_y_s * output_unit;
|
||||
dst_y_s = tmp_iy > output_h ? output_h : tmp_iy;
|
||||
|
||||
int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit;
|
||||
int dst_tile_offset = C4NUM * output_unit * (dst_x_s + dst_y_s * output_w_unit_block * output_unit);
|
||||
int dst_tile_offset = output_channel * (dst_x_s + dst_y_s * output_w);
|
||||
|
||||
for (int j = 0; j < oc4; j++) {
|
||||
int c8_block = j / 2;
|
||||
int c8_res = j % 2;
|
||||
int r_c = output_channel - j * C4NUM;
|
||||
r_c = r_c > C4NUM ? C4NUM : r_c;
|
||||
int src_oc4_offset = src_tile_offset + c8_block * input_unit * input_unit * C8NUM + c8_res * C4NUM;
|
||||
int dst_oc4_offset =
|
||||
dst_tile_offset + j * C4NUM * output_h_unit_block * output_w_unit_block * output_unit * output_unit;
|
||||
int dst_oc4_offset = dst_tile_offset + j * C4NUM;
|
||||
const float *src_ptr = gemm_out + src_oc4_offset;
|
||||
const float *bias_ptr = bias_data + j * C4NUM;
|
||||
float *dst_ptr = tmp_out_data + dst_oc4_offset;
|
||||
func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w_unit_block * output_unit);
|
||||
float *dst_ptr = out_data + dst_oc4_offset;
|
||||
func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, output_channel, r_w, r_h, r_c);
|
||||
// GeneralOutputTransformUnit(src_ptr, dst_ptr, bias_ptr, matrix_a, matrix_at, C8NUM,
|
||||
// output_w_unit_block * output_unit, input_unit, output_unit);
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ extern "C" {
|
|||
void WinogradInputTransform(const float *input_data, float *trans_input, float *tmp_data, int cal_num,
|
||||
int out_tile_index, int out_w_block_num, ConvParameter *conv_param, InputTransFunc func);
|
||||
|
||||
void WinogradOutputTransform(const float *gemm_out, float *tmp_out_data, const float *bias_data, int cal_num,
|
||||
void WinogradOutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num,
|
||||
int out_tile_index, int output_unit_num, ConvParameter *conv_param, OutputTransFunc func);
|
||||
|
||||
// for fp32 convolution 3x3 filter/input/output transform
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -31,7 +31,7 @@ extern "C" {
|
|||
typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step);
|
||||
|
||||
typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step);
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
|
||||
void GeneralInputTransformUnit(const float *src_data, float *dst_data, float *matrix_b, float *matrix_bt, int src_step,
|
||||
int dst_step, int in_unit);
|
||||
|
@ -169,84 +169,144 @@ void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step,
|
|||
|
||||
void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step);
|
||||
|
||||
OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit);
|
||||
OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type);
|
||||
|
||||
#define Store4Data \
|
||||
vst1q_f32(dst_data, m[0]); \
|
||||
vst1q_f32(dst_data + C4NUM, m[1]); \
|
||||
vst1q_f32(dst_data + dst_step * C4NUM, m[2]); \
|
||||
vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[3]);
|
||||
vst1q_f32(dst_data + out_c, m[1]); \
|
||||
vst1q_f32(dst_data + dst_step * out_c, m[2]); \
|
||||
vst1q_f32(dst_data + dst_step * out_c + out_c, m[3]);
|
||||
|
||||
#define Store9Data \
|
||||
vst1q_f32(dst_data, m[0]); \
|
||||
vst1q_f32(dst_data + C4NUM, m[1]); \
|
||||
vst1q_f32(dst_data + 2 * C4NUM, m[2]); \
|
||||
vst1q_f32(dst_data + dst_step * C4NUM, m[3]); \
|
||||
vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[4]); \
|
||||
vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, m[5]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * C4NUM, m[6]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, m[7]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, m[8]);
|
||||
vst1q_f32(dst_data + out_c, m[1]); \
|
||||
vst1q_f32(dst_data + 2 * out_c, m[2]); \
|
||||
vst1q_f32(dst_data + dst_step * out_c, m[3]); \
|
||||
vst1q_f32(dst_data + dst_step * out_c + out_c, m[4]); \
|
||||
vst1q_f32(dst_data + dst_step * out_c + 2 * out_c, m[5]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * out_c, m[6]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * out_c + out_c, m[7]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]);
|
||||
|
||||
#define Store16Data \
|
||||
vst1q_f32(dst_data, m[0]); \
|
||||
vst1q_f32(dst_data + C4NUM, m[1]); \
|
||||
vst1q_f32(dst_data + 2 * C4NUM, m[2]); \
|
||||
vst1q_f32(dst_data + 3 * C4NUM, m[3]); \
|
||||
vst1q_f32(dst_data + dst_step * C4NUM, m[4]); \
|
||||
vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[5]); \
|
||||
vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, m[6]); \
|
||||
vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, m[7]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * C4NUM, m[8]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, m[9]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, m[10]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, m[11]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * C4NUM, m[12]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, m[13]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, m[14]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, m[15]);
|
||||
vst1q_f32(dst_data + out_c, m[1]); \
|
||||
vst1q_f32(dst_data + 2 * out_c, m[2]); \
|
||||
vst1q_f32(dst_data + 3 * out_c, m[3]); \
|
||||
vst1q_f32(dst_data + dst_step * out_c, m[4]); \
|
||||
vst1q_f32(dst_data + dst_step * out_c + out_c, m[5]); \
|
||||
vst1q_f32(dst_data + dst_step * out_c + 2 * out_c, m[6]); \
|
||||
vst1q_f32(dst_data + dst_step * out_c + 3 * out_c, m[7]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * out_c, m[8]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * out_c + out_c, m[9]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * out_c, m[12]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * out_c + out_c, m[13]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]);
|
||||
|
||||
#define Store25Data \
|
||||
vst1q_f32(dst_data, m[0]); \
|
||||
vst1q_f32(dst_data + C4NUM, m[1]); \
|
||||
vst1q_f32(dst_data + 2 * C4NUM, m[2]); \
|
||||
vst1q_f32(dst_data + 3 * C4NUM, m[3]); \
|
||||
vst1q_f32(dst_data + 4 * C4NUM, m[4]); \
|
||||
vst1q_f32(dst_data + dst_step * C4NUM, m[5]); \
|
||||
vst1q_f32(dst_data + dst_step * C4NUM + C4NUM, m[6]); \
|
||||
vst1q_f32(dst_data + dst_step * C4NUM + 2 * C4NUM, m[7]); \
|
||||
vst1q_f32(dst_data + dst_step * C4NUM + 3 * C4NUM, m[8]); \
|
||||
vst1q_f32(dst_data + dst_step * C4NUM + 4 * C4NUM, m[9]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * C4NUM, m[10]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * C4NUM + C4NUM, m[11]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 2 * C4NUM, m[12]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 3 * C4NUM, m[13]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * C4NUM + 4 * C4NUM, m[14]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * C4NUM, m[15]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * C4NUM + C4NUM, m[16]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * C4NUM + 2 * C4NUM, m[17]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * C4NUM + 3 * C4NUM, m[18]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * C4NUM + 4 * C4NUM, m[19]); \
|
||||
vst1q_f32(dst_data + 4 * dst_step * C4NUM, m[20]); \
|
||||
vst1q_f32(dst_data + 4 * dst_step * C4NUM + C4NUM, m[21]); \
|
||||
vst1q_f32(dst_data + 4 * dst_step * C4NUM + 2 * C4NUM, m[22]); \
|
||||
vst1q_f32(dst_data + 4 * dst_step * C4NUM + 3 * C4NUM, m[23]); \
|
||||
vst1q_f32(dst_data + 4 * dst_step * C4NUM + 4 * C4NUM, m[24]);
|
||||
vst1q_f32(dst_data + out_c, m[1]); \
|
||||
vst1q_f32(dst_data + 2 * out_c, m[2]); \
|
||||
vst1q_f32(dst_data + 3 * out_c, m[3]); \
|
||||
vst1q_f32(dst_data + 4 * out_c, m[4]); \
|
||||
vst1q_f32(dst_data + dst_step * out_c, m[5]); \
|
||||
vst1q_f32(dst_data + dst_step * out_c + out_c, m[6]); \
|
||||
vst1q_f32(dst_data + dst_step * out_c + 2 * out_c, m[7]); \
|
||||
vst1q_f32(dst_data + dst_step * out_c + 3 * out_c, m[8]); \
|
||||
vst1q_f32(dst_data + dst_step * out_c + 4 * out_c, m[9]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * out_c, m[10]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * out_c + out_c, m[11]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \
|
||||
vst1q_f32(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * out_c, m[15]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * out_c + out_c, m[16]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \
|
||||
vst1q_f32(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \
|
||||
vst1q_f32(dst_data + 4 * dst_step * out_c, m[20]); \
|
||||
vst1q_f32(dst_data + 4 * dst_step * out_c + out_c, m[21]); \
|
||||
vst1q_f32(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \
|
||||
vst1q_f32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \
|
||||
vst1q_f32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]);
|
||||
|
||||
void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
|
||||
void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
|
||||
void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
|
||||
int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform4x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform4x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
|
||||
int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform4x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform4x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
|
||||
void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
|
||||
void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
|
||||
void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
|
||||
void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
|
||||
void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
|
||||
int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform6x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform6x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
|
||||
int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform6x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform6x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
|
||||
int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform6x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform6x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
|
||||
int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform6x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform6x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
|
||||
void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
|
||||
void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
|
||||
void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
|
||||
void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
|
||||
void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
|
||||
void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step);
|
||||
void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
|
||||
int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
|
||||
int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
|
||||
int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
|
||||
int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
|
||||
int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
|
||||
int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
|
||||
int dst_step, int out_c, int r_w, int r_h, int r_c);
|
||||
|
||||
int SelectOutputUnit(ConvParameter *conv_param);
|
||||
|
||||
|
|
|
@ -136,8 +136,12 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() {
|
|||
float matrix_at[64];
|
||||
float matrix_b[64];
|
||||
float matrix_bt[64];
|
||||
float coef = 1.0f;
|
||||
if (input_unit_ == 8) {
|
||||
coef = 0.5f;
|
||||
}
|
||||
auto ret =
|
||||
CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 1.0f, output_unit_, kernel_unit_);
|
||||
CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, coef, output_unit_, kernel_unit_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "get matrix g from CookToomFilter failed.";
|
||||
return ret;
|
||||
|
@ -243,17 +247,11 @@ int ConvolutionWinogradCPUKernel::ConfigInputOutput() {
|
|||
MS_LOG(ERROR) << "in_func_ is null.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
out_func_ = GetOutputTransFunc(input_unit_, output_unit_);
|
||||
out_func_ = GetOutputTransFunc(input_unit_, output_unit_, conv_param_->act_type_);
|
||||
if (out_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "out_func_ is null.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// #ifdef ENABLE_ARM32
|
||||
// gemm_func_ = IndirectGemmFp32_8x4;
|
||||
// #else
|
||||
gemm_func_ = IndirectGemmFp32_8x8;
|
||||
// #endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -300,12 +298,9 @@ int ConvolutionWinogradCPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int ConvolutionWinogradCPUKernel::RunImpl(int task_id) {
|
||||
if (gemm_func_ == nullptr) {
|
||||
MS_LOG(ERROR) << "gemm_func is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto output_data = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
|
||||
ConvWinogardFp32(reinterpret_cast<float *>(nhwc4_input_), trans_weight_, reinterpret_cast<const float *>(bias_data_),
|
||||
tmp_buffer_address_list_, task_id, conv_param_, in_func_, out_func_, gemm_func_);
|
||||
output_data, tmp_buffer_address_list_, task_id, conv_param_, in_func_, out_func_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -368,12 +363,6 @@ int ConvolutionWinogradCPUKernel::Run() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ret = PostProcess();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Post process failed.";
|
||||
FreeTmpBuffer();
|
||||
return ret;
|
||||
}
|
||||
FreeTmpBuffer();
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -87,7 +87,6 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
|
|||
TmpBufferAddress tmp_buffer_address_list_[5];
|
||||
InputTransFunc in_func_;
|
||||
OutputTransFunc out_func_;
|
||||
GEMM_FUNC_FP32 gemm_func_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
|
|
Loading…
Reference in New Issue