!40670 enable QATTransform pass

Merge pull request !40670 from liyan2022/dev_qat_smoke
This commit is contained in:
i-robot 2022-08-22 02:08:54 +00:00 committed by Gitee
commit 6ec539f79e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 143 additions and 48 deletions

View File

@ -415,8 +415,8 @@ bool AnfTransform::CheckExternalExtension(const std::shared_ptr<ConverterPara> &
STATUS AnfTransform::QATTransform(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param) {
if (param->fullQuantParam.target_device == quant::TargetDevice::DSP &&
param->commonQuantParam.quant_type != schema::QuantType_QUANT_ALL) {
auto remove_pass = quant::RemoveUnusedQuantParam(func_graph);
auto ret = remove_pass.Remove();
auto remove_quant_param_pass = quant::RemoveQuantParam(func_graph);
auto ret = remove_quant_param_pass.Remove();
if (ret != RET_OK) {
MS_LOG(ERROR) << "remove unused quant param failed.";
return RET_ERROR;
@ -434,8 +434,8 @@ STATUS AnfTransform::QATTransform(const FuncGraphPtr &func_graph, const std::sha
MS_LOG(ERROR) << "Run quant type determine failed.";
return ret;
}
auto dtype_trans_pass = quant::DTypeTransformPass(func_graph);
ret = dtype_trans_pass.Transform();
auto transform_uint8_pass = quant::TransformUint8Pass(func_graph);
ret = transform_uint8_pass.Transform();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run dtype transform pass failed.";
return ret;
@ -589,6 +589,12 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
return nullptr;
}
status = QATTransform(old_graph, param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do QATTransform failed.";
return nullptr;
}
status = DoQuantize(old_graph, param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do Quantize failed.";

View File

@ -55,22 +55,6 @@ int QuantTransform(const std::shared_ptr<ConverterPara> &param, schema::MetaGrap
// quantization
if (param->commonQuantParam.quant_type == schema::QuantType_QUANT_NONE ||
param->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) {
{
// quantization
// init old node indices
auto old_nodes = GetGraphNodes(*graph_defT);
Optimizer tensor_quant_optimizer;
tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
tensor_quant_optimizer.AddPass(new (std::nothrow) InferQuantParamPass());
tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass(param->fmk_type));
tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass());
tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
auto status = tensor_quant_optimizer.Run(graph_defT);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantize failed!";
return status;
}
}
{
// quantization
// init old node indices

View File

@ -42,12 +42,6 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status;
return status;
}
status = DoNodeInoutDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status;
return status;
}
return RET_OK;
}

View File

@ -28,11 +28,11 @@
namespace mindspore::lite::quant {
// only enable for uint8
int DTypeTransformPass::Transform() {
int TransformUint8Pass::Transform() {
auto cnodes = func_graph_->GetOrderedCnodes();
for (auto &cnode : cnodes) {
if (!CheckNeedDTypeTrans(cnode)) {
MS_LOG(INFO) << "CheckNeedDTypeTrans invalid cnode, cnode name: " << cnode->fullname_with_scope();
MS_LOG(DEBUG) << "CheckNeedDTypeTrans invalid cnode, cnode name: " << cnode->fullname_with_scope();
continue;
}
auto status = DoNodeDTypeTrans(cnode);
@ -67,26 +67,45 @@ int DTypeTransformPass::Transform() {
return RET_OK;
}
int DTypeTransformPass::DoParameterNodeTrans(const CNodePtr &cnode, const ParameterPtr &input_node,
int TransformUint8Pass::DoParameterNodeTrans(const CNodePtr &cnode, const ParameterPtr &input_node,
size_t input_index) {
CHECK_NULL_RETURN(cnode);
CHECK_NULL_RETURN(input_node);
if (input_index == THIRD_INPUT + 1 && CheckNodeInSet(cnode, kHasBiasOperator)) {
return RET_OK;
return RET_NOT_SUPPORT;
}
auto tensor_info = input_node->default_param()->cast<tensor::TensorPtr>();
CHECK_NULL_RETURN(tensor_info);
bool is_shared_weight = IsSharedWeightParameter(input_node);
auto weight_name = input_node->fullname_with_scope();
if (is_shared_weight) {
auto iter = shared_weight_quant_params_.find(weight_name);
if (iter != shared_weight_quant_params_.end()) {
auto quant_param_holder = GetCNodeQuantHolder(cnode);
CHECK_NULL_RETURN(quant_param_holder);
quant_param_holder->set_input_quant_param(input_index - 1, iter->second);
return RET_NO_CHANGE;
}
}
// filter condition: dtype == kNumberTypeUInt8
if (tensor_info->data_type() != kNumberTypeUInt8) {
MS_LOG(INFO) << input_node->fullname_with_scope() << " dtype not uint8.";
return RET_ERROR;
return RET_NOT_SUPPORT;
}
// transform weight data
size_t elem_count = tensor_info->DataSize();
auto ret = Uint8toInt8(static_cast<uint8_t *>(tensor_info->data().data()), elem_count);
if (ret != RET_OK) {
MS_LOG(ERROR) << input_node->fullname_with_scope() << " transform data uint8 to int8 failed.";
return ret;
}
// update zp
auto quant_param_holder = GetCNodeQuantHolder(cnode);
CHECK_NULL_RETURN(quant_param_holder);
if (quant_param_holder->get_input_quant_params().size() < input_index) {
MS_LOG(ERROR) << "Invalid quant params. input node name: " << input_node->fullname_with_scope();
return RET_ERROR;
@ -96,6 +115,9 @@ int DTypeTransformPass::DoParameterNodeTrans(const CNodePtr &cnode, const Parame
quant_param.zeroPoint -= kU8ZeroPointOffset;
}
quant_param_holder->set_input_quant_param(input_index - 1, quant_params);
if (is_shared_weight && shared_weight_quant_params_.find(weight_name) == shared_weight_quant_params_.end()) {
shared_weight_quant_params_.insert({weight_name, quant_params});
}
// set dtype
tensor_info->set_data_type(kNumberTypeInt8);
@ -107,7 +129,7 @@ int DTypeTransformPass::DoParameterNodeTrans(const CNodePtr &cnode, const Parame
return RET_OK;
}
int DTypeTransformPass::Uint8toInt8(uint8_t *data, int size) {
int TransformUint8Pass::Uint8toInt8(uint8_t *data, int size) {
CHECK_NULL_RETURN(data);
for (int i = 0; i < size; i++) {
@ -123,7 +145,7 @@ int DTypeTransformPass::Uint8toInt8(uint8_t *data, int size) {
return RET_OK;
}
int DTypeTransformPass::GetQuantType(const CNodePtr &cnode, schema::QuantType *quant_type) {
int TransformUint8Pass::GetQuantType(const CNodePtr &cnode, schema::QuantType *quant_type) {
CHECK_NULL_RETURN(cnode);
auto quant_param_holder = GetCNodeQuantHolder(cnode);
CHECK_NULL_RETURN(quant_param_holder);
@ -134,7 +156,7 @@ int DTypeTransformPass::GetQuantType(const CNodePtr &cnode, schema::QuantType *q
/**
* Transform CNode(dtype,uint8toint8,weigh data)
* */
int DTypeTransformPass::DoNodeDTypeTrans(const CNodePtr &cnode) {
int TransformUint8Pass::DoNodeDTypeTrans(const CNodePtr &cnode) {
auto curr_quant_param_holder = GetCNodeQuantHolder(cnode);
CHECK_NULL_RETURN(curr_quant_param_holder);
TypeId cnode_dtype = kTypeUnknown;
@ -190,15 +212,17 @@ int DTypeTransformPass::DoNodeDTypeTrans(const CNodePtr &cnode) {
curr_quant_param_holder->set_input_quant_param(index - 1, input_quant_params);
} else if (input_node->isa<mindspore::Parameter>()) { // weight data
auto ret = DoParameterNodeTrans(cnode, input_node->cast<ParameterPtr>(), index);
if (ret != RET_OK) {
bool is_failed = (ret != RET_OK && ret != RET_NOT_SUPPORT && ret != RET_NO_CHANGE);
if (is_failed) {
MS_LOG(WARNING) << "DoParameterNodeTrans failed, input node name: " << input_node->fullname_with_scope();
return ret;
}
}
}
return RET_OK;
}
int DTypeTransformPass::InsertForwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type) {
int TransformUint8Pass::InsertForwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type) {
// inputs
quant::InsertQuantNodeManager insert_node_manager;
for (size_t index = 1; index < cnode->size(); index++) {
@ -233,7 +257,7 @@ int DTypeTransformPass::InsertForwardCastNode(const CNodePtr &cnode, schema::Qua
return RET_OK;
}
int DTypeTransformPass::InsertBackwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type) {
int TransformUint8Pass::InsertBackwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type) {
// outputs
auto manager = this->func_graph_->manager();
if (manager == nullptr) {
@ -262,13 +286,23 @@ int DTypeTransformPass::InsertBackwardCastNode(const CNodePtr &cnode, schema::Qu
return RET_OK;
}
bool DTypeTransformPass::CheckNeedDTypeTrans(const CNodePtr &cnode) {
bool TransformUint8Pass::CheckNeedDTypeTrans(const CNodePtr &cnode) {
if (opt::IsSpecialType(cnode)) {
return false;
}
if (IsGraphInDTypeCast(cnode) || IsGraphOutDTypeCast(func_graph_, cnode)) {
// If CastNode(kDeQuant) as graph input node, or CastNode(kQuant) as graph output node, do nothing.
CastNodeType cast_node_type = kNone;
auto status = quant::GetCastNodeType(func_graph_, cnode, &cast_node_type);
if (status == RET_OK) {
if ((cast_node_type == kDeQuant && IsGraphInDTypeCast(cnode)) ||
(IsGraphOutDTypeCast(func_graph_, cnode) && cast_node_type == kQuant)) {
return false;
}
} else if (status != RET_NOT_SUPPORT) {
MS_LOG(ERROR) << "Get cast node type failed, cnode name: " << cnode->fullname_with_scope();
return false;
}
TypeId cnode_dtype = kTypeUnknown;
if (opt::GetDataTypeFromAnfNode(cnode, &cnode_dtype) != RET_OK) {
MS_LOG(ERROR) << "Get data type failed, cnode name: " << cnode->fullname_with_scope();
@ -282,4 +316,14 @@ bool DTypeTransformPass::CheckNeedDTypeTrans(const CNodePtr &cnode) {
}
return true;
}
bool TransformUint8Pass::IsSharedWeightParameter(const AnfNodePtr &anf_node) {
auto manager = this->func_graph_->manager();
if (manager == nullptr) {
manager = Manage(this->func_graph_, true);
}
CHECK_NULL_RETURN(manager);
auto node_users = manager->node_users()[anf_node];
return (node_users.size() > 1);
}
} // namespace mindspore::lite::quant

View File

@ -18,6 +18,8 @@
#include <memory>
#include <vector>
#include <string>
#include <map>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ops/primitive_c.h"
@ -27,14 +29,14 @@
namespace mindspore::lite::quant {
/**
* Transform CNode(dtype,uint8toint8,weigh data)
* Transform CNode(dtype uint8toint8, transform weigh data)
* Insert QuantCastNode
* */
class DTypeTransformPass {
class TransformUint8Pass {
public:
explicit DTypeTransformPass(const FuncGraphPtr &func_graph) : func_graph_(func_graph) {}
explicit TransformUint8Pass(const FuncGraphPtr &func_graph) : func_graph_(func_graph) {}
~DTypeTransformPass() = default;
~TransformUint8Pass() = default;
int Transform();
@ -53,7 +55,12 @@ class DTypeTransformPass {
int InsertBackwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type);
bool IsSharedWeightParameter(const AnfNodePtr &anf_node);
FuncGraphPtr func_graph_ = nullptr;
// key is tensor_name
std::map<std::string, std::vector<schema::QuantParamT>> shared_weight_quant_params_;
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_DTYPE_TRANSFORM_PASS_H_

View File

@ -156,7 +156,10 @@ int PropagateQuantParamPass::ForwardPropagate(const std::list<CNodePtr> &nodes)
auto before_cnode = before_cnode_map.first;
size_t before_out_index = before_cnode_map.second;
auto before_quant_holder = GetCNodeQuantHolder(before_cnode);
CHECK_NULL_RETURN(before_quant_holder);
if (before_quant_holder == nullptr) {
MS_LOG(WARNING) << cnode->fullname_with_scope() << " get before_quant_holder failed.";
continue;
}
auto before_output_quant_param = before_quant_holder->get_output_quant_params();
if (before_output_quant_param.size() > before_out_index && before_quant_holder->IsOutputQuantParamsInited()) {
MS_LOG(INFO) << before_cnode->fullname_with_scope() << " forward propagate to " << cnode->fullname_with_scope();

View File

@ -17,7 +17,7 @@
#include "tools/converter/quantizer/quant_helper/remove_unused_quant_param.h"
namespace mindspore::lite::quant {
int RemoveUnusedQuantParam::Remove() {
int RemoveQuantParam::Remove() {
CHECK_NULL_RETURN(func_graph_);
auto nodes = func_graph_->GetOrderedCnodes();
for (auto const &cnode : nodes) {

View File

@ -22,10 +22,10 @@
#include "tools/converter/quantizer/quantize_util.h"
namespace mindspore::lite::quant {
class RemoveUnusedQuantParam {
class RemoveQuantParam {
public:
explicit RemoveUnusedQuantParam(const FuncGraphPtr &funcGraph) : func_graph_(funcGraph) {}
~RemoveUnusedQuantParam() = default;
explicit RemoveQuantParam(const FuncGraphPtr &funcGraph) : func_graph_(funcGraph) {}
~RemoveQuantParam() = default;
public:
int Remove();

View File

@ -175,8 +175,10 @@ bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
}
CHECK_NULL_RETURN(manager);
auto node_users = manager->node_users()[cnode];
MS_CHECK_TRUE_RET(!node_users.empty(), RET_NULL_PTR);
for (auto &node_user : node_users) {
auto output_cnode = node_user.first->cast<CNodePtr>();
CHECK_NULL_RETURN(output_cnode);
if (!opt::CheckPrimitiveType(output_cnode, prim::kPrimReturn)) {
return false;
}
@ -184,6 +186,59 @@ bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
return true;
}
int GetCastNodeType(const FuncGraphPtr &func_graph, const CNodePtr &cnode, CastNodeType *cast_node_type) {
if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
MS_LOG(DEBUG) << "Not QuantDtypeCastNode, cnode name: " << cnode->fullname_with_scope();
return RET_NOT_SUPPORT;
}
auto input_node = cnode->input(1);
MS_CHECK_FALSE(input_node == nullptr, RET_ERROR);
// input node
TypeId pre_node_dtype = kTypeUnknown;
if (opt::GetDataTypeFromAnfNode(input_node, &pre_node_dtype) != RET_OK) {
MS_LOG(ERROR) << "Get data type failed, cnode name: " << input_node->fullname_with_scope();
return RET_ERROR;
}
// output node
TypeId post_node_dtype = kTypeUnknown;
auto manager = func_graph->manager();
if (manager == nullptr) {
manager = Manage(func_graph, true);
}
CHECK_NULL_RETURN(manager);
auto node_users = manager->node_users()[cnode];
MS_CHECK_TRUE_RET(!node_users.empty(), RET_NULL_PTR);
auto output_cnode = node_users.begin()->first->cast<CNodePtr>();
CHECK_NULL_RETURN(output_cnode);
if (!opt::CheckPrimitiveType(output_cnode, prim::kPrimReturn)) {
if (opt::GetDataTypeFromAnfNode(output_cnode, &post_node_dtype) != RET_OK) {
MS_LOG(ERROR) << "Get data type failed, cnode name: " << output_cnode->fullname_with_scope();
return RET_ERROR;
}
if (pre_node_dtype == kNumberTypeFloat32 &&
(post_node_dtype == kNumberTypeInt8 || post_node_dtype == kNumberTypeUInt8)) {
*cast_node_type = kQuant;
} else if ((pre_node_dtype == kNumberTypeInt8 || pre_node_dtype == kNumberTypeUInt8) &&
post_node_dtype == kNumberTypeFloat32) {
*cast_node_type = kDeQuant;
} else {
MS_LOG(ERROR) << "Not support QuantDTypeCastNode, cnode name: " << cnode->fullname_with_scope();
}
} else {
if (pre_node_dtype == kNumberTypeFloat32) {
*cast_node_type = kQuant;
} else if (pre_node_dtype == kNumberTypeInt8 || pre_node_dtype == kNumberTypeUInt8) {
*cast_node_type = kDeQuant;
} else {
MS_LOG(ERROR) << "Not support QuantDTypeCastNode, cnode name: " << cnode->fullname_with_scope();
}
}
return RET_OK;
}
bool TensorQuantParamsInited(const schema::TensorT &tensor) {
if (tensor.quantParams.empty()) {
return false;

View File

@ -83,6 +83,8 @@ bool IsGraphInDTypeCast(const CNodePtr &cnode);
bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
int GetCastNodeType(const FuncGraphPtr &func_graph, const CNodePtr &cnode, CastNodeType *cast_node_type);
template <typename T>
int DeQuantData(const int8_t *tensor_data, int64_t elements_num, std::vector<mindspore::QuantParam> quant_params,
std::vector<T> *dequant_data, int preferred_dim = 0) {