!46806 support perlayer && update_mindir

Merge pull request !46806 from yeyunpeng2020/master_quant_flag_ci_2
This commit is contained in:
i-robot 2022-12-16 01:57:29 +00:00 committed by Gitee
commit 49bf9a8ad2
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
19 changed files with 627 additions and 57 deletions

View File

@ -1654,6 +1654,7 @@ GVAR_DEF(PrimitivePtr, kPrimEltwise, std::make_shared<Primitive>("Eltwise"));
GVAR_DEF(PrimitivePtr, kPrimMatMulFusion, std::make_shared<Primitive>("MatMulFusion")); GVAR_DEF(PrimitivePtr, kPrimMatMulFusion, std::make_shared<Primitive>("MatMulFusion"));
GVAR_DEF(PrimitivePtr, kPrimDynamicQuant, std::make_shared<Primitive>("DynamicQuant")); GVAR_DEF(PrimitivePtr, kPrimDynamicQuant, std::make_shared<Primitive>("DynamicQuant"));
GVAR_DEF(PrimitivePtr, kPrimPartialFusion, std::make_shared<Primitive>("PartialFusion")); GVAR_DEF(PrimitivePtr, kPrimPartialFusion, std::make_shared<Primitive>("PartialFusion"));
GVAR_DEF(PrimitivePtr, kPrimFSEDecode, std::make_shared<Primitive>("FSEDecode"));
// Type introspection // Type introspection
GVAR_DEF(PrimitivePtr, kPrimTypeOf, std::make_shared<Primitive>("typeof")); GVAR_DEF(PrimitivePtr, kPrimTypeOf, std::make_shared<Primitive>("typeof"));

View File

@ -0,0 +1,98 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/fse_decode.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(FSEDecode, BaseOperator);
abstract::ShapePtr FSEDecodeInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
std::vector<int64_t> output_shape;
auto input_y = input_args[kInputIndex6];
MS_EXCEPTION_IF_NULL(input_y);
if (input_y->isa<abstract::AbstractTensor>()) {
auto y_value = input_y->BuildValue();
MS_EXCEPTION_IF_NULL(y_value);
abstract::ShapePtr y_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1);
auto shape_value = y_shape->shape();
if (shape_value.size() != 1) {
MS_EXCEPTION(TypeError) << "For '" << prim_name << "', the shape size must be 1, but got: " << shape_value.size()
<< ".";
}
if (y_shape->IsDynamic()) {
output_shape.push_back(abstract::Shape::kShapeRankAny);
} else {
output_shape = GetShapeValue(primitive, input_y);
}
return std::make_shared<abstract::Shape>(output_shape);
} else {
MS_EXCEPTION(TypeError) << "input_y must be AbstractTensor" << input_y;
}
return std::make_shared<abstract::Shape>(output_shape);
}
TypePtr FSEDecodeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto dst_t = prim->GetAttr(kDstT);
return TypeIdToType(static_cast<TypeId>(GetValue<int64_t>(dst_t)));
}
void FSEDecode::set_dst_t(const int64_t dst_t) { (void)AddAttr(kDstT, api::MakeValue(dst_t)); }
int64_t FSEDecode::get_dst_t() const { return GetValue<int64_t>(GetAttr(kDstT)); }
void FSEDecode::set_curr_chunk(const int64_t curr_chunk) { (void)AddAttr(KCurrChunk, api::MakeValue(curr_chunk)); }
int64_t FSEDecode::get_curr_chunk() const { return GetValue<int64_t>(GetAttr(KCurrChunk)); }
void FSEDecode::set_curr_chunk_index(const int64_t curr_chunk_index) {
(void)AddAttr(KCurrChunkIndex, api::MakeValue(curr_chunk_index));
}
int64_t FSEDecode::get_curr_chunk_index() const { return GetValue<int64_t>(GetAttr(KCurrChunkIndex)); }
void FSEDecode::set_curr_bit_count(const int64_t curr_bit_count) {
(void)AddAttr(KCurrBitCount, api::MakeValue(curr_bit_count));
}
int64_t FSEDecode::get_curr_bit_count() const { return GetValue<int64_t>(GetAttr(KCurrBitCount)); }
void FSEDecode::set_table_log(const int64_t table_log) { (void)AddAttr(KTableLog, api::MakeValue(table_log)); }
int64_t FSEDecode::get_table_log() const { return GetValue<int64_t>(GetAttr(KTableLog)); }
void FSEDecode::Init(const int64_t dst_t, const int64_t curr_chunk, const int64_t curr_chunk_index,
const int64_t curr_bit_count, const int64_t table_log) {
this->set_dst_t(dst_t);
this->set_curr_chunk(curr_chunk);
this->set_curr_chunk_index(curr_chunk_index);
this->set_curr_bit_count(curr_bit_count);
this->set_table_log(table_log);
}
AbstractBasePtr FSEDecodeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputsNum = 7;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kInputsNum, primitive->name());
auto infertype = FSEDecodeInferType(primitive, input_args);
auto infershape = FSEDecodeInferShape(primitive, input_args);
return abstract::MakeAbstract(infershape, infertype);
}
REGISTER_PRIMITIVE_EVAL_IMPL(FSEDecode, prim::kPrimFSEDecode, FSEDecodeInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,100 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_FSE_DECODER_H_
#define MINDSPORE_CORE_OPS_FSE_DECODER_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameFSEDecode = "FSEDecode";
/// \brief FSEDecode FSEDecode the FSEDecode operator prototype.
class MIND_API FSEDecode : public BaseOperator {
public:
MIND_API_BASE_MEMBER(FSEDecode);
/// \brief Constructor.
FSEDecode() : BaseOperator(kNameFSEDecode) {}
/// \brief Method to init the op's attributes.
///
/// \param[in] dst_t Define the data type of output.
void Init(const int64_t dst_t, const int64_t curr_chunk, const int64_t curr_chunk_index, const int64_t curr_bit_count,
const int64_t table_log);
/// \brief Method to set dst_t attribute.
///
/// \param[in] dst_t Define the data type of output.
void set_dst_t(const int64_t dst_t);
/// \brief Method to get dst_t attribute.
///
/// \return the data type of output.
int64_t get_dst_t() const;
/// \brief Method to set curr_chunk attribute.
///
/// \param[in] curr_chunk Define the curr_chunk attribute.
void set_curr_chunk(const int64_t curr_chunk);
/// \brief Method to get curr_chunk attribute.
///
/// \return the curr_chunk attribute.
int64_t get_curr_chunk() const;
/// \brief Method to set curr_chunk_index attribute.
///
/// \param[in] curr_chunk_index Define the curr_chunk_index attribute.
void set_curr_chunk_index(const int64_t curr_chunk_index);
/// \brief Method to get curr_chunk_index attribute.
///
/// \return the curr_chunk_index attribute.
int64_t get_curr_chunk_index() const;
/// \brief Method to set curr_bit_count attribute.
///
/// \param[in] curr_bit_count Define the curr_bit_count attribute..
void set_curr_bit_count(const int64_t curr_bit_count);
/// \brief Method to get curr_bit_count attribute.
///
/// \return the curr_bit_count attribute.
int64_t get_curr_bit_count() const;
/// \brief Method to set table_log attribute.
///
/// \param[in] table_log Define the table_log attribute.
void set_table_log(const int64_t table_log);
/// \brief Method to get table_log attribute.
///
/// \return the table_log attribute.
int64_t get_table_log() const;
};
abstract::AbstractBasePtr FSEDecodeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_FSE_DECODER_H_

View File

@ -381,6 +381,10 @@ constexpr auto kConjugate = "conjugate";
constexpr auto KExclusive = "exclusive"; constexpr auto KExclusive = "exclusive";
constexpr auto KReverse = "reverse"; constexpr auto KReverse = "reverse";
constexpr auto KComputeEigenvectors = "compute_eigenvectors"; constexpr auto KComputeEigenvectors = "compute_eigenvectors";
constexpr auto KCurrChunk = "curr_chunk";
constexpr auto KCurrChunkIndex = "curr_chunk_index";
constexpr auto KCurrBitCount = "curr_bit_count";
constexpr auto KTableLog = "table_log";
constexpr size_t kInputIndex0 = 0; constexpr size_t kInputIndex0 = 0;
constexpr size_t kInputIndex1 = 1; constexpr size_t kInputIndex1 = 1;

View File

@ -565,7 +565,7 @@ STATUS AnfTransform::DoSingleGraphQATTransform(const FuncGraphPtr &func_graph,
prim::kPrimSGD, prim::kPrimApplyMomentum}; prim::kPrimSGD, prim::kPrimApplyMomentum};
auto weight_quantizer = quant::WeightQuantizer(); auto weight_quantizer = quant::WeightQuantizer();
ret = weight_quantizer.WeightQuant(func_graph, support_primitive_types, per_layer_primitive_types, ret = weight_quantizer.WeightQuant(func_graph, support_primitive_types, per_layer_primitive_types,
support_primitive_types, true, true, false); support_primitive_types, false, true, false);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Run supplement weight quant param pass failed."; MS_LOG(ERROR) << "Run supplement weight quant param pass failed.";
return ret; return ret;
@ -621,8 +621,8 @@ STATUS AnfTransform::QATTransform(const FuncGraphPtr &func_graph, const std::sha
} }
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) { int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
quant::QuantizationOptimizer optimizer(param); quant::QuantizationOptimizer quantization_optimizer(param);
auto ret = optimizer.Run(old_graph); auto ret = quantization_optimizer.Run(old_graph);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Post training quantization failed."; MS_LOG(ERROR) << "Post training quantization failed.";
return ret; return ret;
@ -797,6 +797,11 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
MS_LOG(ERROR) << "Proc online transform failed."; MS_LOG(ERROR) << "Proc online transform failed.";
return nullptr; return nullptr;
} }
auto status = DoQuantize(old_graph, param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do Quantize failed.";
return nullptr;
}
return old_graph; return old_graph;
} }
auto value = old_graph->get_attr(kIsOptimized); auto value = old_graph->get_attr(kIsOptimized);

