From 5b9fd2d8bd20e3721ec03c0eb5b6baaf6e80b279 Mon Sep 17 00:00:00 2001 From: albert-yan Date: Sun, 20 Nov 2022 19:25:19 +0800 Subject: [PATCH] fix transform u8 pass --- .../quant_helper/transform_uint8_pass.cc | 25 +++++++++++++++++++ .../quant_helper/transform_uint8_pass.h | 2 ++ 2 files changed, 27 insertions(+) diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.cc index 16dd5bbaeaf..1ffb70ee5fe 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.cc @@ -220,6 +220,12 @@ bool TransformUint8Pass::CheckNeedDTypeTrans(const CNodePtr &cnode) { if (opt::IsSpecialType(cnode) || CheckControlFlowType(cnode)) { return false; } + + // if CastNode(U8toInt8 or Int8toU8), do nonthing + if (CheckCastNodeUint8Int8(cnode)) { + return false; + } + // If CastNode(kDeQuant) as graph input node, or CastNode(kQuant) as graph output node, do nothing. CastNodeType cast_node_type = kNone; auto status = quant::GetCastNodeType(func_graph_, cnode, &cast_node_type); @@ -251,6 +257,25 @@ bool TransformUint8Pass::CheckNeedDTypeTrans(const CNodePtr &cnode) { return true; } +bool TransformUint8Pass::CheckCastNodeUint8Int8(const CNodePtr &cnode) { + if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) { + auto prim = GetValueNode>(cnode->input(kPrimIndex)); + if (prim == nullptr) { + MS_LOG(ERROR) << "Get prim from value node failed."; + return false; + } + auto primc = api::MakeShared(prim); + MS_CHECK_TRUE_MSG(primc != nullptr, false, "cast ptr failed."); + auto src_type = primc->get_src_t(); + auto dst_type = primc->get_dst_t(); + if ((src_type == kNumberTypeUInt8 && dst_type == kNumberTypeInt8) || + (src_type == kNumberTypeInt8 && dst_type == kNumberTypeUInt8)) { + return true; + } + } + return false; +} + bool TransformUint8Pass::IsSharedWeightParameter(const AnfNodePtr &anf_node) { auto manager = this->func_graph_->manager(); if (manager == nullptr) { diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.h b/mindspore/lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.h index e901ace4c4b..272ed6b0208 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.h +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.h @@ -51,6 +51,8 @@ class TransformUint8Pass { bool IsSharedWeightParameter(const AnfNodePtr &anf_node); + bool CheckCastNodeUint8Int8(const CNodePtr &cnode); + FuncGraphPtr func_graph_ = nullptr; // key is tensor_name