!29119 converter support dynamic quant

Merge pull request !29119 from yeyunpeng2020/dynamic_quant
This commit is contained in:
i-robot 2022-01-15 07:41:34 +00:00 committed by Gitee
commit 1bec7aea97
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
29 changed files with 704 additions and 191 deletions

View File

@ -0,0 +1,27 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_DYNAMIC_QUANT_PARAMETER_H_
#define MINDSPORE_NNACL_DYNAMIC_QUANT_PARAMETER_H_
#include "nnacl/op_base.h"
typedef struct DynamicQuantParameter {
OpParameter op_parameter_;
bool symmetric_;
int64_t dst_type_;
} DynamicQuantParameter;
#endif // MINDSPORE_NNACL_DYNAMIC_QUANT_PARAMETER_H_

View File

@ -128,8 +128,9 @@ enum NNACLQuantType {
QuantType_PostTraining = 3,
QuantType_QUANT_WEIGHT = 4,
QuantType_QUANT_ALL = 5,
QuantType_QUANT_DYNAMIC = 6,
QuantType_MIN = QuantType_QUANT_NONE,
QuantType_MAX = QuantType_QUANT_ALL
QuantType_MAX = QuantType_QUANT_DYNAMIC
};
typedef struct vvector {

View File

@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/infer/dynamic_quant_infer.h"
#include "nnacl/infer/infer_register.h"
#include "nnacl/dynamic_quant_parameter.h"
int DynamicQuantInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1);
if (check_ret != NNACL_OK) {
return check_ret;
}
const TensorC *input = inputs[0];
TensorC *output = outputs[0];
DynamicQuantParameter *param = (DynamicQuantParameter *)parameter;
output->data_type_ = param->dst_type_;
MS_CHECK_TRUE_RET(output->data_type_ > kNumberTypeBegin && output->data_type_ < kNumberTypeEnd, NNACL_ERR);
output->format_ = input->format_;
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
SetShapeTensor(output, input);
return NNACL_OK;
}
REG_INFER(DynamicQuant, PrimType_DynamicQuant, DynamicQuantInferShape)

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,21 +13,19 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_DYNAMIC_QUANT_INFER_H
#define MINDSPORE_NNACL_DYNAMIC_QUANT_INFER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_CAST_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_CAST_H
#include "nnacl/infer/common_infer.h"
#include "include/errorcode.h"
#include "ir/anf.h"
#include "ir/dtype/type_id.h"
#include "ir/func_graph.h"
#ifdef __cplusplus
extern "C" {
#endif
namespace mindspore::lite::quant {
class QuantCast {
public:
QuantCast() = default;
virtual ~QuantCast() = default;
int Run(const FuncGraphPtr &graph);
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_CAST_H
int DynamicQuantInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_DYNAMIC_QUANT_INFER_H

View File

@ -56,27 +56,12 @@ int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_
return NNACL_OK;
}
int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1);
if (check_ret != NNACL_OK) {
return check_ret;
}
int SetShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
TensorC *input0 = (TensorC *)inputs[0];
TensorC *input1 = (TensorC *)inputs[1];
TensorC *output = outputs[0];
int diff = abs((int)input0->shape_size_ - (int)input1->shape_size_);
TensorC *in = input0->shape_size_ > input1->shape_size_ ? input1 : input0;
for (int i = 0; i < diff; ++i) {
ShapeInsert(in->shape_, &in->shape_size_, 0, 1);
}
SetDataTypeFormat(output, input0);
MatMulParameter *param = (MatMulParameter *)parameter;
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
int a_shape[MAX_SHAPE_SIZE] = {0};
size_t a_shape_size = 0;
ShapeSet(a_shape, &a_shape_size, input0->shape_, input0->shape_size_);
@ -137,4 +122,30 @@ int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
return NNACL_OK;
}
int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1);
if (check_ret != NNACL_OK) {
return check_ret;
}
TensorC *input0 = (TensorC *)inputs[0];
TensorC *input1 = (TensorC *)inputs[1];
TensorC *output = outputs[0];
int diff = abs((int)input0->shape_size_ - (int)input1->shape_size_);
TensorC *in = input0->shape_size_ > input1->shape_size_ ? input1 : input0;
for (int i = 0; i < diff; ++i) {
ShapeInsert(in->shape_, &in->shape_size_, 0, 1);
}
SetDataTypeFormat(output, input0);
if (parameter->quant_type_ == QuantType_QUANT_DYNAMIC) {
output->data_type_ = kNumberTypeFloat32;
}
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
return SetShape(inputs, inputs_size, outputs, outputs_size, parameter);
}
REG_INFER(MatMul, PrimType_MatMulFusion, MatmulInferShape)

View File

