move target_device to full quant

This commit is contained in:
yeyunpeng2020 2021-12-03 11:07:53 +08:00
parent 6542ba2f2f
commit 2fd2606fee
5 changed files with 14 additions and 17 deletions

View File

@ -114,7 +114,6 @@ int ConfigFileParser::ParseCommonQuantString(const std::map<std::string, std::ma
{"min_quant_weight_channel", common_quant_string_.min_quant_weight_channel},
{"skip_quant_node", common_quant_string_.skip_quant_node},
{"debug_info_save_path", common_quant_string_.debug_info_save_path},
{"target_device", common_quant_string_.target_device},
};
return SetMapData(map, parse_map, kCommonQuantParam);
}
@ -139,6 +138,7 @@ int ConfigFileParser::ParseFullQuantString(const std::map<std::string, std::map<
std::map<std::string, std::string &> parse_map{
{"activation_quant_method", full_quant_string_.activation_quant_method},
{"bias_correction", full_quant_string_.bias_correction},
{"target_device", full_quant_string_.target_device},
};
return SetMapData(map, parse_map, kFullQuantParam);
}

View File

@ -43,7 +43,6 @@ struct CommonQuantString {
std::string min_quant_weight_channel;
std::string skip_quant_node;
std::string debug_info_save_path;
std::string target_device;
};
struct MixedBitWeightQuantString {
@ -54,6 +53,7 @@ struct MixedBitWeightQuantString {
struct FullQuantString {
std::string activation_quant_method;
std::string bias_correction;
std::string target_device;
};
struct RegistryInfoString {

View File

@ -104,13 +104,6 @@ int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_str
common_quant->is_debug = true;
}
if (!common_quant_string.target_device.empty()) {
ret = ParseTargetDevice(common_quant_string.target_device, &common_quant->target_device);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse device failed.";
return ret;
}
}
return RET_OK;
}
@ -147,6 +140,13 @@ int QuantParamParser::ParseFullQuant(const FullQuantString &full_quant_string, q
MS_LOG(ERROR) << "INPUT ILLEGAL: bias_correction should be true or false.";
return RET_INPUT_PARAM_INVALID;
}
if (!full_quant_string.target_device.empty()) {
auto ret = ParseTargetDevice(full_quant_string.target_device, &full_quant->target_device);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse device failed.";
return ret;
}
}
return RET_OK;
}
@ -167,14 +167,11 @@ int QuantParamParser::ParseQuantType(const std::string &quant_type_str, schema::
}
int QuantParamParser::ParseTargetDevice(const std::string &target_device_str, quant::TargetDevice *target_device) {
if (target_device_str == "CPU") {
(*target_device) = quant::CPU;
return RET_OK;
} else if (target_device_str == "KIRIN") {
if (target_device_str == "KIRIN") {
(*target_device) = quant::KIRIN;
return RET_OK;
} else {
MS_LOG(ERROR) << "INPUT ILLEGAL: target_device must be CPU|KIRIN.";
MS_LOG(ERROR) << "INPUT ILLEGAL: target_device must be KIRIN.";
return RET_INPUT_PARAM_INVALID;
}
}

View File

@ -578,7 +578,7 @@ int FullQuantQuantizer::MarkQuantNode(const FuncGraphPtr &func_graph) {
}
int FullQuantQuantizer::PreProcess(const FuncGraphPtr &func_graph) {
switch (flags_.commonQuantParam.target_device) {
switch (flags_.fullQuantParam.target_device) {
case CPU:
InitCpuConfig();
break;
@ -586,7 +586,7 @@ int FullQuantQuantizer::PreProcess(const FuncGraphPtr &func_graph) {
InitKirinConfig();
break;
default:
MS_LOG(ERROR) << " Unsupported device " << flags_.commonQuantParam.target_device;
MS_LOG(ERROR) << " Unsupported device " << flags_.fullQuantParam.target_device;
return RET_ERROR;
break;
}

View File

@ -41,7 +41,6 @@ struct CommonQuantParam {
bool is_debug = false;
std::string debug_info_save_path;
std::set<std::string> skip_quant_node;
TargetDevice target_device = CPU;
int thread_num = 4;
};
@ -53,6 +52,7 @@ struct MixedBitWeightQuantParam {
struct FullQuantParam {
ActivationQuantizedMethod activation_quant_method = MAX_MIN;
bool bias_correction = true;
TargetDevice target_device = CPU;
};
} // namespace mindspore::lite::quant