forked from mindspore-Ecosystem/mindspore
[MS][LITE] move conv bias quant param to propogator
This commit is contained in:
parent
cac91018ad
commit
304664bd09
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h"
|
||||
#include "mindspore/core/ir/dtype/type_id.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
static constexpr size_t kBiasAdd = 3;
|
||||
|
||||
|
@ -22,6 +23,36 @@ STATUS ConvQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGra
|
|||
const mindspore::schema::CNodeT &node) {
|
||||
if (node.inputIndex.size() == kBiasAdd) {
|
||||
auto &bias_tensor = graph->allTensors.at(node.inputIndex.at(kBiasAdd - 1));
|
||||
if (bias_tensor->quantParams.empty() || !bias_tensor->quantParams.front()->inited) {
|
||||
// check input and weight quant params
|
||||
auto &input_tensor = graph->allTensors.at(node.inputIndex.at(0));
|
||||
auto &weight_tensor = graph->allTensors.at(node.inputIndex.at(1));
|
||||
if (input_tensor->quantParams.empty() || !input_tensor->quantParams.front()->inited) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
if (weight_tensor->quantParams.empty() || !weight_tensor->quantParams.front()->inited) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto &input_quant_param = input_tensor->quantParams.at(0);
|
||||
auto &weight_quant_param = weight_tensor->quantParams.at(0);
|
||||
|
||||
if (bias_tensor->quantParams.empty()) {
|
||||
auto tmp_quant_param = std::make_unique<schema::QuantParamT>();
|
||||
bias_tensor->quantParams.emplace_back(std::move(tmp_quant_param));
|
||||
}
|
||||
auto &bias_quant_param = bias_tensor->quantParams.front();
|
||||
bias_quant_param->min = 0.0;
|
||||
bias_quant_param->max = 0.0;
|
||||
bias_quant_param->dstDtype = kNumberTypeInt32;
|
||||
bias_quant_param->inited = input_quant_param->inited && weight_quant_param->inited;
|
||||
bias_quant_param->zeroPoint = 0;
|
||||
if (bias_quant_param->inited) {
|
||||
bias_quant_param->scale = input_quant_param->scale * weight_quant_param->scale;
|
||||
}
|
||||
bias_quant_param->roundType = 1;
|
||||
bias_quant_param->multiplier = 1;
|
||||
}
|
||||
for (auto &quantParam : bias_tensor->quantParams) {
|
||||
quantParam->dstDtype = TypeId::kNumberTypeInt32;
|
||||
}
|
||||
|
|
|
@ -64,7 +64,8 @@ bool MatMulAddFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
auto matmul_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||
auto bias_node = cnode->input(kAddInputSize - index);
|
||||
if (!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param()) {
|
||||
if (!utils::isa<ValueNode>(bias_node) &&
|
||||
(!utils::isa<Parameter>(bias_node) || !bias_node->cast<ParameterPtr>()->default_param())) {
|
||||
continue;
|
||||
}
|
||||
matmul_cnode->add_input(bias_node);
|
||||
|
|
|
@ -26,40 +26,6 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kDoubleNum = 2;
|
||||
void FillDefaultInputQuantParamIfNeed(const PrimitivePtr &prim, const size_t &input_size) {
|
||||
auto quant_tensor_info_ptr = prim->GetAttr("quant_params");
|
||||
if (quant_tensor_info_ptr == nullptr) {
|
||||
prim->AddAttr("quant_params", std::make_shared<lite::QuantParamHolder>());
|
||||
}
|
||||
auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>();
|
||||
std::vector<schema::QuantParamT> quants;
|
||||
schema::QuantParamT quant_param;
|
||||
auto input_quant_params = quant_param_holder->input_quant_params();
|
||||
if (input_quant_params.size() == kDoubleNum) {
|
||||
quants.clear();
|
||||
quant_param.min = 0.0;
|
||||
quant_param.max = 0.0;
|
||||
quant_param.dstDtype = kNumberTypeInt32;
|
||||
quant_param.inited = input_quant_params.at(0).at(0).inited && input_quant_params.at(1).at(0).inited;
|
||||
quant_param.inited = false;
|
||||
quant_param.zeroPoint = 0;
|
||||
if (quant_param.inited) {
|
||||
quant_param.scale = input_quant_params.at(0).at(0).scale * input_quant_params.at(1).at(0).scale;
|
||||
}
|
||||
quant_param.roundType = 1;
|
||||
quant_param.multiplier = 1;
|
||||
quants.emplace_back(quant_param);
|
||||
input_quant_params.emplace_back(quants);
|
||||
}
|
||||
// fill input_quant_param_ by not inited quant_parm
|
||||
if (input_quant_params.size() < input_size) {
|
||||
schema::QuantParamT tmpQuantParam;
|
||||
quants.emplace_back(tmpQuantParam);
|
||||
input_quant_params.insert(input_quant_params.end(), input_size - input_quant_params.size(), quants);
|
||||
}
|
||||
quant_param_holder->set_input_quant_params(input_quant_params);
|
||||
}
|
||||
|
||||
int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) {
|
||||
auto quant_tensor_info_ptr = prim->GetAttr("quant_params");
|
||||
if (quant_tensor_info_ptr == nullptr) {
|
||||
|
@ -212,7 +178,6 @@ int ConvertQuantParam(const PrimitivePtr &prim, const std::vector<AnfNodePtr> &i
|
|||
MS_LOG(ERROR) << "compute int quant param failed.";
|
||||
return status;
|
||||
}
|
||||
FillDefaultInputQuantParamIfNeed(prim, inputs.size());
|
||||
status = ConvertOutputQuantParam(prim, narrow_range_param, num_bits_param);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "compute output quant param failed.";
|
||||
|
|
Loading…
Reference in New Issue