forked from mindspore-Ecosystem/mindspore
!38593 tf parser support QAT
Merge pull request !38593 from liyan2022/dev_qat
This commit is contained in:
commit
619e8512db
|
@ -390,10 +390,10 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const std::sha
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS AnfTransform::QATTransform(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
||||
STATUS AnfTransform::QATTransform(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
||||
if (param->fullQuantParam.target_device == quant::TargetDevice::DSP &&
|
||||
param->commonQuantParam.quant_type != schema::QuantType_QUANT_ALL) {
|
||||
auto remove_pass = quant::RemoveUnusedQuantParam(func_graph);
|
||||
auto remove_pass = quant::RemoveUnusedQuantParam(old_graph);
|
||||
auto ret = remove_pass.Remove();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "remove unused quant param failed.";
|
||||
|
@ -516,6 +516,12 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
status = QATTransform(old_graph, param);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QAT model transform failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = DoQuantize(old_graph, param);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do Quantize failed.";
|
||||
|
|
|
@ -59,7 +59,7 @@ class AnfTransform {
|
|||
|
||||
static STATUS MarkTrainOp(const FuncGraphPtr &func_graph);
|
||||
|
||||
static STATUS QATTransform(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> ¶m);
|
||||
static STATUS QATTransform(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -46,6 +46,7 @@ ADD_CONVERTER_ONLY_OP(Merge);
|
|||
ADD_CONVERTER_ONLY_OP(Einsum);
|
||||
ADD_CONVERTER_ONLY_OP(QuantizeLinear);
|
||||
ADD_CONVERTER_ONLY_OP(DequantizeLinear);
|
||||
ADD_CONVERTER_ONLY_OP(FakeQuantWithMinMaxVars);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
void OnnxQuantizeLinearAdjust::RemoveDequantizeLinear(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
|
@ -109,14 +110,13 @@ bool OnnxQuantizeLinearAdjust::Adjust(const FuncGraphPtr &func_graph) {
|
|||
MS_CHECK_TRUE_RET(manager != nullptr, false);
|
||||
auto node_users = manager->node_users()[cnode];
|
||||
for (auto &node_user : node_users) {
|
||||
auto next_primitive = GetValueNode<PrimitivePtr>(node_user.first->cast<CNodePtr>()->input(0));
|
||||
auto next_quant_holder = GetCNodeQuantHolder(next_primitive);
|
||||
auto next_quant_holder = quant::GetCNodeQuantHolder(node_user.first->cast<CNodePtr>());
|
||||
auto ret = SetInputQuantParam(cnode, next_quant_holder, (node_user.second - kPrimOffset));
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Set quant param failed.";
|
||||
return false;
|
||||
}
|
||||
manager->SetEdge(node_user.first, node_user.second, cnode->inputs()[1]);
|
||||
manager->SetEdge(node_user.first, node_user.second, cnode->inputs()[kIndex1]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -124,7 +124,7 @@ bool OnnxQuantizeLinearAdjust::Adjust(const FuncGraphPtr &func_graph) {
|
|||
for (auto &cnode : func_graph->GetOrderedCnodes()) {
|
||||
MS_LOG(DEBUG) << "check cnode name: " << cnode->fullname_with_scope();
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
auto primitive_quant_holder = GetCNodeQuantHolder(primitive);
|
||||
auto primitive_quant_holder = quant::GetCNodeQuantHolder(primitive);
|
||||
auto input_quant_params = primitive_quant_holder->get_input_quant_params();
|
||||
for (size_t i = 0; i < input_quant_params.size(); i++) {
|
||||
auto quant_params = input_quant_params.at(i);
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/tf/tf_fake_quant_adjust.h"
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
bool TFFakeQuantAdjust::SetInputQuantParam(const CNodePtr &cnode, const QuantParamHolderPtr &quant_param_holder,
|
||||
size_t index) {
|
||||
MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, false, "Primitive quant param holder nullptr.");
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
auto min_value = primitive->GetAttr("min");
|
||||
MS_CHECK_FALSE(min_value == nullptr, false);
|
||||
auto max_value = primitive->GetAttr("max");
|
||||
MS_CHECK_FALSE(max_value == nullptr, false);
|
||||
MS_LOG(INFO) << "min: " << GetValue<float>(min_value) << " max: " << GetValue<float>(max_value);
|
||||
|
||||
std::vector<schema::QuantParamT> quant_params;
|
||||
auto quant_param = std::make_unique<QuantParamT>();
|
||||
quant_param->min = GetValue<float>(min_value);
|
||||
quant_param->max = GetValue<float>(max_value);
|
||||
quant_param->scale = (std::max(abs(quant_param->min), abs(quant_param->max))) / quant::kQuantRange;
|
||||
quant_param->zeroPoint = 0;
|
||||
quant_param->inited = true;
|
||||
quant_params.push_back(*std::move(quant_param));
|
||||
quant_param_holder->set_input_quant_param(index, quant_params);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TFFakeQuantAdjust::Adjust(const FuncGraphPtr &func_graph) {
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, false);
|
||||
for (auto &cnode : func_graph->GetOrderedCnodes()) {
|
||||
if (!opt::CheckPrimitiveType(cnode, std::make_unique<Primitive>(lite::kNameFakeQuantWithMinMaxVars))) {
|
||||
continue;
|
||||
}
|
||||
MS_CHECK_GE(cnode->inputs().size(), kInputSize1, false);
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
manager = Manage(func_graph, true);
|
||||
}
|
||||
MS_CHECK_TRUE_RET(manager != nullptr, true);
|
||||
auto node_users = manager->node_users()[cnode];
|
||||
for (auto &node_user : node_users) {
|
||||
auto next_quant_holder = quant::GetCNodeQuantHolder(node_user.first->cast<CNodePtr>());
|
||||
auto ret = SetInputQuantParam(cnode, next_quant_holder, node_user.second - quant::kPrimOffset);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Set quant param failed.";
|
||||
return false;
|
||||
}
|
||||
manager->SetEdge(node_user.first, node_user.second, cnode->inputs()[kIndex1]);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FAKE_QUANT_ADJUST_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FAKE_QUANT_ADJUST_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/common/optimizer/pass.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "tools/converter/quantizer/quant_param_holder.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFFakeQuantAdjust {
|
||||
public:
|
||||
bool Adjust(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
bool SetInputQuantParam(const CNodePtr &cnode, const QuantParamHolderPtr &quant_param_holder, size_t index);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FAKE_QUANT_PARSER_H_
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/converter/parser/tf/tf_fake_quant_parser.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
PrimitiveCPtr TFFakeQuantParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) {
|
||||
auto prim = std::make_unique<FakeQuantWithMinMaxVars>();
|
||||
MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
|
||||
MS_CHECK_GE(tf_op.input_size(), kInputSize2, nullptr);
|
||||
tensorflow::AttrValue attr_value;
|
||||
// min param
|
||||
auto min_node = GetConstInputNode(tf_node_map, tf_op.input(SECOND_INPUT));
|
||||
if (min_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Find FakeQuant input min node failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if (!TensorFlowUtils::FindAttrValue(*min_node, "value", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The attribute min should be specified.";
|
||||
return nullptr;
|
||||
}
|
||||
auto min_value = attr_value.tensor().float_val(0);
|
||||
|
||||
// max param
|
||||
auto max_node = GetConstInputNode(tf_node_map, tf_op.input(THIRD_INPUT));
|
||||
if (max_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Find FakeQuant input max node failed.";
|
||||
return nullptr;
|
||||
}
|
||||
if (!TensorFlowUtils::FindAttrValue(*max_node, "value", &attr_value)) {
|
||||
MS_LOG(ERROR) << "The attribute max should be specified.";
|
||||
return nullptr;
|
||||
}
|
||||
auto max_value = attr_value.tensor().float_val(0);
|
||||
|
||||
prim->AddAttr("min", MakeValue(min_value));
|
||||
prim->AddAttr("max", MakeValue(max_value));
|
||||
|
||||
*output_size = 1;
|
||||
if (AddOpInput(tf_op, 0, inputs) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Add op input failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return prim;
|
||||
}
|
||||
|
||||
TFNodeRegistrar g_tfFakeQuantParser("FakeQuantWithMinMaxVars", new TFFakeQuantParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FAKE_QUANT_PARSER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FAKE_QUANT_PARSER_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tools/converter/parser/tf/tf_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFFakeQuantParser : public TFNodeParser {
|
||||
public:
|
||||
TFFakeQuantParser() = default;
|
||||
~TFFakeQuantParser() override = default;
|
||||
|
||||
PrimitiveCPtr Parse(const tensorflow::NodeDef &tf_op,
|
||||
const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
std::vector<std::string> *inputs, int *output_size) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FAKE_QUANT_PARSER_H_
|
|
@ -35,6 +35,7 @@
|
|||
#include "tools/converter/parser/tf/functionalize_control_op_pass.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/converter/parser/lite_model_parser_creator.h"
|
||||
#include "tools/converter/parser/tf/tf_fake_quant_adjust.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "src/common/log_util.h"
|
||||
#include "tools/converter/parser/unify_format.h"
|
||||
|
@ -1229,6 +1230,12 @@ int TFModelParser::TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto fake_quant_adjust = std::make_shared<TFFakeQuantAdjust>();
|
||||
if (!fake_quant_adjust->Adjust(func_graph)) {
|
||||
MS_LOG(ERROR) << "tf fake quant adjust failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <map>
|
||||
#include <memory>
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "tools/common/statistic_utils.h"
|
||||
#include "src/tensor.h"
|
||||
|
|
|
@ -94,7 +94,7 @@ int FullQuantQuantizer::DoParameterWeightQuant(const CNodePtr &cnode, const Para
|
|||
auto weight_q_min = per_channel ? weight_channel_q_min_ : weight_layer_q_min_;
|
||||
auto weight_q_max = per_channel ? weight_channel_q_max_ : weight_layer_q_max_;
|
||||
auto symmetric = per_channel ? weight_channel_symmetric_ : weight_layer_symmetric_;
|
||||
auto status = FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max,
|
||||
auto status = FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, schema::QuantType_QUANT_ALL, weight_q_max,
|
||||
weight_q_min, bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1,
|
||||
preferred_dim, symmetric);
|
||||
if (status != RET_OK) {
|
||||
|
@ -118,7 +118,7 @@ int FullQuantQuantizer::DoValueNodeWeightQuant(const CNodePtr &cnode, const Valu
|
|||
auto weight_q_min = per_channel ? weight_channel_q_min_ : weight_layer_q_min_;
|
||||
auto weight_q_max = per_channel ? weight_channel_q_max_ : weight_layer_q_max_;
|
||||
auto symmetric = per_channel ? weight_channel_symmetric_ : weight_layer_symmetric_;
|
||||
auto status = FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max,
|
||||
auto status = FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, schema::QuantType_QUANT_ALL, weight_q_max,
|
||||
weight_q_min, bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1,
|
||||
preferred_dim, symmetric);
|
||||
if (status != RET_OK) {
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "schema/inner/model_generated.h"
|
||||
#include "base/base.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "tools/converter/quantizer/quant_param_holder.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
|
|
|
@ -121,6 +121,7 @@ bool QuantTypeDeterminer::DetermineQuantWeight(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
int QuantTypeDeterminer::Determine() {
|
||||
CHECK_NULL_RETURN(func_graph_);
|
||||
auto nodes = func_graph_->GetOrderedCnodes();
|
||||
for (auto const &cnode : nodes) {
|
||||
auto quant_holder = GetCNodeQuantHolder(cnode);
|
||||
|
@ -130,10 +131,10 @@ int QuantTypeDeterminer::Determine() {
|
|||
}
|
||||
if (DetermineQuantWeight(cnode)) {
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << " set QuantType_QUANT_WEIGHT";
|
||||
quant_holder->set_quant_type(QuantType_QUANT_WEIGHT);
|
||||
quant_holder->set_quant_type(schema::QuantType_QUANT_WEIGHT);
|
||||
} else if (DetermineQuantAll(cnode)) {
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << " set QuantType_QUANT_ALL";
|
||||
quant_holder->set_quant_type(QuantType_QUANT_ALL);
|
||||
quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
|
||||
} else {
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << " Remove unused quant info";
|
||||
quant_holder->ClearQuantParams();
|
||||
|
|
|
@ -13,11 +13,12 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_TYPE_DETERMINER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_TYPE_DETERMINER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_TYPE_DETERMINER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_TYPE_DETERMINER_H_
|
||||
|
||||
#include <utility>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
|
@ -37,4 +38,4 @@ class QuantTypeDeterminer {
|
|||
FuncGraphPtr func_graph_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_TYPE_DETERMINER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_HELPER_QUANT_TYPE_DETERMINER_H_
|
||||
|
|
|
@ -41,6 +41,8 @@ constexpr float kRatio = 10.0;
|
|||
constexpr int kCpuBindMode = 1;
|
||||
constexpr int kPrimIndex = 0;
|
||||
constexpr int kPrimOffset = 1;
|
||||
constexpr int kU8ZeroPointOffset = 128;
|
||||
constexpr int kQuantRange = 127;
|
||||
|
||||
enum ActivationQuantizedMethod {
|
||||
MAX_MIN = 0,
|
||||
|
|
|
@ -37,6 +37,7 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
@ -537,7 +538,7 @@ void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<in
|
|||
}
|
||||
|
||||
int MixedBitQuantFilter(const AnfNodePtr ¶meter_node, const tensor::TensorPtr &weight,
|
||||
const PrimitivePtr &primitive, QuantType quant_type, WeightQuantType weight_quant_type,
|
||||
const PrimitivePtr &primitive, schema::QuantType quant_type, WeightQuantType weight_quant_type,
|
||||
TypeId quant_data_type, double init_scale, int index, int preferred_dim, bool symmetric) {
|
||||
MS_CHECK_TRUE_RET(primitive != nullptr, RET_NULL_PTR);
|
||||
MS_CHECK_TRUE_RET(weight != nullptr, RET_NULL_PTR);
|
||||
|
|
|
@ -48,7 +48,6 @@
|
|||
#include "tools/converter/quantizer/quant_params.h"
|
||||
#include "tools/converter/quantizer/mixed_bit_weight_quantizer.h"
|
||||
#include "src/runtime/lite_session.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "src/common/quant_utils.h"
|
||||
#include "include/api/model.h"
|
||||
|
@ -71,7 +70,7 @@ void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<in
|
|||
bool TensorQuantParamsInited(const schema::TensorT &tensor);
|
||||
|
||||
int MixedBitQuantFilter(const AnfNodePtr ¶meter_node, const tensor::TensorPtr &weight,
|
||||
const PrimitivePtr &primitive, QuantType quant_type, WeightQuantType weight_quant_type,
|
||||
const PrimitivePtr &primitive, schema::QuantType quant_type, WeightQuantType weight_quant_type,
|
||||
TypeId quant_data_type, double init_scale, int index, int preferred_dim, bool symmetric);
|
||||
|
||||
int CalChannels(const std::vector<int> &dims, int channel_cnt, bool *channel_at_first);
|
||||
|
@ -104,7 +103,7 @@ int DeQuantData(const int8_t *tensor_data, int64_t elements_num, std::vector<min
|
|||
|
||||
template <typename T>
|
||||
int FixedBitQuantFilter(const AnfNodePtr ¶meter_node, const tensor::TensorPtr &weight,
|
||||
const PrimitivePtr &primitive, QuantType quant_type, int quant_max, int quant_min,
|
||||
const PrimitivePtr &primitive, schema::QuantType quant_type, int quant_max, int quant_min,
|
||||
size_t bit_num, WeightQuantType weight_quant_type, TypeId quant_data_type, int index,
|
||||
int preferred_dim, bool symmetric = false, bool narrow_range = false) {
|
||||
MS_ASSERT(weight != nullptr);
|
||||
|
|
|
@ -213,7 +213,7 @@ int AnfExporter::CreateNewTensorForParameter(const std::unique_ptr<schema::MetaG
|
|||
schema_tensor->dataType = data_info.data_type_;
|
||||
schema_tensor->data = data_info.data_;
|
||||
schema_tensor->enableHuffmanCode = data_info.enable_huffman_code_;
|
||||
schema_tensor->weightQuantCompressType = static_cast<WeightQuantCompressType>(data_info.compress_type_);
|
||||
schema_tensor->weightQuantCompressType = static_cast<schema::WeightQuantCompressType>(data_info.compress_type_);
|
||||
schema_tensor->nodeType = NodeType_CNode;
|
||||
auto key = std::make_pair(input, 0);
|
||||
node_id_map_[key] = static_cast<int>(meta_graphT->allTensors.size());
|
||||
|
|
Loading…
Reference in New Issue