!40553 modify uint8 dtype transform pass

Merge pull request !40553 from liyan2022/dev_qat
This commit is contained in:
i-robot 2022-08-18 08:04:28 +00:00 committed by Gitee
commit efd05d4b71
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 846 additions and 126 deletions

View File

@ -102,6 +102,7 @@
#include "tools/converter/quantizer/quant_helper/quant_type_determiner.h"
#include "tools/converter/quantizer/quant_helper/propagete_quant_param_pass.h"
#include "tools/converter/quantizer/quant_helper/dtype_transform_pass.h"
#include "tools/converter/quantizer/quant_helper/graph_inout_transform_pass.h"
#include "tools/converter/quantizer/quant_helper/quant_node_pass.h"
#include "tools/converter/quantizer/insert_quant_node_manager.h"
@ -445,7 +446,7 @@ STATUS AnfTransform::QATTransform(const FuncGraphPtr &func_graph, const std::sha
quant::InsertQuantNodeManager inset_quant_node_pass;
ret = inset_quant_node_pass.InsertQuantDtypeCastNode(func_graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "add QuantCast error";
MS_LOG(ERROR) << "Add QuantCast error";
return RET_ERROR;
}
return RET_OK;

View File

@ -15,12 +15,13 @@
*/
#define USE_DEPRECATED_API
#include "mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h"
#include <memory>
#include <set>
#include <vector>
#include <string>
#include "ops/quant_dtype_cast.h"
#include <algorithm>
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/common/format_utils.h"
#include "tools/common/node_util.h"
@ -30,21 +31,24 @@ namespace {
constexpr size_t kMinSize3 = 3;
constexpr size_t kPrimitiveCOffset = 1;
} // namespace
ValueNodePtr InsertQuantNodeManager::NewQuantCastPrimitive(
int src_type, int dst_type, const std::vector<schema::QuantParamT> &input_quant_params,
const std::vector<schema::QuantParamT> &output_quant_params) {
auto prim_c = std::make_shared<ops::QuantDTypeCast>();
MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
prim_c->Init(src_type, dst_type);
auto quant_params_holder = std::make_shared<QuantParamHolder>(input_quant_params.size(), output_quant_params.size());
MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL);
quant_params_holder->set_input_quant_param(0, input_quant_params);
quant_params_holder->set_output_quant_param(0, output_quant_params);
auto prim = prim_c->GetPrim();
MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
prim->AddAttr("quant_params", quant_params_holder);
return NewValueNode(prim);
int InsertQuantNodeManager::SetCastNodeAbstrac(const CNodePtr &cnode, const AnfNodePtr &input_node,
const CNodePtr &cast_cnode) {
CHECK_NULL_RETURN(cnode);
CHECK_NULL_RETURN(input_node);
CHECK_NULL_RETURN(cast_cnode);
AbstractBasePtr abstract;
if (cnode->abstract() != nullptr) {
abstract = cnode->abstract()->Clone();
} else if (input_node->abstract() != nullptr) {
abstract = input_node->abstract()->Clone();
} else {
MS_LOG(ERROR) << "Abstract is nullptr, cnode name: " << cnode->fullname_with_scope()
<< " input node: " << input_node->fullname_with_scope();
return RET_NULL_PTR;
}
cast_cnode->set_abstract(abstract);
return RET_OK;
}
int InsertQuantNodeManager::InsertCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t input_index,
@ -95,47 +99,43 @@ int InsertQuantNodeManager::InsertCastNode(const FuncGraphPtr &graph, const CNod
<< " input data type: " << data_type << " index " << input_index;
return RET_ERROR;
}
ValueNodePtr value_node;
// quant node, uint8toint8: update input_quant_param of cnode
TypeId src_dtype;
TypeId dst_dtype;
std::vector<schema::QuantParamT> input_quant_params;
std::vector<schema::QuantParamT> output_quant_params;
// quant node, uint8toint8: update output_quant_params
if (insert_quant_node) {
src_dtype = data_type;
dst_dtype = kNumberTypeInt8;
auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(primitive);
if (curr_primitive_quant_param_holder->get_input_quant_params().size() < input_index) {
MS_LOG(ERROR) << "quant param is invalid.";
return RET_ERROR;
}
auto output_quant_params = curr_primitive_quant_param_holder->get_input_quant_params()[input_index - 1];
std::vector<schema::QuantParamT> input_quant_params(output_quant_params);
// Uint8toInt8
if (data_type == kNumberTypeUInt8) {
for (auto &quant_param : output_quant_params) {
quant_param.zeroPoint -= kU8ZeroPointOffset;
}
curr_primitive_quant_param_holder->set_input_quant_param(input_index - 1, output_quant_params);
}
value_node = NewQuantCastPrimitive(data_type, kNumberTypeInt8, input_quant_params, output_quant_params);
} else { // insert_dequant_node, int8touint8: update output_quant_params of input_node
output_quant_params = curr_primitive_quant_param_holder->get_input_quant_params()[input_index - 1];
std::copy(output_quant_params.cbegin(), output_quant_params.cend(), std::back_inserter(input_quant_params));
} else { // insert_dequant_node, int8touint8: update input_quant_params
src_dtype = kNumberTypeInt8;
dst_dtype = data_type;
auto input_primitive_quant_param_holder = GetCNodeQuantHolder(input_cnode_primitive_c);
if (input_primitive_quant_param_holder->get_output_quant_params().empty()) {
MS_LOG(ERROR) << "output quant param is empty.";
return RET_ERROR;
}
std::vector<schema::QuantParamT> input_quant_params =
input_primitive_quant_param_holder->get_output_quant_params()[0];
std::vector<schema::QuantParamT> output_quant_params(input_quant_params);
if (data_type == kNumberTypeUInt8) {
for (auto &quant_param : input_quant_params) {
quant_param.zeroPoint -= kU8ZeroPointOffset;
}
input_primitive_quant_param_holder->set_output_quant_param(0, input_quant_params);
}
value_node = NewQuantCastPrimitive(kNumberTypeInt8, data_type, input_quant_params, output_quant_params);
input_quant_params = input_primitive_quant_param_holder->get_output_quant_params()[0];
std::copy(input_quant_params.cbegin(), input_quant_params.cend(), std::back_inserter(output_quant_params));
}
std::vector<AnfNodePtr> op_inputs = {value_node, input_node};
ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params);
std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node};
auto quant_cast_cnode = graph->NewCNode(op_inputs);
MS_CHECK_TRUE_MSG(quant_cast_cnode != nullptr, RET_NULL_PTR, "quant_cast_cnode is nullptr.");
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast_" +
std::to_string(input_index));
cnode->set_input(input_index, quant_cast_cnode);
MS_LOG(INFO) << "InsertCastNode cnode name: " << quant_cast_cnode->fullname_with_scope()
<< " src_dtype: " << src_dtype << " dst_dtype: " << dst_dtype;
return RET_OK;
}
@ -171,8 +171,8 @@ int InsertQuantNodeManager::CheckDataType(const AnfNodePtr &input_node, TypeId c
TypeId type_id;
auto ret = opt::GetDataTypeFromAnfNode(input_node, &type_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Fetch DataType from cnode failed.";
return ret;
MS_LOG(WARNING) << "Fetch DataType from cnode failed.";
return RET_OK;
}
if (type_id != check_type_id) {
return RET_NO_CHANGE;
@ -181,9 +181,9 @@ int InsertQuantNodeManager::CheckDataType(const AnfNodePtr &input_node, TypeId c
return RET_OK;
}
int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph, TypeId src_dtype) {
int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph, TypeId cast_dtype) {
MS_ASSERT(graph != nullptr);
if (src_dtype != kNumberTypeFloat32 && src_dtype != kNumberTypeUInt8) {
if (cast_dtype != kNumberTypeFloat32 && cast_dtype != kNumberTypeUInt8) {
MS_LOG(ERROR) << "Invalid src dtype, only support fp32 and uint8.";
return RET_ERROR;
}
@ -192,11 +192,7 @@ int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph,
for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i);
bool is_graph_input = IsGraphInput(input_node);
// Uint8 check quant params inited
if (src_dtype == kNumberTypeUInt8 && !is_graph_input && !CheckInited(input_node, i)) {
continue;
}
auto ret = CheckDataType(input_node, src_dtype);
auto ret = CheckDataType(input_node, cast_dtype);
if (ret == RET_NO_CHANGE) {
continue;
} else if (ret != RET_OK) {
@ -312,4 +308,139 @@ int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
}
return RET_OK;
}
int InsertQuantNodeManager::InserQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode,
InsertDirection insert_direction, TypeId cast_dtype,
CastNodeType cast_node_type, size_t index,
const AnfNodePtr &output_node) {
if (insert_direction == FORWARD && cast_node_type == kQuant) {
return InserForwardQuantCastNode(graph, cnode, cast_dtype, index);
} else if (insert_direction == BACKWARD && cast_node_type == kDeQuant) {
return InserBackwardDeQuantCastNode(graph, cnode, cast_dtype, index, output_node);
}
MS_LOG(ERROR) << "Invalie insert direction: " << insert_direction;
return RET_NOT_SUPPORT;
}
int InsertQuantNodeManager::InserForwardQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode,
TypeId cast_dtype, size_t index) {
if (cast_dtype != kNumberTypeUInt8 && cast_dtype != kNumberTypeFloat32) {
MS_LOG(ERROR) << "Invalie cast dtype: " << cast_dtype;
return RET_NOT_SUPPORT;
}
auto input_node = cnode->input(index);
CHECK_NULL_RETURN(input_node);
if (!input_node->isa<mindspore::CNode>() && !IsGraphInput(input_node)) {
MS_LOG(ERROR) << "Invalid input node, input node name: " << input_node->fullname_with_scope();
return RET_ERROR;
}
auto ret = CheckDataType(input_node, cast_dtype);
if (ret == RET_NO_CHANGE) {
MS_LOG(DEBUG) << "input node not dtype: " << cast_dtype;
return RET_OK;
} else if (ret != RET_OK) {
MS_LOG(ERROR) << "Check data type failed, input node name: " << input_node->fullname_with_scope();
return ret;
}
// insert forward cast_node
TypeId src_dtype = cast_dtype;
TypeId dst_dtype = kNumberTypeInt8;
std::vector<schema::QuantParamT> input_quant_params;
std::vector<schema::QuantParamT> output_quant_params;
auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(cnode);
CHECK_NULL_RETURN(curr_primitive_quant_param_holder);
if (curr_primitive_quant_param_holder->get_input_quant_params().size() < index) {
MS_LOG(ERROR) << "quant param is invalid.";
return RET_ERROR;
}
output_quant_params = curr_primitive_quant_param_holder->get_input_quant_params()[index - 1];
std::copy(output_quant_params.cbegin(), output_quant_params.cend(), std::back_inserter(input_quant_params));
// Uint8toInt8
if (src_dtype == kNumberTypeUInt8) {
for (auto &quant_param : input_quant_params) {
quant_param.zeroPoint += kU8ZeroPointOffset;
}
}
ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params);
std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node};
auto quant_cast_cnode = graph->NewCNode(op_inputs);
MS_CHECK_TRUE_MSG(quant_cast_cnode != nullptr, RET_NULL_PTR, "quant_cast_cnode is nullptr.");
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) +
"_pre");
if (SetCastNodeAbstrac(cnode, input_node, quant_cast_cnode) != RET_OK) {
MS_LOG(ERROR) << "SetCastNodeAbstrac failed.";
return RET_ERROR;
}
if (quant::UpdateDataType(quant_cast_cnode, dst_dtype) != RET_OK) {
MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
return RET_ERROR;
}
cnode->set_input(index, quant_cast_cnode);
MS_LOG(INFO) << "InserForwardQuantCastNode cnode name: " << cnode->fullname_with_scope() << " src dtype:" << src_dtype
<< " dst_type: " << dst_dtype;
return RET_OK;
}
int InsertQuantNodeManager::InserBackwardDeQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode,
TypeId cast_dtype, size_t index,
const AnfNodePtr &output_node) {
if (cast_dtype != kNumberTypeUInt8 && cast_dtype != kNumberTypeFloat32) {
MS_LOG(ERROR) << "Invalie cast dtype: " << cast_dtype;
return RET_NOT_SUPPORT;
}
CHECK_NULL_RETURN(output_node);
auto ret = CheckDataType(output_node, cast_dtype);
if (ret == RET_NO_CHANGE) {
MS_LOG(DEBUG) << "input node not dtype: " << cast_dtype;
return RET_OK;
} else if (ret != RET_OK) {
MS_LOG(ERROR) << "Check data type failed, output node name: " << output_node->fullname_with_scope();
return ret;
}
auto manager = graph->manager();
if (manager == nullptr) {
manager = Manage(graph, true);
}
CHECK_NULL_RETURN(manager);
// insert backward cast_node
TypeId src_dtype = kNumberTypeInt8;
TypeId dst_dtype = cast_dtype;
std::vector<schema::QuantParamT> input_quant_params;
std::vector<schema::QuantParamT> output_quant_params;
auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(cnode);
CHECK_NULL_RETURN(curr_primitive_quant_param_holder);
if (curr_primitive_quant_param_holder->get_output_quant_params().empty()) {
MS_LOG(ERROR) << "quant param is invalid.";
return RET_ERROR;
}
input_quant_params = curr_primitive_quant_param_holder->get_output_quant_params().front();
std::copy(input_quant_params.cbegin(), input_quant_params.cend(), std::back_inserter(output_quant_params));
// Int8toUint8
if (dst_dtype == kNumberTypeUInt8) {
for (auto &quant_param : output_quant_params) {
quant_param.zeroPoint += kU8ZeroPointOffset;
}
}
ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params);
std::vector<AnfNodePtr> op_inputs = {new_primitive, cnode->cast<AnfNodePtr>()};
auto quant_cast_cnode = graph->NewCNode(op_inputs);
MS_CHECK_TRUE_MSG(quant_cast_cnode != nullptr, RET_NULL_PTR, "quant_cast_cnode is nullptr.");
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) +
"_post");
if (SetCastNodeAbstrac(cnode, output_node, quant_cast_cnode) != RET_OK) {
MS_LOG(ERROR) << "SetCastNodeAbstrac failed.";
return RET_ERROR;
}
if (quant::UpdateDataType(quant_cast_cnode, dst_dtype) != RET_OK) {
MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
return RET_ERROR;
}
manager->SetEdge(output_node, index, quant_cast_cnode);
MS_LOG(INFO) << "InserBackwardDeQuantCastNode cnode name: " << cnode->fullname_with_scope()
<< " src dtype:" << src_dtype << " dst_type: " << dst_dtype;
return RET_OK;
}
} // namespace mindspore::lite::quant