@ -457,8 +457,9 @@ enum PrimType {
PrimType_Affine = 200,
PrimType_AllGather = 201,
PrimType_ReduceScatter = 202,
PrimType_DynamicQuant = 203,
PrimType_MIN = PrimType_NONE,
PrimType_MAX = PrimType_ReduceScatter + 1
PrimType_MAX = PrimType_DynamicQuant + 1
};
typedef enum LiteDataType {

View File

@ -53,17 +53,17 @@ class MS_CORE_API DynamicQuant : public PrimitiveC {
/// \param[in] symmetric Define whether symmetric quantization.
void set_symmetric(const bool symmetric);
/// \brief Method to get src_t attribute.
/// \brief Method to get symmetric attribute.
///
/// \return Whether symmetric quantization.
bool get_symmetric() const;
/// \brief Method to set dst_t attribute.
/// \brief Method to set dst_type attribute.
///
/// \param[in] dst_t Define the data type of output.
void set_dst_type(const int64_t dst_type);
/// \brief Method to get dst_t attribute.
/// \brief Method to get dst_type attribute.
///
/// \return the data type of output.
int64_t get_dst_type() const;

View File

@ -0,0 +1,45 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/populate/populate_register.h"
#include "nnacl/dynamic_quant_parameter.h"
using mindspore::schema::PrimitiveType_DynamicQuant;
namespace mindspore {
namespace lite {
OpParameter *PopulateDynamicQuantParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto value = primitive->value_as_DynamicQuant();
if (value == nullptr) {
MS_LOG(ERROR) << "value is nullptr";
return nullptr;
}
auto *param = reinterpret_cast<DynamicQuantParameter *>(malloc(sizeof(DynamicQuantParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc DynamicQuantParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(DynamicQuantParameter));
param->op_parameter_.type_ = primitive->value_type();
param->dst_type_ = value->dst_type();
param->symmetric_ = value->symmetric();
return reinterpret_cast<OpParameter *>(param);
}
REG_POPULATE(PrimitiveType_DynamicQuant, PopulateDynamicQuantParameter, SCHEMA_CUR);
} // namespace lite
} // namespace mindspore

View File

@ -18,6 +18,7 @@
#include "src/weight_decoder.h"
#include "src/huffman_decode.h"
#include "tools/converter/quantizer/fse_decoder.h"
#include "nnacl/conv_parameter.h"
namespace mindspore::lite {
namespace {
@ -416,9 +417,24 @@ int WeightDecoder::GetMatMulPreferredDim(OpParameter *op_parameter, int input_in
return 0;
}
int WeightDecoder::GetDeConvPreferredDim(OpParameter *op_parameter, const std::vector<int> &dims) {
MS_ASSERT(op_parameter != nullptr);
auto parameter = reinterpret_cast<ConvParameter *>(op_parameter);
if (parameter->input_channel_ == parameter->group_ && parameter->output_channel_ == parameter->group_) {
// DepthWise-DeConv (CO\CI) KH KW 1
return 0;
} else {
// DeConv:CI KH KW CO
return dims.size() - 1;
}
}
int WeightDecoder::GetPreferredDim(OpParameter *op_parameter, int index, const std::vector<int> &dims) {
MS_ASSERT(op_parameter != nullptr);
if (op_parameter->type_ == schema::PrimitiveType_MatMulFusion) {
return GetMatMulPreferredDim(op_parameter, index, dims);
} else if (op_parameter->type_ == schema::PrimitiveType_Conv2dTransposeFusion) {
return 0;
}
// The first index.
return 0;

View File

@ -241,7 +241,7 @@ class WeightDecoder {
}
static int GetMatMulPreferredDim(OpParameter *op_parameter, int input_index, const std::vector<int> &dims);
static int GetDeConvPreferredDim(OpParameter *op_parameter, const std::vector<int> &dims);
static int DequantWeight(lite::Tensor *input_tensor, int preferred_dim, TypeId dst_data_type = kNumberTypeFloat32);
template <typename T1, typename T2>

View File

@ -1,4 +1,4 @@
ssd.r1.1.mindir 1.3 5401000
ssd.r1.1.mindir 1.3 5401040
ml_segmentation_matting 130 160224
ml_video_edit_enhance.pb 22 546552
hiai_ghostnet.tflite 4.7 5745336

View File

@ -435,5 +435,22 @@ ValueNodePtr GetCallAnfPrim() {
MS_CHECK_TRUE_MSG(call_anf_prim != nullptr, nullptr, "call_anf_prim is nullptr");
return call_anf_prim;
}
int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type) {
auto abstract_base = cnode->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of node is nullptr, " << cnode->fullname_with_scope();
return RET_NULL_PTR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of node should be anstract tensor, " << cnode->fullname_with_scope();
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
CHECK_NULL_RETURN(abstract_tensor);
CHECK_NULL_RETURN(abstract_tensor->element());
abstract_tensor->element()->set_type(TypeIdToType(new_data_type));
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -432,6 +432,8 @@ ValueNodePtr GetCallAnfPrim();
inline bool IsGraphInput(const AnfNodePtr &cnode) {
return cnode->isa<Parameter>() && !cnode->cast<ParameterPtr>()->has_default();
}
int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type);
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H

View File

@ -75,8 +75,8 @@
#include "tools/optimizer/graph/specify_graph_input_format.h"
#include "tools/optimizer/graph/dump_graph.h"
#include "tools/converter/quantizer/full_quant_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/weight_quantizer.h"
#include "tools/converter/quantizer/dynamic_quantizer.h"
#include "tools/optimizer/parallel/split_strategy.h"
#include "tools/optimizer/parallel/spliter.h"
#include "tools/optimizer/fisson/iter_node_outputs.h"
@ -375,101 +375,147 @@ void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGr
}
}
int DoFullQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto quantizer = std::make_unique<quant::FullQuantQuantizer>(*config);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New FullQuantQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}
auto status = quantizer->DoQuantize(old_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantization failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}
return RET_OK;
}
int DoWeightQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) {
double init_scale = config->mixedBitWeightQuantParam.init_scale;
if (config->commonQuantParam.bit_num == 0 && config->mixedBitWeightQuantParam.auto_tune) {
quant::ParameterOptimizer optimizer;
auto status = optimizer.GridSearchForScale(old_graph, const_cast<converter::Flags *>(config), &init_scale);
if (status != RET_OK) {
MS_LOG(ERROR) << "Grid search with scale failed.";
return status;
}
auto quantizer = std::make_unique<quant::WeightQuantizer>(*config);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}
status = static_cast<quant::WeightQuantizer *>(quantizer.get())->DoQuantize(old_graph, init_scale);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantization failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}
} else {
auto quantizer = std::make_unique<quant::WeightQuantizer>(*config);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}
auto status = quantizer->DoQuantize(old_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantization failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}
}
return RET_OK;
}
int DoDynamicQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto quantizer = std::make_unique<quant::DynamicQuantizer>(*config);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New DynamicQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}
auto status = quantizer->DoQuantize(old_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantization failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}
return RET_OK;
}
int DoQuantDebug(const FuncGraphPtr &old_graph, const converter::Flags *config, const quant::SessionModel &origin) {
auto quant = quant::CreateSessionByFuncGraph(old_graph, *config, config->commonQuantParam.thread_num);
std::map<std::string, OpParameter *> op_parameters;
FetchOpParameterFromFuncGraph(old_graph, &op_parameters);
DebugInfoManager manager;
CHECK_NULL_RETURN(origin.model);
CHECK_NULL_RETURN(origin.session);
CHECK_NULL_RETURN(quant.model);
CHECK_NULL_RETURN(quant.session);
auto status = manager.CompareOriginWithQuant(
origin, quant, op_parameters, config->commonQuantParam.debug_info_save_path, config->dataPreProcessParam);
auto free_buffer = [&] {
delete origin.session;
delete origin.model;
delete quant.session;
delete quant.model;
for (auto parameter : op_parameters) {
if (parameter.second != nullptr) {
free(parameter.second);
parameter.second = nullptr;
}
}
op_parameters.clear();
};
if (status != RET_OK) {
MS_LOG(ERROR) << "Compare origin with quant failed.";
free_buffer();
return status;
}
free_buffer();
return RET_OK;
}
int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
// quant
if (config->commonQuantParam.quant_type != schema::QuantType_QUANT_ALL &&
config->commonQuantParam.quant_type != schema::QuantType_QUANT_WEIGHT) {
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_NONE) {
return RET_OK;
}
int status;
std::unique_ptr<quant::Quantizer> quantizer = nullptr;
quant::SessionModel origin;
quant::SessionModel quant;
if (config->commonQuantParam.is_debug) {
converter::Flags new_flag = *config;
new_flag.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
origin = quant::CreateSessionByFuncGraph(old_graph, new_flag, config->commonQuantParam.thread_num);
}
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) {
quantizer = std::make_unique<quant::FullQuantQuantizer>(*config);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New FullQuantQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}
status = quantizer->DoQuantize(old_graph);
status = DoFullQuant(old_graph, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantization failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
MS_LOG(ERROR) << "Do full quant failed.";
return status;
}
} else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) {
double init_scale = config->mixedBitWeightQuantParam.init_scale;
if (config->commonQuantParam.bit_num == 0 && config->mixedBitWeightQuantParam.auto_tune) {
quant::ParameterOptimizer optimizer;
status = optimizer.GridSearchForScale(old_graph, const_cast<converter::Flags *>(config), &init_scale);
if (status != RET_OK) {
MS_LOG(ERROR) << "Grid search with scale failed.";
return status;
}
quantizer = std::make_unique<quant::WeightQuantizer>(*config);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}
status = static_cast<quant::WeightQuantizer *>(quantizer.get())->DoQuantize(old_graph, init_scale);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantization failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}
} else {
quantizer = std::make_unique<quant::WeightQuantizer>(*config);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}
status = quantizer->DoQuantize(old_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantization failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}
status = DoWeightQuant(old_graph, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do weight quant failed.";
return status;
}
} else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_DANAMIC) {
status = DoDynamicQuant(old_graph, config);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do dynamic quant failed.";
return status;
}
}
if (config->commonQuantParam.is_debug) {
quant = quant::CreateSessionByFuncGraph(old_graph, *config, config->commonQuantParam.thread_num);
std::map<std::string, OpParameter *> op_parameters;
FetchOpParameterFromFuncGraph(old_graph, &op_parameters);
DebugInfoManager manager;
CHECK_NULL_RETURN(origin.model);
CHECK_NULL_RETURN(origin.session);
CHECK_NULL_RETURN(quant.model);
CHECK_NULL_RETURN(quant.session);
status = manager.CompareOriginWithQuant(origin, quant, op_parameters, config->commonQuantParam.debug_info_save_path,
config->dataPreProcessParam);
auto free_buffer = [&] {
delete origin.session;
delete origin.model;
delete quant.session;
delete quant.model;
for (auto parameter : op_parameters) {
if (parameter.second != nullptr) {
free(parameter.second);
parameter.second = nullptr;
}
}
op_parameters.clear();
};
status = DoQuantDebug(old_graph, config, origin);
if (status != RET_OK) {
MS_LOG(ERROR) << "Compare origin with quant failed.";
free_buffer();
MS_LOG(ERROR) << "Do quant debug failed.";
return status;
}
free_buffer();
}
return RET_OK;
}

View File

@ -23,7 +23,6 @@
#include "tools/converter/preprocess/opencv_utils.h"
#include "src/common/log_adapter.h"
#include "mindspore/lite/tools/common/string_util.h"
#include "src/common/file_utils.h"
#include "include/errorcode.h"
namespace mindspore {

View File

@ -73,6 +73,11 @@ int QuantParamParser::ParseBitNum(const CommonQuantString &common_quant_string,
MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be [1,8].";
return RET_INPUT_PARAM_INVALID;
}
} else if (common_quant->quant_type == schema::QuantType_QUANT_DANAMIC) {
if (common_quant->bit_num != kQuantBitNumInt8) {
MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be 8.";
return RET_INPUT_PARAM_INVALID;
}
}
return RET_OK;
}
@ -155,11 +160,14 @@ int QuantParamParser::ParseQuantType(const std::string &quant_type_str, schema::
} else if (quant_type_str == "FULL_QUANT") {
(*quant_type) = schema::QuantType_QUANT_ALL;
return RET_OK;
} else if (quant_type_str == "DYNAMIC_QUANT") {
(*quant_type) = schema::QuantType_QUANT_DANAMIC;
return RET_OK;
} else if (quant_type_str.empty()) {
(*quant_type) = schema::QuantType_QUANT_NONE;
return RET_OK;
} else {
MS_LOG(ERROR) << "INPUT ILLEGAL: quant_type must be WEIGHT_QUANT|FULL_QUANT.";
MS_LOG(ERROR) << "INPUT ILLEGAL: quant_type must be WEIGHT_QUANT|FULL_QUANT|DYNAMIC_QUANT.";
return RET_INPUT_PARAM_INVALID;
}
}

View File

@ -26,7 +26,6 @@
#include "ops/space_to_batch_nd.h"
#include "ops/space_to_depth.h"
#include "tools/converter/quant_param_holder.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "nnacl/op_base.h"
#include "src/common/log_util.h"

View File

@ -0,0 +1,7 @@
[common_quant_param]
quant_type=DYNAMIC_QUANT
bit_num=8
# Layers with size of weights exceeds threshold `min_quant_weight_size` will be quantized.
min_quant_weight_size=0
# Layers with channel size of weights exceeds threshold `min_quant_weight_channel` will be quantized.
min_quant_weight_channel=16

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/dynamic_quantizer.h"
#include "tools/converter/quantizer/weight_quantizer.h"
#include "tools/converter/quantizer/insert_quant_node_manager.h"
namespace mindspore::lite::quant {
int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) {
InsertQuantNodeManager manager;
auto ret = manager.InsertDynamicQuantPass(func_graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert dynamic quant failed.";
return ret;
}
auto quantizer = WeightQuantizer(flags_);
flags_.commonQuantParam.min_quant_weight_channel = 0;
flags_.commonQuantParam.min_quant_weight_size = 0;
const std::set<PrimitivePtr> support_weight_quant_nodes = {prim::kPrimMatMulFusion, prim::kPrimGather};
const std::set<PrimitivePtr> symmetric_nodes = {prim::kPrimMatMulFusion};
ret = quantizer.WeightQuant(func_graph, support_weight_quant_nodes, {}, symmetric_nodes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Weight Quant failed.";
return ret;
}
return RET_OK;
}
} // namespace mindspore::lite::quant

View File

