!11695 [MS][LITE]huffman code support 1~8 bit && change it to internal interface

From: @jianghui58
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-29 14:12:55 +08:00 committed by Gitee
commit 5412b6ba3f
14 changed files with 379 additions and 344 deletions

View File

@ -14,7 +14,10 @@
* limitations under the License.
*/
#include <cmath>
#include <string>
#include <memory>
#include "src/dequant.h"
#include "src/huffman_decode.h"
namespace mindspore::lite {
float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
@ -34,13 +37,24 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
}
}
void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) {
int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) {
MS_ASSERT(input_tensor != nullptr);
MS_ASSERT(unpack_int_data != nullptr);
auto quant_params = input_tensor->quantParams();
if (quant_params == nullptr) {
MS_LOG(ERROR) << "low bits quantparams is empty.";
return;
return RET_ERROR;
}
auto enable_huffman_code = input_tensor->enableHuffmanCode();
if (enable_huffman_code) {
std::string encode_str(input_tensor->data()->begin(), input_tensor->data()->end());
auto huffman_decode = std::make_unique<lite::HuffmanDecode>();
auto ret = huffman_decode->DoHuffmanDecode(encode_str, unpack_int_data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoHuffmanDecode failed.";
return ret;
}
return RET_OK;
}
int origin_bit = quant_params->Get(0)->numBits();
if (origin_bit < 8 && origin_bit > 0) {
@ -48,6 +62,7 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i
} else if (origin_bit < 16 && origin_bit > 8) {
UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data);
}
return RET_OK;
}
std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors,

View File

@ -31,7 +31,7 @@ class DequantUtil {
public:
static float *DequantWeight(lite::Tensor *input_tensor);
static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);
static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);
static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors,
TypeId data_type, bool need_restore = true);
@ -110,6 +110,21 @@ class DequantUtil {
return dequant_datas;
}
template <typename T1, typename T2>
static void UnpackUtil(const T1 *weight_data, int pack_size, int origin_bit, void *unpack_int_data) {
if (weight_data == nullptr || unpack_int_data == nullptr) {
MS_LOG(ERROR) << "data is nullptr";
return;
}
std::queue<bool> unpack_bit_data;
size_t count = 0;
for (int i = 0; i < pack_size; ++i) {
T2 pack_data = (static_cast<const T2 *>(static_cast<const void *>(weight_data)))[i];
bool is_last = i == pack_size - 1;
UnPackData<T1, T2>(origin_bit, pack_data, &unpack_bit_data, unpack_int_data, &count, is_last);
}
}
private:
template <typename T1, typename T2>
static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int,

View File

@ -19,7 +19,7 @@
namespace mindspore {
namespace lite {
STATUS huffman_decode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) {
STATUS HuffmanDecode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) {
if (decoded_data == nullptr) {
MS_LOG(ERROR) << "decoded_data is nullptr.";
return RET_ERROR;
@ -64,7 +64,7 @@ STATUS huffman_decode::DoHuffmanDecode(const std::string &input_str, void *decod
return RET_OK;
}
STATUS huffman_decode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) {
STATUS HuffmanDecode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) {
HuffmanNodePtr cur_node, tmp_node, new_node;
auto huffman_keys = Str2Vec(std::move(keys));
@ -121,7 +121,7 @@ STATUS huffman_decode::RebuildHuffmanTree(std::string keys, std::string codes, c
return RET_OK;
}
STATUS huffman_decode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) {
STATUS HuffmanDecode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) {
HuffmanNodePtr cur_node = root;
bool pseudo_eof = false;
size_t pos = 0;
@ -157,7 +157,7 @@ STATUS huffman_decode::DoHuffmanDecompress(HuffmanNodePtr root, std::string enco
return RET_OK;
}
huffman_decode::~huffman_decode() {
HuffmanDecode::~HuffmanDecode() {
for (auto &node : this->huffman_nodes_) {
delete node;
}

View File

@ -38,11 +38,11 @@ struct HuffmanNode {
};
using HuffmanNodePtr = HuffmanNode *;
class huffman_decode {
class HuffmanDecode {
public:
huffman_decode() = default;
HuffmanDecode() = default;
~huffman_decode();
~HuffmanDecode();
STATUS DoHuffmanDecode(const std::string &input_str, void *decoded_data);

View File

@ -28,7 +28,6 @@
#include "src/kernel_registry.h"
#include "src/lite_model.h"
#include "src/dequant.h"
#include "src/huffman_decode.h"
#if SUPPORT_NPU
#include "src/runtime/agent/npu/npu_manager.h"
#include "src/runtime/agent/npu/optimizer/npu_pass_manager.h"
@ -96,13 +95,6 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
int org_size = dst_tensor->Size();
return (pack_size != org_size) && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16);
};
auto NeedHuffmanDecode = [&src_tensor, &dst_tensor]() -> bool {
auto data_type = src_tensor->dataType();
auto enable_huffman_code = src_tensor->enableHuffmanCode();
int pack_size = src_tensor->data()->size();
int org_size = dst_tensor->Size();
return (pack_size != org_size) && (data_type == kNumberTypeInt8) && enable_huffman_code;
};
auto src_category = TensorCategory(src_tensor);
if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) &&
src_tensor->data() != nullptr && src_tensor->data()->size() > 0) {
@ -116,21 +108,6 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
return RET_ERROR;
}
} else {
if (NeedHuffmanDecode()) {
auto dst_data = dst_tensor->MutableData();
if (dst_data == nullptr) {
MS_LOG(ERROR) << "Data from tensor is nullptr";
return RET_NULL_PTR;
}
std::string encode_str(src_tensor->data()->begin(), src_tensor->data()->end());
auto huffman_decode = std::make_unique<lite::huffman_decode>();
auto ret = huffman_decode->DoHuffmanDecode(encode_str, dst_data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoHuffmanDecode failed.";
return ret;
}
copyed_tensor_idxes_.emplace_back(tensor_index);
}
if (WeightTensorNeedCopy(model, tensor_index)) {
auto dst_data = dst_tensor->MutableData();
if (dst_data == nullptr) {
@ -138,7 +115,11 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
return RET_NULL_PTR;
}
if (NeedUnPack()) {
DequantUtil::UnPackToInt(src_tensor, dst_data);
auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "unpack to int failed.";
return RET_NULL_PTR;
}
} else {
memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size());
}
@ -148,9 +129,13 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
auto dst_data = dst_tensor->MutableData();
if (dst_data == nullptr) {
MS_LOG(ERROR) << "Data from tensor is nullptr";
return RET_NULL_PTR;
return RET_ERROR;
}
auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "unpack to int failed.";
return RET_ERROR;
}
DequantUtil::UnPackToInt(src_tensor, dst_data);
copyed_tensor_idxes_.emplace_back(tensor_index);
} else {
dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data()));

View File

@ -227,8 +227,8 @@ function Run_Converter() {
fi
model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'`
echo ${model_name} >> "${run_converter_log_file}"
echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true' >> "${run_converter_log_file}"
./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true
echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightChannel=0' >> "${run_converter_log_file}"
./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightChannel=0
if [ $? = 0 ]; then
converter_result='converter weight_quant '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file}
else

View File

