forked from mindspore-Ecosystem/mindspore
!29119 converter support dynamic quant
Merge pull request !29119 from yeyunpeng2020/dynamic_quant
This commit is contained in:
commit
1bec7aea97
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_DYNAMIC_QUANT_PARAMETER_H_
|
||||
#define MINDSPORE_NNACL_DYNAMIC_QUANT_PARAMETER_H_
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct DynamicQuantParameter {
|
||||
OpParameter op_parameter_;
|
||||
bool symmetric_;
|
||||
int64_t dst_type_;
|
||||
} DynamicQuantParameter;
|
||||
|
||||
#endif // MINDSPORE_NNACL_DYNAMIC_QUANT_PARAMETER_H_
|
|
@ -128,8 +128,9 @@ enum NNACLQuantType {
|
|||
QuantType_PostTraining = 3,
|
||||
QuantType_QUANT_WEIGHT = 4,
|
||||
QuantType_QUANT_ALL = 5,
|
||||
QuantType_QUANT_DYNAMIC = 6,
|
||||
QuantType_MIN = QuantType_QUANT_NONE,
|
||||
QuantType_MAX = QuantType_QUANT_ALL
|
||||
QuantType_MAX = QuantType_QUANT_DYNAMIC
|
||||
};
|
||||
|
||||
typedef struct vvector {
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/infer/dynamic_quant_infer.h"
|
||||
#include "nnacl/infer/infer_register.h"
|
||||
#include "nnacl/dynamic_quant_parameter.h"
|
||||
|
||||
int DynamicQuantInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1);
|
||||
if (check_ret != NNACL_OK) {
|
||||
return check_ret;
|
||||
}
|
||||
|
||||
const TensorC *input = inputs[0];
|
||||
TensorC *output = outputs[0];
|
||||
|
||||
DynamicQuantParameter *param = (DynamicQuantParameter *)parameter;
|
||||
output->data_type_ = param->dst_type_;
|
||||
MS_CHECK_TRUE_RET(output->data_type_ > kNumberTypeBegin && output->data_type_ < kNumberTypeEnd, NNACL_ERR);
|
||||
output->format_ = input->format_;
|
||||
if (!InferFlag(inputs, inputs_size)) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
SetShapeTensor(output, input);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
REG_INFER(DynamicQuant, PrimType_DynamicQuant, DynamicQuantInferShape)
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,21 +13,19 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_NNACL_DYNAMIC_QUANT_INFER_H
|
||||
#define MINDSPORE_NNACL_DYNAMIC_QUANT_INFER_H
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_CAST_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_CAST_H
|
||||
#include "nnacl/infer/common_infer.h"
|
||||
|
||||
#include "include/errorcode.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "ir/func_graph.h"
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class QuantCast {
|
||||
public:
|
||||
QuantCast() = default;
|
||||
virtual ~QuantCast() = default;
|
||||
int Run(const FuncGraphPtr &graph);
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_CAST_H
|
||||
int DynamicQuantInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_NNACL_DYNAMIC_QUANT_INFER_H
|
|
@ -56,27 +56,12 @@ int CheckMatmulInputShape(int *a_shape, size_t a_shape_size, int *b_shape, size_
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1);
|
||||
if (check_ret != NNACL_OK) {
|
||||
return check_ret;
|
||||
}
|
||||
|
||||
int SetShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
TensorC *input0 = (TensorC *)inputs[0];
|
||||
TensorC *input1 = (TensorC *)inputs[1];
|
||||
TensorC *output = outputs[0];
|
||||
|
||||
int diff = abs((int)input0->shape_size_ - (int)input1->shape_size_);
|
||||
TensorC *in = input0->shape_size_ > input1->shape_size_ ? input1 : input0;
|
||||
for (int i = 0; i < diff; ++i) {
|
||||
ShapeInsert(in->shape_, &in->shape_size_, 0, 1);
|
||||
}
|
||||
SetDataTypeFormat(output, input0);
|
||||
MatMulParameter *param = (MatMulParameter *)parameter;
|
||||
if (!InferFlag(inputs, inputs_size)) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
int a_shape[MAX_SHAPE_SIZE] = {0};
|
||||
size_t a_shape_size = 0;
|
||||
ShapeSet(a_shape, &a_shape_size, input0->shape_, input0->shape_size_);
|
||||
|
@ -137,4 +122,30 @@ int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
|
|||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
int check_ret = CheckAugmentNullSizeInputTwo(inputs, inputs_size, outputs, outputs_size, parameter, 2, 3, 1);
|
||||
if (check_ret != NNACL_OK) {
|
||||
return check_ret;
|
||||
}
|
||||
|
||||
TensorC *input0 = (TensorC *)inputs[0];
|
||||
TensorC *input1 = (TensorC *)inputs[1];
|
||||
TensorC *output = outputs[0];
|
||||
|
||||
int diff = abs((int)input0->shape_size_ - (int)input1->shape_size_);
|
||||
TensorC *in = input0->shape_size_ > input1->shape_size_ ? input1 : input0;
|
||||
for (int i = 0; i < diff; ++i) {
|
||||
ShapeInsert(in->shape_, &in->shape_size_, 0, 1);
|
||||
}
|
||||
SetDataTypeFormat(output, input0);
|
||||
if (parameter->quant_type_ == QuantType_QUANT_DYNAMIC) {
|
||||
output->data_type_ = kNumberTypeFloat32;
|
||||
}
|
||||
if (!InferFlag(inputs, inputs_size)) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
return SetShape(inputs, inputs_size, outputs, outputs_size, parameter);
|
||||
}
|
||||
|
||||
REG_INFER(MatMul, PrimType_MatMulFusion, MatmulInferShape)
|
||||
|
|
|
@ -457,8 +457,9 @@ enum PrimType {
|
|||
PrimType_Affine = 200,
|
||||
PrimType_AllGather = 201,
|
||||
PrimType_ReduceScatter = 202,
|
||||
PrimType_DynamicQuant = 203,
|
||||
PrimType_MIN = PrimType_NONE,
|
||||
PrimType_MAX = PrimType_ReduceScatter + 1
|
||||
PrimType_MAX = PrimType_DynamicQuant + 1
|
||||
};
|
||||
|
||||
typedef enum LiteDataType {
|
||||
|
|
|
@ -53,17 +53,17 @@ class MS_CORE_API DynamicQuant : public PrimitiveC {
|
|||
/// \param[in] symmetric Define whether symmetric quantization.
|
||||
void set_symmetric(const bool symmetric);
|
||||
|
||||
/// \brief Method to get src_t attribute.
|
||||
/// \brief Method to get symmetric attribute.
|
||||
///
|
||||
/// \return Whether symmetric quantization.
|
||||
bool get_symmetric() const;
|
||||
|
||||
/// \brief Method to set dst_t attribute.
|
||||
/// \brief Method to set dst_type attribute.
|
||||
///
|
||||
/// \param[in] dst_t Define the data type of output.
|
||||
void set_dst_type(const int64_t dst_type);
|
||||
|
||||
/// \brief Method to get dst_t attribute.
|
||||
/// \brief Method to get dst_type attribute.
|
||||
///
|
||||
/// \return the data type of output.
|
||||
int64_t get_dst_type() const;
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "nnacl/dynamic_quant_parameter.h"
|
||||
using mindspore::schema::PrimitiveType_DynamicQuant;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
OpParameter *PopulateDynamicQuantParameter(const void *prim) {
|
||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto value = primitive->value_as_DynamicQuant();
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "value is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto *param = reinterpret_cast<DynamicQuantParameter *>(malloc(sizeof(DynamicQuantParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc DynamicQuantParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(DynamicQuantParameter));
|
||||
|
||||
param->op_parameter_.type_ = primitive->value_type();
|
||||
param->dst_type_ = value->dst_type();
|
||||
param->symmetric_ = value->symmetric();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
REG_POPULATE(PrimitiveType_DynamicQuant, PopulateDynamicQuantParameter, SCHEMA_CUR);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -18,6 +18,7 @@
|
|||
#include "src/weight_decoder.h"
|
||||
#include "src/huffman_decode.h"
|
||||
#include "tools/converter/quantizer/fse_decoder.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
|
@ -416,9 +417,24 @@ int WeightDecoder::GetMatMulPreferredDim(OpParameter *op_parameter, int input_in
|
|||
return 0;
|
||||
}
|
||||
|
||||
int WeightDecoder::GetDeConvPreferredDim(OpParameter *op_parameter, const std::vector<int> &dims) {
|
||||
MS_ASSERT(op_parameter != nullptr);
|
||||
auto parameter = reinterpret_cast<ConvParameter *>(op_parameter);
|
||||
if (parameter->input_channel_ == parameter->group_ && parameter->output_channel_ == parameter->group_) {
|
||||
// DepthWise-DeConv (CO\CI) KH KW 1
|
||||
return 0;
|
||||
} else {
|
||||
// DeConv:CI KH KW CO
|
||||
return dims.size() - 1;
|
||||
}
|
||||
}
|
||||
|
||||
int WeightDecoder::GetPreferredDim(OpParameter *op_parameter, int index, const std::vector<int> &dims) {
|
||||
MS_ASSERT(op_parameter != nullptr);
|
||||
if (op_parameter->type_ == schema::PrimitiveType_MatMulFusion) {
|
||||
return GetMatMulPreferredDim(op_parameter, index, dims);
|
||||
} else if (op_parameter->type_ == schema::PrimitiveType_Conv2dTransposeFusion) {
|
||||
return 0;
|
||||
}
|
||||
// The first index.
|
||||
return 0;
|
||||
|
|
|
@ -241,7 +241,7 @@ class WeightDecoder {
|
|||
}
|
||||
|
||||
static int GetMatMulPreferredDim(OpParameter *op_parameter, int input_index, const std::vector<int> &dims);
|
||||
|
||||
static int GetDeConvPreferredDim(OpParameter *op_parameter, const std::vector<int> &dims);
|
||||
static int DequantWeight(lite::Tensor *input_tensor, int preferred_dim, TypeId dst_data_type = kNumberTypeFloat32);
|
||||
|
||||
template <typename T1, typename T2>
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
ssd.r1.1.mindir 1.3 5401000
|
||||
ssd.r1.1.mindir 1.3 5401040
|
||||
ml_segmentation_matting 130 160224
|
||||
ml_video_edit_enhance.pb 22 546552
|
||||
hiai_ghostnet.tflite 4.7 5745336
|
||||
|
|
|
@ -435,5 +435,22 @@ ValueNodePtr GetCallAnfPrim() {
|
|||
MS_CHECK_TRUE_MSG(call_anf_prim != nullptr, nullptr, "call_anf_prim is nullptr");
|
||||
return call_anf_prim;
|
||||
}
|
||||
|
||||
int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type) {
|
||||
auto abstract_base = cnode->abstract();
|
||||
if (abstract_base == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of node is nullptr, " << cnode->fullname_with_scope();
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
|
||||
MS_LOG(ERROR) << "Abstract of node should be anstract tensor, " << cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||
CHECK_NULL_RETURN(abstract_tensor);
|
||||
CHECK_NULL_RETURN(abstract_tensor->element());
|
||||
abstract_tensor->element()->set_type(TypeIdToType(new_data_type));
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -432,6 +432,8 @@ ValueNodePtr GetCallAnfPrim();
|
|||
inline bool IsGraphInput(const AnfNodePtr &cnode) {
|
||||
return cnode->isa<Parameter>() && !cnode->cast<ParameterPtr>()->has_default();
|
||||
}
|
||||
|
||||
int UpdateDataType(const AnfNodePtr &cnode, TypeId new_data_type);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_COMMON_NODE_UTIL_H
|
||||
|
|
|
@ -75,8 +75,8 @@
|
|||
#include "tools/optimizer/graph/specify_graph_input_format.h"
|
||||
#include "tools/optimizer/graph/dump_graph.h"
|
||||
#include "tools/converter/quantizer/full_quant_quantizer.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||
#include "tools/converter/quantizer/dynamic_quantizer.h"
|
||||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
#include "tools/optimizer/parallel/spliter.h"
|
||||
#include "tools/optimizer/fisson/iter_node_outputs.h"
|
||||
|
@ -375,101 +375,147 @@ void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGr
|
|||
}
|
||||
}
|
||||
|
||||
int DoFullQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
auto quantizer = std::make_unique<quant::FullQuantQuantizer>(*config);
|
||||
if (quantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New FullQuantQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = quantizer->DoQuantize(old_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantization failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DoWeightQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
double init_scale = config->mixedBitWeightQuantParam.init_scale;
|
||||
if (config->commonQuantParam.bit_num == 0 && config->mixedBitWeightQuantParam.auto_tune) {
|
||||
quant::ParameterOptimizer optimizer;
|
||||
auto status = optimizer.GridSearchForScale(old_graph, const_cast<converter::Flags *>(config), &init_scale);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Grid search with scale failed.";
|
||||
return status;
|
||||
}
|
||||
auto quantizer = std::make_unique<quant::WeightQuantizer>(*config);
|
||||
if (quantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New WeightQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = static_cast<quant::WeightQuantizer *>(quantizer.get())->DoQuantize(old_graph, init_scale);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantization failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
auto quantizer = std::make_unique<quant::WeightQuantizer>(*config);
|
||||
if (quantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New WeightQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = quantizer->DoQuantize(old_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantization failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DoDynamicQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
auto quantizer = std::make_unique<quant::DynamicQuantizer>(*config);
|
||||
if (quantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New DynamicQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = quantizer->DoQuantize(old_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantization failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DoQuantDebug(const FuncGraphPtr &old_graph, const converter::Flags *config, const quant::SessionModel &origin) {
|
||||
auto quant = quant::CreateSessionByFuncGraph(old_graph, *config, config->commonQuantParam.thread_num);
|
||||
std::map<std::string, OpParameter *> op_parameters;
|
||||
FetchOpParameterFromFuncGraph(old_graph, &op_parameters);
|
||||
DebugInfoManager manager;
|
||||
CHECK_NULL_RETURN(origin.model);
|
||||
CHECK_NULL_RETURN(origin.session);
|
||||
CHECK_NULL_RETURN(quant.model);
|
||||
CHECK_NULL_RETURN(quant.session);
|
||||
auto status = manager.CompareOriginWithQuant(
|
||||
origin, quant, op_parameters, config->commonQuantParam.debug_info_save_path, config->dataPreProcessParam);
|
||||
auto free_buffer = [&] {
|
||||
delete origin.session;
|
||||
delete origin.model;
|
||||
delete quant.session;
|
||||
delete quant.model;
|
||||
for (auto parameter : op_parameters) {
|
||||
if (parameter.second != nullptr) {
|
||||
free(parameter.second);
|
||||
parameter.second = nullptr;
|
||||
}
|
||||
}
|
||||
op_parameters.clear();
|
||||
};
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Compare origin with quant failed.";
|
||||
free_buffer();
|
||||
return status;
|
||||
}
|
||||
free_buffer();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
// quant
|
||||
if (config->commonQuantParam.quant_type != schema::QuantType_QUANT_ALL &&
|
||||
config->commonQuantParam.quant_type != schema::QuantType_QUANT_WEIGHT) {
|
||||
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_NONE) {
|
||||
return RET_OK;
|
||||
}
|
||||
int status;
|
||||
std::unique_ptr<quant::Quantizer> quantizer = nullptr;
|
||||
|
||||
quant::SessionModel origin;
|
||||
quant::SessionModel quant;
|
||||
if (config->commonQuantParam.is_debug) {
|
||||
converter::Flags new_flag = *config;
|
||||
new_flag.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
|
||||
origin = quant::CreateSessionByFuncGraph(old_graph, new_flag, config->commonQuantParam.thread_num);
|
||||
}
|
||||
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) {
|
||||
quantizer = std::make_unique<quant::FullQuantQuantizer>(*config);
|
||||
if (quantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New FullQuantQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = quantizer->DoQuantize(old_graph);
|
||||
status = DoFullQuant(old_graph, config);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantization failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return RET_ERROR;
|
||||
MS_LOG(ERROR) << "Do full quant failed.";
|
||||
return status;
|
||||
}
|
||||
} else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) {
|
||||
double init_scale = config->mixedBitWeightQuantParam.init_scale;
|
||||
if (config->commonQuantParam.bit_num == 0 && config->mixedBitWeightQuantParam.auto_tune) {
|
||||
quant::ParameterOptimizer optimizer;
|
||||
status = optimizer.GridSearchForScale(old_graph, const_cast<converter::Flags *>(config), &init_scale);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Grid search with scale failed.";
|
||||
return status;
|
||||
}
|
||||
quantizer = std::make_unique<quant::WeightQuantizer>(*config);
|
||||
if (quantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New WeightQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = static_cast<quant::WeightQuantizer *>(quantizer.get())->DoQuantize(old_graph, init_scale);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantization failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
quantizer = std::make_unique<quant::WeightQuantizer>(*config);
|
||||
if (quantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New WeightQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = quantizer->DoQuantize(old_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantization failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = DoWeightQuant(old_graph, config);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do weight quant failed.";
|
||||
return status;
|
||||
}
|
||||
} else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_DANAMIC) {
|
||||
status = DoDynamicQuant(old_graph, config);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do dynamic quant failed.";
|
||||
return status;
|
||||
}
|
||||
}
|
||||
if (config->commonQuantParam.is_debug) {
|
||||
quant = quant::CreateSessionByFuncGraph(old_graph, *config, config->commonQuantParam.thread_num);
|
||||
std::map<std::string, OpParameter *> op_parameters;
|
||||
FetchOpParameterFromFuncGraph(old_graph, &op_parameters);
|
||||
DebugInfoManager manager;
|
||||
CHECK_NULL_RETURN(origin.model);
|
||||
CHECK_NULL_RETURN(origin.session);
|
||||
CHECK_NULL_RETURN(quant.model);
|
||||
CHECK_NULL_RETURN(quant.session);
|
||||
status = manager.CompareOriginWithQuant(origin, quant, op_parameters, config->commonQuantParam.debug_info_save_path,
|
||||
config->dataPreProcessParam);
|
||||
auto free_buffer = [&] {
|
||||
delete origin.session;
|
||||
delete origin.model;
|
||||
delete quant.session;
|
||||
delete quant.model;
|
||||
for (auto parameter : op_parameters) {
|
||||
if (parameter.second != nullptr) {
|
||||
free(parameter.second);
|
||||
parameter.second = nullptr;
|
||||
}
|
||||
}
|
||||
op_parameters.clear();
|
||||
};
|
||||
status = DoQuantDebug(old_graph, config, origin);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Compare origin with quant failed.";
|
||||
free_buffer();
|
||||
MS_LOG(ERROR) << "Do quant debug failed.";
|
||||
return status;
|
||||
}
|
||||
free_buffer();
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
#include "tools/converter/preprocess/opencv_utils.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "mindspore/lite/tools/common/string_util.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -73,6 +73,11 @@ int QuantParamParser::ParseBitNum(const CommonQuantString &common_quant_string,
|
|||
MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be [1,8].";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
} else if (common_quant->quant_type == schema::QuantType_QUANT_DANAMIC) {
|
||||
if (common_quant->bit_num != kQuantBitNumInt8) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be 8.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -155,11 +160,14 @@ int QuantParamParser::ParseQuantType(const std::string &quant_type_str, schema::
|
|||
} else if (quant_type_str == "FULL_QUANT") {
|
||||
(*quant_type) = schema::QuantType_QUANT_ALL;
|
||||
return RET_OK;
|
||||
} else if (quant_type_str == "DYNAMIC_QUANT") {
|
||||
(*quant_type) = schema::QuantType_QUANT_DANAMIC;
|
||||
return RET_OK;
|
||||
} else if (quant_type_str.empty()) {
|
||||
(*quant_type) = schema::QuantType_QUANT_NONE;
|
||||
return RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: quant_type must be WEIGHT_QUANT|FULL_QUANT.";
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: quant_type must be WEIGHT_QUANT|FULL_QUANT|DYNAMIC_QUANT.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
#include "ops/space_to_batch_nd.h"
|
||||
#include "ops/space_to_depth.h"
|
||||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/common/log_util.h"
|
||||
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
[common_quant_param]
|
||||
quant_type=DYNAMIC_QUANT
|
||||
bit_num=8
|
||||
# Layers with size of weights exceeds threshold `min_quant_weight_size` will be quantized.
|
||||
min_quant_weight_size=0
|
||||
# Layers with channel size of weights exceeds threshold `min_quant_weight_channel` will be quantized.
|
||||
min_quant_weight_channel=16
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/quantizer/dynamic_quantizer.h"
|
||||
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||
#include "tools/converter/quantizer/insert_quant_node_manager.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||
InsertQuantNodeManager manager;
|
||||
auto ret = manager.InsertDynamicQuantPass(func_graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Insert dynamic quant failed.";
|
||||
return ret;
|
||||
}
|
||||
auto quantizer = WeightQuantizer(flags_);
|
||||
flags_.commonQuantParam.min_quant_weight_channel = 0;
|
||||
flags_.commonQuantParam.min_quant_weight_size = 0;
|
||||
const std::set<PrimitivePtr> support_weight_quant_nodes = {prim::kPrimMatMulFusion, prim::kPrimGather};
|
||||
const std::set<PrimitivePtr> symmetric_nodes = {prim::kPrimMatMulFusion};
|
||||
ret = quantizer.WeightQuant(func_graph, support_weight_quant_nodes, {}, symmetric_nodes);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Weight Quant failed.";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DYNAMIC_QUANTIZER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DYNAMIC_QUANTIZER_H_
|
||||
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <map>
|
||||
#include <list>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include "tools/converter/quantizer/quantizer.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/converter/quantizer/quant_params.h"
|
||||
#include "tools/converter/quantizer/quant_strategy.h"
|
||||
#include "tools/converter/preprocess/preprocess_param.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
#include "include/model.h"
|
||||
#include "base/base.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "src/lite_session.h"
|
||||
#include "src/common/quant_utils.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class DynamicQuantizer : public Quantizer {
|
||||
public:
|
||||
explicit DynamicQuantizer(const converter::Flags &flags) : Quantizer(flags) {
|
||||
bit_num_ = flags.commonQuantParam.bit_num;
|
||||
}
|
||||
~DynamicQuantizer() = default;
|
||||
|
||||
int DoQuantize(FuncGraphPtr func_graph) override;
|
||||
|
||||
private:
|
||||
size_t bit_num_{8};
|
||||
int quant_max_{127};
|
||||
int quant_min_{-128};
|
||||
TypeId type_id_{kNumberTypeInt8};
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_
|
|
@ -23,7 +23,7 @@
|
|||
#include <vector>
|
||||
#include "ops/tuple_get_item.h"
|
||||
#include "src/tensor.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
#include "tools/converter/quantizer/insert_quant_node_manager.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
@ -618,8 +618,8 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
}
|
||||
if (activation_target_data_type_ == kNumberTypeInt8 || activation_target_data_type_ == kNumberTypeUInt8) {
|
||||
// add quant_cast
|
||||
quant::QuantCast quant_cast;
|
||||
status = quant_cast.Run(func_graph);
|
||||
quant::InsertQuantNodeManager inset_quant_node_pass;
|
||||
status = inset_quant_node_pass.InsertQuantDtypeCastPass(func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "add QuantCast error";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,19 +14,22 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindspore/lite/tools/converter/quantizer/quant_cast.h"
|
||||
#include "mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h"
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/quant_dtype_cast.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "ops/dynamic_quant.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params) {
|
||||
ValueNodePtr InsertQuantNodeManager::NewQuantCastValueNode(int src_type, int dst_type_,
|
||||
const std::vector<schema::QuantParamT> &quant_params) {
|
||||
auto prim_c = std::make_shared<ops::QuantDTypeCast>();
|
||||
MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
|
||||
prim_c->Init(src_type, dst_type);
|
||||
prim_c->Init(src_type, dst_type_);
|
||||
auto quant_params_holder = std::make_shared<QuantParamHolder>(quant_params.size(), quant_params.size());
|
||||
MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
|
||||
quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL);
|
||||
|
@ -40,7 +43,8 @@ ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector
|
|||
return NewValueNode(prim_c);
|
||||
}
|
||||
|
||||
int InsertCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t input_index, bool is_graph_input) {
|
||||
int InsertQuantNodeManager::InsertCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t input_index,
|
||||
bool is_graph_input) {
|
||||
auto primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(0));
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
|
||||
|
@ -107,7 +111,8 @@ int InsertCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t inpu
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int CheckDataType(const AnfNodePtr &input_node, bool is_graph_input) {
|
||||
int InsertQuantNodeManager::CheckDataType(const AnfNodePtr &input_node, TypeId check_type_id) {
|
||||
bool is_graph_input = IsGraphInput(input_node);
|
||||
if (!input_node->isa<mindspore::CNode>() && !is_graph_input) {
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
|
@ -120,27 +125,27 @@ int CheckDataType(const AnfNodePtr &input_node, bool is_graph_input) {
|
|||
MS_LOG(ERROR) << "Fetch DataType from cnode failed.";
|
||||
return ret;
|
||||
}
|
||||
if (type_id != kNumberTypeFloat32) {
|
||||
if (type_id != check_type_id) {
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int QuantCast::Run(const FuncGraphPtr &graph) {
|
||||
int InsertQuantNodeManager::InsertQuantDtypeCastPass(const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto cnodes = graph->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||
auto input_node = cnode->input(i);
|
||||
auto is_graph_input = input_node->isa<Parameter>() && !input_node->cast<ParameterPtr>()->has_default();
|
||||
auto ret = CheckDataType(input_node, is_graph_input);
|
||||
auto ret = CheckDataType(input_node, kNumberTypeFloat32);
|
||||
if (ret == RET_NO_CHANGE) {
|
||||
continue;
|
||||
} else if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Check data type failed.";
|
||||
return ret;
|
||||
}
|
||||
bool is_graph_input = IsGraphInput(input_node);
|
||||
ret = InsertCastNode(graph, cnode, i, is_graph_input);
|
||||
if (ret == RET_NO_CHANGE) {
|
||||
continue;
|
||||
|
@ -152,4 +157,80 @@ int QuantCast::Run(const FuncGraphPtr &graph) {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode) {
|
||||
auto primitive_c = std::make_shared<ops::DynamicQuant>();
|
||||
primitive_c->set_dst_type(dst_type_);
|
||||
primitive_c->set_symmetric(symmetric_);
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
if (cnode->size() <= kInputSize1) {
|
||||
MS_LOG(ERROR) << op_name << " cnode size <= 2.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto dynamic_quant_cnode = graph->NewCNode(primitive_c, {cnode->input(1)});
|
||||
dynamic_quant_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dynamic_cast_node");
|
||||
CHECK_NULL_RETURN(cnode->abstract());
|
||||
auto abstract = cnode->abstract()->Clone();
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of node is nullptr, " << cnode->fullname_with_scope();
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
dynamic_quant_cnode->set_abstract(abstract);
|
||||
auto ret = UpdateDataType(cnode, dst_type_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << cnode->fullname_with_scope() << " set new dtype failed.";
|
||||
return ret;
|
||||
}
|
||||
MarkDynamicQuantize(dynamic_quant_cnode);
|
||||
cnode->set_input(1, dynamic_quant_cnode);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) {
|
||||
MS_CHECK_TRUE_RET(cnode != nullptr, RET_NULL_PTR);
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
quant_param_holder->set_quant_type(schema::QuantType_QUANT_DANAMIC);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int InsertQuantNodeManager::InsertDynamicQuantPass(const FuncGraphPtr &graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto cnodes = graph->GetOrderedCnodes();
|
||||
const std::set<PrimitivePtr> support_dynamic_quant_ops = {
|
||||
prim::kPrimMatMulFusion,
|
||||
};
|
||||
for (auto &cnode : cnodes) {
|
||||
auto ret = CheckDataType(cnode, kNumberTypeFloat32);
|
||||
if (ret == RET_NO_CHANGE) {
|
||||
continue;
|
||||
}
|
||||
auto is_support_node = CheckNodeInSet(cnode, support_dynamic_quant_ops);
|
||||
if (!is_support_node) {
|
||||
auto type = NodePrimitiveType(cnode);
|
||||
MS_LOG(INFO) << "node:" << cnode->fullname_with_scope() << " type:" << type << " will not quantify.";
|
||||
continue;
|
||||
}
|
||||
ret = NewDynamicQuantNode(graph, cnode);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "node:" << cnode->fullname_with_scope() << " new dynamic quant node failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = MarkDynamicQuantize(cnode);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "node:" << cnode->fullname_with_scope() << " new mark dynamic quant node failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = UpdateDataType(cnode, kNumberTypeFloat32);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "node:" << cnode->fullname_with_scope() << " update datatype failed.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_INSERT_QUANT_NODE_MANAGER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_INSERT_QUANT_NODE_MANAGER_H
|
||||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class InsertQuantNodeManager {
|
||||
public:
|
||||
InsertQuantNodeManager() = default;
|
||||
|
||||
~InsertQuantNodeManager() = default;
|
||||
|
||||
int InsertQuantDtypeCastPass(const FuncGraphPtr &graph);
|
||||
|
||||
int InsertDynamicQuantPass(const FuncGraphPtr &graph);
|
||||
|
||||
private:
|
||||
ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params);
|
||||
|
||||
int InsertCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t input_index, bool is_graph_input);
|
||||
|
||||
int CheckDataType(const AnfNodePtr &input_node, TypeId check_type_id);
|
||||
|
||||
int NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode);
|
||||
|
||||
int MarkDynamicQuantize(const CNodePtr &cnode);
|
||||
|
||||
private:
|
||||
TypeId dst_type_ = kNumberTypeInt8;
|
||||
bool symmetric_ = false;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_INSERT_QUANT_NODE_MANAGER_H
|
|
@ -22,8 +22,8 @@
|
|||
#include <set>
|
||||
#include <functional>
|
||||
#include "include/version.h"
|
||||
#include "ops/fusion/full_connection.h"
|
||||
#include "ops/fusion/mat_mul_fusion.h"
|
||||
#include "ops/fusion/conv2d_transpose_fusion.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
#include "tools/converter/quantizer/bitpacking.h"
|
||||
|
@ -32,7 +32,6 @@
|
|||
#include "abstract/abstract_value.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
@ -138,6 +137,17 @@ QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) {
|
|||
return quant_params_holder;
|
||||
}
|
||||
|
||||
int GetQuantType(const CNodePtr &cnode) {
|
||||
MS_CHECK_TRUE_RET(cnode != nullptr, RET_NULL_PTR);
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
return quant_param_holder->quant_type();
|
||||
}
|
||||
|
||||
bool TensorQuantParamsInited(const schema::TensorT &tensor) {
|
||||
if (tensor.quantParams.empty()) {
|
||||
return false;
|
||||
|
@ -194,10 +204,11 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc
|
|||
std::vector<std::vector<float>> clusters_data(clusters.size());
|
||||
for (size_t i = 0; i < elem_count; i++) {
|
||||
size_t index = 0;
|
||||
float min_distance = pow(data[i] - clusters[0], 2);
|
||||
const int pow_index = 2;
|
||||
float min_distance = pow(data[i] - clusters[0], pow_index);
|
||||
for (size_t j = 1; j < clusters.size(); j++) {
|
||||
if (pow(data[i] - clusters[j], 2) < min_distance) {
|
||||
min_distance = pow(data[i] - clusters[j], 2);
|
||||
if (pow(data[i] - clusters[j], pow_index) < min_distance) {
|
||||
min_distance = pow(data[i] - clusters[j], pow_index);
|
||||
index = j;
|
||||
}
|
||||
}
|
||||
|
@ -341,6 +352,11 @@ int UpdateTensorDataAndSize(const AnfNodePtr &node, const tensor::TensorPtr &wei
|
|||
return RET_ERROR;
|
||||
}
|
||||
// set dtype
|
||||
auto ret = UpdateDataType(node, new_data_type);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << node->fullname_with_scope() << " set new dtype failed.";
|
||||
return ret;
|
||||
}
|
||||
auto abstract_base = node->abstract();
|
||||
if (abstract_base == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of node is nullptr, " << node->fullname_with_scope();
|
||||
|
@ -381,6 +397,20 @@ int GetMatMulPreferredDim(const PrimitivePtr &primitive, int input_index, const
|
|||
return 0;
|
||||
}
|
||||
|
||||
int GetDeConvPreferredDim(const PrimitivePtr &primitive, const std::vector<int> &dims) {
|
||||
auto prim = primitive->cast<std::shared_ptr<ops::Conv2DTranspose>>();
|
||||
MS_ASSERT(prim != nullptr);
|
||||
// For MatMul A
|
||||
if (prim->get_in_channel() == prim->get_group() && prim->get_out_channel() == prim->get_group()) {
|
||||
// DepthWise-DeConv (CO\CI) KH KW 1
|
||||
return 0;
|
||||
} else {
|
||||
// DeConv:CI KH KW CO
|
||||
return dims.size() - 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int CalChannels(const std::vector<int> &dims, int channel_cnt, bool *channel_at_first) {
|
||||
auto channels = dims[0];
|
||||
if (!(*channel_at_first)) {
|
||||
|
@ -399,6 +429,8 @@ int CalChannels(const std::vector<int> &dims, int channel_cnt, bool *channel_at_
|
|||
int GetPreferredDim(const PrimitivePtr &primitive, int input_index, const std::vector<int> &dims) {
|
||||
if (primitive->name() == ops::kNameMatMulFusion) {
|
||||
return GetMatMulPreferredDim(primitive, input_index, dims);
|
||||
} else if (primitive->name() == ops::kNameConv2dTransposeFusion) {
|
||||
return 0;
|
||||
}
|
||||
// The first index.
|
||||
return 0;
|
||||
|
|
|
@ -102,6 +102,8 @@ int DeQuantData(mindspore::tensor::MSTensor *tensor, std::vector<double> *dequan
|
|||
|
||||
int DoBitPack(const size_t &bit_num, schema::TensorT *tensor_input);
|
||||
|
||||
int GetQuantType(const CNodePtr &cnode);
|
||||
|
||||
template <typename T>
|
||||
int DeQuantData(const int8_t *tensor_data, int64_t elements_num, std::vector<lite::LiteQuantParam> quant_params,
|
||||
std::vector<T> *dequant_data, int preferred_dim = 0) {
|
||||
|
|
|
@ -31,29 +31,61 @@ WeightQuantizer::~WeightQuantizer() {
|
|||
}
|
||||
}
|
||||
|
||||
int WeightQuantizer::DoWeightQuantize(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
int WeightQuantizer::WeightQuant(const FuncGraphPtr &func_graph,
|
||||
const std::set<PrimitivePtr> &support_weight_quant_types,
|
||||
const std::set<PrimitivePtr> &per_layer_types,
|
||||
const std::set<PrimitivePtr> &symmetric_types) {
|
||||
for (auto &cnode : func_graph->GetOrderedCnodes()) {
|
||||
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
|
||||
continue;
|
||||
}
|
||||
if (!CheckNodeInSet(cnode, support_weight_quant_types)) {
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << " of type: " << primitive->name() << " dont need weight quant.";
|
||||
continue;
|
||||
}
|
||||
WeightQuantType weight_quant_type = WeightQuantType::FIXED_BIT_PER_CHANNEL;
|
||||
if (CheckNodeInSet(cnode, per_layer_types)) {
|
||||
weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||
}
|
||||
bool symmetric = false;
|
||||
int q_min = quant_min_;
|
||||
int q_max = quant_max_;
|
||||
if (CheckNodeInSet(cnode, symmetric_types)) {
|
||||
symmetric = true;
|
||||
q_min = symmetric_quant_min_;
|
||||
q_max = symmetric_quant_max_;
|
||||
}
|
||||
std::vector<int> weight_indices;
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam)) {
|
||||
weight_indices = {2, 3};
|
||||
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
|
||||
weight_indices = {4, 6};
|
||||
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimApplyMomentum)) {
|
||||
weight_indices = {2};
|
||||
} else {
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
weight_indices.push_back(i);
|
||||
}
|
||||
}
|
||||
auto status = DoCNodeWeightQuant(func_graph, cnode, weight_indices, weight_quant_type, q_min, q_max, symmetric);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << cnode->fullname_with_scope() << " do weight quantize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int WeightQuantizer::DoCNodeWeightQuant(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
|
||||
const std::vector<int> &weight_indices, WeightQuantType weight_quant_type,
|
||||
int q_min, int q_max, bool symmetric) {
|
||||
CHECK_NULL_RETURN(cnode);
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
CHECK_NULL_RETURN(primitive);
|
||||
WeightQuantType weight_quant_type = WeightQuantType::FIXED_BIT_PER_CHANNEL;
|
||||
auto manager = api::FuncGraphManager::Manage(func_graph, true);
|
||||
CHECK_NULL_RETURN(manager);
|
||||
std::set<PrimitivePtr> per_layer_primitive_types = {prim::kPrimAdam, prim::kPrimSGD, prim::kPrimApplyMomentum};
|
||||
if (CheckNodeInSet(cnode, per_layer_primitive_types)) {
|
||||
weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||
}
|
||||
std::vector<int> weight_indices;
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam)) {
|
||||
weight_indices = {2, 3};
|
||||
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
|
||||
weight_indices = {4, 6};
|
||||
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimApplyMomentum)) {
|
||||
weight_indices = {2};
|
||||
} else {
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
weight_indices.push_back(i);
|
||||
}
|
||||
}
|
||||
for (auto idx : weight_indices) {
|
||||
auto input = cnode->input(idx);
|
||||
ParameterPtr parameter;
|
||||
|
@ -85,11 +117,11 @@ int WeightQuantizer::DoWeightQuantize(const FuncGraphPtr &func_graph, const CNod
|
|||
status = MixedBitQuantFilter(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT,
|
||||
WeightQuantType::MIXED_BIT_PER_LAYER, type_id_, mixed_bit_init_scale_, idx - 1);
|
||||
} else if (type_id_ == kNumberTypeInt8) {
|
||||
status = FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT, quant_max_,
|
||||
quant_min_, bit_num_, tmp_weight_quant_type, type_id_, idx - 1);
|
||||
status = FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT, q_max, q_min,
|
||||
bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = FixedBitQuantFilter<int16_t>(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT, quant_max_,
|
||||
quant_min_, bit_num_, tmp_weight_quant_type, type_id_, idx - 1);
|
||||
status = FixedBitQuantFilter<int16_t>(parameter, tensor_info, primitive, QuantType_QUANT_WEIGHT, q_max, q_min,
|
||||
bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
|
||||
}
|
||||
if (status == RET_NO_CHANGE) {
|
||||
continue;
|
||||
|
@ -111,8 +143,9 @@ int WeightQuantizer::DoMarkWeightQuantizeIfQuantized(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
if (quant_param_holder->quant_type() == schema::QuantType_QUANT_WEIGHT) {
|
||||
// already marked with QUANT_WEIGHT
|
||||
if (quant_param_holder->quant_type() == schema::QuantType_QUANT_WEIGHT ||
|
||||
quant_param_holder->quant_type() == schema::QuantType_QUANT_DANAMIC) {
|
||||
// already marked with QuantType_QUANT_WEIGHT or QuantType_QUANT_DANAMIC
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -153,28 +186,16 @@ int WeightQuantizer::DoQuantize(const FuncGraphPtr &func_graph, double init_scal
|
|||
mixed_bit_init_scale_ = init_scale;
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
|
||||
weight_quantized_tensors_.clear();
|
||||
|
||||
for (auto &cnode : func_graph->GetOrderedCnodes()) {
|
||||
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
|
||||
continue;
|
||||
}
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
std::set<PrimitivePtr> support_primitive_types = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
|
||||
prim::kPrimMatMulFusion, prim::kPrimFullConnection,
|
||||
prim::kPrimLstm, prim::kPrimGather,
|
||||
prim::kPrimAdam, prim::kPrimSGD,
|
||||
prim::kPrimApplyMomentum};
|
||||
if (CheckNodeInSet(cnode, support_primitive_types)) {
|
||||
auto status = DoWeightQuantize(func_graph, cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoWeightQuantize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << op_name << " of type: " << primitive->name() << " no need quant";
|
||||
}
|
||||
const std::set<PrimitivePtr> support_primitive_types = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
|
||||
prim::kPrimMatMulFusion, prim::kPrimFullConnection,
|
||||
prim::kPrimLstm, prim::kPrimGather,
|
||||
prim::kPrimAdam, prim::kPrimSGD,
|
||||
prim::kPrimApplyMomentum};
|
||||
std::set<PrimitivePtr> per_layer_primitive_types = {prim::kPrimAdam, prim::kPrimSGD, prim::kPrimApplyMomentum};
|
||||
auto ret = WeightQuant(func_graph, support_primitive_types, per_layer_primitive_types, {});
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Weight Quant failed.";
|
||||
return ret;
|
||||
}
|
||||
return MarkWeightQuantizationInNodes(func_graph);
|
||||
}
|
||||
|
|
|
@ -51,8 +51,10 @@ class WeightQuantizer : public Quantizer {
|
|||
}
|
||||
// parse param for fixed bit quant.
|
||||
if (!is_mixed_bit_) {
|
||||
quant_max_ = QuantMax(bit_num_, false);
|
||||
quant_min_ = QuantMin(bit_num_, false, false);
|
||||
quant_max_ = QuantMax(bit_num_, false);
|
||||
symmetric_quant_min_ = QuantMin(bit_num_, false, true);
|
||||
symmetric_quant_max_ = QuantMax(bit_num_, false);
|
||||
// parse type_id_
|
||||
MS_ASSERT(bit_num_ > 0 && bit_num_ <= k16Bit);
|
||||
if (bit_num_ > 0 && bit_num_ <= k8Bit) {
|
||||
|
@ -67,10 +69,14 @@ class WeightQuantizer : public Quantizer {
|
|||
int DoQuantize(FuncGraphPtr func_graph) override;
|
||||
int DoQuantize(const FuncGraphPtr &func_graph, double init_scale);
|
||||
|
||||
int WeightQuant(const FuncGraphPtr &func_graph, const std::set<PrimitivePtr> &support_weight_quant_types,
|
||||
const std::set<PrimitivePtr> &per_layer_types, const std::set<PrimitivePtr> &symmetric_types);
|
||||
|
||||
private:
|
||||
int DoWeightQuantize(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
int MarkWeightQuantizationInNodes(const FuncGraphPtr &);
|
||||
int DoMarkWeightQuantizeIfQuantized(const CNodePtr &);
|
||||
int DoCNodeWeightQuant(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &weight_indices,
|
||||
WeightQuantType weight_quant_type, int q_min, int q_max, bool symmetric);
|
||||
|
||||
private:
|
||||
size_t bit_num_{8};
|
||||
|
@ -79,9 +85,10 @@ class WeightQuantizer : public Quantizer {
|
|||
std::vector<std::unordered_map<std::string, mindspore::tensor::MSTensor *>> fp32_output_tensors_;
|
||||
bool is_mixed_bit_ = false;
|
||||
double mixed_bit_init_scale_ = 0.02;
|
||||
|
||||
int quant_max_{127};
|
||||
int quant_min_{-128};
|
||||
int quant_max_{127};
|
||||
int symmetric_quant_min_{-127};
|
||||
int symmetric_quant_max_{127};
|
||||
TypeId type_id_{kNumberTypeInt8};
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include "ops/op_utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "src/tensor.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "ops/fusion/conv2d_fusion.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "src/tensor.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "securec/include/securec.h"
|
||||
|
@ -83,10 +82,10 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) {
|
|||
}
|
||||
MS_CHECK_TRUE_RET(data_shape.size() == DIMENSION_4D, false);
|
||||
MS_CHECK_TRUE_RET(weight_shape.size() == DIMENSION_4D, false);
|
||||
if (data_shape[3] == 1 || data_shape[3] != weight_shape[3]) {
|
||||
if (data_shape[kNHWC_C] == 1 || data_shape[kNHWC_C] != weight_shape[kNHWC_C]) {
|
||||
conv2d_fusion->EraseAttr(ops::kIsDepthWise);
|
||||
conv2d_fusion->set_group(static_cast<int64_t>(data_shape[3]));
|
||||
conv2d_fusion->set_in_channel(data_shape[3]);
|
||||
conv2d_fusion->set_group(static_cast<int64_t>(data_shape[kNHWC_C]));
|
||||
conv2d_fusion->set_in_channel(data_shape[kNHWC_C]);
|
||||
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
|
||||
auto weight_node = conv_cnode->input(kConvWeightIndex);
|
||||
MS_ASSERT(weight_node != nullptr);
|
||||
|
|
Loading…
Reference in New Issue