@ -0,0 +1,59 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DYNAMIC_QUANTIZER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DYNAMIC_QUANTIZER_H_
#include <future>
#include <memory>
#include <unordered_map>
#include <map>
#include <list>
#include <string>
#include <utility>
#include <vector>
#include <set>
#include "tools/converter/quantizer/quantizer.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/converter/quantizer/quant_params.h"
#include "tools/converter/quantizer/quant_strategy.h"
#include "tools/converter/preprocess/preprocess_param.h"
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "include/model.h"
#include "base/base.h"
#include "abstract/dshape.h"
#include "src/lite_session.h"
#include "src/common/quant_utils.h"
namespace mindspore::lite::quant {
class DynamicQuantizer : public Quantizer {
public:
explicit DynamicQuantizer(const converter::Flags &flags) : Quantizer(flags) {
bit_num_ = flags.commonQuantParam.bit_num;
}
~DynamicQuantizer() = default;
int DoQuantize(FuncGraphPtr func_graph) override;
private:
size_t bit_num_{8};
int quant_max_{127};
int quant_min_{-128};
TypeId type_id_{kNumberTypeInt8};
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_

View File

@ -23,7 +23,7 @@
#include <vector>
#include "ops/tuple_get_item.h"
#include "src/tensor.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/insert_quant_node_manager.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "src/common/log_adapter.h"
@ -618,8 +618,8 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
}
if (activation_target_data_type_ == kNumberTypeInt8 || activation_target_data_type_ == kNumberTypeUInt8) {
// add quant_cast
quant::QuantCast quant_cast;
status = quant_cast.Run(func_graph);
quant::InsertQuantNodeManager inset_quant_node_pass;
status = inset_quant_node_pass.InsertQuantDtypeCastPass(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "add QuantCast error";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,19 +14,22 @@
* limitations under the License.
*/
#include "mindspore/lite/tools/converter/quantizer/quant_cast.h"
#include "mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h"
#include <memory>
#include <set>
#include <vector>
#include "ops/quant_dtype_cast.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "ops/dynamic_quant.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/common/format_utils.h"
#include "tools/common/node_util.h"
namespace mindspore::lite::quant {
ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params) {
ValueNodePtr InsertQuantNodeManager::NewQuantCastValueNode(int src_type, int dst_type_,
const std::vector<schema::QuantParamT> &quant_params) {
auto prim_c = std::make_shared<ops::QuantDTypeCast>();
MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
prim_c->Init(src_type, dst_type);
prim_c->Init(src_type, dst_type_);
auto quant_params_holder = std::make_shared<QuantParamHolder>(quant_params.size(), quant_params.size());
MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL);
@ -40,7 +43,8 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector
return NewValueNode(prim_c);
}
int InsertCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t input_index, bool is_graph_input) {
int InsertQuantNodeManager::InsertCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t input_index,
bool is_graph_input) {
auto primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(0));
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
@ -107,7 +111,8 @@ int InsertCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t inpu
return RET_OK;
}
int CheckDataType(const AnfNodePtr &input_node, bool is_graph_input) {
int InsertQuantNodeManager::CheckDataType(const AnfNodePtr &input_node, TypeId check_type_id) {
bool is_graph_input = IsGraphInput(input_node);
if (!input_node->isa<mindspore::CNode>() && !is_graph_input) {
return RET_NO_CHANGE;
}
@ -120,27 +125,27 @@ int CheckDataType(const AnfNodePtr &input_node, bool is_graph_input) {
MS_LOG(ERROR) << "Fetch DataType from cnode failed.";
return ret;
}
if (type_id != kNumberTypeFloat32) {
if (type_id != check_type_id) {
return RET_NO_CHANGE;
}
}
return RET_OK;
}
int QuantCast::Run(const FuncGraphPtr &graph) {
int InsertQuantNodeManager::InsertQuantDtypeCastPass(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto cnodes = graph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i);
auto is_graph_input = input_node->isa<Parameter>() && !input_node->cast<ParameterPtr>()->has_default();
auto ret = CheckDataType(input_node, is_graph_input);
auto ret = CheckDataType(input_node, kNumberTypeFloat32);
if (ret == RET_NO_CHANGE) {
continue;
} else if (ret != RET_OK) {
MS_LOG(ERROR) << "Check data type failed.";
return ret;
}
bool is_graph_input = IsGraphInput(input_node);
ret = InsertCastNode(graph, cnode, i, is_graph_input);
if (ret == RET_NO_CHANGE) {
continue;
@ -152,4 +157,80 @@ int QuantCast::Run(const FuncGraphPtr &graph) {
}
return RET_OK;
}
int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode) {
auto primitive_c = std::make_shared<ops::DynamicQuant>();
primitive_c->set_dst_type(dst_type_);
primitive_c->set_symmetric(symmetric_);
auto op_name = cnode->fullname_with_scope();
if (cnode->size() <= kInputSize1) {
MS_LOG(ERROR) << op_name << " cnode size <= 2.";
return RET_ERROR;
}
auto dynamic_quant_cnode = graph->NewCNode(primitive_c, {cnode->input(1)});
dynamic_quant_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dynamic_cast_node");
CHECK_NULL_RETURN(cnode->abstract());
auto abstract = cnode->abstract()->Clone();
if (abstract == nullptr) {
MS_LOG(ERROR) << "Abstract of node is nullptr, " << cnode->fullname_with_scope();
return RET_NULL_PTR;
}
dynamic_quant_cnode->set_abstract(abstract);
auto ret = UpdateDataType(cnode, dst_type_);
if (ret != RET_OK) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << " set new dtype failed.";
return ret;
}
MarkDynamicQuantize(dynamic_quant_cnode);
cnode->set_input(1, dynamic_quant_cnode);
return RET_OK;
}
int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) {
MS_CHECK_TRUE_RET(cnode != nullptr, RET_NULL_PTR);
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_ERROR;
}
auto quant_param_holder = GetCNodeQuantHolder(primitive);
quant_param_holder->set_quant_type(schema::QuantType_QUANT_DANAMIC);
return RET_OK;
}
int InsertQuantNodeManager::InsertDynamicQuantPass(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto cnodes = graph->GetOrderedCnodes();
const std::set<PrimitivePtr> support_dynamic_quant_ops = {
prim::kPrimMatMulFusion,
};
for (auto &cnode : cnodes) {
auto ret = CheckDataType(cnode, kNumberTypeFloat32);
if (ret == RET_NO_CHANGE) {
continue;
}
auto is_support_node = CheckNodeInSet(cnode, support_dynamic_quant_ops);
if (!is_support_node) {
auto type = NodePrimitiveType(cnode);
MS_LOG(INFO) << "node:" << cnode->fullname_with_scope() << " type:" << type << " will not quantify.";
continue;
}
ret = NewDynamicQuantNode(graph, cnode);
if (ret != RET_OK) {
MS_LOG(ERROR) << "node:" << cnode->fullname_with_scope() << " new dynamic quant node failed.";
return ret;
}
ret = MarkDynamicQuantize(cnode);
if (ret != RET_OK) {
MS_LOG(ERROR) << "node:" << cnode->fullname_with_scope() << " new mark dynamic quant node failed.";
return ret;
}
ret = UpdateDataType(cnode, kNumberTypeFloat32);
if (ret != RET_OK) {
MS_LOG(ERROR) << "node:" << cnode->fullname_with_scope() << " update datatype failed.";
return ret;
}
}
return RET_OK;
}
} // namespace mindspore::lite::quant

View File

@ -0,0 +1,53 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_INSERT_QUANT_NODE_MANAGER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_INSERT_QUANT_NODE_MANAGER_H
#include <vector>
#include "include/errorcode.h"
#include "ir/anf.h"
#include "ir/dtype/type_id.h"
#include "ir/func_graph.h"
#include "tools/converter/quantizer/quantize_util.h"
namespace mindspore::lite::quant {
class InsertQuantNodeManager {
public:
InsertQuantNodeManager() = default;
~InsertQuantNodeManager() = default;
int InsertQuantDtypeCastPass(const FuncGraphPtr &graph);
int InsertDynamicQuantPass(const FuncGraphPtr &graph);
private:
ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params);
int InsertCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t input_index, bool is_graph_input);
int CheckDataType(const AnfNodePtr &input_node, TypeId check_type_id);
int NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode);
int MarkDynamicQuantize(const CNodePtr &cnode);
private:
TypeId dst_type_ = kNumberTypeInt8;
bool symmetric_ = false;
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_INSERT_QUANT_NODE_MANAGER_H

View File

@ -22,8 +22,8 @@
#include <set>
#include <functional>
#include "include/version.h"
#include "ops/fusion/full_connection.h"
#include "ops/fusion/mat_mul_fusion.h"
#include "ops/fusion/conv2d_transpose_fusion.h"
#include "tools/converter/ops/ops_def.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "tools/converter/quantizer/bitpacking.h"
@ -32,7 +32,6 @@
#include "abstract/abstract_value.h"
#include "securec/include/securec.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/common/format_utils.h"
using std::string;
using std::vector;
@ -138,6 +137,17 @@ QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) {
return quant_params_holder;
}
int GetQuantType(const CNodePtr &cnode) {
MS_CHECK_TRUE_RET(cnode != nullptr, RET_NULL_PTR);
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_ERROR;
}
auto quant_param_holder = GetCNodeQuantHolder(primitive);
return quant_param_holder->quant_type();
}
bool TensorQuantParamsInited(const schema::TensorT &tensor) {
if (tensor.quantParams.empty()) {
return false;
@ -194,10 +204,11 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc
std::vector<std::vector<float>> clusters_data(clusters.size());
for (size_t i = 0; i < elem_count; i++) {
size_t index = 0;
float min_distance = pow(data[i] - clusters[0], 2);
const int pow_index = 2;
float min_distance = pow(data[i] - clusters[0], pow_index);
for (size_t j = 1; j < clusters.size(); j++) {
if (pow(data[i] - clusters[j], 2) < min_distance) {
min_distance = pow(data[i] - clusters[j], 2);
if (pow(data[i] - clusters[j], pow_index) < min_distance) {
min_distance = pow(data[i] - clusters[j], pow_index);
index = j;
}
}
@ -341,6 +352,11 @@ int UpdateTensorDataAndSize(const AnfNodePtr &node, const tensor::TensorPtr &wei
return RET_ERROR;
}
// set dtype
auto ret = UpdateDataType(node, new_data_type);
if (ret != RET_OK) {
MS_LOG(ERROR) << node->fullname_with_scope() << " set new dtype failed.";
return ret;
}
auto abstract_base = node->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of node is nullptr, " << node->fullname_with_scope();
@ -381,6 +397,20 @@ int GetMatMulPreferredDim(const PrimitivePtr &primitive, int input_index, const
return 0;
}
int GetDeConvPreferredDim(const PrimitivePtr &primitive, const std::vector<int> &dims) {
auto prim = primitive->cast<std::shared_ptr<ops::Conv2DTranspose>>();
MS_ASSERT(prim != nullptr);
// For MatMul A
if (prim->get_in_channel() == prim->get_group() && prim->get_out_channel() == prim->get_group()) {
// DepthWise-DeConv (CO\CI) KH KW 1
return 0;
} else {
// DeConv:CI KH KW CO
return dims.size() - 1;
}
return 0;
}
int CalChannels(const std::vector<int> &dims, int channel_cnt, bool *channel_at_first) {
auto channels = dims[0];
if (!(*channel_at_first)) {
@ -399,6 +429,8 @@ int CalChannels(const std::vector<int> &dims, int channel_cnt, bool *channel_at_
int GetPreferredDim(const PrimitivePtr &primitive, int input_index, const std::vector<int> &dims) {
if (primitive->name() == ops::kNameMatMulFusion) {
return GetMatMulPreferredDim(primitive, input_index, dims);
} else if (primitive->name() == ops::kNameConv2dTransposeFusion) {
return 0;
}
// The first index.
return 0;

View File

@ -102,6 +102,8 @@ int DeQuantData(mindspore::tensor::MSTensor *tensor, std::vector<double> *dequan
int DoBitPack(const size_t &bit_num, schema::TensorT *tensor_input);
int GetQuantType(const CNodePtr &cnode);
template <typename T>
int DeQuantData(const int8_t *tensor_data, int64_t elements_num, std::vector<lite::LiteQuantParam> quant_params,
std::vector<T> *dequant_data, int preferred_dim = 0) {

View File

@ -31,29 +31,61 @@ WeightQuantizer::~WeightQuantizer() {
}
}
int WeightQuantizer::DoWeightQuantize(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
int WeightQuantizer::WeightQuant(const FuncGraphPtr &func_graph,
const std::set<PrimitivePtr> &support_weight_quant_types,
const std::set<PrimitivePtr> &per_layer_types,
const std::set<PrimitivePtr> &symmetric_types) {
for (auto &cnode : func_graph->GetOrderedCnodes()) {
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
if (primitive == nullptr) {
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
continue;
}
if (!CheckNodeInSet(cnode, support_weight_quant_types)) {
MS_LOG(INFO) << cnode->fullname_with_scope() << " of type: " << primitive->name() << " dont need weight quant.";
continue;
}
WeightQuantType weight_quant_type = WeightQuantType::FIXED_BIT_PER_CHANNEL;
if (CheckNodeInSet(cnode, per_layer_types)) {
weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
}
bool symmetric = false;
int q_min = quant_min_;
int q_max = quant_max_;
if (CheckNodeInSet(cnode, symmetric_types)) {
symmetric = true;
q_min = symmetric_quant_min_;
q_max = symmetric_quant_max_;
}
std::vector<int> weight_indices;
if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam)) {
weight_indices = {2, 3};
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
weight_indices = {4, 6};
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimApplyMomentum)) {
weight_indices = {2};
} else {
for (size_t i = 1; i < cnode->size(); ++i) {
weight_indices.push_back(i);
}
}
auto status = DoCNodeWeightQuant(func_graph, cnode, weight_indices, weight_quant_type, q_min, q_max, symmetric);
if (status != RET_OK) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << " do weight quantize error";
return RET_ERROR;
}
}
return RET_OK;
}
int WeightQuantizer::DoCNodeWeightQuant(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
const std::vector<int> &weight_indices, WeightQuantType weight_quant_type,
int q_min, int q_max, bool symmetric) {
CHECK_NULL_RETURN(cnode);
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
CHECK_NULL_RETURN(primitive);
WeightQuantType weight_quant_type = WeightQuantType::FIXED_BIT_PER_CHANNEL;
auto manager = api::FuncGraphManager::Manage(func_graph, true);
CHECK_NULL_RETURN(manager);
std::set<PrimitivePtr> per_layer_primitive_types = {prim::kPrimAdam, prim::kPrimSGD, prim::kPrimApplyMomentum};
if (CheckNodeInSet(cnode, per_layer_primitive_types)) {
weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
}
std::vector<int> weight_indices;
if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam)) {
weight_indices = {2, 3};
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
weight_indices = {4, 6};
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimApplyMomentum)) {
weight_indices = {2};
} else {
for (size_t i = 1; i < cnode->size(); ++i) {
weight_indices.push_back(i);
}
}
for (auto idx : weight_indices) {
auto input = cnode->input(idx);
ParameterPtr parameter;
@ -85,11 +117,11 @@ int WeightQuantizer::DoWeightQuantize(const FuncGraphPtr &func_graph, const CNod
status = MixedBitQuantFilter(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT,
WeightQuantType::MIXED_BIT_PER_LAYER, type_id_, mixed_bit_init_scale_, idx - 1);
} else if (type_id_ == kNumberTypeInt8) {
status = FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT, quant_max_,
quant_min_, bit_num_, tmp_weight_quant_type, type_id_, idx - 1);
status = FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT, q_max, q_min,
bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
} else if (type_id_ == kNumberTypeInt16) {
status = FixedBitQuantFilter<int16_t>(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT, quant_max_,
quant_min_, bit_num_, tmp_weight_quant_type, type_id_, idx - 1);
status = FixedBitQuantFilter<int16_t>(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT, q_max, q_min,
bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
}
if (status == RET_NO_CHANGE) {
continue;
@ -111,8 +143,9 @@ int WeightQuantizer::DoMarkWeightQuantizeIfQuantized(const CNodePtr &cnode) {
}
auto quant_param_holder = GetCNodeQuantHolder(primitive);
if (quant_param_holder->quant_type() == schema::QuantType_QUANT_WEIGHT) {
// already marked with QUANT_WEIGHT
if (quant_param_holder->quant_type() == schema::QuantType_QUANT_WEIGHT ||
quant_param_holder->quant_type() == schema::QuantType_QUANT_DANAMIC) {
// already marked with QuantType_QUANT_WEIGHT or QuantType_QUANT_DANAMIC
return RET_OK;
}
@ -153,28 +186,16 @@ int WeightQuantizer::DoQuantize(const FuncGraphPtr &func_graph, double init_scal
mixed_bit_init_scale_ = init_scale;
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
weight_quantized_tensors_.clear();
for (auto &cnode : func_graph->GetOrderedCnodes()) {
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
if (primitive == nullptr) {
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
continue;
}
auto op_name = cnode->fullname_with_scope();
std::set<PrimitivePtr> support_primitive_types = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
prim::kPrimMatMulFusion, prim::kPrimFullConnection,
prim::kPrimLstm, prim::kPrimGather,
prim::kPrimAdam, prim::kPrimSGD,
prim::kPrimApplyMomentum};
if (CheckNodeInSet(cnode, support_primitive_types)) {
auto status = DoWeightQuantize(func_graph, cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoWeightQuantize error";
return RET_ERROR;
}
} else {
MS_LOG(DEBUG) << op_name << " of type: " << primitive->name() << " no need quant";
}
const std::set<PrimitivePtr> support_primitive_types = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
prim::kPrimMatMulFusion, prim::kPrimFullConnection,
prim::kPrimLstm, prim::kPrimGather,
prim::kPrimAdam, prim::kPrimSGD,
prim::kPrimApplyMomentum};
std::set<PrimitivePtr> per_layer_primitive_types = {prim::kPrimAdam, prim::kPrimSGD, prim::kPrimApplyMomentum};
auto ret = WeightQuant(func_graph, support_primitive_types, per_layer_primitive_types, {});
if (ret != RET_OK) {
MS_LOG(ERROR) << "Weight Quant failed.";
return ret;
}
return MarkWeightQuantizationInNodes(func_graph);
}

View File

@ -51,8 +51,10 @@ class WeightQuantizer : public Quantizer {
}
// parse param for fixed bit quant.
if (!is_mixed_bit_) {
quant_max_ = QuantMax(bit_num_, false);
quant_min_ = QuantMin(bit_num_, false, false);
quant_max_ = QuantMax(bit_num_, false);
symmetric_quant_min_ = QuantMin(bit_num_, false, true);
symmetric_quant_max_ = QuantMax(bit_num_, false);
// parse type_id_
MS_ASSERT(bit_num_ > 0 && bit_num_ <= k16Bit);
if (bit_num_ > 0 && bit_num_ <= k8Bit) {
@ -67,10 +69,14 @@ class WeightQuantizer : public Quantizer {
int DoQuantize(FuncGraphPtr func_graph) override;
int DoQuantize(const FuncGraphPtr &func_graph, double init_scale);
int WeightQuant(const FuncGraphPtr &func_graph, const std::set<PrimitivePtr> &support_weight_quant_types,
const std::set<PrimitivePtr> &per_layer_types, const std::set<PrimitivePtr> &symmetric_types);
private:
int DoWeightQuantize(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
int MarkWeightQuantizationInNodes(const FuncGraphPtr &);
int DoMarkWeightQuantizeIfQuantized(const CNodePtr &);
int DoCNodeWeightQuant(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &weight_indices,
WeightQuantType weight_quant_type, int q_min, int q_max, bool symmetric);
private:
size_t bit_num_{8};
@ -79,9 +85,10 @@ class WeightQuantizer : public Quantizer {
std::vector<std::unordered_map<std::string, mindspore::tensor::MSTensor *>> fp32_output_tensors_;
bool is_mixed_bit_ = false;
double mixed_bit_init_scale_ = 0.02;
int quant_max_{127};
int quant_min_{-128};
int quant_max_{127};
int symmetric_quant_min_{-127};
int symmetric_quant_max_{127};
TypeId type_id_{kNumberTypeInt8};
};
} // namespace mindspore::lite::quant

View File

@ -21,7 +21,6 @@
#include "ops/op_utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "src/tensor.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "src/common/log_adapter.h"
#include "nnacl/op_base.h"

View File

@ -20,7 +20,6 @@
#include "ops/fusion/conv2d_fusion.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "src/tensor.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "src/common/log_adapter.h"
#include "tools/common/tensor_util.h"
#include "securec/include/securec.h"
@ -83,10 +82,10 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) {
}
MS_CHECK_TRUE_RET(data_shape.size() == DIMENSION_4D, false);
MS_CHECK_TRUE_RET(weight_shape.size() == DIMENSION_4D, false);
if (data_shape[3] == 1 || data_shape[3] != weight_shape[3]) {
if (data_shape[kNHWC_C] == 1 || data_shape[kNHWC_C] != weight_shape[kNHWC_C]) {
conv2d_fusion->EraseAttr(ops::kIsDepthWise);
conv2d_fusion->set_group(static_cast<int64_t>(data_shape[3]));
conv2d_fusion->set_in_channel(data_shape[3]);
conv2d_fusion->set_group(static_cast<int64_t>(data_shape[kNHWC_C]));
conv2d_fusion->set_in_channel(data_shape[kNHWC_C]);
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
auto weight_node = conv_cnode->input(kConvWeightIndex);
MS_ASSERT(weight_node != nullptr);