View File

@ -237,7 +237,8 @@ int ConfigFileParser::ParseMicroParamString(const std::map<std::string, std::map
int ConfigFileParser::ParseWeightQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) { int ConfigFileParser::ParseWeightQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
if (maps.find(kWeightQuantParam) != maps.end()) { if (maps.find(kWeightQuantParam) != maps.end()) {
const auto &map = maps.at(kWeightQuantParam); const auto &map = maps.at(kWeightQuantParam);
std::map<std::string, std::string &> parse_map{{"dequant_strategy", weight_quant_string_.dequant_strategy}}; std::map<std::string, std::string &> parse_map{{"dequant_strategy", weight_quant_string_.dequant_strategy},
{"update_mindir", weight_quant_string_.update_mindir}};
return SetMapData(map, parse_map, kWeightQuantParam); return SetMapData(map, parse_map, kWeightQuantParam);
} }
return RET_OK; return RET_OK;

View File

@ -55,6 +55,7 @@ struct MixedBitWeightQuantString {
struct WeightQuantString { struct WeightQuantString {
std::string dequant_strategy; std::string dequant_strategy;
std::string update_mindir;
}; };
struct FullQuantString { struct FullQuantString {

View File

@ -261,6 +261,11 @@ int QuantParamParser::ParseWeightQuant(const WeightQuantString &weight_quant_str
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
} }
if (!weight_quant_string.update_mindir.empty() &&
!ConvertBool(weight_quant_string.update_mindir, &weight_quant->update_mindir)) {
MS_LOG(ERROR) << "INPUT ILLEGAL: update_mindir should be true or false.";
return RET_INPUT_PARAM_INVALID;
}
return RET_OK; return RET_OK;
} }
} // namespace lite } // namespace lite

View File

