forked from mindspore-Ecosystem/mindspore
!23664 converter change QuantType_PostTraining to QuantType_QUANT_ALL
Merge pull request !23664 from yeyunpeng2020/quant_bak_3
This commit is contained in:
commit
d5632e318f
|
@ -100,7 +100,7 @@ int OpenCLExecutor::RunOrTune(const std::vector<Tensor *> &inputs, const std::ve
|
|||
int OpenCLExecutor::Tune(kernel::OpenCLKernel *op_kernel) {
|
||||
auto ret = op_kernel->PreProcess();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(WARNING) << "PreProcess kernel failed, name: " << op_kernel->name() << " in tuning";
|
||||
MS_LOG(ERROR) << "PreProcess kernel failed, name: " << op_kernel->name() << " in tuning";
|
||||
return ret;
|
||||
}
|
||||
ret = op_kernel->Tune();
|
||||
|
@ -110,7 +110,7 @@ int OpenCLExecutor::Tune(kernel::OpenCLKernel *op_kernel) {
|
|||
}
|
||||
ret = op_kernel->PostProcess();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(WARNING) << "PostProcess kernel failed, name: " << op_kernel->name() << " in tuning";
|
||||
MS_LOG(ERROR) << "PostProcess kernel failed, name: " << op_kernel->name() << " in tuning";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
|
@ -95,7 +95,7 @@ int AnfExporter::SetPostTrainOutputTensorType(const std::unique_ptr<schema::Meta
|
|||
const std::unique_ptr<schema::CNodeT> &dst_node) {
|
||||
auto first_output_index = dst_node->outputIndex[0];
|
||||
auto first_tensor_output = meta_graph->allTensors[first_output_index].get();
|
||||
if (dst_node->quantType == schema::QuantType_PostTraining) {
|
||||
if (dst_node->quantType == schema::QuantType_QUANT_ALL) {
|
||||
if (primitive->name() != mindspore::ops::kNameQuantDTypeCast) {
|
||||
first_tensor_output->dataType = kNumberTypeInt8;
|
||||
} else {
|
||||
|
|
|
@ -368,7 +368,7 @@ void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGr
|
|||
|
||||
int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
// quant
|
||||
if (config->commonQuantParam.quant_type == schema::QuantType_PostTraining) {
|
||||
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) {
|
||||
this->m_quantizer_ = std::make_unique<quant::FullQuantQuantizer>(old_graph, config->commonQuantParam.bit_num);
|
||||
if (m_quantizer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New FullQuantQuantizer failed";
|
||||
|
|
|
@ -43,7 +43,7 @@ int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_str
|
|||
MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be [0,16].";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
} else if (common_quant->quant_type == schema::QuantType_PostTraining) {
|
||||
} else if (common_quant->quant_type == schema::QuantType_QUANT_ALL) {
|
||||
if (common_quant->bit_num <= 0 || common_quant->bit_num > kQuantBitNumInt8) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be [1,8].";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
|
@ -107,7 +107,7 @@ int QuantParamParser::ParseQuantType(const std::string &quant_type_str, schema::
|
|||
(*quant_type) = schema::QuantType_WeightQuant;
|
||||
return RET_OK;
|
||||
} else if (quant_type_str == "FULL_QUANT") {
|
||||
(*quant_type) = schema::QuantType_PostTraining;
|
||||
(*quant_type) = schema::QuantType_QUANT_ALL;
|
||||
return RET_OK;
|
||||
} else if (quant_type_str.empty()) {
|
||||
(*quant_type) = schema::QuantType_QUANT_NONE;
|
||||
|
|
|
@ -516,7 +516,7 @@ STATUS FullQuantQuantizer::DoWeightQuant(const std::string &op_name, const AnfNo
|
|||
auto quant_max_t = quant_max;
|
||||
auto quant_min_t = quant_min;
|
||||
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||
auto status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_PostTraining, quant_max_t, quant_min_t,
|
||||
auto status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_QUANT_ALL, quant_max_t, quant_min_t,
|
||||
bit_num_t, weight_quant_type, kNumberTypeInt8);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||
|
@ -1179,7 +1179,7 @@ STATUS FullQuantQuantizer::CollectDataFrequency() {
|
|||
return true;
|
||||
}
|
||||
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
for (size_t i = 0; i < (*diverg_info_map)[callParam.node_name].size(); i++) {
|
||||
auto tensor = beforeInputs[i];
|
||||
|
@ -1201,7 +1201,7 @@ STATUS FullQuantQuantizer::CollectDataFrequency() {
|
|||
return true;
|
||||
}
|
||||
if (FullQuantQuantizer::CheckFp32TensorVec(call_param.node_name, after_outputs) != RET_OK) {
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
int output_i = 0;
|
||||
for (const auto &tensor : after_outputs) {
|
||||
|
@ -1301,7 +1301,7 @@ STATUS FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
if (calibrator_->GetBiasCorrection()) {
|
||||
// init in8 session
|
||||
MS_LOG(INFO) << "create quant session";
|
||||
flags.commonQuantParam.quant_type = schema::QuantType_PostTraining;
|
||||
flags.commonQuantParam.quant_type = schema::QuantType_QUANT_ALL;
|
||||
auto int8_sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
|
||||
int8_session_ = int8_sm.session;
|
||||
int8_model_ = int8_sm.model;
|
||||
|
@ -1374,7 +1374,7 @@ KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
|
|||
const CallBackParam &callParam) -> bool {
|
||||
if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
|
||||
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, before_inputs) != RET_OK) {
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
auto tensor = before_inputs[0];
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
|
@ -1537,7 +1537,7 @@ KernelCallBack FullQuantQuantizer::GetFloatAfterCallBack() {
|
|||
const CallBackParam &callParam) -> bool {
|
||||
if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
|
||||
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
auto tensor = afterOutputs[0];
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
|
|
|
@ -223,7 +223,7 @@ STATUS FixedBitQuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
if (quant_type == QuantType_PostTraining) {
|
||||
if (quant_type == QuantType_QUANT_ALL) {
|
||||
quant_param_holder->set_input_quant_param(index, quant_params);
|
||||
} else {
|
||||
quant_param_holder->set_input_quant_param(index, quant_params);
|
||||
|
|
|
@ -113,7 +113,8 @@ int Cropper::GetModelOps() {
|
|||
MS_LOG(DEBUG) << "PrimitiveType:" << schema::EnumNamePrimitiveType(node->primitive()->value_type())
|
||||
<< " QuantType:" << schema::EnumNameQuantType(node->quantType());
|
||||
// QuantType_AwareTraining may change
|
||||
if (node->quantType() == schema::QuantType_AwareTraining || node->quantType() == schema::QuantType_PostTraining) {
|
||||
if (node->quantType() == schema::QuantType_AwareTraining || node->quantType() == schema::QuantType_PostTraining ||
|
||||
node->quantType() == schema::QuantType_QUANT_ALL) {
|
||||
this->int8_operators_.insert(node->primitive()->value_type());
|
||||
} else {
|
||||
this->fp32_operators_.insert(node->primitive()->value_type());
|
||||
|
|
Loading…
Reference in New Issue