!39896 fix determiner and tf tflite quant parser
Merge pull request !39896 from liyan2022/qat_fix_determin_parser
This commit is contained in:
commit
31576ec31c
|
@ -1035,11 +1035,6 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
|
|||
return status;
|
||||
}
|
||||
|
||||
status = ConvertQuantParams(inputs.size() - 1, output_size, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert quant params for " << anf_node->fullname_with_scope() << " failed.";
|
||||
return status;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
|
@ -1077,18 +1072,6 @@ STATUS TFModelParser::ProcessControlFlowOp(const CNodePtr &anf_node, const strin
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS TFModelParser::ConvertQuantParams(const size_t &input_size, const size_t &output_size,
|
||||
PrimitiveCPtr primitive_c) {
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is null, get quant params failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto quant_params_holder = std::make_shared<QuantParamHolder>(input_size, output_size);
|
||||
CHECK_NULL_RETURN(quant_params_holder);
|
||||
primitive_c->AddAttr("quant_params", quant_params_holder);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::set<std::string> TFModelParser::GetAllNodeInputs() {
|
||||
std::set<std::string> all_node_inputs;
|
||||
for (auto &node : tf_root_graph_nodes_vec_) {
|
||||
|
|
|
@ -92,8 +92,6 @@ class TFModelParser : public converter::ModelParser {
|
|||
STATUS ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGraphPtr> &first_func_map,
|
||||
const std::map<CNodePtr, FuncGraphPtr> &second_func_map);
|
||||
|
||||
static STATUS ConvertQuantParams(const size_t &input_size, const size_t &output_size, PrimitiveCPtr primitive_c);
|
||||
|
||||
static STATUS MakeAnfGraphOutputs(const std::vector<AnfNodePtr> &output_nodes, const FuncGraphPtr &anf_graph);
|
||||
|
||||
STATUS RecordNullInput(const CNodePtr &node, const std::vector<std::string> &input_name_not_found);
|
||||
|
|
|
@ -435,23 +435,8 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const std::unique_ptr<tflite::Ope
|
|||
if (primitive_c->name() == "Conv2D" || primitive_c->name() == "Conv2DFusion") {
|
||||
round_type = 2;
|
||||
}
|
||||
int32_t inputs_size = 0;
|
||||
int32_t outputs_size = 0;
|
||||
if (primitive_c->name() == "FullyConnection") {
|
||||
std::vector<int32_t> inputs(op->inputs.size());
|
||||
std::vector<int32_t> outputs(op->outputs.size());
|
||||
auto it =
|
||||
std::copy_if(op->inputs.begin(), op->inputs.end(), inputs.begin(), [](const int32_t item) { return item >= 0; });
|
||||
inputs.resize(std::distance(inputs.begin(), it));
|
||||
it = std::copy_if(op->outputs.begin(), op->outputs.end(), outputs.begin(),
|
||||
[](const int32_t item) { return item >= 0; });
|
||||
outputs.resize(std::distance(outputs.begin(), it));
|
||||
} else {
|
||||
inputs_size = op->inputs.size();
|
||||
outputs_size = op->outputs.size();
|
||||
}
|
||||
auto quant_params_holder = std::make_shared<QuantParamHolder>(inputs_size, outputs_size);
|
||||
MSLITE_CHECK_PTR(quant_params_holder);
|
||||
|
||||
std::map<int, std::vector<schema::QuantParamT>> in_quant_param;
|
||||
size_t idx = 0;
|
||||
for (auto input_idx : op->inputs) {
|
||||
if (input_idx < 0) {
|
||||
|
@ -468,9 +453,10 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const std::unique_ptr<tflite::Ope
|
|||
MS_LOG(ERROR) << "set input tensor quant param failed.";
|
||||
return status;
|
||||
}
|
||||
quant_params_holder->set_input_quant_param(idx, quant_params);
|
||||
in_quant_param.insert({idx, quant_params});
|
||||
idx++;
|
||||
}
|
||||
std::map<size_t, std::vector<schema::QuantParamT>> out_quant_param;
|
||||
idx = 0;
|
||||
for (auto output_idx : op->outputs) {
|
||||
if (output_idx < 0) {
|
||||
|
@ -487,10 +473,20 @@ STATUS TfliteModelParser::ConvertOpQuantParams(const std::unique_ptr<tflite::Ope
|
|||
MS_LOG(ERROR) << "set output tensor quant param failed.";
|
||||
return status;
|
||||
}
|
||||
quant_params_holder->set_output_quant_param(idx, quant_params);
|
||||
out_quant_param.insert({idx, quant_params});
|
||||
idx++;
|
||||
}
|
||||
primitive_c->AddAttr("quant_params", quant_params_holder);
|
||||
if (!in_quant_param.empty() || !out_quant_param.empty()) {
|
||||
auto quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
|
||||
MSLITE_CHECK_PTR(quant_params_holder);
|
||||
for (auto &iter : in_quant_param) {
|
||||
quant_params_holder->set_input_quant_param(iter.first, iter.second);
|
||||
}
|
||||
for (auto &iter : out_quant_param) {
|
||||
quant_params_holder->set_output_quant_param(iter.first, iter.second);
|
||||
}
|
||||
primitive_c->AddAttr("quant_params", quant_params_holder);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -188,6 +188,10 @@ int QuantNodePass::DoParameterNodeQuant(const CNodePtr &cnode, const ParameterPt
|
|||
MS_LOG(ERROR) << input_node->fullname_with_scope() << " can not get value";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (tensor_info->data_type() != kNumberTypeFloat32) {
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << " is not float32, data will not quant.";
|
||||
return RET_OK;
|
||||
}
|
||||
int preferred_dim = GetPreferredDim(cnode, input_index - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
|
||||
MS_CHECK_GT(static_cast<int>(quant_param_holder->get_input_quant_params().size()),
|
||||
static_cast<int>(input_index) - 1, RET_ERROR);
|
||||
|
@ -216,6 +220,10 @@ int QuantNodePass::DoValueNodeQuant(const CNodePtr &cnode, const ValueNodePtr &i
|
|||
MS_LOG(ERROR) << input_node->fullname_with_scope() << " can not get value";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (tensor_info->data_type() != kNumberTypeFloat32) {
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << " is not float32, data will not quant.";
|
||||
return RET_OK;
|
||||
}
|
||||
int preferred_dim = GetPreferredDim(cnode, input_index - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
|
||||
MS_CHECK_GT(static_cast<int>(quant_param_holder->get_input_quant_params().size()), static_cast<int>(input_index) - 1,
|
||||
RET_ERROR);
|
||||
|
|
|
@ -21,33 +21,15 @@
|
|||
#include "src/litert/kernel_exec.h"
|
||||
#include "src/litert/kernel_registry.h"
|
||||
#include "src/common/ops/anf_utils.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
std::pair<size_t, size_t> QuantTypeDeterminer::GetQuantParamsNum(const QuantParamHolderPtr &quant_holder) {
|
||||
// update input quant params num
|
||||
auto input_inited_quant_params = 0;
|
||||
auto input_tensors = quant_holder->get_input_quant_params();
|
||||
for (auto input : input_tensors) {
|
||||
bool is_quant_params_inited = std::all_of(
|
||||
input.begin(), input.end(), [](const schema::QuantParamT &quant_param) { return quant_param.inited; });
|
||||
if (is_quant_params_inited) {
|
||||
input_inited_quant_params++;
|
||||
}
|
||||
}
|
||||
auto output_inited_quant_params = 0;
|
||||
auto output_tensors = quant_holder->get_output_quant_params();
|
||||
for (auto output : output_tensors) {
|
||||
bool is_quant_params_inited = !std::any_of(
|
||||
output.begin(), output.end(), [](const schema::QuantParamT &quant_param) { return !quant_param.inited; });
|
||||
if (is_quant_params_inited) {
|
||||
output_inited_quant_params++;
|
||||
}
|
||||
}
|
||||
return {input_inited_quant_params, output_inited_quant_params};
|
||||
}
|
||||
|
||||
bool QuantTypeDeterminer::DetermineQuantAll(const CNodePtr &cnode) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
if (opt::IsSpecialType(cnode)) {
|
||||
return false;
|
||||
}
|
||||
auto primT = GetPrimitiveT(cnode->input(kPrimIndex));
|
||||
if (primT == nullptr) {
|
||||
MS_LOG(WARNING) << cnode->fullname_with_scope() << " primitive is nullptr.";
|
||||
|
@ -66,51 +48,23 @@ bool QuantTypeDeterminer::DetermineQuantAll(const CNodePtr &cnode) {
|
|||
return false;
|
||||
}
|
||||
|
||||
// GetCNodeQuantType
|
||||
// Get CNode QuantType directly.
|
||||
if (quant_holder->quant_type() != schema::QuantType_QUANT_NONE) {
|
||||
return quant_holder->quant_type() == schema::QuantType_QUANT_ALL;
|
||||
}
|
||||
// All output need init.
|
||||
if (!quant_holder->IsOutputQuantParamsInited()) {
|
||||
return false;
|
||||
}
|
||||
if (CheckNodeInSet(cnode, bias_ops_)) {
|
||||
auto input_quant_params = quant_holder->get_input_quant_params();
|
||||
MS_CHECK_TRUE_RET(!input_quant_params.empty(), false);
|
||||
bool input_params_inited =
|
||||
(!input_quant_params.at(kInputIndex).empty() && input_quant_params.at(kInputIndex).front().inited) &&
|
||||
(!input_quant_params.at(kWeightIndex).empty() && input_quant_params.at(kWeightIndex).front().inited);
|
||||
if (!input_params_inited || !quant_holder->IsOutputQuantParamsInited()) {
|
||||
|
||||
// Quantization parameters exist for all activations.
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto input = cnode->input(i);
|
||||
if (input->isa<CNode>() && !quant_holder->CheckInit(i - kPrimOffset, true)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto in_out_quant_params = GetQuantParamsNum(quant_holder);
|
||||
// Check quant param size is same as tensor size.
|
||||
auto input_size = (cnode->size() - kPrimOffset);
|
||||
if (CheckNodeInSet(cnode, bias_ops_)) {
|
||||
input_size -= kPrimOffset;
|
||||
}
|
||||
// exclude input(not fp32)
|
||||
for (size_t index = 1; index < cnode->size(); ++index) {
|
||||
CHECK_NULL_RETURN(cnode->input(index));
|
||||
auto abstract_base = cnode->input(index)->abstract();
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
|
||||
MS_LOG(ERROR) << cnode->fullname_with_scope() << " index: " << index << " should be AbstractTensorPtr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||
CHECK_NULL_RETURN(abstract_tensor);
|
||||
CHECK_NULL_RETURN(abstract_tensor->element());
|
||||
if (abstract_tensor->element()->GetTypeTrack()->type_id() != kNumberTypeFloat32) {
|
||||
input_size -= kPrimOffset;
|
||||
}
|
||||
}
|
||||
auto output_size = opt::GetOutputSize(cnode);
|
||||
if (in_out_quant_params.first == input_size && in_out_quant_params.second == output_size) {
|
||||
quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool QuantTypeDeterminer::DetermineQuantWeight(const CNodePtr &cnode) {
|
||||
|
@ -120,7 +74,7 @@ bool QuantTypeDeterminer::DetermineQuantWeight(const CNodePtr &cnode) {
|
|||
return false;
|
||||
}
|
||||
|
||||
// GetCNodeQuantType
|
||||
// Get CNode QuantType directly.
|
||||
if (quant_holder->quant_type() != schema::QuantType_QUANT_NONE) {
|
||||
return quant_holder->quant_type() == schema::QuantType_QUANT_WEIGHT;
|
||||
}
|
||||
|
@ -130,8 +84,12 @@ bool QuantTypeDeterminer::DetermineQuantWeight(const CNodePtr &cnode) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool quant_flag = false;
|
||||
for (size_t i = 1; i < cnode->size(); i++) {
|
||||
auto input = cnode->input(i);
|
||||
if (IsGraphInput(input)) {
|
||||
continue;
|
||||
}
|
||||
// non-constants(CNode) don't include quantization parameters
|
||||
if (input->isa<mindspore::CNode>()) {
|
||||
if (quant_holder->CheckInit(i - kPrimOffset, true)) {
|
||||
|
@ -140,11 +98,12 @@ bool QuantTypeDeterminer::DetermineQuantWeight(const CNodePtr &cnode) {
|
|||
} else {
|
||||
// Constants have quantization parameters
|
||||
if (quant_holder->CheckInit(i - kPrimOffset, true)) {
|
||||
return true;
|
||||
quant_flag = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return quant_flag;
|
||||
}
|
||||
|
||||
int QuantTypeDeterminer::Determine() {
|
||||
|
@ -156,15 +115,18 @@ int QuantTypeDeterminer::Determine() {
|
|||
MS_LOG(INFO) << cnode->fullname_with_scope() << " quant holder is nullptr.";
|
||||
continue;
|
||||
}
|
||||
if (DetermineQuantWeight(cnode)) {
|
||||
if (!quant_holder->IsInputQuantParamsInited() && !quant_holder->IsOutputQuantParamsInited()) { // Check FP32.
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << " Remove unused quant info";
|
||||
quant_holder->ClearQuantParams();
|
||||
} else if (DetermineQuantWeight(cnode)) {
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << " set QuantType_QUANT_WEIGHT";
|
||||
quant_holder->set_quant_type(schema::QuantType_QUANT_WEIGHT);
|
||||
} else if (DetermineQuantAll(cnode)) {
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << " set QuantType_QUANT_ALL";
|
||||
quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
|
||||
} else {
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << " Remove unused quant info";
|
||||
quant_holder->ClearQuantParams();
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
|
@ -35,8 +35,6 @@ class QuantTypeDeterminer {
|
|||
private:
|
||||
bool DetermineQuantAll(const CNodePtr &cnode);
|
||||
|
||||
std::pair<size_t, size_t> GetQuantParamsNum(const QuantParamHolderPtr &quant_holder);
|
||||
|
||||
bool DetermineQuantWeight(const CNodePtr &cnode);
|
||||
|
||||
private:
|
||||
|
|
Loading…
Reference in New Issue