delete unused code

This commit is contained in:
yeyunpeng2020 2023-01-04 11:07:46 +08:00
parent ec992b0641
commit c1e641d4f6
41 changed files with 96 additions and 1342 deletions

View File

@ -140,7 +140,7 @@ int CalWeightQuantBias(const float *raw_datas, size_t elem_count, const std::vec
template <typename T> template <typename T>
int DoPerChannelQuant(const float *raw_datas, size_t elem_count, std::vector<schema::QuantParamT> *quant_params, int DoPerChannelQuant(const float *raw_datas, size_t elem_count, std::vector<schema::QuantParamT> *quant_params,
const int &quant_max, const int &quant_min, const size_t &bit_num, std::vector<T> *quant_datas, int quant_max, int quant_min, size_t bit_num, std::vector<T> *quant_datas,
const std::vector<int> &dims, int preferred_dim, bool cal_gain = true, bool symmetric = false, const std::vector<int> &dims, int preferred_dim, bool cal_gain = true, bool symmetric = false,
bool narrow_range = false) { bool narrow_range = false) {
if (raw_datas == nullptr || quant_params == nullptr || quant_datas == nullptr) { if (raw_datas == nullptr || quant_params == nullptr || quant_datas == nullptr) {

View File

@ -26,7 +26,6 @@
#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
#include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h" #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h"
#include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h"
#include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h" #include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h"
#include "tools/converter/legacy_optimizer/graph/convert_fp32_to_fp16_pass.h" #include "tools/converter/legacy_optimizer/graph/convert_fp32_to_fp16_pass.h"
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h" #include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"

View File

@ -1,72 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define USE_DEPRECATED_API
#include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h"
#include <vector>
#include <memory>
#include "src/common/utils.h"
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
#include "tools/common/node_util.h"
#include "nnacl/op_base.h"
namespace mindspore::lite {
STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) {
if (graph == nullptr) {
MS_LOG(ERROR) << "graph is null";
return RET_NULL_PTR;
}
// forward infer nodes' quant params
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
if (node == nullptr) {
MS_LOG(ERROR) << "node is null";
return RET_NULL_PTR;
}
auto quant_helper = QuantHelperRegister::GetInstance()->GetQuantHelper(node->primitive->value.type);
MS_CHECK_TRUE_MSG(quant_helper != nullptr, RET_ERROR, "Find QuantHelper return nullptr");
auto ret = quant_helper->NodeQuantPreprocess(graph, node.get());
if (ret != RET_OK) {
MS_LOG(ERROR) << "Node quant preprocess failed.";
return ret;
}
}
// backward infer nodes' quant params
for (auto iter = graph->nodes.rbegin(); iter != graph->nodes.rend(); iter++) {
auto &node = *iter;
if (node == nullptr) {
MS_LOG(ERROR) << "node is null";
return RET_NULL_PTR;
}
if (!node->primitive) {
continue;
}
auto quant_helper = QuantHelperRegister::GetInstance()->GetQuantHelper(node->primitive->value.type);
MS_CHECK_TRUE_MSG(quant_helper != nullptr, RET_ERROR, "Find QuantHelper return nullptr");
auto ret = quant_helper->NodeQuantPreprocess(graph, node.get());
if (ret != RET_OK) {
MS_LOG(ERROR) << "Node quant preprocess failed.";
return ret;
}
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -1,35 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_INFER_QUANT_PARAM_PASS_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_INFER_QUANT_PARAM_PASS_H_
#include <memory>
#include "tools/converter/optimizer.h"
#include "tools/common/graph_util.h"
namespace mindspore {
namespace lite {
class InferQuantParamPass : public GraphPass {
public:
InferQuantParamPass() = default;
~InferQuantParamPass() override = default;
STATUS Run(schema::MetaGraphT *graph) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_INFER_QUANT_PARAM_PASS_H_

View File

@ -548,7 +548,7 @@ int BiasCorrectionStrategy::DoCNodeBiasCorrection(const FuncGraphPtr &quant_func
return RET_ERROR; return RET_ERROR;
} }
} }
} else if (cnode->size() == kHasBiasTensorSize - 1) { } else if (cnode->size() == kHasBiasTensorSize - kPrimOffset) {
MS_LOG(INFO) << op_name << " add bias input"; MS_LOG(INFO) << op_name << " add bias input";
// need to add bias input // need to add bias input
auto parameter = quant_func_graph->add_parameter(); auto parameter = quant_func_graph->add_parameter();

View File

@ -176,7 +176,7 @@ int ClusterQuantization::KMeansQuantization(const CNodePtr &cnode, const std::ve
auto input = cnode->input(idx); auto input = cnode->input(idx);
ParameterPtr parameter; ParameterPtr parameter;
tensor::TensorPtr tensor_info; tensor::TensorPtr tensor_info;
GetLiteParameter(input, &parameter, &tensor_info); GetParameterAndTensor(input, &parameter, &tensor_info);
if (parameter == nullptr || tensor_info == nullptr || tensor_info->data_type() != TypeId::kNumberTypeFloat32 || if (parameter == nullptr || tensor_info == nullptr || tensor_info->data_type() != TypeId::kNumberTypeFloat32 ||
tensor_info->compression_type() != mindspore::kNoCompression) { tensor_info->compression_type() != mindspore::kNoCompression) {
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " dont need quant weight"; MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " dont need quant weight";

View File

@ -128,7 +128,7 @@ int FixedBitWeightQuantization::QuantBias(const ParameterPtr &bias, const Primit
auto ret = auto ret =
UpdateTensorDataAndSize(bias, bias_param, quant_datas.data(), shape_size * sizeof(int32_t), kNumberTypeInt32); UpdateTensorDataAndSize(bias, bias_param, quant_datas.data(), shape_size * sizeof(int32_t), kNumberTypeInt32);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << bias->fullname_with_scope() << " update tensor data adn size failed."; MS_LOG(ERROR) << bias->fullname_with_scope() << " update tensor data and size failed.";
return RET_ERROR; return RET_ERROR;
} }
return RET_OK; return RET_OK;

View File

@ -96,14 +96,15 @@ class FixedBitWeightQuantization {
weight_quant_type = FIXED_BIT_PER_LAYER; weight_quant_type = FIXED_BIT_PER_LAYER;
} }
} }
if (weight->data_type_c() != kNumberTypeFloat32) {
MS_LOG(ERROR) << "data type is not Float32.";
return RET_ERROR;
}
std::vector<schema::QuantParamT> quant_params; std::vector<schema::QuantParamT> quant_params;
int ret = RET_OK; int ret = RET_OK;
if (weight_quant_type == FIXED_BIT_PER_CHANNEL) { if (weight_quant_type == FIXED_BIT_PER_CHANNEL) {
bool cal_gain = false; bool cal_gain = (quant_type == QUANT_WEIGHT) ? true : false;
if (quant_type == QUANT_WEIGHT) {
cal_gain = true;
}
ret = DoPerChannelQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(), &quant_params, quant_max, ret = DoPerChannelQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(), &quant_params, quant_max,
quant_min, bit_num, quant_data, ConvertShapeVectorToInt32(dims), preferred_dim, quant_min, bit_num, quant_data, ConvertShapeVectorToInt32(dims), preferred_dim,
cal_gain, symmetric, narrow_range); cal_gain, symmetric, narrow_range);

View File

@ -33,7 +33,6 @@
namespace mindspore::lite::quant { namespace mindspore::lite::quant {
namespace { namespace {
constexpr size_t kMinSize3 = 3; constexpr size_t kMinSize3 = 3;
constexpr size_t kPrimitiveCOffset = 1;
constexpr size_t kTableExtend = 3; constexpr size_t kTableExtend = 3;
constexpr size_t kAlignOffset = 7; constexpr size_t kAlignOffset = 7;
constexpr size_t kInt32Mask = 31; constexpr size_t kInt32Mask = 31;
@ -112,16 +111,16 @@ int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const
MS_LOG(ERROR) << op_name << " cnode size:" << cnode->size() << " < 3."; MS_LOG(ERROR) << op_name << " cnode size:" << cnode->size() << " < 3.";
return RET_ERROR; return RET_ERROR;
} }
auto input = cnode->input(kInputIndex + kPrimitiveCOffset); auto input = cnode->input(kInputIndex + kPrimOffset);
if (input->isa<mindspore::CNode>() || IsGraphInput(input)) { if (input->isa<mindspore::CNode>() || IsGraphInput(input)) {
auto ret = InsertDynamicQuantWithIndex(graph, cnode, kInputIndex + kPrimitiveCOffset); auto ret = InsertDynamicQuantWithIndex(graph, cnode, kInputIndex + kPrimOffset);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert dynamic quant with index failed."; MS_LOG(ERROR) << "Insert dynamic quant with index failed.";
} }
} }
auto weight = cnode->input(kWeightIndex + kPrimitiveCOffset); auto weight = cnode->input(kWeightIndex + kPrimOffset);
if (weight->isa<mindspore::CNode>() || IsGraphInput(weight)) { if (weight->isa<mindspore::CNode>() || IsGraphInput(weight)) {
auto ret = InsertDynamicQuantWithIndex(graph, cnode, kWeightIndex + kPrimitiveCOffset); auto ret = InsertDynamicQuantWithIndex(graph, cnode, kWeightIndex + kPrimOffset);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert dynamic quant with index failed."; MS_LOG(ERROR) << "Insert dynamic quant with index failed.";
} }
@ -130,12 +129,9 @@ int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const
} }
int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) { int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) {
MS_CHECK_TRUE_RET(cnode != nullptr, RET_NULL_PTR); CHECK_NULL_RETURN(cnode);
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
if (primitive == nullptr) { CHECK_NULL_RETURN(primitive);
MS_LOG(ERROR) << "primitive is nullptr";
return RET_ERROR;
}
auto quant_param_holder = GetCNodeQuantHolder(primitive); auto quant_param_holder = GetCNodeQuantHolder(primitive);
quant_param_holder->set_quant_type(quant::QUANT_DYNAMIC); quant_param_holder->set_quant_type(quant::QUANT_DYNAMIC);
return RET_OK; return RET_OK;
@ -144,7 +140,7 @@ int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) {
int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph, int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
const std::set<PrimitivePtr> &support_dynamic_quant_ops, const std::set<PrimitivePtr> &support_dynamic_quant_ops,
const std::set<std::string> &skip_quant_node) { const std::set<std::string> &skip_quant_node) {
MS_ASSERT(graph != nullptr); CHECK_NULL_RETURN(graph);
auto cnodes = graph->GetOrderedCnodes(); auto cnodes = graph->GetOrderedCnodes();
for (auto &cnode : cnodes) { for (auto &cnode : cnodes) {
auto op_name = cnode->fullname_with_scope(); auto op_name = cnode->fullname_with_scope();
@ -184,7 +180,7 @@ int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
return RET_OK; return RET_OK;
} }
int InsertQuantNodeManager::InsertFP32DtypeCastNode(const FuncGraphPtr &graph) { int InsertQuantNodeManager::InsertDequantNode(const FuncGraphPtr &graph) {
CHECK_NULL_RETURN(graph); CHECK_NULL_RETURN(graph);
auto cnodes = graph->GetOrderedCnodes(); auto cnodes = graph->GetOrderedCnodes();
for (auto &cnode : cnodes) { for (auto &cnode : cnodes) {
@ -219,6 +215,8 @@ int InsertQuantNodeManager::InserQuantCastNode(const FuncGraphPtr &graph, const
InsertDirection insert_direction, TypeId cast_dtype, InsertDirection insert_direction, TypeId cast_dtype,
CastNodeType cast_node_type, size_t index, CastNodeType cast_node_type, size_t index,
const AnfNodePtr &output_node) { const AnfNodePtr &output_node) {
CHECK_NULL_RETURN(graph);
CHECK_NULL_RETURN(cnode);
if (insert_direction == FORWARD) { if (insert_direction == FORWARD) {
return InsertForwardQuantCastNode(graph, cnode, cast_dtype, index, cast_node_type); return InsertForwardQuantCastNode(graph, cnode, cast_dtype, index, cast_node_type);
} else if (insert_direction == BACKWARD && cast_node_type == kDeQuant) { } else if (insert_direction == BACKWARD && cast_node_type == kDeQuant) {
@ -287,7 +285,7 @@ int InsertQuantNodeManager::InsertForwardQuantCastNode(const FuncGraphPtr &graph
ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params); ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params);
std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node}; std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node};
auto quant_cast_cnode = graph->NewCNode(op_inputs); auto quant_cast_cnode = graph->NewCNode(op_inputs);
MS_CHECK_TRUE_MSG(quant_cast_cnode != nullptr, RET_NULL_PTR, "quant_cast_cnode is nullptr."); CHECK_NULL_RETURN(quant_cast_cnode);
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) + quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) +
"_pre"); "_pre");
// set abstract // set abstract
@ -305,7 +303,7 @@ int InsertQuantNodeManager::InsertForwardQuantCastNode(const FuncGraphPtr &graph
if (manager == nullptr) { if (manager == nullptr) {
manager = Manage(graph, true); manager = Manage(graph, true);
} }
MS_CHECK_TRUE_RET(manager != nullptr, RET_NULL_PTR); CHECK_NULL_RETURN(manager);
manager->SetEdge(cnode, index, quant_cast_cnode); manager->SetEdge(cnode, index, quant_cast_cnode);
MS_LOG(INFO) << "InsertForwardQuantCastNode cnode name: " << cnode->fullname_with_scope() MS_LOG(INFO) << "InsertForwardQuantCastNode cnode name: " << cnode->fullname_with_scope()
<< " src dtype:" << src_dtype << " dst_type: " << dst_dtype; << " src dtype:" << src_dtype << " dst_type: " << dst_dtype;
@ -464,8 +462,9 @@ int InsertQuantNodeManager::InsertBackwardCastNode(const FuncGraphPtr &graph, co
} // node_users } // node_users
return RET_OK; return RET_OK;
} }
int InsertQuantNodeManager::InsertWeightQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int InsertQuantNodeManager::InsertQuantDtypeCastFlyNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
size_t input_index, TypeId src_dtype, TypeId dst_dtype, int axis) { size_t input_index, TypeId src_dtype, TypeId dst_dtype,
int axis) {
auto primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex)); auto primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
if (primitive == nullptr) { if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope(); MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
@ -688,9 +687,8 @@ ValueNodePtr InsertQuantNodeManager::NewQuantCastPrimitive(int src_type, int dst
return NewValueNode(prim); return NewValueNode(prim);
} }
ValueNodePtr InsertQuantNodeManager::NewFSEDecodePrimitive(int dst_type, const uint64_t curr_chunk, ValueNodePtr InsertQuantNodeManager::NewFSEDecodePrimitive(int dst_type, uint64_t curr_chunk, int64_t curr_chunk_index,
const int64_t curr_chunk_index, const int64_t curr_bit_count, int64_t curr_bit_count, int64_t table_log) {
const int64_t table_log) {
auto prim_c = std::make_shared<ops::FSEDecode>(); auto prim_c = std::make_shared<ops::FSEDecode>();
MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr."); MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
prim_c->Init(dst_type, curr_chunk, curr_chunk_index, curr_bit_count, table_log); prim_c->Init(dst_type, curr_chunk, curr_chunk_index, curr_bit_count, table_log);

View File

@ -36,7 +36,7 @@ class InsertQuantNodeManager {
int InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set<PrimitivePtr> &support_dynamic_quant_ops, int InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set<PrimitivePtr> &support_dynamic_quant_ops,
const std::set<std::string> &skip_quant_node); const std::set<std::string> &skip_quant_node);
int InsertFP32DtypeCastNode(const FuncGraphPtr &graph); int InsertDequantNode(const FuncGraphPtr &graph);
int InsertForwardCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype, int InsertForwardCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype,
quant::QuantType curr_quant_type); quant::QuantType curr_quant_type);
@ -47,8 +47,9 @@ class InsertQuantNodeManager {
int InsertBackwardCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype, int InsertBackwardCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype,
quant::QuantType curr_quant_type); quant::QuantType curr_quant_type);
int InsertWeightQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, TypeId src_dtype, int InsertQuantDtypeCastFlyNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
TypeId dst_dtype, int axis); TypeId src_dtype, TypeId dst_dtype, int axis);
int InsertFSEDecodeNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, TypeId dst_dtype); int InsertFSEDecodeNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, TypeId dst_dtype);
private: private:
@ -67,8 +68,10 @@ class InsertQuantNodeManager {
int InsertBackwardDeQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype, size_t index, int InsertBackwardDeQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype, size_t index,
const AnfNodePtr &output_node); const AnfNodePtr &output_node);
int InserQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, InsertDirection insert_direction, int InserQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, InsertDirection insert_direction,
TypeId cast_dtype, CastNodeType cast_node_type, size_t index, const AnfNodePtr &output_node); TypeId cast_dtype, CastNodeType cast_node_type, size_t index, const AnfNodePtr &output_node);
int CreateFSEInputs(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, std::vector<AnfNodePtr> *op_inputs, int CreateFSEInputs(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, std::vector<AnfNodePtr> *op_inputs,
TypeId dst_dtype); TypeId dst_dtype);
@ -77,8 +80,8 @@ class InsertQuantNodeManager {
const std::vector<schema::QuantParamT> &output_quant_params, int axis = 0, const std::vector<schema::QuantParamT> &output_quant_params, int axis = 0,
bool set_quant_flag = true); bool set_quant_flag = true);
ValueNodePtr NewFSEDecodePrimitive(int dst_type, const uint64_t curr_chunk, const int64_t curr_chunk_index, ValueNodePtr NewFSEDecodePrimitive(int dst_type, uint64_t curr_chunk, int64_t curr_chunk_index,
const int64_t curr_bit_count, const int64_t table_log); int64_t curr_bit_count, int64_t table_log);
private: private:
TypeId dst_type_ = kNumberTypeInt8; TypeId dst_type_ = kNumberTypeInt8;

