!45490 modify tranform uint8 pass

Merge pull request !45490 from liyan2022/master_codex
This commit is contained in:
i-robot 2022-11-21 11:00:58 +00:00 committed by Gitee
commit 3c17465db9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 27 additions and 0 deletions

View File

@ -220,6 +220,12 @@ bool TransformUint8Pass::CheckNeedDTypeTrans(const CNodePtr &cnode) {
if (opt::IsSpecialType(cnode) || CheckControlFlowType(cnode)) {
return false;
}
// if CastNode(U8toInt8 or Int8toU8), do nonthing
if (CheckCastNodeUint8Int8(cnode)) {
return false;
}
// If CastNode(kDeQuant) as graph input node, or CastNode(kQuant) as graph output node, do nothing.
CastNodeType cast_node_type = kNone;
auto status = quant::GetCastNodeType(func_graph_, cnode, &cast_node_type);
@ -251,6 +257,25 @@ bool TransformUint8Pass::CheckNeedDTypeTrans(const CNodePtr &cnode) {
return true;
}
bool TransformUint8Pass::CheckCastNodeUint8Int8(const CNodePtr &cnode) {
if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
auto prim = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
if (prim == nullptr) {
MS_LOG(ERROR) << "Get prim from value node failed.";
return false;
}
auto primc = api::MakeShared<mindspore::ops::QuantDTypeCast>(prim);
MS_CHECK_TRUE_MSG(primc != nullptr, false, "cast ptr failed.");
auto src_type = primc->get_src_t();
auto dst_type = primc->get_dst_t();
if ((src_type == kNumberTypeUInt8 && dst_type == kNumberTypeInt8) ||
(src_type == kNumberTypeInt8 && dst_type == kNumberTypeUInt8)) {
return true;
}
}
return false;
}
bool TransformUint8Pass::IsSharedWeightParameter(const AnfNodePtr &anf_node) {
auto manager = this->func_graph_->manager();
if (manager == nullptr) {

View File

@ -51,6 +51,8 @@ class TransformUint8Pass {
bool IsSharedWeightParameter(const AnfNodePtr &anf_node);
bool CheckCastNodeUint8Int8(const CNodePtr &cnode);
FuncGraphPtr func_graph_ = nullptr;
// key is tensor_name