View File

@ -33,16 +33,14 @@ class InsertQuantNodeManager {
~InsertQuantNodeManager() = default;
int InsertQuantDtypeCastNode(const FuncGraphPtr &graph, TypeId src_dtype = kNumberTypeFloat32);
int InsertQuantDtypeCastNode(const FuncGraphPtr &graph, TypeId cast_dtype = kNumberTypeFloat32);
int InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set<PrimitivePtr> &support_dynamic_quant_ops,
const std::set<std::string> &skip_quant_node);
int InserQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, InsertDirection insert_direction,
TypeId cast_dtype, CastNodeType cast_node_type, size_t index, const AnfNodePtr &output_node);
private:
ValueNodePtr NewQuantCastPrimitive(int src_type, int dst_type,
const std::vector<schema::QuantParamT> &input_quant_params,
const std::vector<schema::QuantParamT> &output_quant_params);
int InsertCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t input_index, bool is_graph_input);
bool CheckInited(const AnfNodePtr &input_node, size_t index) const;
@ -55,6 +53,13 @@ class InsertQuantNodeManager {
int InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index);
int SetCastNodeAbstrac(const CNodePtr &cnode, const AnfNodePtr &input_node, const CNodePtr &cast_cnode);
int InserForwardQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype, size_t index);
int InserBackwardDeQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype, size_t index,
const AnfNodePtr &output_node);
private:
TypeId dst_type_ = kNumberTypeInt8;
bool symmetric_ = false;

