add mixed quantization

This commit is contained in:
lz 2021-08-16 16:10:57 +08:00
parent 06e86e34de
commit a64eb7cd69
16 changed files with 944 additions and 37 deletions

View File

@ -41,7 +41,8 @@ table QuantParam {
enum WeightQunatCompressType: int {
NONE,
INDEXING,
SPARSE
SPARSE,
FSE
}
table Tensor {

View File

@ -42,13 +42,14 @@
#include "tools/common/node_util.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
#include "tools/converter/quantizer/fse_encoder.h"
using mindspore::ops::PrimitiveC;
namespace mindspore::lite {
namespace {
constexpr int kIndexOfValueInputOfGetTupleItem = 2;
std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
@ -116,7 +117,17 @@ static STATUS CompressTensor(schema::TensorT *tensor_input, const std::unique_pt
auto repetition_packed = false;
MS_LOG(DEBUG) << dst_node->name;
if (dst_node->quantType == schema::QuantType_QUANT_WEIGHT) {
if (bit_num <= kBitNum8) {
if (bit_num == 0) {
if (tensor_input->data.empty() || tensor_input->dims.size() <= 1) {
return RET_OK;
}
quant::FSEEncoder fse_encoder;
if (dst_node->primitive->value.type == PrimitiveType_GRU) {
fse_encoder.Compress(tensor_input);
} else {
fse_encoder.Compress(tensor_input);
}
} else if (bit_num <= kBitNum8) {
repetition_packed = PackRepetition<int8_t>(bit_num, tensor_input);
} else {
repetition_packed = PackRepetition<int16_t>(bit_num, tensor_input);

View File

@ -286,7 +286,7 @@ int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const con
m_quantizer_->flags = *config;
auto status = m_quantizer_->DoQuantize(old_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Quant failed " << status;
MS_LOG(ERROR) << "DoQuantization failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}

View File

@ -921,7 +921,6 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
if (op_type == "Placeholder" || op_type == "Const" || op_type == "Identity" || op_type == "StopGradient") {
return RET_OK;
}
MS_LOG(INFO) << "parse op : " << op_type;
auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
if (node_parser == nullptr) {
@ -1055,6 +1054,21 @@ STATUS TFModelParser::ConvertRootGraphOutputs() {
return RET_ERROR;
}
output_nodes.push_back(anf_node);
// Get the name of node 'Identity' and 'StopGradient'.
if (pair.second->op() == "Identity" || pair.second->op() == "StopGradient") {
auto tmp_node = pair.second;
bool found_input = true;
while (tmp_node->name().empty() && (tmp_node->op() == "Identity" || tmp_node->op() == "StopGradient")) {
auto flatten_input_name = TensorFlowUtils::GetFlattenNodeName(tmp_node->input(0));
if (tf_root_graph_nodes_.find(flatten_input_name) != tf_root_graph_nodes_.end()) {
tmp_node = tf_root_graph_nodes_.at(flatten_input_name);
} else {
found_input = false;
break;
}
}
origin_name = found_input ? tmp_node->name() : origin_name;
}
graph_output_names_.push_back(origin_name);
}
}

View File

@ -12,6 +12,10 @@ file(GLOB QUANTIZER
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc
${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/huffman_encode.cc
${CMAKE_CURRENT_SOURCE_DIR}/fse_decoder.cc
${CMAKE_CURRENT_SOURCE_DIR}/fse_bit_stream.cc
${CMAKE_CURRENT_SOURCE_DIR}/fse_encoder.cc
${CMAKE_CURRENT_SOURCE_DIR}/fix_bit_weight_quantizer.cc
)
set_property(SOURCE ${QUANTIZER} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
add_library(quantizer_mid OBJECT ${QUANTIZER})

View File

@ -0,0 +1,178 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
#include <cmath>
namespace mindspore::lite::quant {
// the error is currently measured per channel.
// it could be measured per layer but it would be less good.
// the `preferred` dim should point to the output channels dimension.
float FixBitWeightQuantizer::MeasureQuantizationError(float *weights, const int *shape, int dims, int preferred_dim,
float scale) {
int numel = 1;
for (int i = 0; i < dims; i++) {
numel *= shape[i];
}
int bucket_count = shape[preferred_dim];
std::vector<float> norms2(bucket_count);
std::vector<float> dnorms2(bucket_count);
for (int i = 0; i < bucket_count; i++) {
norms2[i] = 0.0;
dnorms2[i] = 0.0;
}
double average_dequant = 0;
double average_raw = 0;
std::vector<float> dequant_datas(numel);
int bucket_volume = 1;
for (int i = preferred_dim; i < dims; i++) {
bucket_volume *= shape[i];
}
for (int i = 0; i < numel; i++) {
float dequant = scale * (floorf(weights[i] / scale + 0.5));
dequant_datas[i] = dequant;
average_raw += weights[i];
average_dequant += dequant;
}
// mean
average_dequant = average_dequant / numel;
average_raw = average_raw / numel;
// std
double variance_dequant = 0;
double variance_raw = 0;
for (int i = 0; i < numel; i++) {
variance_dequant += std::pow(dequant_datas[i] - average_dequant, 2);
variance_raw += std::pow(weights[i] - average_raw, 2);
}
variance_dequant = std::sqrt(variance_dequant / numel);
variance_raw = std::sqrt(variance_raw / numel);
var_corr = variance_raw / variance_dequant;
mean_corr = average_raw - average_dequant * var_corr;
for (int i = 0; i < numel; i++) {
int bucket = (i / bucket_volume) % bucket_count;
norms2[bucket] += weights[i] * weights[i];
float dequant = var_corr * (scale * (floorf(weights[i] / scale + 0.5))) + mean_corr;
float d = weights[i] - dequant;
dnorms2[bucket] += d * d;
}
int c = 0;
float t = 1e-10;
for (int i = 0; i < bucket_count; i++) {
if (norms2[i] < 1.0e-10) continue;
c += 1;
t += sqrtf(dnorms2[i] / norms2[i]);
}
return t / (c + 1e-7);
}
MinMax FixBitWeightQuantizer::GetMinMax(const float *arr, int arrc) {
MinMax min_max = {INFINITY, -INFINITY};
for (int i = 0; i < arrc; i++)
if (arr[i] > min_max.max)
min_max.max = arr[i];
else if (arr[i] < min_max.min)
min_max.min = arr[i];
return min_max;
}
BinarySearchResult FixBitWeightQuantizer::BinarySearchForQuantizationScale(float *weights, int *shape, int dims,
int preferred_dim, int max_iters,
float target_err, float rel_tol) {
int element_num = 1;
for (int i = 0; i < dims; i++) {
element_num *= shape[i];
}
MinMax mm = GetMinMax(weights, element_num);
if (mm.max < mm.min + 1.0e-5) {
return {0, static_cast<float>(std::fabs(mm.max) + 1.0e-5)};
}
// start a binary search
float curr_scale = (mm.max - mm.min) * target_err;
float right_hs_dx = curr_scale * 2.0;
while (MeasureQuantizationError(weights, shape, dims, preferred_dim, right_hs_dx) < target_err) {
right_hs_dx *= 2.0;
}
float left_hs_dx = curr_scale / 2.0;
while (MeasureQuantizationError(weights, shape, dims, preferred_dim, left_hs_dx) > target_err) {
left_hs_dx /= 2.0;
}
int iter_count = 0;
BinarySearchResult res = {0, curr_scale};
while (true) {
float curr_err = MeasureQuantizationError(weights, shape, dims, preferred_dim, res.scale);
if (std::fabs(curr_err - target_err) / target_err < rel_tol) {
return res;
}
if (iter_count > max_iters) {
res.status = 1;
return res;
}
if (curr_err > target_err)
right_hs_dx = res.scale;
else
left_hs_dx = res.scale;
res.scale = (left_hs_dx + right_hs_dx) / 2.0;
iter_count += 1;
}
}
int FixBitWeightQuantizer::DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim,
std::vector<schema::QuantParamT> *quant_params,
std::vector<int16_t> *quant_datas) {
int weight_count = 1;
int dims = shape.size();
int input_shape[4] = {0, 0, 0, 0};
for (int i = 0; i < dims; i++) {
weight_count *= shape[i];
input_shape[i] = shape[i];
}
BinarySearchResult br = BinarySearchForQuantizationScale(weights, input_shape, dims, preferred_dim, max_search_iters,
target_relative_err, target_search_tolerance);
if (br.status != 0) {
MS_LOG(ERROR) << "reached_max_iters";
return RET_ERROR;
}
schema::QuantParamT quant_param;
int qr = QuantizeByScale(weights, weight_count, br.scale, &quant_param, quant_datas);
if (qr != 0) {
MS_LOG(ERROR) << "quant failed.";
return RET_ERROR;
}
// It is used to calculate the Shannon entropy.
quant_params->push_back(quant_param);
return RET_OK;
}
int FixBitWeightQuantizer::QuantizeByScale(const float *weights, int weightsc, float scale,
schema::QuantParamT *quant_params, std::vector<int16_t> *quant_datas) {
for (int i = 0; i < weightsc; i++) {
auto q = static_cast<int>(floorf(weights[i] / scale + 0.5));
quant_datas->at(i) = q;
}
quant_params->meanCorr = mean_corr;
quant_params->varCorr = var_corr;
quant_params->scale = scale;
quant_params->zeroPoint = 0;
quant_params->numBits = 0;
quant_params->inited = true;
return RET_OK;
}
} // namespace mindspore::lite::quant

View File

@ -0,0 +1,71 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_FIX_BIT_WEIGHT_QUANTIZER_H
#define LITE_FIX_BIT_WEIGHT_QUANTIZER_H
#include <cstdint>
#include <vector>
#include <cmath>
#include "schema/inner/model_generated.h"
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
namespace mindspore::lite::quant {
typedef struct {
int status;
float scale;
} BinarySearchResult;
typedef struct {
float min;
float max;
} MinMax;
class FixBitWeightQuantizer {
public:
explicit FixBitWeightQuantizer(float target_relative_err = 0.01, float target_search_tolerance = 0.01,
int max_search_iters = 100)
: target_relative_err(target_relative_err),
target_search_tolerance(target_search_tolerance),
max_search_iters(max_search_iters) {}
~FixBitWeightQuantizer() = default;
int DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim,
std::vector<schema::QuantParamT> *quant_params, std::vector<int16_t> *quant_datas);
private:
// the error is currently measured per channel.
// it could be measured per layer but it would be less good.
// the `preferred` dim should point to the output channels dimension.
float MeasureQuantizationError(float *weights, const int *shape, int dims, int preferred_dim, float scale);
MinMax GetMinMax(const float *arr, int arrc);
int QuantizeByScale(const float *weights, int weightsc, float scale, schema::QuantParamT *quant_params,
std::vector<int16_t> *quant_datas);
BinarySearchResult BinarySearchForQuantizationScale(float *weights, int *shape, int dims, int preferred_dim,
int max_iters, float target_err, float rel_tol);
private:
float var_corr{1};
float mean_corr{0};
float target_relative_err;
float target_search_tolerance;
int max_search_iters;
};
} // namespace mindspore::lite::quant
#endif // LITE_FIX_BIT_WEIGHT_QUANTIZER_H

View File

@ -0,0 +1,103 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/fse_bit_stream.h"
#include <memory.h>
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
namespace mindspore::lite::quant {
int BitStream::Create(int bit_capacity) {
chunk_count_ = (bit_capacity >> 6);
chunks_ = static_cast<uint64_t *>(malloc(chunk_count_ * sizeof(uint64_t)));
if (chunks_ == nullptr) {
MS_LOG(ERROR) << "malloc memory failed.";
return RET_ERROR;
}
memset(chunks_, 0, chunk_count_ * sizeof(uint64_t));
return RET_OK;
}
void BitStream::Free() {
curr_chunk_index_ = -1;
curr_chunk_ = 0;
curr_bit_count_ = 0;
chunk_count_ = 0;
if (chunks_ != nullptr) {
free(chunks_);
chunks_ = nullptr;
}
}
void BitStream::Empty() {
curr_chunk_index_ = -1;
curr_chunk_ = 0;
curr_bit_count_ = 0;
for (int i = 0; i < chunk_count_; i++) {
chunks_[i] = 0;
}
}
int64_t BitStream::Pop(uint8_t bit_count) {
int64_t right = curr_chunk_ >> (64 - curr_bit_count_);
int64_t res = right & ((1 << bit_count) - 1);
curr_bit_count_ -= bit_count;
if (curr_bit_count_ > 0) {
// most likely branch
return res;
}
if (curr_bit_count_ == 0) {
// not so often...
if (curr_chunk_index_ > -1) {
// rare...
curr_bit_count_ = 64;
curr_chunk_ = chunks_[curr_chunk_index_--];
}
return res;
}
// sad path :(
curr_bit_count_ += bit_count;
curr_chunk_ = chunks_[curr_chunk_index_--];
right |= (curr_chunk_ & ((1 << (bit_count - curr_bit_count_)) - 1)) << curr_bit_count_;
curr_bit_count_ = 64 - (bit_count - curr_bit_count_);
return right;
}
void BitStream::Push(int64_t state, uint8_t bit_count) {
curr_bit_count_ += bit_count;
if (curr_bit_count_ <= 64) {
// happy path, no split
curr_chunk_ = (curr_chunk_ << bit_count) | (state & ((1 << bit_count) - 1));
if (curr_bit_count_ == 64) {
// flush (rare)
chunks_[++curr_chunk_index_] = curr_chunk_;
curr_chunk_ = 0;
curr_bit_count_ = 0;
}
} else {
// split, rare
int leftbits = curr_bit_count_ - 64;
int rightbits = bit_count - leftbits;
curr_chunk_ = (curr_chunk_ << rightbits) | ((state >> leftbits) & ((1 << rightbits) - 1));
// flush left
chunks_[++curr_chunk_index_] = curr_chunk_;
curr_chunk_ = state & ((1 << leftbits) - 1);
curr_bit_count_ = leftbits;
}
}
void BitStream::Flush() { curr_chunk_ <<= 64 - curr_bit_count_; }
} // namespace mindspore::lite::quant

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_FSEBITSTREAM_H
#define LITE_FSEBITSTREAM_H
#include <cstdint>
namespace mindspore::lite::quant {
class BitStream {
public:
BitStream() = default;
~BitStream() = default;
public:
int Create(int bit_capacity);
void Free();
void Empty();
int64_t Pop(uint8_t bit_count);
void Push(int64_t state, uint8_t bit_count);
void Flush();
int32_t GetCurrChunkIndex() { return this->curr_chunk_index_; }
uint64_t GetCurrChunk() { return this->curr_chunk_; }
int8_t GetCurrBitCount() { return this->curr_bit_count_; }
uint64_t *GetChunks() { return this->chunks_; }
int GetChunkCount() { return this->chunk_count_; }
void SetCurrChunkIndex(int32_t curr_chunk_index) { this->curr_chunk_index_ = curr_chunk_index; }
void SetCurrChunk(uint64_t curr_chunk) { this->curr_chunk_ = curr_chunk; }
void SetCurrBitCount(int8_t curr_bit_count) { this->curr_bit_count_ = curr_bit_count; }
void SetChunks(uint64_t *chunks) { this->chunks_ = chunks; }
void SetChunkCount(int chunk_count) { this->chunk_count_ = chunk_count; }
private:
int32_t curr_chunk_index_{-1}; // the index of the next chunk that we will write to
uint64_t curr_chunk_{0};
int8_t curr_bit_count_{0}; // the number of bits that are currently written in the register.
uint64_t *chunks_{nullptr}; // the actual memory
int chunk_count_{0}; // the number of chunks
};
} // namespace mindspore::lite::quant
#endif // LITE_FSEBITSTREAM_H

View File

@ -0,0 +1,350 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/fse_encoder.h"
#include <cstdint>
#include <algorithm>
#include <cmath>
#include "mindspore/core/ir/dtype/type_id.h"
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
namespace mindspore::lite::quant {
// The function gives the index of most import `1` in the binary representation.
// e.g. for the number 00100 it gives 2.
int fse_count_bits(int32_t x) { return __builtin_clz(x) ^ 31; }
int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_count, int table_log,
uint32_t *delta_bit_count, int16_t *delta_state, uint16_t *coding_table,
uint16_t *symbol_table) {
int tablesize = 1 << table_log;
int tablemask = tablesize - 1;
int step = ((tablesize >> 1) + (tablesize >> 3) + 3);
int pos = 0;
// Separate the same symbols, coding will be better if the same characters are distributed evenly across the table.
for (int sym = 0; sym < frequency_count; sym++) {
for (int i = 0; i < frequency[sym]; i++) {
symbol_table[pos] = sym;
pos = (pos + step) & tablemask;
while (pos > tablemask) pos = (pos + step) & tablemask;
}
}
if (pos != 0) return 1;
std::vector<uint16_t> cfreqs(frequency_count + 2);
cfreqs[0] = 0;
for (int i = 1; i < frequency_count + 1; i++) {
cfreqs[i] = cfreqs[i - 1] + frequency[i - 1];
}
cfreqs[frequency_count + 1] = cfreqs[frequency_count] + 1;
for (int i = 0; i < tablesize; i++) {
uint16_t sym = symbol_table[i];
coding_table[cfreqs[sym]] = tablesize + i;
cfreqs[sym] += 1;
}
int total = 0;
for (int sym = 0; sym < frequency_count; sym++) {
if (frequency[sym] >= 2) {
int max_bits_out = table_log - fse_count_bits(frequency[sym] - 1);
int min_state_plus = frequency[sym] << max_bits_out;
delta_bit_count[sym] = (max_bits_out << 16) - min_state_plus;
delta_state[sym] = total - frequency[sym];
total += frequency[sym];
} else {
// we assume minimum `frequency` is 1
delta_bit_count[sym] = (table_log << 16) - (1 << table_log);
delta_state[sym] = total - 1;
total++;
}
}
return 0;
}
int ConvertTensor2Quant(schema::TensorT *tensor_input, FSEQuant *quants) {
std::vector<int16_t> dequants;
for (size_t i = 0; i < tensor_input->data.size() / sizeof(int16_t); ++i) {
auto data = static_cast<int16_t>(reinterpret_cast<int16_t *>(tensor_input->data.data())[i]);
dequants.push_back(data);
}
int qmin = *min_element(dequants.begin(), dequants.end());
int qmax = *max_element(dequants.begin(), dequants.end());
int uncompressed_frequency_count = qmax - qmin + 1;
std::vector<int> uncompressed_frequency(uncompressed_frequency_count);
for (int i = 0; i < uncompressed_frequency_count; i++) {
uncompressed_frequency[i] = 0;
}
for (size_t i = 0; i < tensor_input->data.size() / sizeof(int16_t); i++) {
auto data = static_cast<int16_t>(reinterpret_cast<int16_t *>(tensor_input->data.data())[i]);
int q = data - qmin;
uncompressed_frequency[q] += 1;
}
std::vector<uint16_t> uncompressed_freqs_to_compressed_sym(uncompressed_frequency_count);
int sym = 0;
for (int i = 0; i < uncompressed_frequency_count; i++) {
if (uncompressed_frequency[i]) {
if (sym >= MAX_SYMS) return 1; // too many symbols!
uncompressed_freqs_to_compressed_sym[i] = sym;
quants->frequency[sym] = uncompressed_frequency[i];
quants->centroids[sym] =
tensor_input->quantParams.front()->varCorr * tensor_input->quantParams.front()->scale * (i + qmin) +
tensor_input->quantParams.front()->meanCorr;
sym++;
}
}
quants->size = sym;
quants->symbol_table_count = tensor_input->data.size() / sizeof(int16_t);
quants->symbol_table = static_cast<uint16_t *>(malloc(quants->symbol_table_count * sizeof(uint16_t)));
if (quants->symbol_table == nullptr) {
MS_LOG(ERROR) << "malloc memory failed.";
return RET_ERROR;
}
for (int i = 0; i < quants->symbol_table_count; i++) {
auto data = static_cast<int16_t>(reinterpret_cast<int16_t *>(tensor_input->data.data())[i]);
int q = data - qmin;
sym = uncompressed_freqs_to_compressed_sym[q];
quants->symbol_table[i] = sym;
}
return RET_OK;
}
int FSEEncoder::Compress(schema::TensorT *tensor_input) {
int table_log = 0;
FSEQuant fse_quant;
ConvertTensor2Quant(tensor_input, &fse_quant);
NormalizeFrequency(&fse_quant, &table_log);
BitStream bs;
int ret;
ret = bs.Create(16 * fse_quant.symbol_table_count);
if (ret != RET_OK) {
MS_LOG(ERROR) << "BitStream Create failed.";
return ret;
}
ret = FSEEncode(&bs, fse_quant.symbol_table, fse_quant.symbol_table_count, fse_quant.frequency, fse_quant.size,
table_log);
if (ret != RET_OK) {
MS_LOG(ERROR) << "FSE Encode failed.";
return RET_ERROR;
}
bs.Flush();
// Serializing to out:
ret = SerializingToOut(tensor_input, &bs, fse_quant, table_log);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Serializing To Out failed.";
return ret;
}
bs.Free();
free(fse_quant.symbol_table);
return RET_OK;
}
uint16_t FSEEncoder::FSEEncodeSymbolGetNewState(BitStream *bs, uint16_t sym, uint16_t state,
const uint32_t *delta_bit_count, const int16_t *delta_state,
uint16_t *coding_table) {
// It is to determine the number of bits to flush.
// This is basically one of 2 values, n or n+1, depending on state crossing a threshold.
uint8_t bits_out = (state + delta_bit_count[sym]) >> 16;
bs->Push(state, bits_out);
// subrangeID = state >> nbBitsOut
return coding_table[(state >> bits_out) + delta_state[sym]];
}
int GetMaxIndex(const uint16_t *arr, int arr_count) {
float max = -INFINITY;
int index = -1;
for (int i = 0; i < arr_count; i++) {
if (arr[i] > max) {
max = arr[i];
index = i;
}
}
return index;
}
void FSEEncoder::NormalizeFrequency(FSEQuant *q, int *table_log) {
// The higher the number, the more accurate we'll be to the shannon entropy,
// but also the larger the table, so `+3` is a good compromise.
*table_log = std::min(MAX_TABLE_LOG, fse_count_bits((uint32_t)q->size) + 3);
int new_table_size = 1 << (*table_log);
int curr_table_size = 0;
for (int i = 0; i < q->size; i++) curr_table_size += q->frequency[i];
// normalize
int updated_table_size = 0;
float rat = (static_cast<float>(new_table_size)) / curr_table_size;
for (int i = 0; i < q->size; i++) {
q->frequency[i] = std::max(1, static_cast<int>(floorf(0.5 + rat * q->frequency[i])));
updated_table_size += q->frequency[i];
}
// If the sum of the symbol frequencies is not equal to the power of two (almost always),
// then the frequencies need to be normalized-they must be proportionally reduced (or increased) so that the power of
// two is obtained in total.
// shrink
while (updated_table_size > new_table_size) {
int max_ix = GetMaxIndex(q->frequency, q->size);
q->frequency[max_ix]--;
updated_table_size--;
}
// grow
if (updated_table_size < new_table_size) {
int max_ix = GetMaxIndex(q->frequency, q->size);
q->frequency[max_ix] += new_table_size - updated_table_size;
}
}
// Encoding is therefore just a repeat of this process :
// - get Symbol to encode
// - look at current state value
// - determine nbBits, flush them
// - determine sub-Range Id
// - look for Symbol position of same Id : you get your next state
int FSEEncoder::FSEEncode(BitStream *bs, const uint16_t *data, int data_count, uint16_t *frequency, int frequency_count,
int table_log) {
int table_size = 1 << table_log;
// symbolTT.deltaNbBits stores a value which, when added with state,
// makes the result of >> 16 produces either n or n+1, as required.
std::vector<uint32_t> delta_number_bits(frequency_count);
for (int i = 0; i < frequency_count; i++) {
delta_number_bits[i] = 0;
}
// symbolTT.deltaFindState provides the offset to find the correct segment into the table.
std::vector<int16_t> delta_find_state(frequency_count);
for (int i = 0; i < frequency_count; i++) {
delta_find_state[i] = 0;
}
// nextStateTable with symbol
std::vector<uint16_t> coding_table(table_size);
for (int i = 0; i < table_size; i++) {
coding_table[i] = 0;
}
// position with symbol
std::vector<uint16_t> symtable(table_size);
for (int i = 0; i < table_size; i++) {
symtable[i] = 0;
}
int ret = FSECreateStatesForEncoding(frequency, frequency_count, table_log, delta_number_bits.data(),
delta_find_state.data(), coding_table.data(), symtable.data());
if (ret != RET_OK) {
MS_LOG(ERROR) << "Create states table for encoding failed.";
return ret;
}
uint16_t state = table_size;
// The results of the 1st symbol encoding is not flushed to the bitstream,
// It is just to get a valid 1 st state.
state = FSEEncodeSymbolGetNewState(bs, data[0], state, delta_number_bits.data(), delta_find_state.data(),
coding_table.data());
bs->Empty();
for (int i = 0; i < data_count; i++) {
state = FSEEncodeSymbolGetNewState(bs, data[i], state, delta_number_bits.data(), delta_find_state.data(),
coding_table.data());
}
bs->Push(state - table_size, table_log);
return ret;
}
int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant,
int table_log) {
auto max_size = tensor_input->data.size() * 2;
auto *out8 = static_cast<uint8_t *>(malloc(max_size));
if (out8 == nullptr) {
MS_LOG(ERROR) << "malloc memory failed.";
return RET_ERROR;
}
int offset = 0;
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)fse_quant.size;
offset += sizeof(uint16_t);
if (offset + sizeof(uint16_t) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
}
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)table_log;
offset += sizeof(uint16_t);
int chunksc = bs->GetCurrChunkIndex() + 2;
if (offset + sizeof(uint32_t) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
}
*(reinterpret_cast<uint32_t *>(&out8[offset])) = (uint32_t)chunksc;
offset += sizeof(uint32_t);
for (int j = 0; j < fse_quant.size; j++) {
if (offset + sizeof(uint16_t) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
}
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)fse_quant.frequency[j];
offset += sizeof(uint16_t);
}
while (offset % 8 != 0) {
if (offset + sizeof(uint16_t) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
}
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)0;
offset += sizeof(uint16_t);
}
for (int j = 0; j < fse_quant.size; j++) {
if (offset + sizeof(float) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
}
*(reinterpret_cast<float *>(&out8[offset])) = static_cast<float>(fse_quant.centroids[j]);
offset += sizeof(float);
}
while (offset % 8 != 0) {
if (offset + sizeof(uint16_t) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
}
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)0;
offset += sizeof(uint16_t);
}
for (int j = 0; j < bs->GetCurrChunkIndex() + 1; j++) {
if (offset + sizeof(uint64_t) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
}
*(reinterpret_cast<uint64_t *>(&out8[offset])) = (uint64_t)bs->GetChunks()[j];
offset += sizeof(uint64_t);
}
if (offset + sizeof(uint64_t) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
}
*(reinterpret_cast<uint64_t *>(&out8[offset])) = (uint64_t)bs->GetCurrChunk();
offset += sizeof(uint64_t);
if (offset + sizeof(uint8_t) > max_size) {
MS_LOG(ERROR) << "offset over max size"
<< " offset:" << offset << " max_size:" << max_size;
}
*(reinterpret_cast<uint8_t *>(&out8[offset])) = (uint8_t)bs->GetCurrBitCount();
offset += sizeof(uint8_t);
if (static_cast<int>(offset) < static_cast<int>(tensor_input->data.size())) {
tensor_input->data.resize(offset);
if (memcpy_s(tensor_input->data.data(), offset, out8, offset) != EOK) {
MS_LOG(ERROR) << "memcpy failed.";
}
}
tensor_input->quantParams.clear();
tensor_input->weightQunatCompressType = schema::WeightQunatCompressType_FSE;
tensor_input->dataType = TypeId::kNumberTypeFloat32;
free(out8);
return RET_OK;
}
} // namespace mindspore::lite::quant

