!29314 dynamic only support symmetric & support debug & support skip & fix more input

Merge pull request !29314 from yeyunpeng2020/dynamic_quant
This commit is contained in:
i-robot 2022-01-20 06:11:19 +00:00 committed by Gitee
commit f618de798d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 114 additions and 55 deletions

View File

@ -331,29 +331,31 @@ void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int c
}
#endif
void DynamicMatmulInt8Opt(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col, int deep16,
float input_scale, int input_zp, const float *filter_scale, size_t stride) {
void DynamicMatmulInt8AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
int deep16, float input_scale, const float *filter_scale, size_t stride,
bool filter_per_channel) {
/* *
* row4x16-major * row16x4-major => (int8)row-major
* support activation per-layer asymmetric && weight per-channel symmetric
* support activation per-layer symmetric && weight per-layer/per-channel symmetric
* */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM;
int c4div = c / C4NUM, c4mod = c % C4NUM;
size_t ci = r * stride + c;
int filter_quant_index = filter_per_channel ? c : 0;
double multi_scale = input_scale * filter_scale[filter_quant_index];
double value = 0;
for (int d = 0; d < deep16; d++) {
int d16div = d / C16NUM, d16mod = d % C16NUM;
size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
int32_t value_1 = a[ai] * b[bi];
int32_t value_3 = input_zp * b[bi];
value += input_scale * filter_scale[c] * (value_1 - value_3);
value += multi_scale * value_1;
}
if (bias != NULL) {
value += bias[c];
}
size_t ci = r * stride + c;
dst[ci] = value;
}
}

View File

@ -46,8 +46,10 @@ void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int c
const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc,
const int32_t *filter_zp);
void DynamicMatmulInt8Opt(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col, int deep16,
float input_scale, int input_zp, const float *filter_scale, size_t stride);
void DynamicMatmulInt8AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
int deep16, float input_scale, const float *filter_scale, size_t stride,
bool filter_per_channel);
/* 8x4 4x8 -> 8x8 */
/* optimize conv */
void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);

View File

