From 21f15b3981596069cc05e7b03afe5bd98a02bbfe Mon Sep 17 00:00:00 2001 From: albert-yan Date: Sat, 6 Aug 2022 09:49:53 +0800 Subject: [PATCH] fix determiner and tf tflite quant parser --- .../converter/parser/tf/tf_model_parser.cc | 17 ---- .../converter/parser/tf/tf_model_parser.h | 2 - .../parser/tflite/tflite_model_parser.cc | 36 ++++--- .../quantizer/quant_helper/quant_node_pass.cc | 8 ++ .../quant_helper/quant_type_determiner.cc | 96 ++++++------------- .../quant_helper/quant_type_determiner.h | 2 - 6 files changed, 53 insertions(+), 108 deletions(-) diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index e66533e5c42..e4f5c4465e4 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -1035,11 +1035,6 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def, return status; } - status = ConvertQuantParams(inputs.size() - 1, output_size, primitive_c); - if (status != RET_OK) { - MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed."; - return status; - } return status; } @@ -1077,18 +1072,6 @@ STATUS TFModelParser::ProcessControlFlowOp(const CNodePtr &anf_node, const strin return RET_OK; } -STATUS TFModelParser::ConvertQuantParams(const size_t &input_size, const size_t &output_size, - PrimitiveCPtr primitive_c) { - if (primitive_c == nullptr) { - MS_LOG(ERROR) << "primitive_c is null, get quant params failed."; - return RET_NULL_PTR; - } - auto quant_params_holder = std::make_shared(input_size, output_size); - CHECK_NULL_RETURN(quant_params_holder); - primitive_c->AddAttr("quant_params", quant_params_holder); - return RET_OK; -} - std::set TFModelParser::GetAllNodeInputs() { std::set all_node_inputs; for (auto &node : tf_root_graph_nodes_vec_) { diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index 368be5f59a3..cf18297cfcd 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -92,8 +92,6 @@ class TFModelParser : public converter::ModelParser { STATUS ControlFlowNodePostProcess(const std::map &first_func_map, const std::map &second_func_map); - static STATUS ConvertQuantParams(const size_t &input_size, const size_t &output_size, PrimitiveCPtr primitive_c); - static STATUS MakeAnfGraphOutputs(const std::vector &output_nodes, const FuncGraphPtr &anf_graph); STATUS RecordNullInput(const CNodePtr &node, const std::vector &input_name_not_found); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 98608f118f8..f97415b3a2c 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -435,23 +435,8 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const std::unique_ptrname() == "Conv2D" || primitive_c->name() == "Conv2DFusion") { round_type = 2; } - int32_t inputs_size = 0; - int32_t outputs_size = 0; - if (primitive_c->name() == "FullyConnection") { - std::vector inputs(op->inputs.size()); - std::vector outputs(op->outputs.size()); - auto it = - std::copy_if(op->inputs.begin(), op->inputs.end(), inputs.begin(), [](const int32_t item) { return item >= 0; }); - inputs.resize(std::distance(inputs.begin(), it)); - it = std::copy_if(op->outputs.begin(), op->outputs.end(), outputs.begin(), - [](const int32_t item) { return item >= 0; }); - outputs.resize(std::distance(outputs.begin(), it)); - } else { - inputs_size = op->inputs.size(); - outputs_size = op->outputs.size(); - } - auto quant_params_holder = std::make_shared(inputs_size, outputs_size); - MSLITE_CHECK_PTR(quant_params_holder); + + std::map> in_quant_param; size_t idx = 0; for (auto input_idx : op->inputs) { if (input_idx < 0) { @@ -468,9 +453,10 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const std::unique_ptrset_input_quant_param(idx, quant_params); + in_quant_param.insert({idx, quant_params}); idx++; } + std::map> out_quant_param; idx = 0; for (auto output_idx : op->outputs) { if (output_idx < 0) { @@ -487,10 +473,20 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const std::unique_ptrset_output_quant_param(idx, quant_params); + out_quant_param.insert({idx, quant_params}); idx++; } - primitive_c->AddAttr("quant_params", quant_params_holder); + if (!in_quant_param.empty() || !out_quant_param.empty()) { + auto quant_params_holder = std::make_shared(0, 0); + MSLITE_CHECK_PTR(quant_params_holder); + for (auto &iter : in_quant_param) { + quant_params_holder->set_input_quant_param(iter.first, iter.second); + } + for (auto &iter : out_quant_param) { + quant_params_holder->set_output_quant_param(iter.first, iter.second); + } + primitive_c->AddAttr("quant_params", quant_params_holder); + } return RET_OK; } diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_pass.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_pass.cc index 0475a1bcc36..fe5b00e3598 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_pass.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_pass.cc @@ -188,6 +188,10 @@ int QuantNodePass::DoParameterNodeQuant(const CNodePtr &cnode, const ParameterPt MS_LOG(ERROR) << input_node->fullname_with_scope() << " can not get value"; return RET_NULL_PTR; } + if (tensor_info->data_type() != kNumberTypeFloat32) { + MS_LOG(INFO) << cnode->fullname_with_scope() << " is not float32, data will not quant."; + return RET_OK; + } int preferred_dim = GetPreferredDim(cnode, input_index - 1, ConvertShapeVectorToInt32(tensor_info->shape())); MS_CHECK_GT(static_cast(quant_param_holder->get_input_quant_params().size()), static_cast(input_index) - 1, RET_ERROR); @@ -216,6 +220,10 @@ int QuantNodePass::DoValueNodeQuant(const CNodePtr &cnode, const ValueNodePtr &i MS_LOG(ERROR) << input_node->fullname_with_scope() << " can not get value"; return RET_NULL_PTR; } + if (tensor_info->data_type() != kNumberTypeFloat32) { + MS_LOG(INFO) << cnode->fullname_with_scope() << " is not float32, data will not quant."; + return RET_OK; + } int preferred_dim = GetPreferredDim(cnode, input_index - 1, ConvertShapeVectorToInt32(tensor_info->shape())); MS_CHECK_GT(static_cast(quant_param_holder->get_input_quant_params().size()), static_cast(input_index) - 1, RET_ERROR); diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.cc index 59d23fde933..84d5b445088 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.cc @@ -21,33 +21,15 @@ #include "src/litert/kernel_exec.h" #include "src/litert/kernel_registry.h" #include "src/common/ops/anf_utils.h" +#include "tools/optimizer/common/format_utils.h" +#include "tools/common/node_util.h" namespace mindspore::lite::quant { -std::pair QuantTypeDeterminer::GetQuantParamsNum(const QuantParamHolderPtr &quant_holder) { - // update input quant params num - auto input_inited_quant_params = 0; - auto input_tensors = quant_holder->get_input_quant_params(); - for (auto input : input_tensors) { - bool is_quant_params_inited = std::all_of( - input.begin(), input.end(), [](const schema::QuantParamT &quant_param) { return quant_param.inited; }); - if (is_quant_params_inited) { - input_inited_quant_params++; - } - } - auto output_inited_quant_params = 0; - auto output_tensors = quant_holder->get_output_quant_params(); - for (auto output : output_tensors) { - bool is_quant_params_inited = !std::any_of( - output.begin(), output.end(), [](const schema::QuantParamT &quant_param) { return !quant_param.inited; }); - if (is_quant_params_inited) { - output_inited_quant_params++; - } - } - return {input_inited_quant_params, output_inited_quant_params}; -} - bool QuantTypeDeterminer::DetermineQuantAll(const CNodePtr &cnode) { - MS_ASSERT(node != nullptr); + MS_ASSERT(cnode != nullptr); + if (opt::IsSpecialType(cnode)) { + return false; + } auto primT = GetPrimitiveT(cnode->input(kPrimIndex)); if (primT == nullptr) { MS_LOG(WARNING) << cnode->fullname_with_scope() << " primitive is nullptr."; @@ -66,51 +48,23 @@ bool QuantTypeDeterminer::DetermineQuantAll(const CNodePtr &cnode) { return false; } - // GetCNodeQuantType + // Get CNode QuantType directly. if (quant_holder->quant_type() != schema::QuantType_QUANT_NONE) { return quant_holder->quant_type() == schema::QuantType_QUANT_ALL; } + // All output need init. if (!quant_holder->IsOutputQuantParamsInited()) { return false; } - if (CheckNodeInSet(cnode, bias_ops_)) { - auto input_quant_params = quant_holder->get_input_quant_params(); - MS_CHECK_TRUE_RET(!input_quant_params.empty(), false); - bool input_params_inited = - (!input_quant_params.at(kInputIndex).empty() && input_quant_params.at(kInputIndex).front().inited) && - (!input_quant_params.at(kWeightIndex).empty() && input_quant_params.at(kWeightIndex).front().inited); - if (!input_params_inited || !quant_holder->IsOutputQuantParamsInited()) { + + // Quantization parameters exist for all activations. + for (size_t i = 1; i < cnode->size(); ++i) { + auto input = cnode->input(i); + if (input->isa() && !quant_holder->CheckInit(i - kPrimOffset, true)) { return false; } } - - auto in_out_quant_params = GetQuantParamsNum(quant_holder); - // Check quant param size is same as tensor size. - auto input_size = (cnode->size() - kPrimOffset); - if (CheckNodeInSet(cnode, bias_ops_)) { - input_size -= kPrimOffset; - } - // exclude input(not fp32) - for (size_t index = 1; index < cnode->size(); ++index) { - CHECK_NULL_RETURN(cnode->input(index)); - auto abstract_base = cnode->input(index)->abstract(); - if (!utils::isa(abstract_base)) { - MS_LOG(ERROR) << cnode->fullname_with_scope() << " index: " << index << " should be AbstractTensorPtr."; - return RET_ERROR; - } - auto abstract_tensor = utils::cast(abstract_base); - CHECK_NULL_RETURN(abstract_tensor); - CHECK_NULL_RETURN(abstract_tensor->element()); - if (abstract_tensor->element()->GetTypeTrack()->type_id() != kNumberTypeFloat32) { - input_size -= kPrimOffset; - } - } - auto output_size = opt::GetOutputSize(cnode); - if (in_out_quant_params.first == input_size && in_out_quant_params.second == output_size) { - quant_holder->set_quant_type(schema::QuantType_QUANT_ALL); - return true; - } - return false; + return true; } bool QuantTypeDeterminer::DetermineQuantWeight(const CNodePtr &cnode) { @@ -120,7 +74,7 @@ bool QuantTypeDeterminer::DetermineQuantWeight(const CNodePtr &cnode) { return false; } - // GetCNodeQuantType + // Get CNode QuantType directly. if (quant_holder->quant_type() != schema::QuantType_QUANT_NONE) { return quant_holder->quant_type() == schema::QuantType_QUANT_WEIGHT; } @@ -130,8 +84,12 @@ bool QuantTypeDeterminer::DetermineQuantWeight(const CNodePtr &cnode) { return false; } + bool quant_flag = false; for (size_t i = 1; i < cnode->size(); i++) { auto input = cnode->input(i); + if (IsGraphInput(input)) { + continue; + } // non-constants(CNode) don't include quantization parameters if (input->isa()) { if (quant_holder->CheckInit(i - kPrimOffset, true)) { @@ -140,11 +98,12 @@ bool QuantTypeDeterminer::DetermineQuantWeight(const CNodePtr &cnode) { } else { // Constants have quantization parameters if (quant_holder->CheckInit(i - kPrimOffset, true)) { - return true; + quant_flag = true; + continue; } } } - return false; + return quant_flag; } int QuantTypeDeterminer::Determine() { @@ -156,15 +115,18 @@ int QuantTypeDeterminer::Determine() { MS_LOG(INFO) << cnode->fullname_with_scope() << " quant holder is nullptr."; continue; } - if (DetermineQuantWeight(cnode)) { + if (!quant_holder->IsInputQuantParamsInited() && !quant_holder->IsOutputQuantParamsInited()) { // Check FP32. + if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) { + continue; + } + MS_LOG(INFO) << cnode->fullname_with_scope() << " Remove unused quant info"; + quant_holder->ClearQuantParams(); + } else if (DetermineQuantWeight(cnode)) { MS_LOG(INFO) << cnode->fullname_with_scope() << " set QuantType_QUANT_WEIGHT"; quant_holder->set_quant_type(schema::QuantType_QUANT_WEIGHT); } else if (DetermineQuantAll(cnode)) { MS_LOG(INFO) << cnode->fullname_with_scope() << " set QuantType_QUANT_ALL"; quant_holder->set_quant_type(schema::QuantType_QUANT_ALL); - } else { - MS_LOG(INFO) << cnode->fullname_with_scope() << " Remove unused quant info"; - quant_holder->ClearQuantParams(); } } return RET_OK; diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.h b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.h index fe1f2d20b87..3d732c78d08 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.h +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.h @@ -35,8 +35,6 @@ class QuantTypeDeterminer { private: bool DetermineQuantAll(const CNodePtr &cnode); - std::pair GetQuantParamsNum(const QuantParamHolderPtr &quant_holder); - bool DetermineQuantWeight(const CNodePtr &cnode); private: