diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 0e09da86140..733fc28d76d 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -197,7 +197,7 @@ union PrimitiveType { enum QuantType: int { QUANT_NONE, - AwareTrainning, + AwareTraining, WeightQuant, PostTraining } diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc index 3e6a10cf916..6e7aea38945 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -188,7 +188,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { // add quant param node->quantType = primitiveT_value->GetQuantType(); - if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTrainning) { + if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTraining) { MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; // activation auto input_quant_params = primitiveT_value->GetInputQuantParams(); @@ -202,14 +202,12 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { auto activate_index = node->inputIndex[i]; auto tensor_input = metaGraphT->allTensors[activate_index].get(); if (tensor_input->quantParams.empty()) { - std::unique_ptr input_quant_param = - std::make_unique(input_quant_params[i]); - MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale - << " zp: " << input_quant_param->zeroPoint; - tensor_input->quantParams.emplace_back(std::move(input_quant_param)); - if (!(node_type == schema::PrimitiveType_QuantDTypeCast && - primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) { - tensor_input->dataType = kNumberTypeInt8; + for (auto input_quant_param : input_quant_params[i]) { + std::unique_ptr input_quant_param_ptr = + std::make_unique(input_quant_param); + MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param_ptr->scale + << " zp: " << input_quant_param_ptr->zeroPoint; + tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); } } } @@ -221,15 +219,18 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { if (output_quant_params.empty()) { MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; } else { - if (tensor_output->quantParams.empty()) { - std::unique_ptr output_quant_param = - std::make_unique(output_quant_params[0]); - MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale - << " zp: " << output_quant_param->zeroPoint; - tensor_output->quantParams.emplace_back(std::move(output_quant_param)); + for (auto output_quant_param : output_quant_params[0]) { + if (tensor_output->quantParams.empty()) { + std::unique_ptr output_quant_param_ptr = + std::make_unique(output_quant_param); + MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << output_quant_param_ptr->scale + << " zp: " << output_quant_param_ptr->zeroPoint; + tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr)); + } } } - if (!(node_type == schema::PrimitiveType_QuantDTypeCast && + if (node->quantType != schema::QuantType_AwareTraining && + !(node_type == schema::PrimitiveType_QuantDTypeCast && primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { tensor_output->dataType = kNumberTypeInt8; } @@ -322,18 +323,6 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta paramTensor->nodeType = schema::NodeType_ValueNode; paramTensor->data.resize(paramValue->tensor_size()); memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); - for (auto &ite : paramValue->quant_param()) { - auto quantPar = std::make_unique(); - quantPar->scale = ite->scale; - quantPar->zeroPoint = ite->zeroPoint; - quantPar->min = ite->min; - quantPar->max = ite->max; - quantPar->narrowRange = ite->narrowRange; - quantPar->inited = ite->inited; - quantPar->numBits = ite->numBits; - paramTensor->quantParams.emplace_back(std::move(quantPar)); - paramTensor->dataType = paramValue->tensor_type(); - } } nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size(); fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc b/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc index 29407f9c8e4..4e4206845f1 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc +++ b/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc @@ -225,7 +225,7 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PopulaterConv2DSingleGroup(prim, primitive, group); } primitiveTValuePtr->SetPrimitiveT(primitive.release()); - if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTrainning) { + if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) { std::vector> vecQuantParam; PopulaterQuantParam(prim, &vecQuantParam); primitiveTValuePtr->SetInputQuantParam(vecQuantParam); diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc index 786f825a651..4056480f4be 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc +++ b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc @@ -89,13 +89,15 @@ int AnfImporterFromMetaGraphT::ConverterCNode() { } auto primTValue = std::make_shared(cNode->primitive.release()); // add quant parameter - if (cNode->quantType == schema::QuantType_AwareTrainning) { + if (cNode->quantType == schema::QuantType_AwareTraining) { primTValue->SetQuantType(cNode->quantType); for (int index : cNode->inputIndex) { - primTValue->AddInputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0])); + std::vector quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; + primTValue->AddInputQuantParam(quant_params); } for (int index : cNode->outputIndex) { - primTValue->AddOutputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0])); + std::vector quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; + primTValue->AddOutputQuantParam(quant_params); } } cNode->primitive = nullptr; diff --git a/mindspore/lite/src/ir/primitive_t_value.h b/mindspore/lite/src/ir/primitive_t_value.h index 7de7250b231..bc71bc73067 100644 --- a/mindspore/lite/src/ir/primitive_t_value.h +++ b/mindspore/lite/src/ir/primitive_t_value.h @@ -49,17 +49,17 @@ class PrimitiveTValue : public Value { void SetInputQuantParam(std::vector> vec_quant_param) { } - void AddInputQuantParam(schema::QuantParamT quant_param) { + void AddInputQuantParam(std::vector quant_param) { this->input_quant_param_.emplace_back(quant_param); } - std::vector GetInputQuantParams() const { + std::vector> GetInputQuantParams() const { return input_quant_param_; } - void AddOutputQuantParam(schema::QuantParamT quant_param) { + void AddOutputQuantParam(std::vector quant_param) { this->output_quant_param_.emplace_back(quant_param); } - std::vector GetOutputQuantParams() const { + std::vector> GetOutputQuantParams() const { return output_quant_param_; } @@ -69,8 +69,8 @@ class PrimitiveTValue : public Value { protected: schema::PrimitiveT *primitive = nullptr; - std::vector input_quant_param_; - std::vector output_quant_param_; + std::vector> input_quant_param_; + std::vector> output_quant_param_; schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index f7c92e2f725..c1ad5cff8bc 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -131,7 +131,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) { auto type = flags->quantType; switch (type) { - case mindspore::schema::QuantType_AwareTrainning: { + case mindspore::schema::QuantType_AwareTraining: { // mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean)); break; } diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 0c2991b8c84..4059ef12a4d 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -31,7 +31,7 @@ Flags::Flags() { "Input model weight file path. Needed when fmk is CAFFE. CAFFE: *.caffemodel", ""); AddFlag(&Flags::inferenceType, "inferenceType", "Real data type saved in output file, reserved param, NOT used for now. FLOAT | FP16 | UINT8", "FLOAT"); - AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTrainning | WeightQuant | PostTraining", ""); + AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. AwareTraining | WeightQuant | PostTraining", ""); AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | UINT8", "FLOAT"); AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128"); AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "127"); @@ -98,8 +98,8 @@ int Flags::Init(int argc, const char **argv) { std::cerr << "INPUT ILLEGAL: weightFile is not a valid flag"; return 1; } - if (this->quantTypeIn == "AwareTrainning") { - this->quantType = QuantType_AwareTrainning; + if (this->quantTypeIn == "AwareTraining") { + this->quantType = QuantType_AwareTraining; } else if (this->quantTypeIn == "WeightQuant") { this->quantType = QuantType_WeightQuant; } else if (this->quantTypeIn == "PostTraining") { @@ -107,7 +107,7 @@ int Flags::Init(int argc, const char **argv) { } else if (this->quantTypeIn.empty()) { this->quantType = QuantType_QUANT_NONE; } else { - std::cerr << "INPUT ILLEGAL: quantType must be AwareTrainning|WeightQuant|PostTraining"; + std::cerr << "INPUT ILLEGAL: quantType must be AwareTraining|WeightQuant|PostTraining"; return 1; } diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h index 9ccfed6ceb3..98e3581d78b 100644 --- a/mindspore/lite/tools/converter/converter_flags.h +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -27,7 +27,7 @@ namespace lite { using mindspore::schema::QuantType; using mindspore::schema::QuantType_PostTraining; using mindspore::schema::QuantType_QUANT_NONE; -using mindspore::schema::QuantType_AwareTrainning; +using mindspore::schema::QuantType_AwareTraining; using mindspore::schema::QuantType_WeightQuant; using mindspore::schema::QuantType_PostTraining; using mindspore::schema::QuantType_PostTraining; diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index bd715fae292..0f564aeabdb 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -68,8 +68,8 @@ void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { auto type = flags->quantType; switch (type) { - case QuantType::QuantType_AwareTrainning: { - MS_LOG(INFO) << "create AwareTrainningQuantizer!"; + case QuantType::QuantType_AwareTraining: { + MS_LOG(INFO) << "create AwareTrainingQuantizer!"; fbQuantizer = std::make_unique(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean); break; @@ -146,7 +146,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { return status; } if (!(this->graphDefT->fmkType == converter::FmkType_TF && - this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTrainning)) { + this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTraining)) { status = mQuantizer->GenerateQuantParam(); if (status != RET_OK) { MS_LOG(ERROR) << "GenerateQuantParam failed"; @@ -173,7 +173,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { formatTransOptimizer.AddPass(formatTransPass); formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - // if (ctx.quantType == QuantType_AwareTrainning) { + // if (ctx.quantType == QuantType_AwareTraining) { // formatTransOptimizer.AddPass(new (std::nothrow) FormatTransNodeQuantParamFillPass()); // } status = formatTransOptimizer.Run(graphDefT); @@ -193,7 +193,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } // insert quantNode and deQuantNode - if (ctx.quantType == QuantType_AwareTrainning) { + if (ctx.quantType == QuantType_AwareTraining) { Optimizer quantNodeOptimizer; auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); if (dTypeTransPass == nullptr) { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc index ec9d979e32c..9c8da98aab9 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -136,7 +136,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); // insert transNode before and after existNode for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTrainning) { + if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { continue; } auto &node = *iter; @@ -208,7 +208,7 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte transNode->primitive = std::make_unique(); transNode->primitive->value.value = quantDTypeCastParam; transNode->primitive->value.type = PrimitiveType_QuantDTypeCast; - transNode->quantType = QuantType_AwareTrainning; + transNode->quantType = QuantType_AwareTraining; if (nodeType == kInt8ToFP32) { quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8; quantDTypeCastParam->dstT = TypeId::kNumberTypeFloat32; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc index 9c01c074906..58ed5cc2ce2 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc @@ -103,7 +103,7 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { FormatTransNodeType beforeNodeType, afterNodeType; if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc - // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use + // if (quantType == QuantType_AwareTraining) { // AwareTraining op use // nhwc // if (IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only // support nhwc @@ -120,7 +120,7 @@ STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { // beforeNodeType = kNCHW2NHWC; // afterNodeType = kNHWC2NCHW; } else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw - // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use nhwc + // if (quantType == QuantType_AwareTraining) { // AwareTraining op use nhwc // if (!IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only support nhwc // continue; // } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc index c22e3fe12f7..652cdd61c60 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc @@ -27,7 +27,7 @@ int WeightFormatPass::Run(GraphNode *graphNode) { MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status; return status; } - if (this->quantType == QuantType_AwareTrainning || this->quantType == QuantType_PostTraining) { + if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_PostTraining) { status = QuantDataFormatTrans(graphNode); if (status != 0) { MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; @@ -96,7 +96,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { return 0; } else if (fmkType == converter::FmkType_MS) { switch (node->quantType) { - case QuantType_AwareTrainning: { + case QuantType_AwareTraining: { if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { weightTensor->format = schema::Format_HWCK; } else { @@ -123,7 +123,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { return 0; } else if (fmkType == converter::FmkType_TF) { switch (node->quantType) { - case QuantType_AwareTrainning: { + case QuantType_AwareTraining: { if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) { weightTensor->format = schema::Format_HWCK; } else { @@ -148,7 +148,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { } else if (fmkType == converter::FmkType_TFLITE) { switch (node->quantType) { case QuantType_QUANT_NONE: - case QuantType_AwareTrainning: + case QuantType_AwareTraining: case QuantType_PostTraining: { if (opType == schema::PrimitiveType_Conv2D) { weightTensor->format = schema::Format_KHWC; @@ -170,7 +170,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { return 0; } else if (fmkType == converter::FmkType_ONNX) { switch (node->quantType) { - case QuantType_AwareTrainning: { + case QuantType_AwareTraining: { // sum up from current onnx quant models if (opType == schema::PrimitiveType_Conv2D) { weightTensor->format = schema::Format_KHWC; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index 2427eb98a4d..c597e23c20b 100755 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -314,7 +314,7 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const } } if (findQuantParams == needQuantParams) { - dst_op->quantType = schema::QuantType_AwareTrainning; + dst_op->quantType = schema::QuantType_AwareTraining; } else { dst_op->quantType = schema::QuantType_QUANT_NONE; } diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc index a899ad17e19..ef67965d35f 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc @@ -324,7 +324,7 @@ STATUS AwareQuantizer::GenerateQuantParam() { MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); node->quantType = schema::QuantType_QUANT_NONE; } else { - node->quantType = schema::QuantType_AwareTrainning; + node->quantType = schema::QuantType_AwareTraining; } } } @@ -337,7 +337,7 @@ STATUS AwareQuantizer::DoQuantize() { if (!IsContain(GetUint8OpList(), GetCNodeTType(*node))) { continue; } - if (node->quantType != schema::QuantType_AwareTrainning) { + if (node->quantType != schema::QuantType_AwareTraining) { continue; } STATUS status; @@ -584,7 +584,7 @@ STATUS AwareQuantizer::DetermineNodeQuantType() { } } if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) { - node->quantType = schema::QuantType_AwareTrainning; + node->quantType = schema::QuantType_AwareTraining; } else { node->quantType = schema::QuantType_QUANT_NONE; } diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index 08fc4ce2911..01d85b2674e 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -509,7 +509,8 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct M quant_param.min = max_min->min; quant_param.numBits = bit_num; quant_param.narrowRange = false; - lite_primitive->AddInputQuantParam(quant_param); + std::vector quant_params = {quant_param}; + lite_primitive->AddInputQuantParam(quant_params); // p->AddAttr("quant_input_dataType", MakeValue((int)DataType_DT_FLOAT)); return RET_OK; } @@ -526,7 +527,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct quant_param.min = max_min->min; quant_param.numBits = bit_num; quant_param.narrowRange = false; - lite_primitive->AddOutputQuantParam(quant_param); + std::vector quant_params = {quant_param}; + lite_primitive->AddOutputQuantParam(quant_params); // p->AddAttr("quant_output_dataType", MakeValue((int)DataType_DT_FLOAT)); return RET_OK; } @@ -569,7 +571,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr input auto quant_params = input->GetInputQuantParams(); size_t sizeX = quant_params.size(); for (size_t i = 0; i < sizeX; i++) { - input_scales.emplace_back(quant_params[i].scale); + input_scales.emplace_back(quant_params[i].front().scale); } size_t sizeY = weight_param->quant_param().size(); if (sizeX != sizeY) { diff --git a/mindspore/lite/tools/converter/quantizer/quant_cast.cc b/mindspore/lite/tools/converter/quantizer/quant_cast.cc index f924cda6efa..37dbfb43b11 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_cast.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_cast.cc @@ -31,7 +31,8 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector auto primTValue = std::make_shared(primitive.release()); primTValue->SetQuantType(schema::QuantType_PostTraining); for (auto &quant_param : quant_params) { - primTValue->AddInputQuantParam(quant_param); + std::vector quant_params_in = {quant_param}; + primTValue->AddInputQuantParam(quant_params_in); } return NewValueNode(primTValue); } @@ -53,7 +54,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { if (first) { if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { auto value_node = - NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams()); + NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams().front()); std::vector op_inputs = {value_node, cnode->input(1)}; auto quant_cast_cnode = graph->NewCNode(op_inputs); quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); @@ -84,11 +85,11 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { if (curnode_quant_type == schema::QuantType_PostTraining && input_cnode_quant_type == schema::QuantType_QUANT_NONE) { value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, - primitiveT_value->GetInputQuantParams()); + primitiveT_value->GetInputQuantParams().front()); } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && input_cnode_quant_type == schema::QuantType_PostTraining) { value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, - input_cnode_primitiveT_value->GetInputQuantParams()); + input_cnode_primitiveT_value->GetInputQuantParams().front()); } if (value_node == nullptr) { MS_LOG(WARNING) << "value_node is null! "