Optimize nvgpu bias correction
This commit is contained in:
parent
ea80c9d188
commit
1f13808a91
|
@ -97,7 +97,7 @@ hiai_model_normalize_object_scene_ps_20200519.pb;1;1,224,224,3 18
|
|||
hiai_label_and_video.pb;1;1,224,224,3 16
|
||||
tinyyolov2-8.onnx;1;1,416,416,3 11
|
||||
emotion-ferplus-8.onnx
|
||||
rcnn-ilsvrc13-9.onnx
|
||||
#rcnn-ilsvrc13-9.onnx
|
||||
shufflenet-v2-10.onnx
|
||||
squeezenet1.1-7.onnx
|
||||
ml_table_detection_fp32_tmp.onnx
|
||||
|
|
|
@ -301,7 +301,7 @@ porseg_tmp.onnx;2:img,prev_mask
|
|||
hiai_nlu_onnx_model_v1_0.onnx;3:input_ids,segment_ids,position_ids
|
||||
ml_video_edit_makeup_mobilenetv203.onnx;1:input.1
|
||||
Q888_CV_face_recognition_self.onnx;1:input
|
||||
ml_video_edit_hair_dyeing_migrate_v2_fix.onnx;4
|
||||
ml_video_edit_hair_dyeing_migrate_v2_fix.onnx;4 3
|
||||
ml_motion_capture_spin_mobile_mv3_v3_57mm_sim.onnx;5:input,bbox,init_pose,init_shape,init_cam
|
||||
ml_video_edit_dimming_tech_model_345000_color.onnx;2:input.18,1
|
||||
Ireland_gaze_corrector.onnx;3:image,target_angle,strength 1
|
||||
|
|
|
@ -109,17 +109,21 @@ int PreprocessParser::ParseCalibratePath(const std::string &str, std::map<std::s
|
|||
}
|
||||
auto key_values = SplitStringToVector(str, ',');
|
||||
for (const auto &key_value : key_values) {
|
||||
auto tmp = SplitStringToVector(key_value, ':');
|
||||
if (tmp.size() != 2) {
|
||||
MS_LOG(ERROR) << "vector need size = 2, size is " << tmp.size();
|
||||
auto string_split = SplitStringToVector(key_value, ':');
|
||||
const size_t min_size = 2;
|
||||
if (string_split.size() < min_size) {
|
||||
MS_LOG(ERROR) << "vector need size >= 2, size is " << string_split.size();
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
auto data_path = RealPath(tmp.at(1).c_str());
|
||||
auto data_path = string_split.at(1);
|
||||
for (size_t i = 2; i < string_split.size() - 1; ++i) {
|
||||
data_path += ":" + string_split[i];
|
||||
}
|
||||
if (data_path.empty()) {
|
||||
MS_LOG(ERROR) << "path is invalid.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
(*value)[tmp.at(0)] = data_path;
|
||||
(*value)[string_split.at(0)] = data_path;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -161,7 +161,8 @@ KernelCallBack BiasCorrectionStrategy::GetCPUFloatBeforeCallBack() {
|
|||
auto before_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
|
||||
const CallBackParam &call_param) -> bool {
|
||||
if (kSupportBiasCorrectionNode.find(call_param.node_type) == kSupportBiasCorrectionNode.end()) {
|
||||
auto is_skip_op = quant_strategy_->IsSkipOp(call_param.node_name);
|
||||
if (kSupportBiasCorrectionNode.find(call_param.node_type) == kSupportBiasCorrectionNode.end() || is_skip_op) {
|
||||
return true;
|
||||
}
|
||||
auto tensor = before_inputs[0];
|
||||
|
@ -191,7 +192,8 @@ KernelCallBack BiasCorrectionStrategy::GetCPUInt8BeforeCallBack() {
|
|||
auto before_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
|
||||
const CallBackParam &call_param) -> bool {
|
||||
if (kSupportBiasCorrectionNode.find(call_param.node_type) == kSupportBiasCorrectionNode.end()) {
|
||||
auto is_skip_op = quant_strategy_->IsSkipOp(call_param.node_name);
|
||||
if (kSupportBiasCorrectionNode.find(call_param.node_type) == kSupportBiasCorrectionNode.end() || is_skip_op) {
|
||||
return true;
|
||||
}
|
||||
auto tensor = before_inputs[0];
|
||||
|
@ -229,7 +231,8 @@ KernelCallBack BiasCorrectionStrategy::GetCPUInt8AfterCallBack() {
|
|||
auto after_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
|
||||
const CallBackParam &call_param) -> bool {
|
||||
if (kSupportBiasCorrectionNode.find(call_param.node_type) == kSupportBiasCorrectionNode.end()) {
|
||||
auto is_skip_op = quant_strategy_->IsSkipOp(call_param.node_name);
|
||||
if (kSupportBiasCorrectionNode.find(call_param.node_type) == kSupportBiasCorrectionNode.end() || is_skip_op) {
|
||||
return true;
|
||||
}
|
||||
auto tensor = after_outputs[0];
|
||||
|
@ -271,7 +274,8 @@ KernelCallBack BiasCorrectionStrategy::GetCPUFloatAfterCallBack() {
|
|||
auto after_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
|
||||
const CallBackParam &call_param) -> bool {
|
||||
if (kSupportBiasCorrectionNode.find(call_param.node_type) == kSupportBiasCorrectionNode.end()) {
|
||||
auto is_skip_op = quant_strategy_->IsSkipOp(call_param.node_name);
|
||||
if (kSupportBiasCorrectionNode.find(call_param.node_type) == kSupportBiasCorrectionNode.end() || is_skip_op) {
|
||||
return true;
|
||||
}
|
||||
auto tensor = after_outputs[0];
|
||||
|
@ -351,30 +355,36 @@ int BiasCorrectionStrategy::CreateQuantModel(const FuncGraphPtr &quant_func_grap
|
|||
auto int8_sm = CreateSessionByFuncGraph(quant_func_graph, flags_, this->flags_.commonQuantParam.thread_num);
|
||||
int8_session_ = int8_sm.session;
|
||||
int8_model_ = int8_sm.model;
|
||||
if (int8_session_ == nullptr || int8_model_ == nullptr) {
|
||||
MS_LOG(ERROR) << "create session failed!";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
CHECK_NULL_RETURN(int8_session_);
|
||||
CHECK_NULL_RETURN(int8_model_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BiasCorrectionStrategy::DoCPUBiasCorrection(const FuncGraphPtr &quant_func_graph) {
|
||||
CHECK_NULL_RETURN(calibrator_);
|
||||
CHECK_NULL_RETURN(quant_strategy_);
|
||||
CHECK_NULL_RETURN(fp32_session_);
|
||||
CHECK_NULL_RETURN(fp32_model_);
|
||||
int8_before_call_back_ = GetBeforeCallBack(CPUInt8);
|
||||
int8_after_call_back_ = GetAfterCallBack(CPUInt8);
|
||||
fp32_before_call_back_ = GetBeforeCallBack(CPUFP32);
|
||||
fp32_after_call_back_ = GetAfterCallBack(CPUFP32);
|
||||
return DoBiasCorrection(quant_func_graph);
|
||||
return DoBiasCorrection(quant_func_graph, true);
|
||||
}
|
||||
|
||||
int BiasCorrectionStrategy::DoNVGPUBiasCorrection(const FuncGraphPtr &quant_func_graph) {
|
||||
CHECK_NULL_RETURN(calibrator_);
|
||||
CHECK_NULL_RETURN(quant_strategy_);
|
||||
CHECK_NULL_RETURN(fp32_session_);
|
||||
CHECK_NULL_RETURN(fp32_model_);
|
||||
int8_before_call_back_ = GetBeforeCallBack(NVGPUInt8);
|
||||
int8_after_call_back_ = GetAfterCallBack(NVGPUInt8);
|
||||
fp32_before_call_back_ = GetBeforeCallBack(CPUFP32);
|
||||
fp32_after_call_back_ = GetAfterCallBack(CPUFP32);
|
||||
return DoBiasCorrection(quant_func_graph);
|
||||
return DoBiasCorrection(quant_func_graph, false);
|
||||
}
|
||||
|
||||
int BiasCorrectionStrategy::DoBiasCorrection(const FuncGraphPtr &quant_func_graph) {
|
||||
int BiasCorrectionStrategy::DoBiasCorrection(const FuncGraphPtr &quant_func_graph, bool int32_bias) {
|
||||
auto ret = CreateQuantModel(quant_func_graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Create quant model failed:" << ret;
|
||||
|
@ -408,21 +418,106 @@ int BiasCorrectionStrategy::DoBiasCorrection(const FuncGraphPtr &quant_func_grap
|
|||
if (op_bias_diff_sum_map_.find(op_name) == op_bias_diff_sum_map_.end()) {
|
||||
continue;
|
||||
}
|
||||
status = DoCNodeBiasCorrection(quant_func_graph, cnode);
|
||||
status = DoCNodeBiasCorrection(quant_func_graph, cnode, int32_bias);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "do node bias correct failed.";
|
||||
MS_LOG(ERROR) << op_name << " do node bias correct failed.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
int BiasCorrectionStrategy::DoCNodeBiasCorrection(const FuncGraphPtr &quant_func_graph, const CNodePtr &cnode) {
|
||||
int BiasCorrectionStrategy::CreateFp32BiasTensor(const FuncGraphPtr &quant_func_graph, const CNodePtr &cnode,
|
||||
const ParameterPtr ¶meter, const std::vector<float> &bias_diff) {
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << op_name << " parameter is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(bias_diff.size());
|
||||
|
||||
auto tensor_info = CreateTensorInfo(bias_diff.data(), sizeof(float) * bias_diff.size(), shape, kNumberTypeFloat32);
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << op_name << " create tensor info failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = InitParameterFromTensorInfo(parameter, tensor_info);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << op_name << " init parameter from tensor info failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
parameter->set_name("added_" + op_name + "_bias");
|
||||
cnode->add_input(parameter);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BiasCorrectionStrategy::AddBiasToInt32Tensor(const CNodePtr &cnode, const tensor::TensorPtr &bias_tensor,
|
||||
const std::vector<schema::QuantParamT> &bias_quant_params,
|
||||
const std::vector<float> &bias_diff) {
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
int *bias_datas = static_cast<int *>(bias_tensor->data_c());
|
||||
if (static_cast<size_t>(bias_tensor->DataSize()) != bias_diff.size()) {
|
||||
MS_LOG(ERROR) << op_name << " unexpected bias data count: " << bias_tensor->DataSize()
|
||||
<< " not the same as bias_diff: " << bias_diff.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (bias_quant_params.size() != bias_diff.size()) {
|
||||
MS_LOG(ERROR) << op_name << " unexpected bias quant params size: " << bias_quant_params.size()
|
||||
<< " not the same as bias_diff: " << bias_diff.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < bias_tensor->DataSize(); i++) {
|
||||
auto scale = bias_quant_params[i].scale;
|
||||
if (fabs(scale) <= 0.0f) {
|
||||
MS_LOG(ERROR) << op_name << " divisor 'scale' cannot be 0.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
double after_correct = std::round(bias_diff[i] / scale) + bias_datas[i];
|
||||
const constexpr int32_t corrected_bias_abs_limit = 0.6 * INT32_MAX;
|
||||
if (after_correct > corrected_bias_abs_limit) {
|
||||
MS_LOG(WARNING) << op_name << " ch: " << i << " bias after_corrected too large: " << after_correct
|
||||
<< " origin value: " << bias_datas[i] << " bias_diff: " << bias_diff[i] << " scale: " << scale;
|
||||
bias_datas[i] = static_cast<int>(corrected_bias_abs_limit);
|
||||
} else if (after_correct < -corrected_bias_abs_limit) {
|
||||
MS_LOG(WARNING) << op_name << " ch: " << i << " bias after_corrected too small: " << after_correct
|
||||
<< " origin value: " << bias_datas[i] << " bias_diff: " << bias_diff[i] << " scale: " << scale;
|
||||
bias_datas[i] = static_cast<int>(-corrected_bias_abs_limit);
|
||||
} else {
|
||||
auto diff = static_cast<int>(std::round(bias_diff[i] / scale));
|
||||
bias_datas[i] += diff;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BiasCorrectionStrategy::AddBiasToFp32Tensor(const CNodePtr &cnode, const tensor::TensorPtr &bias_tensor,
|
||||
const std::vector<float> &bias_diff) {
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
auto bias_datas = static_cast<float *>(bias_tensor->data_c());
|
||||
if (static_cast<size_t>(bias_tensor->DataSize()) != bias_diff.size()) {
|
||||
MS_LOG(ERROR) << op_name << " unexpected bias data count: " << bias_tensor->DataSize()
|
||||
<< " not the same as bias_diff: " << bias_diff.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (bias_tensor->DataSize() != bias_diff.size()) {
|
||||
MS_LOG(ERROR) << op_name << " unexpected bias size: " << bias_tensor->DataSize()
|
||||
<< " not the same as bias_diff: " << bias_diff.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < bias_tensor->DataSize(); i++) {
|
||||
bias_datas[i] += bias_diff[i];
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BiasCorrectionStrategy::DoCNodeBiasCorrection(const FuncGraphPtr &quant_func_graph, const CNodePtr &cnode,
|
||||
bool int32_bias) {
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
const auto &bias_diff = op_bias_diff_sum_map_[op_name];
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << op_name << " primitive is nullptr";
|
||||
MS_LOG(ERROR) << op_name << " primitive is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
|
@ -430,71 +525,39 @@ int BiasCorrectionStrategy::DoCNodeBiasCorrection(const FuncGraphPtr &quant_func
|
|||
auto input_quant_params = quant_param_holder->get_input_quant_params();
|
||||
if (input_quant_params.size() == kHasBiasTensorSize) {
|
||||
// compensate the existed
|
||||
auto bias_quant_params = input_quant_params.at(THIRD_INPUT);
|
||||
auto bias = cnode->input(THIRD_INPUT + 1);
|
||||
auto bias_parameter_ptr = bias->cast<ParameterPtr>();
|
||||
auto bias_default_param = bias_parameter_ptr->default_param();
|
||||
auto bias_param = bias_default_param->cast<tensor::TensorPtr>();
|
||||
int *bias_datas = static_cast<int *>(bias_param->data_c());
|
||||
|
||||
if (static_cast<size_t>(bias_param->DataSize()) != bias_diff.size()) {
|
||||
MS_LOG(DEBUG) << op_name << " unexpected bias data count: " << bias_param->DataSize()
|
||||
<< " not the same as bias_diff: " << bias_diff.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (bias_quant_params.size() != bias_diff.size()) {
|
||||
MS_LOG(ERROR) << op_name << " unexpected bias quant params size: " << bias_quant_params.size()
|
||||
<< " not the same as bias_diff: " << bias_diff.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < bias_param->DataSize(); i++) {
|
||||
auto scale = bias_quant_params[i].scale;
|
||||
if (fabs(scale) <= 0.0f) {
|
||||
MS_LOG(ERROR) << op_name << " divisor 'scale' cannot be 0.";
|
||||
auto bias_tensor = bias_default_param->cast<tensor::TensorPtr>();
|
||||
if (int32_bias) {
|
||||
auto bias_quant_params = input_quant_params.at(THIRD_INPUT);
|
||||
auto status = AddBiasToInt32Tensor(cnode, bias_tensor, bias_quant_params, bias_diff);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << op_name << " Add bias to int32 tensor failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
double after_correct = std::round(bias_diff[i] / scale) + bias_datas[i];
|
||||
const constexpr int32_t corrected_bias_abs_limit = 0.6 * INT32_MAX;
|
||||
if (after_correct > corrected_bias_abs_limit) {
|
||||
MS_LOG(WARNING) << op_name << " ch: " << i << " bias after_corrected too large: " << after_correct
|
||||
<< " origin value: " << bias_datas[i] << " bias_diff: " << bias_diff[i] << " scale: " << scale;
|
||||
bias_datas[i] = static_cast<int>(corrected_bias_abs_limit);
|
||||
} else if (after_correct < -corrected_bias_abs_limit) {
|
||||
MS_LOG(WARNING) << op_name << " ch: " << i << " bias after_corrected too small: " << after_correct
|
||||
<< " origin value: " << bias_datas[i] << " bias_diff: " << bias_diff[i] << " scale: " << scale;
|
||||
bias_datas[i] = static_cast<int>(-corrected_bias_abs_limit);
|
||||
} else {
|
||||
auto diff = static_cast<int>(std::round(bias_diff[i] / scale));
|
||||
bias_datas[i] += diff;
|
||||
} else {
|
||||
auto status = AddBiasToFp32Tensor(cnode, bias_tensor, bias_diff);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << op_name << " Add bias to int32 tensor failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
} else if (input_quant_params.size() == kHasBiasTensorSize - 1) {
|
||||
MS_LOG(INFO) << op_name << " add bias input";
|
||||
// need to add bias input
|
||||
auto parameter = quant_func_graph->add_parameter();
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
std::vector<int64_t> shape;
|
||||
shape.push_back(bias_diff.size());
|
||||
|
||||
auto tensor_info = CreateTensorInfo(bias_diff.data(), sizeof(float) * bias_diff.size(), shape, kNumberTypeFloat32);
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << op_name << " create tensor info failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = InitParameterFromTensorInfo(parameter, tensor_info);
|
||||
auto status = CreateFp32BiasTensor(quant_func_graph, cnode, parameter, bias_diff);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << op_name << " init parameter from tensor info failed";
|
||||
MS_LOG(ERROR) << op_name << " Create fp32 bias tensor failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
parameter->set_name("added_" + op_name + "_bias");
|
||||
cnode->add_input(parameter);
|
||||
status = DoParameterBiasQuant(parameter, primitive);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << op_name << " Do bias quant failed.";
|
||||
return RET_ERROR;
|
||||
if (int32_bias) {
|
||||
status = DoParameterBiasQuant(parameter, primitive);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << op_name << " Do bias quant failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << op_name << " unexpected size: " << input_quant_params.size()
|
||||
|
@ -507,13 +570,14 @@ KernelCallBack BiasCorrectionStrategy::GetNVGPUInt8BeforeCallBack() {
|
|||
auto before_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
|
||||
const CallBackParam &call_param) -> bool {
|
||||
if (kSupportBiasCorrectionNode.find(call_param.node_type) == kSupportBiasCorrectionNode.end()) {
|
||||
auto is_skip_op = quant_strategy_->IsSkipOp(call_param.node_name);
|
||||
if (kSupportBiasCorrectionNode.find(call_param.node_type) == kSupportBiasCorrectionNode.end() || is_skip_op) {
|
||||
return true;
|
||||
}
|
||||
auto feature_map_tensor = before_inputs[0];
|
||||
MS_ASSERT(feature_map_tensor != nullptr);
|
||||
// op can be skipped.
|
||||
if (feature_map_tensor->data_type() != kNumberTypeInt8) {
|
||||
if (feature_map_tensor->data_type() != kNumberTypeFloat32) {
|
||||
MS_LOG(INFO) << "feature_map_tensor type is " << feature_map_tensor->data_type();
|
||||
return true;
|
||||
}
|
||||
|
@ -524,12 +588,21 @@ KernelCallBack BiasCorrectionStrategy::GetNVGPUInt8BeforeCallBack() {
|
|||
}
|
||||
// do quantization: activation is always per layer quantized
|
||||
std::vector<int8_t> quant_datas;
|
||||
QuantOriginFeatureMap(static_cast<float *>(feature_map_tensor->data()), feature_map_tensor->ElementsNum(),
|
||||
feature_map_tensor->quant_params(), feature_map_tensor->Size(), &quant_datas);
|
||||
std::vector<double> dequant_data;
|
||||
DeQuantData(quant_datas.data(), quant_datas.size(), feature_map_tensor->quant_params(), &dequant_data);
|
||||
auto ret = memcpy_s(feature_map_tensor->data(), feature_map_tensor->Size(), quant_datas.data(),
|
||||
quant_datas.size() * sizeof(int8_t));
|
||||
auto ret = QuantOriginFeatureMap(static_cast<float *>(feature_map_tensor->data()),
|
||||
feature_map_tensor->ElementsNum(), feature_map_tensor->quant_params(),
|
||||
feature_map_tensor->ElementsNum() * sizeof(int8_t), &quant_datas);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << feature_map_tensor->tensor_name() << " Quant origin feature map failed. " << ret;
|
||||
return false;
|
||||
}
|
||||
std::vector<float> dequant_data;
|
||||
ret = DeQuantData(quant_datas.data(), quant_datas.size(), feature_map_tensor->quant_params(), &dequant_data);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DeQuant origin feature map failed. " << ret;
|
||||
return false;
|
||||
}
|
||||
ret = memcpy_s(feature_map_tensor->data(), feature_map_tensor->Size(), dequant_data.data(),
|
||||
dequant_data.size() * sizeof(float));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
return false;
|
||||
|
@ -543,7 +616,8 @@ KernelCallBack BiasCorrectionStrategy::GetNVGPUInt8AfterCallBack() {
|
|||
auto after_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
|
||||
const CallBackParam &call_param) -> bool {
|
||||
if (kSupportBiasCorrectionNode.find(call_param.node_type) == kSupportBiasCorrectionNode.end()) {
|
||||
auto is_skip_op = quant_strategy_->IsSkipOp(call_param.node_name);
|
||||
if (kSupportBiasCorrectionNode.find(call_param.node_type) == kSupportBiasCorrectionNode.end() || is_skip_op) {
|
||||
return true;
|
||||
}
|
||||
auto tensor = after_outputs[0];
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "base/base.h"
|
||||
#include "ir/anf.h"
|
||||
#include "tools/converter/quantizer/calibrator.h"
|
||||
#include "tools/converter/quantizer/quant_strategy.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
enum OperationType {
|
||||
|
@ -37,10 +38,11 @@ enum CallBackType {
|
|||
class BiasCorrectionStrategy {
|
||||
public:
|
||||
BiasCorrectionStrategy(const converter::Flags &flags, const std::shared_ptr<Calibrator> &calibrator,
|
||||
session::LiteSession *fp32_session, Model *fp32_model, int activation_q_min,
|
||||
int activation_q_max)
|
||||
const std::shared_ptr<QuantStrategy> &quant_strategy, session::LiteSession *fp32_session,
|
||||
Model *fp32_model, int activation_q_min, int activation_q_max)
|
||||
: flags_(flags),
|
||||
calibrator_(calibrator),
|
||||
quant_strategy_(quant_strategy),
|
||||
fp32_session_(fp32_session),
|
||||
fp32_model_(fp32_model),
|
||||
activation_q_min_(activation_q_min),
|
||||
|
@ -59,8 +61,8 @@ class BiasCorrectionStrategy {
|
|||
|
||||
private:
|
||||
int CreateQuantModel(const FuncGraphPtr &quant_func_graph);
|
||||
int DoBiasCorrection(const FuncGraphPtr &quant_func_graph);
|
||||
int DoCNodeBiasCorrection(const FuncGraphPtr &quant_func_graph, const CNodePtr &cnode);
|
||||
int DoBiasCorrection(const FuncGraphPtr &quant_func_graph, bool int32_bias);
|
||||
int DoCNodeBiasCorrection(const FuncGraphPtr &quant_func_graph, const CNodePtr &cnode, bool int32_bias);
|
||||
int Int8Inference(const KernelCallBack &before_call_back, const KernelCallBack &after_call_back);
|
||||
int Fp32Inference(const KernelCallBack &before_call_back, const KernelCallBack &after_call_back);
|
||||
bool OpInputDataHandle(OperationType type, const string &op_name, std::vector<float> *data);
|
||||
|
@ -80,6 +82,15 @@ class BiasCorrectionStrategy {
|
|||
const std::vector<lite::LiteQuantParam> &feature_map_quant_params, size_t quant_size,
|
||||
std::vector<int8_t> *quant_datas);
|
||||
|
||||
int CreateFp32BiasTensor(const FuncGraphPtr &quant_func_graph, const CNodePtr &cnode, const ParameterPtr ¶meter,
|
||||
const std::vector<float> &bias_diff);
|
||||
|
||||
int AddBiasToInt32Tensor(const CNodePtr &cnode, const tensor::TensorPtr &bias_tensor,
|
||||
const std::vector<schema::QuantParamT> &bias_quant_params,
|
||||
const std::vector<float> &bias_diff);
|
||||
|
||||
int AddBiasToFp32Tensor(const CNodePtr &cnode, const tensor::TensorPtr &bias_tensor,
|
||||
const std::vector<float> &bias_diff);
|
||||
template <typename T>
|
||||
int CalculatePerChannelMeans(const T *tensor_data, size_t elem_count, std::vector<int> shapes,
|
||||
std::vector<float> *per_channel_mean) {
|
||||
|
@ -110,6 +121,7 @@ class BiasCorrectionStrategy {
|
|||
private:
|
||||
converter::Flags flags_;
|
||||
std::shared_ptr<Calibrator> calibrator_{nullptr};
|
||||
std::shared_ptr<QuantStrategy> quant_strategy_{nullptr};
|
||||
session::LiteSession *fp32_session_{nullptr};
|
||||
Model *fp32_model_{nullptr};
|
||||
int activation_q_min_{INT8_MIN};
|
||||
|
|
|
@ -36,12 +36,14 @@ int DataDistribution::RecordMaxMinValueArray(const std::vector<float> &data) {
|
|||
real_max_ = std::max(max_num, real_max_);
|
||||
if (activation_quant_method_ == REMOVAL_OUTLIER) {
|
||||
auto bak_data(data);
|
||||
auto const q_min = static_cast<int>(0.0001 * bak_data.size());
|
||||
auto const q_max = static_cast<int>(0.9999 * bak_data.size());
|
||||
std::nth_element(bak_data.begin(), bak_data.begin() + q_min, bak_data.end());
|
||||
auto quantile_min = bak_data.at(q_min);
|
||||
std::nth_element(bak_data.begin() + q_min + 1, bak_data.begin() + q_max, bak_data.end());
|
||||
auto quantile_max = bak_data.at(q_max);
|
||||
const float min_percentage = 0.0001;
|
||||
const float max_percentage = 0.9999;
|
||||
auto const quantile_min_index = static_cast<int>(min_percentage * bak_data.size());
|
||||
auto const quantile_max_index = static_cast<int>(max_percentage * bak_data.size());
|
||||
std::nth_element(bak_data.begin(), bak_data.begin() + quantile_min_index, bak_data.end());
|
||||
auto quantile_min = bak_data.at(quantile_min_index);
|
||||
std::nth_element(bak_data.begin() + quantile_min_index + 1, bak_data.begin() + quantile_max_index, bak_data.end());
|
||||
auto quantile_max = bak_data.at(quantile_max_index);
|
||||
MS_LOG(DEBUG) << "real_min_:" << real_min_ << " real_max_:" << real_max_ << " quantile_min:" << quantile_min
|
||||
<< " quantile_max:" << quantile_max;
|
||||
this->min_datas_.emplace_back(quantile_min);
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "src/tensor.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/converter/quantizer/quant_strategy.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "securec/include/securec.h"
|
||||
|
@ -438,18 +437,14 @@ int FullQuantQuantizer::MarkQuantNode(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << cnode->fullname_with_scope() << " cnode is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto quant_strategy = std::make_unique<QuantStrategy>(flags_.commonQuantParam.min_quant_weight_size,
|
||||
flags_.commonQuantParam.min_quant_weight_channel,
|
||||
flags_.commonQuantParam.skip_quant_node);
|
||||
CHECK_NULL_RETURN(quant_strategy);
|
||||
auto is_skip_op = quant_strategy->IsSkipOp(anode);
|
||||
auto is_skip_op = quant_strategy_->IsSkipOp(anode->fullname_with_scope());
|
||||
if (is_skip_op) {
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << " is skip quant.";
|
||||
continue;
|
||||
}
|
||||
// Mark quantifiable nodes
|
||||
auto is_support_op =
|
||||
quant_strategy->CanOpFullQuantized(anode, support_int8_ops_, skip_check_dtype_ops_, support_activation_);
|
||||
quant_strategy_->CanOpFullQuantized(anode, support_int8_ops_, skip_check_dtype_ops_, support_activation_);
|
||||
if (is_support_op) {
|
||||
auto ret = calibrator_->AddQuantizedOp(cnode);
|
||||
if (ret != RET_OK) {
|
||||
|
@ -479,6 +474,10 @@ int FullQuantQuantizer::PreProcess(const FuncGraphPtr &func_graph) {
|
|||
this->flags_.fullQuantParam.activation_quant_method,
|
||||
this->flags_.dataPreProcessParam, activation_symmetry_);
|
||||
MSLITE_CHECK_PTR(calibrator_);
|
||||
quant_strategy_ = std::make_unique<QuantStrategy>(flags_.commonQuantParam.min_quant_weight_size,
|
||||
flags_.commonQuantParam.min_quant_weight_channel,
|
||||
flags_.commonQuantParam.skip_quant_node);
|
||||
CHECK_NULL_RETURN(quant_strategy_);
|
||||
auto ret = MarkQuantNode(func_graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Mark quant node failed.";
|
||||
|
@ -603,26 +602,26 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->flags_.fullQuantParam.bias_correction) {
|
||||
MS_LOG(INFO) << "do bias correction";
|
||||
BiasCorrectionStrategy strategy(flags_, calibrator_, fp32_session_, fp32_model_, activation_q_min_,
|
||||
activation_q_max_);
|
||||
switch (this->flags_.fullQuantParam.target_device) {
|
||||
case CPU:
|
||||
status = strategy.DoCPUBiasCorrection(func_graph);
|
||||
break;
|
||||
case NVGPU:
|
||||
status = strategy.DoNVGPUBiasCorrection(func_graph);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupported target device " << this->flags_.fullQuantParam.target_device
|
||||
<< " for bias correction.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "bias_correction failed.";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
if (this->flags_.fullQuantParam.bias_correction) {
|
||||
MS_LOG(INFO) << "do bias correction";
|
||||
BiasCorrectionStrategy strategy(flags_, calibrator_, quant_strategy_, fp32_session_, fp32_model_, activation_q_min_,
|
||||
activation_q_max_);
|
||||
switch (this->flags_.fullQuantParam.target_device) {
|
||||
case CPU:
|
||||
status = strategy.DoCPUBiasCorrection(func_graph);
|
||||
break;
|
||||
case NVGPU:
|
||||
status = strategy.DoNVGPUBiasCorrection(func_graph);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupported target device " << this->flags_.fullQuantParam.target_device
|
||||
<< " for bias correction.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "bias_correction failed.";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
|
@ -37,6 +37,7 @@
|
|||
#include "tools/converter/quantizer/calibrator.h"
|
||||
#include "tools/converter/quantizer/data_distribution.h"
|
||||
#include "src/common/quant_utils.h"
|
||||
#include "tools/converter/quantizer/quant_strategy.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class FullQuantQuantizer : public Quantizer {
|
||||
|
@ -88,6 +89,7 @@ class FullQuantQuantizer : public Quantizer {
|
|||
std::set<mindspore::ActivationType> support_activation_;
|
||||
|
||||
std::shared_ptr<Calibrator> calibrator_{nullptr};
|
||||
std::shared_ptr<QuantStrategy> quant_strategy_{nullptr};
|
||||
session::LiteSession *fp32_session_{nullptr};
|
||||
Model *fp32_model_{nullptr};
|
||||
|
||||
|
|
|
@ -142,7 +142,7 @@ int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph
|
|||
}
|
||||
}
|
||||
MS_CHECK_TRUE_MSG(weight_quant_size > 0, RET_ERROR, "weight quant size must large 0");
|
||||
auto compress_ratio = 1.0 * origin_model_size / weight_quant_size;
|
||||
const auto compress_ratio = 1.0 * origin_model_size / weight_quant_size;
|
||||
std::cout << " round:" << round << " scale:" << scale << " cos_sim:" << cos_sim << " mean_error:" << mean_error
|
||||
<< " ratio:" << compress_ratio << std::endl;
|
||||
if (cos_sim >= threshold && compress_ratio > best_compress_ratio) {
|
||||
|
@ -249,7 +249,7 @@ int ParameterOptimizer::GridSearchForScale(const FuncGraphPtr &func_graph, conve
|
|||
delete origin_model;
|
||||
return RET_OK;
|
||||
}
|
||||
int baby_step_rounds = 25;
|
||||
const int baby_step_rounds = 25;
|
||||
step = (min_max.max - min_max.min) / baby_step_rounds;
|
||||
|
||||
param.rounds = baby_step_rounds;
|
||||
|
|
|
@ -127,7 +127,7 @@ bool QuantStrategy::CanOpFullQuantized(const AnfNodePtr &node, const std::set<Pr
|
|||
return true;
|
||||
}
|
||||
|
||||
bool QuantStrategy::IsSkipOp(const AnfNodePtr &input_node) {
|
||||
return !(skip_node_.find(input_node->fullname_with_scope()) == skip_node_.end());
|
||||
bool QuantStrategy::IsSkipOp(const std::string &skip_node_name) {
|
||||
return !(skip_node_.find(skip_node_name) == skip_node_.end());
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -38,7 +38,7 @@ class QuantStrategy {
|
|||
const std::set<PrimitivePtr> &skip_check_dtype_ops,
|
||||
const std::set<mindspore::ActivationType> &support_activation);
|
||||
bool CanTensorQuantized(const CNodePtr &cnode, const AnfNodePtr &input_node, int preferred_dim);
|
||||
bool IsSkipOp(const AnfNodePtr &input_node);
|
||||
bool IsSkipOp(const std::string &skip_node_name);
|
||||
|
||||
private:
|
||||
size_t min_quant_weight_size_;
|
||||
|
|
|
@ -620,21 +620,6 @@ int DoParameterBiasQuant(const ParameterPtr &bias, const PrimitivePtr &primitive
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int DeQuantData(const int8_t *tensor_data, int64_t elements_num, std::vector<lite::LiteQuantParam> quant_params,
|
||||
std::vector<double> *dequant_data, int preferred_dim) {
|
||||
if (quant_params.size() != 1) {
|
||||
MS_LOG(ERROR) << "unexpected quant_params size: " << quant_params.size() << " only support per-layer now.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto scale = quant_params[0].scale;
|
||||
auto zp = quant_params[0].zeroPoint;
|
||||
dequant_data->resize(elements_num);
|
||||
for (int64_t i = 0; i < elements_num; i++) {
|
||||
dequant_data->at(i) = scale * (tensor_data[i] - zp);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DeQuantData(mindspore::tensor::MSTensor *tensor, std::vector<double> *dequant_data, int preferred_dim) {
|
||||
return DeQuantData(static_cast<int8_t *>(tensor->data()), tensor->ElementsNum(), tensor->quant_params(), dequant_data,
|
||||
preferred_dim);
|
||||
|
|
|
@ -100,11 +100,24 @@ int DoParameterBiasQuant(const ParameterPtr &bias, const PrimitivePtr &primitive
|
|||
|
||||
int DeQuantData(mindspore::tensor::MSTensor *tensor, std::vector<double> *dequant_data, int preferred_dim = 0);
|
||||
|
||||
int DeQuantData(const int8_t *tensor_data, int64_t elements_num, std::vector<lite::LiteQuantParam> quant_params,
|
||||
std::vector<double> *dequant_data, int preferred_dim = 0);
|
||||
|
||||
int DoBitPack(const size_t &bit_num, schema::TensorT *tensor_input);
|
||||
|
||||
template <typename T>
|
||||
int DeQuantData(const int8_t *tensor_data, int64_t elements_num, std::vector<lite::LiteQuantParam> quant_params,
|
||||
std::vector<T> *dequant_data, int preferred_dim = 0) {
|
||||
if (quant_params.size() != 1) {
|
||||
MS_LOG(ERROR) << "unexpected quant_params size: " << quant_params.size() << " only support per-layer now.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto scale = quant_params[0].scale;
|
||||
auto zp = quant_params[0].zeroPoint;
|
||||
dequant_data->resize(elements_num);
|
||||
for (int64_t i = 0; i < elements_num; i++) {
|
||||
dequant_data->at(i) = scale * (tensor_data[i] - zp);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int FixedBitQuantFilter(const AnfNodePtr ¶meter, const tensor::TensorPtr &weight, const PrimitivePtr &primitive,
|
||||
QuantType quant_type, int quant_max, int quant_min, size_t bit_num,
|
||||
|
@ -376,7 +389,7 @@ bool PackRepetition(size_t bit_num, schema::TensorT *tensor) {
|
|||
size_t coor_best_bit = 0;
|
||||
auto nz_cnt = CalCoorBestBit<T>(quant_data, elem_cnt, quant_params, unique_value_bit, &coor_best_bit);
|
||||
// 1. coor_best_bit 2. nz_cnt 3. quant_data_set size 4. unique_values 5. unique_value indexing 6. nz values coord
|
||||
auto pack_sparsity_size_in_bit =
|
||||
const auto pack_sparsity_size_in_bit =
|
||||
1 * k8Bit + 4 * k8Bit + bit_num + bit_num * unique_value_cnt + unique_value_bit * nz_cnt + nz_cnt * coor_best_bit;
|
||||
size_t pack_sparsity_size_in_byte = ceil(1.0 * pack_sparsity_size_in_bit / k8Bit);
|
||||
MS_LOG(DEBUG) << "coor_best_bit: " << coor_best_bit << " ori: " << origin_size_in_byte
|
||||
|
|
Loading…
Reference in New Issue