!17898 fix codegen bugs

Merge pull request !17898 from yangjie159/modify_cpp
This commit is contained in:
i-robot 2021-06-07 16:04:28 +08:00 committed by Gitee
commit ad5e03c856
8 changed files with 13 additions and 15 deletions

View File

@ -186,8 +186,7 @@ int DeConvolutionFP32Coder::DoCode(CoderContext *const context) {
if (!support_parallel_) { if (!support_parallel_) {
code.CodeFunction("DeConvFp32Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale); code.CodeFunction("DeConvFp32Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
} else { } else {
code.CodeFunction(kParallelLaunch, "DeConvFp32Run", kRunArgsAddr, "conv_parameter.thread_num_", kLhsScale, code.CodeFunction(kParallelLaunch, "DeConvFp32Run", kRunArgsAddr, "conv_parameter.thread_num_");
kRhsScale);
} }
} }
context->AppendCode(code.str()); context->AppendCode(code.str());

View File

@ -156,9 +156,9 @@ int AddInt8Coder::DoCode(CoderContext *const context) {
support_opt_add_, input0, input1, output_tensor_); support_opt_add_, input0, input1, output_tensor_);
if (support_parallel_) { if (support_parallel_) {
if (arith_para_->broadcasting_) { if (arith_para_->broadcasting_) {
code.CodeFunction(kParallelLaunch, "AddBroadcastInt8Run", kRunArgsAddr, gThreadNum, kLhsScale, kRhsScale); code.CodeFunction(kParallelLaunch, "AddBroadcastInt8Run", kRunArgsAddr, gThreadNum);
} else { } else {
code.CodeFunction(kParallelLaunch, "AddInt8Run", kRunArgsAddr, gThreadNum, kLhsScale, kRhsScale); code.CodeFunction(kParallelLaunch, "AddInt8Run", kRunArgsAddr, gThreadNum);
} }
} else { } else {
if (arith_para_->broadcasting_) { if (arith_para_->broadcasting_) {

View File

@ -113,7 +113,7 @@ int ConcatInt8Coder::DoCode(CoderContext *const context) {
code.CodeBaseStruct<false>("ConcatInt8Args", kRunArgs, "input_data", output_tensor_, "&concat_param", axis_, code.CodeBaseStruct<false>("ConcatInt8Args", kRunArgs, "input_data", output_tensor_, "&concat_param", axis_,
before_axis_size, count_unit_); before_axis_size, count_unit_);
if (support_parallel_) { if (support_parallel_) {
code.CodeFunction(kParallelLaunch, "ConcatInt8Run", kRunArgsAddr, gThreadNum, kLhsScale, kRhsScale); code.CodeFunction(kParallelLaunch, "ConcatInt8Run", kRunArgsAddr, gThreadNum);
} else { } else {
code.CodeFunction("ConcatInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale); code.CodeFunction("ConcatInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
} }

View File

@ -88,7 +88,7 @@ int Conv2D1x1Int8Coder::DoCode(CoderContext *const context) {
/* input transpose and input sum */ /* input transpose and input sum */
code << "if (GetSupportOptFlag()) {\n"; code << "if (GetSupportOptFlag()) {\n";
if (support_parallel_) { if (support_parallel_) {
code.CodeFunction(kParallelLaunch, "OcOptPre", kRunArgsAddr, "args.thread_count_hw", kLhsScale, kRhsScale); code.CodeFunction(kParallelLaunch, "OcOptPre", kRunArgsAddr, "args.thread_count_hw");
} else { } else {
code.CodeFunction("OcOptPre", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale); code.CodeFunction("OcOptPre", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
} }
@ -107,13 +107,13 @@ int Conv2D1x1Int8Coder::DoCode(CoderContext *const context) {
/* matmul parallel by oc */ /* matmul parallel by oc */
code << "if (GetSupportOptFlag()) {\n"; code << "if (GetSupportOptFlag()) {\n";
if (support_parallel_) { if (support_parallel_) {
code.CodeFunction(kParallelLaunch, "RunArm64OptOc", kRunArgsAddr, "args.thread_count_oc", kLhsScale, kRhsScale); code.CodeFunction(kParallelLaunch, "RunArm64OptOc", kRunArgsAddr, "args.thread_count_oc");
} else { } else {
code.CodeFunction("RunArm64OptOc", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale); code.CodeFunction("RunArm64OptOc", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
} }
code << "} else {\n"; code << "} else {\n";
if (support_parallel_) { if (support_parallel_) {
code.CodeFunction(kParallelLaunch, "RunArmOc", kRunArgsAddr, "args.thread_count_oc", kLhsScale, kRhsScale); code.CodeFunction(kParallelLaunch, "RunArmOc", kRunArgsAddr, "args.thread_count_oc");
} else { } else {
code.CodeFunction("RunArmOc", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale); code.CodeFunction("RunArmOc", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
} }
@ -122,13 +122,13 @@ int Conv2D1x1Int8Coder::DoCode(CoderContext *const context) {
/* matmul parallel by hw */ /* matmul parallel by hw */
code << "if (GetSupportOptFlag()) {\n"; code << "if (GetSupportOptFlag()) {\n";
if (support_parallel_) { if (support_parallel_) {
code.CodeFunction(kParallelLaunch, "RunArm64OptHw", kRunArgsAddr, "args.thread_count_hw, kLhsScale, kRhsScale"); code.CodeFunction(kParallelLaunch, "RunArm64OptHw", kRunArgsAddr, "args.thread_count_hw");
} else { } else {
code.CodeFunction("RunArm64OptHw", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale); code.CodeFunction("RunArm64OptHw", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
} }
code << "} else {\n"; code << "} else {\n";
if (support_parallel_) { if (support_parallel_) {
code.CodeFunction(kParallelLaunch, "RunArmHw", kRunArgsAddr, "args.thread_count_hw", kLhsScale, kRhsScale); code.CodeFunction(kParallelLaunch, "RunArmHw", kRunArgsAddr, "args.thread_count_hw");
} else { } else {
code.CodeFunction("RunArmHw", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale); code.CodeFunction("RunArmHw", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
} }

View File

@ -163,7 +163,7 @@ int Conv2D3x3Int8Coder::DoCode(CoderContext *const context) {
if (thread_num_ > 1) { if (thread_num_ > 1) {
code.CodeBaseStruct("Conv3x3Int8Args", kRunArgs, c8_input_, transformed_filter_addr_, new_bias_addr_, code.CodeBaseStruct("Conv3x3Int8Args", kRunArgs, c8_input_, transformed_filter_addr_, new_bias_addr_,
output_tensor_, tile_buffer_, block_unit_buffer_, tmp_dst_buffer_, tmp_out_, "&conv_param_"); output_tensor_, tile_buffer_, block_unit_buffer_, tmp_dst_buffer_, tmp_out_, "&conv_param_");
code.CodeFunction(kParallelLaunch, "Conv3x3Int8Run", kRunArgsAddr, gThreadNum, kLhsScale, kRhsScale); code.CodeFunction(kParallelLaunch, "Conv3x3Int8Run", kRunArgsAddr, gThreadNum);
} else { } else {
code.CodeFunction("Conv3x3Int8", c8_input_, transformed_filter_addr_, new_bias_addr_, output_tensor_, tile_buffer_, code.CodeFunction("Conv3x3Int8", c8_input_, transformed_filter_addr_, new_bias_addr_, output_tensor_, tile_buffer_,
block_unit_buffer_, tmp_dst_buffer_, tmp_out_, kDefaultTaskId, "&conv_param_", kLhsScale, block_unit_buffer_, tmp_dst_buffer_, tmp_out_, kDefaultTaskId, "&conv_param_", kLhsScale,

View File

@ -237,7 +237,7 @@ int Conv2DINT8Coder::DoCode(CoderContext *const context) {
} }
if (support_parallel_) { if (support_parallel_) {
code.CodeFunction(kParallelLaunch, "ConvolutionInt8Run", kRunArgsAddr, gThreadNum, kLhsScale, kRhsScale); code.CodeFunction(kParallelLaunch, "ConvolutionInt8Run", kRunArgsAddr, gThreadNum);
} else { } else {
code.CodeFunction("ConvolutionInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale); code.CodeFunction("ConvolutionInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
} }

View File

@ -122,8 +122,7 @@ int ConvolutionDepthwiseINT8Coder::DoCode(CoderContext *const context) {
code.CodeBaseStruct("ConvDepthwiseInt8Args", kRunArgs, output_tensor_, row_buffer_, input_tensor_, packed_weight_, code.CodeBaseStruct("ConvDepthwiseInt8Args", kRunArgs, output_tensor_, row_buffer_, input_tensor_, packed_weight_,
bias_data_, "&conv_param"); bias_data_, "&conv_param");
if (support_parallel_) { if (support_parallel_) {
code.CodeFunction(kParallelLaunch, "ConvDepthwiseInt8Run", kRunArgsAddr, "conv_param.thread_num_", kLhsScale, code.CodeFunction(kParallelLaunch, "ConvDepthwiseInt8Run", kRunArgsAddr, "conv_param.thread_num_");
kRhsScale);
} else { } else {
code.CodeFunction("ConvDepthwiseInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale); code.CodeFunction("ConvDepthwiseInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
} }

View File

@ -92,7 +92,7 @@ int ResizeInt8Coder::DoCode(CoderContext *const context) {
code.CodeBaseStruct("ResizeInt8Args", kRunArgs, input_tensor_, output_tensor_, "input_shape", "output_shape", code.CodeBaseStruct("ResizeInt8Args", kRunArgs, input_tensor_, output_tensor_, "input_shape", "output_shape",
align_corners, gThreadNum); align_corners, gThreadNum);
if (support_parallel_) { if (support_parallel_) {
code.CodeFunction(kParallelLaunch, "ResizeInt8Run", kRunArgsAddr, gThreadNum, kLhsScale, kRhsScale); code.CodeFunction(kParallelLaunch, "ResizeInt8Run", kRunArgsAddr, gThreadNum);
} else { } else {
code.CodeFunction("ResizeInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale); code.CodeFunction("ResizeInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
} }