!23664 converter change QuantType_PostTraining to QuantType_QUANT_ALL

Merge pull request !23664 from yeyunpeng2020/quant_bak_3
This commit is contained in:
i-robot 2021-09-17 07:35:09 +00:00 committed by Gitee
commit d5632e318f
7 changed files with 15 additions and 14 deletions

View File

@ -100,7 +100,7 @@ int OpenCLExecutor::RunOrTune(const std::vector<Tensor *> &inputs, const std::ve
int OpenCLExecutor::Tune(kernel::OpenCLKernel *op_kernel) {
auto ret = op_kernel->PreProcess();
if (ret != RET_OK) {
MS_LOG(WARNING) << "PreProcess kernel failed, name: " << op_kernel->name() << " in tuning";
MS_LOG(ERROR) << "PreProcess kernel failed, name: " << op_kernel->name() << " in tuning";
return ret;
}
ret = op_kernel->Tune();
@ -110,7 +110,7 @@ int OpenCLExecutor::Tune(kernel::OpenCLKernel *op_kernel) {
}
ret = op_kernel->PostProcess();
if (ret != RET_OK) {
MS_LOG(WARNING) << "PostProcess kernel failed, name: " << op_kernel->name() << " in tuning";
MS_LOG(ERROR) << "PostProcess kernel failed, name: " << op_kernel->name() << " in tuning";
return ret;
}
return RET_OK;

View File

@ -95,7 +95,7 @@ int AnfExporter::SetPostTrainOutputTensorType(const std::unique_ptr<schema::Meta
const std::unique_ptr<schema::CNodeT> &dst_node) {
auto first_output_index = dst_node->outputIndex[0];
auto first_tensor_output = meta_graph->allTensors[first_output_index].get();
if (dst_node->quantType == schema::QuantType_PostTraining) {
if (dst_node->quantType == schema::QuantType_QUANT_ALL) {
if (primitive->name() != mindspore::ops::kNameQuantDTypeCast) {
first_tensor_output->dataType = kNumberTypeInt8;
} else {

View File

@ -368,7 +368,7 @@ void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGr
int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
// quant
if (config->commonQuantParam.quant_type == schema::QuantType_PostTraining) {
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) {
this->m_quantizer_ = std::make_unique<quant::FullQuantQuantizer>(old_graph, config->commonQuantParam.bit_num);
if (m_quantizer_ == nullptr) {
MS_LOG(ERROR) << "New FullQuantQuantizer failed";

View File

@ -43,7 +43,7 @@ int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_str
MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be [0,16].";
return RET_INPUT_PARAM_INVALID;
}
} else if (common_quant->quant_type == schema::QuantType_PostTraining) {
} else if (common_quant->quant_type == schema::QuantType_QUANT_ALL) {
if (common_quant->bit_num <= 0 || common_quant->bit_num > kQuantBitNumInt8) {
MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be [1,8].";
return RET_INPUT_PARAM_INVALID;
@ -107,7 +107,7 @@ int QuantParamParser::ParseQuantType(const std::string &quant_type_str, schema::
(*quant_type) = schema::QuantType_WeightQuant;
return RET_OK;
} else if (quant_type_str == "FULL_QUANT") {
(*quant_type) = schema::QuantType_PostTraining;
(*quant_type) = schema::QuantType_QUANT_ALL;
return RET_OK;
} else if (quant_type_str.empty()) {
(*quant_type) = schema::QuantType_QUANT_NONE;

View File

@ -516,7 +516,7 @@ STATUS FullQuantQuantizer::DoWeightQuant(const std::string &op_name, const AnfNo
auto quant_max_t = quant_max;
auto quant_min_t = quant_min;
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
auto status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_PostTraining, quant_max_t, quant_min_t,
auto status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_QUANT_ALL, quant_max_t, quant_min_t,
bit_num_t, weight_quant_type, kNumberTypeInt8);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status;
@ -1179,7 +1179,7 @@ STATUS FullQuantQuantizer::CollectDataFrequency() {
return true;
}
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
return false;
return true;
}
for (size_t i = 0; i < (*diverg_info_map)[callParam.node_name].size(); i++) {
auto tensor = beforeInputs[i];
@ -1201,7 +1201,7 @@ STATUS FullQuantQuantizer::CollectDataFrequency() {
return true;
}
if (FullQuantQuantizer::CheckFp32TensorVec(call_param.node_name, after_outputs) != RET_OK) {
return false;
return true;
}
int output_i = 0;
for (const auto &tensor : after_outputs) {
@ -1301,7 +1301,7 @@ STATUS FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
if (calibrator_->GetBiasCorrection()) {
// init in8 session
MS_LOG(INFO) << "create quant session";
flags.commonQuantParam.quant_type = schema::QuantType_PostTraining;
flags.commonQuantParam.quant_type = schema::QuantType_QUANT_ALL;
auto int8_sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
int8_session_ = int8_sm.session;
int8_model_ = int8_sm.model;
@ -1374,7 +1374,7 @@ KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
const CallBackParam &callParam) -> bool {
if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, before_inputs) != RET_OK) {
return false;
return true;
}
auto tensor = before_inputs[0];
MS_ASSERT(tensor != nullptr);
@ -1537,7 +1537,7 @@ KernelCallBack FullQuantQuantizer::GetFloatAfterCallBack() {
const CallBackParam &callParam) -> bool {
if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
return false;
return true;
}
auto tensor = afterOutputs[0];
MS_ASSERT(tensor != nullptr);

View File

@ -223,7 +223,7 @@ STATUS FixedBitQuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &
return RET_ERROR;
}
auto quant_param_holder = GetCNodeQuantHolder(primitive);
if (quant_type == QuantType_PostTraining) {
if (quant_type == QuantType_QUANT_ALL) {
quant_param_holder->set_input_quant_param(index, quant_params);
} else {
quant_param_holder->set_input_quant_param(index, quant_params);

View File

@ -113,7 +113,8 @@ int Cropper::GetModelOps() {
MS_LOG(DEBUG) << "PrimitiveType:" << schema::EnumNamePrimitiveType(node->primitive()->value_type())
<< " QuantType:" << schema::EnumNameQuantType(node->quantType());
// QuantType_AwareTraining may change
if (node->quantType() == schema::QuantType_AwareTraining || node->quantType() == schema::QuantType_PostTraining) {
if (node->quantType() == schema::QuantType_AwareTraining || node->quantType() == schema::QuantType_PostTraining ||
node->quantType() == schema::QuantType_QUANT_ALL) {
this->int8_operators_.insert(node->primitive()->value_type());
} else {
this->fp32_operators_.insert(node->primitive()->value_type());