View File

@ -15,75 +15,55 @@
*/
#include "tools/converter/quantizer/quant_helper/dtype_transform_pass.h"
#include <set>
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include "tools/common/node_util.h"
#include "tools/converter/quantizer/insert_quant_node_manager.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/optimizer/common/format_utils.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::lite::quant {
// only enable for uint8
int DTypeTransformPass::Transform() {
CHECK_NULL_RETURN(func_graph_);
// insert CastNode Uint8toInt8 & Int8toUint8
quant::InsertQuantNodeManager insert_quant_node_manager;
auto ret = insert_quant_node_manager.InsertQuantDtypeCastNode(func_graph_, kNumberTypeUInt8);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert Uint8toInt8 CastNode failed.";
return RET_ERROR;
}
// update data type and zp
auto cnodes = func_graph_->GetOrderedCnodes();
for (auto &cnode : cnodes) {
if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
if (!CheckNeedDTypeTrans(cnode)) {
MS_LOG(INFO) << "CheckNeedDTypeTrans invalid cnode, cnode name: " << cnode->fullname_with_scope();
continue;
}
TypeId cnode_dtype;
if (opt::GetDataTypeFromAnfNode(cnode, &cnode_dtype) != RET_OK) {
MS_LOG(ERROR) << "Get data type failed, cnode type: " << cnode->type_name();
auto status = DoNodeDTypeTrans(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoNodeDTypeTrans failed, cnode name: " << cnode->fullname_with_scope();
return status;
}
schema::QuantType curr_quant_type;
if (GetQuantType(cnode, &curr_quant_type) != RET_OK) {
MS_LOG(ERROR) << "Get quant type failed, cnode name: " << cnode->fullname_with_scope();
return RET_ERROR;
}
if (cnode_dtype != kNumberTypeUInt8) {
if (curr_quant_type != schema::QuantType_QUANT_ALL) {
MS_LOG(ERROR) << "Invalid cnode quant type, cnode name: " << cnode->fullname_with_scope()
<< " quant type: " << curr_quant_type;
continue;
}
if (UpdateDataType(cnode, kNumberTypeInt8) != RET_OK) {
MS_LOG(ERROR) << "Update data type failed, cnode name: " << cnode->fullname_with_scope();
return RET_ERROR;
status = InsertForwardCastNode(cnode, curr_quant_type);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertForwardCastNode failed, cnode name: " << cnode->fullname_with_scope();
return status;
}
auto curr_quant_param_holder = GetCNodeQuantHolder(cnode);
CHECK_NULL_RETURN(curr_quant_param_holder);
if (curr_quant_param_holder->get_output_quant_params().empty()) {
MS_LOG(ERROR) << "output quant params empty.";
return RET_ERROR;
}
auto out_quant_params = curr_quant_param_holder->get_output_quant_params()[0];
for (auto &quant_param : out_quant_params) {
quant_param.zeroPoint -= kU8ZeroPointOffset;
}
curr_quant_param_holder->set_output_quant_param(0, out_quant_params);
for (size_t i = 1; i < cnode->size(); i++) {
auto input_node = cnode->input(i);
CHECK_NULL_RETURN(input_node);
if (IsGraphInput(input_node)) {
continue;
} else if (input_node->isa<mindspore::CNode>()) {
// updata input_quant_params
if (curr_quant_param_holder->get_input_quant_params().size() < i) {
MS_LOG(ERROR) << "quant params invalid.";
return RET_ERROR;
}
auto input_quant_params = curr_quant_param_holder->get_input_quant_params()[i - 1];
for (auto &quant_param : input_quant_params) {
quant_param.zeroPoint -= kU8ZeroPointOffset;
}
curr_quant_param_holder->set_input_quant_param(i - 1, input_quant_params);
} else if (input_node->isa<mindspore::Parameter>()) {
ret = DoParameterNodeTrans(cnode, input_node->cast<ParameterPtr>(), i);
if (ret != RET_OK) {
MS_LOG(WARNING) << "DoParameterNodeTrans failed, input node name: " << input_node->fullname_with_scope();
}
// DetectionPostProcess op(Uint8toFp32, not need backward cast node)
if (!CheckNodeInSet(cnode, kUint8toFP32Operator)) {
status = InsertBackwardCastNode(cnode, curr_quant_type);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertBackwardCastNode failed, cnode name: " << cnode->fullname_with_scope();
return status;
}
}
}
} // for
return RET_OK;
}
@ -107,8 +87,11 @@ int DTypeTransformPass::DoParameterNodeTrans(const CNodePtr &cnode, const Parame
return ret;
}
auto quant_param_holder = GetCNodeQuantHolder(cnode);
if (quant_param_holder->get_input_quant_params().size() < input_index) {
MS_LOG(ERROR) << "Invalid quant params. input node name: " << input_node->fullname_with_scope();
return RET_ERROR;
}
auto quant_params = quant_param_holder->get_input_quant_params().at(input_index - 1);
MS_CHECK_FALSE_MSG(quant_params.empty(), RET_ERROR, "Quant params is empty.");
for (auto &quant_param : quant_params) {
quant_param.zeroPoint -= kU8ZeroPointOffset;
}
@ -121,18 +104,6 @@ int DTypeTransformPass::DoParameterNodeTrans(const CNodePtr &cnode, const Parame
MS_LOG(ERROR) << input_node->fullname_with_scope() << " set new dtype failed.";
return ret;
}
auto abstract_base = input_node->abstract();
CHECK_NULL_RETURN(abstract_base);
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of node should be abstract tensor, input node name: "
<< input_node->fullname_with_scope();
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
CHECK_NULL_RETURN(abstract_tensor);
CHECK_NULL_RETURN(abstract_tensor->element());
abstract_tensor->element()->set_type(TypeIdToType(kNumberTypeInt8));
return RET_OK;
}
@ -151,4 +122,164 @@ int DTypeTransformPass::Uint8toInt8(uint8_t *data, int size) {
}
return RET_OK;
}
int DTypeTransformPass::GetQuantType(const CNodePtr &cnode, schema::QuantType *quant_type) {
CHECK_NULL_RETURN(cnode);
auto quant_param_holder = GetCNodeQuantHolder(cnode);
CHECK_NULL_RETURN(quant_param_holder);
*quant_type = quant_param_holder->quant_type();
return RET_OK;
}
/**
* Transform CNode(dtype,uint8toint8,weigh data)
* */
int DTypeTransformPass::DoNodeDTypeTrans(const CNodePtr &cnode) {
auto curr_quant_param_holder = GetCNodeQuantHolder(cnode);
CHECK_NULL_RETURN(curr_quant_param_holder);
TypeId cnode_dtype = kTypeUnknown;
if (opt::GetDataTypeFromAnfNode(cnode, &cnode_dtype) != RET_OK) {
MS_LOG(ERROR) << "Get data type failed, cnode name: " << cnode->fullname_with_scope();
return RET_ERROR;
}
if (cnode_dtype == kNumberTypeUInt8) {
MS_LOG(INFO) << "cnode dtype kNumberTypeUInt8, cnode name: " << cnode->fullname_with_scope();
if (UpdateDataType(cnode, kNumberTypeInt8) != RET_OK) {
MS_LOG(ERROR) << "Update data type failed, cnode name: " << cnode->fullname_with_scope();
return RET_ERROR;
}
if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
auto primitive_c = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(0));
auto primc = api::MakeShared<mindspore::ops::QuantDTypeCast>(primitive_c);
primc->set_dst_t(kNumberTypeInt8);
}
// update output quant param zp
if (curr_quant_param_holder->get_output_quant_params().empty()) {
MS_LOG(ERROR) << "output quant params empty.";
return RET_ERROR;
}
auto out_quant_params = curr_quant_param_holder->get_output_quant_params()[0];
for (auto &quant_param : out_quant_params) {
quant_param.zeroPoint -= kU8ZeroPointOffset;
}
curr_quant_param_holder->set_output_quant_param(0, out_quant_params);
}
// DTypeCastNode, set quant type
if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
curr_quant_param_holder->set_quant_type(schema::QuantType_QUANT_NONE);
}
for (size_t index = 1; index < cnode->size(); index++) {
auto input_node = cnode->input(index);
CHECK_NULL_RETURN(input_node);
if (IsGraphInput(input_node) || input_node->isa<mindspore::CNode>()) {
// updata graph input quant params
if (curr_quant_param_holder->get_input_quant_params().size() < index) {
MS_LOG(WARNING) << "quant params invalid, input node name: " << input_node->fullname_with_scope();
continue;
}
auto input_quant_params = curr_quant_param_holder->get_input_quant_params()[index - 1];
if (input_quant_params.empty() || !input_quant_params.front().inited) {
MS_LOG(WARNING) << "input node not quantizied, input node name: " << input_node->fullname_with_scope();
continue;
}
for (auto &quant_param : input_quant_params) {
quant_param.zeroPoint -= kU8ZeroPointOffset;
}
curr_quant_param_holder->set_input_quant_param(index - 1, input_quant_params);
} else if (input_node->isa<mindspore::Parameter>()) { // weight data
auto ret = DoParameterNodeTrans(cnode, input_node->cast<ParameterPtr>(), index);
if (ret != RET_OK) {
MS_LOG(WARNING) << "DoParameterNodeTrans failed, input node name: " << input_node->fullname_with_scope();
}
}
}
return RET_OK;
}
int DTypeTransformPass::InsertForwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type) {
// inputs
quant::InsertQuantNodeManager insert_node_manager;
for (size_t index = 1; index < cnode->size(); index++) {
auto input_node = cnode->input(index);
CHECK_NULL_RETURN(input_node);
if (!input_node->isa<mindspore::CNode>() && !IsGraphInput(input_node)) {
MS_LOG(DEBUG) << "Invalid input node, not CNode and graph input.";
continue;
}
schema::QuantType input_quant_type;
if (GetQuantType(cnode, &input_quant_type) != RET_OK) {
MS_LOG(WARNING) << "Get quant type failed, input node name: " << input_node->fullname_with_scope();
return RET_ERROR;
}
schema::QuantType pre_quant_type = schema::QuantType_QUANT_NONE;
if (input_node->isa<mindspore::CNode>()) {
if (GetQuantType(input_node->cast<mindspore::CNodePtr>(), &pre_quant_type) != RET_OK) {
MS_LOG(ERROR) << "Get quant type failed, cnode name: " << cnode->fullname_with_scope();
return RET_ERROR;
}
}
if (pre_quant_type == schema::QuantType_QUANT_NONE && curr_quant_type == schema::QuantType_QUANT_ALL) {
auto status = insert_node_manager.InserQuantCastNode(this->func_graph_, cnode, FORWARD, kNumberTypeUInt8, kQuant,
index, nullptr);
if (status != RET_OK) {
MS_LOG(ERROR) << "InserQuantCastNode failed, cnode name: " << cnode->fullname_with_scope();
return status;
}
MS_LOG(INFO) << "InserQuantCastNode forward Uint8toInt8, cnode name: " << cnode->fullname_with_scope();
}
}
return RET_OK;
}
int DTypeTransformPass::InsertBackwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type) {
// outputs
auto manager = this->func_graph_->manager();
if (manager == nullptr) {
manager = Manage(this->func_graph_, true);
}
CHECK_NULL_RETURN(manager);
auto node_users = manager->node_users()[cnode];
quant::InsertQuantNodeManager insert_node_manager;
for (auto &node_user : node_users) {
auto output_cnode = node_user.first->cast<CNodePtr>();
schema::QuantType post_quant_type;
if (GetQuantType(output_cnode, &post_quant_type) != RET_OK) {
MS_LOG(ERROR) << "Get quant type failed, cnode name: " << output_cnode->fullname_with_scope();
return RET_ERROR;
}
if (curr_quant_type == schema::QuantType_QUANT_ALL && post_quant_type == schema::QuantType_QUANT_NONE) {
auto status = insert_node_manager.InserQuantCastNode(this->func_graph_, cnode, BACKWARD, kNumberTypeUInt8,
kDeQuant, node_user.second, node_user.first);
if (status != RET_OK) {
MS_LOG(ERROR) << "InserQuantCastNode dequant failed, cnode name: " << cnode->fullname_with_scope();
return status;
}
MS_LOG(INFO) << "InserQuantCastNode backward Int8toUint8, cnode name: " << cnode->fullname_with_scope();
}
} // node_users
return RET_OK;
}
bool DTypeTransformPass::CheckNeedDTypeTrans(const CNodePtr &cnode) {
if (opt::IsSpecialType(cnode)) {
return false;
}
if (IsGraphInDTypeCast(cnode) || IsGraphOutDTypeCast(func_graph_, cnode)) {
return false;
}
TypeId cnode_dtype = kTypeUnknown;
if (opt::GetDataTypeFromAnfNode(cnode, &cnode_dtype) != RET_OK) {
MS_LOG(ERROR) << "Get data type failed, cnode name: " << cnode->fullname_with_scope();
return false;
}
bool is_fp32_output =
opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast) || CheckNodeInSet(cnode, kUint8toFP32Operator);
if (cnode_dtype != kNumberTypeUInt8 && !is_fp32_output) {
MS_LOG(DEBUG) << "dtype not kNumberTypeUInt8, cnode name: " << cnode->fullname_with_scope();
return false;
}
return true;
}
} // namespace mindspore::lite::quant

