[MS][LITE] move conv bias quant param to propogator

This commit is contained in:
cjh9368 2021-04-13 17:04:22 +08:00
parent cac91018ad
commit 304664bd09
3 changed files with 33 additions and 36 deletions

View File

@ -15,6 +15,7 @@
*/
#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h"
#include "mindspore/core/ir/dtype/type_id.h"
namespace mindspore::lite {
static constexpr size_t kBiasAdd = 3;
@ -22,6 +23,36 @@ STATUS ConvQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGra
const mindspore::schema::CNodeT &node) {
if (node.inputIndex.size() == kBiasAdd) {
auto &bias_tensor = graph->allTensors.at(node.inputIndex.at(kBiasAdd - 1));
if (bias_tensor->quantParams.empty() || !bias_tensor->quantParams.front()->inited) {
// check input and weight quant params
auto &input_tensor = graph->allTensors.at(node.inputIndex.at(0));
auto &weight_tensor = graph->allTensors.at(node.inputIndex.at(1));
if (input_tensor->quantParams.empty() || !input_tensor->quantParams.front()->inited) {
return RET_OK;
}
if (weight_tensor->quantParams.empty() || !weight_tensor->quantParams.front()->inited) {
return RET_OK;
}
auto &input_quant_param = input_tensor->quantParams.at(0);
auto &weight_quant_param = weight_tensor->quantParams.at(0);
if (bias_tensor->quantParams.empty()) {
auto tmp_quant_param = std::make_unique<schema::QuantParamT>();
bias_tensor->quantParams.emplace_back(std::move(tmp_quant_param));
}
auto &bias_quant_param = bias_tensor->quantParams.front();
bias_quant_param->min = 0.0;
bias_quant_param->max = 0.0;
bias_quant_param->dstDtype = kNumberTypeInt32;
bias_quant_param->inited = input_quant_param->inited && weight_quant_param->inited;
bias_quant_param->zeroPoint = 0;
if (bias_quant_param->inited) {
bias_quant_param->scale = input_quant_param->scale * weight_quant_param->scale;
}
bias_quant_param->roundType = 1;
bias_quant_param->multiplier = 1;
}
for (auto &quantParam : bias_tensor->quantParams) {
quantParam->dstDtype = TypeId::kNumberTypeInt32;
}

View File

@ -64,7 +64,8 @@ bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) {
}
auto matmul_cnode = cnode->input(index)->cast<CNodePtr>();
auto bias_node = cnode->input(kAddInputSize - index);
if (!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param()) {
if (!utils::isa<ValueNode>(bias_node) &&
(!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param())) {
continue;
}
matmul_cnode->add_input(bias_node);

View File

@ -26,40 +26,6 @@ namespace mindspore {
namespace opt {
namespace {
constexpr size_t kDoubleNum = 2;
void FillDefaultInputQuantParamIfNeed(const PrimitivePtr &prim, const size_t &input_size) {
auto quant_tensor_info_ptr = prim->GetAttr("quant_params");
if (quant_tensor_info_ptr == nullptr) {
prim->AddAttr("quant_params", std::make_shared<lite::QuantParamHolder>());
}
auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>();
std::vector<schema::QuantParamT> quants;
schema::QuantParamT quant_param;
auto input_quant_params = quant_param_holder->input_quant_params();
if (input_quant_params.size() == kDoubleNum) {
quants.clear();
quant_param.min = 0.0;
quant_param.max = 0.0;
quant_param.dstDtype = kNumberTypeInt32;
quant_param.inited = input_quant_params.at(0).at(0).inited && input_quant_params.at(1).at(0).inited;
quant_param.inited = false;
quant_param.zeroPoint = 0;
if (quant_param.inited) {
quant_param.scale = input_quant_params.at(0).at(0).scale * input_quant_params.at(1).at(0).scale;
}
quant_param.roundType = 1;
quant_param.multiplier = 1;
quants.emplace_back(quant_param);
input_quant_params.emplace_back(quants);
}
// fill input_quant_param_ by not inited quant_parm
if (input_quant_params.size() < input_size) {
schema::QuantParamT tmpQuantParam;
quants.emplace_back(tmpQuantParam);
input_quant_params.insert(input_quant_params.end(), input_size - input_quant_params.size(), quants);
}
quant_param_holder->set_input_quant_params(input_quant_params);
}
int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) {
auto quant_tensor_info_ptr = prim->GetAttr("quant_params");
if (quant_tensor_info_ptr == nullptr) {
@ -212,7 +178,6 @@ int ConvertQuantParam(const PrimitivePtr &prim, const std::vector<AnfNodePtr> &i
MS_LOG(ERROR) << "compute int quant param failed.";
return status;
}
FillDefaultInputQuantParamIfNeed(prim, inputs.size());
status = ConvertOutputQuantParam(prim, narrow_range_param, num_bits_param);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "compute output quant param failed.";