@ -279,6 +279,10 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const std::shared_ptr<ConverterPara
FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const std::shared_ptr<ConverterPara> &param, FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const std::shared_ptr<ConverterPara> &param,
FuncGraphPtr func_graph) { FuncGraphPtr func_graph) {
if (!param->weightQuantParam.update_mindir) {
MS_LOG(INFO) << "It will not update mindir.";
return func_graph;
}
if (func_graph == nullptr) { if (func_graph == nullptr) {
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR"; MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
MS_LOG(ERROR) MS_LOG(ERROR)

View File

@ -65,6 +65,84 @@ int FSEDecoder::FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int
return RET_OK; return RET_OK;
} }
int FSEDecoder::DecodeBuffer(int8_t *buffer, size_t data_size, FSEBuffer *fse_buffer) {
CHECK_NULL_RETURN(buffer);
CHECK_NULL_RETURN(fse_buffer);
if (data_size < sizeof(uint16_t)) {
MS_LOG(ERROR) << "data_size is invalid.";
return RET_ERROR;
}
size_t i = 0;
// 16bit for frequency_count
fse_buffer->frequency_count = *(reinterpret_cast<uint16_t *>(buffer + i));
i += sizeof(uint16_t);
if (i > data_size) {
MS_LOG(ERROR) << "index over total size"
<< " index:" << i << " total size:" << data_size;
return RET_ERROR;
}
// 16bit for table_log
fse_buffer->table_log = *(reinterpret_cast<uint16_t *>(buffer + i));
i += sizeof(uint16_t);
if (i > data_size) {
MS_LOG(ERROR) << "index over total size"
<< " index:" << i << " total size:" << data_size;
return RET_ERROR;
}
// 32bit for ChunkCount
fse_buffer->chunk_count = *(reinterpret_cast<uint32_t *>(buffer + i));
const size_t offset = 2;
// 32bit for CurrChunkIndex
fse_buffer->curr_chunk_index = fse_buffer->chunk_count - offset;
i += sizeof(uint32_t);
if (i > data_size) {
MS_LOG(ERROR) << "index over total size"
<< " index:" << i << " total size:" << data_size;
return RET_ERROR;
}
// 32bit * frequency_count for frequency
fse_buffer->frequency = reinterpret_cast<uint32_t *>(buffer + i);
i += fse_buffer->frequency_count * sizeof(uint32_t);
// Used for 8-byte(64bit) alignment
i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
if (i > data_size) {
MS_LOG(ERROR) << "index over total size"
<< " index:" << i << " total size:" << data_size;
return RET_ERROR;
}
// 32bit * frequency_count for centroids
fse_buffer->centroids = reinterpret_cast<void *>(buffer + i);
fse_buffer->centroid_size = fse_buffer->frequency_count * sizeof(float);
i += fse_buffer->centroid_size;
// Used for 8-byte(64bit) alignment
i = ((i + kAlignOffset) >> kTableExtend) << kTableExtend;
if (i > data_size) {
MS_LOG(ERROR) << "index over total size"
<< " index:" << i << " total size:" << data_size;
return RET_ERROR;
}
// 64bit * bs_.GetCurrChunkIndex() + 1 for Chunks.
fse_buffer->chunks = reinterpret_cast<uint64_t *>(buffer + i);
fse_buffer->chunk_size = (fse_buffer->curr_chunk_index + 1) * sizeof(uint64_t);
i += fse_buffer->chunk_size;
if (i > data_size) {
MS_LOG(ERROR) << "index over total size"
<< " index:" << i << " total size:" << data_size;
return RET_ERROR;
}
// 64bit for CurrChunk
fse_buffer->curr_chunk = *(reinterpret_cast<uint64_t *>(buffer + i));
i += sizeof(uint64_t);
if (i > data_size) {
MS_LOG(ERROR) << "index over total size"
<< " index:" << i << " total size:" << data_size;
return RET_ERROR;
}
// 8bit for CurrBitCount
fse_buffer->curr_bit_count = *(reinterpret_cast<uint8_t *>(buffer + i));
return RET_OK;
}
int FSEDecoder::DeCompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_tensor, int FSEDecoder::DeCompress(const SchemaTensorWrapper &src_tensor, Tensor *dst_tensor,
schema::WeightQuantCompressType compress_type) { schema::WeightQuantCompressType compress_type) {
CHECK_NULL_RETURN(src_tensor.handler()); CHECK_NULL_RETURN(src_tensor.handler());

View File

@ -24,6 +24,19 @@
#include "src/litert/lite_model.h" #include "src/litert/lite_model.h"
namespace mindspore::lite::quant { namespace mindspore::lite::quant {
struct FSEBuffer {
uint16_t frequency_count = 0;
size_t table_log = 0;
uint32_t chunk_count = 0;
int32_t curr_chunk_index = 0;
uint32_t *frequency = nullptr;
void *centroids = nullptr;
size_t centroid_size = 0;
uint64_t *chunks = nullptr;
size_t chunk_size = 0;
uint64_t curr_chunk = 0;
uint8_t curr_bit_count = 0;
};
class FSEDecoder { class FSEDecoder {
public: public:
FSEDecoder() = default; FSEDecoder() = default;
@ -35,6 +48,8 @@ class FSEDecoder {
static int FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int symbol_frequency_count, size_t table_log, static int FSECreateStatesForDecoding(const uint32_t *symbol_frequency, int symbol_frequency_count, size_t table_log,
uint16_t *new_state_baseline, uint8_t *bit_count, uint16_t *symbol_table); uint16_t *new_state_baseline, uint8_t *bit_count, uint16_t *symbol_table);
static int DecodeBuffer(int8_t *buffer, size_t data_size, FSEBuffer *fse_buffer);
private: private:
template <typename C_TYPE, typename OUT_TYPE> template <typename C_TYPE, typename OUT_TYPE>
static int FSEDecode(FSEBitStream *bs, OUT_TYPE *buff, int buff_count, uint32_t *frequency, int frequency_count, static int FSEDecode(FSEBitStream *bs, OUT_TYPE *buff, int buff_count, uint32_t *frequency, int frequency_count,

View File

@ -26,11 +26,16 @@
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/common/format_utils.h" #include "tools/optimizer/common/format_utils.h"
#include "tools/common/node_util.h" #include "tools/common/node_util.h"
#include "tools/common/tensor_util.h"
#include "tools/converter/quantizer/fse_decoder.h"
namespace mindspore::lite::quant { namespace mindspore::lite::quant {
namespace { namespace {
constexpr size_t kMinSize3 = 3; constexpr size_t kMinSize3 = 3;
constexpr size_t kPrimitiveCOffset = 1; constexpr size_t kPrimitiveCOffset = 1;
constexpr size_t kTableExtend = 3;
constexpr size_t kAlignOffset = 7;
constexpr size_t kInt32Mask = 31;
} // namespace } // namespace
int InsertQuantNodeManager::SetCastNodeAbstract(const CNodePtr &cnode, const AnfNodePtr &input_node, int InsertQuantNodeManager::SetCastNodeAbstract(const CNodePtr &cnode, const AnfNodePtr &input_node,
const CNodePtr &cast_cnode) { const CNodePtr &cast_cnode) {
@ -472,17 +477,41 @@ int InsertQuantNodeManager::InsertWeightQuantNode(const FuncGraphPtr &func_graph
} }
auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(primitive); auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(primitive);
std::vector<schema::QuantParamT> input_quant_params; if (curr_primitive_quant_param_holder == nullptr ||
if (curr_primitive_quant_param_holder->get_input_quant_params().size() >= input_index) { curr_primitive_quant_param_holder->get_input_quant_params().size() < input_index) {
input_quant_params = curr_primitive_quant_param_holder->get_input_quant_params().at(input_index - kPrimOffset); MS_LOG(ERROR) << input_node->fullname_with_scope() << " quant param is invalid.";
return RET_ERROR;
} }
auto input_quant_params = curr_primitive_quant_param_holder->get_input_quant_params().at(input_index - kPrimOffset);
ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, {}, axis, false); ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, {}, axis, false);
std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node}; std::vector<float> scales;
std::vector<int> zps;
std::vector<float> mean_corrs;
std::vector<float> var_corrs;
for (size_t i = 0; i < input_quant_params.size(); ++i) {
scales.push_back(static_cast<float>(input_quant_params.at(i).scale));
zps.push_back(static_cast<int64_t>(input_quant_params.at(i).zeroPoint));
mean_corrs.push_back(static_cast<float>(input_quant_params.at(i).meanCorr));
var_corrs.push_back(static_cast<float>(input_quant_params.at(i).varCorr));
}
auto scales_node = opt::BuildFloatVecParameterNode(func_graph, scales, "scales");
auto zps_node = opt::BuildIntVecParameterNode(func_graph, zps, "zps");
auto mean_corrs_node = opt::BuildFloatVecParameterNode(func_graph, mean_corrs, "mean_corrs");
auto var_corrs_node = opt::BuildFloatVecParameterNode(func_graph, var_corrs, "var_corrs");
std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node, scales_node,
zps_node, mean_corrs_node, var_corrs_node};
auto quant_cast_cnode = func_graph->NewCNode(op_inputs); auto quant_cast_cnode = func_graph->NewCNode(op_inputs);
CHECK_NULL_RETURN(quant_cast_cnode); CHECK_NULL_RETURN(quant_cast_cnode);
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast_" + auto strings = SplitStringToVector(cnode->fullname_with_scope(), "-op");
std::to_string(input_index)); int index = 0;
if (!ConvertIntNum(strings.at(strings.size() - 1), &index)) {
index = 0;
}
const int quant_dtype_cast_offset = 10000;
quant_cast_cnode->set_fullname_with_scope(strings.at(0) + "-QuantDtypeCast-op" +
std::to_string(index + quant_dtype_cast_offset));
opt::NodeInferShape infer; opt::NodeInferShape infer;
auto status = infer.InferShape(quant_cast_cnode); auto status = infer.InferShape(quant_cast_cnode);
if (status != RET_OK) { if (status != RET_OK) {
@ -502,4 +531,138 @@ int InsertQuantNodeManager::InsertWeightQuantNode(const FuncGraphPtr &func_graph
return RET_OK; return RET_OK;
} }
int InsertQuantNodeManager::InsertFSEDecodeNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
size_t input_index, TypeId dst_dtype) {
auto primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
return RET_ERROR;
}
auto input_node = cnode->input(input_index);
if (!input_node->isa<mindspore::Parameter>()) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << " input " << input_index << " is not parameter node.";
return RET_ERROR;
}
auto shape = input_node->Shape();
std::vector<AnfNodePtr> op_inputs;
int ret = CreateFSEInputs(func_graph, input_node, &op_inputs, dst_dtype);
if (ret != RET_OK) {
MS_LOG(ERROR) << "CreateFSEInputs failed.";
return RET_ERROR;
}
auto fse_decode_cnode = func_graph->NewCNode(op_inputs);
CHECK_NULL_RETURN(fse_decode_cnode);
auto strings = SplitStringToVector(cnode->fullname_with_scope(), "-op");
int index = 0;
if (!ConvertIntNum(strings.at(strings.size() - 1), &index)) {
index = 0;
}
const int fse_decode_offset = 20000;
fse_decode_cnode->set_fullname_with_scope(strings.at(0) + "-FSEDecode-op" +
std::to_string(index + fse_decode_offset));
CHECK_NULL_RETURN(cnode->abstract());
auto fse_abstract = cnode->abstract()->Clone();
fse_abstract->set_shape(shape);
fse_decode_cnode->set_abstract(fse_abstract);
auto manager = func_graph->manager();
CHECK_NULL_RETURN(manager);
ret = manager->Replace(input_node, fse_decode_cnode);
if (!ret) {
MS_LOG(ERROR) << "Replace QuantDtypeCast failed.";
return RET_ERROR;
}
return RET_OK;
}
int InsertQuantNodeManager::CreateFSEInputs(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
std::vector<AnfNodePtr> *op_inputs, TypeId dst_dtype) {
if (!input_node->isa<mindspore::Parameter>()) {
MS_LOG(ERROR) << "FSEDecode input is not parameter node.";
return RET_ERROR;
}
auto parameter_ptr = input_node->cast<ParameterPtr>();
CHECK_NULL_RETURN(parameter_ptr);
if (!parameter_ptr->has_default()) {
MS_LOG(ERROR) << input_node->fullname_with_scope() << " parameter dont have default.";
return RET_ERROR;
}
auto tensor = parameter_ptr->default_param()->cast<tensor::TensorPtr>();
int8_t *data8 = reinterpret_cast<int8_t *>(tensor->data_c());
size_t data_size = tensor->DataSize();
FSEBuffer fse_buffer;
auto ret = FSEDecoder::DecodeBuffer(data8, data_size, &fse_buffer);
if (ret != RET_OK) {
MS_LOG(ERROR) << input_node->fullname_with_scope() << " buffer decode failed.";
return RET_ERROR;
}
ValueNodePtr new_primitive = NewFSEDecodePrimitive(dst_dtype, fse_buffer.curr_chunk, fse_buffer.curr_chunk_index,
fse_buffer.curr_bit_count, fse_buffer.table_log);
op_inputs->push_back(new_primitive);
// make shape to (1,chunk_size)
ShapeVector shape_vector;
shape_vector.push_back(1);
shape_vector.push_back(fse_buffer.chunk_size);
auto chunk_tensor_info =
lite::CreateTensorInfo(fse_buffer.chunks, fse_buffer.chunk_size, shape_vector, kNumberTypeInt8);
parameter_ptr->set_default_param(chunk_tensor_info);
parameter_ptr->set_abstract(chunk_tensor_info->ToAbstract());
op_inputs->push_back(input_node);
size_t table_size = 1u << fse_buffer.table_log;
uint16_t *states_table = static_cast<uint16_t *>(malloc(table_size * sizeof(uint16_t)));
CHECK_NULL_RETURN(states_table);
uint8_t *bit_count_table = static_cast<uint8_t *>(malloc(table_size * sizeof(uint8_t)));
CHECK_NULL_RETURN(bit_count_table);
uint16_t *symbol_table = static_cast<uint16_t *>(malloc(table_size * sizeof(uint16_t)));
CHECK_NULL_RETURN(symbol_table);
ret = FSEDecoder::FSECreateStatesForDecoding(fse_buffer.frequency, fse_buffer.frequency_count, fse_buffer.table_log,
states_table, bit_count_table, symbol_table);
if (ret != RET_OK) {
MS_LOG(ERROR) << "FSE create states for decoding failed.";
free(states_table);
free(bit_count_table);
free(symbol_table);
return RET_ERROR;
}
std::vector<int64_t> shape = {static_cast<int64_t>(table_size)};
auto states_table_tensor_info =
lite::CreateTensorInfo(states_table, sizeof(uint16_t) * table_size, shape, kNumberTypeUInt16);
auto states_table_node = opt::BuildParameterNode(func_graph, states_table_tensor_info, "states_table");
op_inputs->push_back(states_table_node);
auto bit_count_table_tensor_info =
lite::CreateTensorInfo(bit_count_table, sizeof(uint8_t) * table_size, shape, kNumberTypeUInt8);
auto bit_count_table_node = opt::BuildParameterNode(func_graph, bit_count_table_tensor_info, "bit_count_table");
op_inputs->push_back(bit_count_table_node);
auto symbol_table_tensor_info =
lite::CreateTensorInfo(symbol_table, sizeof(uint16_t) * table_size, shape, kNumberTypeUInt16);
auto symbol_table_node = opt::BuildParameterNode(func_graph, symbol_table_tensor_info, "symbol_table");
op_inputs->push_back(symbol_table_node);
auto centroids_tensor_info =
lite::CreateTensorInfo(fse_buffer.centroids, sizeof(float) * fse_buffer.centroid_size,
{static_cast<int64_t>(fse_buffer.centroid_size)}, kNumberTypeFloat32);
auto centroids_node = opt::BuildParameterNode(func_graph, centroids_tensor_info, "centroids");
op_inputs->push_back(centroids_node);
auto shape_tensor_info = lite::CreateTensorInfo(ConvertShapeVectorToInt32(tensor->shape_c()).data(),
sizeof(int32_t) * tensor->shape_c().size(),
{static_cast<int64_t>(tensor->shape_c().size())}, kNumberTypeInt32);
auto shape_node = opt::BuildParameterNode(func_graph, shape_tensor_info, "input_shape");
op_inputs->push_back(shape_node);
// Free buffer
free(states_table);
free(bit_count_table);
free(symbol_table);
return RET_OK;
}
} // namespace mindspore::lite::quant } // namespace mindspore::lite::quant

View File

@ -49,6 +49,7 @@ class InsertQuantNodeManager {
int InsertWeightQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, TypeId src_dtype, int InsertWeightQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, TypeId src_dtype,
TypeId dst_dtype, int axis); TypeId dst_dtype, int axis);
int InsertFSEDecodeNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, TypeId dst_dtype);
private: private:
int CheckDataType(const AnfNodePtr &input_node, TypeId check_type_id) const; int CheckDataType(const AnfNodePtr &input_node, TypeId check_type_id) const;
@ -68,6 +69,8 @@ class InsertQuantNodeManager {
const AnfNodePtr &output_node); const AnfNodePtr &output_node);
int InserQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, InsertDirection insert_direction, int InserQuantCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, InsertDirection insert_direction,
TypeId cast_dtype, CastNodeType cast_node_type, size_t index, const AnfNodePtr &output_node); TypeId cast_dtype, CastNodeType cast_node_type, size_t index, const AnfNodePtr &output_node);
int CreateFSEInputs(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, std::vector<AnfNodePtr> *op_inputs,
TypeId dst_dtype);
private: private:
TypeId dst_type_ = kNumberTypeInt8; TypeId dst_type_ = kNumberTypeInt8;

