forked from mindspore-Ecosystem/mindspore
delete unused code
This commit is contained in:
parent
ec992b0641
commit
c1e641d4f6
|
@ -140,7 +140,7 @@ int CalWeightQuantBias(const float *raw_datas, size_t elem_count, const std::vec
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int DoPerChannelQuant(const float *raw_datas, size_t elem_count, std::vector<schema::QuantParamT> *quant_params,
|
int DoPerChannelQuant(const float *raw_datas, size_t elem_count, std::vector<schema::QuantParamT> *quant_params,
|
||||||
const int &quant_max, const int &quant_min, const size_t &bit_num, std::vector<T> *quant_datas,
|
int quant_max, int quant_min, size_t bit_num, std::vector<T> *quant_datas,
|
||||||
const std::vector<int> &dims, int preferred_dim, bool cal_gain = true, bool symmetric = false,
|
const std::vector<int> &dims, int preferred_dim, bool cal_gain = true, bool symmetric = false,
|
||||||
bool narrow_range = false) {
|
bool narrow_range = false) {
|
||||||
if (raw_datas == nullptr || quant_params == nullptr || quant_datas == nullptr) {
|
if (raw_datas == nullptr || quant_params == nullptr || quant_datas == nullptr) {
|
||||||
|
|
|
@ -26,7 +26,6 @@
|
||||||
#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
|
#include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h"
|
||||||
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
|
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
|
||||||
#include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h"
|
#include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h"
|
||||||
#include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h"
|
|
||||||
#include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h"
|
#include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h"
|
||||||
#include "tools/converter/legacy_optimizer/graph/convert_fp32_to_fp16_pass.h"
|
#include "tools/converter/legacy_optimizer/graph/convert_fp32_to_fp16_pass.h"
|
||||||
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
|
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
|
||||||
|
|
|
@ -1,72 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#define USE_DEPRECATED_API
|
|
||||||
#include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h"
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include "src/common/utils.h"
|
|
||||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
|
||||||
#include "tools/common/node_util.h"
|
|
||||||
#include "nnacl/op_base.h"
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
|
||||||
STATUS InferQuantParamPass::Run(schema::MetaGraphT *graph) {
|
|
||||||
if (graph == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "graph is null";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
|
|
||||||
// forward infer nodes' quant params
|
|
||||||
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
|
|
||||||
auto &node = *iter;
|
|
||||||
if (node == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "node is null";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto quant_helper = QuantHelperRegister::GetInstance()->GetQuantHelper(node->primitive->value.type);
|
|
||||||
MS_CHECK_TRUE_MSG(quant_helper != nullptr, RET_ERROR, "Find QuantHelper return nullptr");
|
|
||||||
auto ret = quant_helper->NodeQuantPreprocess(graph, node.get());
|
|
||||||
if (ret != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "Node quant preprocess failed.";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// backward infer nodes' quant params
|
|
||||||
for (auto iter = graph->nodes.rbegin(); iter != graph->nodes.rend(); iter++) {
|
|
||||||
auto &node = *iter;
|
|
||||||
if (node == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "node is null";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!node->primitive) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto quant_helper = QuantHelperRegister::GetInstance()->GetQuantHelper(node->primitive->value.type);
|
|
||||||
MS_CHECK_TRUE_MSG(quant_helper != nullptr, RET_ERROR, "Find QuantHelper return nullptr");
|
|
||||||
auto ret = quant_helper->NodeQuantPreprocess(graph, node.get());
|
|
||||||
if (ret != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "Node quant preprocess failed.";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
} // namespace mindspore::lite
|
|
|
@ -1,35 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_INFER_QUANT_PARAM_PASS_H_
|
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_INFER_QUANT_PARAM_PASS_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include "tools/converter/optimizer.h"
|
|
||||||
#include "tools/common/graph_util.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace lite {
|
|
||||||
class InferQuantParamPass : public GraphPass {
|
|
||||||
public:
|
|
||||||
InferQuantParamPass() = default;
|
|
||||||
|
|
||||||
~InferQuantParamPass() override = default;
|
|
||||||
|
|
||||||
STATUS Run(schema::MetaGraphT *graph) override;
|
|
||||||
};
|
|
||||||
} // namespace lite
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAPH_INFER_QUANT_PARAM_PASS_H_
|
|
|
@ -548,7 +548,7 @@ int BiasCorrectionStrategy::DoCNodeBiasCorrection(const FuncGraphPtr &quant_func
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (cnode->size() == kHasBiasTensorSize - 1) {
|
} else if (cnode->size() == kHasBiasTensorSize - kPrimOffset) {
|
||||||
MS_LOG(INFO) << op_name << " add bias input";
|
MS_LOG(INFO) << op_name << " add bias input";
|
||||||
// need to add bias input
|
// need to add bias input
|
||||||
auto parameter = quant_func_graph->add_parameter();
|
auto parameter = quant_func_graph->add_parameter();
|
||||||
|
|
|
@ -176,7 +176,7 @@ int ClusterQuantization::KMeansQuantization(const CNodePtr &cnode, const std::ve
|
||||||
auto input = cnode->input(idx);
|
auto input = cnode->input(idx);
|
||||||
ParameterPtr parameter;
|
ParameterPtr parameter;
|
||||||
tensor::TensorPtr tensor_info;
|
tensor::TensorPtr tensor_info;
|
||||||
GetLiteParameter(input, ¶meter, &tensor_info);
|
GetParameterAndTensor(input, ¶meter, &tensor_info);
|
||||||
if (parameter == nullptr || tensor_info == nullptr || tensor_info->data_type() != TypeId::kNumberTypeFloat32 ||
|
if (parameter == nullptr || tensor_info == nullptr || tensor_info->data_type() != TypeId::kNumberTypeFloat32 ||
|
||||||
tensor_info->compression_type() != mindspore::kNoCompression) {
|
tensor_info->compression_type() != mindspore::kNoCompression) {
|
||||||
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " dont need quant weight";
|
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " dont need quant weight";
|
||||||
|
|
|
@ -128,7 +128,7 @@ int FixedBitWeightQuantization::QuantBias(const ParameterPtr &bias, const Primit
|
||||||
auto ret =
|
auto ret =
|
||||||
UpdateTensorDataAndSize(bias, bias_param, quant_datas.data(), shape_size * sizeof(int32_t), kNumberTypeInt32);
|
UpdateTensorDataAndSize(bias, bias_param, quant_datas.data(), shape_size * sizeof(int32_t), kNumberTypeInt32);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << bias->fullname_with_scope() << " update tensor data adn size failed.";
|
MS_LOG(ERROR) << bias->fullname_with_scope() << " update tensor data and size failed.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
|
|
@ -96,14 +96,15 @@ class FixedBitWeightQuantization {
|
||||||
weight_quant_type = FIXED_BIT_PER_LAYER;
|
weight_quant_type = FIXED_BIT_PER_LAYER;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (weight->data_type_c() != kNumberTypeFloat32) {
|
||||||
|
MS_LOG(ERROR) << "data type is not Float32.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<schema::QuantParamT> quant_params;
|
std::vector<schema::QuantParamT> quant_params;
|
||||||
int ret = RET_OK;
|
int ret = RET_OK;
|
||||||
if (weight_quant_type == FIXED_BIT_PER_CHANNEL) {
|
if (weight_quant_type == FIXED_BIT_PER_CHANNEL) {
|
||||||
bool cal_gain = false;
|
bool cal_gain = (quant_type == QUANT_WEIGHT) ? true : false;
|
||||||
if (quant_type == QUANT_WEIGHT) {
|
|
||||||
cal_gain = true;
|
|
||||||
}
|
|
||||||
ret = DoPerChannelQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(), &quant_params, quant_max,
|
ret = DoPerChannelQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(), &quant_params, quant_max,
|
||||||
quant_min, bit_num, quant_data, ConvertShapeVectorToInt32(dims), preferred_dim,
|
quant_min, bit_num, quant_data, ConvertShapeVectorToInt32(dims), preferred_dim,
|
||||||
cal_gain, symmetric, narrow_range);
|
cal_gain, symmetric, narrow_range);
|
||||||
|
|
|
@ -33,7 +33,6 @@
|
||||||
namespace mindspore::lite::quant {
|
namespace mindspore::lite::quant {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kMinSize3 = 3;
|
constexpr size_t kMinSize3 = 3;
|
||||||
constexpr size_t kPrimitiveCOffset = 1;
|
|
||||||
constexpr size_t kTableExtend = 3;
|
constexpr size_t kTableExtend = 3;
|
||||||
constexpr size_t kAlignOffset = 7;
|
constexpr size_t kAlignOffset = 7;
|
||||||
constexpr size_t kInt32Mask = 31;
|
constexpr size_t kInt32Mask = 31;
|
||||||
|
@ -112,16 +111,16 @@ int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const
|
||||||
MS_LOG(ERROR) << op_name << " cnode size:" << cnode->size() << " < 3.";
|
MS_LOG(ERROR) << op_name << " cnode size:" << cnode->size() << " < 3.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto input = cnode->input(kInputIndex + kPrimitiveCOffset);
|
auto input = cnode->input(kInputIndex + kPrimOffset);
|
||||||
if (input->isa<mindspore::CNode>() || IsGraphInput(input)) {
|
if (input->isa<mindspore::CNode>() || IsGraphInput(input)) {
|
||||||
auto ret = InsertDynamicQuantWithIndex(graph, cnode, kInputIndex + kPrimitiveCOffset);
|
auto ret = InsertDynamicQuantWithIndex(graph, cnode, kInputIndex + kPrimOffset);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Insert dynamic quant with index failed.";
|
MS_LOG(ERROR) << "Insert dynamic quant with index failed.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto weight = cnode->input(kWeightIndex + kPrimitiveCOffset);
|
auto weight = cnode->input(kWeightIndex + kPrimOffset);
|
||||||
if (weight->isa<mindspore::CNode>() || IsGraphInput(weight)) {
|
if (weight->isa<mindspore::CNode>() || IsGraphInput(weight)) {
|
||||||
auto ret = InsertDynamicQuantWithIndex(graph, cnode, kWeightIndex + kPrimitiveCOffset);
|
auto ret = InsertDynamicQuantWithIndex(graph, cnode, kWeightIndex + kPrimOffset);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Insert dynamic quant with index failed.";
|
MS_LOG(ERROR) << "Insert dynamic quant with index failed.";
|
||||||
}
|
}
|
||||||
|
@ -130,12 +129,9 @@ int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const
|
||||||
}
|
}
|
||||||
|
|
||||||
int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) {
|
int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) {
|
||||||
MS_CHECK_TRUE_RET(cnode != nullptr, RET_NULL_PTR);
|
CHECK_NULL_RETURN(cnode);
|
||||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
if (primitive == nullptr) {
|
CHECK_NULL_RETURN(primitive);
|
||||||
MS_LOG(ERROR) << "primitive is nullptr";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||||
quant_param_holder->set_quant_type(quant::QUANT_DYNAMIC);
|
quant_param_holder->set_quant_type(quant::QUANT_DYNAMIC);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
@ -144,7 +140,7 @@ int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) {
|
||||||
int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
|
int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
|
||||||
const std::set<PrimitivePtr> &support_dynamic_quant_ops,
|
const std::set<PrimitivePtr> &support_dynamic_quant_ops,
|
||||||
const std::set<std::string> &skip_quant_node) {
|
const std::set<std::string> &skip_quant_node) {
|
||||||
MS_ASSERT(graph != nullptr);
|
CHECK_NULL_RETURN(graph);
|
||||||
auto cnodes = graph->GetOrderedCnodes();
|
auto cnodes = graph->GetOrderedCnodes();
|
||||||
for (auto &cnode : cnodes) {
|
for (auto &cnode : cnodes) {
|
||||||
auto op_name = cnode->fullname_with_scope();
|
auto op_name = cnode->fullname_with_scope();
|
||||||
|
@ -184,7 +180,7 @@ int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int InsertQuantNodeManager::InsertFP32DtypeCastNode(const FuncGraphPtr &graph) {
|
int InsertQuantNodeManager::InsertDequantNode(const FuncGraphPtr &graph) {
|
||||||
CHECK_NULL_RETURN(graph);
|
CHECK_NULL_RETURN(graph);
|
||||||
auto cnodes = graph->GetOrderedCnodes();
|
auto cnodes = graph->GetOrderedCnodes();
|
||||||
for (auto &cnode : cnodes) {
|
for (auto &cnode : cnodes) {
|
||||||
|
@ -219,6 +215,8 @@ int InsertQuantNodeManager::InserQuantCastNode(const FuncGraphPtr &graph, const
|
||||||
InsertDirection insert_direction, TypeId cast_dtype,
|
InsertDirection insert_direction, TypeId cast_dtype,
|
||||||
CastNodeType cast_node_type, size_t index,
|
CastNodeType cast_node_type, size_t index,
|
||||||
const AnfNodePtr &output_node) {
|
const AnfNodePtr &output_node) {
|
||||||
|
CHECK_NULL_RETURN(graph);
|
||||||
|
CHECK_NULL_RETURN(cnode);
|
||||||
if (insert_direction == FORWARD) {
|
if (insert_direction == FORWARD) {
|
||||||
return InsertForwardQuantCastNode(graph, cnode, cast_dtype, index, cast_node_type);
|
return InsertForwardQuantCastNode(graph, cnode, cast_dtype, index, cast_node_type);
|
||||||
} else if (insert_direction == BACKWARD && cast_node_type == kDeQuant) {
|
} else if (insert_direction == BACKWARD && cast_node_type == kDeQuant) {
|
||||||
|
@ -287,7 +285,7 @@ int InsertQuantNodeManager::InsertForwardQuantCastNode(const FuncGraphPtr &graph
|
||||||
ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params);
|
ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params);
|
||||||
std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node};
|
std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node};
|
||||||
auto quant_cast_cnode = graph->NewCNode(op_inputs);
|
auto quant_cast_cnode = graph->NewCNode(op_inputs);
|
||||||
MS_CHECK_TRUE_MSG(quant_cast_cnode != nullptr, RET_NULL_PTR, "quant_cast_cnode is nullptr.");
|
CHECK_NULL_RETURN(quant_cast_cnode);
|
||||||
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) +
|
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) +
|
||||||
"_pre");
|
"_pre");
|
||||||
// set abstract
|
// set abstract
|
||||||
|
@ -305,7 +303,7 @@ int InsertQuantNodeManager::InsertForwardQuantCastNode(const FuncGraphPtr &graph
|
||||||
if (manager == nullptr) {
|
if (manager == nullptr) {
|
||||||
manager = Manage(graph, true);
|
manager = Manage(graph, true);
|
||||||
}
|
}
|
||||||
MS_CHECK_TRUE_RET(manager != nullptr, RET_NULL_PTR);
|
CHECK_NULL_RETURN(manager);
|
||||||
manager->SetEdge(cnode, index, quant_cast_cnode);
|
manager->SetEdge(cnode, index, quant_cast_cnode);
|
||||||
MS_LOG(INFO) << "InsertForwardQuantCastNode cnode name: " << cnode->fullname_with_scope()
|
MS_LOG(INFO) << "InsertForwardQuantCastNode cnode name: " << cnode->fullname_with_scope()
|
||||||
<< " src dtype:" << src_dtype << " dst_type: " << dst_dtype;
|
<< " src dtype:" << src_dtype << " dst_type: " << dst_dtype;
|
||||||
|
@ -464,8 +462,9 @@ int InsertQuantNodeManager::InsertBackwardCastNode(const FuncGraphPtr &graph, co
|
||||||
} // node_users
|
} // node_users
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
int InsertQuantNodeManager::InsertWeightQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
int InsertQuantNodeManager::InsertQuantDtypeCastFlyNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||||
size_t input_index, TypeId src_dtype, TypeId dst_dtype, int axis) {
|
size_t input_index, TypeId src_dtype, TypeId dst_dtype,
|
||||||
|
int axis) {
|
||||||
auto primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
|
auto primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
|
||||||
if (primitive == nullptr) {
|
if (primitive == nullptr) {
|
||||||
MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
|
MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
|
||||||
|
@ -688,9 +687,8 @@ ValueNodePtr InsertQuantNodeManager::NewQuantCastPrimitive(int src_type, int dst
|
||||||
return NewValueNode(prim);
|
return NewValueNode(prim);
|
||||||
}
|
}
|
||||||
|
|
||||||
ValueNodePtr InsertQuantNodeManager::NewFSEDecodePrimitive(int dst_type, const uint64_t curr_chunk,
|
ValueNodePtr InsertQuantNodeManager::NewFSEDecodePrimitive(int dst_type, uint64_t curr_chunk, int64_t curr_chunk_index,
|
||||||
const int64_t curr_chunk_index, const int64_t curr_bit_count,
|
int64_t curr_bit_count, int64_t table_log) {
|
||||||
const int64_t table_log) {
|
|
||||||
auto prim_c = std::make_shared<ops::FSEDecode>();
|
auto prim_c = std::make_shared<ops::FSEDecode>();
|
||||||
MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
|
MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
|
||||||
prim_c->Init(dst_type, curr_chunk, curr_chunk_index, curr_bit_count, table_log);
|
prim_c->Init(dst_type, curr_chunk, curr_chunk_index, curr_bit_count, table_log);
|
||||||
|
|
|
@ -36,7 +36,7 @@ class InsertQuantNodeManager {
|
||||||
int InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set<PrimitivePtr> &support_dynamic_quant_ops,
|
int InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set<PrimitivePtr> &support_dynamic_quant_ops,
|
||||||
const std::set<std::string> &skip_quant_node);
|
const std::set<std::string> &skip_quant_node);
|
||||||
|
|
||||||
int InsertFP32DtypeCastNode(const FuncGraphPtr &graph);
|
int InsertDequantNode(const FuncGraphPtr &graph);
|
||||||
|
|
||||||
int InsertForwardCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype,
|
int InsertForwardCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype,
|
||||||
quant::QuantType curr_quant_type);
|
quant::QuantType curr_quant_type);
|
||||||
|
@ -47,8 +47,9 @@ class InsertQuantNodeManager {
|
||||||
int InsertBackwardCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype,
|
int InsertBackwardCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype,
|
||||||
quant::QuantType curr_quant_type);
|
quant::QuantType curr_quant_type);
|
||||||
|
|
||||||
int InsertWeightQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, TypeId src_dtype,
|
int InsertQuantDtypeCastFlyNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
|
||||||
TypeId dst_dtype, int axis);
|
TypeId src_dtype, TypeId dst_dtype, int axis);
|
||||||
|
|
||||||
int InsertFSEDecodeNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, TypeId dst_dtype);
|
int InsertFSEDecodeNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, TypeId dst_dtype);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -67,8 +68,10 @@ class InsertQuantNodeManager {
|
||||||
|
|
||||||
int InsertBackwardDeQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype, size_t index,
|
int InsertBackwardDeQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype, size_t index,
|
||||||
const AnfNodePtr &output_node);
|
const AnfNodePtr &output_node);
|
||||||
|
|
||||||
int InserQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, InsertDirection insert_direction,
|
int InserQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, InsertDirection insert_direction,
|
||||||
TypeId cast_dtype, CastNodeType cast_node_type, size_t index, const AnfNodePtr &output_node);
|
TypeId cast_dtype, CastNodeType cast_node_type, size_t index, const AnfNodePtr &output_node);
|
||||||
|
|
||||||
int CreateFSEInputs(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, std::vector<AnfNodePtr> *op_inputs,
|
int CreateFSEInputs(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, std::vector<AnfNodePtr> *op_inputs,
|
||||||
TypeId dst_dtype);
|
TypeId dst_dtype);
|
||||||
|
|
||||||
|
@ -77,8 +80,8 @@ class InsertQuantNodeManager {
|
||||||
const std::vector<schema::QuantParamT> &output_quant_params, int axis = 0,
|
const std::vector<schema::QuantParamT> &output_quant_params, int axis = 0,
|
||||||
bool set_quant_flag = true);
|
bool set_quant_flag = true);
|
||||||
|
|
||||||
ValueNodePtr NewFSEDecodePrimitive(int dst_type, const uint64_t curr_chunk, const int64_t curr_chunk_index,
|
ValueNodePtr NewFSEDecodePrimitive(int dst_type, uint64_t curr_chunk, int64_t curr_chunk_index,
|
||||||
const int64_t curr_bit_count, const int64_t table_log);
|
int64_t curr_bit_count, int64_t table_log);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TypeId dst_type_ = kNumberTypeInt8;
|
TypeId dst_type_ = kNumberTypeInt8;
|
||||||
|
|
|
@ -1,67 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "tools/converter/quantizer/quant_helper/attention_quant_type_determiner.h"
|
|
||||||
#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h"
|
|
||||||
#include "tools/converter/quantizer/quantize_util.h"
|
|
||||||
#include "src/common/log_adapter.h"
|
|
||||||
#include "mindspore/core/ir/dtype/type_id.h"
|
|
||||||
#include "nnacl/op_base.h"
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
|
||||||
const size_t kWeightQueryIndex = 4;
|
|
||||||
const size_t kWeightKeyIndex = 5;
|
|
||||||
const size_t kWeightValueIndex = 6;
|
|
||||||
const size_t kWeightOutputIndex = 10;
|
|
||||||
|
|
||||||
bool AttentionQuantTypeDeterminer::DetermineQuantWeight(const mindspore::schema::MetaGraphT &graph,
|
|
||||||
mindspore::schema::CNodeT *node) {
|
|
||||||
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
|
|
||||||
auto input_index = node->inputIndex;
|
|
||||||
MS_CHECK_FALSE_MSG(input_index.empty(), false, "inputIndex is empty.");
|
|
||||||
MS_CHECK_TRUE_MSG(input_index.size() > kInputIndex, false, "invalid access.");
|
|
||||||
MS_CHECK_TRUE_MSG(graph.allTensors.size() > input_index.at(kInputIndex), false, "invalid access.");
|
|
||||||
auto &input_tensor = graph.allTensors.at(input_index.at(kInputIndex));
|
|
||||||
|
|
||||||
MS_CHECK_TRUE_MSG(input_index.size() > kWeightQueryIndex, false, "invalid access.");
|
|
||||||
MS_CHECK_TRUE_MSG(graph.allTensors.size() > input_index.at(kWeightQueryIndex), false, "invalid access.");
|
|
||||||
auto &weight_query_tensor = graph.allTensors.at(input_index.at(kWeightQueryIndex));
|
|
||||||
|
|
||||||
MS_CHECK_TRUE_MSG(input_index.size() > kWeightKeyIndex, false, "invalid access.");
|
|
||||||
MS_CHECK_TRUE_MSG(graph.allTensors.size() > input_index.at(kWeightKeyIndex), false, "invalid access.");
|
|
||||||
auto &weight_key_tensor = graph.allTensors.at(input_index.at(kWeightKeyIndex));
|
|
||||||
|
|
||||||
MS_CHECK_TRUE_MSG(input_index.size() > kWeightValueIndex, false, "invalid access.");
|
|
||||||
MS_CHECK_TRUE_MSG(graph.allTensors.size() > input_index.at(kWeightValueIndex), false, "invalid access.");
|
|
||||||
auto &weight_value_tensor = graph.allTensors.at(input_index.at(kWeightValueIndex));
|
|
||||||
|
|
||||||
MS_CHECK_TRUE_MSG(input_index.size() > kWeightOutputIndex, false, "invalid access.");
|
|
||||||
MS_CHECK_TRUE_MSG(graph.allTensors.size() > input_index.at(kWeightOutputIndex), false, "invalid access.");
|
|
||||||
auto &weight_output_tensor = graph.allTensors.at(input_index.at(kWeightOutputIndex));
|
|
||||||
|
|
||||||
MS_CHECK_TRUE_RET(input_tensor != nullptr, false);
|
|
||||||
MS_CHECK_TRUE_RET(weight_query_tensor != nullptr, false);
|
|
||||||
MS_CHECK_TRUE_RET(weight_key_tensor != nullptr, false);
|
|
||||||
MS_CHECK_TRUE_RET(weight_value_tensor != nullptr, false);
|
|
||||||
MS_CHECK_TRUE_RET(weight_output_tensor != nullptr, false);
|
|
||||||
if (!TensorQuantParamsInited(*input_tensor) && TensorQuantParamsInited(*weight_query_tensor) &&
|
|
||||||
TensorQuantParamsInited(*weight_key_tensor) && TensorQuantParamsInited(*weight_value_tensor) &&
|
|
||||||
TensorQuantParamsInited(*weight_output_tensor)) {
|
|
||||||
node->quantType = schema::QuantType_QUANT_WEIGHT;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} // namespace mindspore::lite
|
|
|
@ -1,27 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_ATTENTION_QUANT_TYPE_DETERMINER_H_
|
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_ATTENTION_QUANT_TYPE_DETERMINER_H_
|
|
||||||
|
|
||||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
|
||||||
namespace mindspore::lite {
|
|
||||||
class AttentionQuantTypeDeterminer : public QuantTypeDeterminer {
|
|
||||||
public:
|
|
||||||
bool DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
|
|
||||||
};
|
|
||||||
} // namespace mindspore::lite
|
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_ATTENTION_QUANT_TYPE_DETERMINER_H_
|
|
|
@ -1,37 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#define USE_DEPRECATED_API
|
|
||||||
#include "tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.h"
|
|
||||||
#include "mindspore/core/ir/dtype/type_id.h"
|
|
||||||
#include "src/common/log_adapter.h"
|
|
||||||
#include "nnacl/op_base.h"
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
|
||||||
static constexpr size_t kBiasAddSize = 2;
|
|
||||||
int BiasAddQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGraphT *graph,
|
|
||||||
const mindspore::schema::CNodeT &node) {
|
|
||||||
MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
|
|
||||||
if (node.inputIndex.size() == kBiasAddSize) {
|
|
||||||
auto &bias_tensor = graph->allTensors.at(node.inputIndex.at(kBiasAddSize - 1));
|
|
||||||
MS_CHECK_TRUE_RET(bias_tensor != nullptr, RET_NULL_PTR);
|
|
||||||
for (auto &quantParam : bias_tensor->quantParams) {
|
|
||||||
quantParam->dstDtype = TypeId::kNumberTypeInt32;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
} // namespace mindspore::lite
|
|
|
@ -1,27 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_BIAS_ADD_QUANT_PARAM_PROPOGATOR_H_
|
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_BIAS_ADD_QUANT_PARAM_PROPOGATOR_H_
|
|
||||||
|
|
||||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
|
||||||
namespace mindspore::lite {
|
|
||||||
class BiasAddQuantParamPropogator : public QuantParamPropogator {
|
|
||||||
public:
|
|
||||||
int PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
|
|
||||||
};
|
|
||||||
} // namespace mindspore::lite
|
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_BIAS_ADD_QUANT_PARAM_PROPOGATOR_H_
|
|
|
@ -1,79 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.h"
|
|
||||||
#include <utility>
|
|
||||||
#include <memory>
|
|
||||||
#include "tools/common/tensor_util.h"
|
|
||||||
#include "src//common/log_util.h"
|
|
||||||
#include "nnacl/op_base.h"
|
|
||||||
namespace mindspore::lite {
|
|
||||||
int CarryDataQuantParamPropogator::PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) {
|
|
||||||
MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "Graph is nullptr.");
|
|
||||||
UpdateQuantParamsNum(*graph, node);
|
|
||||||
|
|
||||||
MS_CHECK_FALSE_MSG(graph->allTensors.empty(), RET_ERROR, "Tensors is empty.");
|
|
||||||
// refresh in_tensor quant_params by out_tensor quant_params
|
|
||||||
if (input_inited_quant_params_ < 1) {
|
|
||||||
MS_CHECK_FALSE_MSG(node.outputIndex.empty(), RET_ERROR, "OutputIndex is empty.");
|
|
||||||
MS_CHECK_TRUE_RET(graph->allTensors.size() > node.outputIndex.at(0), RET_ERROR);
|
|
||||||
auto &out_tensor = graph->allTensors.at(node.outputIndex.at(0));
|
|
||||||
auto out_quant_param = GetTensorQuantParam(out_tensor);
|
|
||||||
if (out_quant_param == nullptr || !out_quant_param->inited) {
|
|
||||||
MS_LOG(DEBUG) << node.name << " dont need to pass quant param.";
|
|
||||||
return RET_NO_CHANGE;
|
|
||||||
}
|
|
||||||
MS_CHECK_FALSE_MSG(node.inputIndex.empty(), RET_ERROR, "inputIndex is empty.");
|
|
||||||
MS_CHECK_TRUE_RET(graph->allTensors.size() > node.inputIndex.at(0), RET_ERROR);
|
|
||||||
auto &in_tensor = graph->allTensors.at(node.inputIndex.at(0));
|
|
||||||
MS_CHECK_TRUE_RET(in_tensor != nullptr, RET_NULL_PTR);
|
|
||||||
auto in_quant_param = GetTensorQuantParam(in_tensor);
|
|
||||||
if (in_quant_param != nullptr && !in_quant_param->inited) {
|
|
||||||
MS_CHECK_FALSE_MSG(in_tensor->quantParams.empty(), RET_ERROR, "in_tensor quantParams is empty.");
|
|
||||||
in_tensor->quantParams.front() = std::move(out_quant_param);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// refresh out_tensor quant_params by in_tensor quant_params
|
|
||||||
if (output_inited_quant_params_ < 1) {
|
|
||||||
MS_CHECK_FALSE_MSG(node.inputIndex.empty(), RET_ERROR, "inputIndex is empty.");
|
|
||||||
MS_CHECK_TRUE_RET(graph->allTensors.size() > node.inputIndex.at(0), RET_ERROR);
|
|
||||||
auto &in_tensor = graph->allTensors.at(node.inputIndex.at(0));
|
|
||||||
MS_CHECK_TRUE_RET(in_tensor != nullptr, RET_NULL_PTR);
|
|
||||||
auto in_quant_param = GetTensorQuantParam(in_tensor);
|
|
||||||
if (in_quant_param == nullptr || !in_quant_param->inited) {
|
|
||||||
MS_LOG(DEBUG) << node.name << " dont need to pass quant param.";
|
|
||||||
return RET_NO_CHANGE;
|
|
||||||
}
|
|
||||||
for (unsigned int i : node.outputIndex) {
|
|
||||||
MS_CHECK_TRUE_RET(graph->allTensors.size() > i, RET_ERROR);
|
|
||||||
auto &out_tensor = graph->allTensors.at(i);
|
|
||||||
MS_CHECK_TRUE_RET(out_tensor != nullptr, RET_NULL_PTR);
|
|
||||||
auto out_quant_param = GetTensorQuantParam(out_tensor);
|
|
||||||
if (out_quant_param == nullptr) {
|
|
||||||
out_tensor->quantParams.emplace_back(std::move(in_quant_param));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (out_quant_param->inited) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
MS_CHECK_FALSE_MSG(out_tensor->quantParams.empty(), RET_ERROR, "out_tensor quantParams is empty.");
|
|
||||||
out_tensor->quantParams.front() = std::move(in_quant_param);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
} // namespace mindspore::lite
|
|
|
@ -1,27 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CARRY_DATA_QUANT_PARAM_PROPOGATOR_H_
|
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CARRY_DATA_QUANT_PARAM_PROPOGATOR_H_
|
|
||||||
|
|
||||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
|
||||||
namespace mindspore::lite {
|
|
||||||
class CarryDataQuantParamPropogator : public QuantParamPropogator {
|
|
||||||
public:
|
|
||||||
int PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
|
|
||||||
};
|
|
||||||
} // namespace mindspore::lite
|
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CARRY_DATA_QUANT_PARAM_PROPOGATOR_H_
|
|
|
@ -1,56 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.h"
|
|
||||||
#include <utility>
|
|
||||||
#include <memory>
|
|
||||||
#include "tools/common/tensor_util.h"
|
|
||||||
#include "nnacl/op_base.h"
|
|
||||||
namespace mindspore::lite {
|
|
||||||
bool CarryDataQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
|
|
||||||
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
|
|
||||||
MS_CHECK_TRUE_RET(node->inputIndex.size() >= kInputIndexOne, false);
|
|
||||||
MS_CHECK_TRUE_RET(node->outputIndex.size() >= kInputIndexOne, false);
|
|
||||||
|
|
||||||
// check first in tensor
|
|
||||||
MS_CHECK_FALSE_MSG(node->inputIndex.empty(), false, "inputIndex is empty.");
|
|
||||||
MS_CHECK_TRUE_RET(graph.allTensors.size() > node->inputIndex.at(0), false);
|
|
||||||
auto &in_tensor = graph.allTensors.at(node->inputIndex.at(0));
|
|
||||||
MS_CHECK_TRUE_RET(in_tensor != nullptr, false);
|
|
||||||
if (!in_tensor->quantParams.empty()) {
|
|
||||||
if (std::any_of(in_tensor->quantParams.begin(), in_tensor->quantParams.end(),
|
|
||||||
[](const std::unique_ptr<QuantParamT> &quant_param) { return !quant_param->inited; })) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// check first out tensor
|
|
||||||
MS_CHECK_FALSE_MSG(node->outputIndex.empty(), false, "outputIndex is empty.");
|
|
||||||
MS_CHECK_TRUE_RET(graph.allTensors.size() > node->outputIndex.at(0), false);
|
|
||||||
auto &out_tensor = graph.allTensors.at(node->outputIndex.at(0));
|
|
||||||
MS_CHECK_TRUE_RET(out_tensor != nullptr, false);
|
|
||||||
if (!out_tensor->quantParams.empty()) {
|
|
||||||
if (std::any_of(out_tensor->quantParams.begin(), out_tensor->quantParams.end(),
|
|
||||||
[](const std::unique_ptr<QuantParamT> &quant_param) { return !quant_param->inited; })) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
node->quantType = schema::QuantType_QUANT_ALL;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} // namespace mindspore::lite
|
|
|
@ -1,27 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CARRY_DATA_QUANT_TYPE_DETERMINER_H_
|
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CARRY_DATA_QUANT_TYPE_DETERMINER_H_
|
|
||||||
|
|
||||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
|
||||||
namespace mindspore::lite {
|
|
||||||
class CarryDataQuantTypeDeterminer : public QuantTypeDeterminer {
|
|
||||||
public:
|
|
||||||
bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
|
|
||||||
};
|
|
||||||
} // namespace mindspore::lite
|
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CARRY_DATA_QUANT_TYPE_DETERMINER_H_
|
|
|
@ -1,89 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "tools/converter/quantizer/quant_helper/concat_quant_param_propogator.h"
|
|
||||||
#include <cfloat>
|
|
||||||
#include <memory>
|
|
||||||
#include <utility>
|
|
||||||
#include "src/common/log_adapter.h"
|
|
||||||
#include "tools/common/tensor_util.h"
|
|
||||||
#include "tools/converter/quantizer/quantize_util.h"
|
|
||||||
#include "nnacl/op_base.h"
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
|
||||||
int ConcatQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGraphT *graph,
|
|
||||||
const mindspore::schema::CNodeT &node) {
|
|
||||||
MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
|
|
||||||
UpdateQuantParamsNum(*graph, node);
|
|
||||||
|
|
||||||
if (input_inited_quant_params_ != node.inputIndex.size()) {
|
|
||||||
MS_LOG(DEBUG) << "Can not determine concat inputTensor quantParam, node " << node.name;
|
|
||||||
return RET_NO_CHANGE;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (output_inited_quant_params_ != 1) {
|
|
||||||
MS_CHECK_TRUE_RET(output_inited_quant_params_ == 0, RET_ERROR);
|
|
||||||
float min_min = FLT_MAX;
|
|
||||||
float max_max = FLT_MIN;
|
|
||||||
bool narrow_range = false;
|
|
||||||
int num_bits = -1;
|
|
||||||
for (size_t index : node.inputIndex) {
|
|
||||||
MS_ASSERT(graph->allTensors.size() > index);
|
|
||||||
auto &in_tensor = graph->allTensors.at(index);
|
|
||||||
MS_ASSERT(in_tensor != nullptr);
|
|
||||||
auto in_quant_param = GetTensorQuantParam(in_tensor);
|
|
||||||
if (in_quant_param == nullptr || !in_quant_param->inited) {
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
if (num_bits == -1) {
|
|
||||||
narrow_range = in_quant_param->narrowRange;
|
|
||||||
num_bits = in_quant_param->numBits;
|
|
||||||
} else {
|
|
||||||
MS_ASSERT(narrow_range == quantParam->narrowRange);
|
|
||||||
MS_ASSERT(num_bits == quantParam->numBits);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (in_quant_param->max < in_quant_param->min) {
|
|
||||||
MS_LOG(DEBUG) << "Input quant param is invalid for propogator";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (min_min > in_quant_param->min) {
|
|
||||||
min_min = in_quant_param->min;
|
|
||||||
}
|
|
||||||
if (max_max < in_quant_param->max) {
|
|
||||||
max_max = in_quant_param->max;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_CHECK_FALSE_MSG(node.outputIndex.empty(), RET_ERROR, "Output index is empty.");
|
|
||||||
MS_CHECK_TRUE_RET(graph->allTensors.size() > node.outputIndex.front(), RET_ERROR);
|
|
||||||
auto &out_tensor = graph->allTensors.at(node.outputIndex.front());
|
|
||||||
MS_CHECK_TRUE_RET(out_tensor != nullptr, RET_NULL_PTR);
|
|
||||||
auto out_quant_param = std::make_unique<QuantParamT>();
|
|
||||||
MS_CHECK_TRUE_MSG(out_quant_param != nullptr, RET_NULL_PTR, "out_quant_param is nullptr.");
|
|
||||||
|
|
||||||
auto status = CalQuantizationParams(out_quant_param.get(), min_min, max_max, num_bits, narrow_range);
|
|
||||||
if (status != RET_OK) {
|
|
||||||
MS_LOG(DEBUG) << "in aware quantization run CalQuantizationParams failed!";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
out_tensor->quantParams.emplace_back(std::move(out_quant_param));
|
|
||||||
output_inited_quant_params_++;
|
|
||||||
}
|
|
||||||
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
} // namespace mindspore::lite
|
|
|
@ -1,27 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONCAT_QUANT_PARAM_PROPOGATOR_H_
|
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONCAT_QUANT_PARAM_PROPOGATOR_H_
|
|
||||||
|
|
||||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
|
||||||
namespace mindspore::lite {
|
|
||||||
class ConcatQuantParamPropogator : public QuantParamPropogator {
|
|
||||||
public:
|
|
||||||
int PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
|
|
||||||
};
|
|
||||||
} // namespace mindspore::lite
|
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONCAT_QUANT_PARAM_PROPOGATOR_H_
|
|
|
@ -1,71 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h"
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include <utility>
|
|
||||||
#include "mindspore/core/ir/dtype/type_id.h"
|
|
||||||
#include "src/common/log_adapter.h"
|
|
||||||
#include "nnacl/op_base.h"
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
|
||||||
static constexpr size_t kBiasAdd = 3;
|
|
||||||
|
|
||||||
int ConvQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGraphT *graph,
|
|
||||||
const mindspore::schema::CNodeT &node) {
|
|
||||||
MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
|
|
||||||
if (node.inputIndex.size() == kBiasAdd) {
|
|
||||||
MS_CHECK_TRUE_RET(graph->allTensors.size() > node.inputIndex.at(kBiasAdd - 1), RET_ERROR);
|
|
||||||
auto &bias_tensor = graph->allTensors.at(node.inputIndex.at(kBiasAdd - 1));
|
|
||||||
if (bias_tensor->quantParams.empty() || !bias_tensor->quantParams.front()->inited) {
|
|
||||||
// check input and weight quant params
|
|
||||||
auto &input_tensor = graph->allTensors.at(node.inputIndex.at(0));
|
|
||||||
auto &weight_tensor = graph->allTensors.at(node.inputIndex.at(1));
|
|
||||||
MS_CHECK_TRUE_RET(input_tensor != nullptr, RET_NULL_PTR);
|
|
||||||
MS_CHECK_TRUE_RET(weight_tensor != nullptr, RET_NULL_PTR);
|
|
||||||
if (input_tensor->quantParams.empty() || !input_tensor->quantParams.front()->inited) {
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (weight_tensor->quantParams.empty() || !weight_tensor->quantParams.front()->inited) {
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
auto &input_quant_param = input_tensor->quantParams.at(0);
|
|
||||||
std::vector<std::unique_ptr<schema::QuantParamT>> bias_quant_params;
|
|
||||||
for (auto &weight_quant_param : weight_tensor->quantParams) {
|
|
||||||
auto bias_quant_param = std::make_unique<schema::QuantParamT>();
|
|
||||||
MS_CHECK_TRUE_MSG(bias_quant_param != nullptr, RET_NULL_PTR, "bias_quant_param is nullptr.");
|
|
||||||
bias_quant_param->min = 0.0;
|
|
||||||
bias_quant_param->max = 0.0;
|
|
||||||
bias_quant_param->dstDtype = kNumberTypeInt32;
|
|
||||||
bias_quant_param->inited = input_quant_param->inited && weight_quant_param->inited;
|
|
||||||
bias_quant_param->zeroPoint = 0;
|
|
||||||
if (bias_quant_param->inited) {
|
|
||||||
bias_quant_param->scale = input_quant_param->scale * weight_quant_param->scale;
|
|
||||||
}
|
|
||||||
bias_quant_param->roundType = 1;
|
|
||||||
bias_quant_param->multiplier = 1;
|
|
||||||
bias_quant_params.emplace_back(std::move(bias_quant_param));
|
|
||||||
}
|
|
||||||
bias_tensor->quantParams = std::move(bias_quant_params);
|
|
||||||
}
|
|
||||||
for (auto &quantParam : bias_tensor->quantParams) {
|
|
||||||
quantParam->dstDtype = TypeId::kNumberTypeInt32;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
} // namespace mindspore::lite
|
|
|
@ -1,27 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONV_QUANT_PARAM_PROPOGATOR_H_
|
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONV_QUANT_PARAM_PROPOGATOR_H_
|
|
||||||
|
|
||||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
|
||||||
namespace mindspore::lite {
|
|
||||||
class ConvQuantParamPropogator : public QuantParamPropogator {
|
|
||||||
public:
|
|
||||||
int PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
|
|
||||||
};
|
|
||||||
} // namespace mindspore::lite
|
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONV_QUANT_PARAM_PROPOGATOR_H_
|
|
|
@ -1,42 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h"
|
|
||||||
#include "tools/converter/quantizer/quantize_util.h"
|
|
||||||
#include "src/common/log_adapter.h"
|
|
||||||
namespace mindspore::lite {
|
|
||||||
bool ConvQuantTypeDeterminer::DetermineQuantWeight(const mindspore::schema::MetaGraphT &graph,
|
|
||||||
mindspore::schema::CNodeT *node) {
|
|
||||||
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
|
|
||||||
MS_CHECK_TRUE_RET(node->inputIndex.size() >= kInputIndexTwo, false);
|
|
||||||
MS_CHECK_TRUE_MSG(graph.allTensors.size() > node->inputIndex.at(kInputIndex), false, "Out of vector range.");
|
|
||||||
auto &input_tensor = graph.allTensors.at(node->inputIndex.at(kInputIndex));
|
|
||||||
MS_CHECK_TRUE_MSG(graph.allTensors.size() > node->inputIndex.at(kWeightIndex), false, "Out of vector range.");
|
|
||||||
auto &weight_tensor = graph.allTensors.at(node->inputIndex.at(kWeightIndex));
|
|
||||||
MS_CHECK_TRUE_RET(node->outputIndex.size() > kOutputIndex, false);
|
|
||||||
MS_CHECK_TRUE_MSG(graph.allTensors.size() > node->outputIndex.at(kOutputIndex), false, "Out of vector range.");
|
|
||||||
auto &output_tensor = graph.allTensors.at(node->outputIndex.at(kOutputIndex));
|
|
||||||
|
|
||||||
MS_CHECK_TRUE_RET(input_tensor != nullptr, false);
|
|
||||||
MS_CHECK_TRUE_RET(output_tensor != nullptr, false);
|
|
||||||
MS_CHECK_TRUE_RET(weight_tensor != nullptr, false);
|
|
||||||
if ((!TensorQuantParamsInited(*input_tensor) || !TensorQuantParamsInited(*output_tensor)) &&
|
|
||||||
TensorQuantParamsInited(*weight_tensor)) {
|
|
||||||
node->quantType = schema::QuantType_QUANT_WEIGHT;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} // namespace mindspore::lite
|
|
|
@ -1,27 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONV_QUANT_TYPE_DETERMINER_H_
|
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONV_QUANT_TYPE_DETERMINER_H_
|
|
||||||
|
|
||||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
|
||||||
namespace mindspore::lite {
|
|
||||||
class ConvQuantTypeDeterminer : public QuantTypeDeterminer {
|
|
||||||
public:
|
|
||||||
bool DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
|
|
||||||
};
|
|
||||||
} // namespace mindspore::lite
|
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_CONV_QUANT_TYPE_DETERMINER_H_
|
|
|
@ -1,22 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h"
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
|
||||||
bool DefaultQuantAllQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
} // namespace mindspore::lite
|
|
|
@ -1,28 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_DEFAULT_QUANT_ALL_QUANT_TYPE_DETERMINER_H_
|
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_DEFAULT_QUANT_ALL_QUANT_TYPE_DETERMINER_H_
|
|
||||||
|
|
||||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
|
||||||
class DefaultQuantAllQuantTypeDeterminer : public QuantTypeDeterminer {
|
|
||||||
public:
|
|
||||||
bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
|
|
||||||
};
|
|
||||||
} // namespace mindspore::lite
|
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_DEFAULT_QUANT_ALL_QUANT_TYPE_DETERMINER_H_
|
|
|
@ -1,44 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "tools/converter/quantizer/quant_helper/matmul_quant_type_determiner.h"
|
|
||||||
#include "tools/converter/quantizer/quantize_util.h"
|
|
||||||
#include "src/common/log_adapter.h"
|
|
||||||
#include "mindspore/core/ir/dtype/type_id.h"
|
|
||||||
namespace mindspore::lite {
|
|
||||||
bool MatmulQuantTypeDeterminer::DetermineQuantWeight(const mindspore::schema::MetaGraphT &graph,
|
|
||||||
mindspore::schema::CNodeT *node) {
|
|
||||||
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
|
|
||||||
MS_CHECK_TRUE_RET(node->inputIndex.size() >= kInputIndexTwo, false);
|
|
||||||
MS_CHECK_TRUE_MSG(graph.allTensors.size() > node->inputIndex.at(kInputIndex), false, "Out of vector range.");
|
|
||||||
auto &input_tensor1 = graph.allTensors.at(node->inputIndex.at(kInputIndex));
|
|
||||||
MS_CHECK_TRUE_MSG(graph.allTensors.size() > node->inputIndex.at(kWeightIndex), false, "Out of vector range.");
|
|
||||||
auto &input_tensor2 = graph.allTensors.at(node->inputIndex.at(kWeightIndex));
|
|
||||||
MS_CHECK_TRUE_RET(node->outputIndex.size() > kOutputIndex, false);
|
|
||||||
MS_CHECK_TRUE_MSG(graph.allTensors.size() > node->outputIndex.at(kOutputIndex), false, "Out of vector range.");
|
|
||||||
auto &output_tensor = graph.allTensors.at(node->outputIndex.at(kOutputIndex));
|
|
||||||
|
|
||||||
MS_CHECK_TRUE_RET(input_tensor1 != nullptr, false);
|
|
||||||
MS_CHECK_TRUE_RET(input_tensor2 != nullptr, false);
|
|
||||||
MS_CHECK_TRUE_RET(output_tensor != nullptr, false);
|
|
||||||
if (((!TensorQuantParamsInited(*input_tensor1) && !TensorQuantParamsInited(*input_tensor2)) ||
|
|
||||||
(!TensorQuantParamsInited(*input_tensor1) && !TensorQuantParamsInited(*input_tensor2))) &&
|
|
||||||
TensorQuantParamsInited(*output_tensor)) {
|
|
||||||
node->quantType = schema::QuantType_QUANT_WEIGHT;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} // namespace mindspore::lite
|
|
|
@ -1,27 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_MATMUL_QUANT_TYPE_DETERMINER_H_
|
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_MATMUL_QUANT_TYPE_DETERMINER_H_
|
|
||||||
|
|
||||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
|
||||||
namespace mindspore::lite {
|
|
||||||
class MatmulQuantTypeDeterminer : public QuantTypeDeterminer {
|
|
||||||
public:
|
|
||||||
bool DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
|
|
||||||
};
|
|
||||||
} // namespace mindspore::lite
|
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_MATMUL_QUANT_TYPE_DETERMINER_H_
|
|
|
@ -1,29 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h"
|
|
||||||
#include "src/common/log_adapter.h"
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
|
||||||
bool OnlyNeedInputsQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
|
|
||||||
MS_ASSERT(node != nullptr);
|
|
||||||
UpdateQuantParamsNum(graph, *node);
|
|
||||||
if (input_inited_quant_params_ == node->inputIndex.size()) {
|
|
||||||
node->quantType = schema::QuantType_QUANT_ALL;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} // namespace mindspore::lite
|
|
|
@ -1,27 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_ONLY_NEED_INPUTS_QUANT_TYPE_DETERMINER_H_
|
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_ONLY_NEED_INPUTS_QUANT_TYPE_DETERMINER_H_
|
|
||||||
|
|
||||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
|
||||||
namespace mindspore::lite {
|
|
||||||
class OnlyNeedInputsQuantTypeDeterminer : public QuantTypeDeterminer {
|
|
||||||
public:
|
|
||||||
bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) override;
|
|
||||||
};
|
|
||||||
} // namespace mindspore::lite
|
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_ONLY_NEED_INPUTS_QUANT_TYPE_DETERMINER_H_
|
|
|
@ -73,7 +73,7 @@ STATUS QATTransform::DoSingleGraphQATTransform(const FuncGraphPtr &func_graph) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
InsertQuantNodeManager inset_quant_node_pass;
|
InsertQuantNodeManager inset_quant_node_pass;
|
||||||
ret = inset_quant_node_pass.InsertFP32DtypeCastNode(func_graph);
|
ret = inset_quant_node_pass.InsertDequantNode(func_graph);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Add QuantCast failed";
|
MS_LOG(ERROR) << "Add QuantCast failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -149,7 +149,7 @@ int QATTransform::QuantWeight(const FuncGraphPtr &func_graph) {
|
||||||
auto input = cnode->input(i);
|
auto input = cnode->input(i);
|
||||||
ParameterPtr parameter;
|
ParameterPtr parameter;
|
||||||
tensor::TensorPtr tensor_info;
|
tensor::TensorPtr tensor_info;
|
||||||
GetLiteParameter(input, ¶meter, &tensor_info);
|
GetParameterAndTensor(input, ¶meter, &tensor_info);
|
||||||
|
|
||||||
if (parameter == nullptr || tensor_info == nullptr ||
|
if (parameter == nullptr || tensor_info == nullptr ||
|
||||||
tensor_info->compression_type() != mindspore::kNoCompression ||
|
tensor_info->compression_type() != mindspore::kNoCompression ||
|
||||||
|
|
|
@ -1,41 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h"
|
|
||||||
#include "src/common/log_adapter.h"
|
|
||||||
#include "mindspore/core/ir/dtype/type_id.h"
|
|
||||||
#include "nnacl/op_base.h"
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
|
||||||
int QuantDtypeCastQuantParamPropogator::PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) {
|
|
||||||
MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
|
|
||||||
MS_CHECK_TRUE_MSG(!node.inputIndex.empty(), RET_ERROR, "inputIndex is empty.");
|
|
||||||
MS_CHECK_TRUE_RET(graph->allTensors.size() > node.inputIndex.at(0), RET_ERROR);
|
|
||||||
auto &input_tensor = graph->allTensors.at(node.inputIndex.at(0));
|
|
||||||
|
|
||||||
MS_CHECK_TRUE_RET(input_tensor != nullptr, RET_NULL_PTR);
|
|
||||||
if (!input_tensor->quantParams.empty() && input_tensor->quantParams.front()->inited) {
|
|
||||||
input_tensor->quantParams.front()->dstDtype = input_tensor->dataType;
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_CHECK_TRUE_RET(node.outputIndex.size() > 0, RET_ERROR);
|
|
||||||
MS_CHECK_TRUE_RET(graph->allTensors.size() > node.outputIndex.at(0), RET_ERROR);
|
|
||||||
auto &output_tensor = graph->allTensors.at(node.outputIndex.at(0));
|
|
||||||
if (!output_tensor->quantParams.empty() && output_tensor->quantParams.front()->inited) {
|
|
||||||
output_tensor->quantParams.front()->dstDtype = output_tensor->dataType;
|
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
} // namespace mindspore::lite
|
|
|
@ -1,27 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_DTYPE_CAST_QUANT_PARAM_PROPOGATOR_H_
|
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_DTYPE_CAST_QUANT_PARAM_PROPOGATOR_H_
|
|
||||||
|
|
||||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
|
||||||
namespace mindspore::lite {
|
|
||||||
class QuantDtypeCastQuantParamPropogator : public QuantParamPropogator {
|
|
||||||
public:
|
|
||||||
int PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) override;
|
|
||||||
};
|
|
||||||
} // namespace mindspore::lite
|
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_DTYPE_CAST_QUANT_PARAM_PROPOGATOR_H_
|
|
|
@ -1,182 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#include "tools/converter/quantizer/quant_helper/quant_node_helper.h"
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <memory>
|
|
||||||
#include "src/common/log_adapter.h"
|
|
||||||
#include "tools/converter/quantizer/quant_helper/bias_add_quant_param_propogator.h"
|
|
||||||
#include "tools/converter/quantizer/quant_helper/carry_data_quant_param_propogator.h"
|
|
||||||
#include "tools/converter/quantizer/quant_helper/carry_data_quant_type_determiner.h"
|
|
||||||
#include "tools/converter/quantizer/quant_helper/concat_quant_param_propogator.h"
|
|
||||||
#include "tools/converter/quantizer/quant_helper/conv_quant_param_propogator.h"
|
|
||||||
#include "tools/converter/quantizer/quant_helper/conv_quant_type_determiner.h"
|
|
||||||
#include "tools/converter/quantizer/quant_helper/default_quant_all_quant_type_determiner.h"
|
|
||||||
#include "tools/converter/quantizer/quant_helper/only_need_inputs_quant_type_determiner.h"
|
|
||||||
#include "tools/converter/quantizer/quant_helper/quant_dtype_cast_quant_param_propogator.h"
|
|
||||||
#include "tools/converter/quantizer/quant_helper/matmul_quant_type_determiner.h"
|
|
||||||
#include "src/litert/kernel_exec.h"
|
|
||||||
#include "src/litert/kernel_registry.h"
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
|
||||||
void QuantNodeBase::UpdateQuantParamsNum(const schema::MetaGraphT &graph, const schema::CNodeT &node) {
|
|
||||||
// update input quant params num
|
|
||||||
input_inited_quant_params_ = 0;
|
|
||||||
for (auto index : node.inputIndex) {
|
|
||||||
MS_CHECK_TRUE_RET_VOID(graph.allTensors.size() > index);
|
|
||||||
auto &input_tensor = graph.allTensors.at(index);
|
|
||||||
MS_CHECK_TRUE_RET_VOID(input_tensor != nullptr);
|
|
||||||
if (!input_tensor->quantParams.empty()) {
|
|
||||||
bool is_quant_params_inited =
|
|
||||||
!std::any_of(input_tensor->quantParams.begin(), input_tensor->quantParams.end(),
|
|
||||||
[](const std::unique_ptr<schema::QuantParamT> &quant_param) { return !quant_param->inited; });
|
|
||||||
|
|
||||||
if (is_quant_params_inited) {
|
|
||||||
input_inited_quant_params_++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// update output quant params num
|
|
||||||
output_inited_quant_params_ = 0;
|
|
||||||
for (auto index : node.outputIndex) {
|
|
||||||
MS_CHECK_TRUE_RET_VOID(graph.allTensors.size() > index);
|
|
||||||
auto &output_tensor = graph.allTensors.at(index);
|
|
||||||
MS_CHECK_TRUE_RET_VOID(output_tensor != nullptr);
|
|
||||||
if (!output_tensor->quantParams.empty()) {
|
|
||||||
bool is_quant_params_inited =
|
|
||||||
!std::any_of(output_tensor->quantParams.begin(), output_tensor->quantParams.end(),
|
|
||||||
[](const std::unique_ptr<schema::QuantParamT> &quant_param) { return !quant_param->inited; });
|
|
||||||
|
|
||||||
if (is_quant_params_inited) {
|
|
||||||
output_inited_quant_params_++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool QuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
|
|
||||||
MS_CHECK_TRUE_RET(node != nullptr, false);
|
|
||||||
kernel::KernelKey desc{kernel::kCPU, kNumberTypeInt8, NHWC, node->primitive->value.type, ""};
|
|
||||||
if (!KernelRegistry::GetInstance()->SupportKernel(desc)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (node->quantType != schema::QuantType_QUANT_NONE) {
|
|
||||||
return node->quantType == schema::QuantType_QUANT_ALL;
|
|
||||||
}
|
|
||||||
|
|
||||||
UpdateQuantParamsNum(graph, *node);
|
|
||||||
if (input_inited_quant_params_ == node->inputIndex.size() &&
|
|
||||||
output_inited_quant_params_ == node->outputIndex.size()) {
|
|
||||||
node->quantType = schema::QuantType_QUANT_ALL;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
bool QuantTypeDeterminer::DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node) {
|
|
||||||
return node->quantType == schema::QuantType_QUANT_WEIGHT;
|
|
||||||
}
|
|
||||||
|
|
||||||
int QuantNodeHelper::NodeQuantPreprocess(schema::MetaGraphT *graph, schema::CNodeT *node) {
|
|
||||||
MS_CHECK_TRUE_RET(graph != nullptr, RET_NULL_PTR);
|
|
||||||
MS_CHECK_TRUE_RET(node != nullptr, RET_NULL_PTR);
|
|
||||||
|
|
||||||
if (quant_type_determiner_->DetermineQuantWeight(*graph, node)) {
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
auto ret = quant_param_propogator_->PropogateQuantParams(graph, *node);
|
|
||||||
if (ret != RET_OK && ret != RET_NO_CHANGE) {
|
|
||||||
MS_LOG(ERROR) << node->name << " propagate Quant Params failed.";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
auto bool_ret = quant_type_determiner_->DetermineQuantAll(*graph, node);
|
|
||||||
if (!bool_ret) {
|
|
||||||
MS_LOG(DEBUG) << node->name << " dont need quant.";
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
return RET_OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
QuantHelperRegister *QuantHelperRegister::GetInstance() {
|
|
||||||
static QuantHelperRegister instance;
|
|
||||||
return &instance;
|
|
||||||
}
|
|
||||||
|
|
||||||
QuantNodeHelper *QuantHelperRegister::GetQuantHelper(schema::PrimitiveType op_type) {
|
|
||||||
auto it = register_map_.find(op_type);
|
|
||||||
if (it != register_map_.end()) {
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
return register_map_[schema::PrimitiveType_NONE];
|
|
||||||
}
|
|
||||||
|
|
||||||
QuantHelperRegister::QuantHelperRegister() {
|
|
||||||
auto base_propogator = std::make_shared<QuantParamPropogator>();
|
|
||||||
auto base_determiner = std::make_shared<QuantTypeDeterminer>();
|
|
||||||
auto quant_dtype_cast_propogator = std::make_shared<QuantDtypeCastQuantParamPropogator>();
|
|
||||||
auto bias_add_propogator = std::make_shared<BiasAddQuantParamPropogator>();
|
|
||||||
auto carry_data_propogator = std::make_shared<CarryDataQuantParamPropogator>();
|
|
||||||
auto carry_data_determiner = std::make_shared<CarryDataQuantTypeDeterminer>();
|
|
||||||
auto concat_propogator = std::make_shared<ConcatQuantParamPropogator>();
|
|
||||||
auto conv_propogator = std::make_shared<ConvQuantParamPropogator>();
|
|
||||||
auto conv_determiner = std::make_shared<ConvQuantTypeDeterminer>();
|
|
||||||
auto default_quant_all_determiner = std::make_shared<DefaultQuantAllQuantTypeDeterminer>();
|
|
||||||
auto only_need_inputs_determiner = std::make_shared<OnlyNeedInputsQuantTypeDeterminer>();
|
|
||||||
auto matmul_determiner = std::make_shared<MatmulQuantTypeDeterminer>();
|
|
||||||
|
|
||||||
register_map_[schema::PrimitiveType_BiasAdd] =
|
|
||||||
new (std::nothrow) QuantNodeHelper(bias_add_propogator, base_determiner);
|
|
||||||
|
|
||||||
register_map_[schema::PrimitiveType_MaxPoolFusion] =
|
|
||||||
new (std::nothrow) QuantNodeHelper(carry_data_propogator, carry_data_determiner);
|
|
||||||
register_map_[schema::PrimitiveType_Resize] =
|
|
||||||
new (std::nothrow) QuantNodeHelper(carry_data_propogator, carry_data_determiner);
|
|
||||||
register_map_[schema::PrimitiveType_Reshape] =
|
|
||||||
new (std::nothrow) QuantNodeHelper(carry_data_propogator, carry_data_determiner);
|
|
||||||
register_map_[schema::PrimitiveType_StridedSlice] =
|
|
||||||
new (std::nothrow) QuantNodeHelper(carry_data_propogator, carry_data_determiner);
|
|
||||||
register_map_[schema::PrimitiveType_Transpose] =
|
|
||||||
new (std::nothrow) QuantNodeHelper(carry_data_propogator, carry_data_determiner);
|
|
||||||
register_map_[schema::PrimitiveType_PadFusion] =
|
|
||||||
new (std::nothrow) QuantNodeHelper(carry_data_propogator, carry_data_determiner);
|
|
||||||
register_map_[schema::PrimitiveType_ReduceFusion] =
|
|
||||||
new (std::nothrow) QuantNodeHelper(base_propogator, carry_data_determiner);
|
|
||||||
register_map_[schema::PrimitiveType_Gather] =
|
|
||||||
new (std::nothrow) QuantNodeHelper(carry_data_propogator, carry_data_determiner);
|
|
||||||
|
|
||||||
register_map_[schema::PrimitiveType_Concat] = new (std::nothrow) QuantNodeHelper(concat_propogator, base_determiner);
|
|
||||||
|
|
||||||
register_map_[schema::PrimitiveType_Conv2DFusion] =
|
|
||||||
new (std::nothrow) QuantNodeHelper(conv_propogator, conv_determiner);
|
|
||||||
register_map_[schema::PrimitiveType_MatMulFusion] =
|
|
||||||
new (std::nothrow) QuantNodeHelper(conv_propogator, matmul_determiner);
|
|
||||||
register_map_[schema::PrimitiveType_FullConnection] =
|
|
||||||
new (std::nothrow) QuantNodeHelper(conv_propogator, matmul_determiner);
|
|
||||||
|
|
||||||
register_map_[schema::PrimitiveType_QuantDTypeCast] =
|
|
||||||
new (std::nothrow) QuantNodeHelper(quant_dtype_cast_propogator, default_quant_all_determiner);
|
|
||||||
|
|
||||||
register_map_[schema::PrimitiveType_DetectionPostProcess] =
|
|
||||||
new (std::nothrow) QuantNodeHelper(base_propogator, only_need_inputs_determiner);
|
|
||||||
register_map_[schema::PrimitiveType_NONE] = new (std::nothrow) QuantNodeHelper(base_propogator, base_determiner);
|
|
||||||
}
|
|
||||||
|
|
||||||
QuantHelperRegister::~QuantHelperRegister() {
|
|
||||||
for (const auto &iter : register_map_) {
|
|
||||||
delete iter.second;
|
|
||||||
}
|
|
||||||
this->register_map_.clear();
|
|
||||||
}
|
|
||||||
} // namespace mindspore::lite
|
|
|
@ -1,74 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_NODE_HELPER_H_
|
|
||||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_NODE_HELPER_H_
|
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <memory>
|
|
||||||
#include "include/errorcode.h"
|
|
||||||
#include "schema/inner/model_generated.h"
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
|
||||||
constexpr int kInputIndexOne = 1;
|
|
||||||
constexpr int kInputIndexTwo = 2;
|
|
||||||
class QuantNodeBase {
|
|
||||||
public:
|
|
||||||
void UpdateQuantParamsNum(const schema::MetaGraphT &graph, const schema::CNodeT &node);
|
|
||||||
|
|
||||||
protected:
|
|
||||||
size_t input_inited_quant_params_ = 0;
|
|
||||||
size_t output_inited_quant_params_ = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
class QuantParamPropogator : public QuantNodeBase {
|
|
||||||
public:
|
|
||||||
virtual int PropogateQuantParams(schema::MetaGraphT *graph, const schema::CNodeT &node) { return RET_OK; }
|
|
||||||
};
|
|
||||||
|
|
||||||
class QuantTypeDeterminer : public QuantNodeBase {
|
|
||||||
public:
|
|
||||||
virtual bool DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node);
|
|
||||||
virtual bool DetermineQuantWeight(const schema::MetaGraphT &graph, schema::CNodeT *node);
|
|
||||||
};
|
|
||||||
|
|
||||||
class QuantNodeHelper {
|
|
||||||
public:
|
|
||||||
int NodeQuantPreprocess(schema::MetaGraphT *graph, schema::CNodeT *node);
|
|
||||||
QuantNodeHelper(std::shared_ptr<QuantParamPropogator> quant_param_propogator,
|
|
||||||
std::shared_ptr<QuantTypeDeterminer> quant_type_determiner) {
|
|
||||||
quant_param_propogator_ = quant_param_propogator;
|
|
||||||
quant_type_determiner_ = quant_type_determiner;
|
|
||||||
}
|
|
||||||
virtual ~QuantNodeHelper() = default;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
std::shared_ptr<QuantParamPropogator> quant_param_propogator_;
|
|
||||||
std::shared_ptr<QuantTypeDeterminer> quant_type_determiner_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class QuantHelperRegister {
|
|
||||||
public:
|
|
||||||
virtual ~QuantHelperRegister();
|
|
||||||
QuantNodeHelper *GetQuantHelper(schema::PrimitiveType op_type);
|
|
||||||
static QuantHelperRegister *GetInstance();
|
|
||||||
|
|
||||||
private:
|
|
||||||
QuantHelperRegister();
|
|
||||||
std::unordered_map<schema::PrimitiveType, QuantNodeHelper *> register_map_;
|
|
||||||
};
|
|
||||||
} // namespace mindspore::lite
|
|
||||||
|
|
||||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_NODE_HELPER_H_
|
|
|
@ -36,7 +36,7 @@ int QuantNodePass::DoWeightQuant(const CNodePtr &cnode) {
|
||||||
auto input = cnode->input(idx);
|
auto input = cnode->input(idx);
|
||||||
ParameterPtr parameter;
|
ParameterPtr parameter;
|
||||||
tensor::TensorPtr weight;
|
tensor::TensorPtr weight;
|
||||||
GetLiteParameter(input, ¶meter, &weight);
|
GetParameterAndTensor(input, ¶meter, &weight);
|
||||||
if (parameter == nullptr || weight == nullptr || weight->data_type() != TypeId::kNumberTypeFloat32) {
|
if (parameter == nullptr || weight == nullptr || weight->data_type() != TypeId::kNumberTypeFloat32) {
|
||||||
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " can not quant weight";
|
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " can not quant weight";
|
||||||
continue;
|
continue;
|
||||||
|
|
|
@ -194,7 +194,7 @@ int ConvertFp16ToFp32(const FuncGraphPtr &old_graph) {
|
||||||
}
|
}
|
||||||
ParameterPtr param_node;
|
ParameterPtr param_node;
|
||||||
tensor::TensorPtr tensor_info;
|
tensor::TensorPtr tensor_info;
|
||||||
GetLiteParameter(input, ¶m_node, &tensor_info);
|
GetParameterAndTensor(input, ¶m_node, &tensor_info);
|
||||||
CHECK_NULL_RETURN(tensor_info);
|
CHECK_NULL_RETURN(tensor_info);
|
||||||
CHECK_NULL_RETURN(param_node);
|
CHECK_NULL_RETURN(param_node);
|
||||||
if (tensor_info->data_type() == kNumberTypeFloat16) {
|
if (tensor_info->data_type() == kNumberTypeFloat16) {
|
||||||
|
@ -272,13 +272,13 @@ int PrepareQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<Convert
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
int DoSingleGraphQuantize(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
||||||
CHECK_NULL_RETURN(param);
|
CHECK_NULL_RETURN(param);
|
||||||
if (param->commonQuantParam.quant_type == quant::QUANT_NONE) {
|
if (param->commonQuantParam.quant_type == quant::QUANT_NONE) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int status = PrepareQuantize(old_graph, param);
|
int status = PrepareQuantize(func_graph, param);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "PrepareQuantize failed.";
|
MS_LOG(ERROR) << "PrepareQuantize failed.";
|
||||||
return status;
|
return status;
|
||||||
|
@ -292,54 +292,57 @@ int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<C
|
||||||
origin = std::make_shared<mindspore::Model>();
|
origin = std::make_shared<mindspore::Model>();
|
||||||
CHECK_NULL_RETURN(origin);
|
CHECK_NULL_RETURN(origin);
|
||||||
size_t size = 0;
|
size_t size = 0;
|
||||||
auto ret = BuildModelByFuncGraph(origin, old_graph, param, &size);
|
auto ret = BuildModelByFuncGraph(origin, func_graph, param, &size);
|
||||||
param->commonQuantParam.quant_type = quant_type;
|
param->commonQuantParam.quant_type = quant_type;
|
||||||
if (ret != kSuccess) {
|
if (ret != kSuccess) {
|
||||||
MS_LOG(ERROR) << "Build model failed";
|
MS_LOG(ERROR) << "Build model failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
origin_lite_model = ParseLiteModel(old_graph, param);
|
origin_lite_model = ParseLiteModel(func_graph, param);
|
||||||
if (origin_lite_model == nullptr) {
|
if (origin_lite_model == nullptr) {
|
||||||
MS_LOG(ERROR) << "Parse lite model failed.";
|
MS_LOG(ERROR) << "Parse lite model failed.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (param->commonQuantParam.quant_type == quant::QUANT_ALL) { // Full Quantization
|
if (param->commonQuantParam.quant_type == quant::QUANT_ALL) { // Full Quantization
|
||||||
status = DoFullQuant(old_graph, param);
|
status = ConvertFp16ToFp32(func_graph);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Converter fp16 to fp32 failed.";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
status = DoFullQuant(func_graph, param);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Do full quant failed.";
|
MS_LOG(ERROR) << "Do full quant failed.";
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
} else if (param->commonQuantParam.quant_type == quant::QUANT_WEIGHT) { // Weight Quantization
|
} else if (param->commonQuantParam.quant_type == quant::QUANT_WEIGHT) { // Weight Quantization
|
||||||
status = DoWeightQuant(old_graph, param);
|
status = DoWeightQuant(func_graph, param);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Do weight quant failed.";
|
MS_LOG(ERROR) << "Do weight quant failed.";
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
} else if (param->commonQuantParam.quant_type == quant::QUANT_DYNAMIC) { // Dynamic Quantization
|
} else if (param->commonQuantParam.quant_type == quant::QUANT_DYNAMIC) { // Dynamic Quantization
|
||||||
status = DoDynamicQuant(old_graph, param);
|
status = DoDynamicQuant(func_graph, param);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Do dynamic quant failed.";
|
MS_LOG(ERROR) << "Do dynamic quant failed.";
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
CHECK_NULL_RETURN(optimizer);
|
||||||
CHECK_NULL_RETURN(optimizer);
|
auto fusion_pm = std::make_shared<opt::LitePassManager>("fusion pass manager after quant", false);
|
||||||
auto fusion_pm = std::make_shared<opt::LitePassManager>("fusion pass manager after quant", false);
|
CHECK_NULL_RETURN(fusion_pm);
|
||||||
CHECK_NULL_RETURN(fusion_pm);
|
fusion_pm->AddPass(std::make_shared<opt::QuantDtypeCastFusion>());
|
||||||
fusion_pm->AddPass(std::make_shared<opt::QuantDtypeCastFusion>());
|
fusion_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
|
||||||
fusion_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
|
optimizer->AddPassManager(fusion_pm);
|
||||||
optimizer->AddPassManager(fusion_pm);
|
if (optimizer->Optimize(func_graph) == nullptr) {
|
||||||
if (optimizer->Optimize(old_graph) == nullptr) {
|
MS_LOG(ERROR) << "run cast node fusion failed.";
|
||||||
MS_LOG(ERROR) << "run cast node fusion failed.";
|
return RET_ERROR;
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (param->commonQuantParam.is_debug) {
|
if (param->commonQuantParam.is_debug) {
|
||||||
status = DoQuantDebug(old_graph, param, origin, origin_lite_model);
|
status = DoQuantDebug(func_graph, param, origin, origin_lite_model);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Do quant debug failed.";
|
MS_LOG(ERROR) << "Do quant debug failed.";
|
||||||
return status;
|
return status;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -45,15 +45,8 @@ using std::vector;
|
||||||
|
|
||||||
namespace mindspore::lite::quant {
|
namespace mindspore::lite::quant {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr int kLstmInputWeightIndex = 1;
|
|
||||||
constexpr int kLstmStateWeightIndex = 2;
|
|
||||||
constexpr int kLstmWeightShapeSize = 3;
|
|
||||||
constexpr int kSingleDirBiasTensorSize = 4;
|
|
||||||
constexpr int kLstmBiasShapeSize = 2;
|
|
||||||
constexpr int kLstmBiasIndex = 3;
|
|
||||||
constexpr size_t kGatherAxisIndex = 3;
|
constexpr size_t kGatherAxisIndex = 3;
|
||||||
constexpr size_t kAnfPrimitiveIndex = 0;
|
constexpr int kDefaultThreadNum = 4;
|
||||||
constexpr int kDefaultThreadNumFour = 4;
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
int GetQuantType(const CNodePtr &cnode, quant::QuantType *quant_type) {
|
int GetQuantType(const CNodePtr &cnode, quant::QuantType *quant_type) {
|
||||||
|
@ -92,10 +85,10 @@ void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_f
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type) {
|
int UpdateDataType(const AnfNodePtr &node, TypeId new_data_type) {
|
||||||
auto abstract_base = cnode->abstract();
|
auto abstract_base = node->abstract();
|
||||||
if (abstract_base == nullptr) {
|
if (abstract_base == nullptr) {
|
||||||
MS_LOG(ERROR) << "Abstract of node is nullptr, " << cnode->fullname_with_scope();
|
MS_LOG(ERROR) << "Abstract of node is nullptr, " << node->fullname_with_scope();
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -260,7 +253,7 @@ Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, con
|
||||||
delete meta_graph;
|
delete meta_graph;
|
||||||
return kLiteNullptr;
|
return kLiteNullptr;
|
||||||
}
|
}
|
||||||
context->SetThreadNum(kDefaultThreadNumFour);
|
context->SetThreadNum(kDefaultThreadNum);
|
||||||
context->SetThreadAffinity(kCpuBindMode);
|
context->SetThreadAffinity(kCpuBindMode);
|
||||||
|
|
||||||
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
|
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
|
||||||
|
@ -297,7 +290,7 @@ std::vector<mindspore::lite::Tensor *> MSTensorToLiteTensors(const std::vector<m
|
||||||
return dst_tensors;
|
return dst_tensors;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info) {
|
void GetParameterAndTensor(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info) {
|
||||||
if (node == nullptr) {
|
if (node == nullptr) {
|
||||||
MS_LOG(ERROR) << "node is nullptr";
|
MS_LOG(ERROR) << "node is nullptr";
|
||||||
return;
|
return;
|
||||||
|
@ -330,7 +323,7 @@ int UpdateTensorDataAndSize(const AnfNodePtr &node, const tensor::TensorPtr &wei
|
||||||
MS_LOG(ERROR) << "Data size of tensor info is error.";
|
MS_LOG(ERROR) << "Data size of tensor info is error.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (memcpy_s(weight->data_c(), new_size, quant_datas, new_size) != EOK) {
|
if (memcpy_s(weight->data_c(), weight->data().nbytes(), quant_datas, new_size) != EOK) {
|
||||||
MS_LOG(ERROR) << "memcpy data failed.";
|
MS_LOG(ERROR) << "memcpy data failed.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -379,7 +372,7 @@ int GetDeConvPreferredDim(const PrimitivePtr &primitive, const std::vector<int>
|
||||||
}
|
}
|
||||||
|
|
||||||
int GetGatherPreferredDim(const CNodePtr &cnode) {
|
int GetGatherPreferredDim(const CNodePtr &cnode) {
|
||||||
if (cnode->size() < kGatherAxisIndex + 1) {
|
if (cnode->size() < kGatherAxisIndex + kPrimOffset) {
|
||||||
MS_LOG(WARNING) << "gather cnode size < 4.";
|
MS_LOG(WARNING) << "gather cnode size < 4.";
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -498,10 +491,10 @@ bool CheckControlFlowType(const AnfNodePtr &node) {
|
||||||
if (node->isa<mindspore::CNode>()) {
|
if (node->isa<mindspore::CNode>()) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
// control flow call
|
// control flow call
|
||||||
if (!IsValueNode<mindspore::Primitive>(cnode->input(kAnfPrimitiveIndex))) {
|
if (!IsValueNode<mindspore::Primitive>(cnode->input(kPrimIndex))) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
auto prim = GetValuePtr<mindspore::Primitive>(cnode->input(kAnfPrimitiveIndex));
|
auto prim = GetValuePtr<mindspore::Primitive>(cnode->input(kPrimIndex));
|
||||||
if (control_flow_ops.find(prim->name()) != control_flow_ops.end()) {
|
if (control_flow_ops.find(prim->name()) != control_flow_ops.end()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -59,7 +59,7 @@ int GetQuantType(const CNodePtr &cnode, quant::QuantType *quant_type);
|
||||||
|
|
||||||
void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs);
|
void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs);
|
||||||
|
|
||||||
int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type);
|
int UpdateDataType(const AnfNodePtr &node, TypeId new_data_type);
|
||||||
|
|
||||||
bool IsGraphInDTypeCast(const CNodePtr &cnode);
|
bool IsGraphInDTypeCast(const CNodePtr &cnode);
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ mindspore::lite::Tensor *MSTensorToLiteTensor(const mindspore::MSTensor &tensor)
|
||||||
|
|
||||||
std::vector<mindspore::lite::Tensor *> MSTensorToLiteTensors(const std::vector<mindspore::MSTensor> &src_tensors);
|
std::vector<mindspore::lite::Tensor *> MSTensorToLiteTensors(const std::vector<mindspore::MSTensor> &src_tensors);
|
||||||
|
|
||||||
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info);
|
void GetParameterAndTensor(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info);
|
||||||
|
|
||||||
bool CheckNodeInSet(const CNodePtr &cnode, const std::set<PrimitivePtr> &support_primitive_types);
|
bool CheckNodeInSet(const CNodePtr &cnode, const std::set<PrimitivePtr> &support_primitive_types);
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
* Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -152,7 +152,7 @@ int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr
|
||||||
auto input = cnode->input(idx);
|
auto input = cnode->input(idx);
|
||||||
ParameterPtr parameter;
|
ParameterPtr parameter;
|
||||||
tensor::TensorPtr tensor_info;
|
tensor::TensorPtr tensor_info;
|
||||||
GetLiteParameter(input, ¶meter, &tensor_info);
|
GetParameterAndTensor(input, ¶meter, &tensor_info);
|
||||||
if (parameter == nullptr || tensor_info == nullptr ||
|
if (parameter == nullptr || tensor_info == nullptr ||
|
||||||
tensor_info->compression_type() != mindspore::kNoCompression) {
|
tensor_info->compression_type() != mindspore::kNoCompression) {
|
||||||
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " dont need quant weight";
|
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " dont need quant weight";
|
||||||
|
@ -263,7 +263,7 @@ int WeightQuantizer::DoCompression(const CNodePtr &cnode, const ParameterPtr &pa
|
||||||
|
|
||||||
int WeightQuantizer::DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr ¶meter, int idx,
|
int WeightQuantizer::DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr ¶meter, int idx,
|
||||||
const tensor::TensorPtr &tensor_info, int preferred_dim,
|
const tensor::TensorPtr &tensor_info, int preferred_dim,
|
||||||
WeightQuantType weight_quant_type, bool symmetric, bool update_tensor) {
|
WeightQuantType weight_quant_type, bool symmetric) {
|
||||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
CHECK_NULL_RETURN(primitive);
|
CHECK_NULL_RETURN(primitive);
|
||||||
auto mixed_bit_quantization = MixedBitWeightQuantization(mixed_bit_init_scale_);
|
auto mixed_bit_quantization = MixedBitWeightQuantization(mixed_bit_init_scale_);
|
||||||
|
@ -293,14 +293,8 @@ int WeightQuantizer::DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr &pa
|
||||||
<< parameter->fullname_with_scope()
|
<< parameter->fullname_with_scope()
|
||||||
<< " mixed bit quantization search failed, the current layer rolls back to 8 bit fixed quantization.";
|
<< " mixed bit quantization search failed, the current layer rolls back to 8 bit fixed quantization.";
|
||||||
FixedBitWeightQuantization fixed_bit_quant;
|
FixedBitWeightQuantization fixed_bit_quant;
|
||||||
if (update_tensor) {
|
status = fixed_bit_quant.QuantFilter(parameter, tensor_info, primitive, quant_type_, quant_max, quant_min, bit_num_,
|
||||||
status =
|
weight_quant_type, kNumberTypeInt8, idx - 1, preferred_dim, symmetric);
|
||||||
fixed_bit_quant.QuantFilter(parameter, tensor_info, primitive, quant_type_, quant_max, quant_min, bit_num_,
|
|
||||||
weight_quant_type, kNumberTypeInt8, idx - 1, preferred_dim, symmetric);
|
|
||||||
} else {
|
|
||||||
status = fixed_bit_quant.StatisticsFilter(tensor_info, primitive, quant_type_, quant_max, quant_min, bit_num_,
|
|
||||||
weight_quant_type, kNumberTypeInt8, idx - 1, preferred_dim, symmetric);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
@ -332,9 +326,11 @@ int WeightQuantizer::InsertDequantNode(const FuncGraphPtr &func_graph, const CNo
|
||||||
MS_LOG(INFO) << tensor_name << " insert WeightQuant node";
|
MS_LOG(INFO) << tensor_name << " insert WeightQuant node";
|
||||||
auto axis = GetPreferredDim(cnode, idx - kPrimOffset, ConvertShapeVectorToInt32(tensor_info->shape_c()));
|
auto axis = GetPreferredDim(cnode, idx - kPrimOffset, ConvertShapeVectorToInt32(tensor_info->shape_c()));
|
||||||
if (type_id == kNumberTypeFloat32) {
|
if (type_id == kNumberTypeFloat32) {
|
||||||
status = quant_manager.InsertWeightQuantNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat32, axis);
|
status =
|
||||||
|
quant_manager.InsertQuantDtypeCastFlyNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat32, axis);
|
||||||
} else {
|
} else {
|
||||||
status = quant_manager.InsertWeightQuantNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat16, axis);
|
status =
|
||||||
|
quant_manager.InsertQuantDtypeCastFlyNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat16, axis);
|
||||||
}
|
}
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << tensor_name << " insert weight quant node failed.";
|
MS_LOG(ERROR) << tensor_name << " insert weight quant node failed.";
|
||||||
|
@ -344,8 +340,8 @@ int WeightQuantizer::InsertDequantNode(const FuncGraphPtr &func_graph, const CNo
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int WeightQuantizer::MarkCnodeWeightQuantType(const CNodePtr &cnode) {
|
int WeightQuantizer::MarkCNodeWeightQuantType(const CNodePtr &cnode) {
|
||||||
MS_CHECK_TRUE_RET(cnode != nullptr, RET_NULL_PTR);
|
CHECK_NULL_RETURN(cnode);
|
||||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
if (primitive == nullptr) {
|
if (primitive == nullptr) {
|
||||||
MS_LOG(ERROR) << "primitive is nullptr";
|
MS_LOG(ERROR) << "primitive is nullptr";
|
||||||
|
@ -353,7 +349,7 @@ int WeightQuantizer::MarkCnodeWeightQuantType(const CNodePtr &cnode) {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||||
MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
|
CHECK_NULL_RETURN(quant_param_holder);
|
||||||
if (quant_param_holder->quant_type() == quant::QUANT_WEIGHT) {
|
if (quant_param_holder->quant_type() == quant::QUANT_WEIGHT) {
|
||||||
// already marked with QuantType_QUANT_WEIGHT
|
// already marked with QuantType_QUANT_WEIGHT
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
@ -361,11 +357,11 @@ int WeightQuantizer::MarkCnodeWeightQuantType(const CNodePtr &cnode) {
|
||||||
|
|
||||||
// Support Share Weight Quant.
|
// Support Share Weight Quant.
|
||||||
for (size_t i = kPrimOffset; i < cnode->size(); i++) {
|
for (size_t i = kPrimOffset; i < cnode->size(); i++) {
|
||||||
auto inputNode = cnode->input(i);
|
auto input_node = cnode->input(i);
|
||||||
if (inputNode->isa<Parameter>()) {
|
if (input_node->isa<Parameter>()) {
|
||||||
ParameterPtr param_node;
|
ParameterPtr param_node;
|
||||||
tensor::TensorPtr tensor_info;
|
tensor::TensorPtr tensor_info;
|
||||||
GetLiteParameter(inputNode, ¶m_node, &tensor_info);
|
GetParameterAndTensor(input_node, ¶m_node, &tensor_info);
|
||||||
auto param = weight_quantized_tensors_.find(tensor_info);
|
auto param = weight_quantized_tensors_.find(tensor_info);
|
||||||
if (param != weight_quantized_tensors_.end()) {
|
if (param != weight_quantized_tensors_.end()) {
|
||||||
quant_param_holder->set_quant_type(quant::QUANT_WEIGHT);
|
quant_param_holder->set_quant_type(quant::QUANT_WEIGHT);
|
||||||
|
@ -381,12 +377,12 @@ int WeightQuantizer::MarkGraphWeightQuantType(const FuncGraphPtr &func_graph) {
|
||||||
for (auto &cnode : func_graph->GetOrderedCnodes()) {
|
for (auto &cnode : func_graph->GetOrderedCnodes()) {
|
||||||
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
|
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
|
||||||
if (primitive == nullptr) {
|
if (primitive == nullptr) {
|
||||||
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
|
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " primitive is nullptr";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto status = MarkCnodeWeightQuantType(cnode);
|
auto status = MarkCNodeWeightQuantType(cnode);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "MarkGraphWeightQuantType error marking " << cnode->fullname_with_scope();
|
MS_LOG(ERROR) << cnode->fullname_with_scope() << " mark graph QuantType failed.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -394,7 +390,7 @@ int WeightQuantizer::MarkGraphWeightQuantType(const FuncGraphPtr &func_graph) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||||
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
|
CHECK_NULL_RETURN(func_graph);
|
||||||
weight_quantized_tensors_.clear();
|
weight_quantized_tensors_.clear();
|
||||||
const std::set<PrimitivePtr> support_primitive_types = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
|
const std::set<PrimitivePtr> support_primitive_types = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
|
||||||
prim::kPrimMatMulFusion, prim::kPrimFullConnection,
|
prim::kPrimMatMulFusion, prim::kPrimFullConnection,
|
||||||
|
|
|
@ -115,11 +115,10 @@ class WeightQuantizer : public Quantizer {
|
||||||
const std::set<PrimitivePtr> &symmetric_types, const std::vector<int> &weight_indices,
|
const std::set<PrimitivePtr> &symmetric_types, const std::vector<int> &weight_indices,
|
||||||
bool compression = true);
|
bool compression = true);
|
||||||
int MarkGraphWeightQuantType(const FuncGraphPtr &func_graph);
|
int MarkGraphWeightQuantType(const FuncGraphPtr &func_graph);
|
||||||
int MarkCnodeWeightQuantType(const CNodePtr &cnode);
|
int MarkCNodeWeightQuantType(const CNodePtr &cnode);
|
||||||
int DoCompression(const CNodePtr &cnode, const ParameterPtr ¶meter, int idx);
|
int DoCompression(const CNodePtr &cnode, const ParameterPtr ¶meter, int idx);
|
||||||
int DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr ¶meter, int idx, const tensor::TensorPtr &tensor_info,
|
int DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr ¶meter, int idx, const tensor::TensorPtr &tensor_info,
|
||||||
int preferred_dim, WeightQuantType weight_quant_type, bool symmetric = true,
|
int preferred_dim, WeightQuantType weight_quant_type, bool symmetric = true);
|
||||||
bool update_tensor = true);
|
|
||||||
int InsertDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const ParameterPtr ¶meter, int idx,
|
int InsertDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const ParameterPtr ¶meter, int idx,
|
||||||
const tensor::TensorPtr &tensor_info);
|
const tensor::TensorPtr &tensor_info);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue