forked from mindspore-Ecosystem/mindspore
!10369 fix quant bug
From: @cjh9368 Reviewed-by: @hangangqiang,@zhanghaibo5 Signed-off-by: @hangangqiang
This commit is contained in:
commit
224ef73f7c
|
@ -446,6 +446,9 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
|
|||
MS_ASSERT(prim != nullptr);
|
||||
preTensor->dataType = prim->srcT;
|
||||
toAddTensor->dataType = prim->dstT;
|
||||
if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
|
||||
preTensor->quantParams.front()->zeroPoint += 128;
|
||||
}
|
||||
}
|
||||
graphT->allTensors.emplace_back(std::move(toAddTensor));
|
||||
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
|
||||
|
@ -486,6 +489,9 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
|
|||
MS_ASSERT(prim != nullptr);
|
||||
preTensor->dataType = prim->srcT;
|
||||
toAddTensor->dataType = prim->dstT;
|
||||
if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
|
||||
preTensor->quantParams.front()->zeroPoint += 128;
|
||||
}
|
||||
}
|
||||
graphT->allTensors.emplace_back(std::move(toAddTensor));
|
||||
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
|
||||
|
@ -546,6 +552,9 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
|
|||
MS_ASSERT(prim != nullptr);
|
||||
postTensor->dataType = prim->srcT;
|
||||
toAddTensor->dataType = prim->dstT;
|
||||
if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) {
|
||||
toAddTensor->quantParams.front()->zeroPoint += 128;
|
||||
}
|
||||
}
|
||||
graphT->allTensors.emplace_back(std::move(toAddTensor));
|
||||
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
|
||||
|
@ -613,6 +622,9 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
|
|||
MS_ASSERT(prim != nullptr);
|
||||
postTensor->dataType = prim->srcT;
|
||||
toAddTensor->dataType = prim->dstT;
|
||||
if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) {
|
||||
toAddTensor->quantParams.front()->zeroPoint += 128;
|
||||
}
|
||||
}
|
||||
graphT->allTensors.emplace_back(std::move(toAddTensor));
|
||||
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
|
||||
|
|
Loading…
Reference in New Issue