fix bug in insert format transform node pass

This commit is contained in:
hangangqiang 2020-12-15 10:20:21 +08:00
parent 6e9fd1ef95
commit 59ca96c1f2
1 changed files with 4 additions and 0 deletions

View File

@ -415,6 +415,7 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
*errorCode = RET_NULL_PTR;
return graphT->nodes.end();
}
toAddTensor->nodeType = schema::NodeType_CNode;
preTensor->refCount = 0;
preTensor->data.clear();
MS_ASSERT(toAddNodeIn->primitive != nullptr);
@ -456,6 +457,7 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
MS_LOG(ERROR) << "Copy TensorT failed";
return graphT->nodes.end();
}
toAddTensor->nodeType = schema::NodeType_CNode;
MS_ASSERT(toAddNodeIn->primitive != nullptr);
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
@ -515,6 +517,7 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
*errorCode = RET_NULL_PTR;
return graphT->nodes.end();
}
toAddTensor->nodeType = schema::NodeType_CNode;
MS_ASSERT(toAddNodeIn->primitive != nullptr);
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
@ -559,6 +562,7 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
*errorCode = RET_NULL_PTR;
return graphT->nodes.end();
}
toAddTensor->nodeType = schema::NodeType_CNode;
graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
auto toAddNode = opDefCopyer(toAddNodeIn.get());