diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc index 1d4e96f8afb..3dd81a1eac8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_int8.cc @@ -19,9 +19,12 @@ #include "nnacl/common_func.h" #include "src/runtime/runtime_api.h" #include "include/errorcode.h" +#include "src/kernel_registry.h" +using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_MEMORY_FAILED; using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_MatMul; namespace mindspore::kernel { MatmulInt8CPUKernel::~MatmulInt8CPUKernel() { FreeTmpBuffer(); } @@ -193,4 +196,29 @@ int MatmulInt8CPUKernel::Run() { } return RET_OK; } + +kernel::LiteKernel *CpuMatMulInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::InnerContext *ctx, const KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + MS_ASSERT(opParameter != nullptr); + MS_ASSERT(desc.type == schema::PrimitiveType_MatMul); + auto *kernel = new (std::nothrow) MatmulInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); + + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + free(opParameter); + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + delete kernel; + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_MatMul, CpuMatMulInt8KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index e0b21371447..3dff798424e 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -279,14 +279,16 @@ std::map Calibrator::GetMinMax( } void Calibrator::Dump() { - for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) { - DivergInfo *info = iter->second.get(); - info->DumpHistogram(); + for (auto &kv : this->inputs_diverg_info_) { + auto &infos = kv.second; + for (auto &info : infos) { + info->DumpHistogram(); + } } } -std::unordered_map> *Calibrator::GetInputDivergInfo() { - return &this->input_diverg_info_; +std::unordered_map>> *Calibrator::GetInputDivergInfo() { + return &this->inputs_diverg_info_; } std::unordered_map>> *Calibrator::GetOutputDivergInfo() { @@ -307,39 +309,41 @@ STATUS Calibrator::ComputeThreshold() { } } // node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as - for (auto iter = this->input_diverg_info_.begin(); iter != this->input_diverg_info_.end(); iter++) { - DivergInfo *info = iter->second.get(); - auto cnode = info->cnode; + for (auto &kv : this->inputs_diverg_info_) { + auto &input_infos = kv.second; + for (size_t i = 0; i < input_infos.size(); i++) { + auto cnode = input_infos[i]->cnode; - bool already_computed = false; - auto input = cnode->input(1); - if (input->isa()) { - auto input_cnode = std::dynamic_pointer_cast(input); - for (const auto &outputs_diverg_info : outputs_diverg_info_) { - if (already_computed) { - break; - } - for (const auto &output_diverg_info : outputs_diverg_info.second) { - auto output_diverg_cnode = output_diverg_info->cnode; - if (output_diverg_cnode == input_cnode) { - if (NodePrimitiveType(input_cnode) != schema::PrimitiveType_TupleGetItem) { - *info = *output_diverg_info; - info->cnode = cnode; - already_computed = true; - break; + bool already_computed = false; + auto input = cnode->input(i + 1); + if (input->isa()) { + auto input_cnode = std::dynamic_pointer_cast(input); + for (const auto &outputs_diverg_info : outputs_diverg_info_) { + if (already_computed) { + break; + } + for (const auto &output_diverg_info : outputs_diverg_info.second) { + auto output_diverg_cnode = output_diverg_info->cnode; + if (output_diverg_cnode == input_cnode) { + if (NodePrimitiveType(input_cnode) != schema::PrimitiveType_TupleGetItem) { + *(input_infos[i]) = *output_diverg_info; + input_infos[i]->cnode = cnode; + already_computed = true; + break; + } } } } } - } - if (!already_computed) { - info->ComputeThreshold(); + if (!already_computed) { + input_infos[i]->ComputeThreshold(); + } } } return RET_OK; } -STATUS Calibrator::UpdateOutputDivergInverval( +STATUS Calibrator::UpdateDivergInverval( std::unordered_map>> *diverg_info) { for (auto &kv : *diverg_info) { for (auto &info : kv.second) { @@ -349,14 +353,6 @@ STATUS Calibrator::UpdateOutputDivergInverval( return RET_OK; } -STATUS Calibrator::UpdateDivergInverval(std::unordered_map> *diverg_info) { - for (auto iter = (*diverg_info).begin(); iter != (*diverg_info).end(); iter++) { - DivergInfo *info = iter->second.get(); - info->UpdateInterval(); - } - return RET_OK; -} - STATUS Calibrator::UpdateDataFrequency(const vector &data, const std::unique_ptr &diverg_info) { diverg_info->UpdateHistogram(data); return RET_OK; @@ -373,7 +369,7 @@ STATUS Calibrator::AddQuantizedOp(CNodePtr node) { std::unique_ptr output_diverg = std::unique_ptr( new DivergInfo(node, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, config_param_.method_x)); - input_diverg_info_.insert(std::make_pair(node_name, std::move(input_diverg))); + inputs_diverg_info_[node_name].push_back(std::move(input_diverg)); outputs_diverg_info_[node_name].push_back(std::move(output_diverg)); return RET_OK; } @@ -746,10 +742,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptrcalibrator_->GetMinMax(this->calibrator_->GetInputDivergInfo()); - auto input_scale = this->calibrator_->GetScale(this->calibrator_->GetInputDivergInfo()); - auto input_zero_point = this->calibrator_->GetZeropoint(this->calibrator_->GetInputDivergInfo()); - + auto inputs_diverg_info = calibrator_->GetInputDivergInfo(); auto outputs_diverg_info = calibrator_->GetOutputDivergInfo(); auto cnodes = funcGraph->GetOrderedCnodes(); @@ -764,7 +757,7 @@ STATUS PostTrainingQuantizer::QuantNode() { MS_LOG(ERROR) << "primitive_c is nullptr"; continue; } - if (input_scale.find(cnode) == input_scale.end()) { + if (inputs_diverg_info->find(op_name) == inputs_diverg_info->end()) { primitive_c->SetQuantType(schema::QuantType_QUANT_NONE); continue; } @@ -803,7 +796,42 @@ STATUS PostTrainingQuantizer::QuantNode() { op_type != PrimitiveType_FullConnection) { for (size_t i = 1; i < cnode->inputs().size(); i++) { auto input_node = cnode->input(i); - if (!input_node->isa()) { + bool is_graph_input = false; + if (input_node->isa()) { + if (!input_node->cast()->has_default()) { + is_graph_input = true; + } + } + if (input_node->isa()) { + auto input_cnode = std::dynamic_pointer_cast(input_node); + auto input_cnode_primitive_c = GetValueNode>(input_cnode->input(0)); + if (input_cnode_primitive_c == nullptr) { + MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " + << " PrimitiveC is null"; + continue; + } + if (input_cnode_primitive_c->IsOutputQuantParamsInited()) { + auto quant_param = input_cnode_primitive_c->GetOutputQuantParams().front(); + primitive_c->AddInputQuantParam(quant_param); + } else { + // do input quant + auto &info = (*inputs_diverg_info)[op_name][i - 1]; + auto input_scale = info->GetScale().second; + auto input_zp = info->GetZeropoint().second; + struct MaxMin input_min_max {}; + input_min_max.max = info->max; + input_min_max.min = info->min; + DoQuantInput(input_scale, input_zp, &input_min_max, primitive_c); + } + } else if (is_graph_input) { + auto &info = (*inputs_diverg_info)[op_name][i - 1]; + auto input_scale = info->GetScale().second; + auto input_zp = info->GetZeropoint().second; + struct MaxMin input_min_max {}; + input_min_max.max = info->max; + input_min_max.min = info->min; + DoQuantInput(input_scale, input_zp, &input_min_max, primitive_c); + } else { MS_LOG(DEBUG) << "node: " << op_name << " input " << i << " not a cnode"; // get dtype auto abstractBase = input_node->abstract(); @@ -822,30 +850,17 @@ STATUS PostTrainingQuantizer::QuantNode() { } else { MS_LOG(DEBUG) << "this parameter no need to do quant"; } - continue; - } - auto input_cnode = std::dynamic_pointer_cast(input_node); - auto input_cnode_primitive_c = GetValueNode>(input_cnode->input(0)); - if (input_cnode_primitive_c == nullptr) { - MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " - << " PrimitiveC is null"; - continue; - } - if (input_cnode_primitive_c->IsOutputQuantParamsInited()) { - auto quant_param = input_cnode_primitive_c->GetOutputQuantParams().front(); - primitive_c->AddInputQuantParam(quant_param); - } else { - // do input quant - double scale = input_scale[cnode]; - int32_t zp = input_zero_point[cnode]; - DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c); } } } else { // do input quant - double scale = input_scale[cnode]; - int32_t convInputzeropoint = input_zero_point[cnode]; - DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c); + auto &info = (*inputs_diverg_info)[op_name][0]; + auto input_scale = info->GetScale().second; + auto input_zp = info->GetZeropoint().second; + struct MaxMin input_min_max {}; + input_min_max.max = info->max; + input_min_max.min = info->min; + DoQuantInput(input_scale, input_zp, &input_min_max, primitive_c); // do weight quant auto weight = cnode->input(2); bool perchannel = per_channel_; @@ -878,7 +893,7 @@ STATUS PostTrainingQuantizer::QuantNode() { STATUS PostTrainingQuantizer::UpdateDivergInverval() { this->calibrator_->UpdateDivergInverval(this->calibrator_->GetInputDivergInfo()); - this->calibrator_->UpdateOutputDivergInverval(this->calibrator_->GetOutputDivergInfo()); + this->calibrator_->UpdateDivergInverval(this->calibrator_->GetOutputDivergInfo()); return RET_OK; } @@ -975,11 +990,21 @@ STATUS PostTrainingQuantizer::DoInference() { if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) { return false; } - auto tensor = beforeInputs[0]; - const float *tData = static_cast(tensor->MutableData()); - size_t elem_count = tensor->ElementsNum(); - vector data(tData, tData + elem_count); - this->calibrator_->RecordMaxValue(data, (*diverg_info_map)[callParam.node_name]); + if ((*diverg_info_map)[callParam.node_name].size() == 1 && + (callParam.node_type == kTypeConcat || callParam.node_type == kTypeAdd)) { + for (size_t i = 1; i < beforeInputs.size(); i++) { + auto input_diverg = std::make_unique(); + *input_diverg = *((*diverg_info_map)[callParam.node_name][0]); + (*diverg_info_map)[callParam.node_name].push_back(std::move(input_diverg)); + } + } + for (size_t i = 0; i < (*diverg_info_map)[callParam.node_name].size(); i++) { + auto tensor = beforeInputs[i]; + const float *tensor_data = static_cast(tensor->MutableData()); + size_t elem_count = tensor->ElementsNum(); + vector data(tensor_data, tensor_data + elem_count); + this->calibrator_->RecordMaxValue(data, (*diverg_info_map)[callParam.node_name][i]); + } return true; }; // func @@ -993,10 +1018,10 @@ STATUS PostTrainingQuantizer::DoInference() { if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) { return false; } - if (afterOutputs.size() > 1) { - auto output_diverg = std::make_unique(); - *output_diverg = *((*diverg_info_map)[callParam.node_name][0]); + if ((*diverg_info_map)[callParam.node_name].size() == 1 && afterOutputs.size() > 1) { for (size_t i = 1; i < afterOutputs.size(); i++) { + auto output_diverg = std::make_unique(); + *output_diverg = *((*diverg_info_map)[callParam.node_name][0]); (*diverg_info_map)[callParam.node_name].push_back(std::move(output_diverg)); } } @@ -1397,11 +1422,13 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() { if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) { return false; } - auto tensor = beforeInputs[0]; - const float *tensor_data = static_cast(tensor->MutableData()); - size_t shape_size = tensor->ElementsNum(); - vector data(tensor_data, tensor_data + shape_size); - this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[callParam.node_name]); + for (size_t i = 0; i < (*diverg_info_map)[callParam.node_name].size(); i++) { + auto tensor = beforeInputs[i]; + const float *tensor_data = static_cast(tensor->MutableData()); + size_t elem_count = tensor->ElementsNum(); + vector data(tensor_data, tensor_data + elem_count); + this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[callParam.node_name][i]); + } return true; }; diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h index 4ae4362a793..d9b75356331 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h @@ -91,6 +91,8 @@ class PostTrainingQuantizer : public Quantizer { const std::string kTypeConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_Conv2D); const std::string kTypeDepthwiseConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_DepthwiseConv2D); + const std::string kTypeConcat = schema::EnumNamePrimitiveType(schema::PrimitiveType_Concat); + const std::string kTypeAdd = schema::EnumNamePrimitiveType(schema::PrimitiveType_Add); STATUS PreProcess(); @@ -191,10 +193,7 @@ class Calibrator { STATUS RecordMaxValue(const std::vector &data, const std::unique_ptr &diverg_info); - STATUS UpdateDivergInverval(std::unordered_map> *diverg_info); - - STATUS UpdateOutputDivergInverval( - std::unordered_map>> *diverg_info); + STATUS UpdateDivergInverval(std::unordered_map>> *diverg_info); STATUS UpdateDataFrequency(const std::vector &data, const std::unique_ptr &diverg_info); void Dump(); @@ -209,7 +208,7 @@ class Calibrator { std::map GetMinMax(std::unordered_map> *diverg_info); - std::unordered_map> *GetInputDivergInfo(); + std::unordered_map>> *GetInputDivergInfo(); std::unordered_map>> *GetOutputDivergInfo(); @@ -220,7 +219,7 @@ class Calibrator { ConfigParam config_param_; - std::unordered_map> input_diverg_info_; + std::unordered_map>> inputs_diverg_info_; std::unordered_map>> outputs_diverg_info_;