forked from mindspore-Ecosystem/mindspore
!46806 support perlayer && update_mindir
Merge pull request !46806 from yeyunpeng2020/master_quant_flag_ci_2
This commit is contained in:
commit
49bf9a8ad2
|
@ -1654,6 +1654,7 @@ GVAR_DEF(PrimitivePtr, kPrimEltwise, std::make_shared<Primitive>("Eltwise"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimMatMulFusion, std::make_shared<Primitive>("MatMulFusion"));
|
GVAR_DEF(PrimitivePtr, kPrimMatMulFusion, std::make_shared<Primitive>("MatMulFusion"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimDynamicQuant, std::make_shared<Primitive>("DynamicQuant"));
|
GVAR_DEF(PrimitivePtr, kPrimDynamicQuant, std::make_shared<Primitive>("DynamicQuant"));
|
||||||
GVAR_DEF(PrimitivePtr, kPrimPartialFusion, std::make_shared<Primitive>("PartialFusion"));
|
GVAR_DEF(PrimitivePtr, kPrimPartialFusion, std::make_shared<Primitive>("PartialFusion"));
|
||||||
|
GVAR_DEF(PrimitivePtr, kPrimFSEDecode, std::make_shared<Primitive>("FSEDecode"));
|
||||||
|
|
||||||
// Type introspection
|
// Type introspection
|
||||||
GVAR_DEF(PrimitivePtr, kPrimTypeOf, std::make_shared<Primitive>("typeof"));
|
GVAR_DEF(PrimitivePtr, kPrimTypeOf, std::make_shared<Primitive>("typeof"));
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/fse_decode.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "abstract/ops/primitive_infer_map.h"
|
||||||
|
#include "mindapi/src/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
MIND_API_OPERATOR_IMPL(FSEDecode, BaseOperator);
|
||||||
|
abstract::ShapePtr FSEDecodeInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||||
|
std::vector<int64_t> output_shape;
|
||||||
|
auto input_y = input_args[kInputIndex6];
|
||||||
|
MS_EXCEPTION_IF_NULL(input_y);
|
||||||
|
if (input_y->isa<abstract::AbstractTensor>()) {
|
||||||
|
auto y_value = input_y->BuildValue();
|
||||||
|
MS_EXCEPTION_IF_NULL(y_value);
|
||||||
|
abstract::ShapePtr y_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1);
|
||||||
|
auto shape_value = y_shape->shape();
|
||||||
|
if (shape_value.size() != 1) {
|
||||||
|
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the shape size must be 1, but got: " << shape_value.size()
|
||||||
|
<< ".";
|
||||||
|
}
|
||||||
|
if (y_shape->IsDynamic()) {
|
||||||
|
output_shape.push_back(abstract::Shape::kShapeRankAny);
|
||||||
|
} else {
|
||||||
|
output_shape = GetShapeValue(primitive, input_y);
|
||||||
|
}
|
||||||
|
return std::make_shared<abstract::Shape>(output_shape);
|
||||||
|
} else {
|
||||||
|
MS_EXCEPTION(TypeError) << "input_y must be AbstractTensor" << input_y;
|
||||||
|
}
|
||||||
|
return std::make_shared<abstract::Shape>(output_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr FSEDecodeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
auto dst_t = prim->GetAttr(kDstT);
|
||||||
|
return TypeIdToType(static_cast<TypeId>(GetValue<int64_t>(dst_t)));
|
||||||
|
}
|
||||||
|
|
||||||
|
void FSEDecode::set_dst_t(const int64_t dst_t) { (void)AddAttr(kDstT, api::MakeValue(dst_t)); }
|
||||||
|
int64_t FSEDecode::get_dst_t() const { return GetValue<int64_t>(GetAttr(kDstT)); }
|
||||||
|
|
||||||
|
void FSEDecode::set_curr_chunk(const int64_t curr_chunk) { (void)AddAttr(KCurrChunk, api::MakeValue(curr_chunk)); }
|
||||||
|
int64_t FSEDecode::get_curr_chunk() const { return GetValue<int64_t>(GetAttr(KCurrChunk)); }
|
||||||
|
|
||||||
|
void FSEDecode::set_curr_chunk_index(const int64_t curr_chunk_index) {
|
||||||
|
(void)AddAttr(KCurrChunkIndex, api::MakeValue(curr_chunk_index));
|
||||||
|
}
|
||||||
|
int64_t FSEDecode::get_curr_chunk_index() const { return GetValue<int64_t>(GetAttr(KCurrChunkIndex)); }
|
||||||
|
|
||||||
|
void FSEDecode::set_curr_bit_count(const int64_t curr_bit_count) {
|
||||||
|
(void)AddAttr(KCurrBitCount, api::MakeValue(curr_bit_count));
|
||||||
|
}
|
||||||
|
int64_t FSEDecode::get_curr_bit_count() const { return GetValue<int64_t>(GetAttr(KCurrBitCount)); }
|
||||||
|
|
||||||
|
void FSEDecode::set_table_log(const int64_t table_log) { (void)AddAttr(KTableLog, api::MakeValue(table_log)); }
|
||||||
|
int64_t FSEDecode::get_table_log() const { return GetValue<int64_t>(GetAttr(KTableLog)); }
|
||||||
|
|
||||||
|
void FSEDecode::Init(const int64_t dst_t, const int64_t curr_chunk, const int64_t curr_chunk_index,
|
||||||
|
const int64_t curr_bit_count, const int64_t table_log) {
|
||||||
|
this->set_dst_t(dst_t);
|
||||||
|
this->set_curr_chunk(curr_chunk);
|
||||||
|
this->set_curr_chunk_index(curr_chunk_index);
|
||||||
|
this->set_curr_bit_count(curr_bit_count);
|
||||||
|
this->set_table_log(table_log);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr FSEDecodeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
const int64_t kInputsNum = 7;
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
|
||||||
|
auto infertype = FSEDecodeInferType(primitive, input_args);
|
||||||
|
auto infershape = FSEDecodeInferShape(primitive, input_args);
|
||||||
|
return abstract::MakeAbstract(infershape, infertype);
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(FSEDecode, prim::kPrimFSEDecode, FSEDecodeInfer, nullptr, true);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,100 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_FSE_DECODER_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_FSE_DECODER_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "ops/base_operator.h"
|
||||||
|
#include "mindapi/base/types.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameFSEDecode = "FSEDecode";
|
||||||
|
/// \brief FSEDecode FSEDecode the FSEDecode operator prototype.
|
||||||
|
class MIND_API FSEDecode : public BaseOperator {
|
||||||
|
public:
|
||||||
|
MIND_API_BASE_MEMBER(FSEDecode);
|
||||||
|
/// \brief Constructor.
|
||||||
|
FSEDecode() : BaseOperator(kNameFSEDecode) {}
|
||||||
|
|
||||||
|
/// \brief Method to init the op's attributes.
|
||||||
|
///
|
||||||
|
/// \param[in] dst_t Define the data type of output.
|
||||||
|
void Init(const int64_t dst_t, const int64_t curr_chunk, const int64_t curr_chunk_index, const int64_t curr_bit_count,
|
||||||
|
const int64_t table_log);
|
||||||
|
|
||||||
|
/// \brief Method to set dst_t attribute.
|
||||||
|
///
|
||||||
|
/// \param[in] dst_t Define the data type of output.
|
||||||
|
void set_dst_t(const int64_t dst_t);
|
||||||
|
|
||||||
|
/// \brief Method to get dst_t attribute.
|
||||||
|
///
|
||||||
|
/// \return the data type of output.
|
||||||
|
int64_t get_dst_t() const;
|
||||||
|
|
||||||
|
/// \brief Method to set curr_chunk attribute.
|
||||||
|
///
|
||||||
|
/// \param[in] curr_chunk Define the curr_chunk attribute.
|
||||||
|
void set_curr_chunk(const int64_t curr_chunk);
|
||||||
|
|
||||||
|
/// \brief Method to get curr_chunk attribute.
|
||||||
|
///
|
||||||
|
/// \return the curr_chunk attribute.
|
||||||
|
int64_t get_curr_chunk() const;
|
||||||
|
|
||||||
|
/// \brief Method to set curr_chunk_index attribute.
|
||||||
|
///
|
||||||
|
/// \param[in] curr_chunk_index Define the curr_chunk_index attribute.
|
||||||
|
void set_curr_chunk_index(const int64_t curr_chunk_index);
|
||||||
|
|
||||||
|
/// \brief Method to get curr_chunk_index attribute.
|
||||||
|
///
|
||||||
|
/// \return the curr_chunk_index attribute.
|
||||||
|
int64_t get_curr_chunk_index() const;
|
||||||
|
|
||||||
|
/// \brief Method to set curr_bit_count attribute.
|
||||||
|
///
|
||||||
|
/// \param[in] curr_bit_count Define the curr_bit_count attribute..
|
||||||
|
void set_curr_bit_count(const int64_t curr_bit_count);
|
||||||
|
|
||||||
|
/// \brief Method to get curr_bit_count attribute.
|
||||||
|
///
|
||||||
|
/// \return the curr_bit_count attribute.
|
||||||
|
int64_t get_curr_bit_count() const;
|
||||||
|
|
||||||
|
/// \brief Method to set table_log attribute.
|
||||||
|
///
|
||||||
|
/// \param[in] table_log Define the table_log attribute.
|
||||||
|
void set_table_log(const int64_t table_log);
|
||||||
|
|
||||||
|
/// \brief Method to get table_log attribute.
|
||||||
|
///
|
||||||
|
/// \return the table_log attribute.
|
||||||
|
int64_t get_table_log() const;
|
||||||
|
};
|
||||||
|
abstract::AbstractBasePtr FSEDecodeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_FSE_DECODER_H_
|
|
@ -381,6 +381,10 @@ constexpr auto kConjugate = "conjugate";
|
||||||
constexpr auto KExclusive = "exclusive";
|
constexpr auto KExclusive = "exclusive";
|
||||||
constexpr auto KReverse = "reverse";
|
constexpr auto KReverse = "reverse";
|
||||||
constexpr auto KComputeEigenvectors = "compute_eigenvectors";
|
constexpr auto KComputeEigenvectors = "compute_eigenvectors";
|
||||||
|
constexpr auto KCurrChunk = "curr_chunk";
|
||||||
|
constexpr auto KCurrChunkIndex = "curr_chunk_index";
|
||||||
|
constexpr auto KCurrBitCount = "curr_bit_count";
|
||||||
|
constexpr auto KTableLog = "table_log";
|
||||||
|
|
||||||
constexpr size_t kInputIndex0 = 0;
|
constexpr size_t kInputIndex0 = 0;
|
||||||
constexpr size_t kInputIndex1 = 1;
|
constexpr size_t kInputIndex1 = 1;
|
||||||
|
|
|
@ -565,7 +565,7 @@ STATUS AnfTransform::DoSingleGraphQATTransform(const FuncGraphPtr &func_graph,
|
||||||
prim::kPrimSGD, prim::kPrimApplyMomentum};
|
prim::kPrimSGD, prim::kPrimApplyMomentum};
|
||||||
auto weight_quantizer = quant::WeightQuantizer();
|
auto weight_quantizer = quant::WeightQuantizer();
|
||||||
ret = weight_quantizer.WeightQuant(func_graph, support_primitive_types, per_layer_primitive_types,
|
ret = weight_quantizer.WeightQuant(func_graph, support_primitive_types, per_layer_primitive_types,
|
||||||
support_primitive_types, true, true, false);
|
support_primitive_types, false, true, false);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Run supplement weight quant param pass failed.";
|
MS_LOG(ERROR) << "Run supplement weight quant param pass failed.";
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -621,8 +621,8 @@ STATUS AnfTransform::QATTransform(const FuncGraphPtr &func_graph, const std::sha
|
||||||
}
|
}
|
||||||
|
|
||||||
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
||||||
quant::QuantizationOptimizer optimizer(param);
|
quant::QuantizationOptimizer quantization_optimizer(param);
|
||||||
auto ret = optimizer.Run(old_graph);
|
auto ret = quantization_optimizer.Run(old_graph);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Post training quantization failed.";
|
MS_LOG(ERROR) << "Post training quantization failed.";
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -797,6 +797,11 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
|
||||||
MS_LOG(ERROR) << "Proc online transform failed.";
|
MS_LOG(ERROR) << "Proc online transform failed.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
auto status = DoQuantize(old_graph, param);
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Do Quantize failed.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
return old_graph;
|
return old_graph;
|
||||||
}
|
}
|
||||||
auto value = old_graph->get_attr(kIsOptimized);
|
auto value = old_graph->get_attr(kIsOptimized);
|
||||||
|
|
|
@ -237,7 +237,8 @@ int ConfigFileParser::ParseMicroParamString(const std::map<std::string, std::map
|
||||||
int ConfigFileParser::ParseWeightQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
|
int ConfigFileParser::ParseWeightQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
|
||||||
if (maps.find(kWeightQuantParam) != maps.end()) {
|
if (maps.find(kWeightQuantParam) != maps.end()) {
|
||||||
const auto &map = maps.at(kWeightQuantParam);
|
const auto &map = maps.at(kWeightQuantParam);
|
||||||
std::map<std::string, std::string &> parse_map{{"dequant_strategy", weight_quant_string_.dequant_strategy}};
|
std::map<std::string, std::string &> parse_map{{"dequant_strategy", weight_quant_string_.dequant_strategy},
|
||||||
|
{"update_mindir", weight_quant_string_.update_mindir}};
|
||||||
return SetMapData(map, parse_map, kWeightQuantParam);
|
return SetMapData(map, parse_map, kWeightQuantParam);
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
|
|
@ -55,6 +55,7 @@ struct MixedBitWeightQuantString {
|
||||||
|
|
||||||
struct WeightQuantString {
|
struct WeightQuantString {
|
||||||
std::string dequant_strategy;
|
std::string dequant_strategy;
|
||||||
|
std::string update_mindir;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FullQuantString {
|
struct FullQuantString {
|
||||||
|
|
|
@ -261,6 +261,11 @@ int QuantParamParser::ParseWeightQuant(const WeightQuantString &weight_quant_str
|
||||||
return RET_INPUT_PARAM_INVALID;
|
return RET_INPUT_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (!weight_quant_string.update_mindir.empty() &&
|
||||||
|
!ConvertBool(weight_quant_string.update_mindir, &weight_quant->update_mindir)) {
|
||||||
|
MS_LOG(ERROR) << "INPUT ILLEGAL: update_mindir should be true or false.";
|
||||||
|
return RET_INPUT_PARAM_INVALID;
|
||||||
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
|
|
|
@ -279,6 +279,10 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const std::shared_ptr<ConverterPara
|
||||||
|
|
||||||
FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const std::shared_ptr<ConverterPara> ¶m,
|
FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const std::shared_ptr<ConverterPara> ¶m,
|
||||||
FuncGraphPtr func_graph) {
|
FuncGraphPtr func_graph) {
|
||||||
|
if (!param->weightQuantParam.update_mindir) {
|
||||||
|
MS_LOG(INFO) << "It will not update mindir.";
|
||||||
|
return func_graph;
|
||||||
|
}
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
|
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
|
||||||
MS_LOG(ERROR)
|
MS_LOG(ERROR)
|
||||||
|
|
|
@ -65,6 +65,84 @@ int FSEDecoder::FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int FSEDecoder::DecodeBuffer(int8_t *buffer, size_t data_size, FSEBuffer *fse_buffer) {
|
||||||
|
CHECK_NULL_RETURN(buffer);
|
||||||
|
CHECK_NULL_RETURN(fse_buffer);
|
||||||
|
if (data_size < sizeof(uint16_t)) {
|
||||||
|
MS_LOG(ERROR) << "data_size is invalid.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
size_t i = 0;
|
||||||
|
// 16bit for frequency_count
|
||||||
|
fse_buffer->frequency_count = *(reinterpret_cast<uint16_t *>(buffer + i));
|
||||||
|
i += sizeof(uint16_t);
|
||||||
|
if (i > data_size) {
|
||||||
|
MS_LOG(ERROR) << "index over total size"
|
||||||
|
<< " index:" << i << " total size:" << data_size;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
// 16bit for table_log
|
||||||
|
fse_buffer->table_log = *(reinterpret_cast<uint16_t *>(buffer + i));
|
||||||
|
i += sizeof(uint16_t);
|
||||||
|
if (i > data_size) {
|
||||||
|
MS_LOG(ERROR) << "index over total size"
|
||||||
|
<< " index:" << i << " total size:" << data_size;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
// 32bit for ChunkCount
|
||||||
|
fse_buffer->chunk_count = *(reinterpret_cast<uint32_t *>(buffer + i));
|
||||||
|
const size_t offset = 2;
|
||||||
|
// 32bit for CurrChunkIndex
|
||||||
|
fse_buffer->curr_chunk_index = fse_buffer->chunk_count - offset;
|
||||||
|
i += sizeof(uint32_t);
|
||||||
|
if (i > data_size) {
|
||||||
|
MS_LOG(ERROR) << "index over total size"
|
||||||
|
<< " index:" << i << " total size:" << data_size;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
// 32bit * frequency_count for frequency
|
||||||
|
fse_buffer->frequency = reinterpret_cast<uint32_t *>(buffer + i);
|
||||||
|
i += fse_buffer->frequency_count * sizeof(uint32_t);
|
||||||
|
// Used for 8-byte(64bit) alignment
|
||||||
|
i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
|
||||||
|
if (i > data_size) {
|
||||||
|
MS_LOG(ERROR) << "index over total size"
|
||||||
|
<< " index:" << i << " total size:" << data_size;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
// 32bit * frequency_count for centroids
|
||||||
|
fse_buffer->centroids = reinterpret_cast<void *>(buffer + i);
|
||||||
|
fse_buffer->centroid_size = fse_buffer->frequency_count * sizeof(float);
|
||||||
|
i += fse_buffer->centroid_size;
|
||||||
|
// Used for 8-byte(64bit) alignment
|
||||||
|
i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
|
||||||
|
if (i > data_size) {
|
||||||
|
MS_LOG(ERROR) << "index over total size"
|
||||||
|
<< " index:" << i << " total size:" << data_size;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
// 64bit * bs_.GetCurrChunkIndex() + 1 for Chunks.
|
||||||
|
fse_buffer->chunks = reinterpret_cast<uint64_t *>(buffer + i);
|
||||||
|
fse_buffer->chunk_size = (fse_buffer->curr_chunk_index + 1) * sizeof(uint64_t);
|
||||||
|
i += fse_buffer->chunk_size;
|
||||||
|
if (i > data_size) {
|
||||||
|
MS_LOG(ERROR) << "index over total size"
|
||||||
|
<< " index:" << i << " total size:" << data_size;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
// 64bit for CurrChunk
|
||||||
|
fse_buffer->curr_chunk = *(reinterpret_cast<uint64_t *>(buffer + i));
|
||||||
|
i += sizeof(uint64_t);
|
||||||
|
if (i > data_size) {
|
||||||
|
MS_LOG(ERROR) << "index over total size"
|
||||||
|
<< " index:" << i << " total size:" << data_size;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
// 8bit for CurrBitCount
|
||||||
|
fse_buffer->curr_bit_count = *(reinterpret_cast<uint8_t *>(buffer + i));
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int FSEDecoder::DeCompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_tensor,
|
int FSEDecoder::DeCompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_tensor,
|
||||||
schema::WeightQuantCompressType compress_type) {
|
schema::WeightQuantCompressType compress_type) {
|
||||||
CHECK_NULL_RETURN(src_tensor.handler());
|
CHECK_NULL_RETURN(src_tensor.handler());
|
||||||
|
|
|
@ -24,6 +24,19 @@
|
||||||
#include "src/litert/lite_model.h"
|
#include "src/litert/lite_model.h"
|
||||||
|
|
||||||
namespace mindspore::lite::quant {
|
namespace mindspore::lite::quant {
|
||||||
|
struct FSEBuffer {
|
||||||
|
uint16_t frequency_count = 0;
|
||||||
|
size_t table_log = 0;
|
||||||
|
uint32_t chunk_count = 0;
|
||||||
|
int32_t curr_chunk_index = 0;
|
||||||
|
uint32_t *frequency = nullptr;
|
||||||
|
void *centroids = nullptr;
|
||||||
|
size_t centroid_size = 0;
|
||||||
|
uint64_t *chunks = nullptr;
|
||||||
|
size_t chunk_size = 0;
|
||||||
|
uint64_t curr_chunk = 0;
|
||||||
|
uint8_t curr_bit_count = 0;
|
||||||
|
};
|
||||||
class FSEDecoder {
|
class FSEDecoder {
|
||||||
public:
|
public:
|
||||||
FSEDecoder() = default;
|
FSEDecoder() = default;
|
||||||
|
@ -35,6 +48,8 @@ class FSEDecoder {
|
||||||
static int FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int symbol_frequency_count, size_t table_log,
|
static int FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int symbol_frequency_count, size_t table_log,
|
||||||
uint16_t *new_state_baseline, uint8_t *bit_count, uint16_t *symbol_table);
|
uint16_t *new_state_baseline, uint8_t *bit_count, uint16_t *symbol_table);
|
||||||
|
|
||||||
|
static int DecodeBuffer(int8_t *buffer, size_t data_size, FSEBuffer *fse_buffer);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <typename C_TYPE, typename OUT_TYPE>
|
template <typename C_TYPE, typename OUT_TYPE>
|
||||||
static int FSEDecode(FSEBitStream *bs, OUT_TYPE *buff, int buff_count, uint32_t *frequency, int frequency_count,
|
static int FSEDecode(FSEBitStream *bs, OUT_TYPE *buff, int buff_count, uint32_t *frequency, int frequency_count,
|
||||||
|
|
|
@ -26,11 +26,16 @@
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "tools/optimizer/common/format_utils.h"
|
#include "tools/optimizer/common/format_utils.h"
|
||||||
#include "tools/common/node_util.h"
|
#include "tools/common/node_util.h"
|
||||||
|
#include "tools/common/tensor_util.h"
|
||||||
|
#include "tools/converter/quantizer/fse_decoder.h"
|
||||||
|
|
||||||
namespace mindspore::lite::quant {
|
namespace mindspore::lite::quant {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t kMinSize3 = 3;
|
constexpr size_t kMinSize3 = 3;
|
||||||
constexpr size_t kPrimitiveCOffset = 1;
|
constexpr size_t kPrimitiveCOffset = 1;
|
||||||
|
constexpr size_t kTableExtend = 3;
|
||||||
|
constexpr size_t kAlignOffset = 7;
|
||||||
|
constexpr size_t kInt32Mask = 31;
|
||||||
} // namespace
|
} // namespace
|
||||||
int InsertQuantNodeManager::SetCastNodeAbstract(const CNodePtr &cnode, const AnfNodePtr &input_node,
|
int InsertQuantNodeManager::SetCastNodeAbstract(const CNodePtr &cnode, const AnfNodePtr &input_node,
|
||||||
const CNodePtr &cast_cnode) {
|
const CNodePtr &cast_cnode) {
|
||||||
|
@ -472,17 +477,41 @@ int InsertQuantNodeManager::InsertWeightQuantNode(const FuncGraphPtr &func_graph
|
||||||
}
|
}
|
||||||
|
|
||||||
auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(primitive);
|
auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||||
std::vector<schema::QuantParamT> input_quant_params;
|
if (curr_primitive_quant_param_holder == nullptr ||
|
||||||
if (curr_primitive_quant_param_holder->get_input_quant_params().size() >= input_index) {
|
curr_primitive_quant_param_holder->get_input_quant_params().size() < input_index) {
|
||||||
input_quant_params = curr_primitive_quant_param_holder->get_input_quant_params().at(input_index - kPrimOffset);
|
MS_LOG(ERROR) << input_node->fullname_with_scope() << " quant param is invalid.";
|
||||||
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
auto input_quant_params = curr_primitive_quant_param_holder->get_input_quant_params().at(input_index - kPrimOffset);
|
||||||
|
|
||||||
ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, {}, axis, false);
|
ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, {}, axis, false);
|
||||||
std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node};
|
std::vector<float> scales;
|
||||||
|
std::vector<int> zps;
|
||||||
|
std::vector<float> mean_corrs;
|
||||||
|
std::vector<float> var_corrs;
|
||||||
|
for (size_t i = 0; i < input_quant_params.size(); ++i) {
|
||||||
|
scales.push_back(static_cast<float>(input_quant_params.at(i).scale));
|
||||||
|
zps.push_back(static_cast<int64_t>(input_quant_params.at(i).zeroPoint));
|
||||||
|
mean_corrs.push_back(static_cast<float>(input_quant_params.at(i).meanCorr));
|
||||||
|
var_corrs.push_back(static_cast<float>(input_quant_params.at(i).varCorr));
|
||||||
|
}
|
||||||
|
auto scales_node = opt::BuildFloatVecParameterNode(func_graph, scales, "scales");
|
||||||
|
auto zps_node = opt::BuildIntVecParameterNode(func_graph, zps, "zps");
|
||||||
|
auto mean_corrs_node = opt::BuildFloatVecParameterNode(func_graph, mean_corrs, "mean_corrs");
|
||||||
|
auto var_corrs_node = opt::BuildFloatVecParameterNode(func_graph, var_corrs, "var_corrs");
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node, scales_node,
|
||||||
|
zps_node, mean_corrs_node, var_corrs_node};
|
||||||
auto quant_cast_cnode = func_graph->NewCNode(op_inputs);
|
auto quant_cast_cnode = func_graph->NewCNode(op_inputs);
|
||||||
CHECK_NULL_RETURN(quant_cast_cnode);
|
CHECK_NULL_RETURN(quant_cast_cnode);
|
||||||
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast_" +
|
auto strings = SplitStringToVector(cnode->fullname_with_scope(), "-op");
|
||||||
std::to_string(input_index));
|
int index = 0;
|
||||||
|
if (!ConvertIntNum(strings.at(strings.size() - 1), &index)) {
|
||||||
|
index = 0;
|
||||||
|
}
|
||||||
|
const int quant_dtype_cast_offset = 10000;
|
||||||
|
quant_cast_cnode->set_fullname_with_scope(strings.at(0) + "-QuantDtypeCast-op" +
|
||||||
|
std::to_string(index + quant_dtype_cast_offset));
|
||||||
opt::NodeInferShape infer;
|
opt::NodeInferShape infer;
|
||||||
auto status = infer.InferShape(quant_cast_cnode);
|
auto status = infer.InferShape(quant_cast_cnode);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
|
@ -502,4 +531,138 @@ int InsertQuantNodeManager::InsertWeightQuantNode(const FuncGraphPtr &func_graph
|
||||||
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int InsertQuantNodeManager::InsertFSEDecodeNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||||
|
size_t input_index, TypeId dst_dtype) {
|
||||||
|
auto primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
|
||||||
|
if (primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
auto input_node = cnode->input(input_index);
|
||||||
|
if (!input_node->isa<mindspore::Parameter>()) {
|
||||||
|
MS_LOG(ERROR) << cnode->fullname_with_scope() << " input " << input_index << " is not parameter node.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
auto shape = input_node->Shape();
|
||||||
|
std::vector<AnfNodePtr> op_inputs;
|
||||||
|
int ret = CreateFSEInputs(func_graph, input_node, &op_inputs, dst_dtype);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "CreateFSEInputs failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fse_decode_cnode = func_graph->NewCNode(op_inputs);
|
||||||
|
CHECK_NULL_RETURN(fse_decode_cnode);
|
||||||
|
auto strings = SplitStringToVector(cnode->fullname_with_scope(), "-op");
|
||||||
|
int index = 0;
|
||||||
|
if (!ConvertIntNum(strings.at(strings.size() - 1), &index)) {
|
||||||
|
index = 0;
|
||||||
|
}
|
||||||
|
const int fse_decode_offset = 20000;
|
||||||
|
fse_decode_cnode->set_fullname_with_scope(strings.at(0) + "-FSEDecode-op" +
|
||||||
|
std::to_string(index + fse_decode_offset));
|
||||||
|
CHECK_NULL_RETURN(cnode->abstract());
|
||||||
|
auto fse_abstract = cnode->abstract()->Clone();
|
||||||
|
fse_abstract->set_shape(shape);
|
||||||
|
fse_decode_cnode->set_abstract(fse_abstract);
|
||||||
|
|
||||||
|
auto manager = func_graph->manager();
|
||||||
|
CHECK_NULL_RETURN(manager);
|
||||||
|
ret = manager->Replace(input_node, fse_decode_cnode);
|
||||||
|
if (!ret) {
|
||||||
|
MS_LOG(ERROR) << "Replace QuantDtypeCast failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int InsertQuantNodeManager::CreateFSEInputs(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
|
||||||
|
std::vector<AnfNodePtr> *op_inputs, TypeId dst_dtype) {
|
||||||
|
if (!input_node->isa<mindspore::Parameter>()) {
|
||||||
|
MS_LOG(ERROR) << "FSEDecode input is not parameter node.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
auto parameter_ptr = input_node->cast<ParameterPtr>();
|
||||||
|
CHECK_NULL_RETURN(parameter_ptr);
|
||||||
|
if (!parameter_ptr->has_default()) {
|
||||||
|
MS_LOG(ERROR) << input_node->fullname_with_scope() << " parameter dont have default.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
auto tensor = parameter_ptr->default_param()->cast<tensor::TensorPtr>();
|
||||||
|
int8_t *data8 = reinterpret_cast<int8_t *>(tensor->data_c());
|
||||||
|
size_t data_size = tensor->DataSize();
|
||||||
|
FSEBuffer fse_buffer;
|
||||||
|
auto ret = FSEDecoder::DecodeBuffer(data8, data_size, &fse_buffer);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << input_node->fullname_with_scope() << " buffer decode failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
ValueNodePtr new_primitive = NewFSEDecodePrimitive(dst_dtype, fse_buffer.curr_chunk, fse_buffer.curr_chunk_index,
|
||||||
|
fse_buffer.curr_bit_count, fse_buffer.table_log);
|
||||||
|
op_inputs->push_back(new_primitive);
|
||||||
|
|
||||||
|
// make shape to (1,chunk_size)
|
||||||
|
ShapeVector shape_vector;
|
||||||
|
shape_vector.push_back(1);
|
||||||
|
shape_vector.push_back(fse_buffer.chunk_size);
|
||||||
|
auto chunk_tensor_info =
|
||||||
|
lite::CreateTensorInfo(fse_buffer.chunks, fse_buffer.chunk_size, shape_vector, kNumberTypeInt8);
|
||||||
|
parameter_ptr->set_default_param(chunk_tensor_info);
|
||||||
|
parameter_ptr->set_abstract(chunk_tensor_info->ToAbstract());
|
||||||
|
op_inputs->push_back(input_node);
|
||||||
|
|
||||||
|
size_t table_size = 1u << fse_buffer.table_log;
|
||||||
|
uint16_t *states_table = static_cast<uint16_t *>(malloc(table_size * sizeof(uint16_t)));
|
||||||
|
CHECK_NULL_RETURN(states_table);
|
||||||
|
uint8_t *bit_count_table = static_cast<uint8_t *>(malloc(table_size * sizeof(uint8_t)));
|
||||||
|
CHECK_NULL_RETURN(bit_count_table);
|
||||||
|
uint16_t *symbol_table = static_cast<uint16_t *>(malloc(table_size * sizeof(uint16_t)));
|
||||||
|
CHECK_NULL_RETURN(symbol_table);
|
||||||
|
|
||||||
|
ret = FSEDecoder::FSECreateStatesForDecoding(fse_buffer.frequency, fse_buffer.frequency_count, fse_buffer.table_log,
|
||||||
|
states_table, bit_count_table, symbol_table);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "FSE create states for decoding failed.";
|
||||||
|
free(states_table);
|
||||||
|
free(bit_count_table);
|
||||||
|
free(symbol_table);
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
std::vector<int64_t> shape = {static_cast<int64_t>(table_size)};
|
||||||
|
|
||||||
|
auto states_table_tensor_info =
|
||||||
|
lite::CreateTensorInfo(states_table, sizeof(uint16_t) * table_size, shape, kNumberTypeUInt16);
|
||||||
|
auto states_table_node = opt::BuildParameterNode(func_graph, states_table_tensor_info, "states_table");
|
||||||
|
op_inputs->push_back(states_table_node);
|
||||||
|
|
||||||
|
auto bit_count_table_tensor_info =
|
||||||
|
lite::CreateTensorInfo(bit_count_table, sizeof(uint8_t) * table_size, shape, kNumberTypeUInt8);
|
||||||
|
auto bit_count_table_node = opt::BuildParameterNode(func_graph, bit_count_table_tensor_info, "bit_count_table");
|
||||||
|
op_inputs->push_back(bit_count_table_node);
|
||||||
|
|
||||||
|
auto symbol_table_tensor_info =
|
||||||
|
lite::CreateTensorInfo(symbol_table, sizeof(uint16_t) * table_size, shape, kNumberTypeUInt16);
|
||||||
|
auto symbol_table_node = opt::BuildParameterNode(func_graph, symbol_table_tensor_info, "symbol_table");
|
||||||
|
op_inputs->push_back(symbol_table_node);
|
||||||
|
|
||||||
|
auto centroids_tensor_info =
|
||||||
|
lite::CreateTensorInfo(fse_buffer.centroids, sizeof(float) * fse_buffer.centroid_size,
|
||||||
|
{static_cast<int64_t>(fse_buffer.centroid_size)}, kNumberTypeFloat32);
|
||||||
|
auto centroids_node = opt::BuildParameterNode(func_graph, centroids_tensor_info, "centroids");
|
||||||
|
op_inputs->push_back(centroids_node);
|
||||||
|
|
||||||
|
auto shape_tensor_info = lite::CreateTensorInfo(ConvertShapeVectorToInt32(tensor->shape_c()).data(),
|
||||||
|
sizeof(int32_t) * tensor->shape_c().size(),
|
||||||
|
{static_cast<int64_t>(tensor->shape_c().size())}, kNumberTypeInt32);
|
||||||
|
auto shape_node = opt::BuildParameterNode(func_graph, shape_tensor_info, "input_shape");
|
||||||
|
op_inputs->push_back(shape_node);
|
||||||
|
|
||||||
|
// Free buffer
|
||||||
|
free(states_table);
|
||||||
|
free(bit_count_table);
|
||||||
|
free(symbol_table);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
} // namespace mindspore::lite::quant
|
} // namespace mindspore::lite::quant
|
||||||
|
|
|
@ -49,6 +49,7 @@ class InsertQuantNodeManager {
|
||||||
|
|
||||||
int InsertWeightQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, TypeId src_dtype,
|
int InsertWeightQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, TypeId src_dtype,
|
||||||
TypeId dst_dtype, int axis);
|
TypeId dst_dtype, int axis);
|
||||||
|
int InsertFSEDecodeNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, TypeId dst_dtype);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int CheckDataType(const AnfNodePtr &input_node, TypeId check_type_id) const;
|
int CheckDataType(const AnfNodePtr &input_node, TypeId check_type_id) const;
|
||||||
|
@ -68,6 +69,8 @@ class InsertQuantNodeManager {
|
||||||
const AnfNodePtr &output_node);
|
const AnfNodePtr &output_node);
|
||||||
int InserQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, InsertDirection insert_direction,
|
int InserQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, InsertDirection insert_direction,
|
||||||
TypeId cast_dtype, CastNodeType cast_node_type, size_t index, const AnfNodePtr &output_node);
|
TypeId cast_dtype, CastNodeType cast_node_type, size_t index, const AnfNodePtr &output_node);
|
||||||
|
int CreateFSEInputs(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, std::vector<AnfNodePtr> *op_inputs,
|
||||||
|
TypeId dst_dtype);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TypeId dst_type_ = kNumberTypeInt8;
|
TypeId dst_type_ = kNumberTypeInt8;
|
||||||
|
|
|
@ -116,6 +116,7 @@ struct CommonQuantParam {
|
||||||
|
|
||||||
struct WeightQuantParam {
|
struct WeightQuantParam {
|
||||||
DequantStrategy dequant_strategy = DEFAULT;
|
DequantStrategy dequant_strategy = DEFAULT;
|
||||||
|
bool update_mindir = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MixedBitWeightQuantParam {
|
struct MixedBitWeightQuantParam {
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <deque>
|
#include <deque>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
|
||||||
#include "tools/lite_exporter/fetch_content.h"
|
#include "tools/lite_exporter/fetch_content.h"
|
||||||
#include "base/base.h"
|
#include "base/base.h"
|
||||||
#include "tools/converter/quantizer/quantize_util.h"
|
#include "tools/converter/quantizer/quantize_util.h"
|
||||||
|
@ -237,27 +238,28 @@ int ConvertValueNodeToParameter(const FuncGraphPtr &func_graph) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int PrepareQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
int PrepareQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
|
||||||
int status;
|
|
||||||
|
|
||||||
status = ConvertFp16ToFp32(old_graph);
|
|
||||||
if (status != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "Convert fp16 To fp32 failed.";
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!param->train_model) {
|
if (!param->train_model) {
|
||||||
status = ConvertValueNodeToParameter(old_graph);
|
auto status = ConvertValueNodeToParameter(old_graph);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Convert value node To parameter failed.";
|
MS_LOG(ERROR) << "Convert value node To parameter failed.";
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto convert_pm = std::make_shared<opt::LitePassManager>("anf graph convert pass manager", true);
|
||||||
|
convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(param->train_model));
|
||||||
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
|
optimizer->AddPassManager(convert_pm);
|
||||||
|
if (optimizer->Optimize(old_graph) == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "run graph pass failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
bool per_layer = param->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL &&
|
bool per_layer = param->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL &&
|
||||||
!param->fullQuantParam.per_channel && param->fullQuantParam.target_device != DSP;
|
!param->fullQuantParam.per_channel && param->fullQuantParam.target_device != DSP;
|
||||||
if (per_layer) {
|
if (per_layer) {
|
||||||
CLEStrategy cle_strategy(old_graph);
|
CLEStrategy cle_strategy(old_graph);
|
||||||
status = cle_strategy.Run();
|
auto status = cle_strategy.Run();
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "do pre process failed!";
|
MS_LOG(ERROR) << "do pre process failed!";
|
||||||
return status;
|
return status;
|
||||||
|
|
|
@ -34,6 +34,7 @@
|
||||||
#include "ops/fusion/conv2d_transpose_fusion.h"
|
#include "ops/fusion/conv2d_transpose_fusion.h"
|
||||||
#include "ops/gather.h"
|
#include "ops/gather.h"
|
||||||
#include "ops/op_utils.h"
|
#include "ops/op_utils.h"
|
||||||
|
#include "ops/fse_decode.h"
|
||||||
#include "src/common/utils.h"
|
#include "src/common/utils.h"
|
||||||
#include "src/litert/cxx_api/tensor/tensor_impl.h"
|
#include "src/litert/cxx_api/tensor/tensor_impl.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
|
@ -165,6 +166,17 @@ ValueNodePtr NewQuantCastPrimitive(int src_type, int dst_type,
|
||||||
return NewValueNode(prim);
|
return NewValueNode(prim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ValueNodePtr NewFSEDecodePrimitive(int dst_type, const uint64_t curr_chunk, const int64_t curr_chunk_index,
|
||||||
|
const int64_t curr_bit_count, const int64_t table_log) {
|
||||||
|
auto prim_c = std::make_shared<ops::FSEDecode>();
|
||||||
|
MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
|
||||||
|
prim_c->Init(dst_type, curr_chunk, curr_chunk_index, curr_bit_count, table_log);
|
||||||
|
|
||||||
|
auto prim = prim_c->GetPrim();
|
||||||
|
MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
|
||||||
|
return NewValueNode(prim);
|
||||||
|
}
|
||||||
|
|
||||||
bool IsGraphInDTypeCast(const CNodePtr &cnode) {
|
bool IsGraphInDTypeCast(const CNodePtr &cnode) {
|
||||||
if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
|
if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -498,7 +510,7 @@ int CalChannels(const std::vector<int> &dims, int channel_cnt, bool *channel_at_
|
||||||
int GetPreferredDim(const CNodePtr &cnode, int input_index, const std::vector<int> &dims) {
|
int GetPreferredDim(const CNodePtr &cnode, int input_index, const std::vector<int> &dims) {
|
||||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
CHECK_NULL_RETURN(primitive);
|
CHECK_NULL_RETURN(primitive);
|
||||||
if (primitive->name() == ops::kNameMatMulFusion) {
|
if (primitive->name() == ops::kNameMatMulFusion || primitive->name() == ops::kNameMatMul) {
|
||||||
return GetMatMulPreferredDim(primitive, input_index, dims);
|
return GetMatMulPreferredDim(primitive, input_index, dims);
|
||||||
} else if (primitive->name() == ops::kNameConv2dTransposeFusion) {
|
} else if (primitive->name() == ops::kNameConv2dTransposeFusion) {
|
||||||
return GetDeConvPreferredDim(primitive, dims);
|
return GetDeConvPreferredDim(primitive, dims);
|
||||||
|
|
|
@ -81,6 +81,9 @@ ValueNodePtr NewQuantCastPrimitive(int src_type, int dst_type,
|
||||||
const std::vector<schema::QuantParamT> &output_quant_params, int axis = 0,
|
const std::vector<schema::QuantParamT> &output_quant_params, int axis = 0,
|
||||||
bool set_quant_flag = true);
|
bool set_quant_flag = true);
|
||||||
|
|
||||||
|
ValueNodePtr NewFSEDecodePrimitive(int dst_type, const uint64_t curr_chunk, const int64_t curr_chunk_index,
|
||||||
|
const int64_t curr_bit_count, const int64_t table_log);
|
||||||
|
|
||||||
bool IsGraphInDTypeCast(const CNodePtr &cnode);
|
bool IsGraphInDTypeCast(const CNodePtr &cnode);
|
||||||
|
|
||||||
bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||||
|
|
|
@ -33,6 +33,33 @@
|
||||||
#include "src/common/quant_utils.h"
|
#include "src/common/quant_utils.h"
|
||||||
|
|
||||||
namespace mindspore::lite::quant {
|
namespace mindspore::lite::quant {
|
||||||
|
namespace {
|
||||||
|
tensor::TensorPtr ConvertParameterFp16TensorToFp32(const ParameterPtr ¶meter) {
|
||||||
|
if (!parameter->has_default()) {
|
||||||
|
MS_LOG(WARNING) << parameter->fullname_with_scope() << " not has_default";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
|
||||||
|
if (tensor_info == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "default_param can not cast to tensor::Tensor";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (tensor_info->data_type() == kNumberTypeFloat16) {
|
||||||
|
MS_LOG(INFO) << "convert " << parameter->fullname_with_scope() << " from fp16 to fp32.";
|
||||||
|
auto data = static_cast<float16 *>(tensor_info->data_c());
|
||||||
|
std::vector<float> fp32_data(tensor_info->DataSize());
|
||||||
|
for (size_t j = 0; j < tensor_info->DataSize(); j++) {
|
||||||
|
fp32_data[j] = mindspore::Float16::ToFloat32(data[j]);
|
||||||
|
}
|
||||||
|
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(
|
||||||
|
kNumberTypeFloat32, tensor_info->shape_c(), fp32_data.data(), fp32_data.size() * sizeof(float));
|
||||||
|
parameter->set_default_param(tensor_ptr);
|
||||||
|
parameter->set_abstract(tensor_ptr->ToAbstract());
|
||||||
|
return tensor_ptr;
|
||||||
|
}
|
||||||
|
return tensor_info;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
int WeightQuantizer::WeightQuant(const FuncGraphPtr &func_graph,
|
int WeightQuantizer::WeightQuant(const FuncGraphPtr &func_graph,
|
||||||
const std::set<PrimitivePtr> &support_weight_quant_types,
|
const std::set<PrimitivePtr> &support_weight_quant_types,
|
||||||
const std::set<PrimitivePtr> &per_layer_types,
|
const std::set<PrimitivePtr> &per_layer_types,
|
||||||
|
@ -69,7 +96,7 @@ int WeightQuantizer::WeightQuantPerCNode(const FuncGraphPtr &func_graph, const C
|
||||||
const std::set<PrimitivePtr> &per_layer_types,
|
const std::set<PrimitivePtr> &per_layer_types,
|
||||||
const std::set<PrimitivePtr> &symmetric_types, bool compression,
|
const std::set<PrimitivePtr> &symmetric_types, bool compression,
|
||||||
bool update_tensor) {
|
bool update_tensor) {
|
||||||
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
|
auto primitive = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
|
||||||
if (primitive == nullptr) {
|
if (primitive == nullptr) {
|
||||||
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
|
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
@ -121,9 +148,10 @@ int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr
|
||||||
const std::set<PrimitivePtr> &symmetric_types, const std::vector<int> &weight_indices,
|
const std::set<PrimitivePtr> &symmetric_types, const std::vector<int> &weight_indices,
|
||||||
bool compression, bool update_tensor) {
|
bool compression, bool update_tensor) {
|
||||||
CHECK_NULL_RETURN(cnode);
|
CHECK_NULL_RETURN(cnode);
|
||||||
WeightQuantType weight_quant_type = WeightQuantType::FIXED_BIT_PER_CHANNEL;
|
// Avoid affecting other operators
|
||||||
|
auto tmp_weight_quant_type = weight_quant_type_;
|
||||||
if (CheckNodeInSet(cnode, per_layer_types)) {
|
if (CheckNodeInSet(cnode, per_layer_types)) {
|
||||||
weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
|
tmp_weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||||
}
|
}
|
||||||
bool symmetric = false;
|
bool symmetric = false;
|
||||||
int q_min = quant_min_;
|
int q_min = quant_min_;
|
||||||
|
@ -143,11 +171,16 @@ int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr
|
||||||
ParameterPtr parameter;
|
ParameterPtr parameter;
|
||||||
tensor::TensorPtr tensor_info;
|
tensor::TensorPtr tensor_info;
|
||||||
GetLiteParameter(input, ¶meter, &tensor_info);
|
GetLiteParameter(input, ¶meter, &tensor_info);
|
||||||
if (parameter == nullptr || tensor_info == nullptr || tensor_info->data_type() != TypeId::kNumberTypeFloat32 ||
|
if (parameter == nullptr || tensor_info == nullptr ||
|
||||||
tensor_info->compression_type() != mindspore::kNoCompression) {
|
tensor_info->compression_type() != mindspore::kNoCompression) {
|
||||||
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " dont need quant weight";
|
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " dont need quant weight";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
tensor_info = ConvertParameterFp16TensorToFp32(parameter);
|
||||||
|
if (tensor_info == nullptr || tensor_info->data_type() != TypeId::kNumberTypeFloat32) {
|
||||||
|
MS_LOG(INFO) << "This op " << input->fullname_with_scope() << " is null or dtype is not fp32.";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
int preferred_dim = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
|
int preferred_dim = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
|
||||||
if (quant_strategy_ != nullptr && !quant_strategy_->CanTensorQuantized(cnode, input, preferred_dim)) {
|
if (quant_strategy_ != nullptr && !quant_strategy_->CanTensorQuantized(cnode, input, preferred_dim)) {
|
||||||
MS_LOG(INFO) << input->fullname_with_scope() << " is not quantizable";
|
MS_LOG(INFO) << input->fullname_with_scope() << " is not quantizable";
|
||||||
|
@ -156,7 +189,6 @@ int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr
|
||||||
// support for matmul shared weight
|
// support for matmul shared weight
|
||||||
auto node_map = manager->node_users();
|
auto node_map = manager->node_users();
|
||||||
auto node_user = node_map[input];
|
auto node_user = node_map[input];
|
||||||
auto tmp_weight_quant_type = weight_quant_type;
|
|
||||||
if (node_user.size() > 1 && opt::CheckPrimitiveType(cnode, prim::kPrimMatMulFusion)) {
|
if (node_user.size() > 1 && opt::CheckPrimitiveType(cnode, prim::kPrimMatMulFusion)) {
|
||||||
MS_LOG(INFO) << input->fullname_with_scope() << " is shared weight.";
|
MS_LOG(INFO) << input->fullname_with_scope() << " is shared weight.";
|
||||||
tmp_weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
|
tmp_weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||||
|
@ -190,12 +222,41 @@ int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
weight_quantized_tensors_.insert(tensor_info);
|
weight_quantized_tensors_.insert(tensor_info);
|
||||||
|
if (dequant_strategy_ == ON_THE_FLY) {
|
||||||
|
status = InsertDequantNode(func_graph, cnode, parameter, idx, tensor_info);
|
||||||
|
if (status == RET_NO_CHANGE) {
|
||||||
|
continue;
|
||||||
|
} else if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << cnode->fullname_with_scope() << " insert dequan node failed.";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int WeightQuantizer::DoCompression(const CNodePtr &cnode, const ParameterPtr ¶meter, int idx) {
|
int WeightQuantizer::DoCompression(const CNodePtr &cnode, const ParameterPtr ¶meter, int idx) {
|
||||||
int ret = RET_OK;
|
int ret = RET_OK;
|
||||||
|
if (dequant_strategy_ == ON_THE_FLY) {
|
||||||
|
if (bit_num_ < k8Bit) {
|
||||||
|
FSEEncoder fse_encoder;
|
||||||
|
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
CHECK_NULL_RETURN(primitive);
|
||||||
|
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||||
|
auto tensor_quant_params = quant_param_holder->get_input_quant_params();
|
||||||
|
MS_CHECK_GT(static_cast<int>(tensor_quant_params.size()), idx - 1, RET_ERROR);
|
||||||
|
auto quant_params = tensor_quant_params.at(idx - 1);
|
||||||
|
mindspore::TensorCompressionType compress_type =
|
||||||
|
dequant_strategy_ == ON_THE_FLY ? mindspore::kFSEInfer : mindspore::kFSE;
|
||||||
|
ret = fse_encoder.Compress(parameter, quant_params, compress_type);
|
||||||
|
auto new_tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
|
||||||
|
CHECK_NULL_RETURN(new_tensor_info);
|
||||||
|
weight_quantized_tensors_.insert(new_tensor_info);
|
||||||
|
return ret;
|
||||||
|
} else {
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
}
|
||||||
TensorCompressor compressor;
|
TensorCompressor compressor;
|
||||||
auto quant_param_holder = GetCNodeQuantHolder(cnode);
|
auto quant_param_holder = GetCNodeQuantHolder(cnode);
|
||||||
auto tensor_quant_params = quant_param_holder->get_input_quant_params();
|
auto tensor_quant_params = quant_param_holder->get_input_quant_params();
|
||||||
|
@ -239,7 +300,7 @@ int WeightQuantizer::DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr &pa
|
||||||
MS_CHECK_GT(static_cast<int>(tensor_quant_params.size()), idx - 1, RET_ERROR);
|
MS_CHECK_GT(static_cast<int>(tensor_quant_params.size()), idx - 1, RET_ERROR);
|
||||||
auto quant_params = tensor_quant_params.at(idx - 1);
|
auto quant_params = tensor_quant_params.at(idx - 1);
|
||||||
mindspore::TensorCompressionType compress_type =
|
mindspore::TensorCompressionType compress_type =
|
||||||
param_->weightQuantParam.dequant_strategy == ON_THE_FLY ? mindspore::kFSEInfer : mindspore::kFSE;
|
dequant_strategy_ == ON_THE_FLY ? mindspore::kFSEInfer : mindspore::kFSE;
|
||||||
status = fse_encoder.Compress(parameter, quant_params, compress_type);
|
status = fse_encoder.Compress(parameter, quant_params, compress_type);
|
||||||
if (status == RET_OK) {
|
if (status == RET_OK) {
|
||||||
quant_param_holder->ClearQuantParams();
|
quant_param_holder->ClearQuantParams();
|
||||||
|
@ -269,34 +330,40 @@ int WeightQuantizer::DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr &pa
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
int WeightQuantizer::InsertQuantDtypeNode(const FuncGraphPtr &func_graph) {
|
int WeightQuantizer::InsertDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||||
|
const ParameterPtr ¶meter, int idx, const tensor::TensorPtr &tensor_info) {
|
||||||
|
InsertQuantNodeManager quant_manager;
|
||||||
CHECK_NULL_RETURN(func_graph);
|
CHECK_NULL_RETURN(func_graph);
|
||||||
for (auto &cnode : func_graph->GetOrderedCnodes()) {
|
TypeId type_id;
|
||||||
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
|
int status;
|
||||||
if (primitive == nullptr) {
|
auto tensor_name = parameter->fullname_with_scope();
|
||||||
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
|
if (opt::GetDataTypeFromAnfNode(cnode, &type_id) != RET_OK) {
|
||||||
continue;
|
MS_LOG(WARNING) << cnode->fullname_with_scope() << " Get data type failed.";
|
||||||
|
return RET_NO_CHANGE;
|
||||||
}
|
}
|
||||||
// Support Share Weight Quant.
|
if (parameter->has_default() &&
|
||||||
for (size_t i = kPrimOffset; i < cnode->size(); i++) {
|
parameter->default_param()->cast<tensor::TensorPtr>()->compression_type() == mindspore::kFSEInfer) {
|
||||||
auto inputNode = cnode->input(i);
|
MS_LOG(INFO) << tensor_name << " insert FSEDecode node";
|
||||||
if (inputNode->isa<Parameter>()) {
|
if (type_id == kNumberTypeFloat32) {
|
||||||
ParameterPtr param_node;
|
status = quant_manager.InsertFSEDecodeNode(func_graph, cnode, idx, kNumberTypeFloat32);
|
||||||
tensor::TensorPtr tensor_info;
|
} else {
|
||||||
GetLiteParameter(inputNode, ¶m_node, &tensor_info);
|
status = quant_manager.InsertFSEDecodeNode(func_graph, cnode, idx, kNumberTypeFloat16);
|
||||||
auto param = weight_quantized_tensors_.find(tensor_info);
|
|
||||||
if (param != weight_quantized_tensors_.end()) {
|
|
||||||
InsertQuantNodeManager manager;
|
|
||||||
auto ret = manager.InsertWeightQuantNode(
|
|
||||||
func_graph, cnode, i, kNumberTypeInt8, kNumberTypeFloat32,
|
|
||||||
GetPreferredDim(cnode, i - kPrimOffset, ConvertShapeVectorToInt32(tensor_info->shape_c())));
|
|
||||||
if (ret != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "Insert weight quant node failed.";
|
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
continue;
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << tensor_name << " insert FSEDecode node failed.";
|
||||||
|
return status;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(INFO) << tensor_name << " insert WeightQuant node";
|
||||||
|
auto axis = GetPreferredDim(cnode, idx - kPrimOffset, ConvertShapeVectorToInt32(tensor_info->shape_c()));
|
||||||
|
if (type_id == kNumberTypeFloat32) {
|
||||||
|
status = quant_manager.InsertWeightQuantNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat32, axis);
|
||||||
|
} else {
|
||||||
|
status = quant_manager.InsertWeightQuantNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat16, axis);
|
||||||
}
|
}
|
||||||
|
if (status != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << tensor_name << " insert weight quant node failed.";
|
||||||
|
return status;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
@ -387,9 +454,9 @@ int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||||
MS_LOG(ERROR) << "Weight Quant failed.";
|
MS_LOG(ERROR) << "Weight Quant failed.";
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
if (param_->weightQuantParam.dequant_strategy == ON_THE_FLY) {
|
if (dequant_strategy_ != ON_THE_FLY) {
|
||||||
return InsertQuantDtypeNode(func_graph);
|
|
||||||
}
|
|
||||||
return MarkGraphWeightQuantType(func_graph);
|
return MarkGraphWeightQuantType(func_graph);
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
}
|
}
|
||||||
} // namespace mindspore::lite::quant
|
} // namespace mindspore::lite::quant
|
||||||
|
|
|
@ -90,6 +90,10 @@ class WeightQuantizer : public Quantizer {
|
||||||
std::inserter(skip_quant_node_, skip_quant_node_.begin()));
|
std::inserter(skip_quant_node_, skip_quant_node_.begin()));
|
||||||
}
|
}
|
||||||
quant_type_ = param_->commonQuantParam.quant_type;
|
quant_type_ = param_->commonQuantParam.quant_type;
|
||||||
|
dequant_strategy_ = param_->weightQuantParam.dequant_strategy;
|
||||||
|
if (param_->weightQuantParam.dequant_strategy == ON_THE_FLY) {
|
||||||
|
weight_quant_type_ = WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
~WeightQuantizer() override = default;
|
~WeightQuantizer() override = default;
|
||||||
|
@ -117,7 +121,8 @@ class WeightQuantizer : public Quantizer {
|
||||||
int preferred_dim, WeightQuantType weight_quant_type, bool symmetric = true,
|
int preferred_dim, WeightQuantType weight_quant_type, bool symmetric = true,
|
||||||
bool update_tensor = true);
|
bool update_tensor = true);
|
||||||
bool CheckWeightQuantExist(const CNodePtr &cnode);
|
bool CheckWeightQuantExist(const CNodePtr &cnode);
|
||||||
int InsertQuantDtypeNode(const FuncGraphPtr &func_graph);
|
int InsertDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const ParameterPtr ¶meter, int idx,
|
||||||
|
const tensor::TensorPtr &tensor_info);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool is_auto_tune_{false};
|
bool is_auto_tune_{false};
|
||||||
|
@ -134,6 +139,8 @@ class WeightQuantizer : public Quantizer {
|
||||||
std::unique_ptr<QuantStrategy> quant_strategy_;
|
std::unique_ptr<QuantStrategy> quant_strategy_;
|
||||||
schema::QuantType quant_type_{schema::QuantType_WeightQuant};
|
schema::QuantType quant_type_{schema::QuantType_WeightQuant};
|
||||||
bool enable_encode_{true};
|
bool enable_encode_{true};
|
||||||
|
WeightQuantType weight_quant_type_ = WeightQuantType::FIXED_BIT_PER_CHANNEL;
|
||||||
|
DequantStrategy dequant_strategy_ = DEFAULT;
|
||||||
// Support for mark shared weight node.
|
// Support for mark shared weight node.
|
||||||
std::set<tensor::TensorPtr> weight_quantized_tensors_;
|
std::set<tensor::TensorPtr> weight_quantized_tensors_;
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue