forked from mindspore-Ecosystem/mindspore
!40670 enable QATTransform pass
Merge pull request !40670 from liyan2022/dev_qat_smoke
This commit is contained in:
commit
6ec539f79e
|
@ -415,8 +415,8 @@ bool AnfTransform::CheckExternalExtension(const std::shared_ptr<ConverterPara> &
|
|||
STATUS AnfTransform::QATTransform(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
||||
if (param->fullQuantParam.target_device == quant::TargetDevice::DSP &&
|
||||
param->commonQuantParam.quant_type != schema::QuantType_QUANT_ALL) {
|
||||
auto remove_pass = quant::RemoveUnusedQuantParam(func_graph);
|
||||
auto ret = remove_pass.Remove();
|
||||
auto remove_quant_param_pass = quant::RemoveQuantParam(func_graph);
|
||||
auto ret = remove_quant_param_pass.Remove();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "remove unused quant param failed.";
|
||||
return RET_ERROR;
|
||||
|
@ -434,8 +434,8 @@ STATUS AnfTransform::QATTransform(const FuncGraphPtr &func_graph, const std::sha
|
|||
MS_LOG(ERROR) << "Run quant type determine failed.";
|
||||
return ret;
|
||||
}
|
||||
auto dtype_trans_pass = quant::DTypeTransformPass(func_graph);
|
||||
ret = dtype_trans_pass.Transform();
|
||||
auto transform_uint8_pass = quant::TransformUint8Pass(func_graph);
|
||||
ret = transform_uint8_pass.Transform();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run dtype transform pass failed.";
|
||||
return ret;
|
||||
|
@ -589,6 +589,12 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
status = QATTransform(old_graph, param);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do QATTransform failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = DoQuantize(old_graph, param);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do Quantize failed.";
|
||||
|
|
|
@ -55,22 +55,6 @@ int QuantTransform(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGrap
|
|||
// quantization
|
||||
if (param->commonQuantParam.quant_type == schema::QuantType_QUANT_NONE ||
|
||||
param->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) {
|
||||
{
|
||||
// quantization
|
||||
// init old node indices
|
||||
auto old_nodes = GetGraphNodes(*graph_defT);
|
||||
Optimizer tensor_quant_optimizer;
|
||||
tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||
tensor_quant_optimizer.AddPass(new (std::nothrow) InferQuantParamPass());
|
||||
tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass(param->fmk_type));
|
||||
tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass());
|
||||
tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||
auto status = tensor_quant_optimizer.Run(graph_defT);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantize failed!";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
{
|
||||
// quantization
|
||||
// init old node indices
|
||||
|
|
|
@ -42,12 +42,6 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
|
|||
MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status;
|
||||
return status;
|
||||
}
|
||||
|
||||
status = DoNodeInoutDTypeTrans(graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status;
|
||||
return status;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -28,11 +28,11 @@
|
|||
|
||||
namespace mindspore::lite::quant {
|
||||
// only enable for uint8
|
||||
int DTypeTransformPass::Transform() {
|
||||
int TransformUint8Pass::Transform() {
|
||||
auto cnodes = func_graph_->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
if (!CheckNeedDTypeTrans(cnode)) {
|
||||
MS_LOG(INFO) << "CheckNeedDTypeTrans invalid cnode, cnode name: " << cnode->fullname_with_scope();
|
||||
MS_LOG(DEBUG) << "CheckNeedDTypeTrans invalid cnode, cnode name: " << cnode->fullname_with_scope();
|
||||
continue;
|
||||
}
|
||||
auto status = DoNodeDTypeTrans(cnode);
|
||||
|
@ -67,26 +67,45 @@ int DTypeTransformPass::Transform() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int DTypeTransformPass::DoParameterNodeTrans(const CNodePtr &cnode, const ParameterPtr &input_node,
|
||||
int TransformUint8Pass::DoParameterNodeTrans(const CNodePtr &cnode, const ParameterPtr &input_node,
|
||||
size_t input_index) {
|
||||
CHECK_NULL_RETURN(cnode);
|
||||
CHECK_NULL_RETURN(input_node);
|
||||
if (input_index == THIRD_INPUT + 1 && CheckNodeInSet(cnode, kHasBiasOperator)) {
|
||||
return RET_OK;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
auto tensor_info = input_node->default_param()->cast<tensor::TensorPtr>();
|
||||
CHECK_NULL_RETURN(tensor_info);
|
||||
bool is_shared_weight = IsSharedWeightParameter(input_node);
|
||||
auto weight_name = input_node->fullname_with_scope();
|
||||
|
||||
if (is_shared_weight) {
|
||||
auto iter = shared_weight_quant_params_.find(weight_name);
|
||||
if (iter != shared_weight_quant_params_.end()) {
|
||||
auto quant_param_holder = GetCNodeQuantHolder(cnode);
|
||||
CHECK_NULL_RETURN(quant_param_holder);
|
||||
quant_param_holder->set_input_quant_param(input_index - 1, iter->second);
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
}
|
||||
|
||||
// filter condition: dtype == kNumberTypeUInt8
|
||||
if (tensor_info->data_type() != kNumberTypeUInt8) {
|
||||
MS_LOG(INFO) << input_node->fullname_with_scope() << " dtype not uint8.";
|
||||
return RET_ERROR;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
||||
// transform weight data
|
||||
size_t elem_count = tensor_info->DataSize();
|
||||
auto ret = Uint8toInt8(static_cast<uint8_t *>(tensor_info->data().data()), elem_count);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << input_node->fullname_with_scope() << " transform data uint8 to int8 failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
// update zp
|
||||
auto quant_param_holder = GetCNodeQuantHolder(cnode);
|
||||
CHECK_NULL_RETURN(quant_param_holder);
|
||||
if (quant_param_holder->get_input_quant_params().size() < input_index) {
|
||||
MS_LOG(ERROR) << "Invalid quant params. input node name: " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
|
@ -96,6 +115,9 @@ int DTypeTransformPass::DoParameterNodeTrans(const CNodePtr &cnode, const Parame
|
|||
quant_param.zeroPoint -= kU8ZeroPointOffset;
|
||||
}
|
||||
quant_param_holder->set_input_quant_param(input_index - 1, quant_params);
|
||||
if (is_shared_weight && shared_weight_quant_params_.find(weight_name) == shared_weight_quant_params_.end()) {
|
||||
shared_weight_quant_params_.insert({weight_name, quant_params});
|
||||
}
|
||||
|
||||
// set dtype
|
||||
tensor_info->set_data_type(kNumberTypeInt8);
|
||||
|
@ -107,7 +129,7 @@ int DTypeTransformPass::DoParameterNodeTrans(const CNodePtr &cnode, const Parame
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int DTypeTransformPass::Uint8toInt8(uint8_t *data, int size) {
|
||||
int TransformUint8Pass::Uint8toInt8(uint8_t *data, int size) {
|
||||
CHECK_NULL_RETURN(data);
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
|
@ -123,7 +145,7 @@ int DTypeTransformPass::Uint8toInt8(uint8_t *data, int size) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int DTypeTransformPass::GetQuantType(const CNodePtr &cnode, schema::QuantType *quant_type) {
|
||||
int TransformUint8Pass::GetQuantType(const CNodePtr &cnode, schema::QuantType *quant_type) {
|
||||
CHECK_NULL_RETURN(cnode);
|
||||
auto quant_param_holder = GetCNodeQuantHolder(cnode);
|
||||
CHECK_NULL_RETURN(quant_param_holder);
|
||||
|
@ -134,7 +156,7 @@ int DTypeTransformPass::GetQuantType(const CNodePtr &cnode, schema::QuantType *q
|
|||
/**
|
||||
* Transform CNode(dtype,uint8toint8,weigh data)
|
||||
* */
|
||||
int DTypeTransformPass::DoNodeDTypeTrans(const CNodePtr &cnode) {
|
||||
int TransformUint8Pass::DoNodeDTypeTrans(const CNodePtr &cnode) {
|
||||
auto curr_quant_param_holder = GetCNodeQuantHolder(cnode);
|
||||
CHECK_NULL_RETURN(curr_quant_param_holder);
|
||||
TypeId cnode_dtype = kTypeUnknown;
|
||||
|
@ -190,15 +212,17 @@ int DTypeTransformPass::DoNodeDTypeTrans(const CNodePtr &cnode) {
|
|||
curr_quant_param_holder->set_input_quant_param(index - 1, input_quant_params);
|
||||
} else if (input_node->isa<mindspore::Parameter>()) { // weight data
|
||||
auto ret = DoParameterNodeTrans(cnode, input_node->cast<ParameterPtr>(), index);
|
||||
if (ret != RET_OK) {
|
||||
bool is_failed = (ret != RET_OK && ret != RET_NOT_SUPPORT && ret != RET_NO_CHANGE);
|
||||
if (is_failed) {
|
||||
MS_LOG(WARNING) << "DoParameterNodeTrans failed, input node name: " << input_node->fullname_with_scope();
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DTypeTransformPass::InsertForwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type) {
|
||||
int TransformUint8Pass::InsertForwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type) {
|
||||
// inputs
|
||||
quant::InsertQuantNodeManager insert_node_manager;
|
||||
for (size_t index = 1; index < cnode->size(); index++) {
|
||||
|
@ -233,7 +257,7 @@ int DTypeTransformPass::InsertForwardCastNode(const CNodePtr &cnode, schema::Qua
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int DTypeTransformPass::InsertBackwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type) {
|
||||
int TransformUint8Pass::InsertBackwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type) {
|
||||
// outputs
|
||||
auto manager = this->func_graph_->manager();
|
||||
if (manager == nullptr) {
|
||||
|
@ -262,13 +286,23 @@ int DTypeTransformPass::InsertBackwardCastNode(const CNodePtr &cnode, schema::Qu
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
bool DTypeTransformPass::CheckNeedDTypeTrans(const CNodePtr &cnode) {
|
||||
bool TransformUint8Pass::CheckNeedDTypeTrans(const CNodePtr &cnode) {
|
||||
if (opt::IsSpecialType(cnode)) {
|
||||
return false;
|
||||
}
|
||||
if (IsGraphInDTypeCast(cnode) || IsGraphOutDTypeCast(func_graph_, cnode)) {
|
||||
// If CastNode(kDeQuant) as graph input node, or CastNode(kQuant) as graph output node, do nothing.
|
||||
CastNodeType cast_node_type = kNone;
|
||||
auto status = quant::GetCastNodeType(func_graph_, cnode, &cast_node_type);
|
||||
if (status == RET_OK) {
|
||||
if ((cast_node_type == kDeQuant && IsGraphInDTypeCast(cnode)) ||
|
||||
(IsGraphOutDTypeCast(func_graph_, cnode) && cast_node_type == kQuant)) {
|
||||
return false;
|
||||
}
|
||||
} else if (status != RET_NOT_SUPPORT) {
|
||||
MS_LOG(ERROR) << "Get cast node type failed, cnode name: " << cnode->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
|
||||
TypeId cnode_dtype = kTypeUnknown;
|
||||
if (opt::GetDataTypeFromAnfNode(cnode, &cnode_dtype) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get data type failed, cnode name: " << cnode->fullname_with_scope();
|
||||
|
@ -282,4 +316,14 @@ bool DTypeTransformPass::CheckNeedDTypeTrans(const CNodePtr &cnode) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TransformUint8Pass::IsSharedWeightParameter(const AnfNodePtr &anf_node) {
|
||||
auto manager = this->func_graph_->manager();
|
||||
if (manager == nullptr) {
|
||||
manager = Manage(this->func_graph_, true);
|
||||
}
|
||||
CHECK_NULL_RETURN(manager);
|
||||
auto node_users = manager->node_users()[anf_node];
|
||||
return (node_users.size() > 1);
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ops/primitive_c.h"
|
||||
|
@ -27,14 +29,14 @@
|
|||
|
||||
namespace mindspore::lite::quant {
|
||||
/**
|
||||
* Transform CNode(dtype,uint8toint8,weigh data)
|
||||
* Transform CNode(dtype uint8toint8, transform weigh data)
|
||||
* Insert QuantCastNode
|
||||
* */
|
||||
class DTypeTransformPass {
|
||||
class TransformUint8Pass {
|
||||
public:
|
||||
explicit DTypeTransformPass(const FuncGraphPtr &func_graph) : func_graph_(func_graph) {}
|
||||
explicit TransformUint8Pass(const FuncGraphPtr &func_graph) : func_graph_(func_graph) {}
|
||||
|
||||
~DTypeTransformPass() = default;
|
||||
~TransformUint8Pass() = default;
|
||||
|
||||
int Transform();
|
||||
|
||||
|
@ -53,7 +55,12 @@ class DTypeTransformPass {
|
|||
|
||||
int InsertBackwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type);
|
||||
|
||||
bool IsSharedWeightParameter(const AnfNodePtr &anf_node);
|
||||
|
||||
FuncGraphPtr func_graph_ = nullptr;
|
||||
|
||||
// key is tensor_name
|
||||
std::map<std::string, std::vector<schema::QuantParamT>> shared_weight_quant_params_;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_DTYPE_TRANSFORM_PASS_H_
|
||||
|
|
|
@ -156,7 +156,10 @@ int PropagateQuantParamPass::ForwardPropagate(const std::list<CNodePtr> &nodes)
|
|||
auto before_cnode = before_cnode_map.first;
|
||||
size_t before_out_index = before_cnode_map.second;
|
||||
auto before_quant_holder = GetCNodeQuantHolder(before_cnode);
|
||||
CHECK_NULL_RETURN(before_quant_holder);
|
||||
if (before_quant_holder == nullptr) {
|
||||
MS_LOG(WARNING) << cnode->fullname_with_scope() << " get before_quant_holder failed.";
|
||||
continue;
|
||||
}
|
||||
auto before_output_quant_param = before_quant_holder->get_output_quant_params();
|
||||
if (before_output_quant_param.size() > before_out_index && before_quant_holder->IsOutputQuantParamsInited()) {
|
||||
MS_LOG(INFO) << before_cnode->fullname_with_scope() << " forward propagate to " << cnode->fullname_with_scope();
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "tools/converter/quantizer/quant_helper/remove_unused_quant_param.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
int RemoveUnusedQuantParam::Remove() {
|
||||
int RemoveQuantParam::Remove() {
|
||||
CHECK_NULL_RETURN(func_graph_);
|
||||
auto nodes = func_graph_->GetOrderedCnodes();
|
||||
for (auto const &cnode : nodes) {
|
||||
|
|
|
@ -22,10 +22,10 @@
|
|||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class RemoveUnusedQuantParam {
|
||||
class RemoveQuantParam {
|
||||
public:
|
||||
explicit RemoveUnusedQuantParam(const FuncGraphPtr &funcGraph) : func_graph_(funcGraph) {}
|
||||
~RemoveUnusedQuantParam() = default;
|
||||
explicit RemoveQuantParam(const FuncGraphPtr &funcGraph) : func_graph_(funcGraph) {}
|
||||
~RemoveQuantParam() = default;
|
||||
|
||||
public:
|
||||
int Remove();
|
||||
|
|
|
@ -175,8 +175,10 @@ bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
|
|||
}
|
||||
CHECK_NULL_RETURN(manager);
|
||||
auto node_users = manager->node_users()[cnode];
|
||||
MS_CHECK_TRUE_RET(!node_users.empty(), RET_NULL_PTR);
|
||||
for (auto &node_user : node_users) {
|
||||
auto output_cnode = node_user.first->cast<CNodePtr>();
|
||||
CHECK_NULL_RETURN(output_cnode);
|
||||
if (!opt::CheckPrimitiveType(output_cnode, prim::kPrimReturn)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -184,6 +186,59 @@ bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode)
|
|||
return true;
|
||||
}
|
||||
|
||||
int GetCastNodeType(const FuncGraphPtr &func_graph, const CNodePtr &cnode, CastNodeType *cast_node_type) {
|
||||
if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
|
||||
MS_LOG(DEBUG) << "Not QuantDtypeCastNode, cnode name: " << cnode->fullname_with_scope();
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
auto input_node = cnode->input(1);
|
||||
MS_CHECK_FALSE(input_node == nullptr, RET_ERROR);
|
||||
|
||||
// input node
|
||||
TypeId pre_node_dtype = kTypeUnknown;
|
||||
if (opt::GetDataTypeFromAnfNode(input_node, &pre_node_dtype) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get data type failed, cnode name: " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// output node
|
||||
TypeId post_node_dtype = kTypeUnknown;
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
manager = Manage(func_graph, true);
|
||||
}
|
||||
CHECK_NULL_RETURN(manager);
|
||||
auto node_users = manager->node_users()[cnode];
|
||||
MS_CHECK_TRUE_RET(!node_users.empty(), RET_NULL_PTR);
|
||||
auto output_cnode = node_users.begin()->first->cast<CNodePtr>();
|
||||
CHECK_NULL_RETURN(output_cnode);
|
||||
|
||||
if (!opt::CheckPrimitiveType(output_cnode, prim::kPrimReturn)) {
|
||||
if (opt::GetDataTypeFromAnfNode(output_cnode, &post_node_dtype) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get data type failed, cnode name: " << output_cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (pre_node_dtype == kNumberTypeFloat32 &&
|
||||
(post_node_dtype == kNumberTypeInt8 || post_node_dtype == kNumberTypeUInt8)) {
|
||||
*cast_node_type = kQuant;
|
||||
} else if ((pre_node_dtype == kNumberTypeInt8 || pre_node_dtype == kNumberTypeUInt8) &&
|
||||
post_node_dtype == kNumberTypeFloat32) {
|
||||
*cast_node_type = kDeQuant;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Not support QuantDTypeCastNode, cnode name: " << cnode->fullname_with_scope();
|
||||
}
|
||||
} else {
|
||||
if (pre_node_dtype == kNumberTypeFloat32) {
|
||||
*cast_node_type = kQuant;
|
||||
} else if (pre_node_dtype == kNumberTypeInt8 || pre_node_dtype == kNumberTypeUInt8) {
|
||||
*cast_node_type = kDeQuant;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Not support QuantDTypeCastNode, cnode name: " << cnode->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool TensorQuantParamsInited(const schema::TensorT &tensor) {
|
||||
if (tensor.quantParams.empty()) {
|
||||
return false;
|
||||
|
|
|
@ -83,6 +83,8 @@ bool IsGraphInDTypeCast(const CNodePtr &cnode);
|
|||
|
||||
bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
|
||||
int GetCastNodeType(const FuncGraphPtr &func_graph, const CNodePtr &cnode, CastNodeType *cast_node_type);
|
||||
|
||||
template <typename T>
|
||||
int DeQuantData(const int8_t *tensor_data, int64_t elements_num, std::vector<mindspore::QuantParam> quant_params,
|
||||
std::vector<T> *dequant_data, int preferred_dim = 0) {
|
||||
|
|
Loading…
Reference in New Issue