View File

@ -1,67 +0,0 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/attention_quant_type_determiner.h"
#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "src/common/log_adapter.h"
#include "mindspore/core/ir/dtype/type_id.h"
#include "nnacl/op_base.h"
namespace mindspore::lite {
const size_t kWeightQueryIndex = 4;
const size_t kWeightKeyIndex = 5;
const size_t kWeightValueIndex = 6;
const size_t kWeightOutputIndex = 10;
bool AttentionQuantTypeDeterminer::DetermineQuantWeight(const mindspore::schema::MetaGraphT &graph,
mindspore::schema::CNodeT *node) {
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
auto input_index = node->inputIndex;
MS_CHECK_FALSE_MSG(input_index.empty(), false, "inputIndex is empty.");
MS_CHECK_TRUE_MSG(input_index.size() > kInputIndex, false, "invalid access.");
MS_CHECK_TRUE_MSG(graph.allTensors.size() > input_index.at(kInputIndex), false, "invalid access.");
auto &input_tensor = graph.allTensors.at(input_index.at(kInputIndex));
MS_CHECK_TRUE_MSG(input_index.size() > kWeightQueryIndex, false, "invalid access.");
MS_CHECK_TRUE_MSG(graph.allTensors.size() > input_index.at(kWeightQueryIndex), false, "invalid access.");
auto &weight_query_tensor = graph.allTensors.at(input_index.at(kWeightQueryIndex));
MS_CHECK_TRUE_MSG(input_index.size() > kWeightKeyIndex, false, "invalid access.");
MS_CHECK_TRUE_MSG(graph.allTensors.size() > input_index.at(kWeightKeyIndex), false, "invalid access.");
auto &weight_key_tensor = graph.allTensors.at(input_index.at(kWeightKeyIndex));
MS_CHECK_TRUE_MSG(input_index.size() > kWeightValueIndex, false, "invalid access.");
MS_CHECK_TRUE_MSG(graph.allTensors.size() > input_index.at(kWeightValueIndex), false, "invalid access.");
auto &weight_value_tensor = graph.allTensors.at(input_index.at(kWeightValueIndex));
MS_CHECK_TRUE_MSG(input_index.size() > kWeightOutputIndex, false, "invalid access.");
MS_CHECK_TRUE_MSG(graph.allTensors.size() > input_index.at(kWeightOutputIndex), false, "invalid access.");
auto &weight_output_tensor = graph.allTensors.at(input_index.at(kWeightOutputIndex));
MS_CHECK_TRUE_RET(input_tensor != nullptr, false);
MS_CHECK_TRUE_RET(weight_query_tensor != nullptr, false);
MS_CHECK_TRUE_RET(weight_key_tensor != nullptr, false);
MS_CHECK_TRUE_RET(weight_value_tensor != nullptr, false);
MS_CHECK_TRUE_RET(weight_output_tensor != nullptr, false);
if (!TensorQuantParamsInited(*input_tensor) && TensorQuantParamsInited(*weight_query_tensor) &&
TensorQuantParamsInited(*weight_key_tensor) && TensorQuantParamsInited(*weight_value_tensor) &&
TensorQuantParamsInited(*weight_output_tensor)) {
node->quantType = schema::QuantType_QUANT_WEIGHT;
return true;
}
return false;
}
} // namespace mindspore::lite

