diff --git a/mindspore/lite/nnacl/infer/common_infer.h b/mindspore/lite/nnacl/infer/common_infer.h index 9aedd92b258..b7489a89b6b 100644 --- a/mindspore/lite/nnacl/infer/common_infer.h +++ b/mindspore/lite/nnacl/infer/common_infer.h @@ -123,8 +123,10 @@ enum NNACLQuantType { QuantType_AwareTraining = 1, QuantType_WeightQuant = 2, QuantType_PostTraining = 3, + QuantType_QUANT_WEIGHT = 4, + QuantType_QUANT_ALL = 5, QuantType_MIN = QuantType_QUANT_NONE, - QuantType_MAX = QuantType_PostTraining + QuantType_MAX = QuantType_QUANT_ALL }; typedef struct vvector { diff --git a/mindspore/lite/nnacl/infer/gather_infer.c b/mindspore/lite/nnacl/infer/gather_infer.c index f7c86d8cb9e..78c025cc865 100644 --- a/mindspore/lite/nnacl/infer/gather_infer.c +++ b/mindspore/lite/nnacl/infer/gather_infer.c @@ -26,7 +26,7 @@ int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * const TensorC *indices = inputs[1]; TensorC *output = outputs[0]; output->data_type_ = input->data_type_; - if (parameter->quant_type_ == QuantType_WeightQuant) { + if (parameter->quant_type_ == QuantType_QUANT_WEIGHT) { output->data_type_ = kNumberTypeFloat32; } output->format_ = input->format_; diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 50fdbfcaf4b..76d60c404c6 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -56,9 +56,11 @@ table Tensor { enum QuantType: int { QUANT_NONE, - AwareTraining, - WeightQuant, - PostTraining + AwareTraining, // deprecated, use QUANT_ALL instead + WeightQuant, // deprecated, use QUANT_WEIGHT instead + PostTraining, // deprecated, use QUANT_ALL instead + QUANT_WEIGHT, + QUANT_ALL } table Primitive { diff --git a/mindspore/lite/src/lite_model.h b/mindspore/lite/src/lite_model.h index 11bf025c0f0..7f6edbe5715 100644 --- a/mindspore/lite/src/lite_model.h +++ b/mindspore/lite/src/lite_model.h @@ -65,6 +65,11 @@ class LiteModel : public Model { auto c_node = meta_graph.nodes()->template GetAs(i); node->primitive_ = c_node->primitive(); node->quant_type_ = c_node->quantType(); + if (node->quant_type_ == schema::QuantType_PostTraining || node->quant_type_ == schema::QuantType_AwareTraining) { + node->quant_type_ = schema::QuantType_QUANT_ALL; + } else if (node->quant_type_ == schema::QuantType_WeightQuant) { + node->quant_type_ = schema::QuantType_QUANT_WEIGHT; + } node->name_ = c_node->name()->c_str(); auto count = c_node->inputIndex()->size(); for (uint32_t j = 0; j < count; ++j) { diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index a4849df6eb4..33eea31bd55 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -430,7 +430,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector &in MS_ASSERT(node != nullptr); // why we need this TypeId data_type = - (node->quant_type_ == schema::QuantType_WeightQuant) ? kNumberTypeFloat32 : GetFirstFp32Fp16OrInt8Type(in_tensors); + (node->quant_type_ == schema::QuantType_QUANT_WEIGHT) ? kNumberTypeFloat32 : GetFirstFp32Fp16OrInt8Type(in_tensors); OpParameter *op_parameter = op_parameters_[node->output_indices_.at(0)]; if (op_parameter == nullptr) { MS_LOG(ERROR) << "Can not find OpParameter!type: " << PrimitiveTypeName(GetPrimitiveType(node->primitive_)); diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index b477839cef9..4576a286ed7 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -669,6 +669,9 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr &input_ano MS_LOG(ERROR) << "schema tensor format is wrong, " << schema_tensor->format; return RET_ERROR; } + if (primitive_c->GetAttr(opt::kWeightFormat) != nullptr) { + schema_tensor->format = static_cast(GetValue(primitive_c->GetAttr(opt::kWeightFormat))); + } schema_tensor->name = param_node->name(); ShapeVector shape_vector; TypeId data_type; @@ -704,9 +707,6 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr &input_ano } } } - if (primitive_c->GetAttr(opt::kWeightFormat) != nullptr) { - schema_tensor->format = static_cast(GetValue(primitive_c->GetAttr(opt::kWeightFormat))); - } schema_tensor->name = input_name; QuantParamHolderPtr quant_param_holder = primitive_c->GetAttr("quant_params") == nullptr ? nullptr @@ -722,9 +722,8 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr &input_ano } int AnfExporter::ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, - const std::shared_ptr &value, schema::CNodeT *output_cnode, - const std::unique_ptr &meta_graphT) { - int ret; + const std::shared_ptr &value, const std::shared_ptr &primitive, + schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT) { auto valueAbstract = value_node->abstract(); auto abstract_tensor = utils::cast(valueAbstract); if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) { @@ -742,15 +741,30 @@ int AnfExporter::ProcessTensor(const ValueNodePtr &value_node, std::unique_ptrnodeType = NodeType_ValueNode; auto data = value->cast(); (*schema_tensor)->data.resize(data->Size()); - ret = memcpy_s((*schema_tensor)->data.data(), data->Size(), data->data_c(), data->Size()); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s error."; + (*schema_tensor)->format = schema::Format_NHWC; + + (*schema_tensor)->format = GetFormatByFmk(meta_graphT->fmkType); + if ((*schema_tensor)->format != schema::Format_NHWC && (*schema_tensor)->format != schema::Format_NCHW) { + MS_LOG(ERROR) << "schema tensor format is wrong, " << (*schema_tensor)->format; return RET_ERROR; } + + // process weight tensor + if (data->Size() > 0) { + if (memcpy_s((*schema_tensor)->data.data(), data->Size(), data->data_c(), data->Size()) != EOK) { + MS_LOG(ERROR) << "memcpy_s error."; + return RET_ERROR; + } + + if (primitive->GetAttr(opt::kWeightFormat) != nullptr) { + (*schema_tensor)->format = static_cast(GetValue(primitive->GetAttr(opt::kWeightFormat))); + } + } + node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); - return ret; + return RET_OK; } int AnfExporter::ProcessInt32OrInt64Imm(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, const std::shared_ptr &value, schema::CNodeT *output_cnode, @@ -877,6 +891,7 @@ int AnfExporter::ProcessTensorInfo(const ValueNodePtr &value_node, std::unique_p } int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_anode, + const std::shared_ptr &primitive, const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode) { auto value_node = input_anode->cast(); @@ -888,7 +903,7 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_ano schema_tensor->name = value_node->fullname_with_scope(); } if (value->isa()) { - ret = ProcessTensor(value_node, &schema_tensor, value, output_cnode, meta_graphT); + ret = ProcessTensor(value_node, &schema_tensor, value, primitive, output_cnode, meta_graphT); } else if (value->isa() || value->isa()) { ret = ProcessInt32OrInt64Imm(value_node, &schema_tensor, value, output_cnode, meta_graphT); } else if (value->isa()) { @@ -945,7 +960,7 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptrisa()) { - auto ret = ConvertInputValueNode(input_node, meta_graphT, fb_node); + auto ret = ConvertInputValueNode(input_node, primitive_c, meta_graphT, fb_node); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvertInputValueNode failed"; return RET_ERROR; diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.h b/mindspore/lite/tools/anf_exporter/anf_exporter.h index 00b21891b0f..cfdcc8dc5af 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.h +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.h @@ -51,11 +51,11 @@ class AnfExporter { int ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode); int ConvertInputParameter(const std::shared_ptr &input_anode, const std::shared_ptr &primitive, const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode); - int ConvertInputValueNode(const std::shared_ptr &input_anode, + int ConvertInputValueNode(const std::shared_ptr &input_anode, const std::shared_ptr &primitive, const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode); int ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, - const std::shared_ptr &value, schema::CNodeT *output_cnode, - const std::unique_ptr &meta_graphT); + const std::shared_ptr &value, const std::shared_ptr &primitive, + schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT); int ProcessInt32OrInt64Imm(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, const std::shared_ptr &value, schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT); diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 2186d24bcac..2c7772850a5 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -16,14 +16,19 @@ #include "tools/common/node_util.h" #include +#include #include +#include "src/ops/populate/populate_register.h" #include "src/common/common.h" #include "src/common/log_adapter.h" #include "tools/common/graph_util.h" #include "tools/common/tensor_util.h" +#include "src/runtime/infer_manager.h" namespace mindspore { namespace lite { +constexpr size_t kInitialSize = 1024; + static const std::vector nhwcOpList = {schema::PrimitiveType_Conv2DBackpropFilterFusion, schema::PrimitiveType_Conv2DBackpropInputFusion, schema::PrimitiveType_AvgPoolGrad, @@ -350,6 +355,33 @@ static int Convert2CKHW(int srcFormat) { return -1; } +STATUS NodeInferShpae(const schema::CNodeT &node, const std::vector &inputs, std::vector *outputs) { + flatbuffers::FlatBufferBuilder fbb(kInitialSize); + auto prim = ConvertToPrimitive(node.primitive.get(), &fbb); + if (prim == nullptr) { + MS_LOG(ERROR) << "get primitive failed."; + fbb.Clear(); + return RET_ERROR; + } + auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim->value_type(), SCHEMA_CUR); + if (parameter_gen == nullptr) { + fbb.Clear(); + MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " << schema::EnumNamePrimitiveType(prim->value_type()); + return RET_ERROR; + } + auto parameter = parameter_gen(prim); + if (parameter == nullptr) { + fbb.Clear(); + MS_LOG(ERROR) << "parameter is nullptr."; + return RET_ERROR; + } + parameter->infer_flag_ = true; + auto ret = KernelInferShape(inputs, outputs, parameter); + fbb.Clear(); + free(parameter); + return ret; +} + STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { if (tensor == nullptr) { MS_LOG(ERROR) << "tensor is null"; diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index 1e4d9f49802..dd430594d1d 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -44,6 +44,8 @@ int CreateOperator(const std::unique_ptr &primitive, schema: using STATUS = int; STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr &node); +STATUS NodeInferShpae(const schema::CNodeT &node, const std::vector &inputs, std::vector *outputs); + inline schema::PrimitiveType GetCNodeTType(const schema::CNodeT &cNodeT) { return cNodeT.primitive->value.type; } inline std::string GetCNodeTTypeName(const schema::CNodeT &cNodeT) { diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 3ede4f979da..5dff0f0b654 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -205,13 +205,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { // init old node indices auto old_nodes = GetGraphNodes(); Optimizer quantNodeOptimizer; - auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); - if (dTypeTransPass == nullptr) { - MS_LOG(ERROR) << "new dTypeTransPass failed"; - return RET_MEMORY_FAILED; - } - dTypeTransPass->SetInputDataDType(ctx.inputDataType); - dTypeTransPass->SetOutputDataDType(ctx.outputDataType); quantNodeOptimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); quantNodeOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); quantNodeOptimizer.AddPass(new (std::nothrow) InferShapePass()); @@ -221,6 +214,14 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { return status; } auto old_nodes2 = GetGraphNodes(); + quantNodeOptimizer.AddPass(new (std::nothrow) InferQuantParamPass()); + auto dTypeTransPass = new (std::nothrow) DTypeTransPass(); + if (dTypeTransPass == nullptr) { + MS_LOG(ERROR) << "new dTypeTransPass failed"; + return RET_MEMORY_FAILED; + } + dTypeTransPass->SetInputDataDType(ctx.inputDataType); + dTypeTransPass->SetOutputDataDType(ctx.outputDataType); quantNodeOptimizer.AddPass(dTypeTransPass); quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc index c8b76b8c64a..c445572a865 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -17,6 +17,8 @@ #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" #include #include +#include +#include #include "tools/common/node_util.h" #include "tools/converter/converter_context.h" #include "src/common/common.h" @@ -36,17 +38,18 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { return status; } + status = DoNodeInoutDTypeTrans(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status; + return status; + } + status = DoModelOutputDTypeTrans(graph); if (status != RET_OK) { MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status; return status; } - status = DoNodeInoutDTypeTrans(graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status; - return status; - } return RET_OK; } @@ -128,48 +131,103 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { return RET_OK; } -STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { - MS_ASSERT(graph != nullptr); - // insert transNode before and after existNode - for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) || (*iter)->quantType != QuantType_AwareTraining) { - continue; - } - auto nodeName = (*iter)->name; - if ((*iter)->inputIndex.empty()) { - MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least"; - return RET_ERROR; - } - STATUS status; - // insert pre - for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) { - MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i)); - auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i)); - if (preTensor->dataType != TypeId::kNumberTypeInt8) { - continue; - } - iter = InsertDTypeTransNode(graph, iter, kBefore, i, kNumberTypeInt8, kNumberTypeFloat32, &status); +STATUS DTypeTransPass::InsetDTypeTransNodeForWrongDtypeQuantOp(schema::MetaGraphT *graph, NodeIter *iter) { + auto node_name = (**iter)->name; + auto status = RET_OK; + // insert fp32 to int8 before + for (size_t i = 0; i < (**iter)->inputIndex.size(); i++) { + auto &pre_tensor = graph->allTensors.at((**iter)->inputIndex.at(i)); + if (pre_tensor->dataType == kNumberTypeFloat32 && !pre_tensor->quantParams.empty() && + pre_tensor->quantParams.front()->inited) { + *iter = InsertDTypeTransNode(graph, *iter, kBefore, i, kNumberTypeFloat32, kNumberTypeInt8, &status); if (status != RET_OK) { - MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed"; + MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed"; return RET_ERROR; } } - - // insert post - for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) { - auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i)); - if (postTensor->dataType != TypeId::kNumberTypeInt8) { - continue; - } - iter = InsertDTypeTransNode(graph, iter, kAfter, i, kNumberTypeFloat32, kNumberTypeUInt8, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; - return RET_ERROR; - } - } - (*iter)->quantType = QuantType_QUANT_NONE; } + // insert int8 to fp32 after + for (size_t i = 0; i < (**iter)->outputIndex.size(); i++) { + auto &post_tensor = graph->allTensors.at((**iter)->outputIndex.at(i)); + if (post_tensor->dataType == kNumberTypeFloat32 && !post_tensor->quantParams.empty() && + post_tensor->quantParams.front()->inited) { + *iter = InsertDTypeTransNode(graph, *iter, kAfter, i, kNumberTypeInt8, kNumberTypeFloat32, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << node_name.c_str() << " failed"; + return RET_ERROR; + } + } + } + return RET_OK; +} + +STATUS DTypeTransPass::InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraphT *graph, NodeIter *iter) { + auto node_name = (**iter)->name; + auto status = RET_OK; + // insert int8 to fp32 before + for (size_t i = 0; i < (**iter)->inputIndex.size(); i++) { + auto &pre_tensor = graph->allTensors.at((**iter)->inputIndex.at(i)); + if (pre_tensor->dataType == kNumberTypeInt8 && !pre_tensor->quantParams.empty() && + pre_tensor->quantParams.front()->inited) { + *iter = InsertDTypeTransNode(graph, *iter, kBefore, i, kNumberTypeInt8, kNumberTypeFloat32, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << node_name.c_str() << " failed"; + return RET_ERROR; + } + } + } + + // insert fp32 to int8 after + for (size_t i = 0; i < (**iter)->outputIndex.size(); i++) { + auto &post_tensor = graph->allTensors.at((**iter)->outputIndex.at(i)); + if (post_tensor->dataType == kNumberTypeInt8 && !post_tensor->quantParams.empty() && + post_tensor->quantParams.front()->inited) { + *iter = InsertDTypeTransNode(graph, *iter, kAfter, i, kNumberTypeInt8, kNumberTypeFloat32, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed"; + return RET_ERROR; + } + } + } + return RET_OK; +} + +STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + auto node_name = (*iter)->name; + if ((*iter)->inputIndex.empty()) { + MS_LOG(ERROR) << "Op " << node_name.c_str() << " should have " << kMinInputNum << " input tensor at least"; + return RET_ERROR; + } + + if ((*iter)->primitive->value.type == schema::PrimitiveType_QuantDTypeCast || + (*iter)->primitive->value.type == schema::PrimitiveType_Cast) { + continue; + } + + STATUS status = RET_OK; + // quant_type is quant_all, but inputs/outputs are float32 + if ((*iter)->quantType == QuantType_QUANT_ALL) { + status = InsetDTypeTransNodeForWrongDtypeQuantOp(graph, &iter); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed"; + return status; + } + continue; + } + + // quant_type is quant_none, but inputs/outputs have quant params and dtype is int8, which means this int8 op is not + // supported yet + if ((*iter)->quantType == QuantType_QUANT_NONE) { + status = InsetDTypeTransNodeForUnsupportedInt8Op(graph, &iter); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertFloat32ToInt8Node before " << node_name.c_str() << " failed"; + return status; + } + } + } return RET_OK; } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h index 42bb39e14f5..c5c765767ea 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h @@ -46,6 +46,11 @@ class DTypeTransPass : public GraphPass { STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph); STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph); + + STATUS InsetDTypeTransNodeForWrongDtypeQuantOp(schema::MetaGraphT *graph, NodeIter *iter); + + STATUS InsetDTypeTransNodeForUnsupportedInt8Op(schema::MetaGraphT *graph, NodeIter *iter); + NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, int32_t inputDataType, int32_t outputDataType, STATUS *errorCode); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc index 34846361c0a..6afa7c02bb2 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc @@ -18,75 +18,32 @@ #include #include #include "src/common/utils.h" -#include "tools/converter/quantizer/calc_quant_param.h" +#include "tools/converter/quantizer/quant_helper/quant_node_helper.h" #include "tools/common/node_util.h" namespace mindspore::lite { STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) { - MS_ASSERT(graph != nullptr); - auto *quantParamRegister = QuantParamCalcRegister::GetInstance(); + if (graph == nullptr) { + MS_LOG(ERROR) << "graph is null"; + return RET_NULL_PTR; + } for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { auto &node = *iter; - MS_ASSERT(node != nullptr); - if (node->quantType == schema::QuantType_WeightQuant) { + if (node == nullptr) { + MS_LOG(ERROR) << "node is null"; + return RET_NULL_PTR; + } + + // make sure reenter node will keep the node status + if (node->quantType != schema::QuantType_QUANT_NONE) { continue; } - DetermineNodeQuantType(*graph, node.get()); - if (node->quantType == schema::QuantType_AwareTraining) { - continue; - } - if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { - MS_ASSERT(false); - } - auto quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); - if (quantParamCalcer == nullptr) { - MS_LOG(DEBUG) << "Can not find QuantParamCalcer for " << node->name.c_str() - << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; - node->quantType = schema::QuantType_QUANT_NONE; - } else { - auto status = quantParamCalcer->Calc(graph, *node); - if (status != RET_OK) { - MS_LOG(DEBUG) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); - node->quantType = schema::QuantType_QUANT_NONE; - } else { - node->quantType = schema::QuantType_AwareTraining; - } - } + + auto quant_helper = QuantHelperRegister::GetInstance()->GetQuantHelper(node->primitive->value.type); + + quant_helper->NodeQuantPreprocess(graph, node.get()); } return RET_OK; } - -void InferQuantParamPass::DetermineNodeQuantType(const schema::MetaGraphT &graph, schema::CNodeT *cnode) { - MS_ASSERT(graph != nullptr); - MS_ASSERT(cnode != nullptr); - bool canQuant = true; - for (auto &inputTensorIdx : cnode->inputIndex) { - MS_ASSERT(graph.allTensors.size() > inputTensorIdx); - auto &inTensor = graph.allTensors.at(inputTensorIdx); - MS_ASSERT(inTensor != nullptr); - if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr || - !inTensor->quantParams.front()->inited) { - canQuant = false; - break; - } - } - - for (auto &outTensorIdx : cnode->outputIndex) { - MS_ASSERT(graph.allTensors.size() > outTensorIdx); - auto &outTensor = graph.allTensors.at(outTensorIdx); - MS_ASSERT(outTensor != nullptr); - if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || - !outTensor->quantParams.front()->inited) { - canQuant = false; - break; - } - } - - if (canQuant) { - cnode->quantType = schema::QuantType_AwareTraining; - } else { - cnode->quantType = schema::QuantType_QUANT_NONE; - } -} } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h index e00b1c2da73..cce0c52f22a 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h @@ -24,14 +24,11 @@ namespace mindspore { namespace lite { class InferQuantParamPass : public GraphPass { public: - InferQuantParamPass() {} + InferQuantParamPass() = default; ~InferQuantParamPass() override = default; STATUS Run(schema::MetaGraphT *graph) override; - - private: - void DetermineNodeQuantType(const schema::MetaGraphT &graph, schema::CNodeT *cnode); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc index 7671bc41be1..ebc2188b5c9 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.cc @@ -31,7 +31,7 @@ namespace mindspore { namespace lite { namespace { constexpr int DEFAULT_DIM_VALUE = -1; -constexpr size_t INITIAL_SIZE = 1024; +constexpr size_t kInitialSize = 1024; void FreeTensors(std::vector input_tensors, std::vector output_tensors) { for (auto &tensor : input_tensors) { @@ -118,7 +118,7 @@ std::vector ConvertTensorToLiteTensor(MetaGraphT *graph, const std::ve STATUS NodeInferShape(const std::unique_ptr &node, const std::vector &inputs, std::vector *outputs) { - flatbuffers::FlatBufferBuilder fbb(INITIAL_SIZE); + flatbuffers::FlatBufferBuilder fbb(kInitialSize); auto prim = ConvertToPrimitive(node->primitive.get(), &fbb); if (prim == nullptr) { MS_LOG(ERROR) << "get primitive failed."; diff --git a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt index f2bf1f52e0d..585ab7ff94e 100644 --- a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt +++ b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt @@ -5,6 +5,7 @@ include_directories(${3RD_DIR}/opencv/build/include/opencv4) file(GLOB QUANTIZER ${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc + ${CMAKE_CURRENT_SOURCE_DIR}/quant_helper/* ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index f9d13631e4f..8b1082e32f2 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -829,7 +829,7 @@ STATUS PostTrainingQuantizer::QuantNode() { << "'s output quant_params size: " << input_primitive_quant_holder->output_quant_params().size() << ", but index: " << index; } - primitive_quant_holder->set_quant_type(schema::QuantType_PostTraining); + primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL); continue; } else if (op_type == ops::kNameConv2DFusion || op_type == ops::kNameConv2dTransposeFusion || op_type == ops::kNameFullConnection || op_type == ops::kNameLayerNormFusion) { @@ -870,7 +870,7 @@ STATUS PostTrainingQuantizer::QuantNode() { output_min_max.min = info->min; DoQuantOutput(output_scale, output_zp, &output_min_max, primitive); - primitive_quant_holder->set_quant_type(schema::QuantType_PostTraining); + primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL); } } return RET_OK; diff --git a/mindspore/lite/tools/converter/quantizer/quant_cast.cc b/mindspore/lite/tools/converter/quantizer/quant_cast.cc index 4442ac06af7..0570946b1f1 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_cast.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_cast.cc @@ -26,7 +26,7 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector auto prim_c = std::make_shared(); prim_c->Init(src_type, dst_type); auto quant_params_holder = std::make_shared(); - quant_params_holder->set_quant_type(schema::QuantType_PostTraining); + quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL); for (auto &quant_param : quant_params) { std::vector quant_params_in = {quant_param}; quant_params_holder->AddInputQuantParam(quant_params_in); @@ -80,7 +80,7 @@ STATUS QuantCast::Run(const FuncGraphPtr &graph) { if (curnode_quant_type != input_cnode_quant_type) { ValueNodePtr value_node = nullptr; - if (curnode_quant_type == schema::QuantType_PostTraining && + if (curnode_quant_type == schema::QuantType_QUANT_ALL && input_cnode_quant_type == schema::QuantType_QUANT_NONE) { if (primitive_quant_param_holder->input_quant_params().size() < i) { MS_LOG(ERROR) << "quant param is invalid."; @@ -89,7 +89,7 @@ STATUS QuantCast::Run(const FuncGraphPtr &graph) { value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitive_quant_param_holder->input_quant_params()[i - 1]); } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && - input_cnode_quant_type == schema::QuantType_PostTraining) { + input_cnode_quant_type == schema::QuantType_QUANT_ALL) { auto input_primitive_quant_param_holder = GetCNodeQuantHolder(input_cnode_primitive_c); value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, input_primitive_quant_param_holder->output_quant_params().front()); diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.cc new file mode 100644 index 00000000000..79093f58236 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.h" +#include "mindspore/core/ir/dtype/type_id.h" + +namespace mindspore::lite { +static constexpr size_t kBiasAddSize = 2; +STATUS BiasAddQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGraphT *graph, + const mindspore::schema::CNodeT &node) { + if (node.inputIndex.size() == kBiasAddSize) { + auto &bias_tensor = graph->allTensors.at(node.inputIndex.at(kBiasAddSize - 1)); + for (auto &quantParam : bias_tensor->quantParams) { + quantParam->dstDtype = TypeId::kNumberTypeInt32; + } + } + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.h b/mindspore/lite/tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.h new file mode 100644 index 00000000000..f96dcb0e28d --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_BIAS_ADD_QUANT_PARAM_PROPOGATOR_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_BIAS_ADD_QUANT_PARAM_PROPOGATOR_H + +#include "tools/converter/quantizer/quant_helper/quant_node_helper.h" +namespace mindspore::lite { +class BiasAddQuantParamPropogator : public QuantParamPropogator { + public: + STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_BIAS_ADD_QUANT_PARAM_PROPOGATOR_H diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.cc new file mode 100644 index 00000000000..0336918171b --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.h" +#include +#include +#include "tools/common/tensor_util.h" +namespace mindspore::lite { +STATUS CarryDataQuantParamPropogator::PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) { + UpdateQuantParamsNum(*graph, node); + + // refresh in_tensor quant_params by out_tensor quant_params + if (input_inited_quant_params_ < 1) { + auto &out_tensor = graph->allTensors.at(node.outputIndex.at(0)); + auto out_quant_param = GetTensorQuantParam(out_tensor); + MS_ASSERT(out_quant_param != nullptr); + if (out_quant_param == nullptr || !out_quant_param->inited) { + MS_LOG(DEBUG) << "Can not determine inputTensor quantParam from outputTensor for node " << node.name; + return RET_ERROR; + } + MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0)); + auto &in_tensor = graph->allTensors.at(node.inputIndex.at(0)); + MS_ASSERT(in_tensor != nullptr); + auto in_quant_param = GetTensorQuantParam(in_tensor); + if (in_quant_param != nullptr && !in_quant_param->inited) { + in_tensor->quantParams.front() = std::move(out_quant_param); + } + } + + // refresh out_tensor quant_params by in_tensor quant_params + if (output_inited_quant_params_ < 1) { + MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0)); + auto &in_tensor = graph->allTensors.at(node.inputIndex.at(0)); + MS_ASSERT(in_tensor != nullptr); + auto in_quant_param = GetTensorQuantParam(in_tensor); + if (in_quant_param == nullptr || !in_quant_param->inited) { + MS_LOG(DEBUG) << "Can not determine outputTensor quantParam from inputTensor for node %s" << node.name; + return RET_ERROR; + } + for (unsigned int i : node.outputIndex) { + MS_ASSERT(graph->allTensors.size() > node.outputIndex.at(i)); + auto &out_tensor = graph->allTensors.at(i); + MS_ASSERT(out_tensor != nullptr); + auto out_quant_param = GetTensorQuantParam(out_tensor); + if (out_quant_param == nullptr) { + out_tensor->quantParams.emplace_back(std::move(in_quant_param)); + continue; + } + if (out_quant_param->inited) { + continue; + } + out_tensor->quantParams.front() = std::move(in_quant_param); + } + } + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.h b/mindspore/lite/tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.h new file mode 100644 index 00000000000..8985afe4584 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_PARAM_PROPOGATOR_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_PARAM_PROPOGATOR_H + +#include "tools/converter/quantizer/quant_helper/quant_node_helper.h" +namespace mindspore::lite { +class CarryDataQuantParamPropogator : public QuantParamPropogator { + public: + STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_PARAM_PROPOGATOR_H diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.cc new file mode 100644 index 00000000000..ab0d11e7576 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.h" +#include +#include +#include "tools/common/tensor_util.h" +namespace mindspore::lite { +bool CarryDataQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) { + MS_ASSERT(node->inputIndex.size() >= 1); + MS_ASSERT(node->outputIndex.size() >= 1); + + // check first in tensor + auto &in_tensor = graph.allTensors.at(node->inputIndex.at(0)); + if (!in_tensor->quantParams.empty()) { + if (std::any_of(in_tensor->quantParams.begin(), in_tensor->quantParams.end(), + [](const std::unique_ptr &quant_param) { return !quant_param->inited; })) { + return false; + } + } else { + return false; + } + + // check first out tensor + auto &out_tensor = graph.allTensors.at(node->outputIndex.at(0)); + if (!out_tensor->quantParams.empty()) { + if (std::any_of(out_tensor->quantParams.begin(), out_tensor->quantParams.end(), + [](const std::unique_ptr &quant_param) { return !quant_param->inited; })) { + return false; + } + node->quantType = schema::QuantType_QUANT_ALL; + return true; + } + return false; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.h b/mindspore/lite/tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.h new file mode 100644 index 00000000000..3e1176aed4c --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_TYPE_DETERMINER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_TYPE_DETERMINER_H + +#include "tools/converter/quantizer/quant_helper/quant_node_helper.h" +namespace mindspore::lite { +class CarryDataQuantTypeDeterminer : public QuantTypeDeterminer { + public: + bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CARRY_DATA_QUANT_TYPE_DETERMINER_H diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/concat_quant_param_propogator.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/concat_quant_param_propogator.cc new file mode 100644 index 00000000000..ecf87296585 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/concat_quant_param_propogator.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/quantizer/quant_helper/concat_quant_param_propogator.h" +#include +#include +#include +#include "src/common/log_adapter.h" +#include "tools/common/tensor_util.h" +#include "tools/converter/quantizer/quantize_util.h" + +namespace mindspore::lite { +STATUS ConcatQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGraphT *graph, + const mindspore::schema::CNodeT &node) { + UpdateQuantParamsNum(*graph, node); + + if (input_inited_quant_params_ != node.inputIndex.size()) { + MS_LOG(DEBUG) << "Can not determine concat inputTensor quantParam, node " << node.name; + return RET_ERROR; + } + + if (output_inited_quant_params_ != 1) { + MS_ASSERT(output_inited_quant_params_ == 0); + float min_min = FLT_MAX; + float max_max = FLT_MIN; + bool narrow_range = false; + int num_bits = -1; + for (size_t i = 0; i < node.inputIndex.size(); i++) { + MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i)); + auto &in_tensor = graph->allTensors.at(i); + MS_ASSERT(in_tensor != nullptr); + auto in_quant_param = GetTensorQuantParam(in_tensor); + if (in_quant_param == nullptr || !in_quant_param->inited) { + return RET_ERROR; + } + if (num_bits == -1) { + narrow_range = in_quant_param->narrowRange; + num_bits = in_quant_param->numBits; + } else { + MS_ASSERT(narrow_range == quantParam->narrowRange); + MS_ASSERT(num_bits == quantParam->numBits); + } + if (min_min > in_quant_param->min) { + min_min = in_quant_param->min; + } + if (max_max < in_quant_param->max) { + max_max = in_quant_param->max; + } + } + + MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); + auto &out_tensor = graph->allTensors.at(node.outputIndex.front()); + MS_ASSERT(out_tensor != nullptr); + auto out_quant_param = std::make_unique(); + + auto status = quant::CalQuantizationParams(out_quant_param.get(), min_min, max_max, narrow_range, num_bits); + if (status != RET_OK) { + MS_LOG(DEBUG) << "in aware quantization run CalQuantizationParams failed!"; + return RET_ERROR; + } + out_tensor->quantParams.emplace_back(std::move(out_quant_param)); + output_inited_quant_params_++; + } + + return RET_OK; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/concat_quant_param_propogator.h b/mindspore/lite/tools/converter/quantizer/quant_helper/concat_quant_param_propogator.h new file mode 100644 index 00000000000..e5a033472ec --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/concat_quant_param_propogator.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONCAT_QUANT_PARAM_PROPOGATOR_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONCAT_QUANT_PARAM_PROPOGATOR_H + +#include "tools/converter/quantizer/quant_helper/quant_node_helper.h" +namespace mindspore::lite { +class ConcatQuantParamPropogator : public QuantParamPropogator { + public: + STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONCAT_QUANT_PARAM_PROPOGATOR_H diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_param_propogator.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_param_propogator.cc new file mode 100644 index 00000000000..061e31b8aae --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_param_propogator.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h" +#include "mindspore/core/ir/dtype/type_id.h" +namespace mindspore::lite { +static constexpr size_t kBiasAdd = 3; + +STATUS ConvQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGraphT *graph, + const mindspore::schema::CNodeT &node) { + if (node.inputIndex.size() == kBiasAdd) { + auto &bias_tensor = graph->allTensors.at(node.inputIndex.at(kBiasAdd - 1)); + for (auto &quantParam : bias_tensor->quantParams) { + quantParam->dstDtype = TypeId::kNumberTypeInt32; + } + } + return RET_OK; +} + +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h b/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h new file mode 100644 index 00000000000..bab2ef441de --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_PARAM_PROPOGATOR_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_PARAM_PROPOGATOR_H + +#include "tools/converter/quantizer/quant_helper/quant_node_helper.h" +namespace mindspore::lite { +class ConvQuantParamPropogator : public QuantParamPropogator { + public: + STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_PARAM_PROPOGATOR_H diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_type_determiner.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_type_determiner.cc new file mode 100644 index 00000000000..69a21aa3f28 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_type_determiner.cc @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h" +#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h" +#include "mindspore/core/utils/log_adapter.h" +#include "mindspore/core/ir/dtype/type_id.h" +namespace mindspore::lite { +static constexpr size_t kInputIndex = 0; +static constexpr size_t kWeightIndex = 1; + +bool ConvQuantTypeDeterminer::DetermineQuantWeight(const mindspore::schema::MetaGraphT &graph, + mindspore::schema::CNodeT *node) { + MS_ASSERT(node->inputIndex.size() >= 2); + auto &input_tensor = graph.allTensors.at(node->inputIndex.at(kInputIndex)); + auto &weight_tensor = graph.allTensors.at(node->inputIndex.at(kWeightIndex)); + if (input_tensor->quantParams.empty() || !input_tensor->quantParams.front()->inited) { + if (!weight_tensor->quantParams.empty() && weight_tensor->quantParams.front()->inited) { + node->quantType = schema::QuantType_QUANT_WEIGHT; + return true; + } + } + return false; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h b/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h new file mode 100644 index 00000000000..2d1b4fd8b56 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_TYPE_DETERMINER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_TYPE_DETERMINER_H + +#include "tools/converter/quantizer/quant_helper/quant_node_helper.h" +namespace mindspore::lite { +class ConvQuantTypeDeterminer : public QuantTypeDeterminer { + public: + bool DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_CONV_QUANT_TYPE_DETERMINER_H diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.cc new file mode 100644 index 00000000000..2783bb08929 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.cc @@ -0,0 +1,23 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h" + +namespace mindspore::lite { + +bool DefaultQuantAllQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) { + return true; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h b/mindspore/lite/tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h new file mode 100644 index 00000000000..0db9761efb1 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_DEFAULT_QUANT_ALL_QUANT_TYPE_DETERMINER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_DEFAULT_QUANT_ALL_QUANT_TYPE_DETERMINER_H + +#include "tools/converter/quantizer/quant_helper/quant_node_helper.h" + +namespace mindspore::lite { +class DefaultQuantAllQuantTypeDeterminer : public QuantTypeDeterminer { + public: + bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_DEFAULT_QUANT_ALL_QUANT_TYPE_DETERMINER_H diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.cc new file mode 100644 index 00000000000..b32b338efed --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h" + +namespace mindspore::lite { + +bool OnlyNeedInputsQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) { + UpdateQuantParamsNum(graph, *node); + if (input_inited_quant_params_ == node->inputIndex.size()) { + node->quantType = schema::QuantType_QUANT_ALL; + return true; + } + return false; +} +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h b/mindspore/lite/tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h new file mode 100644 index 00000000000..a3862fe99a7 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h @@ -0,0 +1,27 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_ONLY_NEED_INPUTS_QUANT_TYPE_DETERMINER_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_ONLY_NEED_INPUTS_QUANT_TYPE_DETERMINER_H + +#include "tools/converter/quantizer/quant_helper/quant_node_helper.h" +namespace mindspore::lite { +class OnlyNeedInputsQuantTypeDeterminer : public QuantTypeDeterminer { + public: + bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) override; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_ONLY_NEED_INPUTS_QUANT_TYPE_DETERMINER_H diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_helper.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_helper.cc new file mode 100644 index 00000000000..d70ec8b68a1 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_helper.cc @@ -0,0 +1,144 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/quantizer/quant_helper/quant_node_helper.h" +#include +#include +#include "mindspore/core/utils/log_adapter.h" +#include "tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.h" +#include "tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.h" +#include "tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.h" +#include "tools/converter/quantizer/quant_helper/concat_quant_param_propogator.h" +#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h" +#include "tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h" +#include "tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h" +#include "tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h" + +namespace mindspore::lite { +void QuantNodeBase::UpdateQuantParamsNum(const schema::MetaGraphT &graph, const schema::CNodeT &node) { + // update input quant params num + input_inited_quant_params_ = 0; + for (auto index : node.inputIndex) { + auto &input_tensor = graph.allTensors.at(index); + if (!input_tensor->quantParams.empty()) { + bool is_quant_params_inited = + !std::any_of(input_tensor->quantParams.begin(), input_tensor->quantParams.end(), + [](const std::unique_ptr &quant_param) { return !quant_param->inited; }); + + if (is_quant_params_inited) { + input_inited_quant_params_++; + } + } + } + + // update output quant params num + output_inited_quant_params_ = 0; + for (auto index : node.outputIndex) { + auto &output_tensor = graph.allTensors.at(index); + if (!output_tensor->quantParams.empty()) { + bool is_quant_params_inited = + !std::any_of(output_tensor->quantParams.begin(), output_tensor->quantParams.end(), + [](const std::unique_ptr &quant_param) { return !quant_param->inited; }); + + if (is_quant_params_inited) { + output_inited_quant_params_++; + } + } + } +} + +bool QuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) { + if (node->quantType != schema::QuantType_QUANT_NONE) { + return node->quantType == schema::QuantType_QUANT_ALL; + } + + UpdateQuantParamsNum(graph, *node); + if (input_inited_quant_params_ == node->inputIndex.size() && + output_inited_quant_params_ == node->outputIndex.size()) { + node->quantType = schema::QuantType_QUANT_ALL; + return true; + } + return false; +} +bool QuantTypeDeterminer::DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node) { + return node->quantType == schema::QuantType_QUANT_WEIGHT; +} + +void QuantNodeHelper::NodeQuantPreprocess(schema::MetaGraphT *graph, schema::CNodeT *node) { + if (quant_type_determiner_->DetermineQuantWeight(*graph, node)) { + return; + } + quant_param_propogator_->PropogateQuantParams(graph, *node); + quant_type_determiner_->DetermineQuantAll(*graph, node); +} + +QuantHelperRegister *QuantHelperRegister::GetInstance() { + static QuantHelperRegister instance; + return &instance; +} + +QuantNodeHelper *QuantHelperRegister::GetQuantHelper(schema::PrimitiveType op_type) { + auto it = register_map_.find(op_type); + if (it != register_map_.end()) { + return it->second; + } + return register_map_[schema::PrimitiveType_NONE]; +} + +QuantHelperRegister::QuantHelperRegister() { + auto base_propogator = std::make_shared(); + auto base_determiner = std::make_shared(); + auto bias_add_propogator = std::make_shared(); + auto carry_data_propogator = std::make_shared(); + auto carry_data_determiner = std::make_shared(); + auto concat_propogator = std::make_shared(); + auto conv_propogator = std::make_shared(); + auto conv_determiner = std::make_shared(); + auto default_quant_all_determiner = std::make_shared(); + auto only_need_inputs_determiner = std::make_shared(); + + register_map_[schema::PrimitiveType_BiasAdd] = new QuantNodeHelper(bias_add_propogator, base_determiner); + + register_map_[schema::PrimitiveType_MaxPoolFusion] = + new QuantNodeHelper(carry_data_propogator, carry_data_determiner); + register_map_[schema::PrimitiveType_Resize] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner); + register_map_[schema::PrimitiveType_Reshape] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner); + register_map_[schema::PrimitiveType_StridedSlice] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner); + register_map_[schema::PrimitiveType_Transpose] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner); + register_map_[schema::PrimitiveType_PadFusion] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner); + register_map_[schema::PrimitiveType_ReduceFusion] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner); + register_map_[schema::PrimitiveType_Gather] = new QuantNodeHelper(carry_data_propogator, carry_data_determiner); + + register_map_[schema::PrimitiveType_Concat] = new QuantNodeHelper(concat_propogator, base_determiner); + + register_map_[schema::PrimitiveType_Conv2DFusion] = new QuantNodeHelper(conv_propogator, conv_determiner); + register_map_[schema::PrimitiveType_MatMul] = new QuantNodeHelper(conv_propogator, conv_determiner); + + register_map_[schema::PrimitiveType_QuantDTypeCast] = + new QuantNodeHelper(base_propogator, default_quant_all_determiner); + + register_map_[schema::PrimitiveType_DetectionPostProcess] = + new QuantNodeHelper(base_propogator, only_need_inputs_determiner); + register_map_[schema::PrimitiveType_NONE] = new QuantNodeHelper(base_propogator, base_determiner); +} + +QuantHelperRegister::~QuantHelperRegister() { + for (const auto &iter : register_map_) { + delete iter.second; + } + this->register_map_.clear(); +} + +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_helper.h b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_helper.h new file mode 100644 index 00000000000..9bc4c5df1a6 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_helper.h @@ -0,0 +1,72 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_NODE_PREPROCESSOR_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_NODE_PREPROCESSOR_H + +#include +#include +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore::lite { +class QuantNodeBase { + public: + void UpdateQuantParamsNum(const schema::MetaGraphT &graph, const schema::CNodeT &node); + + protected: + size_t input_inited_quant_params_ = 0; + size_t output_inited_quant_params_ = 0; +}; + +class QuantParamPropogator : public QuantNodeBase { + public: + virtual STATUS PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) { return RET_OK; } +}; + +class QuantTypeDeterminer : public QuantNodeBase { + public: + virtual bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node); + virtual bool DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node); +}; + +class QuantNodeHelper { + public: + void NodeQuantPreprocess(schema::MetaGraphT *graph, schema::CNodeT *node); + QuantNodeHelper(std::shared_ptr quant_param_propogator, + std::shared_ptr quant_type_determiner) { + quant_param_propogator_ = quant_param_propogator; + quant_type_determiner_ = quant_type_determiner; + } + virtual ~QuantNodeHelper() = default; + + protected: + std::shared_ptr quant_param_propogator_; + std::shared_ptr quant_type_determiner_; +}; + +class QuantHelperRegister { + public: + virtual ~QuantHelperRegister(); + QuantNodeHelper *GetQuantHelper(schema::PrimitiveType op_type); + static QuantHelperRegister *GetInstance(); + + private: + QuantHelperRegister(); + std::unordered_map register_map_; +}; +} // namespace mindspore::lite + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PREPROCESSOR_QUANT_NODE_PREPROCESSOR_H diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index bf9f0ea6329..b80d163f0a7 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -73,7 +73,7 @@ STATUS WeightQuantizer::SetAbstract(const tensor::TensorPtr &tensor_info, const auto abstract_tensor = utils::cast(abstract_base); abstract_tensor->element()->set_type(TypeIdToType(type_id_)); auto quant_param_holder = GetCNodeQuantHolder(primitive); - quant_param_holder->set_quant_type(schema::QuantType_WeightQuant); + quant_param_holder->set_quant_type(schema::QuantType_QUANT_WEIGHT); return RET_OK; } diff --git a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc index fa1263c38f1..e183b906acd 100644 --- a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc @@ -296,6 +296,7 @@ bool MindirAdjustPass::Run(const FuncGraphPtr &graph) { return lite::RET_OK; } MS_ASSERT(graph != nullptr); + graph_ = graph; auto node_list = TopoSort(graph->get_return()); int status = lite::RET_OK; bool success_flag = true; diff --git a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h index e1f0a13439b..4afa6888ff6 100644 --- a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h +++ b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.h @@ -31,15 +31,17 @@ class MindirAdjustPass : public Pass { ~MindirAdjustPass() override = default; void SetQuantType(QuantType quant_type) { quant_type_ = quant_type; } void SetFmkType(FmkType fmk_type) { fmk_type_ = fmk_type; } - int ValueNodeInt64Convert(AnfNodePtr anf_node); void SetTrainFlag(bool train_flag) { train_flag_ = train_flag; } - int ComputeQuantParams(AnfNodePtr anf_node); bool Run(const FuncGraphPtr &graph) override; protected: + int ValueNodeInt64Convert(AnfNodePtr anf_node); + int ComputeQuantParams(AnfNodePtr anf_node); + QuantType quant_type_ = QuantType::QuantType_QUANT_NONE; FmkType fmk_type_ = FmkType::FmkType_MS; bool train_flag_ = false; + FuncGraphPtr graph_; }; } // namespace mindspore::opt #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_ADJUST_PASS_H_ diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc index 4751e1633e1..8d8858144cf 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc @@ -26,7 +26,9 @@ using mindspore::lite::converter::FmkType_TF; using mindspore::lite::converter::FmkType_TFLITE; using mindspore::schema::QuantType_AwareTraining; using mindspore::schema::QuantType_PostTraining; +using mindspore::schema::QuantType_QUANT_ALL; using mindspore::schema::QuantType_QUANT_NONE; +using mindspore::schema::QuantType_QUANT_WEIGHT; using mindspore::schema::QuantType_WeightQuant; namespace mindspore::opt { namespace { @@ -122,9 +124,10 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const CNodePtr &conv_node, } auto weight_node = conv_node->input(kConvWeightIndex); switch (this->quant_type) { - case QuantType_AwareTraining: - case QuantType_PostTraining: case QuantType_WeightQuant: + case QuantType_PostTraining: + case QuantType_QUANT_WEIGHT: + case QuantType_QUANT_ALL: case QuantType_QUANT_NONE: { // sum up from current ms quant models prim->AddAttr(opt::kWeightFormat, MakeValue(Format::KCHW));