diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index f6775e8b09a..7652004fde3 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -67,23 +67,73 @@ static const std::vector fp32FullOpList = { static const std::vector int8NeedNhwcOpList = {}; -static const std::vector int8OpList = { - schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, - schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, - schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, - schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, - schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation, - schema::PrimitiveType_Resize, schema::PrimitiveType_FullConnection, - schema::PrimitiveType_ArgMax, schema::PrimitiveType_ArgMin, - schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm, - schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Div, - schema::PrimitiveType_Mul, schema::PrimitiveType_Slice, - schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split, - schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub, - schema::PrimitiveType_StridedSlice, schema::PrimitiveType_TopK, - schema::PrimitiveType_Unsqueeze, schema::PrimitiveType_MatMul, - schema::PrimitiveType_Pad, schema::PrimitiveType_DeConv2D, - schema::PrimitiveType_Scale}; +static const std::vector int8OpList = {schema::PrimitiveType_Nchw2Nhwc, + schema::PrimitiveType_Nhwc2Nchw, + schema::PrimitiveType_Conv2D, + schema::PrimitiveType_DepthwiseConv2D, + schema::PrimitiveType_Add, + schema::PrimitiveType_Pooling, + schema::PrimitiveType_Concat, + schema::PrimitiveType_SoftMax, + schema::PrimitiveType_Reshape, + schema::PrimitiveType_Activation, + schema::PrimitiveType_Resize, + schema::PrimitiveType_FullConnection, + schema::PrimitiveType_ArgMax, + schema::PrimitiveType_ArgMin, + schema::PrimitiveType_BatchNorm, + schema::PrimitiveType_FusedBatchNorm, + schema::PrimitiveType_BiasAdd, + schema::PrimitiveType_Div, + schema::PrimitiveType_Mul, + schema::PrimitiveType_Slice, + schema::PrimitiveType_SoftMax, + schema::PrimitiveType_Split, + schema::PrimitiveType_Squeeze, + schema::PrimitiveType_Sub, + schema::PrimitiveType_StridedSlice, + schema::PrimitiveType_TopK, + schema::PrimitiveType_Unsqueeze, + schema::PrimitiveType_MatMul, + schema::PrimitiveType_Pad, + schema::PrimitiveType_DeConv2D, + schema::PrimitiveType_Scale, + schema::PrimitiveType_Cast, + schema::PrimitiveType_Shape, + schema::PrimitiveType_ExpandDims, + schema::PrimitiveType_BatchToSpace, + schema::PrimitiveType_BatchToSpaceND, + schema::PrimitiveType_Reduce, + schema::PrimitiveType_Mean, + schema::PrimitiveType_Round, + schema::PrimitiveType_Floor, + schema::PrimitiveType_Ceil, + schema::PrimitiveType_Abs, + schema::PrimitiveType_Sin, + schema::PrimitiveType_Cos, + schema::PrimitiveType_Log, + schema::PrimitiveType_Sqrt, + schema::PrimitiveType_Rsqrt, + schema::PrimitiveType_Square, + schema::PrimitiveType_LogicalNot, + schema::PrimitiveType_SpaceToBatch, + schema::PrimitiveType_SpaceToBatchND, + schema::PrimitiveType_DepthToSpace, + schema::PrimitiveType_Power, + schema::PrimitiveType_GatherNd, + schema::PrimitiveType_LeakyReLU, + schema::PrimitiveType_Gather, + schema::PrimitiveType_Equal, + schema::PrimitiveType_NotEqual, + schema::PrimitiveType_LessEqual, + schema::PrimitiveType_Greater, + schema::PrimitiveType_GreaterEqual, + schema::PrimitiveType_Eltwise, + schema::PrimitiveType_DeDepthwiseConv2D, + schema::PrimitiveType_DetectionPostProcess, + schema::PrimitiveType_Crop, + schema::PrimitiveType_PriorBox, + schema::PrimitiveType_QuantDTypeCast}; static const std::vector needInsertOpList = { schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc index 1084bcad91d..e5adca31c82 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -40,6 +40,12 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status; return status; } + + status = DoNodeInoutDTypeTrans(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status; + return status; + } return RET_OK; } @@ -126,6 +132,51 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { return RET_OK; } +STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + // insert transNode before and after existNode + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) || (*iter)->quantType != QuantType_AwareTraining) { + continue; + } + auto nodeName = (*iter)->name; + if ((*iter)->inputIndex.empty()) { + MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least"; + return RET_ERROR; + } + STATUS status; + // insert pre + for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) { + MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i)); + auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i)); + if (preTensor->dataType != TypeId::kNumberTypeInt8) { + continue; + } + iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed"; + return RET_ERROR; + } + } + + // insert post + for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) { + auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i)); + if (postTensor->dataType != TypeId::kNumberTypeInt8) { + continue; + } + iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; + return RET_ERROR; + } + } + (*iter)->quantType = QuantType_QUANT_NONE; + } + + return RET_OK; +} + NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) { MS_ASSERT((*existNodeIter) != nullptr); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h index d898c10eb30..f38ee93fed3 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h @@ -45,6 +45,7 @@ class DTypeTransPass : public GraphPass { STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph); + STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph); NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc index c49af020ece..f49df78530f 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infer_quant_param_pass.cc @@ -30,6 +30,10 @@ STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) { if (node->quantType == schema::QuantType_WeightQuant) { continue; } + DetermineNodeQuantType(*graph, node.get()); + if (node->quantType == schema::QuantType_AwareTraining) { + continue; + } if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax || GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { MS_ASSERT(false); @@ -38,14 +42,14 @@ STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) { if (quantParamCalcer == nullptr) { MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str() << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; - node->quantType = static_cast(schema::QuantType_QUANT_NONE); + node->quantType = schema::QuantType_QUANT_NONE; } else { auto status = quantParamCalcer->Calc(graph, *node); if (status != RET_OK) { MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); node->quantType = schema::QuantType_QUANT_NONE; } else { - DetermineNodeQuantType(*graph, node.get()); + node->quantType = schema::QuantType_AwareTraining; } } } @@ -77,7 +81,7 @@ void InferQuantParamPass::DetermineNodeQuantType(const schema::MetaGraphT &graph } } - if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*cnode))) { + if (canQuant) { cnode->quantType = schema::QuantType_AwareTraining; } else { cnode->quantType = schema::QuantType_QUANT_NONE;