View File

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_FSE_ENCODER_H
#define LITE_FSE_ENCODER_H
#include <vector>
#include "tools/converter/quantizer/fse_bit_stream.h"
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
namespace mindspore::lite::quant {
constexpr int MAX_SYMS = 65534;
constexpr int MAX_TABLE_LOG = 16;
typedef struct {
uint16_t *symbol_table; // the place to store the quantized tensor
int symbol_table_count; // the number of symbols that exist
float centroids[MAX_SYMS]; // the mean of all the numbers that got quantized into it
uint16_t frequency[MAX_SYMS]; // holds the number of times each symbol appears in `*symbol_table`
int size; // the number of entries in `symbol_table`
} FSEQuant;
class FSEEncoder {
public:
FSEEncoder() = default;
~FSEEncoder() = default;
int Compress(schema::TensorT *tensor_input);
private:
int FSECreateStatesForEncoding(uint16_t *frequency, int frequency_count, int table_log, uint32_t *delta_bit_count,
int16_t *delta_state, uint16_t *coding_table, uint16_t *symbol_table);
uint16_t FSEEncodeSymbolGetNewState(BitStream *bs, uint16_t sym, uint16_t state, const uint32_t *delta_bit_count,
const int16_t *delta_state, uint16_t *coding_table);
int FSEEncode(BitStream *bs, const uint16_t *data, int data_count, uint16_t *frequency, int frequency_count,
int table_log);
void NormalizeFrequency(FSEQuant *q, int *table_log);
int SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant, int table_log);
};
} // namespace mindspore::lite::quant
#endif // LITE_FSE_ENCODER_H

