forked from mindspore-Ecosystem/mindspore
!11695 [MS][LITE]huffman code support 1~8 bit && change it to internal interface
From: @jianghui58 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
5412b6ba3f
|
@ -14,7 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "src/dequant.h"
|
||||
#include "src/huffman_decode.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
|
||||
|
@ -34,13 +37,24 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
|
|||
}
|
||||
}
|
||||
|
||||
void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) {
|
||||
int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) {
|
||||
MS_ASSERT(input_tensor != nullptr);
|
||||
MS_ASSERT(unpack_int_data != nullptr);
|
||||
auto quant_params = input_tensor->quantParams();
|
||||
if (quant_params == nullptr) {
|
||||
MS_LOG(ERROR) << "low bits quantparams is empty.";
|
||||
return;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto enable_huffman_code = input_tensor->enableHuffmanCode();
|
||||
if (enable_huffman_code) {
|
||||
std::string encode_str(input_tensor->data()->begin(), input_tensor->data()->end());
|
||||
auto huffman_decode = std::make_unique<lite::HuffmanDecode>();
|
||||
auto ret = huffman_decode->DoHuffmanDecode(encode_str, unpack_int_data);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoHuffmanDecode failed.";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
int origin_bit = quant_params->Get(0)->numBits();
|
||||
if (origin_bit < 8 && origin_bit > 0) {
|
||||
|
@ -48,6 +62,7 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i
|
|||
} else if (origin_bit < 16 && origin_bit > 8) {
|
||||
UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors,
|
||||
|
|
|
@ -31,7 +31,7 @@ class DequantUtil {
|
|||
public:
|
||||
static float *DequantWeight(lite::Tensor *input_tensor);
|
||||
|
||||
static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);
|
||||
static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);
|
||||
|
||||
static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors,
|
||||
TypeId data_type, bool need_restore = true);
|
||||
|
@ -110,6 +110,21 @@ class DequantUtil {
|
|||
return dequant_datas;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
static void UnpackUtil(const T1 *weight_data, int pack_size, int origin_bit, void *unpack_int_data) {
|
||||
if (weight_data == nullptr || unpack_int_data == nullptr) {
|
||||
MS_LOG(ERROR) << "data is nullptr";
|
||||
return;
|
||||
}
|
||||
std::queue<bool> unpack_bit_data;
|
||||
size_t count = 0;
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
T2 pack_data = (static_cast<const T2 *>(static_cast<const void *>(weight_data)))[i];
|
||||
bool is_last = i == pack_size - 1;
|
||||
UnPackData<T1, T2>(origin_bit, pack_data, &unpack_bit_data, unpack_int_data, &count, is_last);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T1, typename T2>
|
||||
static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int,
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
STATUS huffman_decode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) {
|
||||
STATUS HuffmanDecode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) {
|
||||
if (decoded_data == nullptr) {
|
||||
MS_LOG(ERROR) << "decoded_data is nullptr.";
|
||||
return RET_ERROR;
|
||||
|
@ -64,7 +64,7 @@ STATUS huffman_decode::DoHuffmanDecode(const std::string &input_str, void *decod
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS huffman_decode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) {
|
||||
STATUS HuffmanDecode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) {
|
||||
HuffmanNodePtr cur_node, tmp_node, new_node;
|
||||
|
||||
auto huffman_keys = Str2Vec(std::move(keys));
|
||||
|
@ -121,7 +121,7 @@ STATUS huffman_decode::RebuildHuffmanTree(std::string keys, std::string codes, c
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS huffman_decode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) {
|
||||
STATUS HuffmanDecode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) {
|
||||
HuffmanNodePtr cur_node = root;
|
||||
bool pseudo_eof = false;
|
||||
size_t pos = 0;
|
||||
|
@ -157,7 +157,7 @@ STATUS huffman_decode::DoHuffmanDecompress(HuffmanNodePtr root, std::string enco
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
huffman_decode::~huffman_decode() {
|
||||
HuffmanDecode::~HuffmanDecode() {
|
||||
for (auto &node : this->huffman_nodes_) {
|
||||
delete node;
|
||||
}
|
||||
|
|
|
@ -38,11 +38,11 @@ struct HuffmanNode {
|
|||
};
|
||||
using HuffmanNodePtr = HuffmanNode *;
|
||||
|
||||
class huffman_decode {
|
||||
class HuffmanDecode {
|
||||
public:
|
||||
huffman_decode() = default;
|
||||
HuffmanDecode() = default;
|
||||
|
||||
~huffman_decode();
|
||||
~HuffmanDecode();
|
||||
|
||||
STATUS DoHuffmanDecode(const std::string &input_str, void *decoded_data);
|
||||
|
||||
|
|
|
@ -28,7 +28,6 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "src/lite_model.h"
|
||||
#include "src/dequant.h"
|
||||
#include "src/huffman_decode.h"
|
||||
#if SUPPORT_NPU
|
||||
#include "src/runtime/agent/npu/npu_manager.h"
|
||||
#include "src/runtime/agent/npu/optimizer/npu_pass_manager.h"
|
||||
|
@ -96,13 +95,6 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
|
|||
int org_size = dst_tensor->Size();
|
||||
return (pack_size != org_size) && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16);
|
||||
};
|
||||
auto NeedHuffmanDecode = [&src_tensor, &dst_tensor]() -> bool {
|
||||
auto data_type = src_tensor->dataType();
|
||||
auto enable_huffman_code = src_tensor->enableHuffmanCode();
|
||||
int pack_size = src_tensor->data()->size();
|
||||
int org_size = dst_tensor->Size();
|
||||
return (pack_size != org_size) && (data_type == kNumberTypeInt8) && enable_huffman_code;
|
||||
};
|
||||
auto src_category = TensorCategory(src_tensor);
|
||||
if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) &&
|
||||
src_tensor->data() != nullptr && src_tensor->data()->size() > 0) {
|
||||
|
@ -116,21 +108,6 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
|
|||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
if (NeedHuffmanDecode()) {
|
||||
auto dst_data = dst_tensor->MutableData();
|
||||
if (dst_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Data from tensor is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
std::string encode_str(src_tensor->data()->begin(), src_tensor->data()->end());
|
||||
auto huffman_decode = std::make_unique<lite::huffman_decode>();
|
||||
auto ret = huffman_decode->DoHuffmanDecode(encode_str, dst_data);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoHuffmanDecode failed.";
|
||||
return ret;
|
||||
}
|
||||
copyed_tensor_idxes_.emplace_back(tensor_index);
|
||||
}
|
||||
if (WeightTensorNeedCopy(model, tensor_index)) {
|
||||
auto dst_data = dst_tensor->MutableData();
|
||||
if (dst_data == nullptr) {
|
||||
|
@ -138,7 +115,11 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
if (NeedUnPack()) {
|
||||
DequantUtil::UnPackToInt(src_tensor, dst_data);
|
||||
auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "unpack to int failed.";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
} else {
|
||||
memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size());
|
||||
}
|
||||
|
@ -148,9 +129,13 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
|
|||
auto dst_data = dst_tensor->MutableData();
|
||||
if (dst_data == nullptr) {
|
||||
MS_LOG(ERROR) << "Data from tensor is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "unpack to int failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
DequantUtil::UnPackToInt(src_tensor, dst_data);
|
||||
copyed_tensor_idxes_.emplace_back(tensor_index);
|
||||
} else {
|
||||
dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data()));
|
||||
|
|
|
@ -227,8 +227,8 @@ function Run_Converter() {
|
|||
fi
|
||||
model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'`
|
||||
echo ${model_name} >> "${run_converter_log_file}"
|
||||
echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true' >> "${run_converter_log_file}"
|
||||
./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true
|
||||
echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightChannel=0' >> "${run_converter_log_file}"
|
||||
./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightChannel=0
|
||||
if [ $? = 0 ]; then
|
||||
converter_result='converter weight_quant '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file}
|
||||
else
|
||||
|
|
|
@ -217,26 +217,14 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla
|
|||
const FuncGraphPtr &new_graph) {
|
||||
// quant
|
||||
if (config->quantType == schema::QuantType_PostTraining) {
|
||||
if (!quant::WeightQuantizer::IsPosNum(config->bitNum)) {
|
||||
MS_LOG(ERROR) << "bitNum must be valid pos num.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->mQuantizer =
|
||||
std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, std::stoi(config->bitNum));
|
||||
this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, config->bitNum);
|
||||
if (mQuantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (config->quantType == schema::QuantType_WeightQuant) {
|
||||
if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) {
|
||||
MS_LOG(ERROR) << "weight quant input param error";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->configFile, config->quantWeightSize,
|
||||
config->quantWeightChannel, config->bitNum);
|
||||
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, *config);
|
||||
if (mQuantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New WeightQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
|
@ -255,10 +243,15 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph) {
|
||||
if (config->quantType == schema::QuantType_WeightQuant && config->bitNum == "8" && config->enableHuffmanCode) {
|
||||
auto huffman_encode = std::make_unique<lite::huffman_encode>();
|
||||
auto status = huffman_encode->DoHuffmanEncode(new_graph);
|
||||
int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph,
|
||||
bool enableHuffmanCode) {
|
||||
if (config->quantType == schema::QuantType_WeightQuant && enableHuffmanCode) {
|
||||
if (config->bitNum < 16 && config->bitNum > 8) {
|
||||
MS_LOG(WARNING) << "don't support huffman encode when 8 < bitNum < 16 currently.";
|
||||
return RET_OK;
|
||||
}
|
||||
auto huffman_encode = std::make_unique<lite::HuffmanEncode>();
|
||||
auto status = huffman_encode->DoHuffmanEncode(new_graph, config->bitNum);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Huffman encode failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
@ -322,7 +315,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
status = DoHuffmanEncode(config, new_graph);
|
||||
status = DoHuffmanEncode(config, new_graph, false);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do HuffmanCode failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -59,7 +59,7 @@ class AnfTransform {
|
|||
|
||||
int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph);
|
||||
|
||||
int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph);
|
||||
int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph, bool enableHuffmanCode);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,16 +38,12 @@ Flags::Flags() {
|
|||
"UINT8 | DEFAULT",
|
||||
"DEFAULT");
|
||||
AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", "");
|
||||
AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8");
|
||||
AddFlag(&Flags::quantWeightSize, "quantWeightSize", "Weight quantization size threshold", "0");
|
||||
AddFlag(&Flags::quantWeightChannel, "quantWeightChannel", "Channel threshold for weight quantization", "16");
|
||||
AddFlag(&Flags::bitNumIn, "bitNum", "Weight quantization bitNum", "8");
|
||||
AddFlag(&Flags::quantWeightSizeIn, "quantWeightSize", "Weight quantization size threshold", "0");
|
||||
AddFlag(&Flags::quantWeightChannelIn, "quantWeightChannel", "Channel threshold for weight quantization", "16");
|
||||
AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", "");
|
||||
AddFlag(&Flags::enableHuffmanCodeIn, "enableHuffmanCode",
|
||||
"whether the weight quant model is going to use huffman code."
|
||||
"true | false",
|
||||
"false");
|
||||
AddFlag(&Flags::trainModelIn, "trainModel",
|
||||
"whether the model is going to be trained on device."
|
||||
"whether the model is going to be trained on device. "
|
||||
"true | false",
|
||||
"false");
|
||||
}
|
||||
|
@ -107,7 +103,41 @@ int Flags::InitFmk() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int Flags::InitQuantType() {
|
||||
bool Flags::IsValidNum(const std::string &str, int *num) {
|
||||
char *ptr;
|
||||
*num = strtol(str.c_str(), &ptr, 10);
|
||||
return ptr == (str.c_str() + str.size());
|
||||
}
|
||||
|
||||
int Flags::QuantParamInputCheck() {
|
||||
if (!Flags::IsValidNum(this->quantWeightChannelIn, &this->quantWeightChannel)) {
|
||||
std::cerr << "quantWeightChannel should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (this->quantWeightChannel < 0) {
|
||||
std::cerr << "quantWeightChannel should be greater than or equal to zero.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (!Flags::IsValidNum(this->quantWeightSizeIn, &this->quantWeightSize)) {
|
||||
std::cerr << "quantWeightSize should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (this->quantWeightSize < 0) {
|
||||
std::cerr << "quantWeightSize should be greater than or equal to zero.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (!Flags::IsValidNum(this->bitNumIn, &this->bitNum)) {
|
||||
std::cerr << "bitNum should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (this->bitNum <= 0 || this->bitNum > 16) {
|
||||
std::cerr << "bitNum should be greater than zero and lesser than 16 currently.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Flags::InitQuantParam() {
|
||||
if (this->quantTypeIn == "WeightQuant") {
|
||||
this->quantType = QuantType_WeightQuant;
|
||||
} else if (this->quantTypeIn == "PostTraining") {
|
||||
|
@ -118,19 +148,9 @@ int Flags::InitQuantType() {
|
|||
std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Flags::InitHuffmanCode() {
|
||||
if (this->enableHuffmanCodeIn == "true") {
|
||||
this->enableHuffmanCode = true;
|
||||
} else if (this->enableHuffmanCodeIn == "false") {
|
||||
this->enableHuffmanCode = false;
|
||||
} else {
|
||||
std::cerr << "INPUT ILLEGAL: trainModel must be true|false ";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
return RET_OK;
|
||||
auto ret = QuantParamInputCheck();
|
||||
return ret;
|
||||
}
|
||||
|
||||
int Flags::InitTrainModel() {
|
||||
|
@ -218,15 +238,9 @@ int Flags::Init(int argc, const char **argv) {
|
|||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
ret = InitQuantType();
|
||||
ret = InitQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init quant type failed.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
ret = InitHuffmanCode();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init huffman code failed.";
|
||||
std::cerr << "Init quant param failed.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
|
|
|
@ -49,9 +49,11 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
|
||||
int InitFmk();
|
||||
|
||||
int InitQuantType();
|
||||
bool IsValidNum(const std::string &str, int *num);
|
||||
|
||||
int InitHuffmanCode();
|
||||
int QuantParamInputCheck();
|
||||
|
||||
int InitQuantParam();
|
||||
|
||||
int InitTrainModel();
|
||||
|
||||
|
@ -76,12 +78,13 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
TypeId inputDataType;
|
||||
TypeId outputDataType;
|
||||
// used for post-trainning-weight
|
||||
std::string quantWeightSize;
|
||||
std::string bitNum;
|
||||
std::string quantWeightSizeIn;
|
||||
int quantWeightSize;
|
||||
std::string bitNumIn;
|
||||
int bitNum;
|
||||
std::string configFile;
|
||||
std::string quantWeightChannel;
|
||||
std::string enableHuffmanCodeIn;
|
||||
bool enableHuffmanCode = false;
|
||||
std::string quantWeightChannelIn;
|
||||
int quantWeightChannel;
|
||||
std::string trainModelIn;
|
||||
bool trainModel = false;
|
||||
};
|
||||
|
|
|
@ -18,18 +18,51 @@
|
|||
|
||||
#include <utility>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "securec/include/securec.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#include "src/dequant.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) {
|
||||
STATUS HuffmanEncode::GetParamValueLitePtr(const std::shared_ptr<AnfNode> &input_node, ParamValueLitePtr *param_value) {
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
auto abstract_base = input_node->abstract();
|
||||
if (abstract_base == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||
if (abstract_tensor->element() == nullptr) {
|
||||
MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto tensor_type = abstract_tensor->element()->GetTypeTrack();
|
||||
MS_ASSERT(tensor_type != nullptr);
|
||||
auto tensor_type_id = tensor_type->type_id();
|
||||
if (tensor_type_id != kNumberTypeInt8) {
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
auto param_node = input_node->cast<ParameterPtr>();
|
||||
if (param_node == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!param_node->has_default()) {
|
||||
MS_LOG(WARNING) << "param_node don't have default: " << input_node->fullname_with_scope();
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
*param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS HuffmanEncode::DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num) {
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
STATUS status;
|
||||
for (auto &cnode : cnodes) {
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
|
@ -41,45 +74,33 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||
auto input_node = cnode->input(i);
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
ParamValueLitePtr param_value;
|
||||
auto status = GetParamValueLitePtr(input_node, ¶m_value);
|
||||
if (status == RET_CONTINUE) {
|
||||
continue;
|
||||
}
|
||||
auto abstract_base = input_node->abstract();
|
||||
if (abstract_base == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope();
|
||||
} else if (status == RET_ERROR) {
|
||||
MS_LOG(ERROR) << "Get param value lite ptr failed. " << cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||
if (abstract_tensor->element() == nullptr) {
|
||||
MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto tensor_type = abstract_tensor->element()->GetTypeTrack();
|
||||
MS_ASSERT(tensor_type != nullptr);
|
||||
auto tensor_type_id = tensor_type->type_id();
|
||||
if (tensor_type_id != kNumberTypeInt8) {
|
||||
continue;
|
||||
}
|
||||
auto param_node = input_node->cast<ParameterPtr>();
|
||||
if (param_node == nullptr) {
|
||||
MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!param_node->has_default()) {
|
||||
MS_LOG(WARNING) << "param_node don't have default: " << cnode->fullname_with_scope();
|
||||
continue;
|
||||
}
|
||||
ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
|
||||
size_t elem_count = param_value->tensor_shape_size();
|
||||
size_t packed_size = param_value->tensor_size();
|
||||
auto *raw_datas = static_cast<int8_t *>(param_value->tensor_addr());
|
||||
if (raw_datas == nullptr) {
|
||||
MS_LOG(ERROR) << "rawDatas is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (bit_num < 8 && bit_num > 0) {
|
||||
auto dst_data = new (std::nothrow) int8_t[elem_count];
|
||||
if (dst_data == nullptr) {
|
||||
MS_LOG(ERROR) << "new int8_t[] failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
DequantUtil::UnpackUtil<int8_t, uint8_t>(raw_datas, packed_size, bit_num, dst_data);
|
||||
if (memcpy_s(raw_datas, elem_count, dst_data, elem_count) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed.";
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
}
|
||||
HuffmanPriorityQueue pq;
|
||||
status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq);
|
||||
if (status != RET_OK) {
|
||||
|
@ -97,12 +118,14 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) {
|
|||
return status;
|
||||
}
|
||||
size_t ch_size = huffman_encoded_str_.length();
|
||||
if (ch_size < elem_count) {
|
||||
if (ch_size < packed_size) {
|
||||
auto encode_data = new (std::nothrow) char[ch_size];
|
||||
if (encode_data == nullptr) {
|
||||
MS_LOG(ERROR) << "new char[] failed.";
|
||||
delete[] raw_datas;
|
||||
return RET_MEMORY_FAILED;
|
||||
}
|
||||
delete[] raw_datas;
|
||||
if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed.";
|
||||
delete[] encode_data;
|
||||
|
@ -118,7 +141,7 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) {
|
|||
return RET_SUCCESS;
|
||||
}
|
||||
|
||||
STATUS huffman_encode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) {
|
||||
STATUS HuffmanEncode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) {
|
||||
MS_ASSERT(data != nullptr);
|
||||
|
||||
std::map<int8_t, size_t> freq_map;
|
||||
|
@ -166,7 +189,7 @@ STATUS huffman_encode::GetHuffmanPriorityQueue(const int8_t *data, const size_t
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void huffman_encode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) {
|
||||
void HuffmanEncode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) {
|
||||
if (is_left_node) {
|
||||
node->code = node->parent->code + "0";
|
||||
} else {
|
||||
|
@ -185,7 +208,7 @@ void huffman_encode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_lef
|
|||
}
|
||||
}
|
||||
|
||||
STATUS huffman_encode::BuildHuffmanTree(HuffmanPriorityQueue *pq) {
|
||||
STATUS HuffmanEncode::BuildHuffmanTree(HuffmanPriorityQueue *pq) {
|
||||
HuffmanNodePtr root = nullptr;
|
||||
|
||||
while (!pq->empty()) {
|
||||
|
@ -228,7 +251,7 @@ STATUS huffman_encode::BuildHuffmanTree(HuffmanPriorityQueue *pq) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS huffman_encode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) {
|
||||
STATUS HuffmanEncode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) {
|
||||
unsigned char out_c;
|
||||
string code_str;
|
||||
std::map<int, string>::iterator iter;
|
||||
|
@ -270,7 +293,7 @@ STATUS huffman_encode::DoHuffmanCompress(const int8_t *input_datas, const size_t
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
huffman_encode::~huffman_encode() {
|
||||
HuffmanEncode::~HuffmanEncode() {
|
||||
for (auto &node : this->huffman_nodes_) {
|
||||
delete node;
|
||||
}
|
||||
|
|
|
@ -23,9 +23,12 @@
|
|||
#include <vector>
|
||||
#include <queue>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <fstream>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -49,13 +52,15 @@ struct cmp {
|
|||
};
|
||||
using HuffmanPriorityQueue = std::priority_queue<HuffmanNodePtr, std::vector<HuffmanNodePtr>, cmp>;
|
||||
|
||||
class huffman_encode {
|
||||
class HuffmanEncode {
|
||||
public:
|
||||
huffman_encode() = default;
|
||||
HuffmanEncode() = default;
|
||||
|
||||
~huffman_encode();
|
||||
~HuffmanEncode();
|
||||
|
||||
STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph);
|
||||
STATUS GetParamValueLitePtr(const std::shared_ptr<AnfNode> &input_node, ParamValueLitePtr *param_value);
|
||||
|
||||
STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num);
|
||||
|
||||
private:
|
||||
std::map<int, std::string> huffman_table_;
|
||||
|
|
|
@ -25,52 +25,16 @@ using std::string;
|
|||
using std::vector;
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
bool WeightQuantizer::IsPosNum(const std::string &str) {
|
||||
for (size_t i = 0; i < str.size(); i++) {
|
||||
if (str.at(i) < '0' || str.at(i) > '9') {
|
||||
return false;
|
||||
}
|
||||
if (str.at(i) == '0' && i == 0 && str.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) {
|
||||
MS_ASSERT(config != nullptr);
|
||||
if (!WeightQuantizer::IsPosNum(config->quantWeightChannel)) {
|
||||
MS_LOG(ERROR) << "quantWeightChannel must be valid pos num.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!WeightQuantizer::IsPosNum(config->quantWeightSize)) {
|
||||
MS_LOG(ERROR) << "quantWeightSize must be valid pos num.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!WeightQuantizer::IsPosNum(config->bitNum)) {
|
||||
MS_LOG(ERROR) << "bitNum must be valid pos num.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int bitNum = std::stoi(config->bitNum);
|
||||
if (bitNum <= 0 || bitNum > 16) {
|
||||
MS_LOG(ERROR) << "bitNum should be more than 0 and less than 16 currently.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) {
|
||||
quant_strategy_ = std::make_unique<QuantStrategy>(0, 0);
|
||||
config_param_ = config;
|
||||
}
|
||||
|
||||
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const string &weightSize,
|
||||
const std::string &convWeightChannelThreshold, const std::string &bitNum)
|
||||
: Quantizer(graph) {
|
||||
this->config_file_ = config_file;
|
||||
auto quantSize = static_cast<size_t>(std::stoull(weightSize));
|
||||
this->bit_num_ = static_cast<size_t>(std::stoull(bitNum));
|
||||
auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold));
|
||||
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config) : Quantizer(graph) {
|
||||
this->config_file_ = config.configFile;
|
||||
auto quantSize = config.quantWeightSize;
|
||||
this->bit_num_ = config.bitNum;
|
||||
auto convQuantWeightChannelThreshold = config.quantWeightChannel;
|
||||
quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold);
|
||||
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
|
||||
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
|
||||
|
@ -222,7 +186,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) {
|
||||
STATUS WeightQuantizer::DoLstmQuantize(CNodePtr cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
|
||||
|
@ -233,110 +197,29 @@ STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) {
|
|||
MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
{
|
||||
auto weight_i = cnode->input(2);
|
||||
ParameterPtr param_node;
|
||||
ParamValueLitePtr param_value;
|
||||
GetLiteParameter(weight_i, ¶m_node, ¶m_value);
|
||||
if (param_node == nullptr || param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "GetLiteParameter error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
|
||||
MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
|
||||
return RET_OK;
|
||||
}
|
||||
if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
|
||||
MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < "
|
||||
<< quant_strategy_->mWeightSize;
|
||||
return RET_OK;
|
||||
}
|
||||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
false, 1);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
false, 1);
|
||||
}
|
||||
|
||||
auto status = ProcessLstmWeightByIndex(cnode, primitive_c, 2);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Process lstm weight i failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = ProcessLstmWeightByIndex(cnode, primitive_c, 3);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Process lstm weight h failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (cnode->inputs().size() > 4) {
|
||||
status = ProcessLstmWeightByIndex(cnode, primitive_c, 4);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(param_value, param_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
MS_LOG(ERROR) << "Process lstm bias failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
{
|
||||
auto weight_h = cnode->input(3);
|
||||
ParameterPtr param_node;
|
||||
ParamValueLitePtr param_value;
|
||||
GetLiteParameter(weight_h, ¶m_node, ¶m_value);
|
||||
if (param_node == nullptr || param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "GetLiteParameter error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
|
||||
MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
false, 2);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
false, 2);
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(param_value, param_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
{
|
||||
if (cnode->inputs().size() > 4) {
|
||||
auto bias = cnode->input(4);
|
||||
ParameterPtr param_node;
|
||||
ParamValueLitePtr param_value;
|
||||
GetLiteParameter(bias, ¶m_node, ¶m_value);
|
||||
if (param_node == nullptr || param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "GetLiteParameter error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
|
||||
MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
false, 3);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
false, 3);
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(param_value, param_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) {
|
||||
STATUS WeightQuantizer::DoGatherQuantize(CNodePtr cnode) {
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
MS_ASSERT(primitive_c != nullptr);
|
||||
|
||||
|
@ -375,6 +258,46 @@ STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const std::shared_ptr<PrimitiveC> &primitive_c,
|
||||
const int &index) {
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
auto weight_i = cnode->input(index);
|
||||
ParameterPtr param_node;
|
||||
ParamValueLitePtr param_value;
|
||||
GetLiteParameter(weight_i, ¶m_node, ¶m_value);
|
||||
if (param_node == nullptr || param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "GetLiteParameter error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
|
||||
MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
|
||||
return RET_OK;
|
||||
}
|
||||
if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
|
||||
MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < "
|
||||
<< quant_strategy_->mWeightSize;
|
||||
return RET_OK;
|
||||
}
|
||||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
false, index - 1);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
false, index - 1);
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(param_value, param_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
constexpr float relative_tolerance = 1e-5;
|
||||
constexpr float abs_tolerance = 1e-4;
|
||||
|
||||
|
@ -510,37 +433,28 @@ STATUS WeightQuantizer::RunFp32Graph(FuncGraphPtr func_graph) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
|
||||
// 0.2 Parse input calib files
|
||||
auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "CollectCalibInputs fail";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "run fp32 model";
|
||||
status = RunFp32Graph(func_graph);
|
||||
if (status != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoMixedQuantize(const FuncGraphPtr &func_graph) {
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
int status = RET_OK;
|
||||
for (auto &cnode : cnodes) {
|
||||
auto op_type = NodePrimitiveType(cnode);
|
||||
if (op_type == schema::PrimitiveType_Lstm) {
|
||||
status = DoLstmQuntize(cnode);
|
||||
status = DoLstmQuantize(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoLstmQuntize error";
|
||||
MS_LOG(ERROR) << "DoLstmQuantize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (op_type == schema::PrimitiveType_Gather) {
|
||||
status = DoGatherQuntize(cnode);
|
||||
status = DoGatherQuantize(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoGatherQuntize error";
|
||||
MS_LOG(ERROR) << "DoGatherQuantize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
STATUS WeightQuantizer::CheckImageCnt() {
|
||||
auto image_cnt = images_.at(0).size();
|
||||
if (!config_param_.input_shapes.empty()) {
|
||||
if (config_param_.input_shapes.size() != image_cnt) {
|
||||
|
@ -548,7 +462,62 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS WeightQuantizer::GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name,
|
||||
ParameterPtr *param_node, ParamValueLitePtr *param_value) {
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
MS_LOG(WARNING) << op_name << " the second input is not parameter";
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
*param_node = input_node->cast<ParameterPtr>();
|
||||
if (!(*param_node)->has_default()) {
|
||||
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
*param_value = std::static_pointer_cast<ParamValueLite>((*param_node)->default_param());
|
||||
if (*param_value == nullptr) {
|
||||
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
if ((*param_value)->tensor_type() != TypeId::kNumberTypeFloat32) {
|
||||
MS_LOG(WARNING) << op_name << " the second input type is not float";
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS WeightQuantizer::TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node,
|
||||
const ParamValueLitePtr ¶m_value, const std::shared_ptr<PrimitiveC> &primitive_c) {
|
||||
int status;
|
||||
type_id_ = TypeId::kNumberTypeInt8;
|
||||
int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
|
||||
int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
|
||||
|
||||
if (type_id_ == TypeId::kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
|
||||
bit_num_t, true);
|
||||
} else if (type_id_ == TypeId::kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
|
||||
bit_num_t, true);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "quant filter failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = SetAbstract(param_value, param_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
STATUS WeightQuantizer::DoQuantSearch(const FuncGraphPtr &func_graph) {
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
auto image_cnt = images_.at(0).size();
|
||||
int status = RET_OK;
|
||||
for (auto iter = cnodes.end(); iter != cnodes.begin();) {
|
||||
auto cnode = *(--iter);
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
|
@ -561,22 +530,10 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
|
|||
<< " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type());
|
||||
if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) {
|
||||
auto input_node = cnode->input(2);
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
MS_LOG(WARNING) << op_name << " the second input is not parameter";
|
||||
continue;
|
||||
}
|
||||
auto param_node = input_node->cast<ParameterPtr>();
|
||||
if (!param_node->has_default()) {
|
||||
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
|
||||
continue;
|
||||
}
|
||||
auto param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
|
||||
if (param_value == nullptr) {
|
||||
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
|
||||
continue;
|
||||
}
|
||||
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
|
||||
MS_LOG(WARNING) << op_name << " the second input type is not float";
|
||||
ParameterPtr param_node;
|
||||
ParamValueLitePtr param_value;
|
||||
status = GetParamNodeAndValue(input_node, op_name, ¶m_node, ¶m_value);
|
||||
if (status == RET_CONTINUE) {
|
||||
continue;
|
||||
}
|
||||
// copy origin data in case to recover
|
||||
|
@ -591,27 +548,9 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
|
|||
}
|
||||
// 1. try quant
|
||||
for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) {
|
||||
type_id_ = TypeId::kNumberTypeInt8;
|
||||
int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
|
||||
int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
|
||||
|
||||
if (type_id_ == TypeId::kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
|
||||
quant_min_t, bit_num_t, true);
|
||||
} else if (type_id_ == TypeId::kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
|
||||
quant_min_t, bit_num_t, true);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = TryQuant(bit_num_t, param_node, param_value, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "quant filter fail.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = SetAbstract(param_value, param_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
MS_LOG(ERROR) << "TryQuant failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// 2. evaluate the quant
|
||||
|
@ -679,6 +618,41 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
|
|||
free(origin_data);
|
||||
} // if: conv and matmul
|
||||
} // end loop: all cnode
|
||||
return status;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
|
||||
// 0.2 Parse input calib files
|
||||
auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "CollectCalibInputs failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
status = RunFp32Graph(func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "RunFp32Graph failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
status = DoMixedQuantize(func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoMixedQuantize failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
status = CheckImageCnt();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "CheckImageCnt failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
status = DoQuantSearch(func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantSearch failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
for (const auto &kv : opname_bit_) {
|
||||
MS_LOG(INFO) << "op: " << kv.first << " bit:" << kv.second;
|
||||
}
|
||||
|
@ -709,15 +683,15 @@ STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
} else if (op_type == schema::PrimitiveType_Lstm) {
|
||||
auto status = DoLstmQuntize(cnode);
|
||||
auto status = DoLstmQuantize(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoLstmQuntize error";
|
||||
MS_LOG(ERROR) << "DoLstmQuantize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (op_type == schema::PrimitiveType_Gather) {
|
||||
auto status = DoGatherQuntize(cnode);
|
||||
auto status = DoGatherQuantize(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoGatherQuntize error";
|
||||
MS_LOG(ERROR) << "DoGatherQuantize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -36,18 +36,18 @@
|
|||
namespace mindspore::lite::quant {
|
||||
class WeightQuantizer : public Quantizer {
|
||||
public:
|
||||
WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const std::string &weightSize,
|
||||
const std::string &covWeightChannelThreshold, const std::string &bitNum);
|
||||
WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config);
|
||||
WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config);
|
||||
~WeightQuantizer();
|
||||
|
||||
STATUS DoQuantize(FuncGraphPtr func_graph) override;
|
||||
STATUS DoConvQuantize(CNodePtr);
|
||||
STATUS DoMulQuantize(CNodePtr);
|
||||
STATUS DoLstmQuntize(CNodePtr cnode);
|
||||
STATUS DoGatherQuntize(CNodePtr cnode);
|
||||
static STATUS WeightQuantInputCheck(const converter::Flags *config);
|
||||
static bool IsPosNum(const std::string &str);
|
||||
STATUS DoLstmQuantize(CNodePtr cnode);
|
||||
STATUS DoGatherQuantize(CNodePtr cnode);
|
||||
|
||||
STATUS ProcessLstmWeightByIndex(const CNodePtr &cnode, const std::shared_ptr<PrimitiveC> &primitive_c,
|
||||
const int &index);
|
||||
|
||||
int quant_max_{127};
|
||||
int quant_min_{-128};
|
||||
|
@ -66,6 +66,14 @@ class WeightQuantizer : public Quantizer {
|
|||
STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c);
|
||||
STATUS DoFixedQuant(FuncGraphPtr);
|
||||
STATUS RunFp32Graph(FuncGraphPtr);
|
||||
|
||||
STATUS DoMixedQuantize(const FuncGraphPtr &func_graph);
|
||||
STATUS CheckImageCnt();
|
||||
STATUS GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name,
|
||||
ParameterPtr *param_node, ParamValueLitePtr *param_value);
|
||||
STATUS TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, const ParamValueLitePtr ¶m_value,
|
||||
const std::shared_ptr<PrimitiveC> &primitive_c);
|
||||
STATUS DoQuantSearch(const FuncGraphPtr &func_graph);
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue