add dtype_cast op between nodes

This commit is contained in:
cjh9368 2020-10-27 11:06:00 +08:00
parent b6b254f6e4
commit dbd606d17c
4 changed files with 126 additions and 20 deletions

View File

@ -67,23 +67,73 @@ static const std::vector<schema::PrimitiveType> fp32FullOpList = {
static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {};
static const std::vector<schema::PrimitiveType> int8OpList = {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax,
schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation,
schema::PrimitiveType_Resize, schema::PrimitiveType_FullConnection,
schema::PrimitiveType_ArgMax, schema::PrimitiveType_ArgMin,
schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm,
schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Div,
schema::PrimitiveType_Mul, schema::PrimitiveType_Slice,
schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split,
schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub,
schema::PrimitiveType_StridedSlice, schema::PrimitiveType_TopK,
schema::PrimitiveType_Unsqueeze, schema::PrimitiveType_MatMul,
schema::PrimitiveType_Pad, schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_Scale};
static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveType_Nchw2Nhwc,
schema::PrimitiveType_Nhwc2Nchw,
schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_Add,
schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat,
schema::PrimitiveType_SoftMax,
schema::PrimitiveType_Reshape,
schema::PrimitiveType_Activation,
schema::PrimitiveType_Resize,
schema::PrimitiveType_FullConnection,
schema::PrimitiveType_ArgMax,
schema::PrimitiveType_ArgMin,
schema::PrimitiveType_BatchNorm,
schema::PrimitiveType_FusedBatchNorm,
schema::PrimitiveType_BiasAdd,
schema::PrimitiveType_Div,
schema::PrimitiveType_Mul,
schema::PrimitiveType_Slice,
schema::PrimitiveType_SoftMax,
schema::PrimitiveType_Split,
schema::PrimitiveType_Squeeze,
schema::PrimitiveType_Sub,
schema::PrimitiveType_StridedSlice,
schema::PrimitiveType_TopK,
schema::PrimitiveType_Unsqueeze,
schema::PrimitiveType_MatMul,
schema::PrimitiveType_Pad,
schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_Scale,
schema::PrimitiveType_Cast,
schema::PrimitiveType_Shape,
schema::PrimitiveType_ExpandDims,
schema::PrimitiveType_BatchToSpace,
schema::PrimitiveType_BatchToSpaceND,
schema::PrimitiveType_Reduce,
schema::PrimitiveType_Mean,
schema::PrimitiveType_Round,
schema::PrimitiveType_Floor,
schema::PrimitiveType_Ceil,
schema::PrimitiveType_Abs,
schema::PrimitiveType_Sin,
schema::PrimitiveType_Cos,
schema::PrimitiveType_Log,
schema::PrimitiveType_Sqrt,
schema::PrimitiveType_Rsqrt,
schema::PrimitiveType_Square,
schema::PrimitiveType_LogicalNot,
schema::PrimitiveType_SpaceToBatch,
schema::PrimitiveType_SpaceToBatchND,
schema::PrimitiveType_DepthToSpace,
schema::PrimitiveType_Power,
schema::PrimitiveType_GatherNd,
schema::PrimitiveType_LeakyReLU,
schema::PrimitiveType_Gather,
schema::PrimitiveType_Equal,
schema::PrimitiveType_NotEqual,
schema::PrimitiveType_LessEqual,
schema::PrimitiveType_Greater,
schema::PrimitiveType_GreaterEqual,
schema::PrimitiveType_Eltwise,
schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_DetectionPostProcess,
schema::PrimitiveType_Crop,
schema::PrimitiveType_PriorBox,
schema::PrimitiveType_QuantDTypeCast};
static const std::vector<schema::PrimitiveType> needInsertOpList = {
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,

View File

@ -40,6 +40,12 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status;
return status;
}
status = DoNodeInoutDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status;
return status;
}
return RET_OK;
}
@ -126,6 +132,51 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
return RET_OK;
}
STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// insert transNode before and after existNode
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) || (*iter)->quantType != QuantType_AwareTraining) {
continue;
}
auto nodeName = (*iter)->name;
if ((*iter)->inputIndex.empty()) {
MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least";
return RET_ERROR;
}
STATUS status;
// insert pre
for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) {
MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i));
auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i));
if (preTensor->dataType != TypeId::kNumberTypeInt8) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed";
return RET_ERROR;
}
}
// insert post
for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) {
auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i));
if (postTensor->dataType != TypeId::kNumberTypeInt8) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed";
return RET_ERROR;
}
}
(*iter)->quantType = QuantType_QUANT_NONE;
}
return RET_OK;
}
NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place,
size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) {
MS_ASSERT((*existNodeIter) != nullptr);

View File

@ -45,6 +45,7 @@ class DTypeTransPass : public GraphPass {
STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph);
STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph);
NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
DTypeTransNodeType nodeType, STATUS *errorCode);

View File

@ -30,6 +30,10 @@ STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) {
if (node->quantType == schema::QuantType_WeightQuant) {
continue;
}
DetermineNodeQuantType(*graph, node.get());
if (node->quantType == schema::QuantType_AwareTraining) {
continue;
}
if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax ||
GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) {
MS_ASSERT(false);
@ -38,14 +42,14 @@ STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) {
if (quantParamCalcer == nullptr) {
MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str()
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
node->quantType = static_cast<schema::QuantType>(schema::QuantType_QUANT_NONE);
node->quantType = schema::QuantType_QUANT_NONE;
} else {
auto status = quantParamCalcer->Calc(graph, *node);
if (status != RET_OK) {
MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
node->quantType = schema::QuantType_QUANT_NONE;
} else {
DetermineNodeQuantType(*graph, node.get());
node->quantType = schema::QuantType_AwareTraining;
}
}
}
@ -77,7 +81,7 @@ void InferQuantParamPass::DetermineNodeQuantType(const schema::MetaGraphT &graph
}
}
if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*cnode))) {
if (canQuant) {
cnode->quantType = schema::QuantType_AwareTraining;
} else {
cnode->quantType = schema::QuantType_QUANT_NONE;