Fix auto-coding style

This commit is contained in:
lz 2022-06-09 21:26:46 +08:00
parent 90ac7781ba
commit 51c0c3e97a
3 changed files with 25 additions and 19 deletions

View File

@ -187,9 +187,9 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
if (is_bias_broadcast_) {
float broad_cast_data = (reinterpret_cast<float *>(bias_tensor_->data()))[0];
std::string bias_ptr_str = "((float *)(" + allocator_->GetRuntimeAddr(bias_ptr_) + "))";
init_code << "\tfor (int i = 0; i < " << max_bias_data << "; ++i) {\n";
init_code << "\t\t" << bias_ptr_str << "[i] = " << broad_cast_data << ";\n";
init_code << "\t}\n";
init_code << "\t for (int i = 0; i < " << max_bias_data << "; ++i) {\n";
init_code << "\t\t " << bias_ptr_str << "[i] = " << broad_cast_data << ";\n";
init_code << " }\n";
} else {
std::string bias_tensor_str = allocator_->GetRuntimeAddr(bias_tensor_);
init_code.CodeFunction("memcpy", bias_ptr_, bias_tensor_str, ori_bias_pack_ptr_size_);
@ -246,25 +246,25 @@ int MatMulFP32BaseCoder::DoCode(CoderContext *const context) {
int current_rest_oc = params_->col_ - kDefaultTaskId * thread_stride_ * col_tile_;
int cur_oc = MSMIN(current_stride_oc, current_rest_oc);
if (cur_oc <= 0) return RET_OK;
code << "for (int i = 0; i < " << params_->batch << "; ++i) {\n";
code << " for (int i = 0; i < " << params_->batch << "; ++i) {\n";
if (vec_matmul_) {
code << "\t\tconst float *batch_a_ptr = " << a_pack_str << " + i * " << params_->deep_ << ";\n";
code << "\t\tconst float *batch_b_ptr = " << b_pack_str << " + i * " << params_->deep_ * params_->col_ << ";\n";
code << "\t\tfloat *batch_c_ptr = " << c_str << " + i * " << params_->row_ * params_->col_ << ";\n";
code << " const float *batch_a_ptr = " << a_pack_str << " + i * " << params_->deep_ << ";\n";
code << " const float *batch_b_ptr = " << b_pack_str << " + i * " << params_->deep_ * params_->col_ << ";\n";
code << " float *batch_c_ptr = " << c_str << " + i * " << params_->row_ * params_->col_ << ";\n";
code.CodeFunction("MatVecMulFp32", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_, params_->act_type_,
params_->deep_, cur_oc);
} else {
code << "\t\tconst float *batch_a_ptr = " << a_pack_str << " + i * " << params_->row_align_ * params_->deep_
code << " const float *batch_a_ptr = " << a_pack_str << " + i * " << params_->row_align_ * params_->deep_
<< ";\n";
code << "\t\tconst float *batch_b_ptr = " << b_pack_str << " + i * " << params_->deep_ * params_->col_align_
code << " const float *batch_b_ptr = " << b_pack_str << " + i * " << params_->deep_ * params_->col_align_
<< ";\n";
code << "\t\tfloat *batch_c_ptr = " << c_str << " + i * " << params_->row_ * params_->col_ << ";\n";
code << " float *batch_c_ptr = " << c_str << " + i * " << params_->row_ * params_->col_ << ";\n";
code.CodeFunction("MatMulOpt", "batch_a_ptr", "batch_b_ptr", "batch_c_ptr", bias_ptr_, params_->act_type_,
params_->deep_, params_->row_, cur_oc, params_->col_, "OutType_Nhwc");
}
code << "\t\t}\n";
code << " }\n";
context->AppendInitWeightSizeCode(w_buf_size);
context->AppendCode(code.str());
context->AppendInitCode(init_code.str());

View File

@ -60,8 +60,8 @@ void NNaclFp32Serializer::CodeStruct(const std::string &name, const SoftmaxParam
}
void NNaclFp32Serializer::CodeStruct(const std::string &name, const ConvParameter &conv_parameter) {
code << gThreadNum << " = 1;\n";
code << "int thread_num = MSMIN(" << gThreadNum << ", " << conv_parameter.output_h_ << ");\n";
code << " " << gThreadNum << " = 1;\n";
code << " int thread_num = MSMIN(" << gThreadNum << ", " << conv_parameter.output_h_ << ");\n";
CodeBaseStruct<false>(
"ConvParameter", name, conv_parameter.op_parameter_, "{0}", conv_parameter.kernel_h_, conv_parameter.kernel_w_,
conv_parameter.stride_h_, conv_parameter.stride_w_, conv_parameter.dilation_h_, conv_parameter.dilation_w_,

View File

@ -67,7 +67,7 @@ class Serializer {
*/
template <typename... PARAMETERS>
void CodeFunction(const std::string &name, PARAMETERS... parameters) {
code << name << "(";
code << " " << name << "(";
GenCode(parameters...);
code << ");\n";
}
@ -105,9 +105,9 @@ class Serializer {
void CodeArray(const std::string &name, T *data, int length, bool is_const = true) {
std::string type = GetVariableTypeName<T>();
if (is_const) {
code << "const " << type << " " << name << "[" << length << "] = {";
code << " const " << type << " " << name << "[" << length << "] = {";
} else {
code << type << " " << name << "[" << length << "] = {";
code << " " << type << " " << name << "[" << length << "] = {";
}
for (int i = 0; i < length - 1; ++i) {
code << data[i] << ", ";
@ -190,9 +190,9 @@ class Serializer {
template <bool immutable = true, typename... PARAMETERS>
void CodeBaseStruct(const std::string &type, const std::string &name, PARAMETERS... parameters) {
if constexpr (immutable) {
code << "const " << type << " " << name << " = {";
code << " const " << type << " " << name << " = {";
} else {
code << type << " " << name << " = {";
code << " " << type << " " << name << " = {";
}
GenCode(parameters...);
code << "};\n";
@ -254,7 +254,13 @@ class Serializer {
void GenCode(int8_t t) { code << std::to_string(t); }
void GenCode(uint8_t t) { code << std::to_string(t); }
void GenCode(decltype(nullptr) t) { code << "NULL"; }
void GenCode(const char *t) { code << t; }
void GenCode(const char *t) {
if (t == nullptr || (t != nullptr && strlen(t) == 0)) {
code << "{0}";
} else {
code << t;
}
}
void GenCode(TypeIdC t) { code << "(TypeIdC)" << t; }
};
} // namespace mindspore::lite::micro