!11695 [MS][LITE]huffman code support 1~8 bit && change it to internal interface

From: @jianghui58
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-29 14:12:55 +08:00 committed by Gitee
commit 5412b6ba3f
14 changed files with 379 additions and 344 deletions

View File

@ -14,7 +14,10 @@
* limitations under the License. * limitations under the License.
*/ */
#include <cmath> #include <cmath>
#include <string>
#include <memory>
#include "src/dequant.h" #include "src/dequant.h"
#include "src/huffman_decode.h"
namespace mindspore::lite { namespace mindspore::lite {
float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
@ -34,13 +37,24 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
} }
} }
void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) { int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) {
MS_ASSERT(input_tensor != nullptr); MS_ASSERT(input_tensor != nullptr);
MS_ASSERT(unpack_int_data != nullptr); MS_ASSERT(unpack_int_data != nullptr);
auto quant_params = input_tensor->quantParams(); auto quant_params = input_tensor->quantParams();
if (quant_params == nullptr) { if (quant_params == nullptr) {
MS_LOG(ERROR) << "low bits quantparams is empty."; MS_LOG(ERROR) << "low bits quantparams is empty.";
return; return RET_ERROR;
}
auto enable_huffman_code = input_tensor->enableHuffmanCode();
if (enable_huffman_code) {
std::string encode_str(input_tensor->data()->begin(), input_tensor->data()->end());
auto huffman_decode = std::make_unique<lite::HuffmanDecode>();
auto ret = huffman_decode->DoHuffmanDecode(encode_str, unpack_int_data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoHuffmanDecode failed.";
return ret;
}
return RET_OK;
} }
int origin_bit = quant_params->Get(0)->numBits(); int origin_bit = quant_params->Get(0)->numBits();
if (origin_bit < 8 && origin_bit > 0) { if (origin_bit < 8 && origin_bit > 0) {
@ -48,6 +62,7 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i
} else if (origin_bit < 16 && origin_bit > 8) { } else if (origin_bit < 16 && origin_bit > 8) {
UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data); UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data);
} }
return RET_OK;
} }
std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors, std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors,

View File

@ -31,7 +31,7 @@ class DequantUtil {
public: public:
static float *DequantWeight(lite::Tensor *input_tensor); static float *DequantWeight(lite::Tensor *input_tensor);
static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);
static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors, static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors,
TypeId data_type, bool need_restore = true); TypeId data_type, bool need_restore = true);
@ -110,6 +110,21 @@ class DequantUtil {
return dequant_datas; return dequant_datas;
} }
template <typename T1, typename T2>
static void UnpackUtil(const T1 *weight_data, int pack_size, int origin_bit, void *unpack_int_data) {
if (weight_data == nullptr || unpack_int_data == nullptr) {
MS_LOG(ERROR) << "data is nullptr";
return;
}
std::queue<bool> unpack_bit_data;
size_t count = 0;
for (int i = 0; i < pack_size; ++i) {
T2 pack_data = (static_cast<const T2 *>(static_cast<const void *>(weight_data)))[i];
bool is_last = i == pack_size - 1;
UnPackData<T1, T2>(origin_bit, pack_data, &unpack_bit_data, unpack_int_data, &count, is_last);
}
}
private: private:
template <typename T1, typename T2> template <typename T1, typename T2>
static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int, static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int,

View File

