!39896 fix determiner and tf tflite quant parser

Merge pull request !39896 from liyan2022/qat_fix_determin_parser
This commit is contained in:
i-robot 2022-08-06 04:18:55 +00:00 committed by Gitee
commit 31576ec31c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 53 additions and 108 deletions

View File

@ -1035,11 +1035,6 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
return status;
}
status = ConvertQuantParams(inputs.size() - 1, output_size, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed.";
return status;
}
return status;
}
@ -1077,18 +1072,6 @@ STATUS TFModelParser::ProcessControlFlowOp(const CNodePtr &anf_node, const strin
return RET_OK;
}
STATUS TFModelParser::ConvertQuantParams(const size_t &input_size, const size_t &output_size,
PrimitiveCPtr primitive_c) {
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is null, get quant params failed.";
return RET_NULL_PTR;
}
auto quant_params_holder = std::make_shared<QuantParamHolder>(input_size, output_size);
CHECK_NULL_RETURN(quant_params_holder);
primitive_c->AddAttr("quant_params", quant_params_holder);
return RET_OK;
}
std::set<std::string> TFModelParser::GetAllNodeInputs() {
std::set<std::string> all_node_inputs;
for (auto &node : tf_root_graph_nodes_vec_) {

View File

@ -92,8 +92,6 @@ class TFModelParser : public converter::ModelParser {
STATUS ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGraphPtr> &first_func_map,
const std::map<CNodePtr, FuncGraphPtr> &second_func_map);
static STATUS ConvertQuantParams(const size_t &input_size, const size_t &output_size, PrimitiveCPtr primitive_c);
static STATUS MakeAnfGraphOutputs(const std::vector<AnfNodePtr> &output_nodes, const FuncGraphPtr &anf_graph);
STATUS RecordNullInput(const CNodePtr &node, const std::vector<std::string> &input_name_not_found);

View File

@ -435,23 +435,8 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const std::unique_ptr<tflite::Ope
if (primitive_c->name() == "Conv2D" || primitive_c->name() == "Conv2DFusion") {
round_type = 2;
}
int32_t inputs_size = 0;
int32_t outputs_size = 0;
if (primitive_c->name() == "FullyConnection") {
std::vector<int32_t> inputs(op->inputs.size());
std::vector<int32_t> outputs(op->outputs.size());
auto it =
std::copy_if(op->inputs.begin(), op->inputs.end(), inputs.begin(), [](const int32_t item) { return item >= 0; });
inputs.resize(std::distance(inputs.begin(), it));
it = std::copy_if(op->outputs.begin(), op->outputs.end(), outputs.begin(),
[](const int32_t item) { return item >= 0; });
outputs.resize(std::distance(outputs.begin(), it));
} else {
inputs_size = op->inputs.size();
outputs_size = op->outputs.size();
}
auto quant_params_holder = std::make_shared<QuantParamHolder>(inputs_size, outputs_size);
MSLITE_CHECK_PTR(quant_params_holder);
std::map<int, std::vector<schema::QuantParamT>> in_quant_param;
size_t idx = 0;
for (auto input_idx : op->inputs) {
if (input_idx < 0) {
@ -468,9 +453,10 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const std::unique_ptr<tflite::Ope
MS_LOG(ERROR) << "set input tensor quant param failed.";
return status;
}
quant_params_holder->set_input_quant_param(idx, quant_params);
in_quant_param.insert({idx, quant_params});
idx++;
}
std::map<size_t, std::vector<schema::QuantParamT>> out_quant_param;
idx = 0;
for (auto output_idx : op->outputs) {
if (output_idx < 0) {
@ -487,10 +473,20 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const std::unique_ptr<tflite::Ope
MS_LOG(ERROR) << "set output tensor quant param failed.";
return status;
}
quant_params_holder->set_output_quant_param(idx, quant_params);
out_quant_param.insert({idx, quant_params});
idx++;
}
primitive_c->AddAttr("quant_params", quant_params_holder);
if (!in_quant_param.empty() || !out_quant_param.empty()) {
auto quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
MSLITE_CHECK_PTR(quant_params_holder);
for (auto &iter : in_quant_param) {
quant_params_holder->set_input_quant_param(iter.first, iter.second);
}
for (auto &iter : out_quant_param) {
quant_params_holder->set_output_quant_param(iter.first, iter.second);
}
primitive_c->AddAttr("quant_params", quant_params_holder);
}
return RET_OK;
}

View File

@ -188,6 +188,10 @@ int QuantNodePass::DoParameterNodeQuant(const CNodePtr &cnode, const ParameterPt
MS_LOG(ERROR) << input_node->fullname_with_scope() << " can not get value";
return RET_NULL_PTR;
}
if (tensor_info->data_type() != kNumberTypeFloat32) {
MS_LOG(INFO) << cnode->fullname_with_scope() << " is not float32, data will not quant.";
return RET_OK;
}
int preferred_dim = GetPreferredDim(cnode, input_index - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
MS_CHECK_GT(static_cast<int>(quant_param_holder->get_input_quant_params().size()),
static_cast<int>(input_index) - 1, RET_ERROR);
@ -216,6 +220,10 @@ int QuantNodePass::DoValueNodeQuant(const CNodePtr &cnode, const ValueNodePtr &i
MS_LOG(ERROR) << input_node->fullname_with_scope() << " can not get value";
return RET_NULL_PTR;
}
if (tensor_info->data_type() != kNumberTypeFloat32) {
MS_LOG(INFO) << cnode->fullname_with_scope() << " is not float32, data will not quant.";
return RET_OK;
}
int preferred_dim = GetPreferredDim(cnode, input_index - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
MS_CHECK_GT(static_cast<int>(quant_param_holder->get_input_quant_params().size()), static_cast<int>(input_index) - 1,
RET_ERROR);

View File

@ -21,33 +21,15 @@
#include "src/litert/kernel_exec.h"
#include "src/litert/kernel_registry.h"
#include "src/common/ops/anf_utils.h"
#include "tools/optimizer/common/format_utils.h"
#include "tools/common/node_util.h"
namespace mindspore::lite::quant {
std::pair<size_t, size_t> QuantTypeDeterminer::GetQuantParamsNum(const QuantParamHolderPtr &quant_holder) {
// update input quant params num
auto input_inited_quant_params = 0;
auto input_tensors = quant_holder->get_input_quant_params();
for (auto input : input_tensors) {
bool is_quant_params_inited = std::all_of(
input.begin(), input.end(), [](const schema::QuantParamT &quant_param) { return quant_param.inited; });
if (is_quant_params_inited) {
input_inited_quant_params++;
}
}
auto output_inited_quant_params = 0;
auto output_tensors = quant_holder->get_output_quant_params();
for (auto output : output_tensors) {
bool is_quant_params_inited = !std::any_of(
output.begin(), output.end(), [](const schema::QuantParamT &quant_param) { return !quant_param.inited; });
if (is_quant_params_inited) {
output_inited_quant_params++;
}
}
return {input_inited_quant_params, output_inited_quant_params};
}
bool QuantTypeDeterminer::DetermineQuantAll(const CNodePtr &cnode) {
MS_ASSERT(node != nullptr);
MS_ASSERT(cnode != nullptr);
if (opt::IsSpecialType(cnode)) {
return false;
}
auto primT = GetPrimitiveT(cnode->input(kPrimIndex));
if (primT == nullptr) {
MS_LOG(WARNING) << cnode->fullname_with_scope() << " primitive is nullptr.";
@ -66,51 +48,23 @@ bool QuantTypeDeterminer::DetermineQuantAll(const CNodePtr &cnode) {
return false;
}
// GetCNodeQuantType
// Get CNode QuantType directly.
if (quant_holder->quant_type() != schema::QuantType_QUANT_NONE) {
return quant_holder->quant_type() == schema::QuantType_QUANT_ALL;
}
// All output need init.
if (!quant_holder->IsOutputQuantParamsInited()) {
return false;
}
if (CheckNodeInSet(cnode, bias_ops_)) {
auto input_quant_params = quant_holder->get_input_quant_params();
MS_CHECK_TRUE_RET(!input_quant_params.empty(), false);
bool input_params_inited =
(!input_quant_params.at(kInputIndex).empty() && input_quant_params.at(kInputIndex).front().inited) &&
(!input_quant_params.at(kWeightIndex).empty() && input_quant_params.at(kWeightIndex).front().inited);
if (!input_params_inited || !quant_holder->IsOutputQuantParamsInited()) {
// Quantization parameters exist for all activations.
for (size_t i = 1; i < cnode->size(); ++i) {
auto input = cnode->input(i);
if (input->isa<CNode>() && !quant_holder->CheckInit(i - kPrimOffset, true)) {
return false;
}
}
auto in_out_quant_params = GetQuantParamsNum(quant_holder);
// Check quant param size is same as tensor size.
auto input_size = (cnode->size() - kPrimOffset);
if (CheckNodeInSet(cnode, bias_ops_)) {
input_size -= kPrimOffset;
}
// exclude input(not fp32)
for (size_t index = 1; index < cnode->size(); ++index) {
CHECK_NULL_RETURN(cnode->input(index));
auto abstract_base = cnode->input(index)->abstract();
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << " index: " << index << " should be AbstractTensorPtr.";
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
CHECK_NULL_RETURN(abstract_tensor);
CHECK_NULL_RETURN(abstract_tensor->element());
if (abstract_tensor->element()->GetTypeTrack()->type_id() != kNumberTypeFloat32) {
input_size -= kPrimOffset;
}
}
auto output_size = opt::GetOutputSize(cnode);
if (in_out_quant_params.first == input_size && in_out_quant_params.second == output_size) {
quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
return true;
}
return false;
return true;
}
bool QuantTypeDeterminer::DetermineQuantWeight(const CNodePtr &cnode) {
@ -120,7 +74,7 @@ bool QuantTypeDeterminer::DetermineQuantWeight(const CNodePtr &cnode) {
return false;
}
// GetCNodeQuantType
// Get CNode QuantType directly.
if (quant_holder->quant_type() != schema::QuantType_QUANT_NONE) {
return quant_holder->quant_type() == schema::QuantType_QUANT_WEIGHT;
}
@ -130,8 +84,12 @@ bool QuantTypeDeterminer::DetermineQuantWeight(const CNodePtr &cnode) {
return false;
}
bool quant_flag = false;
for (size_t i = 1; i < cnode->size(); i++) {
auto input = cnode->input(i);
if (IsGraphInput(input)) {
continue;
}
// non-constants(CNode) don't include quantization parameters
if (input->isa<mindspore::CNode>()) {
if (quant_holder->CheckInit(i - kPrimOffset, true)) {
@ -140,11 +98,12 @@ bool QuantTypeDeterminer::DetermineQuantWeight(const CNodePtr &cnode) {
} else {
// Constants have quantization parameters
if (quant_holder->CheckInit(i - kPrimOffset, true)) {
return true;
quant_flag = true;
continue;
}
}
}
return false;
return quant_flag;
}
int QuantTypeDeterminer::Determine() {
@ -156,15 +115,18 @@ int QuantTypeDeterminer::Determine() {
MS_LOG(INFO) << cnode->fullname_with_scope() << " quant holder is nullptr.";
continue;
}
if (DetermineQuantWeight(cnode)) {
if (!quant_holder->IsInputQuantParamsInited() && !quant_holder->IsOutputQuantParamsInited()) { // Check FP32.
if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
continue;
}
MS_LOG(INFO) << cnode->fullname_with_scope() << " Remove unused quant info";
quant_holder->ClearQuantParams();
} else if (DetermineQuantWeight(cnode)) {
MS_LOG(INFO) << cnode->fullname_with_scope() << " set QuantType_QUANT_WEIGHT";
quant_holder->set_quant_type(schema::QuantType_QUANT_WEIGHT);
} else if (DetermineQuantAll(cnode)) {
MS_LOG(INFO) << cnode->fullname_with_scope() << " set QuantType_QUANT_ALL";
quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
} else {
MS_LOG(INFO) << cnode->fullname_with_scope() << " Remove unused quant info";
quant_holder->ClearQuantParams();
}
}
return RET_OK;

View File

@ -35,8 +35,6 @@ class QuantTypeDeterminer {
private:
bool DetermineQuantAll(const CNodePtr &cnode);
std::pair<size_t, size_t> GetQuantParamsNum(const QuantParamHolderPtr &quant_holder);
bool DetermineQuantWeight(const CNodePtr &cnode);
private: