move target_device to full quant
This commit is contained in:
parent
6542ba2f2f
commit
2fd2606fee
|
@ -114,7 +114,6 @@ int ConfigFileParser::ParseCommonQuantString(const std::map<std::string, std::ma
|
|||
{"min_quant_weight_channel", common_quant_string_.min_quant_weight_channel},
|
||||
{"skip_quant_node", common_quant_string_.skip_quant_node},
|
||||
{"debug_info_save_path", common_quant_string_.debug_info_save_path},
|
||||
{"target_device", common_quant_string_.target_device},
|
||||
};
|
||||
return SetMapData(map, parse_map, kCommonQuantParam);
|
||||
}
|
||||
|
@ -139,6 +138,7 @@ int ConfigFileParser::ParseFullQuantString(const std::map<std::string, std::map<
|
|||
std::map<std::string, std::string &> parse_map{
|
||||
{"activation_quant_method", full_quant_string_.activation_quant_method},
|
||||
{"bias_correction", full_quant_string_.bias_correction},
|
||||
{"target_device", full_quant_string_.target_device},
|
||||
};
|
||||
return SetMapData(map, parse_map, kFullQuantParam);
|
||||
}
|
||||
|
|
|
@ -43,7 +43,6 @@ struct CommonQuantString {
|
|||
std::string min_quant_weight_channel;
|
||||
std::string skip_quant_node;
|
||||
std::string debug_info_save_path;
|
||||
std::string target_device;
|
||||
};
|
||||
|
||||
struct MixedBitWeightQuantString {
|
||||
|
@ -54,6 +53,7 @@ struct MixedBitWeightQuantString {
|
|||
struct FullQuantString {
|
||||
std::string activation_quant_method;
|
||||
std::string bias_correction;
|
||||
std::string target_device;
|
||||
};
|
||||
|
||||
struct RegistryInfoString {
|
||||
|
|
|
@ -104,13 +104,6 @@ int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_str
|
|||
common_quant->is_debug = true;
|
||||
}
|
||||
|
||||
if (!common_quant_string.target_device.empty()) {
|
||||
ret = ParseTargetDevice(common_quant_string.target_device, &common_quant->target_device);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Parse device failed.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -147,6 +140,13 @@ int QuantParamParser::ParseFullQuant(const FullQuantString &full_quant_string, q
|
|||
MS_LOG(ERROR) << "INPUT ILLEGAL: bias_correction should be true or false.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (!full_quant_string.target_device.empty()) {
|
||||
auto ret = ParseTargetDevice(full_quant_string.target_device, &full_quant->target_device);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Parse device failed.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -167,14 +167,11 @@ int QuantParamParser::ParseQuantType(const std::string &quant_type_str, schema::
|
|||
}
|
||||
|
||||
int QuantParamParser::ParseTargetDevice(const std::string &target_device_str, quant::TargetDevice *target_device) {
|
||||
if (target_device_str == "CPU") {
|
||||
(*target_device) = quant::CPU;
|
||||
return RET_OK;
|
||||
} else if (target_device_str == "KIRIN") {
|
||||
if (target_device_str == "KIRIN") {
|
||||
(*target_device) = quant::KIRIN;
|
||||
return RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: target_device must be CPU|KIRIN.";
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: target_device must be KIRIN.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -578,7 +578,7 @@ int FullQuantQuantizer::MarkQuantNode(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
|
||||
int FullQuantQuantizer::PreProcess(const FuncGraphPtr &func_graph) {
|
||||
switch (flags_.commonQuantParam.target_device) {
|
||||
switch (flags_.fullQuantParam.target_device) {
|
||||
case CPU:
|
||||
InitCpuConfig();
|
||||
break;
|
||||
|
@ -586,7 +586,7 @@ int FullQuantQuantizer::PreProcess(const FuncGraphPtr &func_graph) {
|
|||
InitKirinConfig();
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << " Unsupported device " << flags_.commonQuantParam.target_device;
|
||||
MS_LOG(ERROR) << " Unsupported device " << flags_.fullQuantParam.target_device;
|
||||
return RET_ERROR;
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -41,7 +41,6 @@ struct CommonQuantParam {
|
|||
bool is_debug = false;
|
||||
std::string debug_info_save_path;
|
||||
std::set<std::string> skip_quant_node;
|
||||
TargetDevice target_device = CPU;
|
||||
int thread_num = 4;
|
||||
};
|
||||
|
||||
|
@ -53,6 +52,7 @@ struct MixedBitWeightQuantParam {
|
|||
struct FullQuantParam {
|
||||
ActivationQuantizedMethod activation_quant_method = MAX_MIN;
|
||||
bool bias_correction = true;
|
||||
TargetDevice target_device = CPU;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
|
||||
|
|
Loading…
Reference in New Issue