!38593 tf parser support QAT

Merge pull request !38593 from liyan2022/dev_qat
This commit is contained in:
i-robot 2022-07-25 11:11:49 +00:00 committed by Gitee
commit 619e8512db
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
18 changed files with 259 additions and 19 deletions

View File

@ -390,10 +390,10 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const std::sha
return RET_OK;
}
STATUS AnfTransform::QATTransform(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param) {
STATUS AnfTransform::QATTransform(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
if (param->fullQuantParam.target_device == quant::TargetDevice::DSP &&
param->commonQuantParam.quant_type != schema::QuantType_QUANT_ALL) {
auto remove_pass = quant::RemoveUnusedQuantParam(func_graph);
auto remove_pass = quant::RemoveUnusedQuantParam(old_graph);
auto ret = remove_pass.Remove();
if (ret != RET_OK) {
MS_LOG(ERROR) << "remove unused quant param failed.";
@ -516,6 +516,12 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
return nullptr;
}
status = QATTransform(old_graph, param);
if (status != RET_OK) {
MS_LOG(ERROR) << "QAT model transform failed.";
return nullptr;
}
status = DoQuantize(old_graph, param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do Quantize failed.";

View File

@ -59,7 +59,7 @@ class AnfTransform {
static STATUS MarkTrainOp(const FuncGraphPtr &func_graph);
static STATUS QATTransform(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param);
static STATUS QATTransform(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param);
};
} // namespace lite
} // namespace mindspore

View File

@ -46,6 +46,7 @@ ADD_CONVERTER_ONLY_OP(Merge);
ADD_CONVERTER_ONLY_OP(Einsum);
ADD_CONVERTER_ONLY_OP(QuantizeLinear);
ADD_CONVERTER_ONLY_OP(DequantizeLinear);
ADD_CONVERTER_ONLY_OP(FakeQuantWithMinMaxVars);
} // namespace lite
} // namespace mindspore

View File

@ -24,6 +24,7 @@
#include "tools/converter/ops/ops_def.h"
#include "src/common/utils.h"
#include "nnacl/op_base.h"
#include "tools/converter/quantizer/quantize_util.h"
namespace mindspore::lite {
void OnnxQuantizeLinearAdjust::RemoveDequantizeLinear(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
@ -109,14 +110,13 @@ bool OnnxQuantizeLinearAdjust::Adjust(const FuncGraphPtr &func_graph) {
MS_CHECK_TRUE_RET(manager != nullptr, false);
auto node_users = manager->node_users()[cnode];
for (auto &node_user : node_users) {
auto next_primitive = GetValueNode<PrimitivePtr>(node_user.first->cast<CNodePtr>()->input(0));
auto next_quant_holder = GetCNodeQuantHolder(next_primitive);
auto next_quant_holder = quant::GetCNodeQuantHolder(node_user.first->cast<CNodePtr>());
auto ret = SetInputQuantParam(cnode, next_quant_holder, (node_user.second - kPrimOffset));
if (!ret) {
MS_LOG(ERROR) << "Set quant param failed.";
return false;
}
manager->SetEdge(node_user.first, node_user.second, cnode->inputs()[1]);
manager->SetEdge(node_user.first, node_user.second, cnode->inputs()[kIndex1]);
}
}
@ -124,7 +124,7 @@ bool OnnxQuantizeLinearAdjust::Adjust(const FuncGraphPtr &func_graph) {
for (auto &cnode : func_graph->GetOrderedCnodes()) {
MS_LOG(DEBUG) << "check cnode name: " << cnode->fullname_with_scope();
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
auto primitive_quant_holder = GetCNodeQuantHolder(primitive);
auto primitive_quant_holder = quant::GetCNodeQuantHolder(primitive);
auto input_quant_params = primitive_quant_holder->get_input_quant_params();
for (size_t i = 0; i < input_quant_params.size(); i++) {
auto quant_params = input_quant_params.at(i);

View File

@ -0,0 +1,77 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_fake_quant_adjust.h"
#include <utility>
#include <memory>
#include <algorithm>
#include "ops/primitive_c.h"
#include "tools/converter/ops/ops_def.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "src/common/utils.h"
#include "tools/converter/quantizer/quantize_util.h"
namespace mindspore {
namespace lite {
bool TFFakeQuantAdjust::SetInputQuantParam(const CNodePtr &cnode, const QuantParamHolderPtr &quant_param_holder,
size_t index) {
MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, false, "Primitive quant param holder nullptr.");
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
auto min_value = primitive->GetAttr("min");
MS_CHECK_FALSE(min_value == nullptr, false);
auto max_value = primitive->GetAttr("max");
MS_CHECK_FALSE(max_value == nullptr, false);
MS_LOG(INFO) << "min: " << GetValue<float>(min_value) << " max: " << GetValue<float>(max_value);
std::vector<schema::QuantParamT> quant_params;
auto quant_param = std::make_unique<QuantParamT>();
quant_param->min = GetValue<float>(min_value);
quant_param->max = GetValue<float>(max_value);
quant_param->scale = (std::max(abs(quant_param->min), abs(quant_param->max))) / quant::kQuantRange;
quant_param->zeroPoint = 0;
quant_param->inited = true;
quant_params.push_back(*std::move(quant_param));
quant_param_holder->set_input_quant_param(index, quant_params);
return true;
}
bool TFFakeQuantAdjust::Adjust(const FuncGraphPtr &func_graph) {
MS_CHECK_TRUE_RET(func_graph != nullptr, false);
for (auto &cnode : func_graph->GetOrderedCnodes()) {
if (!opt::CheckPrimitiveType(cnode, std::make_unique<Primitive>(lite::kNameFakeQuantWithMinMaxVars))) {
continue;
}
MS_CHECK_GE(cnode->inputs().size(), kInputSize1, false);
auto manager = func_graph->manager();
if (manager == nullptr) {
manager = Manage(func_graph, true);
}
MS_CHECK_TRUE_RET(manager != nullptr, true);
auto node_users = manager->node_users()[cnode];
for (auto &node_user : node_users) {
auto next_quant_holder = quant::GetCNodeQuantHolder(node_user.first->cast<CNodePtr>());
auto ret = SetInputQuantParam(cnode, next_quant_holder, node_user.second - quant::kPrimOffset);
if (!ret) {
MS_LOG(ERROR) << "Set quant param failed.";
return false;
}
manager->SetEdge(node_user.first, node_user.second, cnode->inputs()[kIndex1]);
}
}
return true;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FAKE_QUANT_ADJUST_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FAKE_QUANT_ADJUST_H_
#include <string>
#include <vector>
#include "backend/common/optimizer/pass.h"
#include "backend/common/optimizer/optimizer.h"
#include "tools/converter/quantizer/quant_param_holder.h"
namespace mindspore {
namespace lite {
class TFFakeQuantAdjust {
public:
bool Adjust(const FuncGraphPtr &func_graph);
private:
bool SetInputQuantParam(const CNodePtr &cnode, const QuantParamHolderPtr &quant_param_holder, size_t index);
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FAKE_QUANT_PARSER_H_

View File

@ -0,0 +1,67 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_fake_quant_parser.h"
#include "nnacl/op_base.h"
#include "tools/converter/ops/ops_def.h"
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
PrimitiveCPtr TFFakeQuantParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) {
auto prim = std::make_unique<FakeQuantWithMinMaxVars>();
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
MS_CHECK_GE(tf_op.input_size(), kInputSize2, nullptr);
tensorflow::AttrValue attr_value;
// min param
auto min_node = GetConstInputNode(tf_node_map, tf_op.input(SECOND_INPUT));
if (min_node == nullptr) {
MS_LOG(ERROR) << "Find FakeQuant input min node failed.";
return nullptr;
}
if (!TensorFlowUtils::FindAttrValue(*min_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The attribute min should be specified.";
return nullptr;
}
auto min_value = attr_value.tensor().float_val(0);
// max param
auto max_node = GetConstInputNode(tf_node_map, tf_op.input(THIRD_INPUT));
if (max_node == nullptr) {
MS_LOG(ERROR) << "Find FakeQuant input max node failed.";
return nullptr;
}
if (!TensorFlowUtils::FindAttrValue(*max_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The attribute max should be specified.";
return nullptr;
}
auto max_value = attr_value.tensor().float_val(0);
prim->AddAttr("min", MakeValue(min_value));
prim->AddAttr("max", MakeValue(max_value));
*output_size = 1;
if (AddOpInput(tf_op, 0, inputs) != RET_OK) {
MS_LOG(ERROR) << "Add op input failed.";
return nullptr;
}
return prim;
}
TFNodeRegistrar g_tfFakeQuantParser("FakeQuantWithMinMaxVars", new TFFakeQuantParser());
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FAKE_QUANT_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FAKE_QUANT_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFFakeQuantParser : public TFNodeParser {
public:
TFFakeQuantParser() = default;
~TFFakeQuantParser() override = default;
PrimitiveCPtr Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FAKE_QUANT_PARSER_H_

View File

@ -35,6 +35,7 @@
#include "tools/converter/parser/tf/functionalize_control_op_pass.h"
#include "tools/converter/parser/parser_utils.h"
#include "tools/converter/parser/lite_model_parser_creator.h"
#include "tools/converter/parser/tf/tf_fake_quant_adjust.h"
#include "tools/common/tensor_util.h"
#include "src/common/log_util.h"
#include "tools/converter/parser/unify_format.h"
@ -1229,6 +1230,12 @@ int TFModelParser::TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
auto fake_quant_adjust = std::make_shared<TFFakeQuantAdjust>();
if (!fake_quant_adjust->Adjust(func_graph)) {
MS_LOG(ERROR) << "tf fake quant adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
}
return RET_OK;
}

View File

@ -23,6 +23,7 @@
#include <map>
#include <memory>
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/converter/graphdef_transform.h"
#include "nnacl/op_base.h"
#include "tools/common/statistic_utils.h"
#include "src/tensor.h"

View File

@ -94,7 +94,7 @@ int FullQuantQuantizer::DoParameterWeightQuant(const CNodePtr &cnode, const Para
auto weight_q_min = per_channel ? weight_channel_q_min_ : weight_layer_q_min_;
auto weight_q_max = per_channel ? weight_channel_q_max_ : weight_layer_q_max_;
auto symmetric = per_channel ? weight_channel_symmetric_ : weight_layer_symmetric_;
auto status = FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max,
auto status = FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, schema::QuantType_QUANT_ALL, weight_q_max,
weight_q_min, bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1,
preferred_dim, symmetric);
if (status != RET_OK) {
@ -118,7 +118,7 @@ int FullQuantQuantizer::DoValueNodeWeightQuant(const CNodePtr &cnode, const Valu
auto weight_q_min = per_channel ? weight_channel_q_min_ : weight_layer_q_min_;
auto weight_q_max = per_channel ? weight_channel_q_max_ : weight_layer_q_max_;
auto symmetric = per_channel ? weight_channel_symmetric_ : weight_layer_symmetric_;
auto status = FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max,
auto status = FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, schema::QuantType_QUANT_ALL, weight_q_max,
weight_q_min, bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1,
preferred_dim, symmetric);
if (status != RET_OK) {

View File

@ -26,6 +26,7 @@
#include "schema/inner/model_generated.h"
#include "base/base.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "tools/converter/quantizer/quant_param_holder.h"
namespace mindspore::lite::quant {

View File

@ -121,6 +121,7 @@ bool QuantTypeDeterminer::DetermineQuantWeight(const CNodePtr &cnode) {
}
int QuantTypeDeterminer::Determine() {
CHECK_NULL_RETURN(func_graph_);
auto nodes = func_graph_->GetOrderedCnodes();
for (auto const &cnode : nodes) {
auto quant_holder = GetCNodeQuantHolder(cnode);
@ -130,10 +131,10 @@ int QuantTypeDeterminer::Determine() {
}
if (DetermineQuantWeight(cnode)) {
MS_LOG(INFO) << cnode->fullname_with_scope() << " set QuantType_QUANT_WEIGHT";
quant_holder->set_quant_type(QuantType_QUANT_WEIGHT);
quant_holder->set_quant_type(schema::QuantType_QUANT_WEIGHT);
} else if (DetermineQuantAll(cnode)) {
MS_LOG(INFO) << cnode->fullname_with_scope() << " set QuantType_QUANT_ALL";
quant_holder->set_quant_type(QuantType_QUANT_ALL);
quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
} else {
MS_LOG(INFO) << cnode->fullname_with_scope() << " Remove unused quant info";
quant_holder->ClearQuantParams();

View File

@ -13,11 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_TYPE_DETERMINER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_TYPE_DETERMINER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_TYPE_DETERMINER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_TYPE_DETERMINER_H_
#include <utility>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "tools/converter/quantizer/quantize_util.h"
namespace mindspore::lite::quant {
@ -37,4 +38,4 @@ class QuantTypeDeterminer {
FuncGraphPtr func_graph_ = nullptr;
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_TYPE_DETERMINER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_TYPE_DETERMINER_H_

View File

@ -41,6 +41,8 @@ constexpr float kRatio = 10.0;
constexpr int kCpuBindMode = 1;
constexpr int kPrimIndex = 0;
constexpr int kPrimOffset = 1;
constexpr int kU8ZeroPointOffset = 128;
constexpr int kQuantRange = 127;
enum ActivationQuantizedMethod {
MAX_MIN = 0,

View File

@ -37,6 +37,7 @@
#include "tools/optimizer/common/gllo_utils.h"
#include "nnacl/op_base.h"
#include "ops/op_utils.h"
#include "tools/converter/graphdef_transform.h"
using std::string;
using std::vector;
@ -537,7 +538,7 @@ void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<in
}
int MixedBitQuantFilter(const AnfNodePtr &parameter_node, const tensor::TensorPtr &weight,
const PrimitivePtr &primitive, QuantType quant_type, WeightQuantType weight_quant_type,
const PrimitivePtr &primitive, schema::QuantType quant_type, WeightQuantType weight_quant_type,
TypeId quant_data_type, double init_scale, int index, int preferred_dim, bool symmetric) {
MS_CHECK_TRUE_RET(primitive != nullptr, RET_NULL_PTR);
MS_CHECK_TRUE_RET(weight != nullptr, RET_NULL_PTR);

View File

@ -48,7 +48,6 @@
#include "tools/converter/quantizer/quant_params.h"
#include "tools/converter/quantizer/mixed_bit_weight_quantizer.h"
#include "src/runtime/lite_session.h"
#include "tools/converter/graphdef_transform.h"
#include "src/common/file_utils.h"
#include "src/common/quant_utils.h"
#include "include/api/model.h"
@ -71,7 +70,7 @@ void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<in
bool TensorQuantParamsInited(const schema::TensorT &tensor);
int MixedBitQuantFilter(const AnfNodePtr &parameter_node, const tensor::TensorPtr &weight,
const PrimitivePtr &primitive, QuantType quant_type, WeightQuantType weight_quant_type,
const PrimitivePtr &primitive, schema::QuantType quant_type, WeightQuantType weight_quant_type,
TypeId quant_data_type, double init_scale, int index, int preferred_dim, bool symmetric);
int CalChannels(const std::vector<int> &dims, int channel_cnt, bool *channel_at_first);
@ -104,7 +103,7 @@ int DeQuantData(const int8_t *tensor_data, int64_t elements_num, std::vector<min
template <typename T>
int FixedBitQuantFilter(const AnfNodePtr &parameter_node, const tensor::TensorPtr &weight,
const PrimitivePtr &primitive, QuantType quant_type, int quant_max, int quant_min,
const PrimitivePtr &primitive, schema::QuantType quant_type, int quant_max, int quant_min,
size_t bit_num, WeightQuantType weight_quant_type, TypeId quant_data_type, int index,
int preferred_dim, bool symmetric = false, bool narrow_range = false) {
MS_ASSERT(weight != nullptr);

View File

@ -213,7 +213,7 @@ int AnfExporter::CreateNewTensorForParameter(const std::unique_ptr<schema::MetaG
schema_tensor->dataType = data_info.data_type_;
schema_tensor->data = data_info.data_;
schema_tensor->enableHuffmanCode = data_info.enable_huffman_code_;
schema_tensor->weightQuantCompressType = static_cast<WeightQuantCompressType>(data_info.compress_type_);
schema_tensor->weightQuantCompressType = static_cast<schema::WeightQuantCompressType>(data_info.compress_type_);
schema_tensor->nodeType = NodeType_CNode;
auto key = std::make_pair(input, 0);
node_id_map_[key] = static_cast<int>(meta_graphT->allTensors.size());