forked from mindspore-Ecosystem/mindspore
!17898 fix codegen bugs
Merge pull request !17898 from yangjie159/modify_cpp
This commit is contained in:
commit
ad5e03c856
|
@ -186,8 +186,7 @@ int DeConvolutionFP32Coder::DoCode(CoderContext *const context) {
|
||||||
if (!support_parallel_) {
|
if (!support_parallel_) {
|
||||||
code.CodeFunction("DeConvFp32Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
code.CodeFunction("DeConvFp32Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
||||||
} else {
|
} else {
|
||||||
code.CodeFunction(kParallelLaunch, "DeConvFp32Run", kRunArgsAddr, "conv_parameter.thread_num_", kLhsScale,
|
code.CodeFunction(kParallelLaunch, "DeConvFp32Run", kRunArgsAddr, "conv_parameter.thread_num_");
|
||||||
kRhsScale);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
context->AppendCode(code.str());
|
context->AppendCode(code.str());
|
||||||
|
|
|
@ -156,9 +156,9 @@ int AddInt8Coder::DoCode(CoderContext *const context) {
|
||||||
support_opt_add_, input0, input1, output_tensor_);
|
support_opt_add_, input0, input1, output_tensor_);
|
||||||
if (support_parallel_) {
|
if (support_parallel_) {
|
||||||
if (arith_para_->broadcasting_) {
|
if (arith_para_->broadcasting_) {
|
||||||
code.CodeFunction(kParallelLaunch, "AddBroadcastInt8Run", kRunArgsAddr, gThreadNum, kLhsScale, kRhsScale);
|
code.CodeFunction(kParallelLaunch, "AddBroadcastInt8Run", kRunArgsAddr, gThreadNum);
|
||||||
} else {
|
} else {
|
||||||
code.CodeFunction(kParallelLaunch, "AddInt8Run", kRunArgsAddr, gThreadNum, kLhsScale, kRhsScale);
|
code.CodeFunction(kParallelLaunch, "AddInt8Run", kRunArgsAddr, gThreadNum);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (arith_para_->broadcasting_) {
|
if (arith_para_->broadcasting_) {
|
||||||
|
|
|
@ -113,7 +113,7 @@ int ConcatInt8Coder::DoCode(CoderContext *const context) {
|
||||||
code.CodeBaseStruct<false>("ConcatInt8Args", kRunArgs, "input_data", output_tensor_, "&concat_param", axis_,
|
code.CodeBaseStruct<false>("ConcatInt8Args", kRunArgs, "input_data", output_tensor_, "&concat_param", axis_,
|
||||||
before_axis_size, count_unit_);
|
before_axis_size, count_unit_);
|
||||||
if (support_parallel_) {
|
if (support_parallel_) {
|
||||||
code.CodeFunction(kParallelLaunch, "ConcatInt8Run", kRunArgsAddr, gThreadNum, kLhsScale, kRhsScale);
|
code.CodeFunction(kParallelLaunch, "ConcatInt8Run", kRunArgsAddr, gThreadNum);
|
||||||
} else {
|
} else {
|
||||||
code.CodeFunction("ConcatInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
code.CodeFunction("ConcatInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
||||||
}
|
}
|
||||||
|
|
|
@ -88,7 +88,7 @@ int Conv2D1x1Int8Coder::DoCode(CoderContext *const context) {
|
||||||
/* input transpose and input sum */
|
/* input transpose and input sum */
|
||||||
code << "if (GetSupportOptFlag()) {\n";
|
code << "if (GetSupportOptFlag()) {\n";
|
||||||
if (support_parallel_) {
|
if (support_parallel_) {
|
||||||
code.CodeFunction(kParallelLaunch, "OcOptPre", kRunArgsAddr, "args.thread_count_hw", kLhsScale, kRhsScale);
|
code.CodeFunction(kParallelLaunch, "OcOptPre", kRunArgsAddr, "args.thread_count_hw");
|
||||||
} else {
|
} else {
|
||||||
code.CodeFunction("OcOptPre", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
code.CodeFunction("OcOptPre", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
||||||
}
|
}
|
||||||
|
@ -107,13 +107,13 @@ int Conv2D1x1Int8Coder::DoCode(CoderContext *const context) {
|
||||||
/* matmul parallel by oc */
|
/* matmul parallel by oc */
|
||||||
code << "if (GetSupportOptFlag()) {\n";
|
code << "if (GetSupportOptFlag()) {\n";
|
||||||
if (support_parallel_) {
|
if (support_parallel_) {
|
||||||
code.CodeFunction(kParallelLaunch, "RunArm64OptOc", kRunArgsAddr, "args.thread_count_oc", kLhsScale, kRhsScale);
|
code.CodeFunction(kParallelLaunch, "RunArm64OptOc", kRunArgsAddr, "args.thread_count_oc");
|
||||||
} else {
|
} else {
|
||||||
code.CodeFunction("RunArm64OptOc", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
code.CodeFunction("RunArm64OptOc", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
||||||
}
|
}
|
||||||
code << "} else {\n";
|
code << "} else {\n";
|
||||||
if (support_parallel_) {
|
if (support_parallel_) {
|
||||||
code.CodeFunction(kParallelLaunch, "RunArmOc", kRunArgsAddr, "args.thread_count_oc", kLhsScale, kRhsScale);
|
code.CodeFunction(kParallelLaunch, "RunArmOc", kRunArgsAddr, "args.thread_count_oc");
|
||||||
} else {
|
} else {
|
||||||
code.CodeFunction("RunArmOc", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
code.CodeFunction("RunArmOc", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
||||||
}
|
}
|
||||||
|
@ -122,13 +122,13 @@ int Conv2D1x1Int8Coder::DoCode(CoderContext *const context) {
|
||||||
/* matmul parallel by hw */
|
/* matmul parallel by hw */
|
||||||
code << "if (GetSupportOptFlag()) {\n";
|
code << "if (GetSupportOptFlag()) {\n";
|
||||||
if (support_parallel_) {
|
if (support_parallel_) {
|
||||||
code.CodeFunction(kParallelLaunch, "RunArm64OptHw", kRunArgsAddr, "args.thread_count_hw, kLhsScale, kRhsScale");
|
code.CodeFunction(kParallelLaunch, "RunArm64OptHw", kRunArgsAddr, "args.thread_count_hw");
|
||||||
} else {
|
} else {
|
||||||
code.CodeFunction("RunArm64OptHw", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
code.CodeFunction("RunArm64OptHw", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
||||||
}
|
}
|
||||||
code << "} else {\n";
|
code << "} else {\n";
|
||||||
if (support_parallel_) {
|
if (support_parallel_) {
|
||||||
code.CodeFunction(kParallelLaunch, "RunArmHw", kRunArgsAddr, "args.thread_count_hw", kLhsScale, kRhsScale);
|
code.CodeFunction(kParallelLaunch, "RunArmHw", kRunArgsAddr, "args.thread_count_hw");
|
||||||
} else {
|
} else {
|
||||||
code.CodeFunction("RunArmHw", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
code.CodeFunction("RunArmHw", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
||||||
}
|
}
|
||||||
|
|
|
@ -163,7 +163,7 @@ int Conv2D3x3Int8Coder::DoCode(CoderContext *const context) {
|
||||||
if (thread_num_ > 1) {
|
if (thread_num_ > 1) {
|
||||||
code.CodeBaseStruct("Conv3x3Int8Args", kRunArgs, c8_input_, transformed_filter_addr_, new_bias_addr_,
|
code.CodeBaseStruct("Conv3x3Int8Args", kRunArgs, c8_input_, transformed_filter_addr_, new_bias_addr_,
|
||||||
output_tensor_, tile_buffer_, block_unit_buffer_, tmp_dst_buffer_, tmp_out_, "&conv_param_");
|
output_tensor_, tile_buffer_, block_unit_buffer_, tmp_dst_buffer_, tmp_out_, "&conv_param_");
|
||||||
code.CodeFunction(kParallelLaunch, "Conv3x3Int8Run", kRunArgsAddr, gThreadNum, kLhsScale, kRhsScale);
|
code.CodeFunction(kParallelLaunch, "Conv3x3Int8Run", kRunArgsAddr, gThreadNum);
|
||||||
} else {
|
} else {
|
||||||
code.CodeFunction("Conv3x3Int8", c8_input_, transformed_filter_addr_, new_bias_addr_, output_tensor_, tile_buffer_,
|
code.CodeFunction("Conv3x3Int8", c8_input_, transformed_filter_addr_, new_bias_addr_, output_tensor_, tile_buffer_,
|
||||||
block_unit_buffer_, tmp_dst_buffer_, tmp_out_, kDefaultTaskId, "&conv_param_", kLhsScale,
|
block_unit_buffer_, tmp_dst_buffer_, tmp_out_, kDefaultTaskId, "&conv_param_", kLhsScale,
|
||||||
|
|
|
@ -237,7 +237,7 @@ int Conv2DINT8Coder::DoCode(CoderContext *const context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (support_parallel_) {
|
if (support_parallel_) {
|
||||||
code.CodeFunction(kParallelLaunch, "ConvolutionInt8Run", kRunArgsAddr, gThreadNum, kLhsScale, kRhsScale);
|
code.CodeFunction(kParallelLaunch, "ConvolutionInt8Run", kRunArgsAddr, gThreadNum);
|
||||||
} else {
|
} else {
|
||||||
code.CodeFunction("ConvolutionInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
code.CodeFunction("ConvolutionInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
||||||
}
|
}
|
||||||
|
|
|
@ -122,8 +122,7 @@ int ConvolutionDepthwiseINT8Coder::DoCode(CoderContext *const context) {
|
||||||
code.CodeBaseStruct("ConvDepthwiseInt8Args", kRunArgs, output_tensor_, row_buffer_, input_tensor_, packed_weight_,
|
code.CodeBaseStruct("ConvDepthwiseInt8Args", kRunArgs, output_tensor_, row_buffer_, input_tensor_, packed_weight_,
|
||||||
bias_data_, "&conv_param");
|
bias_data_, "&conv_param");
|
||||||
if (support_parallel_) {
|
if (support_parallel_) {
|
||||||
code.CodeFunction(kParallelLaunch, "ConvDepthwiseInt8Run", kRunArgsAddr, "conv_param.thread_num_", kLhsScale,
|
code.CodeFunction(kParallelLaunch, "ConvDepthwiseInt8Run", kRunArgsAddr, "conv_param.thread_num_");
|
||||||
kRhsScale);
|
|
||||||
} else {
|
} else {
|
||||||
code.CodeFunction("ConvDepthwiseInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
code.CodeFunction("ConvDepthwiseInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,7 +92,7 @@ int ResizeInt8Coder::DoCode(CoderContext *const context) {
|
||||||
code.CodeBaseStruct("ResizeInt8Args", kRunArgs, input_tensor_, output_tensor_, "input_shape", "output_shape",
|
code.CodeBaseStruct("ResizeInt8Args", kRunArgs, input_tensor_, output_tensor_, "input_shape", "output_shape",
|
||||||
align_corners, gThreadNum);
|
align_corners, gThreadNum);
|
||||||
if (support_parallel_) {
|
if (support_parallel_) {
|
||||||
code.CodeFunction(kParallelLaunch, "ResizeInt8Run", kRunArgsAddr, gThreadNum, kLhsScale, kRhsScale);
|
code.CodeFunction(kParallelLaunch, "ResizeInt8Run", kRunArgsAddr, gThreadNum);
|
||||||
} else {
|
} else {
|
||||||
code.CodeFunction("ResizeInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
code.CodeFunction("ResizeInt8Run", kRunArgsAddr, kDefaultTaskId, kLhsScale, kRhsScale);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue