diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 583488d9ffa..d938f201487 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -102,6 +102,7 @@ #include "tools/converter/quantizer/quant_helper/quant_type_determiner.h" #include "tools/converter/quantizer/quant_helper/propagete_quant_param_pass.h" #include "tools/converter/quantizer/quant_helper/dtype_transform_pass.h" +#include "tools/converter/quantizer/quant_helper/graph_inout_transform_pass.h" #include "tools/converter/quantizer/quant_helper/quant_node_pass.h" #include "tools/converter/quantizer/insert_quant_node_manager.h" @@ -445,7 +446,7 @@ STATUS AnfTransform::QATTransform(const FuncGraphPtr &func_graph, const std::sha quant::InsertQuantNodeManager inset_quant_node_pass; ret = inset_quant_node_pass.InsertQuantDtypeCastNode(func_graph); if (ret != RET_OK) { - MS_LOG(ERROR) << "add QuantCast error"; + MS_LOG(ERROR) << "Add QuantCast error"; return RET_ERROR; } return RET_OK; diff --git a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc index e8b9b415c4b..ea8f0b89826 100644 --- a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc +++ b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.cc @@ -15,12 +15,13 @@ */ #define USE_DEPRECATED_API + #include "mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h" #include #include #include #include -#include "ops/quant_dtype_cast.h" +#include #include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/format_utils.h" #include "tools/common/node_util.h" @@ -30,21 +31,24 @@ namespace { constexpr size_t kMinSize3 = 3; constexpr size_t kPrimitiveCOffset = 1; } // namespace -ValueNodePtr InsertQuantNodeManager::NewQuantCastPrimitive( - int src_type, int dst_type, const std::vector &input_quant_params, - const std::vector &output_quant_params) { - auto prim_c = std::make_shared(); - MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr."); - prim_c->Init(src_type, dst_type); - auto quant_params_holder = std::make_shared(input_quant_params.size(), output_quant_params.size()); - MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr."); - quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL); - quant_params_holder->set_input_quant_param(0, input_quant_params); - quant_params_holder->set_output_quant_param(0, output_quant_params); - auto prim = prim_c->GetPrim(); - MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr"); - prim->AddAttr("quant_params", quant_params_holder); - return NewValueNode(prim); +int InsertQuantNodeManager::SetCastNodeAbstrac(const CNodePtr &cnode, const AnfNodePtr &input_node, + const CNodePtr &cast_cnode) { + CHECK_NULL_RETURN(cnode); + CHECK_NULL_RETURN(input_node); + CHECK_NULL_RETURN(cast_cnode); + + AbstractBasePtr abstract; + if (cnode->abstract() != nullptr) { + abstract = cnode->abstract()->Clone(); + } else if (input_node->abstract() != nullptr) { + abstract = input_node->abstract()->Clone(); + } else { + MS_LOG(ERROR) << "Abstract is nullptr, cnode name: " << cnode->fullname_with_scope() + << " input node: " << input_node->fullname_with_scope(); + return RET_NULL_PTR; + } + cast_cnode->set_abstract(abstract); + return RET_OK; } int InsertQuantNodeManager::InsertCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t input_index, @@ -95,47 +99,43 @@ int InsertQuantNodeManager::InsertCastNode(const FuncGraphPtr &graph, const CNod << " input data type: " << data_type << " index " << input_index; return RET_ERROR; } - ValueNodePtr value_node; - // quant node, uint8toint8: update input_quant_param of cnode + + TypeId src_dtype; + TypeId dst_dtype; + std::vector input_quant_params; + std::vector output_quant_params; + + // quant node, uint8toint8: update output_quant_params if (insert_quant_node) { + src_dtype = data_type; + dst_dtype = kNumberTypeInt8; auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(primitive); if (curr_primitive_quant_param_holder->get_input_quant_params().size() < input_index) { MS_LOG(ERROR) << "quant param is invalid."; return RET_ERROR; } - auto output_quant_params = curr_primitive_quant_param_holder->get_input_quant_params()[input_index - 1]; - std::vector input_quant_params(output_quant_params); - // Uint8toInt8 - if (data_type == kNumberTypeUInt8) { - for (auto &quant_param : output_quant_params) { - quant_param.zeroPoint -= kU8ZeroPointOffset; - } - curr_primitive_quant_param_holder->set_input_quant_param(input_index - 1, output_quant_params); - } - value_node = NewQuantCastPrimitive(data_type, kNumberTypeInt8, input_quant_params, output_quant_params); - } else { // insert_dequant_node, int8touint8: update output_quant_params of input_node + output_quant_params = curr_primitive_quant_param_holder->get_input_quant_params()[input_index - 1]; + std::copy(output_quant_params.cbegin(), output_quant_params.cend(), std::back_inserter(input_quant_params)); + } else { // insert_dequant_node, int8touint8: update input_quant_params + src_dtype = kNumberTypeInt8; + dst_dtype = data_type; auto input_primitive_quant_param_holder = GetCNodeQuantHolder(input_cnode_primitive_c); if (input_primitive_quant_param_holder->get_output_quant_params().empty()) { MS_LOG(ERROR) << "output quant param is empty."; return RET_ERROR; } - std::vector input_quant_params = - input_primitive_quant_param_holder->get_output_quant_params()[0]; - std::vector output_quant_params(input_quant_params); - if (data_type == kNumberTypeUInt8) { - for (auto &quant_param : input_quant_params) { - quant_param.zeroPoint -= kU8ZeroPointOffset; - } - input_primitive_quant_param_holder->set_output_quant_param(0, input_quant_params); - } - value_node = NewQuantCastPrimitive(kNumberTypeInt8, data_type, input_quant_params, output_quant_params); + input_quant_params = input_primitive_quant_param_holder->get_output_quant_params()[0]; + std::copy(input_quant_params.cbegin(), input_quant_params.cend(), std::back_inserter(output_quant_params)); } - std::vector op_inputs = {value_node, input_node}; + ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params); + std::vector op_inputs = {new_primitive, input_node}; auto quant_cast_cnode = graph->NewCNode(op_inputs); MS_CHECK_TRUE_MSG(quant_cast_cnode != nullptr, RET_NULL_PTR, "quant_cast_cnode is nullptr."); quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast_" + std::to_string(input_index)); cnode->set_input(input_index, quant_cast_cnode); + MS_LOG(INFO) << "InsertCastNode cnode name: " << quant_cast_cnode->fullname_with_scope() + << " src_dtype: " << src_dtype << " dst_dtype: " << dst_dtype; return RET_OK; } @@ -171,8 +171,8 @@ int InsertQuantNodeManager::CheckDataType(const AnfNodePtr &input_node, TypeId c TypeId type_id; auto ret = opt::GetDataTypeFromAnfNode(input_node, &type_id); if (ret != RET_OK) { - MS_LOG(ERROR) << "Fetch DataType from cnode failed."; - return ret; + MS_LOG(WARNING) << "Fetch DataType from cnode failed."; + return RET_OK; } if (type_id != check_type_id) { return RET_NO_CHANGE; @@ -181,9 +181,9 @@ int InsertQuantNodeManager::CheckDataType(const AnfNodePtr &input_node, TypeId c return RET_OK; } -int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph, TypeId src_dtype) { +int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph, TypeId cast_dtype) { MS_ASSERT(graph != nullptr); - if (src_dtype != kNumberTypeFloat32 && src_dtype != kNumberTypeUInt8) { + if (cast_dtype != kNumberTypeFloat32 && cast_dtype != kNumberTypeUInt8) { MS_LOG(ERROR) << "Invalid src dtype, only support fp32 and uint8."; return RET_ERROR; } @@ -192,11 +192,7 @@ int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph, for (size_t i = 1; i < cnode->inputs().size(); i++) { auto input_node = cnode->input(i); bool is_graph_input = IsGraphInput(input_node); - // Uint8 check quant params inited - if (src_dtype == kNumberTypeUInt8 && !is_graph_input && !CheckInited(input_node, i)) { - continue; - } - auto ret = CheckDataType(input_node, src_dtype); + auto ret = CheckDataType(input_node, cast_dtype); if (ret == RET_NO_CHANGE) { continue; } else if (ret != RET_OK) { @@ -312,4 +308,139 @@ int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph, } return RET_OK; } + +int InsertQuantNodeManager::InserQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, + InsertDirection insert_direction, TypeId cast_dtype, + CastNodeType cast_node_type, size_t index, + const AnfNodePtr &output_node) { + if (insert_direction == FORWARD && cast_node_type == kQuant) { + return InserForwardQuantCastNode(graph, cnode, cast_dtype, index); + } else if (insert_direction == BACKWARD && cast_node_type == kDeQuant) { + return InserBackwardDeQuantCastNode(graph, cnode, cast_dtype, index, output_node); + } + MS_LOG(ERROR) << "Invalie insert direction: " << insert_direction; + return RET_NOT_SUPPORT; +} + +int InsertQuantNodeManager::InserForwardQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, + TypeId cast_dtype, size_t index) { + if (cast_dtype != kNumberTypeUInt8 && cast_dtype != kNumberTypeFloat32) { + MS_LOG(ERROR) << "Invalie cast dtype: " << cast_dtype; + return RET_NOT_SUPPORT; + } + + auto input_node = cnode->input(index); + CHECK_NULL_RETURN(input_node); + if (!input_node->isa() && !IsGraphInput(input_node)) { + MS_LOG(ERROR) << "Invalid input node, input node name: " << input_node->fullname_with_scope(); + return RET_ERROR; + } + auto ret = CheckDataType(input_node, cast_dtype); + if (ret == RET_NO_CHANGE) { + MS_LOG(DEBUG) << "input node not dtype: " << cast_dtype; + return RET_OK; + } else if (ret != RET_OK) { + MS_LOG(ERROR) << "Check data type failed, input node name: " << input_node->fullname_with_scope(); + return ret; + } + // insert forward cast_node + TypeId src_dtype = cast_dtype; + TypeId dst_dtype = kNumberTypeInt8; + std::vector input_quant_params; + std::vector output_quant_params; + + auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(cnode); + CHECK_NULL_RETURN(curr_primitive_quant_param_holder); + if (curr_primitive_quant_param_holder->get_input_quant_params().size() < index) { + MS_LOG(ERROR) << "quant param is invalid."; + return RET_ERROR; + } + output_quant_params = curr_primitive_quant_param_holder->get_input_quant_params()[index - 1]; + std::copy(output_quant_params.cbegin(), output_quant_params.cend(), std::back_inserter(input_quant_params)); + // Uint8toInt8 + if (src_dtype == kNumberTypeUInt8) { + for (auto &quant_param : input_quant_params) { + quant_param.zeroPoint += kU8ZeroPointOffset; + } + } + ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params); + std::vector op_inputs = {new_primitive, input_node}; + auto quant_cast_cnode = graph->NewCNode(op_inputs); + MS_CHECK_TRUE_MSG(quant_cast_cnode != nullptr, RET_NULL_PTR, "quant_cast_cnode is nullptr."); + quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) + + "_pre"); + if (SetCastNodeAbstrac(cnode, input_node, quant_cast_cnode) != RET_OK) { + MS_LOG(ERROR) << "SetCastNodeAbstrac failed."; + return RET_ERROR; + } + if (quant::UpdateDataType(quant_cast_cnode, dst_dtype) != RET_OK) { + MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope(); + return RET_ERROR; + } + cnode->set_input(index, quant_cast_cnode); + MS_LOG(INFO) << "InserForwardQuantCastNode cnode name: " << cnode->fullname_with_scope() << " src dtype:" << src_dtype + << " dst_type: " << dst_dtype; + return RET_OK; +} +int InsertQuantNodeManager::InserBackwardDeQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, + TypeId cast_dtype, size_t index, + const AnfNodePtr &output_node) { + if (cast_dtype != kNumberTypeUInt8 && cast_dtype != kNumberTypeFloat32) { + MS_LOG(ERROR) << "Invalie cast dtype: " << cast_dtype; + return RET_NOT_SUPPORT; + } + CHECK_NULL_RETURN(output_node); + auto ret = CheckDataType(output_node, cast_dtype); + if (ret == RET_NO_CHANGE) { + MS_LOG(DEBUG) << "input node not dtype: " << cast_dtype; + return RET_OK; + } else if (ret != RET_OK) { + MS_LOG(ERROR) << "Check data type failed, output node name: " << output_node->fullname_with_scope(); + return ret; + } + auto manager = graph->manager(); + if (manager == nullptr) { + manager = Manage(graph, true); + } + CHECK_NULL_RETURN(manager); + + // insert backward cast_node + TypeId src_dtype = kNumberTypeInt8; + TypeId dst_dtype = cast_dtype; + std::vector input_quant_params; + std::vector output_quant_params; + + auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(cnode); + CHECK_NULL_RETURN(curr_primitive_quant_param_holder); + if (curr_primitive_quant_param_holder->get_output_quant_params().empty()) { + MS_LOG(ERROR) << "quant param is invalid."; + return RET_ERROR; + } + input_quant_params = curr_primitive_quant_param_holder->get_output_quant_params().front(); + std::copy(input_quant_params.cbegin(), input_quant_params.cend(), std::back_inserter(output_quant_params)); + // Int8toUint8 + if (dst_dtype == kNumberTypeUInt8) { + for (auto &quant_param : output_quant_params) { + quant_param.zeroPoint += kU8ZeroPointOffset; + } + } + ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params); + std::vector op_inputs = {new_primitive, cnode->cast()}; + auto quant_cast_cnode = graph->NewCNode(op_inputs); + MS_CHECK_TRUE_MSG(quant_cast_cnode != nullptr, RET_NULL_PTR, "quant_cast_cnode is nullptr."); + quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) + + "_post"); + if (SetCastNodeAbstrac(cnode, output_node, quant_cast_cnode) != RET_OK) { + MS_LOG(ERROR) << "SetCastNodeAbstrac failed."; + return RET_ERROR; + } + if (quant::UpdateDataType(quant_cast_cnode, dst_dtype) != RET_OK) { + MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope(); + return RET_ERROR; + } + manager->SetEdge(output_node, index, quant_cast_cnode); + MS_LOG(INFO) << "InserBackwardDeQuantCastNode cnode name: " << cnode->fullname_with_scope() + << " src dtype:" << src_dtype << " dst_type: " << dst_dtype; + return RET_OK; +} } // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h index dbaa095f2f2..9b3171eff6c 100644 --- a/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h +++ b/mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h @@ -33,16 +33,14 @@ class InsertQuantNodeManager { ~InsertQuantNodeManager() = default; - int InsertQuantDtypeCastNode(const FuncGraphPtr &graph, TypeId src_dtype = kNumberTypeFloat32); + int InsertQuantDtypeCastNode(const FuncGraphPtr &graph, TypeId cast_dtype = kNumberTypeFloat32); int InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set &support_dynamic_quant_ops, const std::set &skip_quant_node); + int InserQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, InsertDirection insert_direction, + TypeId cast_dtype, CastNodeType cast_node_type, size_t index, const AnfNodePtr &output_node); private: - ValueNodePtr NewQuantCastPrimitive(int src_type, int dst_type, - const std::vector &input_quant_params, - const std::vector &output_quant_params); - int InsertCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t input_index, bool is_graph_input); bool CheckInited(const AnfNodePtr &input_node, size_t index) const; @@ -55,6 +53,13 @@ class InsertQuantNodeManager { int InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index); + int SetCastNodeAbstrac(const CNodePtr &cnode, const AnfNodePtr &input_node, const CNodePtr &cast_cnode); + + int InserForwardQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype, size_t index); + + int InserBackwardDeQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype, size_t index, + const AnfNodePtr &output_node); + private: TypeId dst_type_ = kNumberTypeInt8; bool symmetric_ = false; diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/dtype_transform_pass.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/dtype_transform_pass.cc index d246539e250..ba23e604f30 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/dtype_transform_pass.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/dtype_transform_pass.cc @@ -15,75 +15,55 @@ */ #include "tools/converter/quantizer/quant_helper/dtype_transform_pass.h" +#include +#include +#include +#include +#include #include "tools/common/node_util.h" #include "tools/converter/quantizer/insert_quant_node_manager.h" #include "tools/converter/quantizer/quantize_util.h" +#include "tools/optimizer/common/format_utils.h" +#include "tools/optimizer/common/gllo_utils.h" namespace mindspore::lite::quant { // only enable for uint8 int DTypeTransformPass::Transform() { - CHECK_NULL_RETURN(func_graph_); - // insert CastNode Uint8toInt8 & Int8toUint8 - quant::InsertQuantNodeManager insert_quant_node_manager; - auto ret = insert_quant_node_manager.InsertQuantDtypeCastNode(func_graph_, kNumberTypeUInt8); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Insert Uint8toInt8 CastNode failed."; - return RET_ERROR; - } - - // update data type and zp auto cnodes = func_graph_->GetOrderedCnodes(); for (auto &cnode : cnodes) { - if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) { + if (!CheckNeedDTypeTrans(cnode)) { + MS_LOG(INFO) << "CheckNeedDTypeTrans invalid cnode, cnode name: " << cnode->fullname_with_scope(); continue; } - TypeId cnode_dtype; - if (opt::GetDataTypeFromAnfNode(cnode, &cnode_dtype) != RET_OK) { - MS_LOG(ERROR) << "Get data type failed, cnode type: " << cnode->type_name(); + auto status = DoNodeDTypeTrans(cnode); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoNodeDTypeTrans failed, cnode name: " << cnode->fullname_with_scope(); + return status; + } + schema::QuantType curr_quant_type; + if (GetQuantType(cnode, &curr_quant_type) != RET_OK) { + MS_LOG(ERROR) << "Get quant type failed, cnode name: " << cnode->fullname_with_scope(); return RET_ERROR; } - if (cnode_dtype != kNumberTypeUInt8) { + if (curr_quant_type != schema::QuantType_QUANT_ALL) { + MS_LOG(ERROR) << "Invalid cnode quant type, cnode name: " << cnode->fullname_with_scope() + << " quant type: " << curr_quant_type; continue; } - if (UpdateDataType(cnode, kNumberTypeInt8) != RET_OK) { - MS_LOG(ERROR) << "Update data type failed, cnode name: " << cnode->fullname_with_scope(); - return RET_ERROR; + status = InsertForwardCastNode(cnode, curr_quant_type); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertForwardCastNode failed, cnode name: " << cnode->fullname_with_scope(); + return status; } - auto curr_quant_param_holder = GetCNodeQuantHolder(cnode); - CHECK_NULL_RETURN(curr_quant_param_holder); - if (curr_quant_param_holder->get_output_quant_params().empty()) { - MS_LOG(ERROR) << "output quant params empty."; - return RET_ERROR; - } - auto out_quant_params = curr_quant_param_holder->get_output_quant_params()[0]; - for (auto &quant_param : out_quant_params) { - quant_param.zeroPoint -= kU8ZeroPointOffset; - } - curr_quant_param_holder->set_output_quant_param(0, out_quant_params); - for (size_t i = 1; i < cnode->size(); i++) { - auto input_node = cnode->input(i); - CHECK_NULL_RETURN(input_node); - if (IsGraphInput(input_node)) { - continue; - } else if (input_node->isa()) { - // updata input_quant_params - if (curr_quant_param_holder->get_input_quant_params().size() < i) { - MS_LOG(ERROR) << "quant params invalid."; - return RET_ERROR; - } - auto input_quant_params = curr_quant_param_holder->get_input_quant_params()[i - 1]; - for (auto &quant_param : input_quant_params) { - quant_param.zeroPoint -= kU8ZeroPointOffset; - } - curr_quant_param_holder->set_input_quant_param(i - 1, input_quant_params); - } else if (input_node->isa()) { - ret = DoParameterNodeTrans(cnode, input_node->cast(), i); - if (ret != RET_OK) { - MS_LOG(WARNING) << "DoParameterNodeTrans failed, input node name: " << input_node->fullname_with_scope(); - } + // DetectionPostProcess op(Uint8toFp32, not need backward cast node) + if (!CheckNodeInSet(cnode, kUint8toFP32Operator)) { + status = InsertBackwardCastNode(cnode, curr_quant_type); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertBackwardCastNode failed, cnode name: " << cnode->fullname_with_scope(); + return status; } } - } + } // for return RET_OK; } @@ -107,8 +87,11 @@ int DTypeTransformPass::DoParameterNodeTrans(const CNodePtr &cnode, const Parame return ret; } auto quant_param_holder = GetCNodeQuantHolder(cnode); + if (quant_param_holder->get_input_quant_params().size() < input_index) { + MS_LOG(ERROR) << "Invalid quant params. input node name: " << input_node->fullname_with_scope(); + return RET_ERROR; + } auto quant_params = quant_param_holder->get_input_quant_params().at(input_index - 1); - MS_CHECK_FALSE_MSG(quant_params.empty(), RET_ERROR, "Quant params is empty."); for (auto &quant_param : quant_params) { quant_param.zeroPoint -= kU8ZeroPointOffset; } @@ -121,18 +104,6 @@ int DTypeTransformPass::DoParameterNodeTrans(const CNodePtr &cnode, const Parame MS_LOG(ERROR) << input_node->fullname_with_scope() << " set new dtype failed."; return ret; } - - auto abstract_base = input_node->abstract(); - CHECK_NULL_RETURN(abstract_base); - if (!utils::isa(abstract_base)) { - MS_LOG(ERROR) << "Abstract of node should be abstract tensor, input node name: " - << input_node->fullname_with_scope(); - return RET_ERROR; - } - auto abstract_tensor = utils::cast(abstract_base); - CHECK_NULL_RETURN(abstract_tensor); - CHECK_NULL_RETURN(abstract_tensor->element()); - abstract_tensor->element()->set_type(TypeIdToType(kNumberTypeInt8)); return RET_OK; } @@ -151,4 +122,164 @@ int DTypeTransformPass::Uint8toInt8(uint8_t *data, int size) { } return RET_OK; } + +int DTypeTransformPass::GetQuantType(const CNodePtr &cnode, schema::QuantType *quant_type) { + CHECK_NULL_RETURN(cnode); + auto quant_param_holder = GetCNodeQuantHolder(cnode); + CHECK_NULL_RETURN(quant_param_holder); + *quant_type = quant_param_holder->quant_type(); + return RET_OK; +} + +/** + * Transform CNode(dtype,uint8toint8,weigh data) + * */ +int DTypeTransformPass::DoNodeDTypeTrans(const CNodePtr &cnode) { + auto curr_quant_param_holder = GetCNodeQuantHolder(cnode); + CHECK_NULL_RETURN(curr_quant_param_holder); + TypeId cnode_dtype = kTypeUnknown; + if (opt::GetDataTypeFromAnfNode(cnode, &cnode_dtype) != RET_OK) { + MS_LOG(ERROR) << "Get data type failed, cnode name: " << cnode->fullname_with_scope(); + return RET_ERROR; + } + if (cnode_dtype == kNumberTypeUInt8) { + MS_LOG(INFO) << "cnode dtype kNumberTypeUInt8, cnode name: " << cnode->fullname_with_scope(); + if (UpdateDataType(cnode, kNumberTypeInt8) != RET_OK) { + MS_LOG(ERROR) << "Update data type failed, cnode name: " << cnode->fullname_with_scope(); + return RET_ERROR; + } + if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) { + auto primitive_c = GetValueNode>(cnode->input(0)); + auto primc = api::MakeShared(primitive_c); + primc->set_dst_t(kNumberTypeInt8); + } + // update output quant param zp + if (curr_quant_param_holder->get_output_quant_params().empty()) { + MS_LOG(ERROR) << "output quant params empty."; + return RET_ERROR; + } + auto out_quant_params = curr_quant_param_holder->get_output_quant_params()[0]; + for (auto &quant_param : out_quant_params) { + quant_param.zeroPoint -= kU8ZeroPointOffset; + } + curr_quant_param_holder->set_output_quant_param(0, out_quant_params); + } + + // DTypeCastNode, set quant type + if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) { + curr_quant_param_holder->set_quant_type(schema::QuantType_QUANT_NONE); + } + + for (size_t index = 1; index < cnode->size(); index++) { + auto input_node = cnode->input(index); + CHECK_NULL_RETURN(input_node); + if (IsGraphInput(input_node) || input_node->isa()) { + // updata graph input quant params + if (curr_quant_param_holder->get_input_quant_params().size() < index) { + MS_LOG(WARNING) << "quant params invalid, input node name: " << input_node->fullname_with_scope(); + continue; + } + auto input_quant_params = curr_quant_param_holder->get_input_quant_params()[index - 1]; + if (input_quant_params.empty() || !input_quant_params.front().inited) { + MS_LOG(WARNING) << "input node not quantizied, input node name: " << input_node->fullname_with_scope(); + continue; + } + for (auto &quant_param : input_quant_params) { + quant_param.zeroPoint -= kU8ZeroPointOffset; + } + curr_quant_param_holder->set_input_quant_param(index - 1, input_quant_params); + } else if (input_node->isa()) { // weight data + auto ret = DoParameterNodeTrans(cnode, input_node->cast(), index); + if (ret != RET_OK) { + MS_LOG(WARNING) << "DoParameterNodeTrans failed, input node name: " << input_node->fullname_with_scope(); + } + } + } + return RET_OK; +} + +int DTypeTransformPass::InsertForwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type) { + // inputs + quant::InsertQuantNodeManager insert_node_manager; + for (size_t index = 1; index < cnode->size(); index++) { + auto input_node = cnode->input(index); + CHECK_NULL_RETURN(input_node); + if (!input_node->isa() && !IsGraphInput(input_node)) { + MS_LOG(DEBUG) << "Invalid input node, not CNode and graph input."; + continue; + } + schema::QuantType input_quant_type; + if (GetQuantType(cnode, &input_quant_type) != RET_OK) { + MS_LOG(WARNING) << "Get quant type failed, input node name: " << input_node->fullname_with_scope(); + return RET_ERROR; + } + schema::QuantType pre_quant_type = schema::QuantType_QUANT_NONE; + if (input_node->isa()) { + if (GetQuantType(input_node->cast(), &pre_quant_type) != RET_OK) { + MS_LOG(ERROR) << "Get quant type failed, cnode name: " << cnode->fullname_with_scope(); + return RET_ERROR; + } + } + if (pre_quant_type == schema::QuantType_QUANT_NONE && curr_quant_type == schema::QuantType_QUANT_ALL) { + auto status = insert_node_manager.InserQuantCastNode(this->func_graph_, cnode, FORWARD, kNumberTypeUInt8, kQuant, + index, nullptr); + if (status != RET_OK) { + MS_LOG(ERROR) << "InserQuantCastNode failed, cnode name: " << cnode->fullname_with_scope(); + return status; + } + MS_LOG(INFO) << "InserQuantCastNode forward Uint8toInt8, cnode name: " << cnode->fullname_with_scope(); + } + } + return RET_OK; +} + +int DTypeTransformPass::InsertBackwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type) { + // outputs + auto manager = this->func_graph_->manager(); + if (manager == nullptr) { + manager = Manage(this->func_graph_, true); + } + CHECK_NULL_RETURN(manager); + auto node_users = manager->node_users()[cnode]; + quant::InsertQuantNodeManager insert_node_manager; + for (auto &node_user : node_users) { + auto output_cnode = node_user.first->cast(); + schema::QuantType post_quant_type; + if (GetQuantType(output_cnode, &post_quant_type) != RET_OK) { + MS_LOG(ERROR) << "Get quant type failed, cnode name: " << output_cnode->fullname_with_scope(); + return RET_ERROR; + } + if (curr_quant_type == schema::QuantType_QUANT_ALL && post_quant_type == schema::QuantType_QUANT_NONE) { + auto status = insert_node_manager.InserQuantCastNode(this->func_graph_, cnode, BACKWARD, kNumberTypeUInt8, + kDeQuant, node_user.second, node_user.first); + if (status != RET_OK) { + MS_LOG(ERROR) << "InserQuantCastNode dequant failed, cnode name: " << cnode->fullname_with_scope(); + return status; + } + MS_LOG(INFO) << "InserQuantCastNode backward Int8toUint8, cnode name: " << cnode->fullname_with_scope(); + } + } // node_users + return RET_OK; +} + +bool DTypeTransformPass::CheckNeedDTypeTrans(const CNodePtr &cnode) { + if (opt::IsSpecialType(cnode)) { + return false; + } + if (IsGraphInDTypeCast(cnode) || IsGraphOutDTypeCast(func_graph_, cnode)) { + return false; + } + TypeId cnode_dtype = kTypeUnknown; + if (opt::GetDataTypeFromAnfNode(cnode, &cnode_dtype) != RET_OK) { + MS_LOG(ERROR) << "Get data type failed, cnode name: " << cnode->fullname_with_scope(); + return false; + } + bool is_fp32_output = + opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast) || CheckNodeInSet(cnode, kUint8toFP32Operator); + if (cnode_dtype != kNumberTypeUInt8 && !is_fp32_output) { + MS_LOG(DEBUG) << "dtype not kNumberTypeUInt8, cnode name: " << cnode->fullname_with_scope(); + return false; + } + return true; +} } // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/dtype_transform_pass.h b/mindspore/lite/tools/converter/quantizer/quant_helper/dtype_transform_pass.h index 1613e71aa0d..c772b2f81e4 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/dtype_transform_pass.h +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/dtype_transform_pass.h @@ -16,24 +16,43 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_DTYPE_TRANSFORM_PASS_H_ #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_DTYPE_TRANSFORM_PASS_H_ +#include +#include #include "ir/anf.h" #include "ir/func_graph.h" #include "ops/primitive_c.h" #include "tools/converter/quantizer/quant_param_holder.h" #include "tools/converter/quantizer/quantize_util.h" +#include "tools/converter/quantizer/quant_params.h" namespace mindspore::lite::quant { +/** + * Transform CNode(dtype,uint8toint8,weigh data) + * Insert QuantCastNode + * */ class DTypeTransformPass { public: explicit DTypeTransformPass(const FuncGraphPtr &func_graph) : func_graph_(func_graph) {} + ~DTypeTransformPass() = default; int Transform(); private: int DoParameterNodeTrans(const CNodePtr &cnode, const ParameterPtr &input_node, size_t input_index); + int Uint8toInt8(uint8_t *data, int size); + int GetQuantType(const CNodePtr &cnode, schema::QuantType *quant_type); + + int DoNodeDTypeTrans(const CNodePtr &cnode); + + bool CheckNeedDTypeTrans(const CNodePtr &cnode); + + int InsertForwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type); + + int InsertBackwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type); + FuncGraphPtr func_graph_ = nullptr; }; } // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/graph_inout_transform_pass.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/graph_inout_transform_pass.cc new file mode 100644 index 00000000000..1e1da0dd90d --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/graph_inout_transform_pass.cc @@ -0,0 +1,224 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/quantizer/quant_helper/graph_inout_transform_pass.h" +#include +#include +#include +#include +#include +#include "tools/common/node_util.h" +#include "tools/converter/quantizer/insert_quant_node_manager.h" +#include "tools/converter/quantizer/quantize_util.h" +#include "tools/optimizer/common/format_utils.h" +#include "ops/quant_dtype_cast.h" +#include "tools/optimizer/common/gllo_utils.h" + +namespace mindspore::lite::quant { +int GraphInoutTransformPass::Transform() { + auto ret = DoGraphInputDTypeTransform(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoGraphInputDTypeTransform failed."; + return ret; + } + ret = DoGraphOutputDTypeTransform(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoGraphOutputDTypeTransform failed."; + return ret; + } + return RET_OK; +} + +int GraphInoutTransformPass::DoGraphInputDTypeTransform() { + CHECK_NULL_RETURN(this->func_graph_); + if (this->graph_input_dtype_ != TypeId::kNumberTypeFloat32 && this->graph_input_dtype_ != TypeId::kNumberTypeUInt8 && + this->graph_input_dtype_ != TypeId::kNumberTypeInt8 && this->graph_input_dtype_ != TypeId::kTypeUnknown) { + MS_LOG(ERROR) << "Invalid graph input dtype: " << this->graph_input_dtype_; + return RET_ERROR; + } + // not specify inputDataType + if (this->graph_input_dtype_ == TypeId::kTypeUnknown) { + return RET_OK; + } + auto cnodes = this->func_graph_->GetOrderedCnodes(); + for (auto &cnode : cnodes) { + for (size_t index = 1; index < cnode->size(); index++) { + auto input_node = cnode->input(index); + CHECK_NULL_RETURN(input_node); + if (!IsGraphInput(input_node)) { + continue; + } + TypeId input_node_dtype = TypeId::kTypeUnknown; + if (opt::GetDataTypeFromAnfNode(input_node, &input_node_dtype) != RET_OK) { + MS_LOG(ERROR) << "GetDataTypeFromAnfNode failed, input node name: " << input_node->fullname_with_scope(); + return RET_ERROR; + } + // graph input dtype transform + if (this->graph_input_dtype_ != input_node_dtype) { + auto ret = InsertDTypeCastNode(cnode, index, this->graph_input_dtype_, input_node_dtype, kQuant); + if (ret != RET_OK) { + MS_LOG(ERROR) << "InsertDTypeCastNode failed, cnode name: " << cnode->fullname_with_scope() + << " input index: " << index; + return ret; + } + } + } // for + } // for + return RET_OK; +} + +int GraphInoutTransformPass::DoGraphOutputDTypeTransform() { + CHECK_NULL_RETURN(this->func_graph_); + if (this->graph_output_dtype_ != TypeId::kNumberTypeFloat32 && + this->graph_output_dtype_ != TypeId::kNumberTypeUInt8 && this->graph_output_dtype_ != TypeId::kNumberTypeInt8 && + this->graph_output_dtype_ != TypeId::kTypeUnknown) { + MS_LOG(ERROR) << "Invalid graph output dtype: " << this->graph_input_dtype_; + return RET_ERROR; + } + if (this->graph_output_dtype_ == TypeId::kTypeUnknown) { + return RET_OK; + } + auto cnodes = this->func_graph_->GetOrderedCnodes(); + for (auto &cnode : cnodes) { + if (!IsGraphOutput(cnode)) { + continue; + } + for (size_t index = 1; index < cnode->size(); index++) { + auto input_node = cnode->input(index); + CHECK_NULL_RETURN(input_node); + auto input_cnode = std::dynamic_pointer_cast(input_node); + if (opt::CheckPrimitiveType(input_cnode, prim::kPrimMakeTuple)) { // MakeTuple + for (size_t input_index = 1; input_index < input_cnode->size(); input_index++) { + auto ret = DoOutputTransform(input_cnode, input_index); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoOutputTransform failed, cnode name: " << input_node->fullname_with_scope() + << " input index: " << input_index; + return ret; + } + } + } else { // Return + auto ret = DoOutputTransform(cnode, index); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoOutputTransform failed, cnode name: " << cnode->fullname_with_scope() + << " input index: " << index; + return ret; + } + } + } // for + } // for + return RET_OK; +} + +int GraphInoutTransformPass::DoOutputTransform(const CNodePtr &cnode, size_t index) { + CHECK_NULL_RETURN(cnode); + auto input_node = cnode->input(index); + CHECK_NULL_RETURN(input_node); + if (!input_node->isa()) { + return RET_OK; + } + TypeId input_node_dtype = TypeId::kTypeUnknown; + if (opt::GetDataTypeFromAnfNode(input_node, &input_node_dtype) != RET_OK) { + MS_LOG(ERROR) << "GetDataTypeFromAnfNode failed, input node name: " << input_node->fullname_with_scope(); + return RET_ERROR; + } + // graph output dtype transform + if (this->graph_output_dtype_ != input_node_dtype) { + auto ret = InsertDTypeCastNode(cnode, index, input_node_dtype, this->graph_output_dtype_, kDeQuant); + if (ret != RET_OK) { + MS_LOG(ERROR) << "InsertDTypeCastNode failed, cnode name: " << cnode->fullname_with_scope() + << " input index: " << index; + return ret; + } + } + return RET_OK; +} + +bool GraphInoutTransformPass::IsGraphOutput(const CNodePtr &cnode) { + return (opt::CheckPrimitiveType(cnode, prim::kPrimReturn)); +} + +int GraphInoutTransformPass::InsertDTypeCastNode(const CNodePtr &cnode, size_t input_index, TypeId src_dtype, + TypeId dst_dtype, CastNodeType node_type) { + std::vector input_quant_params; + std::vector output_quant_params; + CHECK_NULL_RETURN(cnode); + auto input_node = cnode->input(input_index); + CHECK_NULL_RETURN(input_node); + if (node_type == kQuant) { + auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(cnode); + CHECK_NULL_RETURN(curr_primitive_quant_param_holder); + if (curr_primitive_quant_param_holder->get_input_quant_params().size() < input_index) { + MS_LOG(ERROR) << "Invalid quant params."; + return RET_ERROR; + } + input_quant_params = curr_primitive_quant_param_holder->get_input_quant_params()[input_index - 1]; + } else if (node_type == kDeQuant) { + auto input_cnode = std::dynamic_pointer_cast(input_node); + auto input_primitive_quant_param_holder = GetCNodeQuantHolder(input_cnode); + CHECK_NULL_RETURN(input_primitive_quant_param_holder); + if (input_primitive_quant_param_holder->get_output_quant_params().empty()) { + MS_LOG(ERROR) << "Invalid quant params."; + return RET_ERROR; + } + input_quant_params = input_primitive_quant_param_holder->get_output_quant_params()[0]; + } + std::copy(input_quant_params.cbegin(), input_quant_params.cend(), std::back_inserter(output_quant_params)); + // update zeroPoint(uint8toint8, int8touint8) + if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeInt8) { + output_quant_params.front().zeroPoint -= kU8ZeroPointOffset; + } else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeUInt8) { + input_quant_params.front().zeroPoint -= kU8ZeroPointOffset; + } + + auto primitive = quant::NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params); + std::vector op_inputs = {primitive, input_node}; + auto quant_cast_cnode = this->func_graph_->NewCNode(op_inputs); + CHECK_NULL_RETURN(quant_cast_cnode); + quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + + std::to_string(input_index)); + + // set quant_cast_cnode quant type + AbstractBasePtr abstract; + if (cnode->abstract() != nullptr) { + abstract = cnode->abstract()->Clone(); + } else if (input_node->abstract() != nullptr) { + abstract = input_node->abstract()->Clone(); + } else { + MS_LOG(ERROR) << "Abstract is nullptr, cnode name: " << cnode->fullname_with_scope() + << " input node: " << input_node->fullname_with_scope(); + return RET_NULL_PTR; + } + quant_cast_cnode->set_abstract(abstract); + auto quant_cast_cnode_param_holder = GetCNodeQuantHolder(quant_cast_cnode); + CHECK_NULL_RETURN(quant_cast_cnode_param_holder); + quant_cast_cnode_param_holder->set_quant_type(schema::QuantType_QUANT_NONE); + + // update dtype: input_cnode, quant_cast_cnode + auto ret = quant::UpdateDataType(input_node, src_dtype); + if (ret != RET_OK) { + MS_LOG(ERROR) << "UpdateDataType failed, input node name: " << input_node->fullname_with_scope(); + return RET_ERROR; + } + cnode->set_input(input_index, quant_cast_cnode); + + ret = quant::UpdateDataType(std::dynamic_pointer_cast(quant_cast_cnode), dst_dtype); + if (ret != RET_OK) { + MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope(); + return RET_ERROR; + } + return RET_OK; +} +} // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/graph_inout_transform_pass.h b/mindspore/lite/tools/converter/quantizer/quant_helper/graph_inout_transform_pass.h new file mode 100644 index 00000000000..12a5f4cde07 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/graph_inout_transform_pass.h @@ -0,0 +1,59 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_GRAPH_INOUT_TRANSFORM_PASS_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_GRAPH_INOUT_TRANSFORM_PASS_H_ + +#include +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ops/primitive_c.h" +#include "tools/converter/quantizer/quant_param_holder.h" +#include "tools/converter/quantizer/quantize_util.h" +#include "tools/converter/quantizer/quant_params.h" + +namespace mindspore::lite::quant { +class GraphInoutTransformPass { + public: + explicit GraphInoutTransformPass(const FuncGraphPtr &func_graph, const std::shared_ptr param) + : func_graph_(func_graph), param_(param) { + CHECK_NULL_RETURN_VOID(this->param_); + this->graph_input_dtype_ = static_cast(this->param_->input_data_type); + this->graph_output_dtype_ = static_cast(this->param_->output_data_type); + } + + ~GraphInoutTransformPass() = default; + + int Transform(); + + private: + int DoGraphInputDTypeTransform(); + + int DoGraphOutputDTypeTransform(); + + int InsertDTypeCastNode(const CNodePtr &cnode, size_t input_index, TypeId src_dtype, TypeId dst_dtype, + CastNodeType node_type); + + int DoOutputTransform(const CNodePtr &cnode, size_t index); + + bool IsGraphOutput(const CNodePtr &cnode); + + FuncGraphPtr func_graph_ = nullptr; + std::shared_ptr param_ = nullptr; + TypeId graph_input_dtype_ = TypeId::kTypeUnknown; + TypeId graph_output_dtype_ = TypeId::kTypeUnknown; +}; +} // namespace mindspore::lite::quant +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_GRAPH_INOUT_TRANSFORM_PASS_H_ diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/propagete_quant_param_pass.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/propagete_quant_param_pass.cc index fb10cb1794a..7613506bd62 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/propagete_quant_param_pass.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/propagete_quant_param_pass.cc @@ -118,6 +118,9 @@ int PropagateQuantParamPass::ForwardPropagate(const std::list &nodes) if (IsGraphInput(cnode) || opt::IsSpecialType(cnode) || opt::CheckPrimitiveType(cnode, prim::kPrimLstm)) { continue; } + if (opt::CheckPrimitiveType(cnode, prim::kPrimLstm)) { + continue; + } // Infer quant param with forward (output->input). auto curr_quant_holder = GetCNodeQuantHolder(cnode); if (curr_quant_holder == nullptr) { @@ -182,6 +185,9 @@ int PropagateQuantParamPass::BackwardPropagate(const std::list &nodes) if (IsGraphInput(cnode) || opt::IsSpecialType(cnode) || opt::CheckPrimitiveType(cnode, prim::kPrimLstm)) { continue; } + if (opt::CheckPrimitiveType(cnode, prim::kPrimLstm)) { + continue; + } // Infer quant param with forward (output<-input). auto curr_quant_holder = GetCNodeQuantHolder(cnode); if (curr_quant_holder == nullptr) { @@ -214,7 +220,7 @@ int PropagateQuantParamPass::BackwardPropagate(const std::list &nodes) } } } else { - MS_LOG(ERROR) << "Support for multi output."; + MS_LOG(INFO) << "Support for multi output."; } } } diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_pass.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_pass.cc index 4752c1c204e..570d0a8d37e 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_pass.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_pass.cc @@ -41,6 +41,10 @@ int QuantNodePass::DoWeightQuant(const CNodePtr &cnode) { MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " can not quant weight"; continue; } + if (!CanTensorQuantized(cnode, input)) { + MS_LOG(INFO) << input->fullname_with_scope() << " is not quantized."; + continue; + } int preferred_dim = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32(weight->shape())); MS_CHECK_GT(static_cast(quant_param_holder->get_input_quant_params().size()), static_cast(idx - 1), RET_ERROR); @@ -265,6 +269,40 @@ int QuantNodePass::DoFullQuant(const CNodePtr &cnode) { return RET_OK; } +bool QuantNodePass::CanTensorQuantized(const CNodePtr &cnode, const AnfNodePtr &input_node) { + if (input_node == nullptr) { + MS_LOG(INFO) << "CanTensorQuantized input is nullptr!"; + return false; + } + ParameterPtr param_node = nullptr; + if (input_node->isa()) { + param_node = input_node->cast(); + } + if (param_node == nullptr) { + MS_LOG(INFO) << "CanTensorQuantized invalid param_node!"; + return false; + } + if (!param_node->has_default()) { + MS_LOG(INFO) << "param_node don't has default."; + return false; + } + auto abstract_base = param_node->abstract(); + if (abstract_base == nullptr) { + MS_LOG(INFO) << "abstract is nullptr"; + return false; + } + if (!utils::isa(abstract_base->GetShapeTrack())) { + MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << param_node->name(); + return false; + } + auto weight_shape = utils::cast(abstract_base->GetShapeTrack())->shape(); + MS_ASSERT(weight_shape != nullptr); + if (weight_shape.size() < DIMENSION_2D) { // do not quant single dim tensors + return false; + } + return true; +} + int QuantNodePass::Quant() { CHECK_NULL_RETURN(func_graph_); auto cnodes = func_graph_->GetOrderedCnodes(); diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_pass.h b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_pass.h index 7479cd0eae2..092df09a0e3 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_pass.h +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_node_pass.h @@ -47,6 +47,7 @@ class QuantNodePass { int DoParameterNodeQuant(const CNodePtr &cnode, const ParameterPtr &input_node, size_t input_index); int DoValueNodeQuant(const CNodePtr &cnode, const ValueNodePtr &input_node, size_t input_index); int IsSupportWeightQuant(const CNodePtr &cnode, const AnfNodePtr &input_node, size_t input_index); + bool CanTensorQuantized(const CNodePtr &cnode, const AnfNodePtr &input_node); private: FuncGraphPtr func_graph_ = nullptr; diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.cc index 9a1df763596..0f585efb132 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.cc @@ -26,9 +26,6 @@ #include "tools/common/node_util.h" namespace mindspore::lite::quant { -namespace { -static const std::set fp32_output_operator = {prim::kPrimDetectionPostProcess}; -} bool QuantTypeDeterminer::DetermineQuantAll(const CNodePtr &cnode) { MS_ASSERT(cnode != nullptr); if (opt::IsSpecialType(cnode)) { @@ -56,15 +53,28 @@ bool QuantTypeDeterminer::DetermineQuantAll(const CNodePtr &cnode) { if (quant_holder->quant_type() != schema::QuantType_QUANT_NONE) { return quant_holder->quant_type() == schema::QuantType_QUANT_ALL; } + // if DTypeCastNode is graph input or output, set QuantType_QUANT_NONE + if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) { + return (!IsGraphInDTypeCast(cnode) && !IsGraphOutDTypeCast(func_graph_, cnode)); + } // Check input quant params, quantization parameters exist for all activations. for (size_t i = 1; i < cnode->size(); ++i) { auto input = cnode->input(i); + TypeId input_node_dtype = kTypeUnknown; + if (opt::GetDataTypeFromAnfNode(input, &input_node_dtype) != RET_OK) { + MS_LOG(INFO) << "Get data type failed, input node name: " << input->fullname_with_scope(); + continue; + } + // Only specified dtype need check input quant params + if (kFullQuantDType.find(input_node_dtype) == kFullQuantDType.end()) { + continue; + } if (input->isa() && !quant_holder->CheckInit(i - kPrimOffset, true)) { return false; } } // Check output quant params. - if (CheckNodeInSet(cnode, fp32_output_operator) && !quant_holder->IsOutputQuantParamsInited()) { + if (!CheckNodeInSet(cnode, kUint8toFP32Operator) && !quant_holder->IsOutputQuantParamsInited()) { return false; } return true; @@ -118,7 +128,7 @@ int QuantTypeDeterminer::Determine() { MS_LOG(INFO) << cnode->fullname_with_scope() << " quant holder is nullptr."; continue; } - if (!quant_holder->IsInputQuantParamsInited() && !quant_holder->IsOutputQuantParamsInited()) { // Check FP32. + if (!quant_holder->IsInputExistInited() && !quant_holder->IsOutputExistInited()) { // Check FP32. if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) { continue; } diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.h b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.h index 7819bc29529..355b1e105c8 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.h +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/quant_type_determiner.h @@ -20,7 +20,6 @@ #include "ir/anf.h" #include "ir/func_graph.h" #include "tools/converter/quantizer/quantize_util.h" -#include "mindspore/core/ops/core_ops.h" namespace mindspore::lite::quant { class QuantTypeDeterminer { diff --git a/mindspore/lite/tools/converter/quantizer/quant_param_holder.cc b/mindspore/lite/tools/converter/quantizer/quant_param_holder.cc index b0c2b679099..7b16e4ee5a6 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_param_holder.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_param_holder.cc @@ -63,6 +63,26 @@ bool QuantParamHolder::IsOutputQuantParamsInited() { return is_quant_params_inited; } +bool QuantParamHolder::IsInputExistInited() { + if (this->input_quant_params_.empty()) { + return false; + } + bool is_exist_param_inited = + std::any_of(this->input_quant_params_.begin(), this->input_quant_params_.end(), + [](const std::vector &quant_params) { return quant_params.front().inited; }); + return is_exist_param_inited; +} + +bool QuantParamHolder::IsOutputExistInited() { + if (this->output_quant_params_.empty()) { + return false; + } + bool is_exist_param_inited = + std::any_of(this->output_quant_params_.begin(), this->output_quant_params_.end(), + [](const std::vector &quant_params) { return quant_params.front().inited; }); + return is_exist_param_inited; +} + void QuantParamHolder::ClearQuantParams() { quant_type_ = schema::QuantType_QUANT_NONE; input_quant_params_.clear(); diff --git a/mindspore/lite/tools/converter/quantizer/quant_param_holder.h b/mindspore/lite/tools/converter/quantizer/quant_param_holder.h index 4c773f9d940..66b7767b367 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_param_holder.h +++ b/mindspore/lite/tools/converter/quantizer/quant_param_holder.h @@ -110,6 +110,10 @@ class QuantParamHolder : public Value { bool IsOutputQuantParamsInited(); + bool IsInputExistInited(); + + bool IsOutputExistInited(); + void ClearQuantParams(); bool CheckInit(size_t index, bool is_input); diff --git a/mindspore/lite/tools/converter/quantizer/quant_params.h b/mindspore/lite/tools/converter/quantizer/quant_params.h index a91fca42a55..870b8bd8526 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_params.h +++ b/mindspore/lite/tools/converter/quantizer/quant_params.h @@ -44,11 +44,15 @@ constexpr int kPrimIndex = 0; constexpr int kPrimOffset = 1; constexpr int kU8ZeroPointOffset = 128; constexpr int kQuantRange = 127; +constexpr int kInt8LeftRange = -128; +constexpr int kInt8RightRange = 127; constexpr int kMinIterations = 40; const std::set kHasBiasOperator = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion, prim::kPrimMatMulFusion, prim::kPrimFullConnection, prim::kPrimLayerNormFusion}; +const std::set kUint8toFP32Operator = {prim::kPrimDetectionPostProcess}; +const std::set kFullQuantDType = {kNumberTypeInt8, kNumberTypeUInt8, kNumberTypeFloat32}; enum ActivationQuantizedMethod { MAX_MIN = 0, @@ -68,6 +72,17 @@ enum DebugMode { DETAIL, }; +enum CastNodeType { + kNone, + kQuant, + kDeQuant, +}; + +enum InsertDirection { + FORWARD, + BACKWARD, +}; + struct CommonQuantParam { schema::QuantType quant_type = schema::QuantType_QUANT_NONE; int bit_num = 8; diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 07ebc4c3378..097668fa80c 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -139,6 +139,51 @@ int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type) { return RET_OK; } +ValueNodePtr NewQuantCastPrimitive(int src_type, int dst_type, + const std::vector &input_quant_params, + const std::vector &output_quant_params) { + auto prim_c = std::make_shared(); + MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr."); + prim_c->Init(src_type, dst_type); + auto quant_params_holder = std::make_shared(input_quant_params.size(), output_quant_params.size()); + MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr."); + quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL); + quant_params_holder->set_input_quant_param(0, input_quant_params); + quant_params_holder->set_output_quant_param(0, output_quant_params); + auto prim = prim_c->GetPrim(); + MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr"); + prim->AddAttr("quant_params", quant_params_holder); + return NewValueNode(prim); +} + +bool IsGraphInDTypeCast(const CNodePtr &cnode) { + if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) { + return false; + } + auto input_node = cnode->input(1); + MS_CHECK_FALSE(input_node == nullptr, false); + return IsGraphInput(input_node); +} + +bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) { + return false; + } + auto manager = func_graph->manager(); + if (manager == nullptr) { + manager = Manage(func_graph, true); + } + CHECK_NULL_RETURN(manager); + auto node_users = manager->node_users()[cnode]; + for (auto &node_user : node_users) { + auto output_cnode = node_user.first->cast(); + if (!opt::CheckPrimitiveType(output_cnode, prim::kPrimReturn)) { + return false; + } + } + return true; +} + bool TensorQuantParamsInited(const schema::TensorT &tensor) { if (tensor.quantParams.empty()) { return false; diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index 94164c95bed..667256f6a87 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -43,8 +43,13 @@ #include "tools/converter/quantizer/mixed_bit_weight_quantization.h" #include "tools/common/string_util.h" #include "ops/core_ops.h" +#include "ops/quant_dtype_cast.h" namespace mindspore::lite::quant { +static const std::set has_bias_operator = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion, + prim::kPrimMatMulFusion, prim::kPrimFullConnection, + prim::kPrimLayerNormFusion}; + QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive); QuantParamHolderPtr GetCNodeQuantHolder(const CNodePtr &cnode); @@ -71,6 +76,13 @@ void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set *all_f int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type); +ValueNodePtr NewQuantCastPrimitive(int src_type, int dst_type, + const std::vector &input_quant_params, + const std::vector &output_quant_params); +bool IsGraphInDTypeCast(const CNodePtr &cnode); + +bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode); + template int DeQuantData(const int8_t *tensor_data, int64_t elements_num, std::vector quant_params, std::vector *dequant_data, int preferred_dim = 0) {