fix post training quant

This commit is contained in:
xutianchun 2020-08-17 10:57:44 +08:00
parent cde696477c
commit 3f68f575c4
1 changed files with 12 additions and 1 deletions

View File

@ -238,7 +238,6 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) {
}
meta_graphT->nodes.emplace_back(std::move(node));
primitiveT_value->SetPrimitiveT(nullptr);
}
// set graph input tensors
SetGraphInputIndex(meta_graphT);
@ -296,6 +295,18 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anod
paramTensor->nodeType = schema::NodeType_ValueNode;
paramTensor->data.resize(paramValue->tensor_size());
memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size());
for (auto &ite : paramValue->quant_param()) {
auto quantPar = std::make_unique<schema::QuantParamT>();
quantPar->scale = ite->scale;
quantPar->zeroPoint = ite->zeroPoint;
quantPar->min = ite->zeroPoint;
quantPar->max = ite->max;
quantPar->narrowRange = ite->narrowRange;
quantPar->inited = ite->inited;
quantPar->numBits = ite->numBits;
paramTensor->quantParams.emplace_back(std::move(quantPar));
paramTensor->dataType = paramValue->tensor_type();
}
}
node_id_map_[paramNode->fullname_with_scope()] = meta_graphT->allTensors.size();
output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size());