!33374 quantizer debugger new api

Merge pull request !33374 from liyan2022/dev_main
This commit is contained in:
i-robot 2022-04-22 02:57:40 +00:00 committed by Gitee
commit 20f1148148
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
13 changed files with 282 additions and 288 deletions

View File

@ -466,7 +466,7 @@ bool NeedBitUppackCheck(const SchemaTensorWrapper &src_tensor) {
return need_bit_unpack;
}
int WeightDecoder::DecompressTensor(const SchemaTensorWrapper &src_tensor, Tensor *dst_tensor) {
int WeightDecoder::DecompressTensor(const SchemaTensorWrapper &src_tensor, lite::Tensor *dst_tensor) {
MS_ASSERT(src_tensor.handler() != nullptr);
MS_ASSERT(dst_tensor != nullptr);
if (src_tensor.handler()->weightQunatCompressType() == schema::WeightQunatCompressType_FSE) {

View File

@ -172,7 +172,7 @@ class WeightDecoder {
}
}
static int DecompressTensor(const SchemaTensorWrapper &src_tensor, Tensor *dst_tensor);
static int DecompressTensor(const SchemaTensorWrapper &src_tensor, lite::Tensor *dst_tensor);
private:
static int DequantTensor(Tensor *tensor, int preferred_dim, TypeId dst_data_type = kNumberTypeFloat32);
@ -267,6 +267,7 @@ class WeightDecoder {
}
static int GetMatMulPreferredDim(const OpParameter *op_parameter, int input_index, const std::vector<int> &dims);
static int GetDeConvPreferredDim(const OpParameter *op_parameter, const std::vector<int> &dims);
template <typename T>

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,11 +28,13 @@
#include <cfloat>
#include "schema/inner/model_generated.h"
#include "src/common/log_adapter.h"
#include "src/common/log_util.h"
#include "ir/dtype/type_id.h"
#include "ir/tensor.h"
#include "src/common/utils.h"
#include "tools/common/statistic_utils.h"
#include "src/tensor.h"
#include "include/api/model.h"
namespace mindspore {
namespace lite {
@ -107,54 +109,56 @@ struct CheckTensor {
// tensorData need to be converter first
template <typename T>
float CompareDataByCosineDistance(const std::unordered_map<String, mindspore::tensor::MSTensor *> &calib_tensors,
const std::unordered_map<String, mindspore::tensor::MSTensor *> &out_tensors) {
if (calib_tensors.empty() || out_tensors.empty()) {
float CompareDataByCosineDistance(const std::shared_ptr<mindspore::Model> &origin_model,
const std::shared_ptr<mindspore::Model> &quant_model) {
CHECK_NULL_RETURN(origin_model);
CHECK_NULL_RETURN(quant_model);
if (origin_model->GetOutputs().empty() || quant_model->GetOutputs().empty()) {
MS_LOG(ERROR) << "calib or out tenor is empty.";
return RET_ERROR;
}
float total_cos = 0;
for (const auto &calib : calib_tensors) {
auto calib_tensors = origin_model->GetOutputs();
for (const auto &calib_tensor : calib_tensors) {
size_t error_count = 0;
float mean_error = 0;
auto calib_tensor = calib.second;
auto calib_data = static_cast<const T *>(calib_tensor->data());
auto out_tensor_iter = out_tensors.find(calib_tensor->tensor_name());
if (out_tensor_iter == out_tensors.end()) {
MS_LOG(ERROR) << "Cant find " << calib_tensor->tensor_name() << " in out_tensors";
auto calib_data = reinterpret_cast<const T *>(calib_tensor.Data().get());
auto out_tensor = quant_model->GetOutputByTensorName(calib_tensor.Name());
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "Cant find " << calib_tensor.Name() << " in out_tensors";
return RET_ERROR;
}
auto out_tensor = out_tensor_iter->second;
auto out_data = static_cast<const T *>(out_tensor->data());
auto cos = mindspore::lite::GetCosSimilarity<T>(calib_data, out_data, out_tensor->ElementsNum());
auto out_data = reinterpret_cast<const T *>(out_tensor.Data().get());
auto cos = mindspore::lite::GetCosSimilarity<T>(calib_data, out_data, out_tensor.ElementNum());
total_cos += cos;
MS_LOG(INFO) << "tensor_name:" << calib_tensor->tensor_name() << " cos_sim: " << mean_error
MS_LOG(INFO) << "tensor_name:" << calib_tensor.Name() << " cos_sim: " << mean_error
<< " error_count:" << error_count;
}
return total_cos / calib_tensors.size();
}
template <typename T>
float CompareData(const std::unordered_map<String, mindspore::tensor::MSTensor *> &calib_tensors,
const std::unordered_map<String, mindspore::tensor::MSTensor *> &out_tensors) {
if (calib_tensors.empty() || out_tensors.empty()) {
float CompareData(const std::shared_ptr<mindspore::Model> &origin_model,
const std::shared_ptr<mindspore::Model> &quant_model) {
CHECK_NULL_RETURN(origin_model);
CHECK_NULL_RETURN(quant_model);
if (origin_model->GetOutputs().empty() || quant_model->GetOutputs().empty()) {
MS_LOG(ERROR) << "calib or out tenor is empty.";
return RET_ERROR;
}
float total_meam_error = 0;
for (const auto &calib : calib_tensors) {
auto calib_tensors = origin_model->GetOutputs();
for (const auto &calib_tensor : calib_tensors) {
size_t error_count = 0;
float mean_error = 0;
auto calib_tensor = calib.second;
auto calib_data = static_cast<const T *>(calib_tensor->data());
auto out_tensor_iter = out_tensors.find(calib_tensor->tensor_name());
if (out_tensor_iter == out_tensors.end()) {
MS_LOG(ERROR) << "Cant find " << calib_tensor->tensor_name() << " in out_tensors";
auto calib_data = reinterpret_cast<const T *>(calib_tensor.Data().get());
auto out_tensor = quant_model->GetOutputByTensorName(calib_tensor.Name());
if (out_tensor == nullptr) {
MS_LOG(ERROR) << "Cant find " << calib_tensor.Name() << " in out_tensors";
return RET_ERROR;
}
auto out_tensor = out_tensor_iter->second;
auto out_data = static_cast<const T *>(out_tensor->data());
for (int j = 0; j < calib_tensor->ElementsNum(); j++) {
auto out_data = reinterpret_cast<const T *>(out_tensor.Data().get());
for (int j = 0; j < calib_tensor.ElementNum(); j++) {
if (std::is_same<T, float>::value && (std::isnan(out_data[j]) || std::isinf(out_data[j]))) {
MS_LOG(ERROR) << "Output tensor has nan or inf data, compare fail";
return RET_ERROR;
@ -182,7 +186,7 @@ float CompareData(const std::unordered_map<String, mindspore::tensor::MSTensor *
mean_error /= error_count;
}
total_meam_error += std::abs(mean_error);
MS_LOG(INFO) << "tensor_name:" << calib_tensor->tensor_name() << " mean_error: " << mean_error
MS_LOG(INFO) << "tensor_name:" << calib_tensor.Name() << " mean_error: " << mean_error
<< " error_count:" << error_count;
}
return total_meam_error / calib_tensors.size();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,15 +15,16 @@
*/
#define USE_DEPRECATED_API
#include <fstream>
#include <map>
#include "tools/converter/quantizer/debug_info_manager.h"
#include "src/weight_decoder.h"
#include "src/common/log_adapter.h"
#include "src/lite_session.h"
#include "include/errorcode.h"
#include "tools/converter/preprocess/image_preprocess.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/quantizer/quantize_util.h"
namespace mindspore::lite {
namespace {
@ -232,7 +233,7 @@ int DebugInfoManager::SetQuantStaticInfo(const std::vector<mindspore::tensor::MS
return RET_OK;
}
int DebugInfoManager::AddOriginInfo(const mindspore::CallBackParam &call_back_param, OpParameter *op_parameter,
int DebugInfoManager::AddOriginInfo(const mindspore::MSCallBackParam &call_back_param, OpParameter *op_parameter,
bool is_input, int tensor_index, mindspore::lite::Tensor *origin_tensor) {
CHECK_NULL_RETURN(op_parameter);
CHECK_NULL_RETURN(origin_tensor);
@ -269,7 +270,7 @@ int DebugInfoManager::AddOriginInfo(const mindspore::CallBackParam &call_back_pa
return RET_OK;
}
int DebugInfoManager::AddComparedInfo(const mindspore::CallBackParam &call_back_param,
int DebugInfoManager::AddComparedInfo(const mindspore::MSCallBackParam &call_back_param,
const std::vector<mindspore::tensor::MSTensor *> &inputs,
OpParameter *op_parameter, bool is_input, int tensor_index,
mindspore::lite::Tensor *compared_tensor) {
@ -301,7 +302,8 @@ int DebugInfoManager::AddComparedInfo(const mindspore::CallBackParam &call_back_
return RET_OK;
}
std::map<std::string, mindspore::schema::Tensor *> DebugInfoManager::ParseInputTensorFromModel(const Model &model) {
std::map<std::string, mindspore::schema::Tensor *> DebugInfoManager::ParseInputTensors(
const mindspore::lite::Model &model) {
std::map<std::string, mindspore::schema::Tensor *> maps;
for (auto node : model.all_nodes_) {
for (auto index : node->input_indices_) {
@ -376,34 +378,32 @@ int DebugInfoManager::GetConstTensor(const std::map<std::string, mindspore::sche
return RET_OK;
}
KernelCallBack DebugInfoManager::GetOriginBeforeCallBack(
MSKernelCallBack DebugInfoManager::GetOriginBeforeCallBack(
const std::map<std::string, mindspore::schema::Tensor *> &input_tensor_map,
const std::map<std::string, OpParameter *> &op_parameters) {
auto before_callback = [&](const std::vector<mindspore::tensor::MSTensor *> &inputs,
const std::vector<mindspore::tensor::MSTensor *> &outputs,
const CallBackParam &call_param) {
auto before_callback = [&](const std::vector<mindspore::MSTensor> &inputs,
const std::vector<mindspore::MSTensor> &outputs, const MSCallBackParam &call_param) {
for (size_t i = 0; i < inputs.size(); ++i) {
auto tensor = inputs.at(i);
MS_LOG(INFO) << " Get " << tensor->tensor_name() << " statistics info.";
auto is_const = static_cast<mindspore::lite::Tensor *>(tensor)->category() == CONST_TENSOR ||
static_cast<mindspore::lite::Tensor *>(tensor)->category() == CONST_SCALAR;
if (is_const) {
auto lite_tensor = quant::MSTensorToLiteTensor(inputs.at(i));
MS_LOG(INFO) << " Get " << tensor.Name() << " statistics info.";
if (tensor.IsConst()) {
mindspore::lite::Tensor new_tensor;
auto ret = GetConstTensor(input_tensor_map, tensor, &new_tensor);
auto ret = GetConstTensor(input_tensor_map, lite_tensor, &new_tensor);
if (ret != RET_OK) {
MS_LOG(ERROR) << tensor->tensor_name() << " get const tensor failed.";
MS_LOG(ERROR) << tensor.Name() << " get const tensor failed.";
return false;
}
ret = AddOriginInfo(call_param, op_parameters.at(call_param.node_name), true, i, &new_tensor);
if (ret != RET_OK) {
MS_LOG(ERROR) << tensor->tensor_name() << " add origin info failed.";
MS_LOG(ERROR) << tensor.Name() << " add origin info failed.";
return false;
}
} else {
auto ret = AddOriginInfo(call_param, op_parameters.at(call_param.node_name), true, i,
static_cast<mindspore::lite::Tensor *>(tensor));
static_cast<mindspore::lite::Tensor *>(lite_tensor));
if (ret != RET_OK) {
MS_LOG(ERROR) << tensor->tensor_name() << " add origin info failed.";
MS_LOG(ERROR) << tensor.Name() << " add origin info failed.";
return false;
}
}
@ -413,44 +413,45 @@ KernelCallBack DebugInfoManager::GetOriginBeforeCallBack(
return before_callback;
}
KernelCallBack DebugInfoManager::GetQuantBeforeCallBack(
MSKernelCallBack DebugInfoManager::GetQuantBeforeCallBack(
const std::map<std::string, mindspore::schema::Tensor *> &input_tensor_map,
const std::map<std::string, OpParameter *> &op_parameters) {
auto before_callback = [&](const std::vector<mindspore::tensor::MSTensor *> &inputs,
const std::vector<mindspore::tensor::MSTensor *> &outputs,
const CallBackParam &call_param) {
auto before_callback = [&](const std::vector<mindspore::MSTensor> &inputs,
const std::vector<mindspore::MSTensor> &outputs, const MSCallBackParam &call_param) {
auto lite_inputs = quant::MSTensorToLiteTensors(inputs);
for (size_t i = 0; i < inputs.size(); ++i) {
auto tensor = inputs.at(i);
MS_LOG(INFO) << " Get " << tensor->tensor_name() << " statistics info.";
if (save_flag_ && !tensor->quant_params().empty()) {
auto lite_tensor = quant::MSTensorToLiteTensor(tensor);
MS_LOG(INFO) << " Get " << tensor.Name() << " statistics info.";
if (save_flag_ && !tensor.QuantParams().empty()) {
QuantParamExtend quant_param;
quant_param.node_name = call_param.node_name;
quant_param.node_type = call_param.node_type;
quant_param.quant_params = tensor->quant_params();
quant_param.tensor_name = tensor->tensor_name();
quant_param.element_num = tensor->ElementsNum();
quant_param.dims = tensor->shape();
quant_param.quant_params = lite_tensor->quant_params();
quant_param.tensor_name = lite_tensor->tensor_name();
quant_param.element_num = lite_tensor->ElementsNum();
quant_param.dims = lite_tensor->shape();
quant_params_.push_back(quant_param);
}
auto is_const = static_cast<mindspore::lite::Tensor *>(tensor)->category() == CONST_TENSOR ||
static_cast<mindspore::lite::Tensor *>(tensor)->category() == CONST_SCALAR;
auto is_const = static_cast<mindspore::lite::Tensor *>(lite_tensor)->category() == CONST_TENSOR ||
static_cast<mindspore::lite::Tensor *>(lite_tensor)->category() == CONST_SCALAR;
if (is_const) {
mindspore::lite::Tensor new_tensor;
auto ret = GetConstTensor(input_tensor_map, tensor, &new_tensor);
auto ret = GetConstTensor(input_tensor_map, lite_tensor, &new_tensor);
if (ret != RET_OK) {
MS_LOG(ERROR) << tensor->tensor_name() << " get const tensor failed.";
MS_LOG(ERROR) << tensor.Name() << " get const tensor failed.";
return false;
}
ret = AddComparedInfo(call_param, inputs, op_parameters.at(call_param.node_name), true, i, &new_tensor);
ret = AddComparedInfo(call_param, lite_inputs, op_parameters.at(call_param.node_name), true, i, &new_tensor);
if (ret != RET_OK) {
MS_LOG(ERROR) << tensor->tensor_name() << " add compared info failed.";
MS_LOG(ERROR) << tensor.Name() << " add compared info failed.";
return false;
}
} else {
auto ret = AddComparedInfo(call_param, inputs, op_parameters.at(call_param.node_name), true, i,
static_cast<mindspore::lite::Tensor *>(tensor));
auto ret = AddComparedInfo(call_param, lite_inputs, op_parameters.at(call_param.node_name), true, i,
static_cast<mindspore::lite::Tensor *>(lite_tensor));
if (ret != RET_OK) {
MS_LOG(ERROR) << tensor->tensor_name() << " add compared info failed.";
MS_LOG(ERROR) << tensor.Name() << " add compared info failed.";
return false;
}
}
@ -460,7 +461,7 @@ KernelCallBack DebugInfoManager::GetQuantBeforeCallBack(
return before_callback;
}
KernelCallBack DebugInfoManager::GetBeforeCallBack(
MSKernelCallBack DebugInfoManager::GetBeforeCallBack(
const std::map<std::string, mindspore::schema::Tensor *> &input_tensor_map,
const std::map<std::string, OpParameter *> &op_parameters, bool is_origin) {
if (is_origin) {
@ -470,38 +471,41 @@ KernelCallBack DebugInfoManager::GetBeforeCallBack(
}
}
KernelCallBack DebugInfoManager::GetAfterCallBack(const std::map<std::string, OpParameter *> &op_parameters,
bool is_origin) {
KernelCallBack after_callback;
MSKernelCallBack DebugInfoManager::GetAfterCallBack(const std::map<std::string, OpParameter *> &op_parameters,
bool is_origin) {
MSKernelCallBack after_callback;
if (is_origin) {
after_callback = [&](const std::vector<mindspore::tensor::MSTensor *> &inputs,
const std::vector<mindspore::tensor::MSTensor *> &outputs, const CallBackParam &call_param) {
after_callback = [&](const std::vector<mindspore::MSTensor> &inputs,
const std::vector<mindspore::MSTensor> &outputs, const MSCallBackParam &call_param) {
// all outputs are same dtype.
for (size_t i = 0; i < outputs.size(); ++i) {
auto tensor = outputs.at(i);
if (save_flag_ && !tensor->quant_params().empty()) {
auto lite_tensor = quant::MSTensorToLiteTensor(tensor);
if (save_flag_ && !tensor.QuantParams().empty()) {
QuantParamExtend quant_param;
quant_param.node_name = call_param.node_name;
quant_param.node_type = call_param.node_type;
quant_param.quant_params = tensor->quant_params();
quant_param.tensor_name = tensor->tensor_name();
quant_param.element_num = tensor->ElementsNum();
quant_param.dims = tensor->shape();
quant_param.quant_params = lite_tensor->quant_params();
quant_param.tensor_name = lite_tensor->tensor_name();
quant_param.element_num = lite_tensor->ElementsNum();
quant_param.dims = lite_tensor->shape();
quant_params_.push_back(quant_param);
}
AddOriginInfo(call_param, op_parameters.at(call_param.node_name), false, i,
static_cast<mindspore::lite::Tensor *>(tensor));
static_cast<mindspore::lite::Tensor *>(lite_tensor));
}
return true;
};
} else {
after_callback = [&](const std::vector<mindspore::tensor::MSTensor *> &inputs,
const std::vector<mindspore::tensor::MSTensor *> &outputs, const CallBackParam &call_param) {
after_callback = [&](const std::vector<mindspore::MSTensor> &inputs,
const std::vector<mindspore::MSTensor> &outputs, const MSCallBackParam &call_param) {
// all outputs are same dtype.
for (size_t i = 0; i < outputs.size(); ++i) {
auto tensor = outputs.at(i);
AddComparedInfo(call_param, inputs, op_parameters.at(call_param.node_name), false, i,
static_cast<mindspore::lite::Tensor *>(tensor));
auto lite_tensor = quant::MSTensorToLiteTensor(tensor);
auto lite_inputs = quant::MSTensorToLiteTensors(inputs);
AddComparedInfo(call_param, lite_inputs, op_parameters.at(call_param.node_name), false, i,
static_cast<mindspore::lite::Tensor *>(lite_tensor));
}
return true;
};
@ -593,60 +597,61 @@ int DebugInfoManager::GetClipAndCos() {
return RET_OK;
}
int DebugInfoManager::CompareOriginWithQuant(const quant::SessionModel &origin, const quant::SessionModel &quant,
int DebugInfoManager::CompareOriginWithQuant(const std::shared_ptr<mindspore::Model> &origin,
const std::shared_ptr<mindspore::Model> &quant,
const std::map<std::string, OpParameter *> &op_parameters,
const std::string &debug_info_save_path,
const preprocess::DataPreProcessParam &data_preprocess) {
const preprocess::DataPreProcessParam &data_preprocess,
const mindspore::lite::Model *origin_lite_model,
const mindspore::lite::Model *quant_lite_model) {
auto begin = GetTimeUs();
auto origin_input_tensor_map = ParseInputTensorFromModel(*origin.model);
auto quant_input_tensor_map = ParseInputTensorFromModel(*quant.model);
auto origin_input_tensor_map = ParseInputTensors(*origin_lite_model);
auto quant_input_tensor_map = ParseInputTensors(*quant_lite_model);
int ret;
// When the calibration data set does not exist, use 1 round of random numbers for comparison
int rounds = data_preprocess.calibrate_size > 0 ? data_preprocess.calibrate_size : 1;
for (int round = 0; round < rounds; round++) {
for (auto tensor : origin.session->GetInputs()) {
for (auto tensor : origin->GetInputs()) {
if (data_preprocess.calibrate_size > 0) {
ret = preprocess::PreProcess(data_preprocess, tensor->tensor_name(), round, tensor);
ret = preprocess::PreProcess(data_preprocess, tensor.Name(), round, &tensor);
} else {
ret = GenerateRandomData(tensor);
ret = GenerateRandomData(&tensor);
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "round" << round << ":" << tensor->tensor_name() << " pre-process failed.";
MS_LOG(ERROR) << "round" << round << ":" << tensor.Name() << " pre-process failed.";
return ret;
}
}
std::cout << "Statistics the original data distribution. Round " << round << std::endl;
auto origin_before_callBack = GetBeforeCallBack(origin_input_tensor_map, op_parameters, true);
auto origin_after_callBack = GetAfterCallBack(op_parameters, true);
origin.session->BindThread(true);
ret = origin.session->RunGraph(origin_before_callBack, origin_after_callBack);
origin.session->BindThread(false);
if (ret != RET_OK) {
MS_LOG(ERROR) << "round:" << round << " origin session run graph failed.";
auto origin_outputs = origin->GetOutputs();
auto status = origin->Predict(origin->GetInputs(), &origin_outputs, origin_before_callBack, origin_after_callBack);
if (status != kSuccess) {
MS_LOG(ERROR) << "round:" << round << " origin model run graph failed.";
FreeBuffer();
return ret;
return RET_ERROR;
}
std::cout << "Statistics the quant data distribution. Round " << round << std::endl;
auto quant_before_callBack = GetBeforeCallBack(quant_input_tensor_map, op_parameters, false);
auto quant_after_callBack = GetAfterCallBack(op_parameters, false);
for (auto tensor : quant.session->GetInputs()) {
auto tensor_data = tensor->MutableData();
for (auto tensor : quant->GetInputs()) {
auto tensor_data = tensor.MutableData();
CHECK_NULL_RETURN(tensor_data);
ret = memcpy_s(tensor_data, tensor->Size(), origin.session->GetInputsByTensorName(tensor->tensor_name())->data(),
origin.session->GetInputsByTensorName(tensor->tensor_name())->Size());
ret = memcpy_s(tensor_data, tensor.DataSize(), origin->GetInputByTensorName(tensor.Name()).Data().get(),
origin->GetInputByTensorName(tensor.Name()).DataSize());
if (ret != EOK) {
MS_LOG(ERROR) << tensor->tensor_name() << " memcpy failed.";
MS_LOG(ERROR) << tensor.Name() << " memcpy failed.";
return RET_ERROR;
}
}
quant.session->BindThread(true);
ret = quant.session->RunGraph(quant_before_callBack, quant_after_callBack);
quant.session->BindThread(false);
if (ret != RET_OK) {
MS_LOG(ERROR) << "round:" << round << " quant session run graph failed.";
auto quant_outputs = quant->GetOutputs();
status = quant->Predict(quant->GetInputs(), &quant_outputs, quant_before_callBack, quant_after_callBack);
if (status != kSuccess) {
MS_LOG(ERROR) << "round:" << round << " quant model run graph failed.";
FreeBuffer();
return ret;
return RET_ERROR;
}
ret = GetClipAndCos();
if (ret != RET_OK) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -21,6 +21,7 @@
#include <vector>
#include <cstdio>
#include <map>
#include <memory>
#include "tools/converter/quantizer/quantize_util.h"
#include "nnacl/op_base.h"
#include "tools/common/statistic_utils.h"
@ -35,10 +36,12 @@ struct PrimaryKey {
std::string node_name;
InOutFlag in_out_flag;
size_t index;
friend bool operator<(const struct PrimaryKey &p1, const struct PrimaryKey &p2) {
return p1.node_name < p2.node_name || (p1.node_name == p2.node_name && p1.in_out_flag < p2.in_out_flag) ||
(p1.node_name == p2.node_name && p1.in_out_flag == p2.in_out_flag && p1.index < p2.index);
}
friend std::ostream &operator<<(std::ostream &os, const PrimaryKey &p) { // for struct output
os << "[" << p.node_name << "," << p.in_out_flag << "," << p.index << "]";
return os;
@ -82,16 +85,19 @@ struct QuantParamExtend {
class DebugInfoManager {
public:
int CompareOriginWithQuant(const quant::SessionModel &origin, const quant::SessionModel &quant,
int CompareOriginWithQuant(const std::shared_ptr<mindspore::Model> &origin,
const std::shared_ptr<mindspore::Model> &quant,
const std::map<std::string, OpParameter *> &op_parameters,
const std::string &debug_info_save_path,
const preprocess::DataPreProcessParam &data_preprocess);
const preprocess::DataPreProcessParam &data_preprocess,
const mindspore::lite::Model *origin_lite_model,
const mindspore::lite::Model *quant_lite_model);
private:
int AddOriginInfo(const mindspore::CallBackParam &call_back_param, OpParameter *op_parameter, bool is_input,
int AddOriginInfo(const mindspore::MSCallBackParam &call_back_param, OpParameter *op_parameter, bool is_input,
int tensor_index, mindspore::lite::Tensor *origin_tensor);
int AddComparedInfo(const mindspore::CallBackParam &call_back_param,
int AddComparedInfo(const mindspore::MSCallBackParam &call_back_param,
const std::vector<mindspore::tensor::MSTensor *> &inputs, OpParameter *op_parameter,
bool is_input, int tensor_index, mindspore::lite::Tensor *compared_tensor);
@ -114,22 +120,22 @@ class DebugInfoManager {
void SaveInfo(std::ofstream &out_file, const QuantDebugInfo &info);
std::map<std::string, mindspore::schema::Tensor *> ParseInputTensorFromModel(const Model &model);
std::map<std::string, mindspore::schema::Tensor *> ParseInputTensors(const mindspore::lite::Model &model);
std::map<std::string, mindspore::schema::Tensor *> ParseOutputTensorFromModel(const Model &model);
int GetDataFromTensorMap(const mindspore::schema::Tensor &schema_tensor, mindspore::lite::Tensor *dst_tensor);
KernelCallBack GetBeforeCallBack(const std::map<std::string, mindspore::schema::Tensor *> &input_tensor_map,
const std::map<std::string, OpParameter *> &op_parameters, bool is_origin);
MSKernelCallBack GetBeforeCallBack(const std::map<std::string, mindspore::schema::Tensor *> &input_tensor_map,
const std::map<std::string, OpParameter *> &op_parameters, bool is_origin);
KernelCallBack GetOriginBeforeCallBack(const std::map<std::string, mindspore::schema::Tensor *> &input_tensor_map,
const std::map<std::string, OpParameter *> &op_parameters);
MSKernelCallBack GetOriginBeforeCallBack(const std::map<std::string, mindspore::schema::Tensor *> &input_tensor_map,
const std::map<std::string, OpParameter *> &op_parameters);
KernelCallBack GetQuantBeforeCallBack(const std::map<std::string, mindspore::schema::Tensor *> &input_tensor_map,
const std::map<std::string, OpParameter *> &op_parameters);
MSKernelCallBack GetQuantBeforeCallBack(const std::map<std::string, mindspore::schema::Tensor *> &input_tensor_map,
const std::map<std::string, OpParameter *> &op_parameters);
KernelCallBack GetAfterCallBack(const std::map<std::string, OpParameter *> &op_parameters, bool is_origin);
MSKernelCallBack GetAfterCallBack(const std::map<std::string, OpParameter *> &op_parameters, bool is_origin);
int GetConstTensor(const std::map<std::string, mindspore::schema::Tensor *> &input_tensor_map,
mindspore::tensor::MSTensor *tensor, mindspore::lite::Tensor *new_tensor);

View File

@ -36,7 +36,6 @@
#include "include/model.h"
#include "base/base.h"
#include "abstract/dshape.h"
#include "src/lite_session.h"
#include "src/common/quant_utils.h"
namespace mindspore::lite::quant {

View File

@ -598,7 +598,6 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
}
// anf -- fb
flags_.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
MS_LOG(INFO) << "start create session";
fp32_ms_model_ = std::make_shared<mindspore::Model>();
if (fp32_ms_model_ == nullptr) {
@ -610,10 +609,6 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_LOG(ERROR) << "Build model failed.";
return RET_ERROR;
}
if (fp32_ms_model_ == nullptr) {
MS_LOG(ERROR) << "fp32_ms_model_ nullptr.";
return RET_ERROR;
}
MS_LOG(INFO) << "start to update divergence's max value";
status = DoInference(MIN_MAX);
if (status != RET_OK) {

View File

@ -27,7 +27,6 @@
#include <set>
#include "ops/primitive_c.h"
#include "schema/inner/model_generated.h"
#include "src/lite_session.h"
#include "tools/converter/quantizer/quantizer.h"
#include "include/ms_tensor.h"
#include "tools/converter/quantizer/quantize_util.h"

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,6 +14,7 @@
* limitations under the License.
*/
#define USE_DEPRECATED_API
#include "tools/converter/quantizer/parameter_tunner.h"
#include <set>
#include <functional>
@ -27,6 +28,7 @@
#include "tools/converter/export_model.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/parser/parser_utils.h"
namespace mindspore::lite::quant {
MinMax ParameterOptimizer::GetFineTuneRange(std::vector<float> *candidate_scales) {
const int top_3 = 3;
@ -59,14 +61,14 @@ int ParameterOptimizer::CloneFuncGraph(const FuncGraphPtr &func_graph, converter
}
int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph, converter::Flags *flags,
session::LiteSession *origin_session, int origin_model_size,
std::shared_ptr<mindspore::Model> origin_model, int origin_model_size,
const InferenceParam &param, double *init_scale,
std::vector<float> *candidate_scales, bool is_run_all) {
CHECK_NULL_RETURN(flags);
CHECK_NULL_RETURN(origin_session);
CHECK_NULL_RETURN(origin_model);
CHECK_NULL_RETURN(init_scale);
CHECK_NULL_RETURN(candidate_scales);
auto origin_out_tensor = origin_session->GetOutputs();
auto origin_out_tensor = origin_model->GetOutputs();
const float threshold = 0.995f;
float best_compress_ratio = 0.0f;
float best_compress_mean_error = 0.0f;
@ -94,39 +96,30 @@ int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph
MS_LOG(INFO) << "create quant session";
int weight_quant_size;
auto weight_quant_sm = CreateSessionByFuncGraph(func_graph_bak, *flags, param.thread_num, &weight_quant_size);
auto weight_quant_session = weight_quant_sm.session;
auto weight_quant_model = weight_quant_sm.model;
if (weight_quant_session == nullptr || weight_quant_model == nullptr) {
MS_LOG(WARNING) << "create session failed!";
auto weight_quant_model = std::make_shared<mindspore::Model>();
auto build_status = BuildModelByFuncGraph(weight_quant_model, func_graph_bak, *flags, &weight_quant_size);
if (build_status != kSuccess) {
MS_LOG(WARNING) << "build model failed!";
continue;
}
auto weight_quant_inputs = weight_quant_session->GetInputs();
auto weight_quant_inputs = weight_quant_model->GetInputs();
for (auto input : weight_quant_inputs) {
auto origin_tensor = origin_session->GetInputsByTensorName(input->tensor_name());
auto weight_quant_tensor_data = input->MutableData();
if (memcpy_s(weight_quant_tensor_data, input->Size(), origin_tensor->data(), origin_tensor->Size()) != EOK) {
auto origin_tensor = origin_model->GetInputByTensorName(input.Name());
auto weight_quant_tensor_data = input.MutableData();
if (memcpy_s(weight_quant_tensor_data, input.DataSize(), origin_tensor.Data().get(), origin_tensor.DataSize()) !=
EOK) {
MS_LOG(ERROR) << "memcpy data failed.";
delete weight_quant_session;
delete weight_quant_model;
return RET_ERROR;
}
}
weight_quant_session->BindThread(true);
ret = weight_quant_session->RunGraph();
weight_quant_session->BindThread(false);
if (ret != RET_OK) {
auto weight_quant_outputs = weight_quant_model->GetOutputs();
auto model_status = weight_quant_model->Predict(weight_quant_inputs, &weight_quant_outputs);
if (model_status != kSuccess) {
MS_LOG(ERROR) << "Run origin session failed.";
delete weight_quant_session;
delete weight_quant_model;
return ret;
return RET_ERROR;
}
auto weight_quant_tensor = weight_quant_session->GetOutputs();
auto cos_sim = CompareDataByCosineDistance<float>(origin_out_tensor, weight_quant_tensor);
auto mean_error = CompareData<float>(origin_out_tensor, weight_quant_tensor);
delete weight_quant_session;
delete weight_quant_model;
auto cos_sim = CompareDataByCosineDistance<float>(origin_model, weight_quant_model);
auto mean_error = CompareData<float>(origin_model, weight_quant_model);
if (!is_run_all) {
const int tolerate_round = 3;
@ -161,10 +154,10 @@ int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph
return RET_OK;
}
int ParameterOptimizer::OriginModelInference(const FuncGraphPtr &func_graph, converter::Flags *flags, SessionModel *sm,
int *origin_model_size) {
int ParameterOptimizer::OriginModelInference(const FuncGraphPtr &func_graph, converter::Flags *flags,
std::shared_ptr<mindspore::Model> origin_model, int *origin_model_size) {
CHECK_NULL_RETURN(flags);
CHECK_NULL_RETURN(sm);
CHECK_NULL_RETURN(origin_model);
CHECK_NULL_RETURN(origin_model_size);
FuncGraphPtr func_graph_bak;
auto ret = CloneFuncGraph(func_graph, flags, &func_graph_bak);
@ -174,32 +167,29 @@ int ParameterOptimizer::OriginModelInference(const FuncGraphPtr &func_graph, con
}
flags->commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
*origin_model_size = 0;
*sm = CreateSessionByFuncGraph(func_graph_bak, *flags, flags->commonQuantParam.thread_num, origin_model_size);
auto origin_session = sm->session;
auto origin_model = sm->model;
if (origin_session == nullptr || origin_model == nullptr) {
MS_LOG(ERROR) << "create session failed!";
auto status = BuildModelByFuncGraph(origin_model, func_graph_bak, *flags, origin_model_size);
if (status != kSuccess) {
MS_LOG(ERROR) << "build model failed!";
return RET_ERROR;
}
auto origin_inputs = origin_session->GetInputs();
auto origin_inputs = origin_model->GetInputs();
for (auto input : origin_inputs) {
if (flags->dataPreProcessParam.calibrate_size > 0) {
ret = preprocess::PreProcess(flags->dataPreProcessParam, input->tensor_name(), 0, input);
ret = preprocess::PreProcess(flags->dataPreProcessParam, input.Name(), 0, &input);
} else {
ret = GenerateRandomData(input);
ret = GenerateRandomData(&input);
}
if (ret != RET_OK) {
MS_LOG(ERROR) << input->tensor_name() << ":"
MS_LOG(ERROR) << input.Name() << ":"
<< "Generate random data failed.";
return ret;
}
}
origin_session->BindThread(true);
ret = origin_session->RunGraph();
origin_session->BindThread(false);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run origin session failed.";
return ret;
auto origin_outputs = origin_model->GetOutputs();
auto model_status = origin_model->Predict(origin_inputs, &origin_outputs);
if (model_status != kSuccess) {
MS_LOG(ERROR) << "Run origin predict failed.";
return RET_ERROR;
}
return RET_OK;
}
@ -211,15 +201,13 @@ int ParameterOptimizer::GridSearchForScale(const FuncGraphPtr &func_graph, conve
double default_init_scale = *init_scale;
SessionModel sm;
auto origin_model = std::make_shared<mindspore::Model>();
int origin_model_size;
auto ret = OriginModelInference(func_graph, flags, &sm, &origin_model_size);
auto ret = OriginModelInference(func_graph, flags, origin_model, &origin_model_size);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Origin Model Inference failed.";
return ret;
}
auto origin_session = sm.session;
auto origin_model = sm.model;
float start_scale = 0.005f;
const int giant_rounds = 10;
@ -233,12 +221,10 @@ int ParameterOptimizer::GridSearchForScale(const FuncGraphPtr &func_graph, conve
param.thread_num = flags->commonQuantParam.thread_num;
std::cout << "==========Search with giant step==============\n";
ret = WeightQuantModelInference(func_graph, flags, origin_session, origin_model_size, param, init_scale,
ret = WeightQuantModelInference(func_graph, flags, origin_model, origin_model_size, param, init_scale,
&candidate_scales, false);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Weight quant graph inference failed.";
delete origin_session;
delete origin_model;
return ret;
}
@ -248,8 +234,6 @@ int ParameterOptimizer::GridSearchForScale(const FuncGraphPtr &func_graph, conve
if (min_max.max - min_max.min <= 0) {
MS_LOG(WARNING) << "search reach max step, init_scale return default " << *init_scale;
*init_scale = default_init_scale;
delete origin_session;
delete origin_model;
return RET_OK;
}
const int baby_step_rounds = 25;
@ -260,16 +244,12 @@ int ParameterOptimizer::GridSearchForScale(const FuncGraphPtr &func_graph, conve
param.step = step;
param.thread_num = flags->commonQuantParam.thread_num;
std::cout << "==========Search with baby step==============\n";
ret = WeightQuantModelInference(func_graph, flags, origin_session, origin_model_size, param, init_scale,
ret = WeightQuantModelInference(func_graph, flags, origin_model, origin_model_size, param, init_scale,
&candidate_scales, true);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Weight quant graph inference failed.";
delete origin_session;
delete origin_model;
return ret;
}
delete origin_session;
delete origin_model;
return RET_OK;
}
} // namespace mindspore::lite::quant

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,18 +16,20 @@
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_PARAMETER_TUNNER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_PARAMETER_TUNNER_H
#include <utility>
#include <map>
#include <vector>
#include <memory>
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/converter/quantizer/weight_quantizer.h"
#include "tools/converter/export_model.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/parser/parser_utils.h"
#include "include/lite_session.h"
#include "include/model.h"
#include "base/base.h"
#include "tools/converter/converter_flags.h"
namespace mindspore::lite::quant {
struct InferenceParam {
size_t rounds;
@ -35,6 +37,7 @@ struct InferenceParam {
float step;
int thread_num;
};
class ParameterOptimizer {
public:
ParameterOptimizer() = default;
@ -49,12 +52,12 @@ class ParameterOptimizer {
int CloneFuncGraph(const FuncGraphPtr &func_graph, converter::Flags *flags, FuncGraphPtr *func_graph_bak);
int WeightQuantModelInference(const FuncGraphPtr &func_graph, converter::Flags *flags,
session::LiteSession *origin_session, int origin_model_size,
std::shared_ptr<mindspore::Model> origin_model, int origin_model_size,
const InferenceParam &param, double *init_scale, std::vector<float> *candidate_scales,
bool is_run_all);
int OriginModelInference(const FuncGraphPtr &func_graph, converter::Flags *flags, SessionModel *sm,
int *origin_model_size);
int OriginModelInference(const FuncGraphPtr &func_graph, converter::Flags *flags,
std::shared_ptr<mindspore::Model> origin_model, int *origin_model_size);
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_PARAMETER_TUNNER_H

View File

@ -15,6 +15,7 @@
*/
#define USE_DEPRECATED_API
#include "tools/converter/quantizer/quantization_optimizer.h"
#include <memory>
#include <string>
@ -30,6 +31,7 @@
#include "tools/converter/quantizer/debug_info_manager.h"
#include "tools/converter/quantizer/parameter_tunner.h"
#include "tools/converter/quantizer/dynamic_quantizer.h"
#include "tools/anf_exporter/anf_exporter.h"
namespace mindspore::lite::quant {
void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
@ -119,22 +121,59 @@ int DoDynamicQuant(const FuncGraphPtr &old_graph, const converter::Flags *config
return RET_OK;
}
int DoQuantDebug(const FuncGraphPtr &old_graph, const converter::Flags *config, const SessionModel &origin) {
auto quant = CreateSessionByFuncGraph(old_graph, *config, config->commonQuantParam.thread_num);
lite::Model *ParseLiteModel(const FuncGraphPtr &func_graph, const converter::Flags &flags) {
auto meta_graph = Export(func_graph, true, true);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta_graph failed";
return static_cast<Model *>(nullptr);
}
// transform
GraphDefTransform fb_transform;
fb_transform.SetGraphDef(meta_graph);
auto status = fb_transform.Transform(flags);
if (status != RET_OK) {
MS_LOG(ERROR) << "FBTransform model failed";
delete meta_graph;
return static_cast<Model *>(nullptr);
}
meta_graph->version = Version();
flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
auto offset = schema::MetaGraph::Pack(builder, meta_graph);
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
int size = builder.GetSize();
auto content = builder.GetBufferPointer();
if (content == nullptr) {
MS_LOG(ERROR) << "GetBufferPointer nullptr";
return static_cast<Model *>(nullptr);
}
return lite::Model::Import((const char *)content, size);
}
int DoQuantDebug(const FuncGraphPtr &old_graph, const converter::Flags *config,
const std::shared_ptr<mindspore::Model> &origin_model, mindspore::lite::Model *origin_lite_model) {
auto quant_model = std::make_shared<mindspore::Model>();
CHECK_NULL_RETURN(quant_model);
auto ret = BuildModelByFuncGraph(quant_model, old_graph, *config);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Build model failed";
return RET_ERROR;
}
std::map<std::string, OpParameter *> op_parameters;
FetchOpParameterFromFuncGraph(old_graph, &op_parameters);
DebugInfoManager manager;
CHECK_NULL_RETURN(origin.model);
CHECK_NULL_RETURN(origin.session);
CHECK_NULL_RETURN(quant.model);
CHECK_NULL_RETURN(quant.session);
auto status = manager.CompareOriginWithQuant(
origin, quant, op_parameters, config->commonQuantParam.debug_info_save_path, config->dataPreProcessParam);
auto quant_lite_model = ParseLiteModel(old_graph, *config);
if (quant_lite_model == nullptr) {
MS_LOG(ERROR) << "Parse lite model failed";
return RET_ERROR;
}
auto status = manager.CompareOriginWithQuant(origin_model, quant_model, op_parameters,
config->commonQuantParam.debug_info_save_path,
config->dataPreProcessParam, origin_lite_model, quant_lite_model);
auto free_buffer = [&] {
delete origin.session;
delete origin.model;
delete quant.session;
delete quant.model;
for (auto parameter : op_parameters) {
if (parameter.second != nullptr) {
free(parameter.second);
@ -158,11 +197,23 @@ int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags
}
int status;
SessionModel origin;
std::shared_ptr<mindspore::Model> origin;
lite::Model *origin_lite_model = nullptr;
if (config->commonQuantParam.is_debug) { // Bak fp32 model for debug
converter::Flags new_flag = *config;
new_flag.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
origin = CreateSessionByFuncGraph(old_graph, new_flag, config->commonQuantParam.thread_num);
origin = std::make_shared<mindspore::Model>();
CHECK_NULL_RETURN(origin);
auto ret = BuildModelByFuncGraph(origin, old_graph, new_flag);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Build model failed";
return RET_ERROR;
}
origin_lite_model = ParseLiteModel(old_graph, *config);
if (origin_lite_model == nullptr) {
MS_LOG(ERROR) << "Parse lite model failed.";
return RET_ERROR;
}
}
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) { // Full Quantization
status = DoFullQuant(old_graph, config);
@ -184,7 +235,7 @@ int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags
}
}
if (config->commonQuantParam.is_debug) {
status = DoQuantDebug(old_graph, config, origin);
status = DoQuantDebug(old_graph, config, origin, origin_lite_model);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do quant debug failed.";
return status;

View File

@ -254,76 +254,6 @@ std::string NodePrimitiveType(const CNodePtr &cnode) {
return primitive_c->name();
}
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num,
int *size) {
SessionModel sm;
auto meta_graph = Export(func_graph, true, true);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta_graph failed";
return sm;
}
// transform
GraphDefTransform fb_transform;
fb_transform.SetGraphDef(meta_graph);
auto status = fb_transform.Transform(flags);
if (status != RET_OK) {
MS_LOG(ERROR) << "FBTransform model failed";
delete meta_graph;
return sm;
}
meta_graph->version = Version();
flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
auto offset = schema::MetaGraph::Pack(builder, meta_graph);
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
*size = builder.GetSize();
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
if (content == nullptr) {
MS_LOG(ERROR) << "GetBufferPointer return null";
delete meta_graph;
return sm;
}
auto model = lite::Model::Import(content, *size);
if (model == nullptr) {
MS_LOG(ERROR) << "Import model failed";
delete meta_graph;
return sm;
}
Context ctx;
ctx.thread_num_ = thread_num;
MS_ASSERT(!ctx.device_list_.empty());
ctx.device_list_.front().device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU;
auto session = session::LiteSession::CreateSession(&ctx);
if (session == nullptr) {
MS_LOG(ERROR) << "create session failed.";
model->Free();
delete meta_graph;
delete model;
return sm;
}
status = session->CompileGraph(model);
if (status != RET_OK) {
MS_LOG(ERROR) << "CompileGraph error";
model->Free();
delete meta_graph;
delete session;
delete model;
return sm;
}
delete meta_graph;
sm.session = session;
sm.model = model;
return sm;
}
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num) {
int size = 0;
return CreateSessionByFuncGraph(func_graph, flags, thread_num, &size);
}
Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, const FuncGraphPtr &func_graph,
const converter::Flags &flags) {
int size = 0;
@ -370,7 +300,31 @@ Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, con
std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
auto &device_list = context->MutableDeviceInfo();
device_list.push_back(device_info);
return model->Build(content, *size, kMindIR, context);
auto ret = model->Build(content, *size, kMindIR, context);
delete meta_graph;
return ret;
}
mindspore::tensor::MSTensor *MSTensorToLiteTensor(const MSTensor &tensor) {
if (tensor.impl() == nullptr) {
MS_LOG(ERROR) << "Tensor " << tensor.Name() << " is nullptr.";
return static_cast<lite::Tensor *>(nullptr);
}
auto lite_impl = std::static_pointer_cast<LiteTensorImpl>(tensor.impl());
return static_cast<tensor::MSTensor *>(lite_impl->lite_tensor());
}
std::vector<mindspore::tensor::MSTensor *> MSTensorToLiteTensors(const std::vector<MSTensor> &srcTensors) {
std::vector<mindspore::tensor::MSTensor *> dstTensors;
dstTensors.reserve(srcTensors.size());
for (auto inTensor : srcTensors) {
auto tensor = MSTensorToLiteTensor(inTensor);
if (tensor == nullptr) {
return {};
}
dstTensors.emplace_back(tensor);
}
return dstTensors;
}
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info) {

View File

@ -18,8 +18,11 @@
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZE_UTIL_H_
#ifndef _MSC_VER
#include <dirent.h>
#endif
#include <sys/stat.h>
#include <memory>
#include <string>
@ -73,11 +76,6 @@ constexpr int kCpuBindMode = 1;
constexpr int kAnfWeightIndex = 2;
constexpr int kAnfBiasIndex = 3;
struct SessionModel {
session::LiteSession *session{nullptr};
Model *model{nullptr};
};
QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive);
std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epochs, schema::QuantParamT *quantParam);
@ -199,17 +197,16 @@ int FixedBitQuantFilter(const AnfNodePtr &parameter_node, const tensor::TensorPt
std::string NodePrimitiveType(const CNodePtr &cnode);
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num);
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num,
int *size);
Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, const FuncGraphPtr &func_graph,
const converter::Flags &flags);
Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, const FuncGraphPtr &func_graph,
const converter::Flags &flags, int *size);
mindspore::tensor::MSTensor *MSTensorToLiteTensor(const MSTensor &tensor);
std::vector<mindspore::tensor::MSTensor *> MSTensorToLiteTensors(const std::vector<MSTensor> &srcTensors);
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info);
bool CheckNodeInSet(const CNodePtr &cnode, const std::set<PrimitivePtr> &support_primitive_types);