View File

@ -580,8 +580,9 @@ STATUS PostTrainingQuantizer::DoWeightQuant(const std::string &op_name, const An
quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
}
}
auto weight_quant_type = perchanel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
auto status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_PostTraining, quant_max_t, quant_min_t, bit_num_t,
perchanel, kNumberTypeInt8);
weight_quant_type, kNumberTypeInt8);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status;

View File

@ -1017,4 +1017,53 @@ void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<in
}
}
}
STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
WeightQuantType weight_quant_type, TypeId quant_data_type, int index) {
MS_ASSERT(weight != nullptr);
MS_ASSERT(primitive != nullptr);
auto dims = weight->shape();
if (weight_quant_type == FIXED_BIT_PER_CHANNEL) {
if (dims.size() <= 1) {
MS_LOG(WARNING) << "dims is " << dims.size() << " can not per_channel";
weight_quant_type = FIXED_BIT_PER_LAYER;
}
}
std::vector<schema::QuantParamT> quant_params;
size_t elem_count = weight->DataSize();
auto *raw_data = static_cast<float *>(weight->data_c());
if (raw_data == nullptr) {
MS_LOG(ERROR) << "rawDatas is nullptr";
return RET_ERROR;
}
std::vector<int16_t> quant_data(elem_count);
int ret = RET_OK;
if (weight_quant_type == MIXED_BIT_PER_LAYER) {
FixBitWeightQuantizer quantizer(0.02);
quantizer.DoQuantization(static_cast<float *>(weight->data_c()), weight->shape_c(), index - 1, &quant_params,
&quant_data);
} else {
MS_LOG(ERROR) << "Unsupported weight quant type:" << weight_quant_type;
}
auto status =
UpdateTensorDataAndSize(weight, quant_data.data(), quant_data.size() * sizeof(int16_t), TypeId::kNumberTypeInt16);
if (status != RET_OK) {
MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
return RET_ERROR;
}
if (quant_params.empty()) {
MS_LOG(ERROR) << "quant_params empty";
return RET_ERROR;
}
auto quant_param_holder = GetCNodeQuantHolder(primitive);
if (quant_type == QuantType_PostTraining) {
quant_param_holder->AddInputQuantParam(quant_params);
} else {
quant_param_holder->set_input_quant_param(index, quant_params);
}
return ret;
}
} // namespace mindspore::lite::quant