@ -217,26 +217,14 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla
const FuncGraphPtr &new_graph) {
// quant
if (config->quantType == schema::QuantType_PostTraining) {
if (!quant::WeightQuantizer::IsPosNum(config->bitNum)) {
MS_LOG(ERROR) << "bitNum must be valid pos num.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
this->mQuantizer =
std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, std::stoi(config->bitNum));
this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, config->bitNum);
if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}
} else if (config->quantType == schema::QuantType_WeightQuant) {
if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) {
MS_LOG(ERROR) << "weight quant input param error";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->configFile, config->quantWeightSize,
config->quantWeightChannel, config->bitNum);
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, *config);
if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
@ -255,10 +243,15 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla
return RET_OK;
}
int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph) {
if (config->quantType == schema::QuantType_WeightQuant && config->bitNum == "8" && config->enableHuffmanCode) {
auto huffman_encode = std::make_unique<lite::huffman_encode>();
auto status = huffman_encode->DoHuffmanEncode(new_graph);
int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph,
bool enableHuffmanCode) {
if (config->quantType == schema::QuantType_WeightQuant && enableHuffmanCode) {
if (config->bitNum < 16 && config->bitNum > 8) {
MS_LOG(WARNING) << "don't support huffman encode when 8 < bitNum < 16 currently.";
return RET_OK;
}
auto huffman_encode = std::make_unique<lite::HuffmanEncode>();
auto status = huffman_encode->DoHuffmanEncode(new_graph, config->bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "Huffman encode failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -322,7 +315,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
return nullptr;
}
status = DoHuffmanEncode(config, new_graph);
status = DoHuffmanEncode(config, new_graph, false);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do HuffmanCode failed.";
return nullptr;

View File

@ -59,7 +59,7 @@ class AnfTransform {
int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph);
int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph);
int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph, bool enableHuffmanCode);
};
} // namespace lite
} // namespace mindspore

View File