@ -19,7 +19,7 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS huffman_decode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) { STATUS HuffmanDecode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) {
if (decoded_data == nullptr) { if (decoded_data == nullptr) {
MS_LOG(ERROR) << "decoded_data is nullptr."; MS_LOG(ERROR) << "decoded_data is nullptr.";
return RET_ERROR; return RET_ERROR;
@ -64,7 +64,7 @@ STATUS huffman_decode::DoHuffmanDecode(const std::string &input_str, void *decod
return RET_OK; return RET_OK;
} }
STATUS huffman_decode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) { STATUS HuffmanDecode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) {
HuffmanNodePtr cur_node, tmp_node, new_node; HuffmanNodePtr cur_node, tmp_node, new_node;
auto huffman_keys = Str2Vec(std::move(keys)); auto huffman_keys = Str2Vec(std::move(keys));
@ -121,7 +121,7 @@ STATUS huffman_decode::RebuildHuffmanTree(std::string keys, std::string codes, c
return RET_OK; return RET_OK;
} }
STATUS huffman_decode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) { STATUS HuffmanDecode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) {
HuffmanNodePtr cur_node = root; HuffmanNodePtr cur_node = root;
bool pseudo_eof = false; bool pseudo_eof = false;
size_t pos = 0; size_t pos = 0;
@ -157,7 +157,7 @@ STATUS huffman_decode::DoHuffmanDecompress(HuffmanNodePtr root, std::string enco
return RET_OK; return RET_OK;
} }
huffman_decode::~huffman_decode() { HuffmanDecode::~HuffmanDecode() {
for (auto &node : this->huffman_nodes_) { for (auto &node : this->huffman_nodes_) {
delete node; delete node;
} }

View File

@ -38,11 +38,11 @@ struct HuffmanNode {
}; };
using HuffmanNodePtr = HuffmanNode *; using HuffmanNodePtr = HuffmanNode *;
class huffman_decode { class HuffmanDecode {
public: public:
huffman_decode() = default; HuffmanDecode() = default;
~huffman_decode(); ~HuffmanDecode();
STATUS DoHuffmanDecode(const std::string &input_str, void *decoded_data); STATUS DoHuffmanDecode(const std::string &input_str, void *decoded_data);

View File

@ -28,7 +28,6 @@
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/lite_model.h" #include "src/lite_model.h"
#include "src/dequant.h" #include "src/dequant.h"
#include "src/huffman_decode.h"
#if SUPPORT_NPU #if SUPPORT_NPU
#include "src/runtime/agent/npu/npu_manager.h" #include "src/runtime/agent/npu/npu_manager.h"
#include "src/runtime/agent/npu/optimizer/npu_pass_manager.h" #include "src/runtime/agent/npu/optimizer/npu_pass_manager.h"
@ -96,13 +95,6 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
int org_size = dst_tensor->Size(); int org_size = dst_tensor->Size();
return (pack_size != org_size) && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16); return (pack_size != org_size) && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16);
}; };
auto NeedHuffmanDecode = [&src_tensor, &dst_tensor]() -> bool {
auto data_type = src_tensor->dataType();
auto enable_huffman_code = src_tensor->enableHuffmanCode();
int pack_size = src_tensor->data()->size();
int org_size = dst_tensor->Size();
return (pack_size != org_size) && (data_type == kNumberTypeInt8) && enable_huffman_code;
};
auto src_category = TensorCategory(src_tensor); auto src_category = TensorCategory(src_tensor);
if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) && if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) &&
src_tensor->data() != nullptr && src_tensor->data()->size() > 0) { src_tensor->data() != nullptr && src_tensor->data()->size() > 0) {
@ -116,21 +108,6 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
return RET_ERROR; return RET_ERROR;
} }
} else { } else {
if (NeedHuffmanDecode()) {
auto dst_data = dst_tensor->MutableData();
if (dst_data == nullptr) {
MS_LOG(ERROR) << "Data from tensor is nullptr";
return RET_NULL_PTR;
}
std::string encode_str(src_tensor->data()->begin(), src_tensor->data()->end());
auto huffman_decode = std::make_unique<lite::huffman_decode>();
auto ret = huffman_decode->DoHuffmanDecode(encode_str, dst_data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoHuffmanDecode failed.";
return ret;
}
copyed_tensor_idxes_.emplace_back(tensor_index);
}
if (WeightTensorNeedCopy(model, tensor_index)) { if (WeightTensorNeedCopy(model, tensor_index)) {
auto dst_data = dst_tensor->MutableData(); auto dst_data = dst_tensor->MutableData();
if (dst_data == nullptr) { if (dst_data == nullptr) {
@ -138,7 +115,11 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
return RET_NULL_PTR; return RET_NULL_PTR;
} }
if (NeedUnPack()) { if (NeedUnPack()) {
DequantUtil::UnPackToInt(src_tensor, dst_data); auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "unpack to int failed.";
return RET_NULL_PTR;
}
} else { } else {
memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size()); memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size());
} }
@ -148,9 +129,13 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
auto dst_data = dst_tensor->MutableData(); auto dst_data = dst_tensor->MutableData();
if (dst_data == nullptr) { if (dst_data == nullptr) {
MS_LOG(ERROR) << "Data from tensor is nullptr"; MS_LOG(ERROR) << "Data from tensor is nullptr";
return RET_NULL_PTR; return RET_ERROR;
}
auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "unpack to int failed.";
return RET_ERROR;
} }
DequantUtil::UnPackToInt(src_tensor, dst_data);
copyed_tensor_idxes_.emplace_back(tensor_index); copyed_tensor_idxes_.emplace_back(tensor_index);
} else { } else {
dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data())); dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data()));

View File

@ -227,8 +227,8 @@ function Run_Converter() {
fi fi
model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'` model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'`
echo ${model_name} >> "${run_converter_log_file}" echo ${model_name} >> "${run_converter_log_file}"
echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true' >> "${run_converter_log_file}" echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightChannel=0' >> "${run_converter_log_file}"
./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true ./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightChannel=0
if [ $? = 0 ]; then if [ $? = 0 ]; then
converter_result='converter weight_quant '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} converter_result='converter weight_quant '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file}
else else

View File

@ -217,26 +217,14 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla
const FuncGraphPtr &new_graph) { const FuncGraphPtr &new_graph) {
// quant // quant
if (config->quantType == schema::QuantType_PostTraining) { if (config->quantType == schema::QuantType_PostTraining) {
if (!quant::WeightQuantizer::IsPosNum(config->bitNum)) { this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, config->bitNum);
MS_LOG(ERROR) << "bitNum must be valid pos num.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
this->mQuantizer =
std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, std::stoi(config->bitNum));
if (mQuantizer == nullptr) { if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR; return RET_ERROR;
} }
} else if (config->quantType == schema::QuantType_WeightQuant) { } else if (config->quantType == schema::QuantType_WeightQuant) {
if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) { this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, *config);
MS_LOG(ERROR) << "weight quant input param error";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->configFile, config->quantWeightSize,
config->quantWeightChannel, config->bitNum);
if (mQuantizer == nullptr) { if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed"; MS_LOG(ERROR) << "New WeightQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
@ -255,10 +243,15 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla
return RET_OK; return RET_OK;
} }
int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph) { int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph,
if (config->quantType == schema::QuantType_WeightQuant && config->bitNum == "8" && config->enableHuffmanCode) { bool enableHuffmanCode) {
auto huffman_encode = std::make_unique<lite::huffman_encode>(); if (config->quantType == schema::QuantType_WeightQuant && enableHuffmanCode) {
auto status = huffman_encode->DoHuffmanEncode(new_graph); if (config->bitNum < 16 && config->bitNum > 8) {
MS_LOG(WARNING) << "don't support huffman encode when 8 < bitNum < 16 currently.";
return RET_OK;
}
auto huffman_encode = std::make_unique<lite::HuffmanEncode>();
auto status = huffman_encode->DoHuffmanEncode(new_graph, config->bitNum);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Huffman encode failed."; MS_LOG(ERROR) << "Huffman encode failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -322,7 +315,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
return nullptr; return nullptr;
} }
status = DoHuffmanEncode(config, new_graph); status = DoHuffmanEncode(config, new_graph, false);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Do HuffmanCode failed."; MS_LOG(ERROR) << "Do HuffmanCode failed.";
return nullptr; return nullptr;

View File

@ -59,7 +59,7 @@ class AnfTransform {
int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph); int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph);
int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph); int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph, bool enableHuffmanCode);
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

View File

@ -38,16 +38,12 @@ Flags::Flags() {
"UINT8 | DEFAULT", "UINT8 | DEFAULT",
"DEFAULT"); "DEFAULT");
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", ""); AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", "");
AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8"); AddFlag(&Flags::bitNumIn, "bitNum", "Weight quantization bitNum", "8");
AddFlag(&Flags::quantWeightSize, "quantWeightSize", "Weight quantization size threshold", "0"); AddFlag(&Flags::quantWeightSizeIn, "quantWeightSize", "Weight quantization size threshold", "0");
AddFlag(&Flags::quantWeightChannel, "quantWeightChannel", "Channel threshold for weight quantization", "16"); AddFlag(&Flags::quantWeightChannelIn, "quantWeightChannel", "Channel threshold for weight quantization", "16");
AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", ""); AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", "");
AddFlag(&Flags::enableHuffmanCodeIn, "enableHuffmanCode",
"whether the weight quant model is going to use huffman code."
"true | false",
"false");
AddFlag(&Flags::trainModelIn, "trainModel", AddFlag(&Flags::trainModelIn, "trainModel",
"whether the model is going to be trained on device." "whether the model is going to be trained on device. "
"true | false", "true | false",
"false"); "false");
} }
@ -107,7 +103,41 @@ int Flags::InitFmk() {
return RET_OK; return RET_OK;
} }
int Flags::InitQuantType() { bool Flags::IsValidNum(const std::string &str, int *num) {
char *ptr;
*num = strtol(str.c_str(), &ptr, 10);
return ptr == (str.c_str() + str.size());
}
int Flags::QuantParamInputCheck() {
if (!Flags::IsValidNum(this->quantWeightChannelIn, &this->quantWeightChannel)) {
std::cerr << "quantWeightChannel should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (this->quantWeightChannel < 0) {
std::cerr << "quantWeightChannel should be greater than or equal to zero.";
return RET_INPUT_PARAM_INVALID;
}
if (!Flags::IsValidNum(this->quantWeightSizeIn, &this->quantWeightSize)) {
std::cerr << "quantWeightSize should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (this->quantWeightSize < 0) {
std::cerr << "quantWeightSize should be greater than or equal to zero.";
return RET_INPUT_PARAM_INVALID;
}
if (!Flags::IsValidNum(this->bitNumIn, &this->bitNum)) {
std::cerr << "bitNum should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (this->bitNum <= 0 || this->bitNum > 16) {
std::cerr << "bitNum should be greater than zero and lesser than 16 currently.";
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
}
int Flags::InitQuantParam() {
if (this->quantTypeIn == "WeightQuant") { if (this->quantTypeIn == "WeightQuant") {
this->quantType = QuantType_WeightQuant; this->quantType = QuantType_WeightQuant;
} else if (this->quantTypeIn == "PostTraining") { } else if (this->quantTypeIn == "PostTraining") {
@ -118,19 +148,9 @@ int Flags::InitQuantType() {
std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining"; std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
return RET_OK;
}
int Flags::InitHuffmanCode() { auto ret = QuantParamInputCheck();
if (this->enableHuffmanCodeIn == "true") { return ret;
this->enableHuffmanCode = true;
} else if (this->enableHuffmanCodeIn == "false") {
this->enableHuffmanCode = false;
} else {
std::cerr << "INPUT ILLEGAL: trainModel must be true|false ";
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
} }
int Flags::InitTrainModel() { int Flags::InitTrainModel() {
@ -218,15 +238,9 @@ int Flags::Init(int argc, const char **argv) {
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
ret = InitQuantType(); ret = InitQuantParam();
if (ret != RET_OK) { if (ret != RET_OK) {
std::cerr << "Init quant type failed."; std::cerr << "Init quant param failed.";
return RET_INPUT_PARAM_INVALID;
}
ret = InitHuffmanCode();
if (ret != RET_OK) {
std::cerr << "Init huffman code failed.";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }

View File

@ -49,9 +49,11 @@ class Flags : public virtual mindspore::lite::FlagParser {
int InitFmk(); int InitFmk();
int InitQuantType(); bool IsValidNum(const std::string &str, int *num);
int InitHuffmanCode(); int QuantParamInputCheck();
int InitQuantParam();
int InitTrainModel(); int InitTrainModel();
@ -76,12 +78,13 @@ class Flags : public virtual mindspore::lite::FlagParser {
TypeId inputDataType; TypeId inputDataType;
TypeId outputDataType; TypeId outputDataType;
// used for post-trainning-weight // used for post-trainning-weight
std::string quantWeightSize; std::string quantWeightSizeIn;
std::string bitNum; int quantWeightSize;
std::string bitNumIn;
int bitNum;
std::string configFile; std::string configFile;
std::string quantWeightChannel; std::string quantWeightChannelIn;
std::string enableHuffmanCodeIn; int quantWeightChannel;
bool enableHuffmanCode = false;
std::string trainModelIn; std::string trainModelIn;
bool trainModel = false; bool trainModel = false;
}; };

View File

@ -18,18 +18,51 @@
#include <utility> #include <utility>
#include <iostream> #include <iostream>
#include <memory>
#include <vector>
#include "securec/include/securec.h" #include "src/dequant.h"
#include "src/param_value_lite.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) { STATUS HuffmanEncode::GetParamValueLitePtr(const std::shared_ptr<AnfNode> &input_node, ParamValueLitePtr *param_value) {
if (!input_node->isa<Parameter>()) {
return RET_CONTINUE;
}
auto abstract_base = input_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope();
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
if (abstract_tensor->element() == nullptr) {
MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
auto tensor_type = abstract_tensor->element()->GetTypeTrack();
MS_ASSERT(tensor_type != nullptr);
auto tensor_type_id = tensor_type->type_id();
if (tensor_type_id != kNumberTypeInt8) {
return RET_CONTINUE;
}
auto param_node = input_node->cast<ParameterPtr>();
if (param_node == nullptr) {
MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
if (!param_node->has_default()) {
MS_LOG(WARNING) << "param_node don't have default: " << input_node->fullname_with_scope();
return RET_CONTINUE;
}
*param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
return RET_OK;
}
STATUS HuffmanEncode::DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num) {
auto cnodes = func_graph->GetOrderedCnodes(); auto cnodes = func_graph->GetOrderedCnodes();
STATUS status;
for (auto &cnode : cnodes) { for (auto &cnode : cnodes) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) { if (primitive_c == nullptr) {
@ -41,45 +74,33 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) {
} }
for (size_t i = 1; i < cnode->inputs().size(); i++) { for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i); auto input_node = cnode->input(i);
if (!input_node->isa<Parameter>()) { ParamValueLitePtr param_value;
auto status = GetParamValueLitePtr(input_node, &param_value);
if (status == RET_CONTINUE) {
continue; continue;
} } else if (status == RET_ERROR) {
auto abstract_base = input_node->abstract(); MS_LOG(ERROR) << "Get param value lite ptr failed. " << cnode->fullname_with_scope();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR; return RET_ERROR;
} }
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope();
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
if (abstract_tensor->element() == nullptr) {
MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
auto tensor_type = abstract_tensor->element()->GetTypeTrack();
MS_ASSERT(tensor_type != nullptr);
auto tensor_type_id = tensor_type->type_id();
if (tensor_type_id != kNumberTypeInt8) {
continue;
}
auto param_node = input_node->cast<ParameterPtr>();
if (param_node == nullptr) {
MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
if (!param_node->has_default()) {
MS_LOG(WARNING) << "param_node don't have default: " << cnode->fullname_with_scope();
continue;
}
ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
size_t elem_count = param_value->tensor_shape_size(); size_t elem_count = param_value->tensor_shape_size();
size_t packed_size = param_value->tensor_size();
auto *raw_datas = static_cast<int8_t *>(param_value->tensor_addr()); auto *raw_datas = static_cast<int8_t *>(param_value->tensor_addr());
if (raw_datas == nullptr) { if (raw_datas == nullptr) {
MS_LOG(ERROR) << "rawDatas is nullptr"; MS_LOG(ERROR) << "rawDatas is nullptr";
return RET_ERROR; return RET_ERROR;
} }
if (bit_num < 8 && bit_num > 0) {
auto dst_data = new (std::nothrow) int8_t[elem_count];
if (dst_data == nullptr) {
MS_LOG(ERROR) << "new int8_t[] failed";
return RET_ERROR;
}
DequantUtil::UnpackUtil<int8_t, uint8_t>(raw_datas, packed_size, bit_num, dst_data);
if (memcpy_s(raw_datas, elem_count, dst_data, elem_count) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed.";
return RET_MEMORY_FAILED;
}
}
HuffmanPriorityQueue pq; HuffmanPriorityQueue pq;
status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq); status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq);
if (status != RET_OK) { if (status != RET_OK) {
@ -97,12 +118,14 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) {
return status; return status;
} }
size_t ch_size = huffman_encoded_str_.length(); size_t ch_size = huffman_encoded_str_.length();
if (ch_size < elem_count) { if (ch_size < packed_size) {
auto encode_data = new (std::nothrow) char[ch_size]; auto encode_data = new (std::nothrow) char[ch_size];
if (encode_data == nullptr) { if (encode_data == nullptr) {
MS_LOG(ERROR) << "new char[] failed."; MS_LOG(ERROR) << "new char[] failed.";
delete[] raw_datas;
return RET_MEMORY_FAILED; return RET_MEMORY_FAILED;
} }
delete[] raw_datas;
if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) { if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) {
MS_LOG(ERROR) << "memcpy_s failed."; MS_LOG(ERROR) << "memcpy_s failed.";
delete[] encode_data; delete[] encode_data;
@ -118,7 +141,7 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) {
return RET_SUCCESS; return RET_SUCCESS;
} }
STATUS huffman_encode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) { STATUS HuffmanEncode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) {
MS_ASSERT(data != nullptr); MS_ASSERT(data != nullptr);
std::map<int8_t, size_t> freq_map; std::map<int8_t, size_t> freq_map;
@ -166,7 +189,7 @@ STATUS huffman_encode::GetHuffmanPriorityQueue(const int8_t *data, const size_t
return RET_OK; return RET_OK;
} }
void huffman_encode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) { void HuffmanEncode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) {
if (is_left_node) { if (is_left_node) {
node->code = node->parent->code + "0"; node->code = node->parent->code + "0";
} else { } else {
@ -185,7 +208,7 @@ void huffman_encode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_lef
} }
} }
STATUS huffman_encode::BuildHuffmanTree(HuffmanPriorityQueue *pq) { STATUS HuffmanEncode::BuildHuffmanTree(HuffmanPriorityQueue *pq) {
HuffmanNodePtr root = nullptr; HuffmanNodePtr root = nullptr;
while (!pq->empty()) { while (!pq->empty()) {
@ -228,7 +251,7 @@ STATUS huffman_encode::BuildHuffmanTree(HuffmanPriorityQueue *pq) {
return RET_OK; return RET_OK;
} }
STATUS huffman_encode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) { STATUS HuffmanEncode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) {
unsigned char out_c; unsigned char out_c;
string code_str; string code_str;
std::map<int, string>::iterator iter; std::map<int, string>::iterator iter;
@ -270,7 +293,7 @@ STATUS huffman_encode::DoHuffmanCompress(const int8_t *input_datas, const size_t
return RET_OK; return RET_OK;
} }
huffman_encode::~huffman_encode() { HuffmanEncode::~HuffmanEncode() {
for (auto &node : this->huffman_nodes_) { for (auto &node : this->huffman_nodes_) {
delete node; delete node;
} }

View File

@ -23,9 +23,12 @@
#include <vector> #include <vector>
#include <queue> #include <queue>
#include <map> #include <map>
#include <memory>
#include <fstream> #include <fstream>
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "src/ops/primitive_c.h" #include "src/ops/primitive_c.h"
#include "securec/include/securec.h"
#include "src/param_value_lite.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
namespace mindspore { namespace mindspore {
@ -49,13 +52,15 @@ struct cmp {
}; };
using HuffmanPriorityQueue = std::priority_queue<HuffmanNodePtr, std::vector<HuffmanNodePtr>, cmp>; using HuffmanPriorityQueue = std::priority_queue<HuffmanNodePtr, std::vector<HuffmanNodePtr>, cmp>;
class huffman_encode { class HuffmanEncode {
public: public:
huffman_encode() = default; HuffmanEncode() = default;
~huffman_encode(); ~HuffmanEncode();
STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph); STATUS GetParamValueLitePtr(const std::shared_ptr<AnfNode> &input_node, ParamValueLitePtr *param_value);
STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num);
private: private:
std::map<int, std::string> huffman_table_; std::map<int, std::string> huffman_table_;

View File

@ -25,52 +25,16 @@ using std::string;
using std::vector; using std::vector;
namespace mindspore::lite::quant { namespace mindspore::lite::quant {
bool WeightQuantizer::IsPosNum(const std::string &str) {
for (size_t i = 0; i < str.size(); i++) {
if (str.at(i) < '0' || str.at(i) > '9') {
return false;
}
if (str.at(i) == '0' && i == 0 && str.size() != 1) {
return false;
}
}
return true;
}
STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) {
MS_ASSERT(config != nullptr);
if (!WeightQuantizer::IsPosNum(config->quantWeightChannel)) {
MS_LOG(ERROR) << "quantWeightChannel must be valid pos num.";
return RET_ERROR;
}
if (!WeightQuantizer::IsPosNum(config->quantWeightSize)) {
MS_LOG(ERROR) << "quantWeightSize must be valid pos num.";
return RET_ERROR;
}
if (!WeightQuantizer::IsPosNum(config->bitNum)) {
MS_LOG(ERROR) << "bitNum must be valid pos num.";
return RET_ERROR;
}
int bitNum = std::stoi(config->bitNum);
if (bitNum <= 0 || bitNum > 16) {
MS_LOG(ERROR) << "bitNum should be more than 0 and less than 16 currently.";
return RET_ERROR;
}
return RET_OK;
}
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) { WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) {
quant_strategy_ = std::make_unique<QuantStrategy>(0, 0); quant_strategy_ = std::make_unique<QuantStrategy>(0, 0);
config_param_ = config; config_param_ = config;
} }
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const string &weightSize, WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config) : Quantizer(graph) {
const std::string &convWeightChannelThreshold, const std::string &bitNum) this->config_file_ = config.configFile;
: Quantizer(graph) { auto quantSize = config.quantWeightSize;
this->config_file_ = config_file; this->bit_num_ = config.bitNum;
auto quantSize = static_cast<size_t>(std::stoull(weightSize)); auto convQuantWeightChannelThreshold = config.quantWeightChannel;
this->bit_num_ = static_cast<size_t>(std::stoull(bitNum));
auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold));
quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold); quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold);
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1)); quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
@ -222,7 +186,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
return RET_OK; return RET_OK;
} }
STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) { STATUS WeightQuantizer::DoLstmQuantize(CNodePtr cnode) {
MS_ASSERT(cnode != nullptr); MS_ASSERT(cnode != nullptr);
auto op_name = cnode->fullname_with_scope(); auto op_name = cnode->fullname_with_scope();
@ -233,110 +197,29 @@ STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) {
MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size(); MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size();
return RET_ERROR; return RET_ERROR;
} }
{
auto weight_i = cnode->input(2); auto status = ProcessLstmWeightByIndex(cnode, primitive_c, 2);
ParameterPtr param_node; if (status != RET_OK) {
ParamValueLitePtr param_value; MS_LOG(ERROR) << "Process lstm weight i failed.";
GetLiteParameter(weight_i, &param_node, &param_value); return RET_ERROR;
if (param_node == nullptr || param_value == nullptr) { }
MS_LOG(ERROR) << "GetLiteParameter error"; status = ProcessLstmWeightByIndex(cnode, primitive_c, 3);
return RET_ERROR; if (status != RET_OK) {
} MS_LOG(ERROR) << "Process lstm weight h failed.";
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { return RET_ERROR;
MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; }
return RET_OK; if (cnode->inputs().size() > 4) {
} status = ProcessLstmWeightByIndex(cnode, primitive_c, 4);
if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < "
<< quant_strategy_->mWeightSize;
return RET_OK;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 1);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 1);
}
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status; MS_LOG(ERROR) << "Process lstm bias failed.";
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR; return RET_ERROR;
} }
} }
{
auto weight_h = cnode->input(3); return status;
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(weight_h, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_ERROR;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 2);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 2);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
}
{
if (cnode->inputs().size() > 4) {
auto bias = cnode->input(4);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(bias, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_ERROR;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 3);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, 3);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
}
}
return RET_OK;
} }
STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) { STATUS WeightQuantizer::DoGatherQuantize(CNodePtr cnode) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
MS_ASSERT(primitive_c != nullptr); MS_ASSERT(primitive_c != nullptr);
@ -375,6 +258,46 @@ STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) {
return RET_OK; return RET_OK;
} }
STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const std::shared_ptr<PrimitiveC> &primitive_c,
const int &index) {
auto op_name = cnode->fullname_with_scope();
auto weight_i = cnode->input(index);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(weight_i, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_OK;
}
if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < "
<< quant_strategy_->mWeightSize;
return RET_OK;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, index - 1);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, index - 1);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
return RET_OK;
}
constexpr float relative_tolerance = 1e-5; constexpr float relative_tolerance = 1e-5;
constexpr float abs_tolerance = 1e-4; constexpr float abs_tolerance = 1e-4;
@ -510,37 +433,28 @@ STATUS WeightQuantizer::RunFp32Graph(FuncGraphPtr func_graph) {
return RET_OK; return RET_OK;
} }
STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { STATUS WeightQuantizer::DoMixedQuantize(const FuncGraphPtr &func_graph) {
// 0.2 Parse input calib files
auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
if (status != RET_OK) {
MS_LOG(ERROR) << "CollectCalibInputs fail";
return RET_ERROR;
}
MS_LOG(DEBUG) << "run fp32 model";
status = RunFp32Graph(func_graph);
if (status != RET_OK) {
return RET_ERROR;
}
auto cnodes = func_graph->GetOrderedCnodes(); auto cnodes = func_graph->GetOrderedCnodes();
int status = RET_OK;
for (auto &cnode : cnodes) { for (auto &cnode : cnodes) {
auto op_type = NodePrimitiveType(cnode); auto op_type = NodePrimitiveType(cnode);
if (op_type == schema::PrimitiveType_Lstm) { if (op_type == schema::PrimitiveType_Lstm) {
status = DoLstmQuntize(cnode); status = DoLstmQuantize(cnode);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "DoLstmQuntize error"; MS_LOG(ERROR) << "DoLstmQuantize error";
return RET_ERROR; return RET_ERROR;
} }
} else if (op_type == schema::PrimitiveType_Gather) { } else if (op_type == schema::PrimitiveType_Gather) {
status = DoGatherQuntize(cnode); status = DoGatherQuantize(cnode);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "DoGatherQuntize error"; MS_LOG(ERROR) << "DoGatherQuantize error";
return RET_ERROR; return RET_ERROR;
} }
} }
} }
return status;
}
STATUS WeightQuantizer::CheckImageCnt() {
auto image_cnt = images_.at(0).size(); auto image_cnt = images_.at(0).size();
if (!config_param_.input_shapes.empty()) { if (!config_param_.input_shapes.empty()) {
if (config_param_.input_shapes.size() != image_cnt) { if (config_param_.input_shapes.size() != image_cnt) {
@ -548,7 +462,62 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
return RET_ERROR; return RET_ERROR;
} }
} }
return RET_OK;
}
STATUS WeightQuantizer::GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name,
ParameterPtr *param_node, ParamValueLitePtr *param_value) {
if (!input_node->isa<Parameter>()) {
MS_LOG(WARNING) << op_name << " the second input is not parameter";
return RET_CONTINUE;
}
*param_node = input_node->cast<ParameterPtr>();
if (!(*param_node)->has_default()) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
return RET_CONTINUE;
}
*param_value = std::static_pointer_cast<ParamValueLite>((*param_node)->default_param());
if (*param_value == nullptr) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
return RET_CONTINUE;
}
if ((*param_value)->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << op_name << " the second input type is not float";
return RET_CONTINUE;
}
return RET_OK;
}
STATUS WeightQuantizer::TryQuant(const int &bit_num_t, const ParameterPtr &param_node,
const ParamValueLitePtr &param_value, const std::shared_ptr<PrimitiveC> &primitive_c) {
int status;
type_id_ = TypeId::kNumberTypeInt8;
int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
if (type_id_ == TypeId::kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
bit_num_t, true);
} else if (type_id_ == TypeId::kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
bit_num_t, true);
} else {
MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
return RET_ERROR;
}
if (status != RET_OK) {
MS_LOG(ERROR) << "quant filter failed.";
return RET_ERROR;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
return status;
}
STATUS WeightQuantizer::DoQuantSearch(const FuncGraphPtr &func_graph) {
auto cnodes = func_graph->GetOrderedCnodes();
auto image_cnt = images_.at(0).size();
int status = RET_OK;
for (auto iter = cnodes.end(); iter != cnodes.begin();) { for (auto iter = cnodes.end(); iter != cnodes.begin();) {
auto cnode = *(--iter); auto cnode = *(--iter);
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
@ -561,22 +530,10 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
<< " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type()); << " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type());
if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) { if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) {
auto input_node = cnode->input(2); auto input_node = cnode->input(2);
if (!input_node->isa<Parameter>()) { ParameterPtr param_node;
MS_LOG(WARNING) << op_name << " the second input is not parameter"; ParamValueLitePtr param_value;
continue; status = GetParamNodeAndValue(input_node, op_name, &param_node, &param_value);
} if (status == RET_CONTINUE) {
auto param_node = input_node->cast<ParameterPtr>();
if (!param_node->has_default()) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
continue;
}
auto param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
if (param_value == nullptr) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
continue;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << op_name << " the second input type is not float";
continue; continue;
} }
// copy origin data in case to recover // copy origin data in case to recover
@ -591,27 +548,9 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
} }
// 1. try quant // 1. try quant
for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) { for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) {
type_id_ = TypeId::kNumberTypeInt8; status = TryQuant(bit_num_t, param_node, param_value, primitive_c);
int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
if (type_id_ == TypeId::kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
quant_min_t, bit_num_t, true);
} else if (type_id_ == TypeId::kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
quant_min_t, bit_num_t, true);
} else {
MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
return RET_ERROR;
}
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "quant filter fail."; MS_LOG(ERROR) << "TryQuant failed.";
return RET_ERROR;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR; return RET_ERROR;
} }
// 2. evaluate the quant // 2. evaluate the quant
@ -679,6 +618,41 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
free(origin_data); free(origin_data);
} // if: conv and matmul } // if: conv and matmul
} // end loop: all cnode } // end loop: all cnode
return status;
}
STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
// 0.2 Parse input calib files
auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
if (status != RET_OK) {
MS_LOG(ERROR) << "CollectCalibInputs failed.";
return RET_ERROR;
}
status = RunFp32Graph(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "RunFp32Graph failed.";
return RET_ERROR;
}
status = DoMixedQuantize(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoMixedQuantize failed.";
return RET_ERROR;
}
status = CheckImageCnt();
if (status != RET_OK) {
MS_LOG(ERROR) << "CheckImageCnt failed.";
return RET_ERROR;
}
status = DoQuantSearch(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantSearch failed.";
return RET_ERROR;
}
for (const auto &kv : opname_bit_) { for (const auto &kv : opname_bit_) {
MS_LOG(INFO) << "op: " << kv.first << " bit:" << kv.second; MS_LOG(INFO) << "op: " << kv.first << " bit:" << kv.second;
} }
@ -709,15 +683,15 @@ STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) {
return RET_ERROR; return RET_ERROR;
} }
} else if (op_type == schema::PrimitiveType_Lstm) { } else if (op_type == schema::PrimitiveType_Lstm) {
auto status = DoLstmQuntize(cnode); auto status = DoLstmQuantize(cnode);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "DoLstmQuntize error"; MS_LOG(ERROR) << "DoLstmQuantize error";
return RET_ERROR; return RET_ERROR;
} }
} else if (op_type == schema::PrimitiveType_Gather) { } else if (op_type == schema::PrimitiveType_Gather) {
auto status = DoGatherQuntize(cnode); auto status = DoGatherQuantize(cnode);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "DoGatherQuntize error"; MS_LOG(ERROR) << "DoGatherQuantize error";
return RET_ERROR; return RET_ERROR;
} }
} else { } else {

View File

@ -36,18 +36,18 @@
namespace mindspore::lite::quant { namespace mindspore::lite::quant {
class WeightQuantizer : public Quantizer { class WeightQuantizer : public Quantizer {
public: public:
WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const std::string &weightSize, WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config);
const std::string &covWeightChannelThreshold, const std::string &bitNum);
WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config); WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config);
~WeightQuantizer(); ~WeightQuantizer();
STATUS DoQuantize(FuncGraphPtr func_graph) override; STATUS DoQuantize(FuncGraphPtr func_graph) override;
STATUS DoConvQuantize(CNodePtr); STATUS DoConvQuantize(CNodePtr);
STATUS DoMulQuantize(CNodePtr); STATUS DoMulQuantize(CNodePtr);
STATUS DoLstmQuntize(CNodePtr cnode); STATUS DoLstmQuantize(CNodePtr cnode);
STATUS DoGatherQuntize(CNodePtr cnode); STATUS DoGatherQuantize(CNodePtr cnode);
static STATUS WeightQuantInputCheck(const converter::Flags *config);
static bool IsPosNum(const std::string &str); STATUS ProcessLstmWeightByIndex(const CNodePtr &cnode, const std::shared_ptr<PrimitiveC> &primitive_c,
const int &index);
int quant_max_{127}; int quant_max_{127};
int quant_min_{-128}; int quant_min_{-128};
@ -66,6 +66,14 @@ class WeightQuantizer : public Quantizer {
STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c); STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c);
STATUS DoFixedQuant(FuncGraphPtr); STATUS DoFixedQuant(FuncGraphPtr);
STATUS RunFp32Graph(FuncGraphPtr); STATUS RunFp32Graph(FuncGraphPtr);
STATUS DoMixedQuantize(const FuncGraphPtr &func_graph);
STATUS CheckImageCnt();
STATUS GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name,
ParameterPtr *param_node, ParamValueLitePtr *param_value);
STATUS TryQuant(const int &bit_num_t, const ParameterPtr &param_node, const ParamValueLitePtr &param_value,
const std::shared_ptr<PrimitiveC> &primitive_c);
STATUS DoQuantSearch(const FuncGraphPtr &func_graph);
}; };
} // namespace mindspore::lite::quant } // namespace mindspore::lite::quant
#endif #endif