forked from mindspore-Ecosystem/mindspore
add dtype_cast op between nodes
This commit is contained in:
parent
b6b254f6e4
commit
dbd606d17c
|
@ -67,23 +67,73 @@ static const std::vector<schema::PrimitiveType> fp32FullOpList = {
|
|||
|
||||
static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> int8OpList = {
|
||||
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw,
|
||||
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
|
||||
schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
|
||||
schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax,
|
||||
schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation,
|
||||
schema::PrimitiveType_Resize, schema::PrimitiveType_FullConnection,
|
||||
schema::PrimitiveType_ArgMax, schema::PrimitiveType_ArgMin,
|
||||
schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm,
|
||||
schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Div,
|
||||
schema::PrimitiveType_Mul, schema::PrimitiveType_Slice,
|
||||
schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split,
|
||||
schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub,
|
||||
schema::PrimitiveType_StridedSlice, schema::PrimitiveType_TopK,
|
||||
schema::PrimitiveType_Unsqueeze, schema::PrimitiveType_MatMul,
|
||||
schema::PrimitiveType_Pad, schema::PrimitiveType_DeConv2D,
|
||||
schema::PrimitiveType_Scale};
|
||||
static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveType_Nchw2Nhwc,
|
||||
schema::PrimitiveType_Nhwc2Nchw,
|
||||
schema::PrimitiveType_Conv2D,
|
||||
schema::PrimitiveType_DepthwiseConv2D,
|
||||
schema::PrimitiveType_Add,
|
||||
schema::PrimitiveType_Pooling,
|
||||
schema::PrimitiveType_Concat,
|
||||
schema::PrimitiveType_SoftMax,
|
||||
schema::PrimitiveType_Reshape,
|
||||
schema::PrimitiveType_Activation,
|
||||
schema::PrimitiveType_Resize,
|
||||
schema::PrimitiveType_FullConnection,
|
||||
schema::PrimitiveType_ArgMax,
|
||||
schema::PrimitiveType_ArgMin,
|
||||
schema::PrimitiveType_BatchNorm,
|
||||
schema::PrimitiveType_FusedBatchNorm,
|
||||
schema::PrimitiveType_BiasAdd,
|
||||
schema::PrimitiveType_Div,
|
||||
schema::PrimitiveType_Mul,
|
||||
schema::PrimitiveType_Slice,
|
||||
schema::PrimitiveType_SoftMax,
|
||||
schema::PrimitiveType_Split,
|
||||
schema::PrimitiveType_Squeeze,
|
||||
schema::PrimitiveType_Sub,
|
||||
schema::PrimitiveType_StridedSlice,
|
||||
schema::PrimitiveType_TopK,
|
||||
schema::PrimitiveType_Unsqueeze,
|
||||
schema::PrimitiveType_MatMul,
|
||||
schema::PrimitiveType_Pad,
|
||||
schema::PrimitiveType_DeConv2D,
|
||||
schema::PrimitiveType_Scale,
|
||||
schema::PrimitiveType_Cast,
|
||||
schema::PrimitiveType_Shape,
|
||||
schema::PrimitiveType_ExpandDims,
|
||||
schema::PrimitiveType_BatchToSpace,
|
||||
schema::PrimitiveType_BatchToSpaceND,
|
||||
schema::PrimitiveType_Reduce,
|
||||
schema::PrimitiveType_Mean,
|
||||
schema::PrimitiveType_Round,
|
||||
schema::PrimitiveType_Floor,
|
||||
schema::PrimitiveType_Ceil,
|
||||
schema::PrimitiveType_Abs,
|
||||
schema::PrimitiveType_Sin,
|
||||
schema::PrimitiveType_Cos,
|
||||
schema::PrimitiveType_Log,
|
||||
schema::PrimitiveType_Sqrt,
|
||||
schema::PrimitiveType_Rsqrt,
|
||||
schema::PrimitiveType_Square,
|
||||
schema::PrimitiveType_LogicalNot,
|
||||
schema::PrimitiveType_SpaceToBatch,
|
||||
schema::PrimitiveType_SpaceToBatchND,
|
||||
schema::PrimitiveType_DepthToSpace,
|
||||
schema::PrimitiveType_Power,
|
||||
schema::PrimitiveType_GatherNd,
|
||||
schema::PrimitiveType_LeakyReLU,
|
||||
schema::PrimitiveType_Gather,
|
||||
schema::PrimitiveType_Equal,
|
||||
schema::PrimitiveType_NotEqual,
|
||||
schema::PrimitiveType_LessEqual,
|
||||
schema::PrimitiveType_Greater,
|
||||
schema::PrimitiveType_GreaterEqual,
|
||||
schema::PrimitiveType_Eltwise,
|
||||
schema::PrimitiveType_DeDepthwiseConv2D,
|
||||
schema::PrimitiveType_DetectionPostProcess,
|
||||
schema::PrimitiveType_Crop,
|
||||
schema::PrimitiveType_PriorBox,
|
||||
schema::PrimitiveType_QuantDTypeCast};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> needInsertOpList = {
|
||||
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,
|
||||
|
|
|
@ -40,6 +40,12 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
|
|||
MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
status = DoNodeInoutDTypeTrans(graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status;
|
||||
return status;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -126,6 +132,51 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
// insert transNode before and after existNode
|
||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
||||
if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) || (*iter)->quantType != QuantType_AwareTraining) {
|
||||
continue;
|
||||
}
|
||||
auto nodeName = (*iter)->name;
|
||||
if ((*iter)->inputIndex.empty()) {
|
||||
MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least";
|
||||
return RET_ERROR;
|
||||
}
|
||||
STATUS status;
|
||||
// insert pre
|
||||
for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) {
|
||||
MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i));
|
||||
auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i));
|
||||
if (preTensor->dataType != TypeId::kNumberTypeInt8) {
|
||||
continue;
|
||||
}
|
||||
iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
// insert post
|
||||
for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) {
|
||||
auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i));
|
||||
if (postTensor->dataType != TypeId::kNumberTypeInt8) {
|
||||
continue;
|
||||
}
|
||||
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
(*iter)->quantType = QuantType_QUANT_NONE;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place,
|
||||
size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) {
|
||||
MS_ASSERT((*existNodeIter) != nullptr);
|
||||
|
|
|
@ -45,6 +45,7 @@ class DTypeTransPass : public GraphPass {
|
|||
|
||||
STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph);
|
||||
|
||||
STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph);
|
||||
NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
|
||||
DTypeTransNodeType nodeType, STATUS *errorCode);
|
||||
|
||||
|
|
|
@ -30,6 +30,10 @@ STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) {
|
|||
if (node->quantType == schema::QuantType_WeightQuant) {
|
||||
continue;
|
||||
}
|
||||
DetermineNodeQuantType(*graph, node.get());
|
||||
if (node->quantType == schema::QuantType_AwareTraining) {
|
||||
continue;
|
||||
}
|
||||
if (GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMax ||
|
||||
GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) {
|
||||
MS_ASSERT(false);
|
||||
|
@ -38,14 +42,14 @@ STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) {
|
|||
if (quantParamCalcer == nullptr) {
|
||||
MS_LOG(WARNING) << "Can not find QuantParamCalcer for " << node->name.c_str()
|
||||
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
|
||||
node->quantType = static_cast<schema::QuantType>(schema::QuantType_QUANT_NONE);
|
||||
node->quantType = schema::QuantType_QUANT_NONE;
|
||||
} else {
|
||||
auto status = quantParamCalcer->Calc(graph, *node);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
|
||||
node->quantType = schema::QuantType_QUANT_NONE;
|
||||
} else {
|
||||
DetermineNodeQuantType(*graph, node.get());
|
||||
node->quantType = schema::QuantType_AwareTraining;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -77,7 +81,7 @@ void InferQuantParamPass::DetermineNodeQuantType(const schema::MetaGraphT &graph
|
|||
}
|
||||
}
|
||||
|
||||
if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*cnode))) {
|
||||
if (canQuant) {
|
||||
cnode->quantType = schema::QuantType_AwareTraining;
|
||||
} else {
|
||||
cnode->quantType = schema::QuantType_QUANT_NONE;
|
||||
|
|
Loading…
Reference in New Issue