View File

@ -1,27 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_ATTENTION_QUANT_TYPE_DETERMINER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_ATTENTION_QUANT_TYPE_DETERMINER_H_
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class AttentionQuantTypeDeterminer : public QuantTypeDeterminer {
public:
bool DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_ATTENTION_QUANT_TYPE_DETERMINER_H_

View File

@ -1,37 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define USE_DEPRECATED_API
#include "tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.h"
#include "mindspore/core/ir/dtype/type_id.h"
#include "src/common/log_adapter.h"
#include "nnacl/op_base.h"
namespace mindspore::lite {
static constexpr size_t kBiasAddSize = 2;
int BiasAddQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGraphT *graph,
const mindspore::schema::CNodeT &node) {
MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
if (node.inputIndex.size() == kBiasAddSize) {
auto &bias_tensor = graph->allTensors.at(node.inputIndex.at(kBiasAddSize - 1));
MS_CHECK_TRUE_RET(bias_tensor != nullptr, RET_NULL_PTR);
for (auto &quantParam : bias_tensor->quantParams) {
quantParam->dstDtype = TypeId::kNumberTypeInt32;
}
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -1,27 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_BIAS_ADD_QUANT_PARAM_PROPOGATOR_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_BIAS_ADD_QUANT_PARAM_PROPOGATOR_H_
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class BiasAddQuantParamPropogator : public QuantParamPropogator {
public:
int PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_BIAS_ADD_QUANT_PARAM_PROPOGATOR_H_

View File

@ -1,79 +0,0 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.h"
#include <utility>
#include <memory>
#include "tools/common/tensor_util.h"
#include "src//common/log_util.h"
#include "nnacl/op_base.h"
namespace mindspore::lite {
int CarryDataQuantParamPropogator::PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) {
MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "Graph is nullptr.");
UpdateQuantParamsNum(*graph, node);
MS_CHECK_FALSE_MSG(graph->allTensors.empty(), RET_ERROR, "Tensors is empty.");
// refresh in_tensor quant_params by out_tensor quant_params
if (input_inited_quant_params_ < 1) {
MS_CHECK_FALSE_MSG(node.outputIndex.empty(), RET_ERROR, "OutputIndex is empty.");
MS_CHECK_TRUE_RET(graph->allTensors.size() > node.outputIndex.at(0), RET_ERROR);
auto &out_tensor = graph->allTensors.at(node.outputIndex.at(0));
auto out_quant_param = GetTensorQuantParam(out_tensor);
if (out_quant_param == nullptr || !out_quant_param->inited) {
MS_LOG(DEBUG) << node.name << " dont need to pass quant param.";
return RET_NO_CHANGE;
}
MS_CHECK_FALSE_MSG(node.inputIndex.empty(), RET_ERROR, "inputIndex is empty.");
MS_CHECK_TRUE_RET(graph->allTensors.size() > node.inputIndex.at(0), RET_ERROR);
auto &in_tensor = graph->allTensors.at(node.inputIndex.at(0));
MS_CHECK_TRUE_RET(in_tensor != nullptr, RET_NULL_PTR);
auto in_quant_param = GetTensorQuantParam(in_tensor);
if (in_quant_param != nullptr && !in_quant_param->inited) {
MS_CHECK_FALSE_MSG(in_tensor->quantParams.empty(), RET_ERROR, "in_tensor quantParams is empty.");
in_tensor->quantParams.front() = std::move(out_quant_param);
}
}
// refresh out_tensor quant_params by in_tensor quant_params
if (output_inited_quant_params_ < 1) {
MS_CHECK_FALSE_MSG(node.inputIndex.empty(), RET_ERROR, "inputIndex is empty.");
MS_CHECK_TRUE_RET(graph->allTensors.size() > node.inputIndex.at(0), RET_ERROR);
auto &in_tensor = graph->allTensors.at(node.inputIndex.at(0));
MS_CHECK_TRUE_RET(in_tensor != nullptr, RET_NULL_PTR);
auto in_quant_param = GetTensorQuantParam(in_tensor);
if (in_quant_param == nullptr || !in_quant_param->inited) {
MS_LOG(DEBUG) << node.name << " dont need to pass quant param.";
return RET_NO_CHANGE;
}
for (unsigned int i : node.outputIndex) {
MS_CHECK_TRUE_RET(graph->allTensors.size() > i, RET_ERROR);
auto &out_tensor = graph->allTensors.at(i);
MS_CHECK_TRUE_RET(out_tensor != nullptr, RET_NULL_PTR);
auto out_quant_param = GetTensorQuantParam(out_tensor);
if (out_quant_param == nullptr) {
out_tensor->quantParams.emplace_back(std::move(in_quant_param));
continue;
}
if (out_quant_param->inited) {
continue;
}
MS_CHECK_FALSE_MSG(out_tensor->quantParams.empty(), RET_ERROR, "out_tensor quantParams is empty.");
out_tensor->quantParams.front() = std::move(in_quant_param);
}
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -1,27 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CARRY_DATA_QUANT_PARAM_PROPOGATOR_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CARRY_DATA_QUANT_PARAM_PROPOGATOR_H_
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class CarryDataQuantParamPropogator : public QuantParamPropogator {
public:
int PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CARRY_DATA_QUANT_PARAM_PROPOGATOR_H_

View File

@ -1,56 +0,0 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.h"
#include <utility>
#include <memory>
#include "tools/common/tensor_util.h"
#include "nnacl/op_base.h"
namespace mindspore::lite {
bool CarryDataQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
MS_CHECK_TRUE_RET(node->inputIndex.size() >= kInputIndexOne, false);
MS_CHECK_TRUE_RET(node->outputIndex.size() >= kInputIndexOne, false);
// check first in tensor
MS_CHECK_FALSE_MSG(node->inputIndex.empty(), false, "inputIndex is empty.");
MS_CHECK_TRUE_RET(graph.allTensors.size() > node->inputIndex.at(0), false);
auto &in_tensor = graph.allTensors.at(node->inputIndex.at(0));
MS_CHECK_TRUE_RET(in_tensor != nullptr, false);
if (!in_tensor->quantParams.empty()) {
if (std::any_of(in_tensor->quantParams.begin(), in_tensor->quantParams.end(),
[](const std::unique_ptr<QuantParamT> &quant_param) { return !quant_param->inited; })) {
return false;
}
} else {
return false;
}
// check first out tensor
MS_CHECK_FALSE_MSG(node->outputIndex.empty(), false, "outputIndex is empty.");
MS_CHECK_TRUE_RET(graph.allTensors.size() > node->outputIndex.at(0), false);
auto &out_tensor = graph.allTensors.at(node->outputIndex.at(0));
MS_CHECK_TRUE_RET(out_tensor != nullptr, false);
if (!out_tensor->quantParams.empty()) {
if (std::any_of(out_tensor->quantParams.begin(), out_tensor->quantParams.end(),
[](const std::unique_ptr<QuantParamT> &quant_param) { return !quant_param->inited; })) {
return false;
}
node->quantType = schema::QuantType_QUANT_ALL;
return true;
}
return false;
}
} // namespace mindspore::lite

View File

@ -1,27 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CARRY_DATA_QUANT_TYPE_DETERMINER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CARRY_DATA_QUANT_TYPE_DETERMINER_H_
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class CarryDataQuantTypeDeterminer : public QuantTypeDeterminer {
public:
bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CARRY_DATA_QUANT_TYPE_DETERMINER_H_

View File

@ -1,89 +0,0 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/concat_quant_param_propogator.h"
#include <cfloat>
#include <memory>
#include <utility>
#include "src/common/log_adapter.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "nnacl/op_base.h"
namespace mindspore::lite {
int ConcatQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGraphT *graph,
const mindspore::schema::CNodeT &node) {
MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
UpdateQuantParamsNum(*graph, node);
if (input_inited_quant_params_ != node.inputIndex.size()) {
MS_LOG(DEBUG) << "Can not determine concat inputTensor quantParam, node " << node.name;
return RET_NO_CHANGE;
}
if (output_inited_quant_params_ != 1) {
MS_CHECK_TRUE_RET(output_inited_quant_params_ == 0, RET_ERROR);
float min_min = FLT_MAX;
float max_max = FLT_MIN;
bool narrow_range = false;
int num_bits = -1;
for (size_t index : node.inputIndex) {
MS_ASSERT(graph->allTensors.size() > index);
auto &in_tensor = graph->allTensors.at(index);
MS_ASSERT(in_tensor != nullptr);
auto in_quant_param = GetTensorQuantParam(in_tensor);
if (in_quant_param == nullptr || !in_quant_param->inited) {
return RET_ERROR;
}
if (num_bits == -1) {
narrow_range = in_quant_param->narrowRange;
num_bits = in_quant_param->numBits;
} else {
MS_ASSERT(narrow_range == quantParam->narrowRange);
MS_ASSERT(num_bits == quantParam->numBits);
}
if (in_quant_param->max < in_quant_param->min) {
MS_LOG(DEBUG) << "Input quant param is invalid for propogator";
return RET_ERROR;
}
if (min_min > in_quant_param->min) {
min_min = in_quant_param->min;
}
if (max_max < in_quant_param->max) {
max_max = in_quant_param->max;
}
}
MS_CHECK_FALSE_MSG(node.outputIndex.empty(), RET_ERROR, "Output index is empty.");
MS_CHECK_TRUE_RET(graph->allTensors.size() > node.outputIndex.front(), RET_ERROR);
auto &out_tensor = graph->allTensors.at(node.outputIndex.front());
MS_CHECK_TRUE_RET(out_tensor != nullptr, RET_NULL_PTR);
auto out_quant_param = std::make_unique<QuantParamT>();
MS_CHECK_TRUE_MSG(out_quant_param != nullptr, RET_NULL_PTR, "out_quant_param is nullptr.");
auto status = CalQuantizationParams(out_quant_param.get(), min_min, max_max, num_bits, narrow_range);
if (status != RET_OK) {
MS_LOG(DEBUG) << "in aware quantization run CalQuantizationParams failed!";
return RET_ERROR;
}
out_tensor->quantParams.emplace_back(std::move(out_quant_param));
output_inited_quant_params_++;
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -1,27 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONCAT_QUANT_PARAM_PROPOGATOR_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONCAT_QUANT_PARAM_PROPOGATOR_H_
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class ConcatQuantParamPropogator : public QuantParamPropogator {
public:
int PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONCAT_QUANT_PARAM_PROPOGATOR_H_

View File

@ -1,71 +0,0 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h"
#include <vector>
#include <memory>
#include <utility>
#include "mindspore/core/ir/dtype/type_id.h"
#include "src/common/log_adapter.h"
#include "nnacl/op_base.h"
namespace mindspore::lite {
static constexpr size_t kBiasAdd = 3;
int ConvQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGraphT *graph,
const mindspore::schema::CNodeT &node) {
MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
if (node.inputIndex.size() == kBiasAdd) {
MS_CHECK_TRUE_RET(graph->allTensors.size() > node.inputIndex.at(kBiasAdd - 1), RET_ERROR);
auto &bias_tensor = graph->allTensors.at(node.inputIndex.at(kBiasAdd - 1));
if (bias_tensor->quantParams.empty() || !bias_tensor->quantParams.front()->inited) {
// check input and weight quant params
auto &input_tensor = graph->allTensors.at(node.inputIndex.at(0));
auto &weight_tensor = graph->allTensors.at(node.inputIndex.at(1));
MS_CHECK_TRUE_RET(input_tensor != nullptr, RET_NULL_PTR);
MS_CHECK_TRUE_RET(weight_tensor != nullptr, RET_NULL_PTR);
if (input_tensor->quantParams.empty() || !input_tensor->quantParams.front()->inited) {
return RET_OK;
}
if (weight_tensor->quantParams.empty() || !weight_tensor->quantParams.front()->inited) {
return RET_OK;
}
auto &input_quant_param = input_tensor->quantParams.at(0);
std::vector<std::unique_ptr<schema::QuantParamT>> bias_quant_params;
for (auto &weight_quant_param : weight_tensor->quantParams) {
auto bias_quant_param = std::make_unique<schema::QuantParamT>();
MS_CHECK_TRUE_MSG(bias_quant_param != nullptr, RET_NULL_PTR, "bias_quant_param is nullptr.");
bias_quant_param->min = 0.0;
bias_quant_param->max = 0.0;
bias_quant_param->dstDtype = kNumberTypeInt32;
bias_quant_param->inited = input_quant_param->inited && weight_quant_param->inited;
bias_quant_param->zeroPoint = 0;
if (bias_quant_param->inited) {
bias_quant_param->scale = input_quant_param->scale * weight_quant_param->scale;
}
bias_quant_param->roundType = 1;
bias_quant_param->multiplier = 1;
bias_quant_params.emplace_back(std::move(bias_quant_param));
}
bias_tensor->quantParams = std::move(bias_quant_params);
}
for (auto &quantParam : bias_tensor->quantParams) {
quantParam->dstDtype = TypeId::kNumberTypeInt32;
}
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -1,27 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONV_QUANT_PARAM_PROPOGATOR_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONV_QUANT_PARAM_PROPOGATOR_H_
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class ConvQuantParamPropogator : public QuantParamPropogator {
public:
int PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONV_QUANT_PARAM_PROPOGATOR_H_

View File

@ -1,42 +0,0 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "src/common/log_adapter.h"
namespace mindspore::lite {
bool ConvQuantTypeDeterminer::DetermineQuantWeight(const mindspore::schema::MetaGraphT &graph,
mindspore::schema::CNodeT *node) {
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
MS_CHECK_TRUE_RET(node->inputIndex.size() >= kInputIndexTwo, false);
MS_CHECK_TRUE_MSG(graph.allTensors.size() > node->inputIndex.at(kInputIndex), false, "Out of vector range.");
auto &input_tensor = graph.allTensors.at(node->inputIndex.at(kInputIndex));
MS_CHECK_TRUE_MSG(graph.allTensors.size() > node->inputIndex.at(kWeightIndex), false, "Out of vector range.");
auto &weight_tensor = graph.allTensors.at(node->inputIndex.at(kWeightIndex));
MS_CHECK_TRUE_RET(node->outputIndex.size() > kOutputIndex, false);
MS_CHECK_TRUE_MSG(graph.allTensors.size() > node->outputIndex.at(kOutputIndex), false, "Out of vector range.");
auto &output_tensor = graph.allTensors.at(node->outputIndex.at(kOutputIndex));
MS_CHECK_TRUE_RET(input_tensor != nullptr, false);
MS_CHECK_TRUE_RET(output_tensor != nullptr, false);
MS_CHECK_TRUE_RET(weight_tensor != nullptr, false);
if ((!TensorQuantParamsInited(*input_tensor) || !TensorQuantParamsInited(*output_tensor)) &&
TensorQuantParamsInited(*weight_tensor)) {
node->quantType = schema::QuantType_QUANT_WEIGHT;
return true;
}
return false;
}
} // namespace mindspore::lite

View File

@ -1,27 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONV_QUANT_TYPE_DETERMINER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONV_QUANT_TYPE_DETERMINER_H_
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class ConvQuantTypeDeterminer : public QuantTypeDeterminer {
public:
bool DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONV_QUANT_TYPE_DETERMINER_H_

View File

@ -1,22 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h"
namespace mindspore::lite {
bool DefaultQuantAllQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
return true;
}
} // namespace mindspore::lite

View File

@ -1,28 +0,0 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_DEFAULT_QUANT_ALL_QUANT_TYPE_DETERMINER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_DEFAULT_QUANT_ALL_QUANT_TYPE_DETERMINER_H_
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class DefaultQuantAllQuantTypeDeterminer : public QuantTypeDeterminer {
public:
bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_DEFAULT_QUANT_ALL_QUANT_TYPE_DETERMINER_H_

View File

@ -1,44 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/matmul_quant_type_determiner.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "src/common/log_adapter.h"
#include "mindspore/core/ir/dtype/type_id.h"
namespace mindspore::lite {
bool MatmulQuantTypeDeterminer::DetermineQuantWeight(const mindspore::schema::MetaGraphT &graph,
mindspore::schema::CNodeT *node) {
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
MS_CHECK_TRUE_RET(node->inputIndex.size() >= kInputIndexTwo, false);
MS_CHECK_TRUE_MSG(graph.allTensors.size() > node->inputIndex.at(kInputIndex), false, "Out of vector range.");
auto &input_tensor1 = graph.allTensors.at(node->inputIndex.at(kInputIndex));
MS_CHECK_TRUE_MSG(graph.allTensors.size() > node->inputIndex.at(kWeightIndex), false, "Out of vector range.");
auto &input_tensor2 = graph.allTensors.at(node->inputIndex.at(kWeightIndex));
MS_CHECK_TRUE_RET(node->outputIndex.size() > kOutputIndex, false);
MS_CHECK_TRUE_MSG(graph.allTensors.size() > node->outputIndex.at(kOutputIndex), false, "Out of vector range.");
auto &output_tensor = graph.allTensors.at(node->outputIndex.at(kOutputIndex));
MS_CHECK_TRUE_RET(input_tensor1 != nullptr, false);
MS_CHECK_TRUE_RET(input_tensor2 != nullptr, false);
MS_CHECK_TRUE_RET(output_tensor != nullptr, false);
if (((!TensorQuantParamsInited(*input_tensor1) && !TensorQuantParamsInited(*input_tensor2)) ||
(!TensorQuantParamsInited(*input_tensor1) && !TensorQuantParamsInited(*input_tensor2))) &&
TensorQuantParamsInited(*output_tensor)) {
node->quantType = schema::QuantType_QUANT_WEIGHT;
return true;
}
return false;
}
} // namespace mindspore::lite

View File

@ -1,27 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_MATMUL_QUANT_TYPE_DETERMINER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_MATMUL_QUANT_TYPE_DETERMINER_H_
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class MatmulQuantTypeDeterminer : public QuantTypeDeterminer {
public:
bool DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_MATMUL_QUANT_TYPE_DETERMINER_H_

View File

@ -1,29 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h"
#include "src/common/log_adapter.h"
namespace mindspore::lite {
bool OnlyNeedInputsQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
MS_ASSERT(node != nullptr);
UpdateQuantParamsNum(graph, *node);
if (input_inited_quant_params_ == node->inputIndex.size()) {
node->quantType = schema::QuantType_QUANT_ALL;
return true;
}
return false;
}
} // namespace mindspore::lite

View File

@ -1,27 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_ONLY_NEED_INPUTS_QUANT_TYPE_DETERMINER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_ONLY_NEED_INPUTS_QUANT_TYPE_DETERMINER_H_
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class OnlyNeedInputsQuantTypeDeterminer : public QuantTypeDeterminer {
public:
bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_ONLY_NEED_INPUTS_QUANT_TYPE_DETERMINER_H_

View File

@ -73,7 +73,7 @@ STATUS QATTransform::DoSingleGraphQATTransform(const FuncGraphPtr &func_graph) {
return ret; return ret;
} }
InsertQuantNodeManager inset_quant_node_pass; InsertQuantNodeManager inset_quant_node_pass;
ret = inset_quant_node_pass.InsertFP32DtypeCastNode(func_graph); ret = inset_quant_node_pass.InsertDequantNode(func_graph);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Add QuantCast failed"; MS_LOG(ERROR) << "Add QuantCast failed";
return RET_ERROR; return RET_ERROR;
@ -149,7 +149,7 @@ int QATTransform::QuantWeight(const FuncGraphPtr &func_graph) {
auto input = cnode->input(i); auto input = cnode->input(i);
ParameterPtr parameter; ParameterPtr parameter;
tensor::TensorPtr tensor_info; tensor::TensorPtr tensor_info;
GetLiteParameter(input, &parameter, &tensor_info); GetParameterAndTensor(input, &parameter, &tensor_info);
if (parameter == nullptr || tensor_info == nullptr || if (parameter == nullptr || tensor_info == nullptr ||
tensor_info->compression_type() != mindspore::kNoCompression || tensor_info->compression_type() != mindspore::kNoCompression ||

View File

@ -1,41 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h"
#include "src/common/log_adapter.h"
#include "mindspore/core/ir/dtype/type_id.h"
#include "nnacl/op_base.h"
namespace mindspore::lite {
int QuantDtypeCastQuantParamPropogator::PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) {
MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
MS_CHECK_TRUE_MSG(!node.inputIndex.empty(), RET_ERROR, "inputIndex is empty.");
MS_CHECK_TRUE_RET(graph->allTensors.size() > node.inputIndex.at(0), RET_ERROR);
auto &input_tensor = graph->allTensors.at(node.inputIndex.at(0));
MS_CHECK_TRUE_RET(input_tensor != nullptr, RET_NULL_PTR);
if (!input_tensor->quantParams.empty() && input_tensor->quantParams.front()->inited) {
input_tensor->quantParams.front()->dstDtype = input_tensor->dataType;
}
MS_CHECK_TRUE_RET(node.outputIndex.size() > 0, RET_ERROR);
MS_CHECK_TRUE_RET(graph->allTensors.size() > node.outputIndex.at(0), RET_ERROR);
auto &output_tensor = graph->allTensors.at(node.outputIndex.at(0));
if (!output_tensor->quantParams.empty() && output_tensor->quantParams.front()->inited) {
output_tensor->quantParams.front()->dstDtype = output_tensor->dataType;
}
return RET_OK;
}
} // namespace mindspore::lite

View File

@ -1,27 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_DTYPE_CAST_QUANT_PARAM_PROPOGATOR_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_DTYPE_CAST_QUANT_PARAM_PROPOGATOR_H_
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
namespace mindspore::lite {
class QuantDtypeCastQuantParamPropogator : public QuantParamPropogator {
public:
int PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_DTYPE_CAST_QUANT_PARAM_PROPOGATOR_H_

View File

@ -1,182 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
#include <unordered_map>
#include <memory>
#include "src/common/log_adapter.h"
#include "tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.h"
#include "tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.h"
#include "tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.h"
#include "tools/converter/quantizer/quant_helper/concat_quant_param_propogator.h"
#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h"
#include "tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h"
#include "tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h"
#include "tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h"
#include "tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h"
#include "tools/converter/quantizer/quant_helper/matmul_quant_type_determiner.h"
#include "src/litert/kernel_exec.h"
#include "src/litert/kernel_registry.h"
namespace mindspore::lite {
void QuantNodeBase::UpdateQuantParamsNum(const schema::MetaGraphT &graph, const schema::CNodeT &node) {
// update input quant params num
input_inited_quant_params_ = 0;
for (auto index : node.inputIndex) {
MS_CHECK_TRUE_RET_VOID(graph.allTensors.size() > index);
auto &input_tensor = graph.allTensors.at(index);
MS_CHECK_TRUE_RET_VOID(input_tensor != nullptr);
if (!input_tensor->quantParams.empty()) {
bool is_quant_params_inited =
!std::any_of(input_tensor->quantParams.begin(), input_tensor->quantParams.end(),
[](const std::unique_ptr<schema::QuantParamT> &quant_param) { return !quant_param->inited; });
if (is_quant_params_inited) {
input_inited_quant_params_++;
}
}
}
// update output quant params num
output_inited_quant_params_ = 0;
for (auto index : node.outputIndex) {
MS_CHECK_TRUE_RET_VOID(graph.allTensors.size() > index);
auto &output_tensor = graph.allTensors.at(index);
MS_CHECK_TRUE_RET_VOID(output_tensor != nullptr);
if (!output_tensor->quantParams.empty()) {
bool is_quant_params_inited =
!std::any_of(output_tensor->quantParams.begin(), output_tensor->quantParams.end(),
[](const std::unique_ptr<schema::QuantParamT> &quant_param) { return !quant_param->inited; });
if (is_quant_params_inited) {
output_inited_quant_params_++;
}
}
}
}
bool QuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
MS_CHECK_TRUE_RET(node != nullptr, false);
kernel::KernelKey desc{kernel::kCPU, kNumberTypeInt8, NHWC, node->primitive->value.type, ""};
if (!KernelRegistry::GetInstance()->SupportKernel(desc)) {
return false;
}
if (node->quantType != schema::QuantType_QUANT_NONE) {
return node->quantType == schema::QuantType_QUANT_ALL;
}
UpdateQuantParamsNum(graph, *node);
if (input_inited_quant_params_ == node->inputIndex.size() &&
output_inited_quant_params_ == node->outputIndex.size()) {
node->quantType = schema::QuantType_QUANT_ALL;
return true;
}
return false;
}
bool QuantTypeDeterminer::DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node) {
return node->quantType == schema::QuantType_QUANT_WEIGHT;
}
int QuantNodeHelper::NodeQuantPreprocess(schema::MetaGraphT *graph, schema::CNodeT *node) {
MS_CHECK_TRUE_RET(graph != nullptr, RET_NULL_PTR);
MS_CHECK_TRUE_RET(node != nullptr, RET_NULL_PTR);
if (quant_type_determiner_->DetermineQuantWeight(*graph, node)) {
return RET_OK;
}
auto ret = quant_param_propogator_->PropogateQuantParams(graph, *node);
if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(ERROR) << node->name << " propagate Quant Params failed.";
return ret;
}
auto bool_ret = quant_type_determiner_->DetermineQuantAll(*graph, node);
if (!bool_ret) {
MS_LOG(DEBUG) << node->name << " dont need quant.";
return RET_OK;
}
return RET_OK;
}
QuantHelperRegister *QuantHelperRegister::GetInstance() {
static QuantHelperRegister instance;
return &instance;
}
QuantNodeHelper *QuantHelperRegister::GetQuantHelper(schema::PrimitiveType op_type) {
auto it = register_map_.find(op_type);
if (it != register_map_.end()) {
return it->second;
}
return register_map_[schema::PrimitiveType_NONE];
}
QuantHelperRegister::QuantHelperRegister() {
auto base_propogator = std::make_shared<QuantParamPropogator>();
auto base_determiner = std::make_shared<QuantTypeDeterminer>();
auto quant_dtype_cast_propogator = std::make_shared<QuantDtypeCastQuantParamPropogator>();
auto bias_add_propogator = std::make_shared<BiasAddQuantParamPropogator>();
auto carry_data_propogator = std::make_shared<CarryDataQuantParamPropogator>();
auto carry_data_determiner = std::make_shared<CarryDataQuantTypeDeterminer>();
auto concat_propogator = std::make_shared<ConcatQuantParamPropogator>();
auto conv_propogator = std::make_shared<ConvQuantParamPropogator>();
auto conv_determiner = std::make_shared<ConvQuantTypeDeterminer>();
auto default_quant_all_determiner = std::make_shared<DefaultQuantAllQuantTypeDeterminer>();
auto only_need_inputs_determiner = std::make_shared<OnlyNeedInputsQuantTypeDeterminer>();
auto matmul_determiner = std::make_shared<MatmulQuantTypeDeterminer>();
register_map_[schema::PrimitiveType_BiasAdd] =
new (std::nothrow) QuantNodeHelper(bias_add_propogator, base_determiner);
register_map_[schema::PrimitiveType_MaxPoolFusion] =
new (std::nothrow) QuantNodeHelper(carry_data_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_Resize] =
new (std::nothrow) QuantNodeHelper(carry_data_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_Reshape] =
new (std::nothrow) QuantNodeHelper(carry_data_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_StridedSlice] =
new (std::nothrow) QuantNodeHelper(carry_data_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_Transpose] =
new (std::nothrow) QuantNodeHelper(carry_data_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_PadFusion] =
new (std::nothrow) QuantNodeHelper(carry_data_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_ReduceFusion] =
new (std::nothrow) QuantNodeHelper(base_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_Gather] =
new (std::nothrow) QuantNodeHelper(carry_data_propogator, carry_data_determiner);
register_map_[schema::PrimitiveType_Concat] = new (std::nothrow) QuantNodeHelper(concat_propogator, base_determiner);
register_map_[schema::PrimitiveType_Conv2DFusion] =
new (std::nothrow) QuantNodeHelper(conv_propogator, conv_determiner);
register_map_[schema::PrimitiveType_MatMulFusion] =
new (std::nothrow) QuantNodeHelper(conv_propogator, matmul_determiner);
register_map_[schema::PrimitiveType_FullConnection] =
new (std::nothrow) QuantNodeHelper(conv_propogator, matmul_determiner);
register_map_[schema::PrimitiveType_QuantDTypeCast] =
new (std::nothrow) QuantNodeHelper(quant_dtype_cast_propogator, default_quant_all_determiner);
register_map_[schema::PrimitiveType_DetectionPostProcess] =
new (std::nothrow) QuantNodeHelper(base_propogator, only_need_inputs_determiner);
register_map_[schema::PrimitiveType_NONE] = new (std::nothrow) QuantNodeHelper(base_propogator, base_determiner);
}
QuantHelperRegister::~QuantHelperRegister() {
for (const auto &iter : register_map_) {
delete iter.second;
}
this->register_map_.clear();
}
} // namespace mindspore::lite

View File

@ -1,74 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_NODE_HELPER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_NODE_HELPER_H_
#include <unordered_map>
#include <memory>
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
namespace mindspore::lite {
constexpr int kInputIndexOne = 1;
constexpr int kInputIndexTwo = 2;
class QuantNodeBase {
public:
void UpdateQuantParamsNum(const schema::MetaGraphT &graph, const schema::CNodeT &node);
protected:
size_t input_inited_quant_params_ = 0;
size_t output_inited_quant_params_ = 0;
};
class QuantParamPropogator : public QuantNodeBase {
public:
virtual int PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) { return RET_OK; }
};
class QuantTypeDeterminer : public QuantNodeBase {
public:
virtual bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node);
virtual bool DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node);
};
class QuantNodeHelper {
public:
int NodeQuantPreprocess(schema::MetaGraphT *graph, schema::CNodeT *node);
QuantNodeHelper(std::shared_ptr<QuantParamPropogator> quant_param_propogator,
std::shared_ptr<QuantTypeDeterminer> quant_type_determiner) {
quant_param_propogator_ = quant_param_propogator;
quant_type_determiner_ = quant_type_determiner;
}
virtual ~QuantNodeHelper() = default;
protected:
std::shared_ptr<QuantParamPropogator> quant_param_propogator_;
std::shared_ptr<QuantTypeDeterminer> quant_type_determiner_;
};
class QuantHelperRegister {
public:
virtual ~QuantHelperRegister();
QuantNodeHelper *GetQuantHelper(schema::PrimitiveType op_type);
static QuantHelperRegister *GetInstance();
private:
QuantHelperRegister();
std::unordered_map<schema::PrimitiveType, QuantNodeHelper *> register_map_;
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_NODE_HELPER_H_

View File

@ -36,7 +36,7 @@ int QuantNodePass::DoWeightQuant(const CNodePtr &cnode) {
auto input = cnode->input(idx); auto input = cnode->input(idx);
ParameterPtr parameter; ParameterPtr parameter;
tensor::TensorPtr weight; tensor::TensorPtr weight;
GetLiteParameter(input, &parameter, &weight); GetParameterAndTensor(input, &parameter, &weight);
if (parameter == nullptr || weight == nullptr || weight->data_type() != TypeId::kNumberTypeFloat32) { if (parameter == nullptr || weight == nullptr || weight->data_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " can not quant weight"; MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " can not quant weight";
continue; continue;

View File

@ -194,7 +194,7 @@ int ConvertFp16ToFp32(const FuncGraphPtr &old_graph) {
} }
ParameterPtr param_node; ParameterPtr param_node;
tensor::TensorPtr tensor_info; tensor::TensorPtr tensor_info;
GetLiteParameter(input, &param_node, &tensor_info); GetParameterAndTensor(input, &param_node, &tensor_info);
CHECK_NULL_RETURN(tensor_info); CHECK_NULL_RETURN(tensor_info);
CHECK_NULL_RETURN(param_node); CHECK_NULL_RETURN(param_node);
if (tensor_info->data_type() == kNumberTypeFloat16) { if (tensor_info->data_type() == kNumberTypeFloat16) {
@ -272,13 +272,13 @@ int PrepareQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<Convert
return RET_OK; return RET_OK;
} }
int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) { int DoSingleGraphQuantize(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param) {
CHECK_NULL_RETURN(param); CHECK_NULL_RETURN(param);
if (param->commonQuantParam.quant_type == quant::QUANT_NONE) { if (param->commonQuantParam.quant_type == quant::QUANT_NONE) {
return RET_OK; return RET_OK;
} }
int status = PrepareQuantize(old_graph, param); int status = PrepareQuantize(func_graph, param);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "PrepareQuantize failed."; MS_LOG(ERROR) << "PrepareQuantize failed.";
return status; return status;
@ -292,54 +292,57 @@ int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<C
origin = std::make_shared<mindspore::Model>(); origin = std::make_shared<mindspore::Model>();
CHECK_NULL_RETURN(origin); CHECK_NULL_RETURN(origin);
size_t size = 0; size_t size = 0;
auto ret = BuildModelByFuncGraph(origin, old_graph, param, &size); auto ret = BuildModelByFuncGraph(origin, func_graph, param, &size);
param->commonQuantParam.quant_type = quant_type; param->commonQuantParam.quant_type = quant_type;
if (ret != kSuccess) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Build model failed"; MS_LOG(ERROR) << "Build model failed";
return RET_ERROR; return RET_ERROR;
} }
origin_lite_model = ParseLiteModel(old_graph, param); origin_lite_model = ParseLiteModel(func_graph, param);
if (origin_lite_model == nullptr) { if (origin_lite_model == nullptr) {
MS_LOG(ERROR) << "Parse lite model failed."; MS_LOG(ERROR) << "Parse lite model failed.";
return RET_ERROR; return RET_ERROR;
} }
} }
if (param->commonQuantParam.quant_type == quant::QUANT_ALL) { // Full Quantization if (param->commonQuantParam.quant_type == quant::QUANT_ALL) { // Full Quantization
status = DoFullQuant(old_graph, param); status = ConvertFp16ToFp32(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Converter fp16 to fp32 failed.";
return status;
}
status = DoFullQuant(func_graph, param);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Do full quant failed."; MS_LOG(ERROR) << "Do full quant failed.";
return status; return status;
} }
} else if (param->commonQuantParam.quant_type == quant::QUANT_WEIGHT) { // Weight Quantization } else if (param->commonQuantParam.quant_type == quant::QUANT_WEIGHT) { // Weight Quantization
status = DoWeightQuant(old_graph, param); status = DoWeightQuant(func_graph, param);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Do weight quant failed."; MS_LOG(ERROR) << "Do weight quant failed.";
return status; return status;
} }
} else if (param->commonQuantParam.quant_type == quant::QUANT_DYNAMIC) { // Dynamic Quantization } else if (param->commonQuantParam.quant_type == quant::QUANT_DYNAMIC) { // Dynamic Quantization
status = DoDynamicQuant(old_graph, param); status = DoDynamicQuant(func_graph, param);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Do dynamic quant failed."; MS_LOG(ERROR) << "Do dynamic quant failed.";
return status; return status;
} }
} }
{ auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto optimizer = std::make_shared<opt::GraphOptimizer>(); CHECK_NULL_RETURN(optimizer);
CHECK_NULL_RETURN(optimizer); auto fusion_pm = std::make_shared<opt::LitePassManager>("fusion pass manager after quant", false);
auto fusion_pm = std::make_shared<opt::LitePassManager>("fusion pass manager after quant", false); CHECK_NULL_RETURN(fusion_pm);
CHECK_NULL_RETURN(fusion_pm); fusion_pm->AddPass(std::make_shared<opt::QuantDtypeCastFusion>());
fusion_pm->AddPass(std::make_shared<opt::QuantDtypeCastFusion>()); fusion_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
fusion_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model)); optimizer->AddPassManager(fusion_pm);
optimizer->AddPassManager(fusion_pm); if (optimizer->Optimize(func_graph) == nullptr) {
if (optimizer->Optimize(old_graph) == nullptr) { MS_LOG(ERROR) << "run cast node fusion failed.";
MS_LOG(ERROR) << "run cast node fusion failed."; return RET_ERROR;
return RET_ERROR;
}
} }
if (param->commonQuantParam.is_debug) { if (param->commonQuantParam.is_debug) {
status = DoQuantDebug(old_graph, param, origin, origin_lite_model); status = DoQuantDebug(func_graph, param, origin, origin_lite_model);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Do quant debug failed."; MS_LOG(ERROR) << "Do quant debug failed.";
return status; return status;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2022 Huawei Technologies Co., Ltd * Copyright 2020-2023 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -45,15 +45,8 @@ using std::vector;
namespace mindspore::lite::quant { namespace mindspore::lite::quant {
namespace { namespace {
constexpr int kLstmInputWeightIndex = 1;
constexpr int kLstmStateWeightIndex = 2;
constexpr int kLstmWeightShapeSize = 3;
constexpr int kSingleDirBiasTensorSize = 4;
constexpr int kLstmBiasShapeSize = 2;
constexpr int kLstmBiasIndex = 3;
constexpr size_t kGatherAxisIndex = 3; constexpr size_t kGatherAxisIndex = 3;
constexpr size_t kAnfPrimitiveIndex = 0; constexpr int kDefaultThreadNum = 4;
constexpr int kDefaultThreadNumFour = 4;
} // namespace } // namespace
int GetQuantType(const CNodePtr &cnode, quant::QuantType *quant_type) { int GetQuantType(const CNodePtr &cnode, quant::QuantType *quant_type) {
@ -92,10 +85,10 @@ void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_f
} }
} }
int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type) { int UpdateDataType(const AnfNodePtr &node, TypeId new_data_type) {
auto abstract_base = cnode->abstract(); auto abstract_base = node->abstract();
if (abstract_base == nullptr) { if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of node is nullptr, " << cnode->fullname_with_scope(); MS_LOG(ERROR) << "Abstract of node is nullptr, " << node->fullname_with_scope();
return RET_NULL_PTR; return RET_NULL_PTR;
} }
@ -260,7 +253,7 @@ Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, con
delete meta_graph; delete meta_graph;
return kLiteNullptr; return kLiteNullptr;
} }
context->SetThreadNum(kDefaultThreadNumFour); context->SetThreadNum(kDefaultThreadNum);
context->SetThreadAffinity(kCpuBindMode); context->SetThreadAffinity(kCpuBindMode);
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>(); std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
@ -297,7 +290,7 @@ std::vector<mindspore::lite::Tensor *> MSTensorToLiteTensors(const std::vector<m
return dst_tensors; return dst_tensors;
} }
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info) { void GetParameterAndTensor(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info) {
if (node == nullptr) { if (node == nullptr) {
MS_LOG(ERROR) << "node is nullptr"; MS_LOG(ERROR) << "node is nullptr";
return; return;
@ -330,7 +323,7 @@ int UpdateTensorDataAndSize(const AnfNodePtr &node, const tensor::TensorPtr &wei
MS_LOG(ERROR) << "Data size of tensor info is error."; MS_LOG(ERROR) << "Data size of tensor info is error.";
return RET_ERROR; return RET_ERROR;
} }
if (memcpy_s(weight->data_c(), new_size, quant_datas, new_size) != EOK) { if (memcpy_s(weight->data_c(), weight->data().nbytes(), quant_datas, new_size) != EOK) {
MS_LOG(ERROR) << "memcpy data failed."; MS_LOG(ERROR) << "memcpy data failed.";
return RET_ERROR; return RET_ERROR;
} }
@ -379,7 +372,7 @@ int GetDeConvPreferredDim(const PrimitivePtr &primitive, const std::vector<int>
} }
int GetGatherPreferredDim(const CNodePtr &cnode) { int GetGatherPreferredDim(const CNodePtr &cnode) {
if (cnode->size() < kGatherAxisIndex + 1) { if (cnode->size() < kGatherAxisIndex + kPrimOffset) {
MS_LOG(WARNING) << "gather cnode size < 4."; MS_LOG(WARNING) << "gather cnode size < 4.";
return 0; return 0;
} }
@ -498,10 +491,10 @@ bool CheckControlFlowType(const AnfNodePtr &node) {
if (node->isa<mindspore::CNode>()) { if (node->isa<mindspore::CNode>()) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
// control flow call // control flow call
if (!IsValueNode<mindspore::Primitive>(cnode->input(kAnfPrimitiveIndex))) { if (!IsValueNode<mindspore::Primitive>(cnode->input(kPrimIndex))) {
return true; return true;
} }
auto prim = GetValuePtr<mindspore::Primitive>(cnode->input(kAnfPrimitiveIndex)); auto prim = GetValuePtr<mindspore::Primitive>(cnode->input(kPrimIndex));
if (control_flow_ops.find(prim->name()) != control_flow_ops.end()) { if (control_flow_ops.find(prim->name()) != control_flow_ops.end()) {
return true; return true;
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2022 Huawei Technologies Co., Ltd * Copyright 2020-2023 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -59,7 +59,7 @@ int GetQuantType(const CNodePtr &cnode, quant::QuantType *quant_type);
void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs); void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs);
int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type); int UpdateDataType(const AnfNodePtr &node, TypeId new_data_type);
bool IsGraphInDTypeCast(const CNodePtr &cnode); bool IsGraphInDTypeCast(const CNodePtr &cnode);
@ -92,7 +92,7 @@ mindspore::lite::Tensor *MSTensorToLiteTensor(const mindspore::MSTensor &tensor)
std::vector<mindspore::lite::Tensor *> MSTensorToLiteTensors(const std::vector<mindspore::MSTensor> &src_tensors); std::vector<mindspore::lite::Tensor *> MSTensorToLiteTensors(const std::vector<mindspore::MSTensor> &src_tensors);
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info); void GetParameterAndTensor(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info);
bool CheckNodeInSet(const CNodePtr &cnode, const std::set<PrimitivePtr> &support_primitive_types); bool CheckNodeInSet(const CNodePtr &cnode, const std::set<PrimitivePtr> &support_primitive_types);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2022 Huawei Technologies Co., Ltd * Copyright 2020-2023 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -152,7 +152,7 @@ int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr
auto input = cnode->input(idx); auto input = cnode->input(idx);
ParameterPtr parameter; ParameterPtr parameter;
tensor::TensorPtr tensor_info; tensor::TensorPtr tensor_info;
GetLiteParameter(input, &parameter, &tensor_info); GetParameterAndTensor(input, &parameter, &tensor_info);
if (parameter == nullptr || tensor_info == nullptr || if (parameter == nullptr || tensor_info == nullptr ||
tensor_info->compression_type() != mindspore::kNoCompression) { tensor_info->compression_type() != mindspore::kNoCompression) {
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " dont need quant weight"; MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " dont need quant weight";
@ -263,7 +263,7 @@ int WeightQuantizer::DoCompression(const CNodePtr &cnode, const ParameterPtr &pa
int WeightQuantizer::DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr &parameter, int idx, int WeightQuantizer::DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr &parameter, int idx,
const tensor::TensorPtr &tensor_info, int preferred_dim, const tensor::TensorPtr &tensor_info, int preferred_dim,
WeightQuantType weight_quant_type, bool symmetric, bool update_tensor) { WeightQuantType weight_quant_type, bool symmetric) {
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
CHECK_NULL_RETURN(primitive); CHECK_NULL_RETURN(primitive);
auto mixed_bit_quantization = MixedBitWeightQuantization(mixed_bit_init_scale_); auto mixed_bit_quantization = MixedBitWeightQuantization(mixed_bit_init_scale_);
@ -293,14 +293,8 @@ int WeightQuantizer::DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr &pa
<< parameter->fullname_with_scope() << parameter->fullname_with_scope()
<< " mixed bit quantization search failed, the current layer rolls back to 8 bit fixed quantization."; << " mixed bit quantization search failed, the current layer rolls back to 8 bit fixed quantization.";
FixedBitWeightQuantization fixed_bit_quant; FixedBitWeightQuantization fixed_bit_quant;
if (update_tensor) { status = fixed_bit_quant.QuantFilter(parameter, tensor_info, primitive, quant_type_, quant_max, quant_min, bit_num_,
status = weight_quant_type, kNumberTypeInt8, idx - 1, preferred_dim, symmetric);
fixed_bit_quant.QuantFilter(parameter, tensor_info, primitive, quant_type_, quant_max, quant_min, bit_num_,
weight_quant_type, kNumberTypeInt8, idx - 1, preferred_dim, symmetric);
} else {
status = fixed_bit_quant.StatisticsFilter(tensor_info, primitive, quant_type_, quant_max, quant_min, bit_num_,
weight_quant_type, kNumberTypeInt8, idx - 1, preferred_dim, symmetric);
}
} }
return status; return status;
} }
@ -332,9 +326,11 @@ int WeightQuantizer::InsertDequantNode(const FuncGraphPtr &func_graph, const CNo
MS_LOG(INFO) << tensor_name << " insert WeightQuant node"; MS_LOG(INFO) << tensor_name << " insert WeightQuant node";
auto axis = GetPreferredDim(cnode, idx - kPrimOffset, ConvertShapeVectorToInt32(tensor_info->shape_c())); auto axis = GetPreferredDim(cnode, idx - kPrimOffset, ConvertShapeVectorToInt32(tensor_info->shape_c()));
if (type_id == kNumberTypeFloat32) { if (type_id == kNumberTypeFloat32) {
status = quant_manager.InsertWeightQuantNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat32, axis); status =
quant_manager.InsertQuantDtypeCastFlyNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat32, axis);
} else { } else {
status = quant_manager.InsertWeightQuantNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat16, axis); status =
quant_manager.InsertQuantDtypeCastFlyNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat16, axis);
} }
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << tensor_name << " insert weight quant node failed."; MS_LOG(ERROR) << tensor_name << " insert weight quant node failed.";
@ -344,8 +340,8 @@ int WeightQuantizer::InsertDequantNode(const FuncGraphPtr &func_graph, const CNo
return RET_OK; return RET_OK;
} }
int WeightQuantizer::MarkCnodeWeightQuantType(const CNodePtr &cnode) { int WeightQuantizer::MarkCNodeWeightQuantType(const CNodePtr &cnode) {
MS_CHECK_TRUE_RET(cnode != nullptr, RET_NULL_PTR); CHECK_NULL_RETURN(cnode);
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
if (primitive == nullptr) { if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr"; MS_LOG(ERROR) << "primitive is nullptr";
@ -353,7 +349,7 @@ int WeightQuantizer::MarkCnodeWeightQuantType(const CNodePtr &cnode) {
} }
auto quant_param_holder = GetCNodeQuantHolder(primitive); auto quant_param_holder = GetCNodeQuantHolder(primitive);
MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr."); CHECK_NULL_RETURN(quant_param_holder);
if (quant_param_holder->quant_type() == quant::QUANT_WEIGHT) { if (quant_param_holder->quant_type() == quant::QUANT_WEIGHT) {
// already marked with QuantType_QUANT_WEIGHT // already marked with QuantType_QUANT_WEIGHT
return RET_OK; return RET_OK;
@ -361,11 +357,11 @@ int WeightQuantizer::MarkCnodeWeightQuantType(const CNodePtr &cnode) {
// Support Share Weight Quant. // Support Share Weight Quant.
for (size_t i = kPrimOffset; i < cnode->size(); i++) { for (size_t i = kPrimOffset; i < cnode->size(); i++) {
auto inputNode = cnode->input(i); auto input_node = cnode->input(i);
if (inputNode->isa<Parameter>()) { if (input_node->isa<Parameter>()) {
ParameterPtr param_node; ParameterPtr param_node;
tensor::TensorPtr tensor_info; tensor::TensorPtr tensor_info;
GetLiteParameter(inputNode, &param_node, &tensor_info); GetParameterAndTensor(input_node, &param_node, &tensor_info);
auto param = weight_quantized_tensors_.find(tensor_info); auto param = weight_quantized_tensors_.find(tensor_info);
if (param != weight_quantized_tensors_.end()) { if (param != weight_quantized_tensors_.end()) {
quant_param_holder->set_quant_type(quant::QUANT_WEIGHT); quant_param_holder->set_quant_type(quant::QUANT_WEIGHT);
@ -381,12 +377,12 @@ int WeightQuantizer::MarkGraphWeightQuantType(const FuncGraphPtr &func_graph) {
for (auto &cnode : func_graph->GetOrderedCnodes()) { for (auto &cnode : func_graph->GetOrderedCnodes()) {
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0)); auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
if (primitive == nullptr) { if (primitive == nullptr) {
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr"; MS_LOG(DEBUG) << cnode->fullname_with_scope() << " primitive is nullptr";
continue; continue;
} }
auto status = MarkCnodeWeightQuantType(cnode); auto status = MarkCNodeWeightQuantType(cnode);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "MarkGraphWeightQuantType error marking " << cnode->fullname_with_scope(); MS_LOG(ERROR) << cnode->fullname_with_scope() << " mark graph QuantType failed.";
return RET_ERROR; return RET_ERROR;
} }
} }
@ -394,7 +390,7 @@ int WeightQuantizer::MarkGraphWeightQuantType(const FuncGraphPtr &func_graph) {
} }
int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR); CHECK_NULL_RETURN(func_graph);
weight_quantized_tensors_.clear(); weight_quantized_tensors_.clear();
const std::set<PrimitivePtr> support_primitive_types = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion, const std::set<PrimitivePtr> support_primitive_types = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
prim::kPrimMatMulFusion, prim::kPrimFullConnection, prim::kPrimMatMulFusion, prim::kPrimFullConnection,

View File

@ -115,11 +115,10 @@ class WeightQuantizer : public Quantizer {
const std::set<PrimitivePtr> &symmetric_types, const std::vector<int> &weight_indices, const std::set<PrimitivePtr> &symmetric_types, const std::vector<int> &weight_indices,
bool compression = true); bool compression = true);
int MarkGraphWeightQuantType(const FuncGraphPtr &func_graph); int MarkGraphWeightQuantType(const FuncGraphPtr &func_graph);
int MarkCnodeWeightQuantType(const CNodePtr &cnode); int MarkCNodeWeightQuantType(const CNodePtr &cnode);
int DoCompression(const CNodePtr &cnode, const ParameterPtr &parameter, int idx); int DoCompression(const CNodePtr &cnode, const ParameterPtr &parameter, int idx);
int DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr &parameter, int idx, const tensor::TensorPtr &tensor_info, int DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr &parameter, int idx, const tensor::TensorPtr &tensor_info,
int preferred_dim, WeightQuantType weight_quant_type, bool symmetric = true, int preferred_dim, WeightQuantType weight_quant_type, bool symmetric = true);
bool update_tensor = true);
int InsertDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const ParameterPtr &parameter, int idx, int InsertDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const ParameterPtr &parameter, int idx,
const tensor::TensorPtr &tensor_info); const tensor::TensorPtr &tensor_info);