!8018 full quant support multi input && add matmul_int8 creator

Merge pull request !8018 from jianghui58/master
This commit is contained in:
mindspore-ci-bot 2020-10-30 16:42:51 +08:00 committed by Gitee
commit 74a2e767d6
3 changed files with 137 additions and 83 deletions

View File

@ -19,9 +19,12 @@
#include "nnacl/common_func.h" #include "nnacl/common_func.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/kernel_registry.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_MEMORY_FAILED; using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_MatMul;
namespace mindspore::kernel { namespace mindspore::kernel {
MatmulInt8CPUKernel::~MatmulInt8CPUKernel() { FreeTmpBuffer(); } MatmulInt8CPUKernel::~MatmulInt8CPUKernel() { FreeTmpBuffer(); }
@ -193,4 +196,29 @@ int MatmulInt8CPUKernel::Run() {
} }
return RET_OK; return RET_OK;
} }
kernel::LiteKernel *CpuMatMulInt8KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_MatMul);
auto *kernel = new (std::nothrow) MatmulInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
free(opParameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_MatMul, CpuMatMulInt8KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

View File

@ -279,14 +279,16 @@ std::map<CNodePtr, MaxMin> Calibrator::GetMinMax(
} }
void Calibrator::Dump() { void Calibrator::Dump() {
for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) { for (auto &kv : this->inputs_diverg_info_) {
DivergInfo *info = iter->second.get(); auto &infos = kv.second;
info->DumpHistogram(); for (auto &info : infos) {
info->DumpHistogram();
}
} }
} }
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *Calibrator::GetInputDivergInfo() { std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *Calibrator::GetInputDivergInfo() {
return &this->input_diverg_info_; return &this->inputs_diverg_info_;
} }
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *Calibrator::GetOutputDivergInfo() { std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *Calibrator::GetOutputDivergInfo() {
@ -307,39 +309,41 @@ STATUS Calibrator::ComputeThreshold() {
} }
} }
// node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as // node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as
for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) { for (auto &kv : this->inputs_diverg_info_) {
DivergInfo *info = iter->second.get(); auto &input_infos = kv.second;
auto cnode = info->cnode; for (size_t i = 0; i < input_infos.size(); i++) {
auto cnode = input_infos[i]->cnode;
bool already_computed = false; bool already_computed = false;
auto input = cnode->input(1); auto input = cnode->input(i + 1);
if (input->isa<mindspore::CNode>()) { if (input->isa<mindspore::CNode>()) {
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input); auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input);
for (const auto &outputs_diverg_info : outputs_diverg_info_) { for (const auto &outputs_diverg_info : outputs_diverg_info_) {
if (already_computed) { if (already_computed) {
break; break;
} }
for (const auto &output_diverg_info : outputs_diverg_info.second) { for (const auto &output_diverg_info : outputs_diverg_info.second) {
auto output_diverg_cnode = output_diverg_info->cnode; auto output_diverg_cnode = output_diverg_info->cnode;
if (output_diverg_cnode == input_cnode) { if (output_diverg_cnode == input_cnode) {
if (NodePrimitiveType(input_cnode) != schema::PrimitiveType_TupleGetItem) { if (NodePrimitiveType(input_cnode) != schema::PrimitiveType_TupleGetItem) {
*info = *output_diverg_info; *(input_infos[i]) = *output_diverg_info;
info->cnode = cnode; input_infos[i]->cnode = cnode;
already_computed = true; already_computed = true;
break; break;
}
} }
} }
} }
} }
} if (!already_computed) {
if (!already_computed) { input_infos[i]->ComputeThreshold();
info->ComputeThreshold(); }
} }
} }
return RET_OK; return RET_OK;
} }
STATUS Calibrator::UpdateOutputDivergInverval( STATUS Calibrator::UpdateDivergInverval(
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info) { std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info) {
for (auto &kv : *diverg_info) { for (auto &kv : *diverg_info) {
for (auto &info : kv.second) { for (auto &info : kv.second) {
@ -349,14 +353,6 @@ STATUS Calibrator::UpdateOutputDivergInverval(
return RET_OK; return RET_OK;
} }
STATUS Calibrator::UpdateDivergInverval(std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
for (auto iter = (*diverg_info).begin(); iter != (*diverg_info).end(); iter++) {
DivergInfo *info = iter->second.get();
info->UpdateInterval();
}
return RET_OK;
}
STATUS Calibrator::UpdateDataFrequency(const vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info) { STATUS Calibrator::UpdateDataFrequency(const vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info) {
diverg_info->UpdateHistogram(data); diverg_info->UpdateHistogram(data);
return RET_OK; return RET_OK;
@ -373,7 +369,7 @@ STATUS Calibrator::AddQuantizedOp(CNodePtr node) {
std::unique_ptr<DivergInfo> output_diverg = std::unique_ptr<DivergInfo>( std::unique_ptr<DivergInfo> output_diverg = std::unique_ptr<DivergInfo>(
new DivergInfo(node, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, config_param_.method_x)); new DivergInfo(node, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, config_param_.method_x));
input_diverg_info_.insert(std::make_pair(node_name, std::move(input_diverg))); inputs_diverg_info_[node_name].push_back(std::move(input_diverg));
outputs_diverg_info_[node_name].push_back(std::move(output_diverg)); outputs_diverg_info_[node_name].push_back(std::move(output_diverg));
return RET_OK; return RET_OK;
} }
@ -746,10 +742,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
} }
STATUS PostTrainingQuantizer::QuantNode() { STATUS PostTrainingQuantizer::QuantNode() {
auto input_min_max = this->calibrator_->GetMinMax(this->calibrator_->GetInputDivergInfo()); auto inputs_diverg_info = calibrator_->GetInputDivergInfo();
auto input_scale = this->calibrator_->GetScale(this->calibrator_->GetInputDivergInfo());
auto input_zero_point = this->calibrator_->GetZeropoint(this->calibrator_->GetInputDivergInfo());
auto outputs_diverg_info = calibrator_->GetOutputDivergInfo(); auto outputs_diverg_info = calibrator_->GetOutputDivergInfo();
auto cnodes = funcGraph->GetOrderedCnodes(); auto cnodes = funcGraph->GetOrderedCnodes();
@ -764,7 +757,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
MS_LOG(ERROR) << "primitive_c is nullptr"; MS_LOG(ERROR) << "primitive_c is nullptr";
continue; continue;
} }
if (input_scale.find(cnode) == input_scale.end()) { if (inputs_diverg_info->find(op_name) == inputs_diverg_info->end()) {
primitive_c->SetQuantType(schema::QuantType_QUANT_NONE); primitive_c->SetQuantType(schema::QuantType_QUANT_NONE);
continue; continue;
} }
@ -803,7 +796,42 @@ STATUS PostTrainingQuantizer::QuantNode() {
op_type != PrimitiveType_FullConnection) { op_type != PrimitiveType_FullConnection) {
for (size_t i = 1; i < cnode->inputs().size(); i++) { for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i); auto input_node = cnode->input(i);
if (!input_node->isa<mindspore::CNode>()) { bool is_graph_input = false;
if (input_node->isa<Parameter>()) {
if (!input_node->cast<ParameterPtr>()->has_default()) {
is_graph_input = true;
}
}
if (input_node->isa<mindspore::CNode>()) {
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
if (input_cnode_primitive_c == nullptr) {
MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": "
<< " PrimitiveC is null";
continue;
}
if (input_cnode_primitive_c->IsOutputQuantParamsInited()) {
auto quant_param = input_cnode_primitive_c->GetOutputQuantParams().front();
primitive_c->AddInputQuantParam(quant_param);
} else {
// do input quant
auto &info = (*inputs_diverg_info)[op_name][i - 1];
auto input_scale = info->GetScale().second;
auto input_zp = info->GetZeropoint().second;
struct MaxMin input_min_max {};
input_min_max.max = info->max;
input_min_max.min = info->min;
DoQuantInput(input_scale, input_zp, &input_min_max, primitive_c);
}
} else if (is_graph_input) {
auto &info = (*inputs_diverg_info)[op_name][i - 1];
auto input_scale = info->GetScale().second;
auto input_zp = info->GetZeropoint().second;
struct MaxMin input_min_max {};
input_min_max.max = info->max;
input_min_max.min = info->min;
DoQuantInput(input_scale, input_zp, &input_min_max, primitive_c);
} else {
MS_LOG(DEBUG) << "node: " << op_name << " input " << i << " not a cnode"; MS_LOG(DEBUG) << "node: " << op_name << " input " << i << " not a cnode";
// get dtype // get dtype
auto abstractBase = input_node->abstract(); auto abstractBase = input_node->abstract();
@ -822,30 +850,17 @@ STATUS PostTrainingQuantizer::QuantNode() {
} else { } else {
MS_LOG(DEBUG) << "this parameter no need to do quant"; MS_LOG(DEBUG) << "this parameter no need to do quant";
} }
continue;
}
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
if (input_cnode_primitive_c == nullptr) {
MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": "
<< " PrimitiveC is null";
continue;
}
if (input_cnode_primitive_c->IsOutputQuantParamsInited()) {
auto quant_param = input_cnode_primitive_c->GetOutputQuantParams().front();
primitive_c->AddInputQuantParam(quant_param);
} else {
// do input quant
double scale = input_scale[cnode];
int32_t zp = input_zero_point[cnode];
DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c);
} }
} }
} else { } else {
// do input quant // do input quant
double scale = input_scale[cnode]; auto &info = (*inputs_diverg_info)[op_name][0];
int32_t convInputzeropoint = input_zero_point[cnode]; auto input_scale = info->GetScale().second;
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c); auto input_zp = info->GetZeropoint().second;
struct MaxMin input_min_max {};
input_min_max.max = info->max;
input_min_max.min = info->min;
DoQuantInput(input_scale, input_zp, &input_min_max, primitive_c);
// do weight quant // do weight quant
auto weight = cnode->input(2); auto weight = cnode->input(2);
bool perchannel = per_channel_; bool perchannel = per_channel_;
@ -878,7 +893,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
STATUS PostTrainingQuantizer::UpdateDivergInverval() { STATUS PostTrainingQuantizer::UpdateDivergInverval() {
this->calibrator_->UpdateDivergInverval(this->calibrator_->GetInputDivergInfo()); this->calibrator_->UpdateDivergInverval(this->calibrator_->GetInputDivergInfo());
this->calibrator_->UpdateOutputDivergInverval(this->calibrator_->GetOutputDivergInfo()); this->calibrator_->UpdateDivergInverval(this->calibrator_->GetOutputDivergInfo());
return RET_OK; return RET_OK;
} }
@ -975,11 +990,21 @@ STATUS PostTrainingQuantizer::DoInference() {
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) { if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
return false; return false;
} }
auto tensor = beforeInputs[0]; if ((*diverg_info_map)[callParam.node_name].size() == 1 &&
const float *tData = static_cast<const float *>(tensor->MutableData()); (callParam.node_type == kTypeConcat || callParam.node_type == kTypeAdd)) {
size_t elem_count = tensor->ElementsNum(); for (size_t i = 1; i < beforeInputs.size(); i++) {
vector<float> data(tData, tData + elem_count); auto input_diverg = std::make_unique<DivergInfo>();
this->calibrator_->RecordMaxValue(data, (*diverg_info_map)[callParam.node_name]); *input_diverg = *((*diverg_info_map)[callParam.node_name][0]);
(*diverg_info_map)[callParam.node_name].push_back(std::move(input_diverg));
}
}
for (size_t i = 0; i < (*diverg_info_map)[callParam.node_name].size(); i++) {
auto tensor = beforeInputs[i];
const float *tensor_data = static_cast<const float *>(tensor->MutableData());
size_t elem_count = tensor->ElementsNum();
vector<float> data(tensor_data, tensor_data + elem_count);
this->calibrator_->RecordMaxValue(data, (*diverg_info_map)[callParam.node_name][i]);
}
return true; return true;
}; };
// func // func
@ -993,10 +1018,10 @@ STATUS PostTrainingQuantizer::DoInference() {
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) { if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
return false; return false;
} }
if (afterOutputs.size() > 1) { if ((*diverg_info_map)[callParam.node_name].size() == 1 && afterOutputs.size() > 1) {
auto output_diverg = std::make_unique<DivergInfo>();
*output_diverg = *((*diverg_info_map)[callParam.node_name][0]);
for (size_t i = 1; i < afterOutputs.size(); i++) { for (size_t i = 1; i < afterOutputs.size(); i++) {
auto output_diverg = std::make_unique<DivergInfo>();
*output_diverg = *((*diverg_info_map)[callParam.node_name][0]);
(*diverg_info_map)[callParam.node_name].push_back(std::move(output_diverg)); (*diverg_info_map)[callParam.node_name].push_back(std::move(output_diverg));
} }
} }
@ -1397,11 +1422,13 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) { if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
return false; return false;
} }
auto tensor = beforeInputs[0]; for (size_t i = 0; i < (*diverg_info_map)[callParam.node_name].size(); i++) {
const float *tensor_data = static_cast<const float *>(tensor->MutableData()); auto tensor = beforeInputs[i];
size_t shape_size = tensor->ElementsNum(); const float *tensor_data = static_cast<const float *>(tensor->MutableData());
vector<float> data(tensor_data, tensor_data + shape_size); size_t elem_count = tensor->ElementsNum();
this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[callParam.node_name]); vector<float> data(tensor_data, tensor_data + elem_count);
this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[callParam.node_name][i]);
}
return true; return true;
}; };

View File

@ -91,6 +91,8 @@ class PostTrainingQuantizer : public Quantizer {
const std::string kTypeConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_Conv2D); const std::string kTypeConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_Conv2D);
const std::string kTypeDepthwiseConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_DepthwiseConv2D); const std::string kTypeDepthwiseConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_DepthwiseConv2D);
const std::string kTypeConcat = schema::EnumNamePrimitiveType(schema::PrimitiveType_Concat);
const std::string kTypeAdd = schema::EnumNamePrimitiveType(schema::PrimitiveType_Add);
STATUS PreProcess(); STATUS PreProcess();
@ -191,10 +193,7 @@ class Calibrator {
STATUS RecordMaxValue(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info); STATUS RecordMaxValue(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info);
STATUS UpdateDivergInverval(std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info); STATUS UpdateDivergInverval(std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info);
STATUS UpdateOutputDivergInverval(
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info);
STATUS UpdateDataFrequency(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info); STATUS UpdateDataFrequency(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info);
void Dump(); void Dump();
@ -209,7 +208,7 @@ class Calibrator {
std::map<CNodePtr, MaxMin> GetMinMax(std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info); std::map<CNodePtr, MaxMin> GetMinMax(std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *GetInputDivergInfo(); std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetInputDivergInfo();
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetOutputDivergInfo(); std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetOutputDivergInfo();
@ -220,7 +219,7 @@ class Calibrator {
ConfigParam config_param_; ConfigParam config_param_;
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> input_diverg_info_; std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> inputs_diverg_info_;
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> outputs_diverg_info_; std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> outputs_diverg_info_;