forked from mindspore-Ecosystem/mindspore
!4469 change param_t_ value quant params
Merge pull request !4469 from cjh9368/aware_quant
This commit is contained in:
commit
ed9e62c760
|
@ -197,7 +197,7 @@ union PrimitiveType {
|
|||
|
||||
enum QuantType: int {
|
||||
QUANT_NONE,
|
||||
AwareTrainning,
|
||||
AwareTraining,
|
||||
WeightQuant,
|
||||
PostTraining
|
||||
}
|
||||
|
|
|
@ -188,7 +188,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
|
|||
|
||||
// add quant param
|
||||
node->quantType = primitiveT_value->GetQuantType();
|
||||
if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTrainning) {
|
||||
if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTraining) {
|
||||
MS_LOG(INFO) << "node: " << node->name << " add QuantParam";
|
||||
// activation
|
||||
auto input_quant_params = primitiveT_value->GetInputQuantParams();
|
||||
|
@ -202,14 +202,12 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
|
|||
auto activate_index = node->inputIndex[i];
|
||||
auto tensor_input = metaGraphT->allTensors[activate_index].get();
|
||||
if (tensor_input->quantParams.empty()) {
|
||||
std::unique_ptr<schema::QuantParamT> input_quant_param =
|
||||
std::make_unique<schema::QuantParamT>(input_quant_params[i]);
|
||||
MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale
|
||||
<< " zp: " << input_quant_param->zeroPoint;
|
||||
tensor_input->quantParams.emplace_back(std::move(input_quant_param));
|
||||
if (!(node_type == schema::PrimitiveType_QuantDTypeCast &&
|
||||
primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) {
|
||||
tensor_input->dataType = kNumberTypeInt8;
|
||||
for (auto input_quant_param : input_quant_params[i]) {
|
||||
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
|
||||
std::make_unique<schema::QuantParamT>(input_quant_param);
|
||||
MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param_ptr->scale
|
||||
<< " zp: " << input_quant_param_ptr->zeroPoint;
|
||||
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -221,15 +219,18 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
|
|||
if (output_quant_params.empty()) {
|
||||
MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty";
|
||||
} else {
|
||||
if (tensor_output->quantParams.empty()) {
|
||||
std::unique_ptr<schema::QuantParamT> output_quant_param =
|
||||
std::make_unique<schema::QuantParamT>(output_quant_params[0]);
|
||||
MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale
|
||||
<< " zp: " << output_quant_param->zeroPoint;
|
||||
tensor_output->quantParams.emplace_back(std::move(output_quant_param));
|
||||
for (auto output_quant_param : output_quant_params[0]) {
|
||||
if (tensor_output->quantParams.empty()) {
|
||||
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
|
||||
std::make_unique<schema::QuantParamT>(output_quant_param);
|
||||
MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << output_quant_param_ptr->scale
|
||||
<< " zp: " << output_quant_param_ptr->zeroPoint;
|
||||
tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!(node_type == schema::PrimitiveType_QuantDTypeCast &&
|
||||
if (node->quantType != schema::QuantType_AwareTraining &&
|
||||
!(node_type == schema::PrimitiveType_QuantDTypeCast &&
|
||||
primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) {
|
||||
tensor_output->dataType = kNumberTypeInt8;
|
||||
}
|
||||
|
@ -322,18 +323,6 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
|
|||
paramTensor->nodeType = schema::NodeType_ValueNode;
|
||||
paramTensor->data.resize(paramValue->tensor_size());
|
||||
memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size());
|
||||
for (auto &ite : paramValue->quant_param()) {
|
||||
auto quantPar = std::make_unique<schema::QuantParamT>();
|
||||
quantPar->scale = ite->scale;
|
||||
quantPar->zeroPoint = ite->zeroPoint;
|
||||
quantPar->min = ite->min;
|
||||
quantPar->max = ite->max;
|
||||
quantPar->narrowRange = ite->narrowRange;
|
||||
quantPar->inited = ite->inited;
|
||||
quantPar->numBits = ite->numBits;
|
||||
paramTensor->quantParams.emplace_back(std::move(quantPar));
|
||||
paramTensor->dataType = paramValue->tensor_type();
|
||||
}
|
||||
}
|
||||
nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size();
|
||||
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
|
||||
|
|
|
@ -225,7 +225,7 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim,
|
|||
PopulaterConv2DSingleGroup(prim, primitive, group);
|
||||
}
|
||||
primitiveTValuePtr->SetPrimitiveT(primitive.release());
|
||||
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTrainning) {
|
||||
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) {
|
||||
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
|
||||
PopulaterQuantParam(prim, &vecQuantParam);
|
||||
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
|
||||
|
|
|
@ -89,13 +89,15 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
|
|||
}
|
||||
auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release());
|
||||
// add quant parameter
|
||||
if (cNode->quantType == schema::QuantType_AwareTrainning) {
|
||||
if (cNode->quantType == schema::QuantType_AwareTraining) {
|
||||
primTValue->SetQuantType(cNode->quantType);
|
||||
for (int index : cNode->inputIndex) {
|
||||
primTValue->AddInputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0]));
|
||||
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])};
|
||||
primTValue->AddInputQuantParam(quant_params);
|
||||
}
|
||||
for (int index : cNode->outputIndex) {
|
||||
primTValue->AddOutputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0]));
|
||||
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])};
|
||||
primTValue->AddOutputQuantParam(quant_params);
|
||||
}
|
||||
}
|
||||
cNode->primitive = nullptr;
|
||||
|
|
|
@ -49,17 +49,17 @@ class PrimitiveTValue : public Value {
|
|||
void SetInputQuantParam(std::vector<std::vector<schema::QuantParamT>> vec_quant_param) {
|
||||
}
|
||||
|
||||
void AddInputQuantParam(schema::QuantParamT quant_param) {
|
||||
void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) {
|
||||
this->input_quant_param_.emplace_back(quant_param);
|
||||
}
|
||||
std::vector<schema::QuantParamT> GetInputQuantParams() const {
|
||||
std::vector<std::vector<schema::QuantParamT>> GetInputQuantParams() const {
|
||||
return input_quant_param_;
|
||||
}
|
||||
|
||||
void AddOutputQuantParam(schema::QuantParamT quant_param) {
|
||||
void AddOutputQuantParam(std::vector<schema::QuantParamT> quant_param) {
|
||||
this->output_quant_param_.emplace_back(quant_param);
|
||||
}
|
||||
std::vector<schema::QuantParamT> GetOutputQuantParams() const {
|
||||
std::vector<std::vector<schema::QuantParamT>> GetOutputQuantParams() const {
|
||||
return output_quant_param_;
|
||||
}
|
||||
|
||||
|
@ -69,8 +69,8 @@ class PrimitiveTValue : public Value {
|
|||
|
||||
protected:
|
||||
schema::PrimitiveT *primitive = nullptr;
|
||||
std::vector<schema::QuantParamT> input_quant_param_;
|
||||
std::vector<schema::QuantParamT> output_quant_param_;
|
||||
std::vector<std::vector<schema::QuantParamT>> input_quant_param_;
|
||||
std::vector<std::vector<schema::QuantParamT>> output_quant_param_;
|
||||
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -131,7 +131,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
|||
void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) {
|
||||
auto type = flags->quantType;
|
||||
switch (type) {
|
||||
case mindspore::schema::QuantType_AwareTrainning: {
|
||||
case mindspore::schema::QuantType_AwareTraining: {
|
||||
// mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean));
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ Flags::Flags() {
|
|||
"Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", "");
|
||||
AddFlag(&Flags::inferenceType, "inferenceType",
|
||||
"Real data type saved in output file, reserved param, NOT used for now. FLOAT | FP16 | UINT8", "FLOAT");
|
||||
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTrainning | WeightQuant | PostTraining", "");
|
||||
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | WeightQuant | PostTraining", "");
|
||||
AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | UINT8", "FLOAT");
|
||||
AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128");
|
||||
AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "127");
|
||||
|
@ -98,8 +98,8 @@ int Flags::Init(int argc, const char **argv) {
|
|||
std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag";
|
||||
return 1;
|
||||
}
|
||||
if (this->quantTypeIn == "AwareTrainning") {
|
||||
this->quantType = QuantType_AwareTrainning;
|
||||
if (this->quantTypeIn == "AwareTraining") {
|
||||
this->quantType = QuantType_AwareTraining;
|
||||
} else if (this->quantTypeIn == "WeightQuant") {
|
||||
this->quantType = QuantType_WeightQuant;
|
||||
} else if (this->quantTypeIn == "PostTraining") {
|
||||
|
@ -107,7 +107,7 @@ int Flags::Init(int argc, const char **argv) {
|
|||
} else if (this->quantTypeIn.empty()) {
|
||||
this->quantType = QuantType_QUANT_NONE;
|
||||
} else {
|
||||
std::cerr << "INPUT ILLEGAL: quantType must be AwareTrainning|WeightQuant|PostTraining";
|
||||
std::cerr << "INPUT ILLEGAL: quantType must be AwareTraining|WeightQuant|PostTraining";
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace lite {
|
|||
using mindspore::schema::QuantType;
|
||||
using mindspore::schema::QuantType_PostTraining;
|
||||
using mindspore::schema::QuantType_QUANT_NONE;
|
||||
using mindspore::schema::QuantType_AwareTrainning;
|
||||
using mindspore::schema::QuantType_AwareTraining;
|
||||
using mindspore::schema::QuantType_WeightQuant;
|
||||
using mindspore::schema::QuantType_PostTraining;
|
||||
using mindspore::schema::QuantType_PostTraining;
|
||||
|
|
|
@ -68,8 +68,8 @@ void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _
|
|||
void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) {
|
||||
auto type = flags->quantType;
|
||||
switch (type) {
|
||||
case QuantType::QuantType_AwareTrainning: {
|
||||
MS_LOG(INFO) << "create AwareTrainningQuantizer!";
|
||||
case QuantType::QuantType_AwareTraining: {
|
||||
MS_LOG(INFO) << "create AwareTrainingQuantizer!";
|
||||
fbQuantizer =
|
||||
std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean);
|
||||
break;
|
||||
|
@ -146,7 +146,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
return status;
|
||||
}
|
||||
if (!(this->graphDefT->fmkType == converter::FmkType_TF &&
|
||||
this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTrainning)) {
|
||||
this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTraining)) {
|
||||
status = mQuantizer->GenerateQuantParam();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "GenerateQuantParam failed";
|
||||
|
@ -173,7 +173,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
formatTransOptimizer.AddPass(formatTransPass);
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
|
||||
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||
// if (ctx.quantType == QuantType_AwareTrainning) {
|
||||
// if (ctx.quantType == QuantType_AwareTraining) {
|
||||
// formatTransOptimizer.AddPass(new (std::nothrow) FormatTransNodeQuantParamFillPass());
|
||||
// }
|
||||
status = formatTransOptimizer.Run(graphDefT);
|
||||
|
@ -193,7 +193,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
}
|
||||
|
||||
// insert quantNode and deQuantNode
|
||||
if (ctx.quantType == QuantType_AwareTrainning) {
|
||||
if (ctx.quantType == QuantType_AwareTraining) {
|
||||
Optimizer quantNodeOptimizer;
|
||||
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
|
||||
if (dTypeTransPass == nullptr) {
|
||||
|
|
|
@ -136,7 +136,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
|
|||
MS_ASSERT(graph != nullptr);
|
||||
// insert transNode before and after existNode
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTrainning) {
|
||||
if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) {
|
||||
continue;
|
||||
}
|
||||
auto &node = *iter;
|
||||
|
@ -208,7 +208,7 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte
|
|||
transNode->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
transNode->primitive->value.value = quantDTypeCastParam;
|
||||
transNode->primitive->value.type = PrimitiveType_QuantDTypeCast;
|
||||
transNode->quantType = QuantType_AwareTrainning;
|
||||
transNode->quantType = QuantType_AwareTraining;
|
||||
if (nodeType == kInt8ToFP32) {
|
||||
quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8;
|
||||
quantDTypeCastParam->dstT = TypeId::kNumberTypeFloat32;
|
||||
|
|
|
@ -103,7 +103,7 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
|
|||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
FormatTransNodeType beforeNodeType, afterNodeType;
|
||||
if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc
|
||||
// if (quantType == QuantType_AwareTrainning) { // awaretrainning op use
|
||||
// if (quantType == QuantType_AwareTraining) { // AwareTraining op use
|
||||
// nhwc
|
||||
// if (IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only
|
||||
// support nhwc
|
||||
|
@ -120,7 +120,7 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) {
|
|||
// beforeNodeType = kNCHW2NHWC;
|
||||
// afterNodeType = kNHWC2NCHW;
|
||||
} else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw
|
||||
// if (quantType == QuantType_AwareTrainning) { // awaretrainning op use nhwc
|
||||
// if (quantType == QuantType_AwareTraining) { // AwareTraining op use nhwc
|
||||
// if (!IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only support nhwc
|
||||
// continue;
|
||||
// }
|
||||
|
|
|
@ -27,7 +27,7 @@ int WeightFormatPass::Run(GraphNode *graphNode) {
|
|||
MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status;
|
||||
return status;
|
||||
}
|
||||
if (this->quantType == QuantType_AwareTrainning || this->quantType == QuantType_PostTraining) {
|
||||
if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_PostTraining) {
|
||||
status = QuantDataFormatTrans(graphNode);
|
||||
if (status != 0) {
|
||||
MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status;
|
||||
|
@ -96,7 +96,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
|
|||
return 0;
|
||||
} else if (fmkType == converter::FmkType_MS) {
|
||||
switch (node->quantType) {
|
||||
case QuantType_AwareTrainning: {
|
||||
case QuantType_AwareTraining: {
|
||||
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
weightTensor->format = schema::Format_HWCK;
|
||||
} else {
|
||||
|
@ -123,7 +123,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
|
|||
return 0;
|
||||
} else if (fmkType == converter::FmkType_TF) {
|
||||
switch (node->quantType) {
|
||||
case QuantType_AwareTrainning: {
|
||||
case QuantType_AwareTraining: {
|
||||
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
weightTensor->format = schema::Format_HWCK;
|
||||
} else {
|
||||
|
@ -148,7 +148,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
|
|||
} else if (fmkType == converter::FmkType_TFLITE) {
|
||||
switch (node->quantType) {
|
||||
case QuantType_QUANT_NONE:
|
||||
case QuantType_AwareTrainning:
|
||||
case QuantType_AwareTraining:
|
||||
case QuantType_PostTraining: {
|
||||
if (opType == schema::PrimitiveType_Conv2D) {
|
||||
weightTensor->format = schema::Format_KHWC;
|
||||
|
@ -170,7 +170,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
|
|||
return 0;
|
||||
} else if (fmkType == converter::FmkType_ONNX) {
|
||||
switch (node->quantType) {
|
||||
case QuantType_AwareTrainning: {
|
||||
case QuantType_AwareTraining: {
|
||||
// sum up from current onnx quant models
|
||||
if (opType == schema::PrimitiveType_Conv2D) {
|
||||
weightTensor->format = schema::Format_KHWC;
|
||||
|
|
|
@ -314,7 +314,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
|
|||
}
|
||||
}
|
||||
if (findQuantParams == needQuantParams) {
|
||||
dst_op->quantType = schema::QuantType_AwareTrainning;
|
||||
dst_op->quantType = schema::QuantType_AwareTraining;
|
||||
} else {
|
||||
dst_op->quantType = schema::QuantType_QUANT_NONE;
|
||||
}
|
||||
|
|
|
@ -324,7 +324,7 @@ STATUS AwareQuantizer::GenerateQuantParam() {
|
|||
MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
|
||||
node->quantType = schema::QuantType_QUANT_NONE;
|
||||
} else {
|
||||
node->quantType = schema::QuantType_AwareTrainning;
|
||||
node->quantType = schema::QuantType_AwareTraining;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -337,7 +337,7 @@ STATUS AwareQuantizer::DoQuantize() {
|
|||
if (!IsContain(GetUint8OpList(), GetCNodeTType(*node))) {
|
||||
continue;
|
||||
}
|
||||
if (node->quantType != schema::QuantType_AwareTrainning) {
|
||||
if (node->quantType != schema::QuantType_AwareTraining) {
|
||||
continue;
|
||||
}
|
||||
STATUS status;
|
||||
|
@ -584,7 +584,7 @@ STATUS AwareQuantizer::DetermineNodeQuantType() {
|
|||
}
|
||||
}
|
||||
if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) {
|
||||
node->quantType = schema::QuantType_AwareTrainning;
|
||||
node->quantType = schema::QuantType_AwareTraining;
|
||||
} else {
|
||||
node->quantType = schema::QuantType_QUANT_NONE;
|
||||
}
|
||||
|
|
|
@ -509,7 +509,8 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct M
|
|||
quant_param.min = max_min->min;
|
||||
quant_param.numBits = bit_num;
|
||||
quant_param.narrowRange = false;
|
||||
lite_primitive->AddInputQuantParam(quant_param);
|
||||
std::vector<schema::QuantParamT> quant_params = {quant_param};
|
||||
lite_primitive->AddInputQuantParam(quant_params);
|
||||
// p->AddAttr("quant_input_dataType", MakeValue((int)DataType_DT_FLOAT));
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -526,7 +527,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
|
|||
quant_param.min = max_min->min;
|
||||
quant_param.numBits = bit_num;
|
||||
quant_param.narrowRange = false;
|
||||
lite_primitive->AddOutputQuantParam(quant_param);
|
||||
std::vector<schema::QuantParamT> quant_params = {quant_param};
|
||||
lite_primitive->AddOutputQuantParam(quant_params);
|
||||
// p->AddAttr("quant_output_dataType", MakeValue((int)DataType_DT_FLOAT));
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -569,7 +571,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr<PrimitiveTValue> input
|
|||
auto quant_params = input->GetInputQuantParams();
|
||||
size_t sizeX = quant_params.size();
|
||||
for (size_t i = 0; i < sizeX; i++) {
|
||||
input_scales.emplace_back(quant_params[i].scale);
|
||||
input_scales.emplace_back(quant_params[i].front().scale);
|
||||
}
|
||||
size_t sizeY = weight_param->quant_param().size();
|
||||
if (sizeX != sizeY) {
|
||||
|
|
|
@ -31,7 +31,8 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector
|
|||
auto primTValue = std::make_shared<PrimitiveTValue>(primitive.release());
|
||||
primTValue->SetQuantType(schema::QuantType_PostTraining);
|
||||
for (auto &quant_param : quant_params) {
|
||||
primTValue->AddInputQuantParam(quant_param);
|
||||
std::vector<schema::QuantParamT> quant_params_in = {quant_param};
|
||||
primTValue->AddInputQuantParam(quant_params_in);
|
||||
}
|
||||
return NewValueNode(primTValue);
|
||||
}
|
||||
|
@ -53,7 +54,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
|
|||
if (first) {
|
||||
if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) {
|
||||
auto value_node =
|
||||
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams());
|
||||
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams().front());
|
||||
std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)};
|
||||
auto quant_cast_cnode = graph->NewCNode(op_inputs);
|
||||
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast");
|
||||
|
@ -84,11 +85,11 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
|
|||
if (curnode_quant_type == schema::QuantType_PostTraining &&
|
||||
input_cnode_quant_type == schema::QuantType_QUANT_NONE) {
|
||||
value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8,
|
||||
primitiveT_value->GetInputQuantParams());
|
||||
primitiveT_value->GetInputQuantParams().front());
|
||||
} else if (curnode_quant_type == schema::QuantType_QUANT_NONE &&
|
||||
input_cnode_quant_type == schema::QuantType_PostTraining) {
|
||||
value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32,
|
||||
input_cnode_primitiveT_value->GetInputQuantParams());
|
||||
input_cnode_primitiveT_value->GetInputQuantParams().front());
|
||||
}
|
||||
if (value_node == nullptr) {
|
||||
MS_LOG(WARNING) << "value_node is null! "
|
||||
|
|
Loading…
Reference in New Issue