From fae78e11a7ca0ba43ea6e31d2250246ceb9f1445 Mon Sep 17 00:00:00 2001 From: xutianchun Date: Thu, 30 Jul 2020 21:31:08 +0800 Subject: [PATCH] 1. compute threshold only once 2. fix anf_exporter bug: pool, concat op may not set into metagraph 3. fix weight trans pass will return error when post quantization 4. make anf_exporter reentrant: do not set PrimitiveT * to nullptr --- .../src/common/anf_exporter/anf_exporter.cc | 19 ++++++------- mindspore/lite/src/ir/primitive_t_value.h | 4 +-- .../optimizer/node/weight_format_pass.cc | 5 ++-- .../converter/quantizer/post_training.cc | 27 ++++++++++++++++--- 4 files changed, 36 insertions(+), 19 deletions(-) diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc index a4bdcc9eae6..d5d55014a13 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -98,7 +98,6 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { } node->primitive = std::unique_ptr(primitiveT_value->GetPrimitiveT()); - primitiveT_value->SetPrimitiveT(nullptr); std::vector outputs; SetOpInputNode(cnode, metaGraphT.get(), node.get()); SetOpOutputNode(outputs, metaGraphT.get(), node.get()); @@ -113,24 +112,22 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { auto input_quant_params = primitiveT_value->GetInputQuantParams(); if (input_quant_params.empty()) { MS_LOG(WARNING) << "node: " << node->name << " input quant params is empty"; - continue; + } else { + std::unique_ptr input_quant_param = + std::make_unique(input_quant_params[0]); + tensor_input->quantParams.emplace_back(std::move(input_quant_param)); } - - std::unique_ptr input_quant_param = - std::make_unique(input_quant_params[0]); - tensor_input->quantParams.emplace_back(std::move(input_quant_param)); // output auto output_index = node->outputIndex[0]; auto tensor_output = metaGraphT->allTensors[output_index].get(); auto output_quant_params = primitiveT_value->GetOutputQuantParams(); if (output_quant_params.empty()) { MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; - continue; + } else { + std::unique_ptr output_quant_param = + std::make_unique(output_quant_params[0]); + tensor_output->quantParams.emplace_back(std::move(output_quant_param)); } - - std::unique_ptr output_quant_param = - std::make_unique(output_quant_params[0]); - tensor_output->quantParams.emplace_back(std::move(output_quant_param)); // // TensorType // valuePtr = primitive->GetAttr(kInputTensorDataType); // if (valuePtr != nullptr) { diff --git a/mindspore/lite/src/ir/primitive_t_value.h b/mindspore/lite/src/ir/primitive_t_value.h index 56667890f3e..b13f4606eb1 100644 --- a/mindspore/lite/src/ir/primitive_t_value.h +++ b/mindspore/lite/src/ir/primitive_t_value.h @@ -26,8 +26,8 @@ namespace mindspore::lite { class PrimitiveTValue : public Value { public: explicit PrimitiveTValue(schema::PrimitiveT *primt) : primitive(primt) {} - - ~PrimitiveTValue() override { delete this->primitive; } + // not responsible to free primitive, the one created the dynamic memory is responsible to free it. + ~PrimitiveTValue() override = default; MS_DECLARE_PARENT(PrimitiveTValue, Value) diff --git a/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc index 3400c884515..ff173918b74 100644 --- a/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/optimizer/node/weight_format_pass.cc @@ -27,7 +27,7 @@ int WeightFormatPass::Run(GraphNode *graphNode) { MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status; return status; } - if (this->quantType == QuantType_AwareTrainning) { + if (this->quantType == QuantType_AwareTrainning || this->quantType == QuantType_PostTraining) { status = QuantDataFormatTrans(graphNode); if (status != 0) { MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; @@ -147,7 +147,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { } else if (fmkType == converter::FmkType_TFLITE) { switch (node->quantType) { case QuantType_QUANT_NONE: - case QuantType_AwareTrainning: { + case QuantType_AwareTrainning: + case QuantType_PostTraining: { if (opType == schema::PrimitiveType_Conv2D) { weightTensor->format = schema::Format_KHWC; } else if (opType == schema::PrimitiveType_DepthwiseConv2D) { diff --git a/mindspore/lite/tools/converter/quantizer/post_training.cc b/mindspore/lite/tools/converter/quantizer/post_training.cc index 392c3f92949..08353856bea 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training.cc @@ -292,14 +292,33 @@ STATUS Calibrator::RecordMaxValue(std::string opName, vector data, } STATUS Calibrator::ComputeThreshold() { - for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) { - DivergInfo *info = iter->second.get(); - info->ComputeThreshold(); - } for (auto iter = this->output_diverg_info_.begin(); iter != this->output_diverg_info_.end(); iter++) { DivergInfo *info = iter->second.get(); info->ComputeThreshold(); } + // node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as + for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) { + DivergInfo *info = iter->second.get(); + auto cnode = info->cnode; + + bool already_computed = false; + auto input = cnode->input(1); + if (input->isa()) { + auto input_cnode = std::dynamic_pointer_cast(input); + for (const auto &output_diverg_info : output_diverg_info_) { + auto output_diverg_cnode = output_diverg_info.second->cnode; + if (output_diverg_cnode == input_cnode) { + *info = *(output_diverg_info.second); + info->cnode = cnode; + already_computed = true; + break; + } + } + } + if (!already_computed) { + info->ComputeThreshold(); + } + } return RET_OK; }