@ -38,16 +38,12 @@ Flags::Flags() {
"UINT8 | DEFAULT",
"DEFAULT");
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", "");
AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8");
AddFlag(&Flags::quantWeightSize, "quantWeightSize", "Weight quantization size threshold", "0");
AddFlag(&Flags::quantWeightChannel, "quantWeightChannel", "Channel threshold for weight quantization", "16");
AddFlag(&Flags::bitNumIn, "bitNum", "Weight quantization bitNum", "8");
AddFlag(&Flags::quantWeightSizeIn, "quantWeightSize", "Weight quantization size threshold", "0");
AddFlag(&Flags::quantWeightChannelIn, "quantWeightChannel", "Channel threshold for weight quantization", "16");
AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", "");
AddFlag(&Flags::enableHuffmanCodeIn, "enableHuffmanCode",
"whether the weight quant model is going to use huffman code."
"true | false",
"false");
AddFlag(&Flags::trainModelIn, "trainModel",
"whether the model is going to be trained on device."
"whether the model is going to be trained on device. "
"true | false",
"false");
}
@ -107,7 +103,41 @@ int Flags::InitFmk() {
return RET_OK;
}
int Flags::InitQuantType() {
bool Flags::IsValidNum(const std::string &str, int *num) {
char *ptr;
*num = strtol(str.c_str(), &ptr, 10);
return ptr == (str.c_str() + str.size());
}
int Flags::QuantParamInputCheck() {
if (!Flags::IsValidNum(this->quantWeightChannelIn, &this->quantWeightChannel)) {
std::cerr << "quantWeightChannel should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (this->quantWeightChannel < 0) {
std::cerr << "quantWeightChannel should be greater than or equal to zero.";
return RET_INPUT_PARAM_INVALID;
}
if (!Flags::IsValidNum(this->quantWeightSizeIn, &this->quantWeightSize)) {
std::cerr << "quantWeightSize should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (this->quantWeightSize < 0) {
std::cerr << "quantWeightSize should be greater than or equal to zero.";
return RET_INPUT_PARAM_INVALID;
}
if (!Flags::IsValidNum(this->bitNumIn, &this->bitNum)) {
std::cerr << "bitNum should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (this->bitNum <= 0 || this->bitNum > 16) {
std::cerr << "bitNum should be greater than zero and lesser than 16 currently.";
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
}
int Flags::InitQuantParam() {
if (this->quantTypeIn == "WeightQuant") {
this->quantType = QuantType_WeightQuant;
} else if (this->quantTypeIn == "PostTraining") {
@ -118,19 +148,9 @@ int Flags::InitQuantType() {
std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining";
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
}
int Flags::InitHuffmanCode() {
if (this->enableHuffmanCodeIn == "true") {
this->enableHuffmanCode = true;
} else if (this->enableHuffmanCodeIn == "false") {
this->enableHuffmanCode = false;
} else {
std::cerr << "INPUT ILLEGAL: trainModel must be true|false ";
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
auto ret = QuantParamInputCheck();
return ret;
}
int Flags::InitTrainModel() {
@ -218,15 +238,9 @@ int Flags::Init(int argc, const char **argv) {
return RET_INPUT_PARAM_INVALID;
}
ret = InitQuantType();
ret = InitQuantParam();
if (ret != RET_OK) {
std::cerr << "Init quant type failed.";
return RET_INPUT_PARAM_INVALID;
}
ret = InitHuffmanCode();
if (ret != RET_OK) {
std::cerr << "Init huffman code failed.";
std::cerr << "Init quant param failed.";
return RET_INPUT_PARAM_INVALID;
}

View File

@ -49,9 +49,11 @@ class Flags : public virtual mindspore::lite::FlagParser {
int InitFmk();
int InitQuantType();
bool IsValidNum(const std::string &str, int *num);
int InitHuffmanCode();
int QuantParamInputCheck();
int InitQuantParam();
int InitTrainModel();
@ -76,12 +78,13 @@ class Flags : public virtual mindspore::lite::FlagParser {
TypeId inputDataType;
TypeId outputDataType;
// used for post-trainning-weight
std::string quantWeightSize;
std::string bitNum;
std::string quantWeightSizeIn;
int quantWeightSize;
std::string bitNumIn;
int bitNum;
std::string configFile;
std::string quantWeightChannel;
std::string enableHuffmanCodeIn;
bool enableHuffmanCode = false;
std::string quantWeightChannelIn;
int quantWeightChannel;
std::string trainModelIn;
bool trainModel = false;
};

View File

@ -18,18 +18,51 @@
#include <utility>
#include <iostream>
#include <memory>
#include <vector>
#include "securec/include/securec.h"
#include "src/param_value_lite.h"
#include "src/dequant.h"
namespace mindspore {
namespace lite {
STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) {
STATUS HuffmanEncode::GetParamValueLitePtr(const std::shared_ptr<AnfNode> &input_node, ParamValueLitePtr *param_value) {
if (!input_node->isa<Parameter>()) {
return RET_CONTINUE;
}
auto abstract_base = input_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope();
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
if (abstract_tensor->element() == nullptr) {
MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
auto tensor_type = abstract_tensor->element()->GetTypeTrack();
MS_ASSERT(tensor_type != nullptr);
auto tensor_type_id = tensor_type->type_id();
if (tensor_type_id != kNumberTypeInt8) {
return RET_CONTINUE;
}
auto param_node = input_node->cast<ParameterPtr>();
if (param_node == nullptr) {
MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
if (!param_node->has_default()) {
MS_LOG(WARNING) << "param_node don't have default: " << input_node->fullname_with_scope();
return RET_CONTINUE;
}
*param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
return RET_OK;
}
STATUS HuffmanEncode::DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num) {
auto cnodes = func_graph->GetOrderedCnodes();
STATUS status;
for (auto &cnode : cnodes) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
@ -41,45 +74,33 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) {
}
for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i);
if (!input_node->isa<Parameter>()) {
ParamValueLitePtr param_value;
auto status = GetParamValueLitePtr(input_node, &param_value);
if (status == RET_CONTINUE) {
continue;
}
auto abstract_base = input_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope();
} else if (status == RET_ERROR) {
MS_LOG(ERROR) << "Get param value lite ptr failed. " << cnode->fullname_with_scope();
return RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope();
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
if (abstract_tensor->element() == nullptr) {
MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
auto tensor_type = abstract_tensor->element()->GetTypeTrack();
MS_ASSERT(tensor_type != nullptr);
auto tensor_type_id = tensor_type->type_id();
if (tensor_type_id != kNumberTypeInt8) {
continue;
}
auto param_node = input_node->cast<ParameterPtr>();
if (param_node == nullptr) {
MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
if (!param_node->has_default()) {
MS_LOG(WARNING) << "param_node don't have default: " << cnode->fullname_with_scope();
continue;
}
ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
size_t elem_count = param_value->tensor_shape_size();
size_t packed_size = param_value->tensor_size();
auto *raw_datas = static_cast<int8_t *>(param_value->tensor_addr());
if (raw_datas == nullptr) {
MS_LOG(ERROR) << "rawDatas is nullptr";
return RET_ERROR;
}
if (bit_num < 8 && bit_num > 0) {
auto dst_data = new (std::nothrow) int8_t[elem_count];
if (dst_data == nullptr) {
MS_LOG(ERROR) << "new int8_t[] failed";
return RET_ERROR;
}
DequantUtil::UnpackUtil<int8_t, uint8_t>(raw_datas, packed_size, bit_num, dst_data);
if (memcpy_s(raw_datas, elem_count, dst_data, elem_count) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed.";
return RET_MEMORY_FAILED;
}
}
HuffmanPriorityQueue pq;
status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq);
if (status != RET_OK) {
@ -97,12 +118,14 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) {
return status;
}
size_t ch_size = huffman_encoded_str_.length();
if (ch_size < elem_count) {
if (ch_size < packed_size) {
auto encode_data = new (std::nothrow) char[ch_size];
if (encode_data == nullptr) {
MS_LOG(ERROR) << "new char[] failed.";
delete[] raw_datas;
return RET_MEMORY_FAILED;
}
delete[] raw_datas;
if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed.";
delete[] encode_data;
@ -118,7 +141,7 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) {
return RET_SUCCESS;
}
STATUS huffman_encode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) {
STATUS HuffmanEncode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) {
MS_ASSERT(data != nullptr);
std::map<int8_t, size_t> freq_map;
@ -166,7 +189,7 @@ STATUS huffman_encode::GetHuffmanPriorityQueue(const int8_t *data, const size_t
return RET_OK;
}
void huffman_encode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) {
void HuffmanEncode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) {
if (is_left_node) {
node->code = node->parent->code + "0";
} else {
@ -185,7 +208,7 @@ void huffman_encode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_lef
}
}
STATUS huffman_encode::BuildHuffmanTree(HuffmanPriorityQueue *pq) {
STATUS HuffmanEncode::BuildHuffmanTree(HuffmanPriorityQueue *pq) {
HuffmanNodePtr root = nullptr;
while (!pq->empty()) {
@ -228,7 +251,7 @@ STATUS huffman_encode::BuildHuffmanTree(HuffmanPriorityQueue *pq) {
return RET_OK;
}
STATUS huffman_encode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) {
STATUS HuffmanEncode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) {
unsigned char out_c;
string code_str;
std::map<int, string>::iterator iter;
@ -270,7 +293,7 @@ STATUS huffman_encode::DoHuffmanCompress(const int8_t *input_datas, const size_t
return RET_OK;
}
huffman_encode::~huffman_encode() {
HuffmanEncode::~HuffmanEncode() {
for (auto &node : this->huffman_nodes_) {
delete node;
}

View File

@ -23,9 +23,12 @@
#include <vector>
#include <queue>
#include <map>
#include <memory>
#include <fstream>
#include "src/common/log_adapter.h"
#include "src/ops/primitive_c.h"
#include "securec/include/securec.h"
#include "src/param_value_lite.h"
#include "ir/func_graph.h"
namespace mindspore {
@ -49,13 +52,15 @@ struct cmp {
};
using HuffmanPriorityQueue = std::priority_queue<HuffmanNodePtr, std::vector<HuffmanNodePtr>, cmp>;
class huffman_encode {
class HuffmanEncode {
public:
huffman_encode() = default;
HuffmanEncode() = default;
~huffman_encode();
~HuffmanEncode();
STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph);
STATUS GetParamValueLitePtr(const std::shared_ptr<AnfNode> &input_node, ParamValueLitePtr *param_value);
STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num);
private:
std::map<int, std::string> huffman_table_;

View File

@ -25,52 +25,16 @@ using std::string;
using std::vector;
namespace mindspore::lite::quant {
bool WeightQuantizer::IsPosNum(const std::string &str) {
for (size_t i = 0; i < str.size(); i++) {
if (str.at(i) < '0' || str.at(i) > '9') {
return false;
}
if (str.at(i) == '0' && i == 0 && str.size() != 1) {
return false;
}
}
return true;
}
STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) {
MS_ASSERT(config != nullptr);
if (!WeightQuantizer::IsPosNum(config->quantWeightChannel)) {
MS_LOG(ERROR) << "quantWeightChannel must be valid pos num.";
return RET_ERROR;
}
if (!WeightQuantizer::IsPosNum(config->quantWeightSize)) {
MS_LOG(ERROR) << "quantWeightSize must be valid pos num.";
return RET_ERROR;
}
if (!WeightQuantizer::IsPosNum(config->bitNum)) {
MS_LOG(ERROR) << "bitNum must be valid pos num.";
return RET_ERROR;
}
int bitNum = std::stoi(config->bitNum);
if (bitNum <= 0 || bitNum > 16) {
MS_LOG(ERROR) << "bitNum should be more than 0 and less than 16 currently.";
return RET_ERROR;
}
return RET_OK;
}
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) {
quant_strategy_ = std::make_unique<QuantStrategy>(0, 0);
config_param_ = config;
}
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const string &weightSize,
const std::string &convWeightChannelThreshold, const std::string &bitNum)
: Quantizer(graph) {
this->config_file_ = config_file;
auto quantSize = static_cast<size_t>(std::stoull(weightSize));
this->bit_num_ = static_cast<size_t>(std::stoull(bitNum));
auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold));
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config) : Quantizer(graph) {
this->config_file_ = config.configFile;
auto quantSize = config.quantWeightSize;
this->bit_num_ = config.bitNum;
auto convQuantWeightChannelThreshold = config.quantWeightChannel;
quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold);
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
@ -222,7 +186,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
return RET_OK;
}
STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) {
STATUS WeightQuantizer::DoLstmQuantize(CNodePtr cnode) {
MS_ASSERT(cnode != nullptr);
auto op_name = cnode->fullname_with_scope();
@ -233,110 +197,29 @@ STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) {
MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size();
return RET_ERROR;
}
{
auto weight_i = cnode->input(2);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(weight_i, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_OK;
}
if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < "
<< quant_strategy_->mWeightSize;
return RET_OK;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 1);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 1);
}
auto status = ProcessLstmWeightByIndex(cnode, primitive_c, 2);
if (status != RET_OK) {
MS_LOG(ERROR) << "Process lstm weight i failed.";
return RET_ERROR;
}
status = ProcessLstmWeightByIndex(cnode, primitive_c, 3);
if (status != RET_OK) {
MS_LOG(ERROR) << "Process lstm weight h failed.";
return RET_ERROR;
}
if (cnode->inputs().size() > 4) {
status = ProcessLstmWeightByIndex(cnode, primitive_c, 4);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
MS_LOG(ERROR) << "Process lstm bias failed.";
return RET_ERROR;
}
}
{
auto weight_h = cnode->input(3);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(weight_h, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_ERROR;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 2);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 2);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
}
{
if (cnode->inputs().size() > 4) {
auto bias = cnode->input(4);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(bias, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_ERROR;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 3);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 3);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
}
}
return RET_OK;
return status;
}
STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) {
STATUS WeightQuantizer::DoGatherQuantize(CNodePtr cnode) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
MS_ASSERT(primitive_c != nullptr);
@ -375,6 +258,46 @@ STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) {
return RET_OK;
}
STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const std::shared_ptr<PrimitiveC> &primitive_c,
const int &index) {
auto op_name = cnode->fullname_with_scope();
auto weight_i = cnode->input(index);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(weight_i, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_OK;
}
if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < "
<< quant_strategy_->mWeightSize;
return RET_OK;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, index - 1);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, index - 1);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
return RET_OK;
}
constexpr float relative_tolerance = 1e-5;
constexpr float abs_tolerance = 1e-4;
@ -510,37 +433,28 @@ STATUS WeightQuantizer::RunFp32Graph(FuncGraphPtr func_graph) {
return RET_OK;
}
STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
// 0.2 Parse input calib files
auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
if (status != RET_OK) {
MS_LOG(ERROR) << "CollectCalibInputs fail";
return RET_ERROR;
}
MS_LOG(DEBUG) << "run fp32 model";
status = RunFp32Graph(func_graph);
if (status != RET_OK) {
return RET_ERROR;
}
STATUS WeightQuantizer::DoMixedQuantize(const FuncGraphPtr &func_graph) {
auto cnodes = func_graph->GetOrderedCnodes();
int status = RET_OK;
for (auto &cnode : cnodes) {
auto op_type = NodePrimitiveType(cnode);
if (op_type == schema::PrimitiveType_Lstm) {
status = DoLstmQuntize(cnode);
status = DoLstmQuantize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoLstmQuntize error";
MS_LOG(ERROR) << "DoLstmQuantize error";
return RET_ERROR;
}
} else if (op_type == schema::PrimitiveType_Gather) {
status = DoGatherQuntize(cnode);
status = DoGatherQuantize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoGatherQuntize error";
MS_LOG(ERROR) << "DoGatherQuantize error";
return RET_ERROR;
}
}
}
return status;
}
STATUS WeightQuantizer::CheckImageCnt() {
auto image_cnt = images_.at(0).size();
if (!config_param_.input_shapes.empty()) {
if (config_param_.input_shapes.size() != image_cnt) {
@ -548,7 +462,62 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
return RET_ERROR;
}
}
return RET_OK;
}
STATUS WeightQuantizer::GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name,
ParameterPtr *param_node, ParamValueLitePtr *param_value) {
if (!input_node->isa<Parameter>()) {
MS_LOG(WARNING) << op_name << " the second input is not parameter";
return RET_CONTINUE;
}
*param_node = input_node->cast<ParameterPtr>();
if (!(*param_node)->has_default()) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
return RET_CONTINUE;
}
*param_value = std::static_pointer_cast<ParamValueLite>((*param_node)->default_param());
if (*param_value == nullptr) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
return RET_CONTINUE;
}
if ((*param_value)->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << op_name << " the second input type is not float";
return RET_CONTINUE;
}
return RET_OK;
}
STATUS WeightQuantizer::TryQuant(const int &bit_num_t, const ParameterPtr &param_node,
const ParamValueLitePtr &param_value, const std::shared_ptr<PrimitiveC> &primitive_c) {
int status;
type_id_ = TypeId::kNumberTypeInt8;
int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
if (type_id_ == TypeId::kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
bit_num_t, true);
} else if (type_id_ == TypeId::kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
bit_num_t, true);
} else {
MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
return RET_ERROR;
}
if (status != RET_OK) {
MS_LOG(ERROR) << "quant filter failed.";
return RET_ERROR;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
return status;
}
STATUS WeightQuantizer::DoQuantSearch(const FuncGraphPtr &func_graph) {
auto cnodes = func_graph->GetOrderedCnodes();
auto image_cnt = images_.at(0).size();
int status = RET_OK;
for (auto iter = cnodes.end(); iter != cnodes.begin();) {
auto cnode = *(--iter);
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
@ -561,22 +530,10 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
<< " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type());
if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) {
auto input_node = cnode->input(2);
if (!input_node->isa<Parameter>()) {
MS_LOG(WARNING) << op_name << " the second input is not parameter";
continue;
}
auto param_node = input_node->cast<ParameterPtr>();
if (!param_node->has_default()) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
continue;
}
auto param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
if (param_value == nullptr) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
continue;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << op_name << " the second input type is not float";
ParameterPtr param_node;
ParamValueLitePtr param_value;
status = GetParamNodeAndValue(input_node, op_name, &param_node, &param_value);
if (status == RET_CONTINUE) {
continue;
}
// copy origin data in case to recover
@ -591,27 +548,9 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
}
// 1. try quant
for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) {
type_id_ = TypeId::kNumberTypeInt8;
int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
if (type_id_ == TypeId::kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
quant_min_t, bit_num_t, true);
} else if (type_id_ == TypeId::kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
quant_min_t, bit_num_t, true);
} else {
MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
return RET_ERROR;
}
status = TryQuant(bit_num_t, param_node, param_value, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "quant filter fail.";
return RET_ERROR;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
MS_LOG(ERROR) << "TryQuant failed.";
return RET_ERROR;
}
// 2. evaluate the quant
@ -679,6 +618,41 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
free(origin_data);
} // if: conv and matmul
} // end loop: all cnode
return status;
}
STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
// 0.2 Parse input calib files
auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
if (status != RET_OK) {
MS_LOG(ERROR) << "CollectCalibInputs failed.";
return RET_ERROR;
}
status = RunFp32Graph(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "RunFp32Graph failed.";
return RET_ERROR;
}
status = DoMixedQuantize(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoMixedQuantize failed.";
return RET_ERROR;
}
status = CheckImageCnt();
if (status != RET_OK) {
MS_LOG(ERROR) << "CheckImageCnt failed.";
return RET_ERROR;
}
status = DoQuantSearch(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantSearch failed.";
return RET_ERROR;
}
for (const auto &kv : opname_bit_) {
MS_LOG(INFO) << "op: " << kv.first << " bit:" << kv.second;
}
@ -709,15 +683,15 @@ STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) {
return RET_ERROR;
}
} else if (op_type == schema::PrimitiveType_Lstm) {
auto status = DoLstmQuntize(cnode);
auto status = DoLstmQuantize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoLstmQuntize error";
MS_LOG(ERROR) << "DoLstmQuantize error";
return RET_ERROR;
}
} else if (op_type == schema::PrimitiveType_Gather) {
auto status = DoGatherQuntize(cnode);
auto status = DoGatherQuantize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoGatherQuntize error";
MS_LOG(ERROR) << "DoGatherQuantize error";
return RET_ERROR;
}
} else {

View File

@ -36,18 +36,18 @@
namespace mindspore::lite::quant {
class WeightQuantizer : public Quantizer {
public:
WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const std::string &weightSize,
const std::string &covWeightChannelThreshold, const std::string &bitNum);
WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config);
WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config);
~WeightQuantizer();
STATUS DoQuantize(FuncGraphPtr func_graph) override;
STATUS DoConvQuantize(CNodePtr);
STATUS DoMulQuantize(CNodePtr);
STATUS DoLstmQuntize(CNodePtr cnode);
STATUS DoGatherQuntize(CNodePtr cnode);
static STATUS WeightQuantInputCheck(const converter::Flags *config);
static bool IsPosNum(const std::string &str);
STATUS DoLstmQuantize(CNodePtr cnode);
STATUS DoGatherQuantize(CNodePtr cnode);
STATUS ProcessLstmWeightByIndex(const CNodePtr &cnode, const std::shared_ptr<PrimitiveC> &primitive_c,
const int &index);
int quant_max_{127};
int quant_min_{-128};
@ -66,6 +66,14 @@ class WeightQuantizer : public Quantizer {
STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c);
STATUS DoFixedQuant(FuncGraphPtr);
STATUS RunFp32Graph(FuncGraphPtr);
STATUS DoMixedQuantize(const FuncGraphPtr &func_graph);
STATUS CheckImageCnt();
STATUS GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name,
ParameterPtr *param_node, ParamValueLitePtr *param_value);
STATUS TryQuant(const int &bit_num_t, const ParameterPtr &param_node, const ParamValueLitePtr &param_value,
const std::shared_ptr<PrimitiveC> &primitive_c);
STATUS DoQuantSearch(const FuncGraphPtr &func_graph);
};
} // namespace mindspore::lite::quant
#endif