forked from mindspore-Ecosystem/mindspore
add mixed quantization
This commit is contained in:
parent
06e86e34de
commit
a64eb7cd69
|
@ -41,7 +41,8 @@ table QuantParam {
|
|||
enum WeightQunatCompressType: int {
|
||||
NONE,
|
||||
INDEXING,
|
||||
SPARSE
|
||||
SPARSE,
|
||||
FSE
|
||||
}
|
||||
|
||||
table Tensor {
|
||||
|
|
|
@ -42,13 +42,14 @@
|
|||
#include "tools/common/node_util.h"
|
||||
#include "tools/converter/converter_context.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
|
||||
#include "tools/converter/quantizer/fse_encoder.h"
|
||||
|
||||
using mindspore::ops::PrimitiveC;
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr int kIndexOfValueInputOfGetTupleItem = 2;
|
||||
|
||||
std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
|
||||
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
|
||||
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
|
||||
|
@ -116,7 +117,17 @@ static STATUS CompressTensor(schema::TensorT *tensor_input, const std::unique_pt
|
|||
auto repetition_packed = false;
|
||||
MS_LOG(DEBUG) << dst_node->name;
|
||||
if (dst_node->quantType == schema::QuantType_QUANT_WEIGHT) {
|
||||
if (bit_num <= kBitNum8) {
|
||||
if (bit_num == 0) {
|
||||
if (tensor_input->data.empty() || tensor_input->dims.size() <= 1) {
|
||||
return RET_OK;
|
||||
}
|
||||
quant::FSEEncoder fse_encoder;
|
||||
if (dst_node->primitive->value.type == PrimitiveType_GRU) {
|
||||
fse_encoder.Compress(tensor_input);
|
||||
} else {
|
||||
fse_encoder.Compress(tensor_input);
|
||||
}
|
||||
} else if (bit_num <= kBitNum8) {
|
||||
repetition_packed = PackRepetition<int8_t>(bit_num, tensor_input);
|
||||
} else {
|
||||
repetition_packed = PackRepetition<int16_t>(bit_num, tensor_input);
|
||||
|
|
|
@ -286,7 +286,7 @@ int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const con
|
|||
m_quantizer_->flags = *config;
|
||||
auto status = m_quantizer_->DoQuantize(old_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Quant failed " << status;
|
||||
MS_LOG(ERROR) << "DoQuantization failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -921,7 +921,6 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
|
|||
if (op_type == "Placeholder" || op_type == "Const" || op_type == "Identity" || op_type == "StopGradient") {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "parse op : " << op_type;
|
||||
auto node_parser = TFNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
|
||||
if (node_parser == nullptr) {
|
||||
|
@ -1055,6 +1054,21 @@ STATUS TFModelParser::ConvertRootGraphOutputs() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
output_nodes.push_back(anf_node);
|
||||
// Get the name of node 'Identity' and 'StopGradient'.
|
||||
if (pair.second->op() == "Identity" || pair.second->op() == "StopGradient") {
|
||||
auto tmp_node = pair.second;
|
||||
bool found_input = true;
|
||||
while (tmp_node->name().empty() && (tmp_node->op() == "Identity" || tmp_node->op() == "StopGradient")) {
|
||||
auto flatten_input_name = TensorFlowUtils::GetFlattenNodeName(tmp_node->input(0));
|
||||
if (tf_root_graph_nodes_.find(flatten_input_name) != tf_root_graph_nodes_.end()) {
|
||||
tmp_node = tf_root_graph_nodes_.at(flatten_input_name);
|
||||
} else {
|
||||
found_input = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
origin_name = found_input ? tmp_node->name() : origin_name;
|
||||
}
|
||||
graph_output_names_.push_back(origin_name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,10 @@ file(GLOB QUANTIZER
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/huffman_encode.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fse_decoder.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fse_bit_stream.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fse_encoder.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fix_bit_weight_quantizer.cc
|
||||
)
|
||||
set_property(SOURCE ${QUANTIZER} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
|
||||
add_library(quantizer_mid OBJECT ${QUANTIZER})
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
|
||||
#include <cmath>
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
|
||||
// the error is currently measured per channel.
|
||||
// it could be measured per layer but it would be less good.
|
||||
// the `preferred` dim should point to the output channels dimension.
|
||||
float FixBitWeightQuantizer::MeasureQuantizationError(float *weights, const int *shape, int dims, int preferred_dim,
|
||||
float scale) {
|
||||
int numel = 1;
|
||||
for (int i = 0; i < dims; i++) {
|
||||
numel *= shape[i];
|
||||
}
|
||||
int bucket_count = shape[preferred_dim];
|
||||
std::vector<float> norms2(bucket_count);
|
||||
std::vector<float> dnorms2(bucket_count);
|
||||
for (int i = 0; i < bucket_count; i++) {
|
||||
norms2[i] = 0.0;
|
||||
dnorms2[i] = 0.0;
|
||||
}
|
||||
double average_dequant = 0;
|
||||
double average_raw = 0;
|
||||
std::vector<float> dequant_datas(numel);
|
||||
int bucket_volume = 1;
|
||||
for (int i = preferred_dim; i < dims; i++) {
|
||||
bucket_volume *= shape[i];
|
||||
}
|
||||
for (int i = 0; i < numel; i++) {
|
||||
float dequant = scale * (floorf(weights[i] / scale + 0.5));
|
||||
dequant_datas[i] = dequant;
|
||||
average_raw += weights[i];
|
||||
average_dequant += dequant;
|
||||
}
|
||||
// mean
|
||||
average_dequant = average_dequant / numel;
|
||||
average_raw = average_raw / numel;
|
||||
// std
|
||||
double variance_dequant = 0;
|
||||
double variance_raw = 0;
|
||||
for (int i = 0; i < numel; i++) {
|
||||
variance_dequant += std::pow(dequant_datas[i] - average_dequant, 2);
|
||||
variance_raw += std::pow(weights[i] - average_raw, 2);
|
||||
}
|
||||
variance_dequant = std::sqrt(variance_dequant / numel);
|
||||
variance_raw = std::sqrt(variance_raw / numel);
|
||||
var_corr = variance_raw / variance_dequant;
|
||||
mean_corr = average_raw - average_dequant * var_corr;
|
||||
|
||||
for (int i = 0; i < numel; i++) {
|
||||
int bucket = (i / bucket_volume) % bucket_count;
|
||||
norms2[bucket] += weights[i] * weights[i];
|
||||
float dequant = var_corr * (scale * (floorf(weights[i] / scale + 0.5))) + mean_corr;
|
||||
float d = weights[i] - dequant;
|
||||
dnorms2[bucket] += d * d;
|
||||
}
|
||||
|
||||
int c = 0;
|
||||
float t = 1e-10;
|
||||
for (int i = 0; i < bucket_count; i++) {
|
||||
if (norms2[i] < 1.0e-10) continue;
|
||||
c += 1;
|
||||
t += sqrtf(dnorms2[i] / norms2[i]);
|
||||
}
|
||||
return t / (c + 1e-7);
|
||||
}
|
||||
|
||||
MinMax FixBitWeightQuantizer::GetMinMax(const float *arr, int arrc) {
|
||||
MinMax min_max = {INFINITY, -INFINITY};
|
||||
for (int i = 0; i < arrc; i++)
|
||||
if (arr[i] > min_max.max)
|
||||
min_max.max = arr[i];
|
||||
else if (arr[i] < min_max.min)
|
||||
min_max.min = arr[i];
|
||||
return min_max;
|
||||
}
|
||||
|
||||
BinarySearchResult FixBitWeightQuantizer::BinarySearchForQuantizationScale(float *weights, int *shape, int dims,
|
||||
int preferred_dim, int max_iters,
|
||||
float target_err, float rel_tol) {
|
||||
int element_num = 1;
|
||||
for (int i = 0; i < dims; i++) {
|
||||
element_num *= shape[i];
|
||||
}
|
||||
MinMax mm = GetMinMax(weights, element_num);
|
||||
if (mm.max < mm.min + 1.0e-5) {
|
||||
return {0, static_cast<float>(std::fabs(mm.max) + 1.0e-5)};
|
||||
}
|
||||
// start a binary search
|
||||
float curr_scale = (mm.max - mm.min) * target_err;
|
||||
float right_hs_dx = curr_scale * 2.0;
|
||||
while (MeasureQuantizationError(weights, shape, dims, preferred_dim, right_hs_dx) < target_err) {
|
||||
right_hs_dx *= 2.0;
|
||||
}
|
||||
float left_hs_dx = curr_scale / 2.0;
|
||||
while (MeasureQuantizationError(weights, shape, dims, preferred_dim, left_hs_dx) > target_err) {
|
||||
left_hs_dx /= 2.0;
|
||||
}
|
||||
int iter_count = 0;
|
||||
BinarySearchResult res = {0, curr_scale};
|
||||
while (true) {
|
||||
float curr_err = MeasureQuantizationError(weights, shape, dims, preferred_dim, res.scale);
|
||||
if (std::fabs(curr_err - target_err) / target_err < rel_tol) {
|
||||
return res;
|
||||
}
|
||||
if (iter_count > max_iters) {
|
||||
res.status = 1;
|
||||
return res;
|
||||
}
|
||||
if (curr_err > target_err)
|
||||
right_hs_dx = res.scale;
|
||||
else
|
||||
left_hs_dx = res.scale;
|
||||
res.scale = (left_hs_dx + right_hs_dx) / 2.0;
|
||||
iter_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
int FixBitWeightQuantizer::DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim,
|
||||
std::vector<schema::QuantParamT> *quant_params,
|
||||
std::vector<int16_t> *quant_datas) {
|
||||
int weight_count = 1;
|
||||
int dims = shape.size();
|
||||
int input_shape[4] = {0, 0, 0, 0};
|
||||
for (int i = 0; i < dims; i++) {
|
||||
weight_count *= shape[i];
|
||||
input_shape[i] = shape[i];
|
||||
}
|
||||
|
||||
BinarySearchResult br = BinarySearchForQuantizationScale(weights, input_shape, dims, preferred_dim, max_search_iters,
|
||||
target_relative_err, target_search_tolerance);
|
||||
if (br.status != 0) {
|
||||
MS_LOG(ERROR) << "reached_max_iters";
|
||||
return RET_ERROR;
|
||||
}
|
||||
schema::QuantParamT quant_param;
|
||||
int qr = QuantizeByScale(weights, weight_count, br.scale, &quant_param, quant_datas);
|
||||
if (qr != 0) {
|
||||
MS_LOG(ERROR) << "quant failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// It is used to calculate the Shannon entropy.
|
||||
quant_params->push_back(quant_param);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int FixBitWeightQuantizer::QuantizeByScale(const float *weights, int weightsc, float scale,
|
||||
schema::QuantParamT *quant_params, std::vector<int16_t> *quant_datas) {
|
||||
for (int i = 0; i < weightsc; i++) {
|
||||
auto q = static_cast<int>(floorf(weights[i] / scale + 0.5));
|
||||
quant_datas->at(i) = q;
|
||||
}
|
||||
quant_params->meanCorr = mean_corr;
|
||||
quant_params->varCorr = var_corr;
|
||||
quant_params->scale = scale;
|
||||
quant_params->zeroPoint = 0;
|
||||
quant_params->numBits = 0;
|
||||
quant_params->inited = true;
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_FIX_BIT_WEIGHT_QUANTIZER_H
|
||||
#define LITE_FIX_BIT_WEIGHT_QUANTIZER_H
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
typedef struct {
|
||||
int status;
|
||||
float scale;
|
||||
} BinarySearchResult;
|
||||
|
||||
typedef struct {
|
||||
float min;
|
||||
float max;
|
||||
} MinMax;
|
||||
|
||||
class FixBitWeightQuantizer {
|
||||
public:
|
||||
explicit FixBitWeightQuantizer(float target_relative_err = 0.01, float target_search_tolerance = 0.01,
|
||||
int max_search_iters = 100)
|
||||
: target_relative_err(target_relative_err),
|
||||
target_search_tolerance(target_search_tolerance),
|
||||
max_search_iters(max_search_iters) {}
|
||||
~FixBitWeightQuantizer() = default;
|
||||
|
||||
int DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim,
|
||||
std::vector<schema::QuantParamT> *quant_params, std::vector<int16_t> *quant_datas);
|
||||
|
||||
private:
|
||||
// the error is currently measured per channel.
|
||||
// it could be measured per layer but it would be less good.
|
||||
// the `preferred` dim should point to the output channels dimension.
|
||||
float MeasureQuantizationError(float *weights, const int *shape, int dims, int preferred_dim, float scale);
|
||||
|
||||
MinMax GetMinMax(const float *arr, int arrc);
|
||||
|
||||
int QuantizeByScale(const float *weights, int weightsc, float scale, schema::QuantParamT *quant_params,
|
||||
std::vector<int16_t> *quant_datas);
|
||||
|
||||
BinarySearchResult BinarySearchForQuantizationScale(float *weights, int *shape, int dims, int preferred_dim,
|
||||
int max_iters, float target_err, float rel_tol);
|
||||
|
||||
private:
|
||||
float var_corr{1};
|
||||
float mean_corr{0};
|
||||
float target_relative_err;
|
||||
float target_search_tolerance;
|
||||
int max_search_iters;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // LITE_FIX_BIT_WEIGHT_QUANTIZER_H
|
|
@ -0,0 +1,103 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/quantizer/fse_bit_stream.h"
|
||||
#include <memory.h>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
int BitStream::Create(int bit_capacity) {
|
||||
chunk_count_ = (bit_capacity >> 6);
|
||||
chunks_ = static_cast<uint64_t *>(malloc(chunk_count_ * sizeof(uint64_t)));
|
||||
if (chunks_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc memory failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(chunks_, 0, chunk_count_ * sizeof(uint64_t));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void BitStream::Free() {
|
||||
curr_chunk_index_ = -1;
|
||||
curr_chunk_ = 0;
|
||||
curr_bit_count_ = 0;
|
||||
chunk_count_ = 0;
|
||||
if (chunks_ != nullptr) {
|
||||
free(chunks_);
|
||||
chunks_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void BitStream::Empty() {
|
||||
curr_chunk_index_ = -1;
|
||||
curr_chunk_ = 0;
|
||||
curr_bit_count_ = 0;
|
||||
for (int i = 0; i < chunk_count_; i++) {
|
||||
chunks_[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t BitStream::Pop(uint8_t bit_count) {
|
||||
int64_t right = curr_chunk_ >> (64 - curr_bit_count_);
|
||||
int64_t res = right & ((1 << bit_count) - 1);
|
||||
curr_bit_count_ -= bit_count;
|
||||
if (curr_bit_count_ > 0) {
|
||||
// most likely branch
|
||||
return res;
|
||||
}
|
||||
if (curr_bit_count_ == 0) {
|
||||
// not so often...
|
||||
if (curr_chunk_index_ > -1) {
|
||||
// rare...
|
||||
curr_bit_count_ = 64;
|
||||
curr_chunk_ = chunks_[curr_chunk_index_--];
|
||||
}
|
||||
return res;
|
||||
}
|
||||
// sad path :(
|
||||
curr_bit_count_ += bit_count;
|
||||
curr_chunk_ = chunks_[curr_chunk_index_--];
|
||||
right |= (curr_chunk_ & ((1 << (bit_count - curr_bit_count_)) - 1)) << curr_bit_count_;
|
||||
curr_bit_count_ = 64 - (bit_count - curr_bit_count_);
|
||||
return right;
|
||||
}
|
||||
|
||||
void BitStream::Push(int64_t state, uint8_t bit_count) {
|
||||
curr_bit_count_ += bit_count;
|
||||
if (curr_bit_count_ <= 64) {
|
||||
// happy path, no split
|
||||
curr_chunk_ = (curr_chunk_ << bit_count) | (state & ((1 << bit_count) - 1));
|
||||
if (curr_bit_count_ == 64) {
|
||||
// flush (rare)
|
||||
chunks_[++curr_chunk_index_] = curr_chunk_;
|
||||
curr_chunk_ = 0;
|
||||
curr_bit_count_ = 0;
|
||||
}
|
||||
} else {
|
||||
// split, rare
|
||||
int leftbits = curr_bit_count_ - 64;
|
||||
int rightbits = bit_count - leftbits;
|
||||
curr_chunk_ = (curr_chunk_ << rightbits) | ((state >> leftbits) & ((1 << rightbits) - 1));
|
||||
// flush left
|
||||
chunks_[++curr_chunk_index_] = curr_chunk_;
|
||||
curr_chunk_ = state & ((1 << leftbits) - 1);
|
||||
curr_bit_count_ = leftbits;
|
||||
}
|
||||
}
|
||||
|
||||
void BitStream::Flush() { curr_chunk_ <<= 64 - curr_bit_count_; }
|
||||
} // namespace mindspore::lite::quant
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_FSEBITSTREAM_H
|
||||
#define LITE_FSEBITSTREAM_H
|
||||
#include <cstdint>
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class BitStream {
|
||||
public:
|
||||
BitStream() = default;
|
||||
|
||||
~BitStream() = default;
|
||||
|
||||
public:
|
||||
int Create(int bit_capacity);
|
||||
void Free();
|
||||
void Empty();
|
||||
int64_t Pop(uint8_t bit_count);
|
||||
void Push(int64_t state, uint8_t bit_count);
|
||||
void Flush();
|
||||
|
||||
int32_t GetCurrChunkIndex() { return this->curr_chunk_index_; }
|
||||
uint64_t GetCurrChunk() { return this->curr_chunk_; }
|
||||
int8_t GetCurrBitCount() { return this->curr_bit_count_; }
|
||||
uint64_t *GetChunks() { return this->chunks_; }
|
||||
int GetChunkCount() { return this->chunk_count_; }
|
||||
|
||||
void SetCurrChunkIndex(int32_t curr_chunk_index) { this->curr_chunk_index_ = curr_chunk_index; }
|
||||
void SetCurrChunk(uint64_t curr_chunk) { this->curr_chunk_ = curr_chunk; }
|
||||
void SetCurrBitCount(int8_t curr_bit_count) { this->curr_bit_count_ = curr_bit_count; }
|
||||
void SetChunks(uint64_t *chunks) { this->chunks_ = chunks; }
|
||||
void SetChunkCount(int chunk_count) { this->chunk_count_ = chunk_count; }
|
||||
|
||||
private:
|
||||
int32_t curr_chunk_index_{-1}; // the index of the next chunk that we will write to
|
||||
uint64_t curr_chunk_{0};
|
||||
int8_t curr_bit_count_{0}; // the number of bits that are currently written in the register.
|
||||
uint64_t *chunks_{nullptr}; // the actual memory
|
||||
int chunk_count_{0}; // the number of chunks
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // LITE_FSEBITSTREAM_H
|
|
@ -0,0 +1,350 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/quantizer/fse_encoder.h"
|
||||
#include <cstdint>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include "mindspore/core/ir/dtype/type_id.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
// The function gives the index of most import `1` in the binary representation.
|
||||
// e.g. for the number 00100 it gives 2.
|
||||
int fse_count_bits(int32_t x) { return __builtin_clz(x) ^ 31; }
|
||||
|
||||
int FSEEncoder::FSECreateStatesForEncoding(uint16_t *frequency, int frequency_count, int table_log,
|
||||
uint32_t *delta_bit_count, int16_t *delta_state, uint16_t *coding_table,
|
||||
uint16_t *symbol_table) {
|
||||
int tablesize = 1 << table_log;
|
||||
int tablemask = tablesize - 1;
|
||||
int step = ((tablesize >> 1) + (tablesize >> 3) + 3);
|
||||
int pos = 0;
|
||||
// Separate the same symbols, coding will be better if the same characters are distributed evenly across the table.
|
||||
for (int sym = 0; sym < frequency_count; sym++) {
|
||||
for (int i = 0; i < frequency[sym]; i++) {
|
||||
symbol_table[pos] = sym;
|
||||
pos = (pos + step) & tablemask;
|
||||
while (pos > tablemask) pos = (pos + step) & tablemask;
|
||||
}
|
||||
}
|
||||
if (pos != 0) return 1;
|
||||
|
||||
std::vector<uint16_t> cfreqs(frequency_count + 2);
|
||||
cfreqs[0] = 0;
|
||||
for (int i = 1; i < frequency_count + 1; i++) {
|
||||
cfreqs[i] = cfreqs[i - 1] + frequency[i - 1];
|
||||
}
|
||||
cfreqs[frequency_count + 1] = cfreqs[frequency_count] + 1;
|
||||
for (int i = 0; i < tablesize; i++) {
|
||||
uint16_t sym = symbol_table[i];
|
||||
coding_table[cfreqs[sym]] = tablesize + i;
|
||||
cfreqs[sym] += 1;
|
||||
}
|
||||
|
||||
int total = 0;
|
||||
for (int sym = 0; sym < frequency_count; sym++) {
|
||||
if (frequency[sym] >= 2) {
|
||||
int max_bits_out = table_log - fse_count_bits(frequency[sym] - 1);
|
||||
int min_state_plus = frequency[sym] << max_bits_out;
|
||||
delta_bit_count[sym] = (max_bits_out << 16) - min_state_plus;
|
||||
delta_state[sym] = total - frequency[sym];
|
||||
total += frequency[sym];
|
||||
} else {
|
||||
// we assume minimum `frequency` is 1
|
||||
delta_bit_count[sym] = (table_log << 16) - (1 << table_log);
|
||||
delta_state[sym] = total - 1;
|
||||
total++;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ConvertTensor2Quant(schema::TensorT *tensor_input, FSEQuant *quants) {
|
||||
std::vector<int16_t> dequants;
|
||||
for (size_t i = 0; i < tensor_input->data.size() / sizeof(int16_t); ++i) {
|
||||
auto data = static_cast<int16_t>(reinterpret_cast<int16_t *>(tensor_input->data.data())[i]);
|
||||
dequants.push_back(data);
|
||||
}
|
||||
|
||||
int qmin = *min_element(dequants.begin(), dequants.end());
|
||||
int qmax = *max_element(dequants.begin(), dequants.end());
|
||||
int uncompressed_frequency_count = qmax - qmin + 1;
|
||||
std::vector<int> uncompressed_frequency(uncompressed_frequency_count);
|
||||
for (int i = 0; i < uncompressed_frequency_count; i++) {
|
||||
uncompressed_frequency[i] = 0;
|
||||
}
|
||||
for (size_t i = 0; i < tensor_input->data.size() / sizeof(int16_t); i++) {
|
||||
auto data = static_cast<int16_t>(reinterpret_cast<int16_t *>(tensor_input->data.data())[i]);
|
||||
int q = data - qmin;
|
||||
uncompressed_frequency[q] += 1;
|
||||
}
|
||||
|
||||
std::vector<uint16_t> uncompressed_freqs_to_compressed_sym(uncompressed_frequency_count);
|
||||
int sym = 0;
|
||||
for (int i = 0; i < uncompressed_frequency_count; i++) {
|
||||
if (uncompressed_frequency[i]) {
|
||||
if (sym >= MAX_SYMS) return 1; // too many symbols!
|
||||
uncompressed_freqs_to_compressed_sym[i] = sym;
|
||||
quants->frequency[sym] = uncompressed_frequency[i];
|
||||
quants->centroids[sym] =
|
||||
tensor_input->quantParams.front()->varCorr * tensor_input->quantParams.front()->scale * (i + qmin) +
|
||||
tensor_input->quantParams.front()->meanCorr;
|
||||
sym++;
|
||||
}
|
||||
}
|
||||
quants->size = sym;
|
||||
quants->symbol_table_count = tensor_input->data.size() / sizeof(int16_t);
|
||||
quants->symbol_table = static_cast<uint16_t *>(malloc(quants->symbol_table_count * sizeof(uint16_t)));
|
||||
if (quants->symbol_table == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc memory failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (int i = 0; i < quants->symbol_table_count; i++) {
|
||||
auto data = static_cast<int16_t>(reinterpret_cast<int16_t *>(tensor_input->data.data())[i]);
|
||||
int q = data - qmin;
|
||||
sym = uncompressed_freqs_to_compressed_sym[q];
|
||||
quants->symbol_table[i] = sym;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int FSEEncoder::Compress(schema::TensorT *tensor_input) {
|
||||
int table_log = 0;
|
||||
FSEQuant fse_quant;
|
||||
ConvertTensor2Quant(tensor_input, &fse_quant);
|
||||
NormalizeFrequency(&fse_quant, &table_log);
|
||||
BitStream bs;
|
||||
int ret;
|
||||
ret = bs.Create(16 * fse_quant.symbol_table_count);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "BitStream Create failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = FSEEncode(&bs, fse_quant.symbol_table, fse_quant.symbol_table_count, fse_quant.frequency, fse_quant.size,
|
||||
table_log);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "FSE Encode failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
bs.Flush();
|
||||
// Serializing to out:
|
||||
ret = SerializingToOut(tensor_input, &bs, fse_quant, table_log);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Serializing To Out failed.";
|
||||
return ret;
|
||||
}
|
||||
bs.Free();
|
||||
free(fse_quant.symbol_table);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
uint16_t FSEEncoder::FSEEncodeSymbolGetNewState(BitStream *bs, uint16_t sym, uint16_t state,
|
||||
const uint32_t *delta_bit_count, const int16_t *delta_state,
|
||||
uint16_t *coding_table) {
|
||||
// It is to determine the number of bits to flush.
|
||||
// This is basically one of 2 values, n or n+1, depending on state crossing a threshold.
|
||||
uint8_t bits_out = (state + delta_bit_count[sym]) >> 16;
|
||||
bs->Push(state, bits_out);
|
||||
// subrangeID = state >> nbBitsOut
|
||||
return coding_table[(state >> bits_out) + delta_state[sym]];
|
||||
}
|
||||
|
||||
int GetMaxIndex(const uint16_t *arr, int arr_count) {
|
||||
float max = -INFINITY;
|
||||
int index = -1;
|
||||
for (int i = 0; i < arr_count; i++) {
|
||||
if (arr[i] > max) {
|
||||
max = arr[i];
|
||||
index = i;
|
||||
}
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
void FSEEncoder::NormalizeFrequency(FSEQuant *q, int *table_log) {
|
||||
// The higher the number, the more accurate we'll be to the shannon entropy,
|
||||
// but also the larger the table, so `+3` is a good compromise.
|
||||
*table_log = std::min(MAX_TABLE_LOG, fse_count_bits((uint32_t)q->size) + 3);
|
||||
int new_table_size = 1 << (*table_log);
|
||||
int curr_table_size = 0;
|
||||
for (int i = 0; i < q->size; i++) curr_table_size += q->frequency[i];
|
||||
|
||||
// normalize
|
||||
int updated_table_size = 0;
|
||||
float rat = (static_cast<float>(new_table_size)) / curr_table_size;
|
||||
for (int i = 0; i < q->size; i++) {
|
||||
q->frequency[i] = std::max(1, static_cast<int>(floorf(0.5 + rat * q->frequency[i])));
|
||||
updated_table_size += q->frequency[i];
|
||||
}
|
||||
|
||||
// If the sum of the symbol frequencies is not equal to the power of two (almost always),
|
||||
// then the frequencies need to be normalized-they must be proportionally reduced (or increased) so that the power of
|
||||
// two is obtained in total.
|
||||
// shrink
|
||||
while (updated_table_size > new_table_size) {
|
||||
int max_ix = GetMaxIndex(q->frequency, q->size);
|
||||
q->frequency[max_ix]--;
|
||||
updated_table_size--;
|
||||
}
|
||||
|
||||
// grow
|
||||
if (updated_table_size < new_table_size) {
|
||||
int max_ix = GetMaxIndex(q->frequency, q->size);
|
||||
q->frequency[max_ix] += new_table_size - updated_table_size;
|
||||
}
|
||||
}
|
||||
|
||||
// Encoding is therefore just a repeat of this process :
|
||||
// - get Symbol to encode
|
||||
// - look at current state value
|
||||
// - determine nbBits, flush them
|
||||
// - determine sub-Range Id
|
||||
// - look for Symbol position of same Id : you get your next state
|
||||
int FSEEncoder::FSEEncode(BitStream *bs, const uint16_t *data, int data_count, uint16_t *frequency, int frequency_count,
|
||||
int table_log) {
|
||||
int table_size = 1 << table_log;
|
||||
// symbolTT.deltaNbBits stores a value which, when added with state,
|
||||
// makes the result of >> 16 produces either n or n+1, as required.
|
||||
std::vector<uint32_t> delta_number_bits(frequency_count);
|
||||
for (int i = 0; i < frequency_count; i++) {
|
||||
delta_number_bits[i] = 0;
|
||||
}
|
||||
// symbolTT.deltaFindState provides the offset to find the correct segment into the table.
|
||||
std::vector<int16_t> delta_find_state(frequency_count);
|
||||
for (int i = 0; i < frequency_count; i++) {
|
||||
delta_find_state[i] = 0;
|
||||
}
|
||||
// nextStateTable with symbol
|
||||
std::vector<uint16_t> coding_table(table_size);
|
||||
for (int i = 0; i < table_size; i++) {
|
||||
coding_table[i] = 0;
|
||||
}
|
||||
// position with symbol
|
||||
std::vector<uint16_t> symtable(table_size);
|
||||
for (int i = 0; i < table_size; i++) {
|
||||
symtable[i] = 0;
|
||||
}
|
||||
int ret = FSECreateStatesForEncoding(frequency, frequency_count, table_log, delta_number_bits.data(),
|
||||
delta_find_state.data(), coding_table.data(), symtable.data());
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Create states table for encoding failed.";
|
||||
return ret;
|
||||
}
|
||||
uint16_t state = table_size;
|
||||
// The results of the 1st symbol encoding is not flushed to the bitstream,
|
||||
// It is just to get a valid 1 st state.
|
||||
state = FSEEncodeSymbolGetNewState(bs, data[0], state, delta_number_bits.data(), delta_find_state.data(),
|
||||
coding_table.data());
|
||||
bs->Empty();
|
||||
for (int i = 0; i < data_count; i++) {
|
||||
state = FSEEncodeSymbolGetNewState(bs, data[i], state, delta_number_bits.data(), delta_find_state.data(),
|
||||
coding_table.data());
|
||||
}
|
||||
bs->Push(state - table_size, table_log);
|
||||
return ret;
|
||||
}
|
||||
|
||||
int FSEEncoder::SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant,
|
||||
int table_log) {
|
||||
auto max_size = tensor_input->data.size() * 2;
|
||||
auto *out8 = static_cast<uint8_t *>(malloc(max_size));
|
||||
if (out8 == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc memory failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int offset = 0;
|
||||
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)fse_quant.size;
|
||||
offset += sizeof(uint16_t);
|
||||
if (offset + sizeof(uint16_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
}
|
||||
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)table_log;
|
||||
offset += sizeof(uint16_t);
|
||||
int chunksc = bs->GetCurrChunkIndex() + 2;
|
||||
if (offset + sizeof(uint32_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
}
|
||||
*(reinterpret_cast<uint32_t *>(&out8[offset])) = (uint32_t)chunksc;
|
||||
offset += sizeof(uint32_t);
|
||||
for (int j = 0; j < fse_quant.size; j++) {
|
||||
if (offset + sizeof(uint16_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
}
|
||||
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)fse_quant.frequency[j];
|
||||
offset += sizeof(uint16_t);
|
||||
}
|
||||
while (offset % 8 != 0) {
|
||||
if (offset + sizeof(uint16_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
}
|
||||
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)0;
|
||||
offset += sizeof(uint16_t);
|
||||
}
|
||||
for (int j = 0; j < fse_quant.size; j++) {
|
||||
if (offset + sizeof(float) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
}
|
||||
*(reinterpret_cast<float *>(&out8[offset])) = static_cast<float>(fse_quant.centroids[j]);
|
||||
offset += sizeof(float);
|
||||
}
|
||||
while (offset % 8 != 0) {
|
||||
if (offset + sizeof(uint16_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
}
|
||||
*(reinterpret_cast<uint16_t *>(&out8[offset])) = (uint16_t)0;
|
||||
offset += sizeof(uint16_t);
|
||||
}
|
||||
for (int j = 0; j < bs->GetCurrChunkIndex() + 1; j++) {
|
||||
if (offset + sizeof(uint64_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
}
|
||||
*(reinterpret_cast<uint64_t *>(&out8[offset])) = (uint64_t)bs->GetChunks()[j];
|
||||
offset += sizeof(uint64_t);
|
||||
}
|
||||
if (offset + sizeof(uint64_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
}
|
||||
*(reinterpret_cast<uint64_t *>(&out8[offset])) = (uint64_t)bs->GetCurrChunk();
|
||||
offset += sizeof(uint64_t);
|
||||
if (offset + sizeof(uint8_t) > max_size) {
|
||||
MS_LOG(ERROR) << "offset over max size"
|
||||
<< " offset:" << offset << " max_size:" << max_size;
|
||||
}
|
||||
*(reinterpret_cast<uint8_t *>(&out8[offset])) = (uint8_t)bs->GetCurrBitCount();
|
||||
offset += sizeof(uint8_t);
|
||||
if (static_cast<int>(offset) < static_cast<int>(tensor_input->data.size())) {
|
||||
tensor_input->data.resize(offset);
|
||||
if (memcpy_s(tensor_input->data.data(), offset, out8, offset) != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy failed.";
|
||||
}
|
||||
}
|
||||
tensor_input->quantParams.clear();
|
||||
tensor_input->weightQunatCompressType = schema::WeightQunatCompressType_FSE;
|
||||
tensor_input->dataType = TypeId::kNumberTypeFloat32;
|
||||
free(out8);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef LITE_FSE_ENCODER_H
|
||||
#define LITE_FSE_ENCODER_H
|
||||
|
||||
#include <vector>
|
||||
#include "tools/converter/quantizer/fse_bit_stream.h"
|
||||
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
|
||||
namespace mindspore::lite::quant {
|
||||
constexpr int MAX_SYMS = 65534;
|
||||
constexpr int MAX_TABLE_LOG = 16;
|
||||
typedef struct {
|
||||
uint16_t *symbol_table; // the place to store the quantized tensor
|
||||
int symbol_table_count; // the number of symbols that exist
|
||||
float centroids[MAX_SYMS]; // the mean of all the numbers that got quantized into it
|
||||
uint16_t frequency[MAX_SYMS]; // holds the number of times each symbol appears in `*symbol_table`
|
||||
int size; // the number of entries in `symbol_table`
|
||||
} FSEQuant;
|
||||
|
||||
class FSEEncoder {
|
||||
public:
|
||||
FSEEncoder() = default;
|
||||
~FSEEncoder() = default;
|
||||
int Compress(schema::TensorT *tensor_input);
|
||||
|
||||
private:
|
||||
int FSECreateStatesForEncoding(uint16_t *frequency, int frequency_count, int table_log, uint32_t *delta_bit_count,
|
||||
int16_t *delta_state, uint16_t *coding_table, uint16_t *symbol_table);
|
||||
|
||||
uint16_t FSEEncodeSymbolGetNewState(BitStream *bs, uint16_t sym, uint16_t state, const uint32_t *delta_bit_count,
|
||||
const int16_t *delta_state, uint16_t *coding_table);
|
||||
|
||||
int FSEEncode(BitStream *bs, const uint16_t *data, int data_count, uint16_t *frequency, int frequency_count,
|
||||
int table_log);
|
||||
|
||||
void NormalizeFrequency(FSEQuant *q, int *table_log);
|
||||
|
||||
int SerializingToOut(schema::TensorT *tensor_input, BitStream *bs, const FSEQuant &fse_quant, int table_log);
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // LITE_FSE_ENCODER_H
|
|
@ -580,8 +580,9 @@ STATUS PostTrainingQuantizer::DoWeightQuant(const std::string &op_name, const An
|
|||
quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
|
||||
}
|
||||
}
|
||||
auto weight_quant_type = perchanel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||
auto status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_PostTraining, quant_max_t, quant_min_t, bit_num_t,
|
||||
perchanel, kNumberTypeInt8);
|
||||
weight_quant_type, kNumberTypeInt8);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||
return status;
|
||||
|
|
|
@ -1017,4 +1017,53 @@ void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<in
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
|
||||
WeightQuantType weight_quant_type, TypeId quant_data_type, int index) {
|
||||
MS_ASSERT(weight != nullptr);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto dims = weight->shape();
|
||||
if (weight_quant_type == FIXED_BIT_PER_CHANNEL) {
|
||||
if (dims.size() <= 1) {
|
||||
MS_LOG(WARNING) << "dims is " << dims.size() << " can not per_channel";
|
||||
weight_quant_type = FIXED_BIT_PER_LAYER;
|
||||
}
|
||||
}
|
||||
std::vector<schema::QuantParamT> quant_params;
|
||||
size_t elem_count = weight->DataSize();
|
||||
auto *raw_data = static_cast<float *>(weight->data_c());
|
||||
if (raw_data == nullptr) {
|
||||
MS_LOG(ERROR) << "rawDatas is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
std::vector<int16_t> quant_data(elem_count);
|
||||
int ret = RET_OK;
|
||||
if (weight_quant_type == MIXED_BIT_PER_LAYER) {
|
||||
FixBitWeightQuantizer quantizer(0.02);
|
||||
quantizer.DoQuantization(static_cast<float *>(weight->data_c()), weight->shape_c(), index - 1, &quant_params,
|
||||
&quant_data);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported weight quant type:" << weight_quant_type;
|
||||
}
|
||||
auto status =
|
||||
UpdateTensorDataAndSize(weight, quant_data.data(), quant_data.size() * sizeof(int16_t), TypeId::kNumberTypeInt16);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (quant_params.empty()) {
|
||||
MS_LOG(ERROR) << "quant_params empty";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
if (quant_type == QuantType_PostTraining) {
|
||||
quant_param_holder->AddInputQuantParam(quant_params);
|
||||
} else {
|
||||
quant_param_holder->set_input_quant_param(index, quant_params);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -40,12 +40,18 @@
|
|||
#include "abstract/dshape.h"
|
||||
#include "tools/converter/quantizer/huffman_encode.h"
|
||||
#include "tools/converter/quantizer/bitpacking.h"
|
||||
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
|
||||
#include "src/lite_session.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "src/common/quant_utils.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
enum WeightQuantType {
|
||||
FIXED_BIT_PER_CHANNEL = 0,
|
||||
FIXED_BIT_PER_LAYER = 1,
|
||||
MIXED_BIT_PER_LAYER = 2,
|
||||
};
|
||||
constexpr size_t kUint8Quantization = 8;
|
||||
constexpr size_t kMaxBit = 8;
|
||||
constexpr size_t kMaxNum1024 = 1024;
|
||||
|
@ -155,17 +161,20 @@ STATUS DoBitPack(const tensor::TensorPtr &weight, const size_t &bit_num, const s
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
|
||||
WeightQuantType weight_quant_type, TypeId quant_data_type, int index = 1);
|
||||
|
||||
template <typename T>
|
||||
STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type, int quant_max,
|
||||
int quant_min, size_t bit_num, bool per_channel, TypeId quant_data_type, int index = 1,
|
||||
bool k_means = false) {
|
||||
int quant_min, size_t bit_num, WeightQuantType weight_quant_type, TypeId quant_data_type,
|
||||
int index = 1, bool k_means = false) {
|
||||
MS_ASSERT(weight != nullptr);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto dims = weight->shape();
|
||||
if (per_channel) {
|
||||
if (weight_quant_type == FIXED_BIT_PER_CHANNEL) {
|
||||
if (dims.size() <= 1) {
|
||||
MS_LOG(WARNING) << "dims is " << dims.size() << " can not per_channel";
|
||||
per_channel = false;
|
||||
weight_quant_type = FIXED_BIT_PER_LAYER;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -179,7 +188,7 @@ STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitiv
|
|||
|
||||
std::vector<T> quant_data(elem_count);
|
||||
int ret = RET_OK;
|
||||
if (per_channel) {
|
||||
if (weight_quant_type == FIXED_BIT_PER_CHANNEL) {
|
||||
bool channel_at_first = true;
|
||||
int channel_cnt = -1;
|
||||
CalQuantAssitInfo(primitive, dims, index, &channel_at_first, &channel_cnt);
|
||||
|
@ -197,13 +206,15 @@ STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitiv
|
|||
MS_LOG(ERROR) << "Do per channel quant failed.";
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
} else if (weight_quant_type == FIXED_BIT_PER_LAYER) {
|
||||
ret = DoPerLayerQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(), &quant_params, quant_max,
|
||||
quant_min, bit_num, k_means, &quant_data);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do per layer quant failed.";
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported weight quant type:" << weight_quant_type;
|
||||
}
|
||||
auto status = UpdateTensorDataAndSize(weight, quant_data.data(), quant_data.size() * sizeof(T), quant_data_type);
|
||||
if (status != RET_OK) {
|
||||
|
|
|
@ -105,12 +105,15 @@ STATUS WeightQuantizer::DoConvQuantize(const CNodePtr &cnode) {
|
|||
return RET_OK;
|
||||
}
|
||||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true,
|
||||
type_id_);
|
||||
if (is_mixed_bit) {
|
||||
type_id_ = kNumberTypeInt16;
|
||||
status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER, type_id_);
|
||||
} else if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true,
|
||||
type_id_);
|
||||
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
|
||||
}
|
||||
if (status == RET_CONTINUE) {
|
||||
return RET_OK;
|
||||
|
@ -142,16 +145,16 @@ STATUS WeightQuantizer::DoMulQuantize(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
auto status = RET_ERROR;
|
||||
auto per_channel = true;
|
||||
if (i == kInputSize2) {
|
||||
per_channel = false;
|
||||
auto weight_quant_type = WeightQuantType::FIXED_BIT_PER_CHANNEL;
|
||||
if (i == 3) {
|
||||
weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||
}
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
|
||||
bit_num_, per_channel, type_id_, i - 1);
|
||||
bit_num_, weight_quant_type, type_id_, i - 1);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
|
||||
bit_num_, per_channel, type_id_, i - 1);
|
||||
bit_num_, weight_quant_type, type_id_, i - 1);
|
||||
}
|
||||
if (status == RET_CONTINUE) {
|
||||
continue;
|
||||
|
@ -225,11 +228,11 @@ STATUS WeightQuantizer::DoGatherQuantize(const CNodePtr &cnode) {
|
|||
|
||||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false,
|
||||
type_id_, 0);
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, 0);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
false, type_id_, 0);
|
||||
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, 0);
|
||||
}
|
||||
if (status == RET_CONTINUE) {
|
||||
return RET_OK;
|
||||
|
@ -274,10 +277,10 @@ STATUS WeightQuantizer::DoOptimizerQuantize(const CNodePtr &cnode) {
|
|||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
false, type_id_, idx - 1);
|
||||
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, idx - 1);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
false, type_id_, idx - 1);
|
||||
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, idx - 1);
|
||||
}
|
||||
if (status != RET_OK && status != RET_CONTINUE) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
|
@ -344,11 +347,11 @@ STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const Pr
|
|||
}
|
||||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true,
|
||||
type_id_, index - 1);
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_, index - 1);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true,
|
||||
type_id_, index - 1);
|
||||
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_, index - 1);
|
||||
}
|
||||
if (status == RET_CONTINUE) {
|
||||
return RET_OK;
|
||||
|
@ -559,10 +562,10 @@ STATUS WeightQuantizer::TryQuant(const int &bit_num_t, const ParameterPtr ¶m
|
|||
|
||||
if (type_id_ == TypeId::kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
|
||||
bit_num_t, true, type_id_);
|
||||
bit_num_t, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
|
||||
} else if (type_id_ == TypeId::kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
|
||||
bit_num_t, true, type_id_);
|
||||
bit_num_t, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -41,9 +41,9 @@ class WeightQuantizer : public Quantizer {
|
|||
~WeightQuantizer() override;
|
||||
|
||||
STATUS DoQuantize(FuncGraphPtr func_graph) override;
|
||||
STATUS DoConvQuantize(const CNodePtr &);
|
||||
STATUS DoMulQuantize(const CNodePtr &);
|
||||
STATUS DoOptimizerQuantize(const CNodePtr &);
|
||||
STATUS DoConvQuantize(const CNodePtr &cnode);
|
||||
STATUS DoMulQuantize(const CNodePtr &cnode);
|
||||
STATUS DoOptimizerQuantize(const CNodePtr &cnode);
|
||||
STATUS DoLstmQuantize(const CNodePtr &cnode);
|
||||
STATUS DoGatherQuantize(const CNodePtr &cnode);
|
||||
|
||||
|
@ -62,6 +62,7 @@ class WeightQuantizer : public Quantizer {
|
|||
PostQuantConfig config_param_;
|
||||
std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...]
|
||||
std::vector<std::unordered_map<std::string, mindspore::tensor::MSTensor *>> fp32_output_tensors_;
|
||||
bool is_mixed_bit = false;
|
||||
|
||||
STATUS DoMixedQuant(const FuncGraphPtr &);
|
||||
STATUS SetAbstract(const tensor::TensorPtr &tensor_info, const ParameterPtr ¶m_node,
|
||||
|
@ -78,7 +79,6 @@ class WeightQuantizer : public Quantizer {
|
|||
STATUS TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, const tensor::TensorPtr &tensor_info,
|
||||
const PrimitivePtr &primitive);
|
||||
STATUS DoQuantSearch(const FuncGraphPtr &func_graph);
|
||||
STATUS DoTensorQuantize(const CNodePtr &);
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_
|
||||
|
|
Loading…
Reference in New Issue