View File

@ -40,12 +40,18 @@
#include "abstract/dshape.h"
#include "tools/converter/quantizer/huffman_encode.h"
#include "tools/converter/quantizer/bitpacking.h"
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
#include "src/lite_session.h"
#include "tools/converter/graphdef_transform.h"
#include "src/common/file_utils.h"
#include "src/common/quant_utils.h"
namespace mindspore::lite::quant {
enum WeightQuantType {
FIXED_BIT_PER_CHANNEL = 0,
FIXED_BIT_PER_LAYER = 1,
MIXED_BIT_PER_LAYER = 2,
};
constexpr size_t kUint8Quantization = 8;
constexpr size_t kMaxBit = 8;
constexpr size_t kMaxNum1024 = 1024;
@ -155,17 +161,20 @@ STATUS DoBitPack(const tensor::TensorPtr &weight, const size_t &bit_num, const s
return RET_OK;
}
STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
WeightQuantType weight_quant_type, TypeId quant_data_type, int index = 1);
template <typename T>
STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type, int quant_max,
int quant_min, size_t bit_num, bool per_channel, TypeId quant_data_type, int index = 1,
bool k_means = false) {
int quant_min, size_t bit_num, WeightQuantType weight_quant_type, TypeId quant_data_type,
int index = 1, bool k_means = false) {
MS_ASSERT(weight != nullptr);
MS_ASSERT(primitive != nullptr);
auto dims = weight->shape();
if (per_channel) {
if (weight_quant_type == FIXED_BIT_PER_CHANNEL) {
if (dims.size() <= 1) {
MS_LOG(WARNING) << "dims is " << dims.size() << " can not per_channel";
per_channel = false;
weight_quant_type = FIXED_BIT_PER_LAYER;
}
}
@ -179,7 +188,7 @@ STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitiv
std::vector<T> quant_data(elem_count);
int ret = RET_OK;
if (per_channel) {
if (weight_quant_type == FIXED_BIT_PER_CHANNEL) {
bool channel_at_first = true;
int channel_cnt = -1;
CalQuantAssitInfo(primitive, dims, index, &channel_at_first, &channel_cnt);
@ -197,13 +206,15 @@ STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitiv
MS_LOG(ERROR) << "Do per channel quant failed.";
return ret;
}
} else {
} else if (weight_quant_type == FIXED_BIT_PER_LAYER) {
ret = DoPerLayerQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(), &quant_params, quant_max,
quant_min, bit_num, k_means, &quant_data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Do per layer quant failed.";
return ret;
}
} else {
MS_LOG(ERROR) << "Unsupported weight quant type:" << weight_quant_type;
}
auto status = UpdateTensorDataAndSize(weight, quant_data.data(), quant_data.size() * sizeof(T), quant_data_type);
if (status != RET_OK) {

View File

@ -105,12 +105,15 @@ STATUS WeightQuantizer::DoConvQuantize(const CNodePtr &cnode) {
return RET_OK;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true,
type_id_);
if (is_mixed_bit) {
type_id_ = kNumberTypeInt16;
status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER, type_id_);
} else if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true,
type_id_);
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
}
if (status == RET_CONTINUE) {
return RET_OK;
@ -142,16 +145,16 @@ STATUS WeightQuantizer::DoMulQuantize(const CNodePtr &cnode) {
}
auto status = RET_ERROR;
auto per_channel = true;
if (i == kInputSize2) {
per_channel = false;
auto weight_quant_type = WeightQuantType::FIXED_BIT_PER_CHANNEL;
if (i == 3) {
weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
}
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
bit_num_, per_channel, type_id_, i - 1);
bit_num_, weight_quant_type, type_id_, i - 1);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
bit_num_, per_channel, type_id_, i - 1);
bit_num_, weight_quant_type, type_id_, i - 1);
}
if (status == RET_CONTINUE) {
continue;
@ -225,11 +228,11 @@ STATUS WeightQuantizer::DoGatherQuantize(const CNodePtr &cnode) {
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false,
type_id_, 0);
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, 0);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, type_id_, 0);
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, 0);
}
if (status == RET_CONTINUE) {
return RET_OK;
@ -274,10 +277,10 @@ STATUS WeightQuantizer::DoOptimizerQuantize(const CNodePtr &cnode) {
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, type_id_, idx - 1);
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, idx - 1);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false, type_id_, idx - 1);
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, idx - 1);
}
if (status != RET_OK && status != RET_CONTINUE) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
@ -344,11 +347,11 @@ STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const Pr
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true,
type_id_, index - 1);
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_, index - 1);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true,
type_id_, index - 1);
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_, index - 1);
}
if (status == RET_CONTINUE) {
return RET_OK;
@ -559,10 +562,10 @@ STATUS WeightQuantizer::TryQuant(const int &bit_num_t, const ParameterPtr &param
if (type_id_ == TypeId::kNumberTypeInt8) {
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
bit_num_t, true, type_id_);
bit_num_t, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
} else if (type_id_ == TypeId::kNumberTypeInt16) {
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
bit_num_t, true, type_id_);
bit_num_t, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
} else {
MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
return RET_ERROR;

View File

@ -41,9 +41,9 @@ class WeightQuantizer : public Quantizer {
~WeightQuantizer() override;
STATUS DoQuantize(FuncGraphPtr func_graph) override;
STATUS DoConvQuantize(const CNodePtr &);
STATUS DoMulQuantize(const CNodePtr &);
STATUS DoOptimizerQuantize(const CNodePtr &);
STATUS DoConvQuantize(const CNodePtr &cnode);
STATUS DoMulQuantize(const CNodePtr &cnode);
STATUS DoOptimizerQuantize(const CNodePtr &cnode);
STATUS DoLstmQuantize(const CNodePtr &cnode);
STATUS DoGatherQuantize(const CNodePtr &cnode);
@ -62,6 +62,7 @@ class WeightQuantizer : public Quantizer {
PostQuantConfig config_param_;
std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...]
std::vector<std::unordered_map<std::string, mindspore::tensor::MSTensor *>> fp32_output_tensors_;
bool is_mixed_bit = false;
STATUS DoMixedQuant(const FuncGraphPtr &);
STATUS SetAbstract(const tensor::TensorPtr &tensor_info, const ParameterPtr &param_node,
@ -78,7 +79,6 @@ class WeightQuantizer : public Quantizer {
STATUS TryQuant(const int &bit_num_t, const ParameterPtr &param_node, const tensor::TensorPtr &tensor_info,
const PrimitivePtr &primitive);
STATUS DoQuantSearch(const FuncGraphPtr &func_graph);
STATUS DoTensorQuantize(const CNodePtr &);
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_