View File

@ -116,6 +116,7 @@ struct CommonQuantParam {
struct WeightQuantParam { struct WeightQuantParam {
DequantStrategy dequant_strategy = DEFAULT; DequantStrategy dequant_strategy = DEFAULT;
bool update_mindir = true;
}; };
struct MixedBitWeightQuantParam { struct MixedBitWeightQuantParam {

View File

@ -22,6 +22,7 @@
#include <deque> #include <deque>
#include <map> #include <map>
#include <set> #include <set>
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
#include "tools/lite_exporter/fetch_content.h" #include "tools/lite_exporter/fetch_content.h"
#include "base/base.h" #include "base/base.h"
#include "tools/converter/quantizer/quantize_util.h" #include "tools/converter/quantizer/quantize_util.h"
@ -237,27 +238,28 @@ int ConvertValueNodeToParameter(const FuncGraphPtr &func_graph) {
} }
int PrepareQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) { int PrepareQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
int status;
status = ConvertFp16ToFp32(old_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Convert fp16 To fp32 failed.";
return status;
}
if (!param->train_model) { if (!param->train_model) {
status = ConvertValueNodeToParameter(old_graph); auto status = ConvertValueNodeToParameter(old_graph);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Convert value node To parameter failed."; MS_LOG(ERROR) << "Convert value node To parameter failed.";
return status; return status;
} }
} }
auto convert_pm = std::make_shared<opt::LitePassManager>("anf graph convert pass manager", true);
convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(param->train_model));
auto optimizer = std::make_shared<opt::GraphOptimizer>();
optimizer->AddPassManager(convert_pm);
if (optimizer->Optimize(old_graph) == nullptr) {
MS_LOG(ERROR) << "run graph pass failed";
return RET_ERROR;
}
bool per_layer = param->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL && bool per_layer = param->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL &&
!param->fullQuantParam.per_channel && param->fullQuantParam.target_device != DSP; !param->fullQuantParam.per_channel && param->fullQuantParam.target_device != DSP;
if (per_layer) { if (per_layer) {
CLEStrategy cle_strategy(old_graph); CLEStrategy cle_strategy(old_graph);
status = cle_strategy.Run(); auto status = cle_strategy.Run();
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "do pre process failed!"; MS_LOG(ERROR) << "do pre process failed!";
return status; return status;

View File

@ -34,6 +34,7 @@
#include "ops/fusion/conv2d_transpose_fusion.h" #include "ops/fusion/conv2d_transpose_fusion.h"
#include "ops/gather.h" #include "ops/gather.h"
#include "ops/op_utils.h" #include "ops/op_utils.h"
#include "ops/fse_decode.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/litert/cxx_api/tensor/tensor_impl.h" #include "src/litert/cxx_api/tensor/tensor_impl.h"
#include "ir/anf.h" #include "ir/anf.h"
@ -165,6 +166,17 @@ ValueNodePtr NewQuantCastPrimitive(int src_type, int dst_type,
return NewValueNode(prim); return NewValueNode(prim);
} }
ValueNodePtr NewFSEDecodePrimitive(int dst_type, const uint64_t curr_chunk, const int64_t curr_chunk_index,
const int64_t curr_bit_count, const int64_t table_log) {
auto prim_c = std::make_shared<ops::FSEDecode>();
MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
prim_c->Init(dst_type, curr_chunk, curr_chunk_index, curr_bit_count, table_log);
auto prim = prim_c->GetPrim();
MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
return NewValueNode(prim);
}
bool IsGraphInDTypeCast(const CNodePtr &cnode) { bool IsGraphInDTypeCast(const CNodePtr &cnode) {
if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) { if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
return false; return false;
@ -498,7 +510,7 @@ int CalChannels(const std::vector<int> &dims, int channel_cnt, bool *channel_at_
int GetPreferredDim(const CNodePtr &cnode, int input_index, const std::vector<int> &dims) { int GetPreferredDim(const CNodePtr &cnode, int input_index, const std::vector<int> &dims) {
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
CHECK_NULL_RETURN(primitive); CHECK_NULL_RETURN(primitive);
if (primitive->name() == ops::kNameMatMulFusion) { if (primitive->name() == ops::kNameMatMulFusion || primitive->name() == ops::kNameMatMul) {
return GetMatMulPreferredDim(primitive, input_index, dims); return GetMatMulPreferredDim(primitive, input_index, dims);
} else if (primitive->name() == ops::kNameConv2dTransposeFusion) { } else if (primitive->name() == ops::kNameConv2dTransposeFusion) {
return GetDeConvPreferredDim(primitive, dims); return GetDeConvPreferredDim(primitive, dims);

View File

@ -81,6 +81,9 @@ ValueNodePtr NewQuantCastPrimitive(int src_type, int dst_type,
const std::vector<schema::QuantParamT> &output_quant_params, int axis = 0, const std::vector<schema::QuantParamT> &output_quant_params, int axis = 0,
bool set_quant_flag = true); bool set_quant_flag = true);
ValueNodePtr NewFSEDecodePrimitive(int dst_type, const uint64_t curr_chunk, const int64_t curr_chunk_index,
const int64_t curr_bit_count, const int64_t table_log);
bool IsGraphInDTypeCast(const CNodePtr &cnode); bool IsGraphInDTypeCast(const CNodePtr &cnode);
bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode); bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode);

View File

@ -33,6 +33,33 @@
#include "src/common/quant_utils.h" #include "src/common/quant_utils.h"
namespace mindspore::lite::quant { namespace mindspore::lite::quant {
namespace {
tensor::TensorPtr ConvertParameterFp16TensorToFp32(const ParameterPtr &parameter) {
if (!parameter->has_default()) {
MS_LOG(WARNING) << parameter->fullname_with_scope() << " not has_default";
return nullptr;
}
auto tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
if (tensor_info == nullptr) {
MS_LOG(WARNING) << "default_param can not cast to tensor::Tensor";
return nullptr;
}
if (tensor_info->data_type() == kNumberTypeFloat16) {
MS_LOG(INFO) << "convert " << parameter->fullname_with_scope() << " from fp16 to fp32.";
auto data = static_cast<float16 *>(tensor_info->data_c());
std::vector<float> fp32_data(tensor_info->DataSize());
for (size_t j = 0; j < tensor_info->DataSize(); j++) {
fp32_data[j] = mindspore::Float16::ToFloat32(data[j]);
}
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(
kNumberTypeFloat32, tensor_info->shape_c(), fp32_data.data(), fp32_data.size() * sizeof(float));
parameter->set_default_param(tensor_ptr);
parameter->set_abstract(tensor_ptr->ToAbstract());
return tensor_ptr;
}
return tensor_info;
}
} // namespace
int WeightQuantizer::WeightQuant(const FuncGraphPtr &func_graph, int WeightQuantizer::WeightQuant(const FuncGraphPtr &func_graph,
const std::set<PrimitivePtr> &support_weight_quant_types, const std::set<PrimitivePtr> &support_weight_quant_types,
const std::set<PrimitivePtr> &per_layer_types, const std::set<PrimitivePtr> &per_layer_types,
@ -69,7 +96,7 @@ int WeightQuantizer::WeightQuantPerCNode(const FuncGraphPtr &func_graph, const C
const std::set<PrimitivePtr> &per_layer_types, const std::set<PrimitivePtr> &per_layer_types,
const std::set<PrimitivePtr> &symmetric_types, bool compression, const std::set<PrimitivePtr> &symmetric_types, bool compression,
bool update_tensor) { bool update_tensor) {
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0)); auto primitive = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
if (primitive == nullptr) { if (primitive == nullptr) {
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr"; MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
return RET_OK; return RET_OK;
@ -121,9 +148,10 @@ int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr
const std::set<PrimitivePtr> &symmetric_types, const std::vector<int> &weight_indices, const std::set<PrimitivePtr> &symmetric_types, const std::vector<int> &weight_indices,
bool compression, bool update_tensor) { bool compression, bool update_tensor) {
CHECK_NULL_RETURN(cnode); CHECK_NULL_RETURN(cnode);
WeightQuantType weight_quant_type = WeightQuantType::FIXED_BIT_PER_CHANNEL; // Avoid affecting other operators
auto tmp_weight_quant_type = weight_quant_type_;
if (CheckNodeInSet(cnode, per_layer_types)) { if (CheckNodeInSet(cnode, per_layer_types)) {
weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER; tmp_weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
} }
bool symmetric = false; bool symmetric = false;
int q_min = quant_min_; int q_min = quant_min_;
@ -143,11 +171,16 @@ int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr
ParameterPtr parameter; ParameterPtr parameter;
tensor::TensorPtr tensor_info; tensor::TensorPtr tensor_info;
GetLiteParameter(input, &parameter, &tensor_info); GetLiteParameter(input, &parameter, &tensor_info);
if (parameter == nullptr || tensor_info == nullptr || tensor_info->data_type() != TypeId::kNumberTypeFloat32 || if (parameter == nullptr || tensor_info == nullptr ||
tensor_info->compression_type() != mindspore::kNoCompression) { tensor_info->compression_type() != mindspore::kNoCompression) {
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " dont need quant weight"; MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " dont need quant weight";
continue; continue;
} }
tensor_info = ConvertParameterFp16TensorToFp32(parameter);
if (tensor_info == nullptr || tensor_info->data_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(INFO) << "This op " << input->fullname_with_scope() << " is null or dtype is not fp32.";
continue;
}
int preferred_dim = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape())); int preferred_dim = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
if (quant_strategy_ != nullptr && !quant_strategy_->CanTensorQuantized(cnode, input, preferred_dim)) { if (quant_strategy_ != nullptr && !quant_strategy_->CanTensorQuantized(cnode, input, preferred_dim)) {
MS_LOG(INFO) << input->fullname_with_scope() << " is not quantizable"; MS_LOG(INFO) << input->fullname_with_scope() << " is not quantizable";
@ -156,7 +189,6 @@ int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr
// support for matmul shared weight // support for matmul shared weight
auto node_map = manager->node_users(); auto node_map = manager->node_users();
auto node_user = node_map[input]; auto node_user = node_map[input];
auto tmp_weight_quant_type = weight_quant_type;
if (node_user.size() > 1 && opt::CheckPrimitiveType(cnode, prim::kPrimMatMulFusion)) { if (node_user.size() > 1 && opt::CheckPrimitiveType(cnode, prim::kPrimMatMulFusion)) {
MS_LOG(INFO) << input->fullname_with_scope() << " is shared weight."; MS_LOG(INFO) << input->fullname_with_scope() << " is shared weight.";
tmp_weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER; tmp_weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
@ -190,12 +222,41 @@ int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr
} }
} }
weight_quantized_tensors_.insert(tensor_info); weight_quantized_tensors_.insert(tensor_info);
if (dequant_strategy_ == ON_THE_FLY) {
status = InsertDequantNode(func_graph, cnode, parameter, idx, tensor_info);
if (status == RET_NO_CHANGE) {
continue;
} else if (status != RET_OK) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << " insert dequan node failed.";
return status;
}
}
} }
return RET_OK; return RET_OK;
} }
int WeightQuantizer::DoCompression(const CNodePtr &cnode, const ParameterPtr &parameter, int idx) { int WeightQuantizer::DoCompression(const CNodePtr &cnode, const ParameterPtr &parameter, int idx) {
int ret = RET_OK; int ret = RET_OK;
if (dequant_strategy_ == ON_THE_FLY) {
if (bit_num_ < k8Bit) {
FSEEncoder fse_encoder;
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
CHECK_NULL_RETURN(primitive);
auto quant_param_holder = GetCNodeQuantHolder(primitive);
auto tensor_quant_params = quant_param_holder->get_input_quant_params();
MS_CHECK_GT(static_cast<int>(tensor_quant_params.size()), idx - 1, RET_ERROR);
auto quant_params = tensor_quant_params.at(idx - 1);
mindspore::TensorCompressionType compress_type =
dequant_strategy_ == ON_THE_FLY ? mindspore::kFSEInfer : mindspore::kFSE;
ret = fse_encoder.Compress(parameter, quant_params, compress_type);
auto new_tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
CHECK_NULL_RETURN(new_tensor_info);
weight_quantized_tensors_.insert(new_tensor_info);
return ret;
} else {
return RET_OK;
}
}
TensorCompressor compressor; TensorCompressor compressor;
auto quant_param_holder = GetCNodeQuantHolder(cnode); auto quant_param_holder = GetCNodeQuantHolder(cnode);
auto tensor_quant_params = quant_param_holder->get_input_quant_params(); auto tensor_quant_params = quant_param_holder->get_input_quant_params();
@ -239,7 +300,7 @@ int WeightQuantizer::DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr &pa
MS_CHECK_GT(static_cast<int>(tensor_quant_params.size()), idx - 1, RET_ERROR); MS_CHECK_GT(static_cast<int>(tensor_quant_params.size()), idx - 1, RET_ERROR);
auto quant_params = tensor_quant_params.at(idx - 1); auto quant_params = tensor_quant_params.at(idx - 1);
mindspore::TensorCompressionType compress_type = mindspore::TensorCompressionType compress_type =
param_->weightQuantParam.dequant_strategy == ON_THE_FLY ? mindspore::kFSEInfer : mindspore::kFSE; dequant_strategy_ == ON_THE_FLY ? mindspore::kFSEInfer : mindspore::kFSE;
status = fse_encoder.Compress(parameter, quant_params, compress_type); status = fse_encoder.Compress(parameter, quant_params, compress_type);
if (status == RET_OK) { if (status == RET_OK) {
quant_param_holder->ClearQuantParams(); quant_param_holder->ClearQuantParams();
@ -269,34 +330,40 @@ int WeightQuantizer::DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr &pa
return status; return status;
} }
int WeightQuantizer::InsertQuantDtypeNode(const FuncGraphPtr &func_graph) { int WeightQuantizer::InsertDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const ParameterPtr &parameter, int idx, const tensor::TensorPtr &tensor_info) {
InsertQuantNodeManager quant_manager;
CHECK_NULL_RETURN(func_graph); CHECK_NULL_RETURN(func_graph);
for (auto &cnode : func_graph->GetOrderedCnodes()) { TypeId type_id;
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0)); int status;
if (primitive == nullptr) { auto tensor_name = parameter->fullname_with_scope();
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr"; if (opt::GetDataTypeFromAnfNode(cnode, &type_id) != RET_OK) {
continue; MS_LOG(WARNING) << cnode->fullname_with_scope() << " Get data type failed.";
return RET_NO_CHANGE;
}
if (parameter->has_default() &&
parameter->default_param()->cast<tensor::TensorPtr>()->compression_type() == mindspore::kFSEInfer) {
MS_LOG(INFO) << tensor_name << " insert FSEDecode node";
if (type_id == kNumberTypeFloat32) {
status = quant_manager.InsertFSEDecodeNode(func_graph, cnode, idx, kNumberTypeFloat32);
} else {
status = quant_manager.InsertFSEDecodeNode(func_graph, cnode, idx, kNumberTypeFloat16);
} }
// Support Share Weight Quant. if (status != RET_OK) {
for (size_t i = kPrimOffset; i < cnode->size(); i++) { MS_LOG(ERROR) << tensor_name << " insert FSEDecode node failed.";
auto inputNode = cnode->input(i); return status;
if (inputNode->isa<Parameter>()) { }
ParameterPtr param_node; } else {
tensor::TensorPtr tensor_info; MS_LOG(INFO) << tensor_name << " insert WeightQuant node";
GetLiteParameter(inputNode, &param_node, &tensor_info); auto axis = GetPreferredDim(cnode, idx - kPrimOffset, ConvertShapeVectorToInt32(tensor_info->shape_c()));
auto param = weight_quantized_tensors_.find(tensor_info); if (type_id == kNumberTypeFloat32) {
if (param != weight_quantized_tensors_.end()) { status = quant_manager.InsertWeightQuantNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat32, axis);
InsertQuantNodeManager manager; } else {
auto ret = manager.InsertWeightQuantNode( status = quant_manager.InsertWeightQuantNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat16, axis);
func_graph, cnode, i, kNumberTypeInt8, kNumberTypeFloat32, }
GetPreferredDim(cnode, i - kPrimOffset, ConvertShapeVectorToInt32(tensor_info->shape_c()))); if (status != RET_OK) {
if (ret != RET_OK) { MS_LOG(ERROR) << tensor_name << " insert weight quant node failed.";
MS_LOG(ERROR) << "Insert weight quant node failed."; return status;
return ret;
}
continue;
}
}
} }
} }
return RET_OK; return RET_OK;
@ -387,9 +454,9 @@ int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_LOG(ERROR) << "Weight Quant failed."; MS_LOG(ERROR) << "Weight Quant failed.";
return ret; return ret;
} }
if (param_->weightQuantParam.dequant_strategy == ON_THE_FLY) { if (dequant_strategy_ != ON_THE_FLY) {
return InsertQuantDtypeNode(func_graph); return MarkGraphWeightQuantType(func_graph);
} }
return MarkGraphWeightQuantType(func_graph); return RET_OK;
} }
} // namespace mindspore::lite::quant } // namespace mindspore::lite::quant

View File

@ -90,6 +90,10 @@ class WeightQuantizer : public Quantizer {
std::inserter(skip_quant_node_, skip_quant_node_.begin())); std::inserter(skip_quant_node_, skip_quant_node_.begin()));
} }
quant_type_ = param_->commonQuantParam.quant_type; quant_type_ = param_->commonQuantParam.quant_type;
dequant_strategy_ = param_->weightQuantParam.dequant_strategy;
if (param_->weightQuantParam.dequant_strategy == ON_THE_FLY) {
weight_quant_type_ = WeightQuantType::FIXED_BIT_PER_LAYER;
}
} }
~WeightQuantizer() override = default; ~WeightQuantizer() override = default;
@ -117,7 +121,8 @@ class WeightQuantizer : public Quantizer {
int preferred_dim, WeightQuantType weight_quant_type, bool symmetric = true, int preferred_dim, WeightQuantType weight_quant_type, bool symmetric = true,
bool update_tensor = true); bool update_tensor = true);
bool CheckWeightQuantExist(const CNodePtr &cnode); bool CheckWeightQuantExist(const CNodePtr &cnode);
int InsertQuantDtypeNode(const FuncGraphPtr &func_graph); int InsertDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const ParameterPtr &parameter, int idx,
const tensor::TensorPtr &tensor_info);
private: private:
bool is_auto_tune_{false}; bool is_auto_tune_{false};
@ -134,6 +139,8 @@ class WeightQuantizer : public Quantizer {
std::unique_ptr<QuantStrategy> quant_strategy_; std::unique_ptr<QuantStrategy> quant_strategy_;
schema::QuantType quant_type_{schema::QuantType_WeightQuant}; schema::QuantType quant_type_{schema::QuantType_WeightQuant};
bool enable_encode_{true}; bool enable_encode_{true};
WeightQuantType weight_quant_type_ = WeightQuantType::FIXED_BIT_PER_CHANNEL;
DequantStrategy dequant_strategy_ = DEFAULT;
// Support for mark shared weight node. // Support for mark shared weight node.
std::set<tensor::TensorPtr> weight_quantized_tensors_; std::set<tensor::TensorPtr> weight_quantized_tensors_;
}; };