!40091 modify quant params converter of onnx parser

Merge pull request !40091 from liyan2022/qat_fix_onnx_quant_parser
This commit is contained in:
i-robot 2022-08-10 08:27:54 +00:00 committed by Gitee
commit 51fee4ef61
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 23 additions and 8 deletions

View File

@ -980,8 +980,8 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o
return RET_ERROR;
}
// set input tensors
auto quant_params_holder = std::make_shared<QuantParamHolder>(onnx_node.input_size(), onnx_node.output_size());
MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, RET_NULL_PTR, "create QuantParamHolder return nullptr");
std::map<int, std::vector<schema::QuantParamT>> input_quant_params;
size_t idx = 0;
for (int i = 0; i < onnx_node.input_size(); ++i) {
const auto &input_name = onnx_node.input(i);
std::vector<schema::QuantParamT> quant_params;
@ -990,9 +990,14 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o
MS_LOG(ERROR) << "set input tensor quant param failed.";
return status;
}
quant_params_holder->set_input_quant_param(i, quant_params);
if (!quant_params.empty()) {
input_quant_params.insert({idx, quant_params});
idx++;
}
}
// set out tensors
idx = 0;
std::map<int, std::vector<schema::QuantParamT>> output_quant_params;
for (int i = 0; i < onnx_node.output_size(); ++i) {
const auto &output_name = onnx_node.output(i);
std::vector<schema::QuantParamT> quant_params;
@ -1001,9 +1006,22 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o
MS_LOG(ERROR) << "set output tensor quant param failed.";
return status;
}
quant_params_holder->set_output_quant_param(i, quant_params);
if (!quant_params.empty()) {
output_quant_params.insert({idx, quant_params});
idx++;
}
}
if (!input_quant_params.empty() || !output_quant_params.empty()) {
auto quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
MSLITE_CHECK_PTR(quant_params_holder);
for (auto &iter : input_quant_params) {
quant_params_holder->set_input_quant_param(iter.first, iter.second);
}
for (auto &iter : output_quant_params) {
quant_params_holder->set_output_quant_param(iter.first, iter.second);
}
primitive_c->AddAttr("quant_params", quant_params_holder);
}
primitive_c->AddAttr("quant_params", quant_params_holder);
return RET_OK;
}
@ -1089,9 +1107,6 @@ STATUS OnnxModelParser::SetTensorQuantParamFromNode(const std::string &tensor_na
}
if (quant_param->inited) {
quant_params->push_back(*std::move(quant_param));
} else {
std::vector<schema::QuantParamT> notinited_quant_params(1);
*quant_params = notinited_quant_params;
}
return RET_OK;
}