forked from mindspore-Ecosystem/mindspore
!40553 modify uint8 dtype transform pass
Merge pull request !40553 from liyan2022/dev_qat
This commit is contained in:
commit
efd05d4b71
|
@ -102,6 +102,7 @@
|
|||
#include "tools/converter/quantizer/quant_helper/quant_type_determiner.h"
|
||||
#include "tools/converter/quantizer/quant_helper/propagete_quant_param_pass.h"
|
||||
#include "tools/converter/quantizer/quant_helper/dtype_transform_pass.h"
|
||||
#include "tools/converter/quantizer/quant_helper/graph_inout_transform_pass.h"
|
||||
#include "tools/converter/quantizer/quant_helper/quant_node_pass.h"
|
||||
#include "tools/converter/quantizer/insert_quant_node_manager.h"
|
||||
|
||||
|
@ -445,7 +446,7 @@ STATUS AnfTransform::QATTransform(const FuncGraphPtr &func_graph, const std::sha
|
|||
quant::InsertQuantNodeManager inset_quant_node_pass;
|
||||
ret = inset_quant_node_pass.InsertQuantDtypeCastNode(func_graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "add QuantCast error";
|
||||
MS_LOG(ERROR) << "Add QuantCast error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
|
@ -15,12 +15,13 @@
|
|||
*/
|
||||
|
||||
#define USE_DEPRECATED_API
|
||||
|
||||
#include "mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ops/quant_dtype_cast.h"
|
||||
#include <algorithm>
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
|
@ -30,21 +31,24 @@ namespace {
|
|||
constexpr size_t kMinSize3 = 3;
|
||||
constexpr size_t kPrimitiveCOffset = 1;
|
||||
} // namespace
|
||||
ValueNodePtr InsertQuantNodeManager::NewQuantCastPrimitive(
|
||||
int src_type, int dst_type, const std::vector<schema::QuantParamT> &input_quant_params,
|
||||
const std::vector<schema::QuantParamT> &output_quant_params) {
|
||||
auto prim_c = std::make_shared<ops::QuantDTypeCast>();
|
||||
MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
|
||||
prim_c->Init(src_type, dst_type);
|
||||
auto quant_params_holder = std::make_shared<QuantParamHolder>(input_quant_params.size(), output_quant_params.size());
|
||||
MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
|
||||
quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL);
|
||||
quant_params_holder->set_input_quant_param(0, input_quant_params);
|
||||
quant_params_holder->set_output_quant_param(0, output_quant_params);
|
||||
auto prim = prim_c->GetPrim();
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
|
||||
prim->AddAttr("quant_params", quant_params_holder);
|
||||
return NewValueNode(prim);
|
||||
int InsertQuantNodeManager::SetCastNodeAbstrac(const CNodePtr &cnode, const AnfNodePtr &input_node,
|
||||
const CNodePtr &cast_cnode) {
|
||||
CHECK_NULL_RETURN(cnode);
|
||||
CHECK_NULL_RETURN(input_node);
|
||||
CHECK_NULL_RETURN(cast_cnode);
|
||||
|
||||
AbstractBasePtr abstract;
|
||||
if (cnode->abstract() != nullptr) {
|
||||
abstract = cnode->abstract()->Clone();
|
||||
} else if (input_node->abstract() != nullptr) {
|
||||
abstract = input_node->abstract()->Clone();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Abstract is nullptr, cnode name: " << cnode->fullname_with_scope()
|
||||
<< " input node: " << input_node->fullname_with_scope();
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
cast_cnode->set_abstract(abstract);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int InsertQuantNodeManager::InsertCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t input_index,
|
||||
|
@ -95,47 +99,43 @@ int InsertQuantNodeManager::InsertCastNode(const FuncGraphPtr &graph, const CNod
|
|||
<< " input data type: " << data_type << " index " << input_index;
|
||||
return RET_ERROR;
|
||||
}
|
||||
ValueNodePtr value_node;
|
||||
// quant node, uint8toint8: update input_quant_param of cnode
|
||||
|
||||
TypeId src_dtype;
|
||||
TypeId dst_dtype;
|
||||
std::vector<schema::QuantParamT> input_quant_params;
|
||||
std::vector<schema::QuantParamT> output_quant_params;
|
||||
|
||||
// quant node, uint8toint8: update output_quant_params
|
||||
if (insert_quant_node) {
|
||||
src_dtype = data_type;
|
||||
dst_dtype = kNumberTypeInt8;
|
||||
auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
if (curr_primitive_quant_param_holder->get_input_quant_params().size() < input_index) {
|
||||
MS_LOG(ERROR) << "quant param is invalid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto output_quant_params = curr_primitive_quant_param_holder->get_input_quant_params()[input_index - 1];
|
||||
std::vector<schema::QuantParamT> input_quant_params(output_quant_params);
|
||||
// Uint8toInt8
|
||||
if (data_type == kNumberTypeUInt8) {
|
||||
for (auto &quant_param : output_quant_params) {
|
||||
quant_param.zeroPoint -= kU8ZeroPointOffset;
|
||||
}
|
||||
curr_primitive_quant_param_holder->set_input_quant_param(input_index - 1, output_quant_params);
|
||||
}
|
||||
value_node = NewQuantCastPrimitive(data_type, kNumberTypeInt8, input_quant_params, output_quant_params);
|
||||
} else { // insert_dequant_node, int8touint8: update output_quant_params of input_node
|
||||
output_quant_params = curr_primitive_quant_param_holder->get_input_quant_params()[input_index - 1];
|
||||
std::copy(output_quant_params.cbegin(), output_quant_params.cend(), std::back_inserter(input_quant_params));
|
||||
} else { // insert_dequant_node, int8touint8: update input_quant_params
|
||||
src_dtype = kNumberTypeInt8;
|
||||
dst_dtype = data_type;
|
||||
auto input_primitive_quant_param_holder = GetCNodeQuantHolder(input_cnode_primitive_c);
|
||||
if (input_primitive_quant_param_holder->get_output_quant_params().empty()) {
|
||||
MS_LOG(ERROR) << "output quant param is empty.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<schema::QuantParamT> input_quant_params =
|
||||
input_primitive_quant_param_holder->get_output_quant_params()[0];
|
||||
std::vector<schema::QuantParamT> output_quant_params(input_quant_params);
|
||||
if (data_type == kNumberTypeUInt8) {
|
||||
for (auto &quant_param : input_quant_params) {
|
||||
quant_param.zeroPoint -= kU8ZeroPointOffset;
|
||||
}
|
||||
input_primitive_quant_param_holder->set_output_quant_param(0, input_quant_params);
|
||||
}
|
||||
value_node = NewQuantCastPrimitive(kNumberTypeInt8, data_type, input_quant_params, output_quant_params);
|
||||
input_quant_params = input_primitive_quant_param_holder->get_output_quant_params()[0];
|
||||
std::copy(input_quant_params.cbegin(), input_quant_params.cend(), std::back_inserter(output_quant_params));
|
||||
}
|
||||
std::vector<AnfNodePtr> op_inputs = {value_node, input_node};
|
||||
ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params);
|
||||
std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node};
|
||||
auto quant_cast_cnode = graph->NewCNode(op_inputs);
|
||||
MS_CHECK_TRUE_MSG(quant_cast_cnode != nullptr, RET_NULL_PTR, "quant_cast_cnode is nullptr.");
|
||||
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast_" +
|
||||
std::to_string(input_index));
|
||||
cnode->set_input(input_index, quant_cast_cnode);
|
||||
MS_LOG(INFO) << "InsertCastNode cnode name: " << quant_cast_cnode->fullname_with_scope()
|
||||
<< " src_dtype: " << src_dtype << " dst_dtype: " << dst_dtype;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -171,8 +171,8 @@ int InsertQuantNodeManager::CheckDataType(const AnfNodePtr &input_node, TypeId c
|
|||
TypeId type_id;
|
||||
auto ret = opt::GetDataTypeFromAnfNode(input_node, &type_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Fetch DataType from cnode failed.";
|
||||
return ret;
|
||||
MS_LOG(WARNING) << "Fetch DataType from cnode failed.";
|
||||
return RET_OK;
|
||||
}
|
||||
if (type_id != check_type_id) {
|
||||
return RET_NO_CHANGE;
|
||||
|
@ -181,9 +181,9 @@ int InsertQuantNodeManager::CheckDataType(const AnfNodePtr &input_node, TypeId c
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph, TypeId src_dtype) {
|
||||
int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph, TypeId cast_dtype) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
if (src_dtype != kNumberTypeFloat32 && src_dtype != kNumberTypeUInt8) {
|
||||
if (cast_dtype != kNumberTypeFloat32 && cast_dtype != kNumberTypeUInt8) {
|
||||
MS_LOG(ERROR) << "Invalid src dtype, only support fp32 and uint8.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -192,11 +192,7 @@ int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph,
|
|||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||
auto input_node = cnode->input(i);
|
||||
bool is_graph_input = IsGraphInput(input_node);
|
||||
// Uint8 check quant params inited
|
||||
if (src_dtype == kNumberTypeUInt8 && !is_graph_input && !CheckInited(input_node, i)) {
|
||||
continue;
|
||||
}
|
||||
auto ret = CheckDataType(input_node, src_dtype);
|
||||
auto ret = CheckDataType(input_node, cast_dtype);
|
||||
if (ret == RET_NO_CHANGE) {
|
||||
continue;
|
||||
} else if (ret != RET_OK) {
|
||||
|
@ -312,4 +308,139 @@ int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int InsertQuantNodeManager::InserQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode,
|
||||
InsertDirection insert_direction, TypeId cast_dtype,
|
||||
CastNodeType cast_node_type, size_t index,
|
||||
const AnfNodePtr &output_node) {
|
||||
if (insert_direction == FORWARD && cast_node_type == kQuant) {
|
||||
return InserForwardQuantCastNode(graph, cnode, cast_dtype, index);
|
||||
} else if (insert_direction == BACKWARD && cast_node_type == kDeQuant) {
|
||||
return InserBackwardDeQuantCastNode(graph, cnode, cast_dtype, index, output_node);
|
||||
}
|
||||
MS_LOG(ERROR) << "Invalie insert direction: " << insert_direction;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
||||
int InsertQuantNodeManager::InserForwardQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode,
|
||||
TypeId cast_dtype, size_t index) {
|
||||
if (cast_dtype != kNumberTypeUInt8 && cast_dtype != kNumberTypeFloat32) {
|
||||
MS_LOG(ERROR) << "Invalie cast dtype: " << cast_dtype;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
|
||||
auto input_node = cnode->input(index);
|
||||
CHECK_NULL_RETURN(input_node);
|
||||
if (!input_node->isa<mindspore::CNode>() && !IsGraphInput(input_node)) {
|
||||
MS_LOG(ERROR) << "Invalid input node, input node name: " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = CheckDataType(input_node, cast_dtype);
|
||||
if (ret == RET_NO_CHANGE) {
|
||||
MS_LOG(DEBUG) << "input node not dtype: " << cast_dtype;
|
||||
return RET_OK;
|
||||
} else if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Check data type failed, input node name: " << input_node->fullname_with_scope();
|
||||
return ret;
|
||||
}
|
||||
// insert forward cast_node
|
||||
TypeId src_dtype = cast_dtype;
|
||||
TypeId dst_dtype = kNumberTypeInt8;
|
||||
std::vector<schema::QuantParamT> input_quant_params;
|
||||
std::vector<schema::QuantParamT> output_quant_params;
|
||||
|
||||
auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(cnode);
|
||||
CHECK_NULL_RETURN(curr_primitive_quant_param_holder);
|
||||
if (curr_primitive_quant_param_holder->get_input_quant_params().size() < index) {
|
||||
MS_LOG(ERROR) << "quant param is invalid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
output_quant_params = curr_primitive_quant_param_holder->get_input_quant_params()[index - 1];
|
||||
std::copy(output_quant_params.cbegin(), output_quant_params.cend(), std::back_inserter(input_quant_params));
|
||||
// Uint8toInt8
|
||||
if (src_dtype == kNumberTypeUInt8) {
|
||||
for (auto &quant_param : input_quant_params) {
|
||||
quant_param.zeroPoint += kU8ZeroPointOffset;
|
||||
}
|
||||
}
|
||||
ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params);
|
||||
std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node};
|
||||
auto quant_cast_cnode = graph->NewCNode(op_inputs);
|
||||
MS_CHECK_TRUE_MSG(quant_cast_cnode != nullptr, RET_NULL_PTR, "quant_cast_cnode is nullptr.");
|
||||
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) +
|
||||
"_pre");
|
||||
if (SetCastNodeAbstrac(cnode, input_node, quant_cast_cnode) != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetCastNodeAbstrac failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (quant::UpdateDataType(quant_cast_cnode, dst_dtype) != RET_OK) {
|
||||
MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
cnode->set_input(index, quant_cast_cnode);
|
||||
MS_LOG(INFO) << "InserForwardQuantCastNode cnode name: " << cnode->fullname_with_scope() << " src dtype:" << src_dtype
|
||||
<< " dst_type: " << dst_dtype;
|
||||
return RET_OK;
|
||||
}
|
||||
int InsertQuantNodeManager::InserBackwardDeQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode,
|
||||
TypeId cast_dtype, size_t index,
|
||||
const AnfNodePtr &output_node) {
|
||||
if (cast_dtype != kNumberTypeUInt8 && cast_dtype != kNumberTypeFloat32) {
|
||||
MS_LOG(ERROR) << "Invalie cast dtype: " << cast_dtype;
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
CHECK_NULL_RETURN(output_node);
|
||||
auto ret = CheckDataType(output_node, cast_dtype);
|
||||
if (ret == RET_NO_CHANGE) {
|
||||
MS_LOG(DEBUG) << "input node not dtype: " << cast_dtype;
|
||||
return RET_OK;
|
||||
} else if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Check data type failed, output node name: " << output_node->fullname_with_scope();
|
||||
return ret;
|
||||
}
|
||||
auto manager = graph->manager();
|
||||
if (manager == nullptr) {
|
||||
manager = Manage(graph, true);
|
||||
}
|
||||
CHECK_NULL_RETURN(manager);
|
||||
|
||||
// insert backward cast_node
|
||||
TypeId src_dtype = kNumberTypeInt8;
|
||||
TypeId dst_dtype = cast_dtype;
|
||||
std::vector<schema::QuantParamT> input_quant_params;
|
||||
std::vector<schema::QuantParamT> output_quant_params;
|
||||
|
||||
auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(cnode);
|
||||
CHECK_NULL_RETURN(curr_primitive_quant_param_holder);
|
||||
if (curr_primitive_quant_param_holder->get_output_quant_params().empty()) {
|
||||
MS_LOG(ERROR) << "quant param is invalid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
input_quant_params = curr_primitive_quant_param_holder->get_output_quant_params().front();
|
||||
std::copy(input_quant_params.cbegin(), input_quant_params.cend(), std::back_inserter(output_quant_params));
|
||||
// Int8toUint8
|
||||
if (dst_dtype == kNumberTypeUInt8) {
|
||||
for (auto &quant_param : output_quant_params) {
|
||||
quant_param.zeroPoint += kU8ZeroPointOffset;
|
||||
}
|
||||
}
|
||||
ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params);
|
||||
std::vector<AnfNodePtr> op_inputs = {new_primitive, cnode->cast<AnfNodePtr>()};
|
||||
auto quant_cast_cnode = graph->NewCNode(op_inputs);
|
||||
MS_CHECK_TRUE_MSG(quant_cast_cnode != nullptr, RET_NULL_PTR, "quant_cast_cnode is nullptr.");
|
||||
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) +
|
||||
"_post");
|
||||
if (SetCastNodeAbstrac(cnode, output_node, quant_cast_cnode) != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetCastNodeAbstrac failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (quant::UpdateDataType(quant_cast_cnode, dst_dtype) != RET_OK) {
|
||||
MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
manager->SetEdge(output_node, index, quant_cast_cnode);
|
||||
MS_LOG(INFO) << "InserBackwardDeQuantCastNode cnode name: " << cnode->fullname_with_scope()
|
||||
<< " src dtype:" << src_dtype << " dst_type: " << dst_dtype;
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -33,16 +33,14 @@ class InsertQuantNodeManager {
|
|||
|
||||
~InsertQuantNodeManager() = default;
|
||||
|
||||
int InsertQuantDtypeCastNode(const FuncGraphPtr &graph, TypeId src_dtype = kNumberTypeFloat32);
|
||||
int InsertQuantDtypeCastNode(const FuncGraphPtr &graph, TypeId cast_dtype = kNumberTypeFloat32);
|
||||
|
||||
int InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set<PrimitivePtr> &support_dynamic_quant_ops,
|
||||
const std::set<std::string> &skip_quant_node);
|
||||
int InserQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, InsertDirection insert_direction,
|
||||
TypeId cast_dtype, CastNodeType cast_node_type, size_t index, const AnfNodePtr &output_node);
|
||||
|
||||
private:
|
||||
ValueNodePtr NewQuantCastPrimitive(int src_type, int dst_type,
|
||||
const std::vector<schema::QuantParamT> &input_quant_params,
|
||||
const std::vector<schema::QuantParamT> &output_quant_params);
|
||||
|
||||
int InsertCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t input_index, bool is_graph_input);
|
||||
|
||||
bool CheckInited(const AnfNodePtr &input_node, size_t index) const;
|
||||
|
@ -55,6 +53,13 @@ class InsertQuantNodeManager {
|
|||
|
||||
int InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index);
|
||||
|
||||
int SetCastNodeAbstrac(const CNodePtr &cnode, const AnfNodePtr &input_node, const CNodePtr &cast_cnode);
|
||||
|
||||
int InserForwardQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype, size_t index);
|
||||
|
||||
int InserBackwardDeQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype, size_t index,
|
||||
const AnfNodePtr &output_node);
|
||||
|
||||
private:
|
||||
TypeId dst_type_ = kNumberTypeInt8;
|
||||
bool symmetric_ = false;
|
||||
|
|
|
@ -15,75 +15,55 @@
|
|||
*/
|
||||
|
||||
#include "tools/converter/quantizer/quant_helper/dtype_transform_pass.h"
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/converter/quantizer/insert_quant_node_manager.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
// only enable for uint8
|
||||
int DTypeTransformPass::Transform() {
|
||||
CHECK_NULL_RETURN(func_graph_);
|
||||
// insert CastNode Uint8toInt8 & Int8toUint8
|
||||
quant::InsertQuantNodeManager insert_quant_node_manager;
|
||||
auto ret = insert_quant_node_manager.InsertQuantDtypeCastNode(func_graph_, kNumberTypeUInt8);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Insert Uint8toInt8 CastNode failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// update data type and zp
|
||||
auto cnodes = func_graph_->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
|
||||
if (!CheckNeedDTypeTrans(cnode)) {
|
||||
MS_LOG(INFO) << "CheckNeedDTypeTrans invalid cnode, cnode name: " << cnode->fullname_with_scope();
|
||||
continue;
|
||||
}
|
||||
TypeId cnode_dtype;
|
||||
if (opt::GetDataTypeFromAnfNode(cnode, &cnode_dtype) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get data type failed, cnode type: " << cnode->type_name();
|
||||
auto status = DoNodeDTypeTrans(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoNodeDTypeTrans failed, cnode name: " << cnode->fullname_with_scope();
|
||||
return status;
|
||||
}
|
||||
schema::QuantType curr_quant_type;
|
||||
if (GetQuantType(cnode, &curr_quant_type) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get quant type failed, cnode name: " << cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (cnode_dtype != kNumberTypeUInt8) {
|
||||
if (curr_quant_type != schema::QuantType_QUANT_ALL) {
|
||||
MS_LOG(ERROR) << "Invalid cnode quant type, cnode name: " << cnode->fullname_with_scope()
|
||||
<< " quant type: " << curr_quant_type;
|
||||
continue;
|
||||
}
|
||||
if (UpdateDataType(cnode, kNumberTypeInt8) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Update data type failed, cnode name: " << cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
status = InsertForwardCastNode(cnode, curr_quant_type);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertForwardCastNode failed, cnode name: " << cnode->fullname_with_scope();
|
||||
return status;
|
||||
}
|
||||
auto curr_quant_param_holder = GetCNodeQuantHolder(cnode);
|
||||
CHECK_NULL_RETURN(curr_quant_param_holder);
|
||||
if (curr_quant_param_holder->get_output_quant_params().empty()) {
|
||||
MS_LOG(ERROR) << "output quant params empty.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto out_quant_params = curr_quant_param_holder->get_output_quant_params()[0];
|
||||
for (auto &quant_param : out_quant_params) {
|
||||
quant_param.zeroPoint -= kU8ZeroPointOffset;
|
||||
}
|
||||
curr_quant_param_holder->set_output_quant_param(0, out_quant_params);
|
||||
for (size_t i = 1; i < cnode->size(); i++) {
|
||||
auto input_node = cnode->input(i);
|
||||
CHECK_NULL_RETURN(input_node);
|
||||
if (IsGraphInput(input_node)) {
|
||||
continue;
|
||||
} else if (input_node->isa<mindspore::CNode>()) {
|
||||
// updata input_quant_params
|
||||
if (curr_quant_param_holder->get_input_quant_params().size() < i) {
|
||||
MS_LOG(ERROR) << "quant params invalid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input_quant_params = curr_quant_param_holder->get_input_quant_params()[i - 1];
|
||||
for (auto &quant_param : input_quant_params) {
|
||||
quant_param.zeroPoint -= kU8ZeroPointOffset;
|
||||
}
|
||||
curr_quant_param_holder->set_input_quant_param(i - 1, input_quant_params);
|
||||
} else if (input_node->isa<mindspore::Parameter>()) {
|
||||
ret = DoParameterNodeTrans(cnode, input_node->cast<ParameterPtr>(), i);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(WARNING) << "DoParameterNodeTrans failed, input node name: " << input_node->fullname_with_scope();
|
||||
}
|
||||
// DetectionPostProcess op(Uint8toFp32, not need backward cast node)
|
||||
if (!CheckNodeInSet(cnode, kUint8toFP32Operator)) {
|
||||
status = InsertBackwardCastNode(cnode, curr_quant_type);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertBackwardCastNode failed, cnode name: " << cnode->fullname_with_scope();
|
||||
return status;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // for
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -107,8 +87,11 @@ int DTypeTransformPass::DoParameterNodeTrans(const CNodePtr &cnode, const Parame
|
|||
return ret;
|
||||
}
|
||||
auto quant_param_holder = GetCNodeQuantHolder(cnode);
|
||||
if (quant_param_holder->get_input_quant_params().size() < input_index) {
|
||||
MS_LOG(ERROR) << "Invalid quant params. input node name: " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto quant_params = quant_param_holder->get_input_quant_params().at(input_index - 1);
|
||||
MS_CHECK_FALSE_MSG(quant_params.empty(), RET_ERROR, "Quant params is empty.");
|
||||
for (auto &quant_param : quant_params) {
|
||||
quant_param.zeroPoint -= kU8ZeroPointOffset;
|
||||
}
|
||||
|
@ -121,18 +104,6 @@ int DTypeTransformPass::DoParameterNodeTrans(const CNodePtr &cnode, const Parame
|
|||
MS_LOG(ERROR) << input_node->fullname_with_scope() << " set new dtype failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto abstract_base = input_node->abstract();
|
||||
CHECK_NULL_RETURN(abstract_base);
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
|
||||
MS_LOG(ERROR) << "Abstract of node should be abstract tensor, input node name: "
|
||||
<< input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||
CHECK_NULL_RETURN(abstract_tensor);
|
||||
CHECK_NULL_RETURN(abstract_tensor->element());
|
||||
abstract_tensor->element()->set_type(TypeIdToType(kNumberTypeInt8));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -151,4 +122,164 @@ int DTypeTransformPass::Uint8toInt8(uint8_t *data, int size) {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DTypeTransformPass::GetQuantType(const CNodePtr &cnode, schema::QuantType *quant_type) {
|
||||
CHECK_NULL_RETURN(cnode);
|
||||
auto quant_param_holder = GetCNodeQuantHolder(cnode);
|
||||
CHECK_NULL_RETURN(quant_param_holder);
|
||||
*quant_type = quant_param_holder->quant_type();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform CNode(dtype,uint8toint8,weigh data)
|
||||
* */
|
||||
int DTypeTransformPass::DoNodeDTypeTrans(const CNodePtr &cnode) {
|
||||
auto curr_quant_param_holder = GetCNodeQuantHolder(cnode);
|
||||
CHECK_NULL_RETURN(curr_quant_param_holder);
|
||||
TypeId cnode_dtype = kTypeUnknown;
|
||||
if (opt::GetDataTypeFromAnfNode(cnode, &cnode_dtype) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get data type failed, cnode name: " << cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (cnode_dtype == kNumberTypeUInt8) {
|
||||
MS_LOG(INFO) << "cnode dtype kNumberTypeUInt8, cnode name: " << cnode->fullname_with_scope();
|
||||
if (UpdateDataType(cnode, kNumberTypeInt8) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Update data type failed, cnode name: " << cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(0));
|
||||
auto primc = api::MakeShared<mindspore::ops::QuantDTypeCast>(primitive_c);
|
||||
primc->set_dst_t(kNumberTypeInt8);
|
||||
}
|
||||
// update output quant param zp
|
||||
if (curr_quant_param_holder->get_output_quant_params().empty()) {
|
||||
MS_LOG(ERROR) << "output quant params empty.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto out_quant_params = curr_quant_param_holder->get_output_quant_params()[0];
|
||||
for (auto &quant_param : out_quant_params) {
|
||||
quant_param.zeroPoint -= kU8ZeroPointOffset;
|
||||
}
|
||||
curr_quant_param_holder->set_output_quant_param(0, out_quant_params);
|
||||
}
|
||||
|
||||
// DTypeCastNode, set quant type
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
|
||||
curr_quant_param_holder->set_quant_type(schema::QuantType_QUANT_NONE);
|
||||
}
|
||||
|
||||
for (size_t index = 1; index < cnode->size(); index++) {
|
||||
auto input_node = cnode->input(index);
|
||||
CHECK_NULL_RETURN(input_node);
|
||||
if (IsGraphInput(input_node) || input_node->isa<mindspore::CNode>()) {
|
||||
// updata graph input quant params
|
||||
if (curr_quant_param_holder->get_input_quant_params().size() < index) {
|
||||
MS_LOG(WARNING) << "quant params invalid, input node name: " << input_node->fullname_with_scope();
|
||||
continue;
|
||||
}
|
||||
auto input_quant_params = curr_quant_param_holder->get_input_quant_params()[index - 1];
|
||||
if (input_quant_params.empty() || !input_quant_params.front().inited) {
|
||||
MS_LOG(WARNING) << "input node not quantizied, input node name: " << input_node->fullname_with_scope();
|
||||
continue;
|
||||
}
|
||||
for (auto &quant_param : input_quant_params) {
|
||||
quant_param.zeroPoint -= kU8ZeroPointOffset;
|
||||
}
|
||||
curr_quant_param_holder->set_input_quant_param(index - 1, input_quant_params);
|
||||
} else if (input_node->isa<mindspore::Parameter>()) { // weight data
|
||||
auto ret = DoParameterNodeTrans(cnode, input_node->cast<ParameterPtr>(), index);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(WARNING) << "DoParameterNodeTrans failed, input node name: " << input_node->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DTypeTransformPass::InsertForwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type) {
|
||||
// inputs
|
||||
quant::InsertQuantNodeManager insert_node_manager;
|
||||
for (size_t index = 1; index < cnode->size(); index++) {
|
||||
auto input_node = cnode->input(index);
|
||||
CHECK_NULL_RETURN(input_node);
|
||||
if (!input_node->isa<mindspore::CNode>() && !IsGraphInput(input_node)) {
|
||||
MS_LOG(DEBUG) << "Invalid input node, not CNode and graph input.";
|
||||
continue;
|
||||
}
|
||||
schema::QuantType input_quant_type;
|
||||
if (GetQuantType(cnode, &input_quant_type) != RET_OK) {
|
||||
MS_LOG(WARNING) << "Get quant type failed, input node name: " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
schema::QuantType pre_quant_type = schema::QuantType_QUANT_NONE;
|
||||
if (input_node->isa<mindspore::CNode>()) {
|
||||
if (GetQuantType(input_node->cast<mindspore::CNodePtr>(), &pre_quant_type) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get quant type failed, cnode name: " << cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
if (pre_quant_type == schema::QuantType_QUANT_NONE && curr_quant_type == schema::QuantType_QUANT_ALL) {
|
||||
auto status = insert_node_manager.InserQuantCastNode(this->func_graph_, cnode, FORWARD, kNumberTypeUInt8, kQuant,
|
||||
index, nullptr);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InserQuantCastNode failed, cnode name: " << cnode->fullname_with_scope();
|
||||
return status;
|
||||
}
|
||||
MS_LOG(INFO) << "InserQuantCastNode forward Uint8toInt8, cnode name: " << cnode->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DTypeTransformPass::InsertBackwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type) {
|
||||
// outputs
|
||||
auto manager = this->func_graph_->manager();
|
||||
if (manager == nullptr) {
|
||||
manager = Manage(this->func_graph_, true);
|
||||
}
|
||||
CHECK_NULL_RETURN(manager);
|
||||
auto node_users = manager->node_users()[cnode];
|
||||
quant::InsertQuantNodeManager insert_node_manager;
|
||||
for (auto &node_user : node_users) {
|
||||
auto output_cnode = node_user.first->cast<CNodePtr>();
|
||||
schema::QuantType post_quant_type;
|
||||
if (GetQuantType(output_cnode, &post_quant_type) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get quant type failed, cnode name: " << output_cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (curr_quant_type == schema::QuantType_QUANT_ALL && post_quant_type == schema::QuantType_QUANT_NONE) {
|
||||
auto status = insert_node_manager.InserQuantCastNode(this->func_graph_, cnode, BACKWARD, kNumberTypeUInt8,
|
||||
kDeQuant, node_user.second, node_user.first);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "InserQuantCastNode dequant failed, cnode name: " << cnode->fullname_with_scope();
|
||||
return status;
|
||||
}
|
||||
MS_LOG(INFO) << "InserQuantCastNode backward Int8toUint8, cnode name: " << cnode->fullname_with_scope();
|
||||
}
|
||||
} // node_users
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool DTypeTransformPass::CheckNeedDTypeTrans(const CNodePtr &cnode) {
|
||||
if (opt::IsSpecialType(cnode)) {
|
||||
return false;
|
||||
}
|
||||
if (IsGraphInDTypeCast(cnode) || IsGraphOutDTypeCast(func_graph_, cnode)) {
|
||||
return false;
|
||||
}
|
||||
TypeId cnode_dtype = kTypeUnknown;
|
||||
if (opt::GetDataTypeFromAnfNode(cnode, &cnode_dtype) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get data type failed, cnode name: " << cnode->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
bool is_fp32_output =
|
||||
opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast) || CheckNodeInSet(cnode, kUint8toFP32Operator);
|
||||
if (cnode_dtype != kNumberTypeUInt8 && !is_fp32_output) {
|
||||
MS_LOG(DEBUG) << "dtype not kNumberTypeUInt8, cnode name: " << cnode->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -16,24 +16,43 @@
|
|||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_DTYPE_TRANSFORM_PASS_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_DTYPE_TRANSFORM_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "tools/converter/quantizer/quant_param_holder.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/converter/quantizer/quant_params.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
/**
|
||||
* Transform CNode(dtype,uint8toint8,weigh data)
|
||||
* Insert QuantCastNode
|
||||
* */
|
||||
class DTypeTransformPass {
|
||||
public:
|
||||
explicit DTypeTransformPass(const FuncGraphPtr &func_graph) : func_graph_(func_graph) {}
|
||||
|
||||
~DTypeTransformPass() = default;
|
||||
|
||||
int Transform();
|
||||
|
||||
private:
|
||||
int DoParameterNodeTrans(const CNodePtr &cnode, const ParameterPtr &input_node, size_t input_index);
|
||||
|
||||
int Uint8toInt8(uint8_t *data, int size);
|
||||
|
||||
int GetQuantType(const CNodePtr &cnode, schema::QuantType *quant_type);
|
||||
|
||||
int DoNodeDTypeTrans(const CNodePtr &cnode);
|
||||
|
||||
bool CheckNeedDTypeTrans(const CNodePtr &cnode);
|
||||
|
||||
int InsertForwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type);
|
||||
|
||||
int InsertBackwardCastNode(const CNodePtr &cnode, schema::QuantType curr_quant_type);
|
||||
|
||||
FuncGraphPtr func_graph_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -0,0 +1,224 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/quantizer/quant_helper/graph_inout_transform_pass.h"
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "tools/common/node_util.h"
|
||||
#include "tools/converter/quantizer/insert_quant_node_manager.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "ops/quant_dtype_cast.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
int GraphInoutTransformPass::Transform() {
|
||||
auto ret = DoGraphInputDTypeTransform();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoGraphInputDTypeTransform failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = DoGraphOutputDTypeTransform();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoGraphOutputDTypeTransform failed.";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GraphInoutTransformPass::DoGraphInputDTypeTransform() {
|
||||
CHECK_NULL_RETURN(this->func_graph_);
|
||||
if (this->graph_input_dtype_ != TypeId::kNumberTypeFloat32 && this->graph_input_dtype_ != TypeId::kNumberTypeUInt8 &&
|
||||
this->graph_input_dtype_ != TypeId::kNumberTypeInt8 && this->graph_input_dtype_ != TypeId::kTypeUnknown) {
|
||||
MS_LOG(ERROR) << "Invalid graph input dtype: " << this->graph_input_dtype_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
// not specify inputDataType
|
||||
if (this->graph_input_dtype_ == TypeId::kTypeUnknown) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto cnodes = this->func_graph_->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
for (size_t index = 1; index < cnode->size(); index++) {
|
||||
auto input_node = cnode->input(index);
|
||||
CHECK_NULL_RETURN(input_node);
|
||||
if (!IsGraphInput(input_node)) {
|
||||
continue;
|
||||
}
|
||||
TypeId input_node_dtype = TypeId::kTypeUnknown;
|
||||
if (opt::GetDataTypeFromAnfNode(input_node, &input_node_dtype) != RET_OK) {
|
||||
MS_LOG(ERROR) << "GetDataTypeFromAnfNode failed, input node name: " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
// graph input dtype transform
|
||||
if (this->graph_input_dtype_ != input_node_dtype) {
|
||||
auto ret = InsertDTypeCastNode(cnode, index, this->graph_input_dtype_, input_node_dtype, kQuant);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertDTypeCastNode failed, cnode name: " << cnode->fullname_with_scope()
|
||||
<< " input index: " << index;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
} // for
|
||||
} // for
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GraphInoutTransformPass::DoGraphOutputDTypeTransform() {
|
||||
CHECK_NULL_RETURN(this->func_graph_);
|
||||
if (this->graph_output_dtype_ != TypeId::kNumberTypeFloat32 &&
|
||||
this->graph_output_dtype_ != TypeId::kNumberTypeUInt8 && this->graph_output_dtype_ != TypeId::kNumberTypeInt8 &&
|
||||
this->graph_output_dtype_ != TypeId::kTypeUnknown) {
|
||||
MS_LOG(ERROR) << "Invalid graph output dtype: " << this->graph_input_dtype_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->graph_output_dtype_ == TypeId::kTypeUnknown) {
|
||||
return RET_OK;
|
||||
}
|
||||
auto cnodes = this->func_graph_->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
if (!IsGraphOutput(cnode)) {
|
||||
continue;
|
||||
}
|
||||
for (size_t index = 1; index < cnode->size(); index++) {
|
||||
auto input_node = cnode->input(index);
|
||||
CHECK_NULL_RETURN(input_node);
|
||||
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
|
||||
if (opt::CheckPrimitiveType(input_cnode, prim::kPrimMakeTuple)) { // MakeTuple
|
||||
for (size_t input_index = 1; input_index < input_cnode->size(); input_index++) {
|
||||
auto ret = DoOutputTransform(input_cnode, input_index);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoOutputTransform failed, cnode name: " << input_node->fullname_with_scope()
|
||||
<< " input index: " << input_index;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
} else { // Return
|
||||
auto ret = DoOutputTransform(cnode, index);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoOutputTransform failed, cnode name: " << cnode->fullname_with_scope()
|
||||
<< " input index: " << index;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
} // for
|
||||
} // for
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int GraphInoutTransformPass::DoOutputTransform(const CNodePtr &cnode, size_t index) {
|
||||
CHECK_NULL_RETURN(cnode);
|
||||
auto input_node = cnode->input(index);
|
||||
CHECK_NULL_RETURN(input_node);
|
||||
if (!input_node->isa<mindspore::CNode>()) {
|
||||
return RET_OK;
|
||||
}
|
||||
TypeId input_node_dtype = TypeId::kTypeUnknown;
|
||||
if (opt::GetDataTypeFromAnfNode(input_node, &input_node_dtype) != RET_OK) {
|
||||
MS_LOG(ERROR) << "GetDataTypeFromAnfNode failed, input node name: " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
// graph output dtype transform
|
||||
if (this->graph_output_dtype_ != input_node_dtype) {
|
||||
auto ret = InsertDTypeCastNode(cnode, index, input_node_dtype, this->graph_output_dtype_, kDeQuant);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InsertDTypeCastNode failed, cnode name: " << cnode->fullname_with_scope()
|
||||
<< " input index: " << index;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool GraphInoutTransformPass::IsGraphOutput(const CNodePtr &cnode) {
|
||||
return (opt::CheckPrimitiveType(cnode, prim::kPrimReturn));
|
||||
}
|
||||
|
||||
int GraphInoutTransformPass::InsertDTypeCastNode(const CNodePtr &cnode, size_t input_index, TypeId src_dtype,
|
||||
TypeId dst_dtype, CastNodeType node_type) {
|
||||
std::vector<schema::QuantParamT> input_quant_params;
|
||||
std::vector<schema::QuantParamT> output_quant_params;
|
||||
CHECK_NULL_RETURN(cnode);
|
||||
auto input_node = cnode->input(input_index);
|
||||
CHECK_NULL_RETURN(input_node);
|
||||
if (node_type == kQuant) {
|
||||
auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(cnode);
|
||||
CHECK_NULL_RETURN(curr_primitive_quant_param_holder);
|
||||
if (curr_primitive_quant_param_holder->get_input_quant_params().size() < input_index) {
|
||||
MS_LOG(ERROR) << "Invalid quant params.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
input_quant_params = curr_primitive_quant_param_holder->get_input_quant_params()[input_index - 1];
|
||||
} else if (node_type == kDeQuant) {
|
||||
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
|
||||
auto input_primitive_quant_param_holder = GetCNodeQuantHolder(input_cnode);
|
||||
CHECK_NULL_RETURN(input_primitive_quant_param_holder);
|
||||
if (input_primitive_quant_param_holder->get_output_quant_params().empty()) {
|
||||
MS_LOG(ERROR) << "Invalid quant params.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
input_quant_params = input_primitive_quant_param_holder->get_output_quant_params()[0];
|
||||
}
|
||||
std::copy(input_quant_params.cbegin(), input_quant_params.cend(), std::back_inserter(output_quant_params));
|
||||
// update zeroPoint(uint8toint8, int8touint8)
|
||||
if (src_dtype == TypeId::kNumberTypeUInt8 && dst_dtype == TypeId::kNumberTypeInt8) {
|
||||
output_quant_params.front().zeroPoint -= kU8ZeroPointOffset;
|
||||
} else if (src_dtype == TypeId::kNumberTypeInt8 && dst_dtype == TypeId::kNumberTypeUInt8) {
|
||||
input_quant_params.front().zeroPoint -= kU8ZeroPointOffset;
|
||||
}
|
||||
|
||||
auto primitive = quant::NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params);
|
||||
std::vector<AnfNodePtr> op_inputs = {primitive, input_node};
|
||||
auto quant_cast_cnode = this->func_graph_->NewCNode(op_inputs);
|
||||
CHECK_NULL_RETURN(quant_cast_cnode);
|
||||
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" +
|
||||
std::to_string(input_index));
|
||||
|
||||
// set quant_cast_cnode quant type
|
||||
AbstractBasePtr abstract;
|
||||
if (cnode->abstract() != nullptr) {
|
||||
abstract = cnode->abstract()->Clone();
|
||||
} else if (input_node->abstract() != nullptr) {
|
||||
abstract = input_node->abstract()->Clone();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Abstract is nullptr, cnode name: " << cnode->fullname_with_scope()
|
||||
<< " input node: " << input_node->fullname_with_scope();
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
quant_cast_cnode->set_abstract(abstract);
|
||||
auto quant_cast_cnode_param_holder = GetCNodeQuantHolder(quant_cast_cnode);
|
||||
CHECK_NULL_RETURN(quant_cast_cnode_param_holder);
|
||||
quant_cast_cnode_param_holder->set_quant_type(schema::QuantType_QUANT_NONE);
|
||||
|
||||
// update dtype: input_cnode, quant_cast_cnode
|
||||
auto ret = quant::UpdateDataType(input_node, src_dtype);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "UpdateDataType failed, input node name: " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
cnode->set_input(input_index, quant_cast_cnode);
|
||||
|
||||
ret = quant::UpdateDataType(std::dynamic_pointer_cast<mindspore::AnfNode>(quant_cast_cnode), dst_dtype);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_GRAPH_INOUT_TRANSFORM_PASS_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_GRAPH_INOUT_TRANSFORM_PASS_H_
|
||||
|
||||
#include <memory>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "tools/converter/quantizer/quant_param_holder.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/converter/quantizer/quant_params.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class GraphInoutTransformPass {
|
||||
public:
|
||||
explicit GraphInoutTransformPass(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> param)
|
||||
: func_graph_(func_graph), param_(param) {
|
||||
CHECK_NULL_RETURN_VOID(this->param_);
|
||||
this->graph_input_dtype_ = static_cast<TypeId>(this->param_->input_data_type);
|
||||
this->graph_output_dtype_ = static_cast<TypeId>(this->param_->output_data_type);
|
||||
}
|
||||
|
||||
~GraphInoutTransformPass() = default;
|
||||
|
||||
int Transform();
|
||||
|
||||
private:
|
||||
int DoGraphInputDTypeTransform();
|
||||
|
||||
int DoGraphOutputDTypeTransform();
|
||||
|
||||
int InsertDTypeCastNode(const CNodePtr &cnode, size_t input_index, TypeId src_dtype, TypeId dst_dtype,
|
||||
CastNodeType node_type);
|
||||
|
||||
int DoOutputTransform(const CNodePtr &cnode, size_t index);
|
||||
|
||||
bool IsGraphOutput(const CNodePtr &cnode);
|
||||
|
||||
FuncGraphPtr func_graph_ = nullptr;
|
||||
std::shared_ptr<ConverterPara> param_ = nullptr;
|
||||
TypeId graph_input_dtype_ = TypeId::kTypeUnknown;
|
||||
TypeId graph_output_dtype_ = TypeId::kTypeUnknown;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_GRAPH_INOUT_TRANSFORM_PASS_H_
|
|
@ -118,6 +118,9 @@ int PropagateQuantParamPass::ForwardPropagate(const std::list<CNodePtr> &nodes)
|
|||
if (IsGraphInput(cnode) || opt::IsSpecialType(cnode) || opt::CheckPrimitiveType(cnode, prim::kPrimLstm)) {
|
||||
continue;
|
||||
}
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimLstm)) {
|
||||
continue;
|
||||
}
|
||||
// Infer quant param with forward (output->input).
|
||||
auto curr_quant_holder = GetCNodeQuantHolder(cnode);
|
||||
if (curr_quant_holder == nullptr) {
|
||||
|
@ -182,6 +185,9 @@ int PropagateQuantParamPass::BackwardPropagate(const std::list<CNodePtr> &nodes)
|
|||
if (IsGraphInput(cnode) || opt::IsSpecialType(cnode) || opt::CheckPrimitiveType(cnode, prim::kPrimLstm)) {
|
||||
continue;
|
||||
}
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimLstm)) {
|
||||
continue;
|
||||
}
|
||||
// Infer quant param with forward (output<-input).
|
||||
auto curr_quant_holder = GetCNodeQuantHolder(cnode);
|
||||
if (curr_quant_holder == nullptr) {
|
||||
|
@ -214,7 +220,7 @@ int PropagateQuantParamPass::BackwardPropagate(const std::list<CNodePtr> &nodes)
|
|||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Support for multi output.";
|
||||
MS_LOG(INFO) << "Support for multi output.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,6 +41,10 @@ int QuantNodePass::DoWeightQuant(const CNodePtr &cnode) {
|
|||
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " can not quant weight";
|
||||
continue;
|
||||
}
|
||||
if (!CanTensorQuantized(cnode, input)) {
|
||||
MS_LOG(INFO) << input->fullname_with_scope() << " is not quantized.";
|
||||
continue;
|
||||
}
|
||||
int preferred_dim = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32(weight->shape()));
|
||||
MS_CHECK_GT(static_cast<int>(quant_param_holder->get_input_quant_params().size()), static_cast<int>(idx - 1),
|
||||
RET_ERROR);
|
||||
|
@ -265,6 +269,40 @@ int QuantNodePass::DoFullQuant(const CNodePtr &cnode) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
bool QuantNodePass::CanTensorQuantized(const CNodePtr &cnode, const AnfNodePtr &input_node) {
|
||||
if (input_node == nullptr) {
|
||||
MS_LOG(INFO) << "CanTensorQuantized input is nullptr!";
|
||||
return false;
|
||||
}
|
||||
ParameterPtr param_node = nullptr;
|
||||
if (input_node->isa<Parameter>()) {
|
||||
param_node = input_node->cast<ParameterPtr>();
|
||||
}
|
||||
if (param_node == nullptr) {
|
||||
MS_LOG(INFO) << "CanTensorQuantized invalid param_node!";
|
||||
return false;
|
||||
}
|
||||
if (!param_node->has_default()) {
|
||||
MS_LOG(INFO) << "param_node don't has default.";
|
||||
return false;
|
||||
}
|
||||
auto abstract_base = param_node->abstract();
|
||||
if (abstract_base == nullptr) {
|
||||
MS_LOG(INFO) << "abstract is nullptr";
|
||||
return false;
|
||||
}
|
||||
if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
|
||||
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << param_node->name();
|
||||
return false;
|
||||
}
|
||||
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
|
||||
MS_ASSERT(weight_shape != nullptr);
|
||||
if (weight_shape.size() < DIMENSION_2D) { // do not quant single dim tensors
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int QuantNodePass::Quant() {
|
||||
CHECK_NULL_RETURN(func_graph_);
|
||||
auto cnodes = func_graph_->GetOrderedCnodes();
|
||||
|
|
|
@ -47,6 +47,7 @@ class QuantNodePass {
|
|||
int DoParameterNodeQuant(const CNodePtr &cnode, const ParameterPtr &input_node, size_t input_index);
|
||||
int DoValueNodeQuant(const CNodePtr &cnode, const ValueNodePtr &input_node, size_t input_index);
|
||||
int IsSupportWeightQuant(const CNodePtr &cnode, const AnfNodePtr &input_node, size_t input_index);
|
||||
bool CanTensorQuantized(const CNodePtr &cnode, const AnfNodePtr &input_node);
|
||||
|
||||
private:
|
||||
FuncGraphPtr func_graph_ = nullptr;
|
||||
|
|
|
@ -26,9 +26,6 @@
|
|||
#include "tools/common/node_util.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
namespace {
|
||||
static const std::set<PrimitivePtr> fp32_output_operator = {prim::kPrimDetectionPostProcess};
|
||||
}
|
||||
bool QuantTypeDeterminer::DetermineQuantAll(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (opt::IsSpecialType(cnode)) {
|
||||
|
@ -56,15 +53,28 @@ bool QuantTypeDeterminer::DetermineQuantAll(const CNodePtr &cnode) {
|
|||
if (quant_holder->quant_type() != schema::QuantType_QUANT_NONE) {
|
||||
return quant_holder->quant_type() == schema::QuantType_QUANT_ALL;
|
||||
}
|
||||
// if DTypeCastNode is graph input or output, set QuantType_QUANT_NONE
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
|
||||
return (!IsGraphInDTypeCast(cnode) && !IsGraphOutDTypeCast(func_graph_, cnode));
|
||||
}
|
||||
// Check input quant params, quantization parameters exist for all activations.
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto input = cnode->input(i);
|
||||
TypeId input_node_dtype = kTypeUnknown;
|
||||
if (opt::GetDataTypeFromAnfNode(input, &input_node_dtype) != RET_OK) {
|
||||
MS_LOG(INFO) << "Get data type failed, input node name: " << input->fullname_with_scope();
|
||||
continue;
|
||||
}
|
||||
// Only specified dtype need check input quant params
|
||||
if (kFullQuantDType.find(input_node_dtype) == kFullQuantDType.end()) {
|
||||
continue;
|
||||
}
|
||||
if (input->isa<CNode>() && !quant_holder->CheckInit(i - kPrimOffset, true)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Check output quant params.
|
||||
if (CheckNodeInSet(cnode, fp32_output_operator) && !quant_holder->IsOutputQuantParamsInited()) {
|
||||
if (!CheckNodeInSet(cnode, kUint8toFP32Operator) && !quant_holder->IsOutputQuantParamsInited()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -118,7 +128,7 @@ int QuantTypeDeterminer::Determine() {
|
|||
MS_LOG(INFO) << cnode->fullname_with_scope() << " quant holder is nullptr.";
|
||||
continue;
|
||||
}
|
||||
if (!quant_holder->IsInputQuantParamsInited() && !quant_holder->IsOutputQuantParamsInited()) { // Check FP32.
|
||||
if (!quant_holder->IsInputExistInited() && !quant_holder->IsOutputExistInited()) { // Check FP32.
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class QuantTypeDeterminer {
|
||||
|
|
|
@ -63,6 +63,26 @@ bool QuantParamHolder::IsOutputQuantParamsInited() {
|
|||
return is_quant_params_inited;
|
||||
}
|
||||
|
||||
bool QuantParamHolder::IsInputExistInited() {
|
||||
if (this->input_quant_params_.empty()) {
|
||||
return false;
|
||||
}
|
||||
bool is_exist_param_inited =
|
||||
std::any_of(this->input_quant_params_.begin(), this->input_quant_params_.end(),
|
||||
[](const std::vector<schema::QuantParamT> &quant_params) { return quant_params.front().inited; });
|
||||
return is_exist_param_inited;
|
||||
}
|
||||
|
||||
bool QuantParamHolder::IsOutputExistInited() {
|
||||
if (this->output_quant_params_.empty()) {
|
||||
return false;
|
||||
}
|
||||
bool is_exist_param_inited =
|
||||
std::any_of(this->output_quant_params_.begin(), this->output_quant_params_.end(),
|
||||
[](const std::vector<schema::QuantParamT> &quant_params) { return quant_params.front().inited; });
|
||||
return is_exist_param_inited;
|
||||
}
|
||||
|
||||
void QuantParamHolder::ClearQuantParams() {
|
||||
quant_type_ = schema::QuantType_QUANT_NONE;
|
||||
input_quant_params_.clear();
|
||||
|
|
|
@ -110,6 +110,10 @@ class QuantParamHolder : public Value {
|
|||
|
||||
bool IsOutputQuantParamsInited();
|
||||
|
||||
bool IsInputExistInited();
|
||||
|
||||
bool IsOutputExistInited();
|
||||
|
||||
void ClearQuantParams();
|
||||
|
||||
bool CheckInit(size_t index, bool is_input);
|
||||
|
|
|
@ -44,11 +44,15 @@ constexpr int kPrimIndex = 0;
|
|||
constexpr int kPrimOffset = 1;
|
||||
constexpr int kU8ZeroPointOffset = 128;
|
||||
constexpr int kQuantRange = 127;
|
||||
constexpr int kInt8LeftRange = -128;
|
||||
constexpr int kInt8RightRange = 127;
|
||||
constexpr int kMinIterations = 40;
|
||||
|
||||
const std::set<PrimitivePtr> kHasBiasOperator = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
|
||||
prim::kPrimMatMulFusion, prim::kPrimFullConnection,
|
||||
prim::kPrimLayerNormFusion};
|
||||
const std::set<PrimitivePtr> kUint8toFP32Operator = {prim::kPrimDetectionPostProcess};
|
||||
const std::set<TypeId> kFullQuantDType = {kNumberTypeInt8, kNumberTypeUInt8, kNumberTypeFloat32};
|
||||
|
||||
enum ActivationQuantizedMethod {
|
||||
MAX_MIN = 0,
|
||||
|
@ -68,6 +72,17 @@ enum DebugMode {
|
|||
DETAIL,
|
||||
};
|
||||
|
||||
enum CastNodeType {
|
||||
kNone,
|
||||
kQuant,
|
||||
kDeQuant,
|
||||
};
|
||||
|
||||
enum InsertDirection {
|
||||
FORWARD,
|
||||
BACKWARD,
|
||||
};
|
||||
|
||||
struct CommonQuantParam {
|
||||
schema::QuantType quant_type = schema::QuantType_QUANT_NONE;
|
||||
int bit_num = 8;
|
||||
|
|
|
@ -139,6 +139,51 @@ int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
ValueNodePtr NewQuantCastPrimitive(int src_type, int dst_type,
|
||||
const std::vector<schema::QuantParamT> &input_quant_params,
|
||||
const std::vector<schema::QuantParamT> &output_quant_params) {
|
||||
auto prim_c = std::make_shared<ops::QuantDTypeCast>();
|
||||
MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
|
||||
prim_c->Init(src_type, dst_type);
|
||||
auto quant_params_holder = std::make_shared<QuantParamHolder>(input_quant_params.size(), output_quant_params.size());
|
||||
MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
|
||||
quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL);
|
||||
quant_params_holder->set_input_quant_param(0, input_quant_params);
|
||||
quant_params_holder->set_output_quant_param(0, output_quant_params);
|
||||
auto prim = prim_c->GetPrim();
|
||||
MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
|
||||
prim->AddAttr("quant_params", quant_params_holder);
|
||||
return NewValueNode(prim);
|
||||
}
|
||||
|
||||
bool IsGraphInDTypeCast(const CNodePtr &cnode) {
|
||||
if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
|
||||
return false;
|
||||
}
|
||||
auto input_node = cnode->input(1);
|
||||
MS_CHECK_FALSE(input_node == nullptr, false);
|
||||
return IsGraphInput(input_node);
|
||||
}
|
||||
|
||||
bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
|
||||
return false;
|
||||
}
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
manager = Manage(func_graph, true);
|
||||
}
|
||||
CHECK_NULL_RETURN(manager);
|
||||
auto node_users = manager->node_users()[cnode];
|
||||
for (auto &node_user : node_users) {
|
||||
auto output_cnode = node_user.first->cast<CNodePtr>();
|
||||
if (!opt::CheckPrimitiveType(output_cnode, prim::kPrimReturn)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TensorQuantParamsInited(const schema::TensorT &tensor) {
|
||||
if (tensor.quantParams.empty()) {
|
||||
return false;
|
||||
|
|
|
@ -43,8 +43,13 @@
|
|||
#include "tools/converter/quantizer/mixed_bit_weight_quantization.h"
|
||||
#include "tools/common/string_util.h"
|
||||
#include "ops/core_ops.h"
|
||||
#include "ops/quant_dtype_cast.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
static const std::set<PrimitivePtr> has_bias_operator = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
|
||||
prim::kPrimMatMulFusion, prim::kPrimFullConnection,
|
||||
prim::kPrimLayerNormFusion};
|
||||
|
||||
QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive);
|
||||
|
||||
QuantParamHolderPtr GetCNodeQuantHolder(const CNodePtr &cnode);
|
||||
|
@ -71,6 +76,13 @@ void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_f
|
|||
|
||||
int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type);
|
||||
|
||||
ValueNodePtr NewQuantCastPrimitive(int src_type, int dst_type,
|
||||
const std::vector<schema::QuantParamT> &input_quant_params,
|
||||
const std::vector<schema::QuantParamT> &output_quant_params);
|
||||
bool IsGraphInDTypeCast(const CNodePtr &cnode);
|
||||
|
||||
bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
|
||||
template <typename T>
|
||||
int DeQuantData(const int8_t *tensor_data, int64_t elements_num, std::vector<mindspore::QuantParam> quant_params,
|
||||
std::vector<T> *dequant_data, int preferred_dim = 0) {
|
||||
|
|
Loading…
Reference in New Issue