View File

@ -16,24 +16,43 @@
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_DTYPE_TRANSFORM_PASS_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_DTYPE_TRANSFORM_PASS_H_
#include <memory>
#include <vector>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ops/primitive_c.h"
#include "tools/converter/quantizer/quant_param_holder.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/converter/quantizer/quant_params.h"
namespace mindspore::lite::quant {
/**
* Transform CNode(dtype,uint8toint8,weigh data)
* Insert QuantCastNode
* */
class DTypeTransformPass {
public:
explicit DTypeTransformPass(const FuncGraphPtr &func_graph) : func_graph_(func_graph) {}
~DTypeTransformPass() = default;
int Transform();
private:
int DoParameterNodeTrans(const CNodePtr &cnode, const ParameterPtr &input_node, size_t input_index);
int Uint8toInt8(uint8_t *data, int size);
int GetQuantType(const CNodePtr &cnode, schema::QuantType *quant_type);
int DoNodeDTypeTrans(const CNodePtr &cnode);
bool CheckNeedDTypeTrans(const CNodePtr &cnode);
int InsertForwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type);
int InsertBackwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type);
FuncGraphPtr func_graph_ = nullptr;
};
} // namespace mindspore::lite::quant

View File

@ -0,0 +1,224 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/graph_inout_transform_pass.h"
#include <set>
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include "tools/common/node_util.h"
#include "tools/converter/quantizer/insert_quant_node_manager.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/optimizer/common/format_utils.h"
#include "ops/quant_dtype_cast.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::lite::quant {
int GraphInoutTransformPass::Transform() {
auto ret = DoGraphInputDTypeTransform();
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoGraphInputDTypeTransform failed.";
return ret;
}
ret = DoGraphOutputDTypeTransform();
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoGraphOutputDTypeTransform failed.";
return ret;
}
return RET_OK;
}
int GraphInoutTransformPass::DoGraphInputDTypeTransform() {
CHECK_NULL_RETURN(this->func_graph_);
if (this->graph_input_dtype_ != TypeId::kNumberTypeFloat32 && this->graph_input_dtype_ != TypeId::kNumberTypeUInt8 &&
this->graph_input_dtype_ != TypeId::kNumberTypeInt8 && this->graph_input_dtype_ != TypeId::kTypeUnknown) {
MS_LOG(ERROR) << "Invalid graph input dtype: " << this->graph_input_dtype_;
return RET_ERROR;
}
// not specify inputDataType
if (this->graph_input_dtype_ == TypeId::kTypeUnknown) {
return RET_OK;
}
auto cnodes = this->func_graph_->GetOrderedCnodes();
for (auto &cnode : cnodes) {
for (size_t index = 1; index < cnode->size(); index++) {
auto input_node = cnode->input(index);
CHECK_NULL_RETURN(input_node);
if (!IsGraphInput(input_node)) {
continue;
}
TypeId input_node_dtype = TypeId::kTypeUnknown;
if (opt::GetDataTypeFromAnfNode(input_node, &input_node_dtype) != RET_OK) {
MS_LOG(ERROR) << "GetDataTypeFromAnfNode failed, input node name: " << input_node->fullname_with_scope();
return RET_ERROR;
}
// graph input dtype transform
if (this->graph_input_dtype_ != input_node_dtype) {
auto ret = InsertDTypeCastNode(cnode, index, this->graph_input_dtype_, input_node_dtype, kQuant);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InsertDTypeCastNode failed, cnode name: " << cnode->fullname_with_scope()
<< " input index: " << index;
return ret;
}
}
} // for
} // for
return RET_OK;
}
int GraphInoutTransformPass::DoGraphOutputDTypeTransform() {
CHECK_NULL_RETURN(this->func_graph_);
if (this->graph_output_dtype_ != TypeId::kNumberTypeFloat32 &&
this->graph_output_dtype_ != TypeId::kNumberTypeUInt8 && this->graph_output_dtype_ != TypeId::kNumberTypeInt8 &&
this->graph_output_dtype_ != TypeId::kTypeUnknown) {
MS_LOG(ERROR) << "Invalid graph output dtype: " << this->graph_input_dtype_;
return RET_ERROR;
}
if (this->graph_output_dtype_ == TypeId::kTypeUnknown) {
return RET_OK;
}
auto cnodes = this->func_graph_->GetOrderedCnodes();
for (auto &cnode : cnodes) {
if (!IsGraphOutput(cnode)) {
continue;
}
for (size_t index = 1; index < cnode->size(); index++) {
auto input_node = cnode->input(index);
CHECK_NULL_RETURN(input_node);
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
if (opt::CheckPrimitiveType(input_cnode, prim::kPrimMakeTuple)) { // MakeTuple
for (size_t input_index = 1; input_index < input_cnode->size(); input_index++) {
auto ret = DoOutputTransform(input_cnode, input_index);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoOutputTransform failed, cnode name: " << input_node->fullname_with_scope()
<< " input index: " << input_index;
return ret;
}
}
} else { // Return
auto ret = DoOutputTransform(cnode, index);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoOutputTransform failed, cnode name: " << cnode->fullname_with_scope()
<< " input index: " << index;
return ret;
}
}
} // for
} // for
return RET_OK;
}
int GraphInoutTransformPass::DoOutputTransform(const CNodePtr &cnode, size_t index) {
CHECK_NULL_RETURN(cnode);
auto input_node = cnode->input(index);
CHECK_NULL_RETURN(input_node);
if (!input_node->isa<mindspore::CNode>()) {
return RET_OK;
}
TypeId input_node_dtype = TypeId::kTypeUnknown;
if (opt::GetDataTypeFromAnfNode(input_node, &input_node_dtype) != RET_OK) {
MS_LOG(ERROR) << "GetDataTypeFromAnfNode failed, input node name: " << input_node->fullname_with_scope();
return RET_ERROR;
}
// graph output dtype transform
if (this->graph_output_dtype_ != input_node_dtype) {
auto ret = InsertDTypeCastNode(cnode, index, input_node_dtype, this->graph_output_dtype_, kDeQuant);
if (ret != RET_OK) {
MS_LOG(ERROR) << "InsertDTypeCastNode failed, cnode name: " << cnode->fullname_with_scope()
<< " input index: " << index;
return ret;
}
}
return RET_OK;
}
bool GraphInoutTransformPass::IsGraphOutput(const CNodePtr &cnode) {
return (opt::CheckPrimitiveType(cnode, prim::kPrimReturn));
}
int GraphInoutTransformPass::InsertDTypeCastNode(const CNodePtr &cnode, size_t input_index, TypeId src_dtype,
TypeId dst_dtype, CastNodeType node_type) {
std::vector<schema::QuantParamT> input_quant_params;
std::vector<schema::QuantParamT> output_quant_params;
CHECK_NULL_RETURN(cnode);
auto input_node = cnode->input(input_index);
CHECK_NULL_RETURN(input_node);
if (node_type == kQuant) {
auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(cnode);
CHECK_NULL_RETURN(curr_primitive_quant_param_holder);
if (curr_primitive_quant_param_holder->get_input_quant_params().size() < input_index) {
MS_LOG(ERROR) << "Invalid quant params.";
return RET_ERROR;
}
input_quant_params = curr_primitive_quant_param_holder->get_input_quant_params()[input_index - 1];
} else if (node_type == kDeQuant) {
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
auto input_primitive_quant_param_holder = GetCNodeQuantHolder(input_cnode);
CHECK_NULL_RETURN(input_primitive_quant_param_holder);
if (input_primitive_quant_param_holder->get_output_quant_params().empty()) {
MS_LOG(ERROR) << "Invalid quant params.";
return RET_ERROR;
}
input_quant_params = input_primitive_quant_param_holder->get_output_quant_params()[0];
}
std::copy(input_quant_params.cbegin(), input_quant_params.cend(), std::back_inserter(output_quant_params));
// update zeroPoint(uint8toint8, int8touint8)
if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeInt8) {
output_quant_params.front().zeroPoint -= kU8ZeroPointOffset;
} else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeUInt8) {
input_quant_params.front().zeroPoint -= kU8ZeroPointOffset;
}
auto primitive = quant::NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params);
std::vector<AnfNodePtr> op_inputs = {primitive, input_node};
auto quant_cast_cnode = this->func_graph_->NewCNode(op_inputs);
CHECK_NULL_RETURN(quant_cast_cnode);
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" +
std::to_string(input_index));
// set quant_cast_cnode quant type
AbstractBasePtr abstract;
if (cnode->abstract() != nullptr) {
abstract = cnode->abstract()->Clone();
} else if (input_node->abstract() != nullptr) {
abstract = input_node->abstract()->Clone();
} else {
MS_LOG(ERROR) << "Abstract is nullptr, cnode name: " << cnode->fullname_with_scope()
<< " input node: " << input_node->fullname_with_scope();
return RET_NULL_PTR;
}
quant_cast_cnode->set_abstract(abstract);
auto quant_cast_cnode_param_holder = GetCNodeQuantHolder(quant_cast_cnode);
CHECK_NULL_RETURN(quant_cast_cnode_param_holder);
quant_cast_cnode_param_holder->set_quant_type(schema::QuantType_QUANT_NONE);
// update dtype: input_cnode, quant_cast_cnode
auto ret = quant::UpdateDataType(input_node, src_dtype);
if (ret != RET_OK) {
MS_LOG(ERROR) << "UpdateDataType failed, input node name: " << input_node->fullname_with_scope();
return RET_ERROR;
}
cnode->set_input(input_index, quant_cast_cnode);
ret = quant::UpdateDataType(std::dynamic_pointer_cast<mindspore::AnfNode>(quant_cast_cnode), dst_dtype);
if (ret != RET_OK) {
MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
return RET_ERROR;
}
return RET_OK;
}
} // namespace mindspore::lite::quant

View File

@ -0,0 +1,59 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_GRAPH_INOUT_TRANSFORM_PASS_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_GRAPH_INOUT_TRANSFORM_PASS_H_
#include <memory>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ops/primitive_c.h"
#include "tools/converter/quantizer/quant_param_holder.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/converter/quantizer/quant_params.h"
namespace mindspore::lite::quant {
class GraphInoutTransformPass {
public:
explicit GraphInoutTransformPass(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> param)
: func_graph_(func_graph), param_(param) {
CHECK_NULL_RETURN_VOID(this->param_);
this->graph_input_dtype_ = static_cast<TypeId>(this->param_->input_data_type);
this->graph_output_dtype_ = static_cast<TypeId>(this->param_->output_data_type);
}
~GraphInoutTransformPass() = default;
int Transform();
private:
int DoGraphInputDTypeTransform();
int DoGraphOutputDTypeTransform();
int InsertDTypeCastNode(const CNodePtr &cnode, size_t input_index, TypeId src_dtype, TypeId dst_dtype,
CastNodeType node_type);
int DoOutputTransform(const CNodePtr &cnode, size_t index);
bool IsGraphOutput(const CNodePtr &cnode);
FuncGraphPtr func_graph_ = nullptr;
std::shared_ptr<ConverterPara> param_ = nullptr;
TypeId graph_input_dtype_ = TypeId::kTypeUnknown;
TypeId graph_output_dtype_ = TypeId::kTypeUnknown;
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_GRAPH_INOUT_TRANSFORM_PASS_H_

View File

@ -118,6 +118,9 @@ int PropagateQuantParamPass::ForwardPropagate(const std::list<CNodePtr> &nodes)
if (IsGraphInput(cnode) || opt::IsSpecialType(cnode) || opt::CheckPrimitiveType(cnode, prim::kPrimLstm)) {
continue;
}
if (opt::CheckPrimitiveType(cnode, prim::kPrimLstm)) {
continue;
}
// Infer quant param with forward (output->input).
auto curr_quant_holder = GetCNodeQuantHolder(cnode);
if (curr_quant_holder == nullptr) {
@ -182,6 +185,9 @@ int PropagateQuantParamPass::BackwardPropagate(const std::list<CNodePtr> &nodes)
if (IsGraphInput(cnode) || opt::IsSpecialType(cnode) || opt::CheckPrimitiveType(cnode, prim::kPrimLstm)) {
continue;
}
if (opt::CheckPrimitiveType(cnode, prim::kPrimLstm)) {
continue;
}
// Infer quant param with forward (output<-input).
auto curr_quant_holder = GetCNodeQuantHolder(cnode);
if (curr_quant_holder == nullptr) {
@ -214,7 +220,7 @@ int PropagateQuantParamPass::BackwardPropagate(const std::list<CNodePtr> &nodes)
}
}
} else {
MS_LOG(ERROR) << "Support for multi output.";
MS_LOG(INFO) << "Support for multi output.";
}
}
}

View File

@ -41,6 +41,10 @@ int QuantNodePass::DoWeightQuant(const CNodePtr &cnode) {
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " can not quant weight";
continue;
}
if (!CanTensorQuantized(cnode, input)) {
MS_LOG(INFO) << input->fullname_with_scope() << " is not quantized.";
continue;
}
int preferred_dim = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32(weight->shape()));
MS_CHECK_GT(static_cast<int>(quant_param_holder->get_input_quant_params().size()), static_cast<int>(idx - 1),
RET_ERROR);
@ -265,6 +269,40 @@ int QuantNodePass::DoFullQuant(const CNodePtr &cnode) {
return RET_OK;
}
bool QuantNodePass::CanTensorQuantized(const CNodePtr &cnode, const AnfNodePtr &input_node) {
if (input_node == nullptr) {
MS_LOG(INFO) << "CanTensorQuantized input is nullptr!";
return false;
}
ParameterPtr param_node = nullptr;
if (input_node->isa<Parameter>()) {
param_node = input_node->cast<ParameterPtr>();
}
if (param_node == nullptr) {
MS_LOG(INFO) << "CanTensorQuantized invalid param_node!";
return false;
}
if (!param_node->has_default()) {
MS_LOG(INFO) << "param_node don't has default.";
return false;
}
auto abstract_base = param_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(INFO) << "abstract is nullptr";
return false;
}
if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << param_node->name();
return false;
}
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
MS_ASSERT(weight_shape != nullptr);
if (weight_shape.size() < DIMENSION_2D) { // do not quant single dim tensors
return false;
}
return true;
}
int QuantNodePass::Quant() {
CHECK_NULL_RETURN(func_graph_);
auto cnodes = func_graph_->GetOrderedCnodes();

View File

@ -47,6 +47,7 @@ class QuantNodePass {
int DoParameterNodeQuant(const CNodePtr &cnode, const ParameterPtr &input_node, size_t input_index);
int DoValueNodeQuant(const CNodePtr &cnode, const ValueNodePtr &input_node, size_t input_index);
int IsSupportWeightQuant(const CNodePtr &cnode, const AnfNodePtr &input_node, size_t input_index);
bool CanTensorQuantized(const CNodePtr &cnode, const AnfNodePtr &input_node);
private:
FuncGraphPtr func_graph_ = nullptr;

View File

@ -26,9 +26,6 @@
#include "tools/common/node_util.h"
namespace mindspore::lite::quant {
namespace {
static const std::set<PrimitivePtr> fp32_output_operator = {prim::kPrimDetectionPostProcess};
}
bool QuantTypeDeterminer::DetermineQuantAll(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
if (opt::IsSpecialType(cnode)) {
@ -56,15 +53,28 @@ bool QuantTypeDeterminer::DetermineQuantAll(const CNodePtr &cnode) {
if (quant_holder->quant_type() != schema::QuantType_QUANT_NONE) {
return quant_holder->quant_type() == schema::QuantType_QUANT_ALL;
}
// if DTypeCastNode is graph input or output, set QuantType_QUANT_NONE
if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
return (!IsGraphInDTypeCast(cnode) && !IsGraphOutDTypeCast(func_graph_, cnode));
}
// Check input quant params, quantization parameters exist for all activations.
for (size_t i = 1; i < cnode->size(); ++i) {
auto input = cnode->input(i);
TypeId input_node_dtype = kTypeUnknown;
if (opt::GetDataTypeFromAnfNode(input, &input_node_dtype) != RET_OK) {
MS_LOG(INFO) << "Get data type failed, input node name: " << input->fullname_with_scope();
continue;
}
// Only specified dtype need check input quant params
if (kFullQuantDType.find(input_node_dtype) == kFullQuantDType.end()) {
continue;
}
if (input->isa<CNode>() && !quant_holder->CheckInit(i - kPrimOffset, true)) {
return false;
}
}
// Check output quant params.
if (CheckNodeInSet(cnode, fp32_output_operator) && !quant_holder->IsOutputQuantParamsInited()) {
if (!CheckNodeInSet(cnode, kUint8toFP32Operator) && !quant_holder->IsOutputQuantParamsInited()) {
return false;
}
return true;
@ -118,7 +128,7 @@ int QuantTypeDeterminer::Determine() {
MS_LOG(INFO) << cnode->fullname_with_scope() << " quant holder is nullptr.";
continue;
}
if (!quant_holder->IsInputQuantParamsInited() && !quant_holder->IsOutputQuantParamsInited()) { // Check FP32.
if (!quant_holder->IsInputExistInited() && !quant_holder->IsOutputExistInited()) { // Check FP32.
if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
continue;
}

View File

@ -20,7 +20,6 @@
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "mindspore/core/ops/core_ops.h"
namespace mindspore::lite::quant {
class QuantTypeDeterminer {

View File

@ -63,6 +63,26 @@ bool QuantParamHolder::IsOutputQuantParamsInited() {
return is_quant_params_inited;
}
bool QuantParamHolder::IsInputExistInited() {
if (this->input_quant_params_.empty()) {
return false;
}
bool is_exist_param_inited =
std::any_of(this->input_quant_params_.begin(), this->input_quant_params_.end(),
[](const std::vector<schema::QuantParamT> &quant_params) { return quant_params.front().inited; });
return is_exist_param_inited;
}
bool QuantParamHolder::IsOutputExistInited() {
if (this->output_quant_params_.empty()) {
return false;
}
bool is_exist_param_inited =
std::any_of(this->output_quant_params_.begin(), this->output_quant_params_.end(),
[](const std::vector<schema::QuantParamT> &quant_params) { return quant_params.front().inited; });
return is_exist_param_inited;
}
void QuantParamHolder::ClearQuantParams() {
quant_type_ = schema::QuantType_QUANT_NONE;
input_quant_params_.clear();

View File

@ -110,6 +110,10 @@ class QuantParamHolder : public Value {
bool IsOutputQuantParamsInited();
bool IsInputExistInited();
bool IsOutputExistInited();
void ClearQuantParams();
bool CheckInit(size_t index, bool is_input);

View File

@ -44,11 +44,15 @@ constexpr int kPrimIndex = 0;
constexpr int kPrimOffset = 1;
constexpr int kU8ZeroPointOffset = 128;
constexpr int kQuantRange = 127;
constexpr int kInt8LeftRange = -128;
constexpr int kInt8RightRange = 127;
constexpr int kMinIterations = 40;
const std::set<PrimitivePtr> kHasBiasOperator = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
prim::kPrimMatMulFusion, prim::kPrimFullConnection,
prim::kPrimLayerNormFusion};
const std::set<PrimitivePtr> kUint8toFP32Operator = {prim::kPrimDetectionPostProcess};
const std::set<TypeId> kFullQuantDType = {kNumberTypeInt8, kNumberTypeUInt8, kNumberTypeFloat32};
enum ActivationQuantizedMethod {
MAX_MIN = 0,
@ -68,6 +72,17 @@ enum DebugMode {
DETAIL,
};
enum CastNodeType {
kNone,
kQuant,
kDeQuant,
};
enum InsertDirection {
FORWARD,
BACKWARD,
};
struct CommonQuantParam {
schema::QuantType quant_type = schema::QuantType_QUANT_NONE;
int bit_num = 8;

View File

@ -139,6 +139,51 @@ int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type) {
return RET_OK;
}
ValueNodePtr NewQuantCastPrimitive(int src_type, int dst_type,
const std::vector<schema::QuantParamT> &input_quant_params,
const std::vector<schema::QuantParamT> &output_quant_params) {
auto prim_c = std::make_shared<ops::QuantDTypeCast>();
MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
prim_c->Init(src_type, dst_type);
auto quant_params_holder = std::make_shared<QuantParamHolder>(input_quant_params.size(), output_quant_params.size());
MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL);
quant_params_holder->set_input_quant_param(0, input_quant_params);
quant_params_holder->set_output_quant_param(0, output_quant_params);
auto prim = prim_c->GetPrim();
MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
prim->AddAttr("quant_params", quant_params_holder);
return NewValueNode(prim);
}
bool IsGraphInDTypeCast(const CNodePtr &cnode) {
if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
return false;
}
auto input_node = cnode->input(1);
MS_CHECK_FALSE(input_node == nullptr, false);
return IsGraphInput(input_node);
}
bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
return false;
}
auto manager = func_graph->manager();
if (manager == nullptr) {
manager = Manage(func_graph, true);
}
CHECK_NULL_RETURN(manager);
auto node_users = manager->node_users()[cnode];
for (auto &node_user : node_users) {
auto output_cnode = node_user.first->cast<CNodePtr>();
if (!opt::CheckPrimitiveType(output_cnode, prim::kPrimReturn)) {
return false;
}
}
return true;
}
bool TensorQuantParamsInited(const schema::TensorT &tensor) {
if (tensor.quantParams.empty()) {
return false;

View File

@ -43,8 +43,13 @@
#include "tools/converter/quantizer/mixed_bit_weight_quantization.h"
#include "tools/common/string_util.h"
#include "ops/core_ops.h"
#include "ops/quant_dtype_cast.h"
namespace mindspore::lite::quant {
static const std::set<PrimitivePtr> has_bias_operator = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
prim::kPrimMatMulFusion, prim::kPrimFullConnection,
prim::kPrimLayerNormFusion};
QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive);
QuantParamHolderPtr GetCNodeQuantHolder(const CNodePtr &cnode);
@ -71,6 +76,13 @@ void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_f
int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type);
ValueNodePtr NewQuantCastPrimitive(int src_type, int dst_type,
const std::vector<schema::QuantParamT> &input_quant_params,
const std::vector<schema::QuantParamT> &output_quant_params);
bool IsGraphInDTypeCast(const CNodePtr &cnode);
bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
template <typename T>
int DeQuantData(const int8_t *tensor_data, int64_t elements_num, std::vector<mindspore::QuantParam> quant_params,
std::vector<T> *dequant_data, int preferred_dim = 0) {