forked from mindspore-Ecosystem/mindspore
!45490 modify tranform uint8 pass
Merge pull request !45490 from liyan2022/master_codex
This commit is contained in:
commit
3c17465db9
|
@ -220,6 +220,12 @@ bool TransformUint8Pass::CheckNeedDTypeTrans(const CNodePtr &cnode) {
|
|||
if (opt::IsSpecialType(cnode) || CheckControlFlowType(cnode)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// if CastNode(U8toInt8 or Int8toU8), do nonthing
|
||||
if (CheckCastNodeUint8Int8(cnode)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If CastNode(kDeQuant) as graph input node, or CastNode(kQuant) as graph output node, do nothing.
|
||||
CastNodeType cast_node_type = kNone;
|
||||
auto status = quant::GetCastNodeType(func_graph_, cnode, &cast_node_type);
|
||||
|
@ -251,6 +257,25 @@ bool TransformUint8Pass::CheckNeedDTypeTrans(const CNodePtr &cnode) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool TransformUint8Pass::CheckCastNodeUint8Int8(const CNodePtr &cnode) {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
|
||||
auto prim = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Get prim from value node failed.";
|
||||
return false;
|
||||
}
|
||||
auto primc = api::MakeShared<mindspore::ops::QuantDTypeCast>(prim);
|
||||
MS_CHECK_TRUE_MSG(primc != nullptr, false, "cast ptr failed.");
|
||||
auto src_type = primc->get_src_t();
|
||||
auto dst_type = primc->get_dst_t();
|
||||
if ((src_type == kNumberTypeUInt8 && dst_type == kNumberTypeInt8) ||
|
||||
(src_type == kNumberTypeInt8 && dst_type == kNumberTypeUInt8)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool TransformUint8Pass::IsSharedWeightParameter(const AnfNodePtr &anf_node) {
|
||||
auto manager = this->func_graph_->manager();
|
||||
if (manager == nullptr) {
|
||||
|
|
|
@ -51,6 +51,8 @@ class TransformUint8Pass {
|
|||
|
||||
bool IsSharedWeightParameter(const AnfNodePtr &anf_node);
|
||||
|
||||
bool CheckCastNodeUint8Int8(const CNodePtr &cnode);
|
||||
|
||||
FuncGraphPtr func_graph_ = nullptr;
|
||||
|
||||
// key is tensor_name
|
||||
|
|
Loading…
Reference in New Issue