@ -50,9 +50,9 @@ int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) {
if (cur_oc <= 0) {
return RET_OK;
}
DynamicMatmulInt8Opt(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, fp32_bias_ptr_,
batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_align_, quant_param_->input_scale_,
quant_param_->input_zp_, quant_param_->filter_scale_, param_->col_);
DynamicMatmulInt8AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, fp32_bias_ptr_,
batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_align_,
quant_param_->input_scale_, quant_param_->filter_scale_, param_->col_, filter_per_channel_);
return RET_OK;
}
@ -77,33 +77,47 @@ void MatmulDynamicInt8CPUKernel::FreeQuantParam() {
}
int MatmulDynamicInt8CPUKernel::MallocQuantParam() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto weight_quant_params = weight_tensor->quant_params();
auto w_shape = weight_tensor->shape();
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() >= DIMENSION_2D, lite::RET_ERROR, "weight dims should >=2");
int col = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1];
filter_per_channel_ = (weight_quant_params.size() > 1);
channel_num_ = filter_per_channel_ ? col : 1;
quant_param_ = reinterpret_cast<MatmulDynamicQuantParameter *>(malloc(sizeof(MatmulQuantParameter)));
if (quant_param_ == nullptr) {
MS_LOG(ERROR) << "Malloc MatmulDynamicQuantParameter for Matmul int8 op failed!";
return RET_ERROR;
}
memset(quant_param_, 0, sizeof(MatmulQuantParameter));
return RET_OK;
}
int MatmulDynamicInt8CPUKernel::InitFilterQuantParam() {
if (quant_param_->filter_scale_ != nullptr) {
free(quant_param_->filter_scale_);
quant_param_->filter_scale_ = nullptr;
}
if (quant_param_->filter_zp_ != nullptr) {
free(quant_param_->filter_zp_);
quant_param_->filter_zp_ = nullptr;
}
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto weight_quant_params = weight_tensor->quant_params();
auto w_shape = weight_tensor->shape();
if (w_shape.size() < DIMENSION_2D) {
MS_LOG(ERROR) << weight_tensor->tensor_name() << " dims < 2.";
return RET_ERROR;
}
int col = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1];
filter_per_channel_ = (weight_quant_params.size() > 1);
channel_num_ = filter_per_channel_ ? col : 1;
if (static_cast<int>(weight_quant_params.size()) != channel_num_) {
MS_LOG(ERROR) << weight_tensor->tensor_name() << " quant params size:" << weight_quant_params.size()
<< " != channel_num_:" << channel_num_;
return RET_ERROR;
}
quant_param_->filter_scale_ = reinterpret_cast<float *>(malloc(channel_num_ * sizeof(float)));
CHECK_NULL_RETURN(quant_param_->filter_scale_);
memset(quant_param_->filter_scale_, 0, sizeof(channel_num_));
quant_param_->filter_zp_ = reinterpret_cast<int32_t *>(malloc(channel_num_ * sizeof(int32_t)));
CHECK_NULL_RETURN(quant_param_->filter_zp_);
memset(quant_param_->filter_zp_, 0, sizeof(channel_num_));
return RET_OK;
}
int MatmulDynamicInt8CPUKernel::InitFilterQuantParam() {
auto weight_tensor = in_tensors_.at(kWeightIndex);
auto weight_quant_params = weight_tensor->quant_params();
MS_CHECK_TRUE_RET(static_cast<int>(weight_quant_params.size()) == channel_num_, RET_ERROR);
for (int i = 0; i < channel_num_; i++) {
quant_param_->filter_scale_[i] = static_cast<float>(weight_quant_params[i].scale);
quant_param_->filter_zp_[i] = weight_quant_params[i].zeroPoint;
@ -212,7 +226,7 @@ int MatmulDynamicInt8CPUKernel::TransferB() {
b_pack_func_(current_weight, current_b_pack, param_->deep_, param_->col_);
}
}
return RET_ERROR;
return RET_OK;
}
int MatmulDynamicInt8CPUKernel::InitTmpBuffer() {
@ -252,25 +266,23 @@ int MatmulDynamicInt8CPUKernel::Prepare() {
CHECK_LESS_RETURN(in_tensors_.size(), kMinInputSize);
CHECK_LESS_RETURN(out_tensors_.size(), kOutputSize);
InitParameter();
auto ret = MallocQuantParam();
if (ret != RET_OK) {
FreeQuantParam();
return ret;
}
if (param_->b_const_) {
ret = InitFilterQuantParam();
if (ret != RET_OK) {
FreeQuantParam();
return ret;
}
}
ret = CopyBias();
if (ret != RET_OK) {
FreeQuantParam();
return ret;
}
if (!InferShapeDone()) {
return RET_OK;
}
@ -313,6 +325,12 @@ int MatmulDynamicInt8CPUKernel::Run() {
return ret;
}
if (!param_->b_const_) {
ret = InitFilterQuantParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init filter quant param failed.";
FreeQuantParam();
return ret;
}
ret = TransferB();
if (ret != RET_OK) {
MS_LOG(ERROR) << "TransferB failed.";

View File

@ -80,6 +80,10 @@ kernel::InnerKernel *MatmulInt8CPUKernelCreator(const std::vector<lite::Tensor *
kernel =
new (std::nothrow) MatmulInt8CPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
} else if (parameter->quant_type_ == schema::QuantType_QUANT_DYNAMIC) {
if (inputs.front()->IsConst()) {
MS_LOG(ERROR) << "kernel: " << parameter->name_ << " is unsupported A is const.";
return nullptr;
}
kernel = new (std::nothrow)
MatmulDynamicInt8CPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
} else {

View File

@ -23,7 +23,6 @@ int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) {
// Dynamic dont support filters.
flags_.commonQuantParam.min_quant_weight_channel = 0;
flags_.commonQuantParam.min_quant_weight_size = 0;
flags_.commonQuantParam.skip_quant_node.clear();
auto quantizer = WeightQuantizer(flags_);
const std::set<PrimitivePtr> support_weight_quant_nodes = {prim::kPrimMatMulFusion, prim::kPrimGather};
const std::set<PrimitivePtr> symmetric_nodes = {prim::kPrimMatMulFusion};
@ -36,7 +35,7 @@ int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) {
const std::set<PrimitivePtr> support_dynamic_quant_ops = {
prim::kPrimMatMulFusion,
};
ret = manager.InsertDynamicQuantNode(func_graph, support_dynamic_quant_ops);
ret = manager.InsertDynamicQuantNode(func_graph, support_dynamic_quant_ops, flags_.commonQuantParam.skip_quant_node);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert dynamic quant failed.";
return ret;

View File

@ -18,18 +18,22 @@
#include <memory>
#include <set>
#include <vector>
#include <string>
#include "ops/quant_dtype_cast.h"
#include "ops/dynamic_quant.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/optimizer/common/format_utils.h"
#include "tools/common/node_util.h"
namespace mindspore::lite::quant {
ValueNodePtr InsertQuantNodeManager::NewQuantCastValueNode(int src_type, int dst_type_,
namespace {
constexpr size_t kMinSize3 = 3;
constexpr size_t kPrimitiveCOffset = 1;
} // namespace
ValueNodePtr InsertQuantNodeManager::NewQuantCastValueNode(int src_type, int dst_type,
const std::vector<schema::QuantParamT> &quant_params) {
auto prim_c = std::make_shared<ops::QuantDTypeCast>();
MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
prim_c->Init(src_type, dst_type_);
prim_c->Init(src_type, dst_type);
auto quant_params_holder = std::make_shared<QuantParamHolder>(quant_params.size(), quant_params.size());
MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL);
@ -158,17 +162,14 @@ int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph)
return RET_OK;
}
int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode) {
int InsertQuantNodeManager::InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode,
size_t index) {
auto primitive_c = std::make_shared<ops::DynamicQuant>();
primitive_c->set_dst_type(dst_type_);
primitive_c->set_symmetric(symmetric_);
auto op_name = cnode->fullname_with_scope();
if (cnode->size() <= kInputSize1) {
MS_LOG(ERROR) << op_name << " cnode size <= 2.";
return RET_ERROR;
}
auto dynamic_quant_cnode = graph->NewCNode(primitive_c, {cnode->input(1)});
dynamic_quant_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dynamic_cast_node");
auto dynamic_quant_cnode = graph->NewCNode(primitive_c, {cnode->input(index)});
auto name = cnode->fullname_with_scope() + "_dynamic_cast_node_" + to_string(index);
dynamic_quant_cnode->set_fullname_with_scope(name);
CHECK_NULL_RETURN(cnode->abstract());
auto abstract = cnode->abstract()->Clone();
if (abstract == nullptr) {
@ -182,7 +183,24 @@ int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const
return ret;
}
MarkDynamicQuantize(dynamic_quant_cnode);
cnode->set_input(1, dynamic_quant_cnode);
cnode->set_input(index, dynamic_quant_cnode);
return RET_OK;
}
int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode) {
auto op_name = cnode->fullname_with_scope();
if (cnode->size() < kMinSize3) {
MS_LOG(ERROR) << op_name << " cnode size:" << cnode->size() << " < 3.";
return RET_ERROR;
}
auto input = cnode->input(kInputIndex + kPrimitiveCOffset);
if (input->isa<mindspore::CNode>() || IsGraphInput(input)) {
InsertDynamicQuantWithIndex(graph, cnode, kInputIndex + kPrimitiveCOffset);
}
auto weight = cnode->input(kWeightIndex + kPrimitiveCOffset);
if (weight->isa<mindspore::CNode>() || IsGraphInput(weight)) {
InsertDynamicQuantWithIndex(graph, cnode, kWeightIndex + kPrimitiveCOffset);
}
return RET_OK;
}
@ -199,10 +217,16 @@ int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) {
}
int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
const std::set<PrimitivePtr> &support_dynamic_quant_ops) {
const std::set<PrimitivePtr> &support_dynamic_quant_ops,
const std::set<std::string> &skip_quant_node) {
MS_ASSERT(graph != nullptr);
auto cnodes = graph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
auto op_name = cnode->fullname_with_scope();
if (skip_quant_node.find(op_name) != skip_quant_node.end()) {
MS_LOG(INFO) << op_name << " is skip dynamic quant.";
continue;
}
auto ret = CheckDataType(cnode, kNumberTypeFloat32);
if (ret == RET_NO_CHANGE) {
continue;
@ -210,22 +234,22 @@ int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
auto is_support_node = CheckNodeInSet(cnode, support_dynamic_quant_ops);
if (!is_support_node) {
auto type = NodePrimitiveType(cnode);
MS_LOG(INFO) << "node:" << cnode->fullname_with_scope() << " type:" << type << " will not quantify.";
MS_LOG(INFO) << "node:" << op_name << " type:" << type << " will not quantify.";
continue;
}
ret = NewDynamicQuantNode(graph, cnode);
if (ret != RET_OK) {
MS_LOG(ERROR) << "node:" << cnode->fullname_with_scope() << " new dynamic quant node failed.";
MS_LOG(ERROR) << "node:" << op_name << " new dynamic quant node failed.";
return ret;
}
ret = MarkDynamicQuantize(cnode);
if (ret != RET_OK) {
MS_LOG(ERROR) << "node:" << cnode->fullname_with_scope() << " new mark dynamic quant node failed.";
MS_LOG(ERROR) << "node:" << op_name << " new mark dynamic quant node failed.";
return ret;
}
ret = UpdateDataType(cnode, kNumberTypeFloat32);
if (ret != RET_OK) {
MS_LOG(ERROR) << "node:" << cnode->fullname_with_scope() << " update datatype failed.";
MS_LOG(ERROR) << "node:" << op_name << " update datatype failed.";
return ret;
}
}

View File

@ -18,11 +18,13 @@
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_INSERT_QUANT_NODE_MANAGER_H
#include <vector>
#include <set>
#include <string>
#include "include/errorcode.h"
#include "ir/anf.h"
#include "ir/dtype/type_id.h"
#include "ir/func_graph.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "ops/dynamic_quant.h"
namespace mindspore::lite::quant {
class InsertQuantNodeManager {
@ -33,7 +35,8 @@ class InsertQuantNodeManager {
int InsertQuantDtypeCastNode(const FuncGraphPtr &graph);
int InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set<PrimitivePtr> &support_dynamic_quant_ops);
int InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set<PrimitivePtr> &support_dynamic_quant_ops,
const std::set<std::string> &skip_quant_node);
private:
ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params);
@ -46,9 +49,11 @@ class InsertQuantNodeManager {
int MarkDynamicQuantize(const CNodePtr &cnode);
int InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index);
private:
TypeId dst_type_ = kNumberTypeInt8;
bool symmetric_ = false;
bool symmetric_ = true;
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_INSERT_QUANT_NODE_MANAGER_H

View File

@ -41,6 +41,11 @@ int WeightQuantizer::WeightQuant(const FuncGraphPtr &func_graph,
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
continue;
}
auto op_name = cnode->fullname_with_scope();
if (flags_.commonQuantParam.skip_quant_node.find(op_name) != flags_.commonQuantParam.skip_quant_node.end()) {
MS_LOG(INFO) << op_name << " is skip dynamic quant.";
continue;
}
if (!CheckNodeInSet(cnode, support_weight_quant_types)) {
MS_LOG(INFO) << cnode->fullname_with_scope() << " of type: " << primitive->name() << " dont need weight quant.";
continue;