!29314 dynamic only support symmetric & support debug & support skip & fix more input
Merge pull request !29314 from yeyunpeng2020/dynamic_quant
This commit is contained in:
commit
f618de798d
|
@ -331,29 +331,31 @@ void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int c
|
|||
}
|
||||
#endif
|
||||
|
||||
void DynamicMatmulInt8Opt(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col, int deep16,
|
||||
float input_scale, int input_zp, const float *filter_scale, size_t stride) {
|
||||
void DynamicMatmulInt8AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
|
||||
int deep16, float input_scale, const float *filter_scale, size_t stride,
|
||||
bool filter_per_channel) {
|
||||
/* *
|
||||
* row4x16-major * row16x4-major => (int8)row-major
|
||||
* support activation per-layer asymmetric && weight per-channel symmetric
|
||||
* support activation per-layer symmetric && weight per-layer/per-channel symmetric
|
||||
* */
|
||||
for (int r = 0; r < row; r++) {
|
||||
for (int c = 0; c < col; c++) {
|
||||
int r4div = r / C4NUM, r4mod = r % C4NUM;
|
||||
int c4div = c / C4NUM, c4mod = c % C4NUM;
|
||||
size_t ci = r * stride + c;
|
||||
int filter_quant_index = filter_per_channel ? c : 0;
|
||||
double multi_scale = input_scale * filter_scale[filter_quant_index];
|
||||
double value = 0;
|
||||
for (int d = 0; d < deep16; d++) {
|
||||
int d16div = d / C16NUM, d16mod = d % C16NUM;
|
||||
size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
|
||||
size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
|
||||
int32_t value_1 = a[ai] * b[bi];
|
||||
int32_t value_3 = input_zp * b[bi];
|
||||
value += input_scale * filter_scale[c] * (value_1 - value_3);
|
||||
value += multi_scale * value_1;
|
||||
}
|
||||
if (bias != NULL) {
|
||||
value += bias[c];
|
||||
}
|
||||
size_t ci = r * stride + c;
|
||||
dst[ci] = value;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,8 +46,10 @@ void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int c
|
|||
const int32_t *left_shift, const int32_t *right_shift, size_t stride, size_t filter_peroc,
|
||||
const int32_t *filter_zp);
|
||||
|
||||
void DynamicMatmulInt8Opt(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col, int deep16,
|
||||
float input_scale, int input_zp, const float *filter_scale, size_t stride);
|
||||
void DynamicMatmulInt8AIWI(const int8_t *a, const int8_t *b, const float *bias, float *dst, int row, int col,
|
||||
int deep16, float input_scale, const float *filter_scale, size_t stride,
|
||||
bool filter_per_channel);
|
||||
|
||||
/* 8x4 4x8 -> 8x8 */
|
||||
/* optimize conv */
|
||||
void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
|
||||
|
|
|
@ -50,9 +50,9 @@ int MatmulDynamicInt8CPUKernel::RunImpl(int task_id) {
|
|||
if (cur_oc <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
DynamicMatmulInt8Opt(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, fp32_bias_ptr_,
|
||||
batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_align_, quant_param_->input_scale_,
|
||||
quant_param_->input_zp_, quant_param_->filter_scale_, param_->col_);
|
||||
DynamicMatmulInt8AIWI(pack_a_ptr_, batch_b_ptr_ + cur_stride * param_->deep_align_, fp32_bias_ptr_,
|
||||
batch_c_ptr_ + cur_stride, param_->row_, cur_oc, param_->deep_align_,
|
||||
quant_param_->input_scale_, quant_param_->filter_scale_, param_->col_, filter_per_channel_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -77,33 +77,47 @@ void MatmulDynamicInt8CPUKernel::FreeQuantParam() {
|
|||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::MallocQuantParam() {
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
auto weight_quant_params = weight_tensor->quant_params();
|
||||
auto w_shape = weight_tensor->shape();
|
||||
MS_CHECK_TRUE_MSG(weight_tensor->shape().size() >= DIMENSION_2D, lite::RET_ERROR, "weight dims should >=2");
|
||||
int col = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1];
|
||||
filter_per_channel_ = (weight_quant_params.size() > 1);
|
||||
channel_num_ = filter_per_channel_ ? col : 1;
|
||||
|
||||
quant_param_ = reinterpret_cast<MatmulDynamicQuantParameter *>(malloc(sizeof(MatmulQuantParameter)));
|
||||
if (quant_param_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc MatmulDynamicQuantParameter for Matmul int8 op failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
memset(quant_param_, 0, sizeof(MatmulQuantParameter));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::InitFilterQuantParam() {
|
||||
if (quant_param_->filter_scale_ != nullptr) {
|
||||
free(quant_param_->filter_scale_);
|
||||
quant_param_->filter_scale_ = nullptr;
|
||||
}
|
||||
if (quant_param_->filter_zp_ != nullptr) {
|
||||
free(quant_param_->filter_zp_);
|
||||
quant_param_->filter_zp_ = nullptr;
|
||||
}
|
||||
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
auto weight_quant_params = weight_tensor->quant_params();
|
||||
auto w_shape = weight_tensor->shape();
|
||||
if (w_shape.size() < DIMENSION_2D) {
|
||||
MS_LOG(ERROR) << weight_tensor->tensor_name() << " dims < 2.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int col = param_->b_transpose_ ? w_shape[w_shape.size() - kSize2] : w_shape[w_shape.size() - kSize1];
|
||||
filter_per_channel_ = (weight_quant_params.size() > 1);
|
||||
channel_num_ = filter_per_channel_ ? col : 1;
|
||||
if (static_cast<int>(weight_quant_params.size()) != channel_num_) {
|
||||
MS_LOG(ERROR) << weight_tensor->tensor_name() << " quant params size:" << weight_quant_params.size()
|
||||
<< " != channel_num_:" << channel_num_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
quant_param_->filter_scale_ = reinterpret_cast<float *>(malloc(channel_num_ * sizeof(float)));
|
||||
CHECK_NULL_RETURN(quant_param_->filter_scale_);
|
||||
memset(quant_param_->filter_scale_, 0, sizeof(channel_num_));
|
||||
quant_param_->filter_zp_ = reinterpret_cast<int32_t *>(malloc(channel_num_ * sizeof(int32_t)));
|
||||
CHECK_NULL_RETURN(quant_param_->filter_zp_);
|
||||
memset(quant_param_->filter_zp_, 0, sizeof(channel_num_));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::InitFilterQuantParam() {
|
||||
auto weight_tensor = in_tensors_.at(kWeightIndex);
|
||||
auto weight_quant_params = weight_tensor->quant_params();
|
||||
MS_CHECK_TRUE_RET(static_cast<int>(weight_quant_params.size()) == channel_num_, RET_ERROR);
|
||||
for (int i = 0; i < channel_num_; i++) {
|
||||
quant_param_->filter_scale_[i] = static_cast<float>(weight_quant_params[i].scale);
|
||||
quant_param_->filter_zp_[i] = weight_quant_params[i].zeroPoint;
|
||||
|
@ -212,7 +226,7 @@ int MatmulDynamicInt8CPUKernel::TransferB() {
|
|||
b_pack_func_(current_weight, current_b_pack, param_->deep_, param_->col_);
|
||||
}
|
||||
}
|
||||
return RET_ERROR;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MatmulDynamicInt8CPUKernel::InitTmpBuffer() {
|
||||
|
@ -252,25 +266,23 @@ int MatmulDynamicInt8CPUKernel::Prepare() {
|
|||
CHECK_LESS_RETURN(in_tensors_.size(), kMinInputSize);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), kOutputSize);
|
||||
InitParameter();
|
||||
|
||||
auto ret = MallocQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (param_->b_const_) {
|
||||
ret = InitFilterQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
|
||||
}
|
||||
ret = CopyBias();
|
||||
if (ret != RET_OK) {
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -313,6 +325,12 @@ int MatmulDynamicInt8CPUKernel::Run() {
|
|||
return ret;
|
||||
}
|
||||
if (!param_->b_const_) {
|
||||
ret = InitFilterQuantParam();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init filter quant param failed.";
|
||||
FreeQuantParam();
|
||||
return ret;
|
||||
}
|
||||
ret = TransferB();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "TransferB failed.";
|
||||
|
|
|
@ -80,6 +80,10 @@ kernel::InnerKernel *MatmulInt8CPUKernelCreator(const std::vector<lite::Tensor *
|
|||
kernel =
|
||||
new (std::nothrow) MatmulInt8CPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
|
||||
} else if (parameter->quant_type_ == schema::QuantType_QUANT_DYNAMIC) {
|
||||
if (inputs.front()->IsConst()) {
|
||||
MS_LOG(ERROR) << "kernel: " << parameter->name_ << " is unsupported A is const.";
|
||||
return nullptr;
|
||||
}
|
||||
kernel = new (std::nothrow)
|
||||
MatmulDynamicInt8CPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
|
||||
} else {
|
||||
|
|
|
@ -23,7 +23,6 @@ int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
// Dynamic dont support filters.
|
||||
flags_.commonQuantParam.min_quant_weight_channel = 0;
|
||||
flags_.commonQuantParam.min_quant_weight_size = 0;
|
||||
flags_.commonQuantParam.skip_quant_node.clear();
|
||||
auto quantizer = WeightQuantizer(flags_);
|
||||
const std::set<PrimitivePtr> support_weight_quant_nodes = {prim::kPrimMatMulFusion, prim::kPrimGather};
|
||||
const std::set<PrimitivePtr> symmetric_nodes = {prim::kPrimMatMulFusion};
|
||||
|
@ -36,7 +35,7 @@ int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
const std::set<PrimitivePtr> support_dynamic_quant_ops = {
|
||||
prim::kPrimMatMulFusion,
|
||||
};
|
||||
ret = manager.InsertDynamicQuantNode(func_graph, support_dynamic_quant_ops);
|
||||
ret = manager.InsertDynamicQuantNode(func_graph, support_dynamic_quant_ops, flags_.commonQuantParam.skip_quant_node);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Insert dynamic quant failed.";
|
||||
return ret;
|
||||
|
|
|
@ -18,18 +18,22 @@
|
|||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ops/quant_dtype_cast.h"
|
||||
#include "ops/dynamic_quant.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/optimizer/common/format_utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
ValueNodePtr InsertQuantNodeManager::NewQuantCastValueNode(int src_type, int dst_type_,
|
||||
namespace {
|
||||
constexpr size_t kMinSize3 = 3;
|
||||
constexpr size_t kPrimitiveCOffset = 1;
|
||||
} // namespace
|
||||
ValueNodePtr InsertQuantNodeManager::NewQuantCastValueNode(int src_type, int dst_type,
|
||||
const std::vector<schema::QuantParamT> &quant_params) {
|
||||
auto prim_c = std::make_shared<ops::QuantDTypeCast>();
|
||||
MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
|
||||
prim_c->Init(src_type, dst_type_);
|
||||
prim_c->Init(src_type, dst_type);
|
||||
auto quant_params_holder = std::make_shared<QuantParamHolder>(quant_params.size(), quant_params.size());
|
||||
MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
|
||||
quant_params_holder->set_quant_type(schema::QuantType_QUANT_ALL);
|
||||
|
@ -158,17 +162,14 @@ int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph)
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode) {
|
||||
int InsertQuantNodeManager::InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode,
|
||||
size_t index) {
|
||||
auto primitive_c = std::make_shared<ops::DynamicQuant>();
|
||||
primitive_c->set_dst_type(dst_type_);
|
||||
primitive_c->set_symmetric(symmetric_);
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
if (cnode->size() <= kInputSize1) {
|
||||
MS_LOG(ERROR) << op_name << " cnode size <= 2.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto dynamic_quant_cnode = graph->NewCNode(primitive_c, {cnode->input(1)});
|
||||
dynamic_quant_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dynamic_cast_node");
|
||||
auto dynamic_quant_cnode = graph->NewCNode(primitive_c, {cnode->input(index)});
|
||||
auto name = cnode->fullname_with_scope() + "_dynamic_cast_node_" + to_string(index);
|
||||
dynamic_quant_cnode->set_fullname_with_scope(name);
|
||||
CHECK_NULL_RETURN(cnode->abstract());
|
||||
auto abstract = cnode->abstract()->Clone();
|
||||
if (abstract == nullptr) {
|
||||
|
@ -182,7 +183,24 @@ int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const
|
|||
return ret;
|
||||
}
|
||||
MarkDynamicQuantize(dynamic_quant_cnode);
|
||||
cnode->set_input(1, dynamic_quant_cnode);
|
||||
cnode->set_input(index, dynamic_quant_cnode);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode) {
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
if (cnode->size() < kMinSize3) {
|
||||
MS_LOG(ERROR) << op_name << " cnode size:" << cnode->size() << " < 3.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input = cnode->input(kInputIndex + kPrimitiveCOffset);
|
||||
if (input->isa<mindspore::CNode>() || IsGraphInput(input)) {
|
||||
InsertDynamicQuantWithIndex(graph, cnode, kInputIndex + kPrimitiveCOffset);
|
||||
}
|
||||
auto weight = cnode->input(kWeightIndex + kPrimitiveCOffset);
|
||||
if (weight->isa<mindspore::CNode>() || IsGraphInput(weight)) {
|
||||
InsertDynamicQuantWithIndex(graph, cnode, kWeightIndex + kPrimitiveCOffset);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -199,10 +217,16 @@ int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
|
||||
const std::set<PrimitivePtr> &support_dynamic_quant_ops) {
|
||||
const std::set<PrimitivePtr> &support_dynamic_quant_ops,
|
||||
const std::set<std::string> &skip_quant_node) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
auto cnodes = graph->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
if (skip_quant_node.find(op_name) != skip_quant_node.end()) {
|
||||
MS_LOG(INFO) << op_name << " is skip dynamic quant.";
|
||||
continue;
|
||||
}
|
||||
auto ret = CheckDataType(cnode, kNumberTypeFloat32);
|
||||
if (ret == RET_NO_CHANGE) {
|
||||
continue;
|
||||
|
@ -210,22 +234,22 @@ int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
|
|||
auto is_support_node = CheckNodeInSet(cnode, support_dynamic_quant_ops);
|
||||
if (!is_support_node) {
|
||||
auto type = NodePrimitiveType(cnode);
|
||||
MS_LOG(INFO) << "node:" << cnode->fullname_with_scope() << " type:" << type << " will not quantify.";
|
||||
MS_LOG(INFO) << "node:" << op_name << " type:" << type << " will not quantify.";
|
||||
continue;
|
||||
}
|
||||
ret = NewDynamicQuantNode(graph, cnode);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "node:" << cnode->fullname_with_scope() << " new dynamic quant node failed.";
|
||||
MS_LOG(ERROR) << "node:" << op_name << " new dynamic quant node failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = MarkDynamicQuantize(cnode);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "node:" << cnode->fullname_with_scope() << " new mark dynamic quant node failed.";
|
||||
MS_LOG(ERROR) << "node:" << op_name << " new mark dynamic quant node failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = UpdateDataType(cnode, kNumberTypeFloat32);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "node:" << cnode->fullname_with_scope() << " update datatype failed.";
|
||||
MS_LOG(ERROR) << "node:" << op_name << " update datatype failed.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,11 +18,13 @@
|
|||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_INSERT_QUANT_NODE_MANAGER_H
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "include/errorcode.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "ops/dynamic_quant.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class InsertQuantNodeManager {
|
||||
|
@ -33,7 +35,8 @@ class InsertQuantNodeManager {
|
|||
|
||||
int InsertQuantDtypeCastNode(const FuncGraphPtr &graph);
|
||||
|
||||
int InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set<PrimitivePtr> &support_dynamic_quant_ops);
|
||||
int InsertDynamicQuantNode(const FuncGraphPtr &graph, const std::set<PrimitivePtr> &support_dynamic_quant_ops,
|
||||
const std::set<std::string> &skip_quant_node);
|
||||
|
||||
private:
|
||||
ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params);
|
||||
|
@ -46,9 +49,11 @@ class InsertQuantNodeManager {
|
|||
|
||||
int MarkDynamicQuantize(const CNodePtr &cnode);
|
||||
|
||||
int InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index);
|
||||
|
||||
private:
|
||||
TypeId dst_type_ = kNumberTypeInt8;
|
||||
bool symmetric_ = false;
|
||||
bool symmetric_ = true;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_INSERT_QUANT_NODE_MANAGER_H
|
||||
|
|
|
@ -41,6 +41,11 @@ int WeightQuantizer::WeightQuant(const FuncGraphPtr &func_graph,
|
|||
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
|
||||
continue;
|
||||
}
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
if (flags_.commonQuantParam.skip_quant_node.find(op_name) != flags_.commonQuantParam.skip_quant_node.end()) {
|
||||
MS_LOG(INFO) << op_name << " is skip dynamic quant.";
|
||||
continue;
|
||||
}
|
||||
if (!CheckNodeInSet(cnode, support_weight_quant_types)) {
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << " of type: " << primitive->name() << " dont need weight quant.";
|
||||
continue;
|
||||
|
|
Loading…
Reference in New Issue