diff --git a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc index 7ff75ba0b4c..7bbe8af23ac 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/fullconnection_base.cc @@ -63,7 +63,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vectordata_c(); - if (!weight_tensor->GetQuantParams().empty()) { + if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor); if (dequant_weight == nullptr) { MS_LOG(ERROR) << "dequant data is nullptr."; @@ -91,7 +91,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vectorGetQuantParams().empty()) { + if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) { weight_tensor->FreeData(); weight_tensor->SetData(restore_data); } diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index b1d8d970ce6..4d3a4e45c90 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -93,7 +93,7 @@ std::vector GetNhwcAllInputOpList() { return nhwcOpAllInp std::vector GetUint8NhwcOpList() { return int8NeedNhwcOpList; } -std::vector GetUint8OpList() { return int8OpList; } +std::vector GetInt8OpList() { return int8OpList; } STATUS NodeUtils::ConvertDims(mindspore::schema::Format src_format, const std::vector &src_dims, mindspore::schema::Format dst_format, std::vector *dst_dims) { diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index 2e096a541ac..0a7bc94a9f0 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -42,7 +42,7 @@ std::vector Getfp32FullOpList(); std::vector GetUint8NhwcOpList(); -std::vector GetUint8OpList(); +std::vector GetInt8OpList(); class NodeUtils { public: diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc index 9c2b4b924a9..05cfb8e9549 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -51,13 +51,7 @@ STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); - // modify inputTensor first auto &graphInIdxes = graph->inputIndex; - for (auto graphInIdx : graphInIdxes) { - MS_ASSERT(graph->allTensors.size() > graphInIdx); - auto &graphInTensor = graph->allTensors.at(graphInIdx); - graphInTensor->dataType = TypeId::kNumberTypeInt8; - } if (this->inputDataDType == TypeId::kNumberTypeInt8) { return RET_OK; @@ -70,7 +64,7 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { for (auto graphInIdx : graphInIdxes) { MS_ASSERT(graphInIdx < graph->allTensors.size()); auto &tensor = graph->allTensors.at(graphInIdx); - if (tensor->dims.size() != kNHWCDimNumber) { + if (tensor->dims.size() != kNHWCDimNumber || tensor->dataType != kNumberTypeInt8) { continue; } @@ -137,7 +131,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); // insert transNode before and after existNode for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { + if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { continue; } if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) { @@ -157,10 +151,16 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) { MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i)); auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i)); + if (preTensor->dataType == TypeId::kNumberTypeInt || preTensor->dataType == TypeId::kNumberTypeInt32) { + continue; + } auto &graphInIdxes = graph->inputIndex; if (!preTensor->data.empty() && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { continue; } + if (IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) { + continue; + } iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status); if (status != RET_OK) { MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed"; @@ -170,6 +170,10 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { if (needInsertPost) { for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) { + auto &postTensor = graph->allTensors.at((*iter)->outputIndex.at(i)); + if (postTensor->dataType == TypeId::kNumberTypeInt || postTensor->dataType == TypeId::kNumberTypeInt32) { + continue; + } iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status); if (status != RET_OK) { MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed"; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 564fa4f1388..82e4eb5d06f 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -79,6 +79,7 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptrtype) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) { quant_param->zeroPoint = quant_param->zeroPoint - 128; + tensor->dataType = TypeId::kNumberTypeInt8; } if (!tflite_tensor->quantization->min.empty()) { @@ -164,11 +165,7 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr MS_LOG(ERROR) << "obtain const tensor failed"; return status; } - } else if (quantType == QuantType_AwareTraining && tensor->dataType == TypeId::kNumberTypeUInt8) { - // set in/out tensor to int8 to fit ms-lite op - tensor->dataType = TypeId::kNumberTypeInt8; } - // set tensor attr if (isInput || isConst) { tensor->nodeType = schema::NodeType::NodeType_ValueNode; diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc index 21f16e4ba8f..0e40fba1319 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc @@ -145,7 +145,7 @@ STATUS AwareQuantizer::GenerateQuantParam() { STATUS AwareQuantizer::DoQuantize() { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { auto &node = *iter; - if (!IsContain(GetUint8OpList(), GetCNodeTType(*node))) { + if (!IsContain(GetInt8OpList(), GetCNodeTType(*node))) { continue; } if (node->quantType != schema::QuantType_AwareTraining) { @@ -388,7 +388,7 @@ STATUS AwareQuantizer::DetermineNodeQuantType() { } } - if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) { + if (canQuant && IsContain(GetInt8OpList(), GetCNodeTType(*node))) { node->quantType = schema::QuantType_AwareTraining; } else { node->quantType = schema::QuantType_QUANT_NONE;