forked from mindspore-Ecosystem/mindspore
!40091 modify quant params converter of onnx parser
Merge pull request !40091 from liyan2022/qat_fix_onnx_quant_parser
This commit is contained in:
commit
51fee4ef61
|
@ -980,8 +980,8 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o
|
|||
return RET_ERROR;
|
||||
}
|
||||
// set input tensors
|
||||
auto quant_params_holder = std::make_shared<QuantParamHolder>(onnx_node.input_size(), onnx_node.output_size());
|
||||
MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, RET_NULL_PTR, "create QuantParamHolder return nullptr");
|
||||
std::map<int, std::vector<schema::QuantParamT>> input_quant_params;
|
||||
size_t idx = 0;
|
||||
for (int i = 0; i < onnx_node.input_size(); ++i) {
|
||||
const auto &input_name = onnx_node.input(i);
|
||||
std::vector<schema::QuantParamT> quant_params;
|
||||
|
@ -990,9 +990,14 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o
|
|||
MS_LOG(ERROR) << "set input tensor quant param failed.";
|
||||
return status;
|
||||
}
|
||||
quant_params_holder->set_input_quant_param(i, quant_params);
|
||||
if (!quant_params.empty()) {
|
||||
input_quant_params.insert({idx, quant_params});
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
// set out tensors
|
||||
idx = 0;
|
||||
std::map<int, std::vector<schema::QuantParamT>> output_quant_params;
|
||||
for (int i = 0; i < onnx_node.output_size(); ++i) {
|
||||
const auto &output_name = onnx_node.output(i);
|
||||
std::vector<schema::QuantParamT> quant_params;
|
||||
|
@ -1001,9 +1006,22 @@ STATUS OnnxModelParser::ConvertOpQuantParams(const onnx::NodeProto &onnx_node, o
|
|||
MS_LOG(ERROR) << "set output tensor quant param failed.";
|
||||
return status;
|
||||
}
|
||||
quant_params_holder->set_output_quant_param(i, quant_params);
|
||||
if (!quant_params.empty()) {
|
||||
output_quant_params.insert({idx, quant_params});
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
if (!input_quant_params.empty() || !output_quant_params.empty()) {
|
||||
auto quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
|
||||
MSLITE_CHECK_PTR(quant_params_holder);
|
||||
for (auto &iter : input_quant_params) {
|
||||
quant_params_holder->set_input_quant_param(iter.first, iter.second);
|
||||
}
|
||||
for (auto &iter : output_quant_params) {
|
||||
quant_params_holder->set_output_quant_param(iter.first, iter.second);
|
||||
}
|
||||
primitive_c->AddAttr("quant_params", quant_params_holder);
|
||||
}
|
||||
primitive_c->AddAttr("quant_params", quant_params_holder);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -1089,9 +1107,6 @@ STATUS OnnxModelParser::SetTensorQuantParamFromNode(const std::string &tensor_na
|
|||
}
|
||||
if (quant_param->inited) {
|
||||
quant_params->push_back(*std::move(quant_param));
|
||||
} else {
|
||||
std::vector<schema::QuantParamT> notinited_quant_params(1);
|
||||
*quant_params = notinited_quant_params;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue