diff --git a/include/api/types.h b/include/api/types.h index 10be863aeaf..f983f6270d1 100644 --- a/include/api/types.h +++ b/include/api/types.h @@ -58,6 +58,8 @@ struct QuantParam { int bit_num; double scale; int32_t zero_point; + double min; + double max; }; class Allocator; diff --git a/mindspore/lite/src/cxx_api/tensor/tensor_impl.h b/mindspore/lite/src/cxx_api/tensor/tensor_impl.h index 699919c8a71..3151a583e35 100644 --- a/mindspore/lite/src/cxx_api/tensor/tensor_impl.h +++ b/mindspore/lite/src/cxx_api/tensor/tensor_impl.h @@ -218,6 +218,8 @@ class MSTensor::Impl { param.bit_num = lite_quant_params[i].bitNum; param.scale = lite_quant_params[i].scale; param.zero_point = lite_quant_params[i].zeroPoint; + param.min = lite_quant_params[i].min; + param.max = lite_quant_params[i].max; quant_params.push_back(param); } return quant_params; diff --git a/mindspore/lite/src/delegate/tensorrt/op/activation_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/activation_tensorrt.cc index a518b03038b..6d671e995b8 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/activation_tensorrt.cc +++ b/mindspore/lite/src/delegate/tensorrt/op/activation_tensorrt.cc @@ -100,6 +100,7 @@ int ActivationTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { out_tensor->setName((op_name_ + "_output").c_str()); this->AddInnerOutTensors( ITensorHelper{out_tensor, tensorrt_in_tensors_[0].format_, tensorrt_in_tensors_[0].same_format_}); + tensor_name_map_[out_tensor->getName()] = op_name_; return RET_OK; } nvinfer1::IActivationLayer *ActivationTensorRT::AddActivation(nvinfer1::INetworkDefinition *network, diff --git a/mindspore/lite/src/delegate/tensorrt/op/convolution_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/convolution_tensorrt.cc index d75d8005205..f0427ac7d95 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/convolution_tensorrt.cc +++ b/mindspore/lite/src/delegate/tensorrt/op/convolution_tensorrt.cc @@ -63,6 +63,7 @@ int ConvolutionTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { return RET_ERROR; } transpose_layer_in->setName((op_name_ + "_transpose2NCHW").c_str()); + tensor_name_map_[transpose_layer_in->getOutput(0)->getName()] = op_name_; conv_input = transpose_layer_in->getOutput(0); } @@ -105,6 +106,7 @@ int ConvolutionTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { return RET_ERROR; } conv_layer->setName((op_name_ + "_conv").c_str()); + tensor_name_map_[conv_layer->getOutput(0)->getName()] = op_name_; // add params SetAttributes(conv_op, conv_layer); @@ -124,6 +126,8 @@ int ConvolutionTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { } activation_layer->getOutput(0)->setName((op_name_ + "_output").c_str()); this->AddInnerOutTensors(ITensorHelper{activation_layer->getOutput(0), Format::NCHW, false}); + // add tensor name mapping + tensor_name_map_[activation_layer->getOutput(0)->getName()] = op_name_; return RET_OK; } diff --git a/mindspore/lite/src/delegate/tensorrt/op/deconvolution_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/deconvolution_tensorrt.cc index 9a6a7d04b6a..0c748b257d3 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/deconvolution_tensorrt.cc +++ b/mindspore/lite/src/delegate/tensorrt/op/deconvolution_tensorrt.cc @@ -60,6 +60,7 @@ int DeconvolutionTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { return RET_ERROR; } transpose_layer_in->setName((op_name_ + "_transpose2NCHW").c_str()); + tensor_name_map_[transpose_layer_in->getOutput(0)->getName()] = op_name_; deconv_input = transpose_layer_in->getOutput(0); } @@ -103,6 +104,8 @@ int DeconvolutionTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { } deconv_layer->setName((op_name_ + "_deconv").c_str()); + // add tensor name mapping + tensor_name_map_[deconv_layer->getOutput(0)->getName()] = op_name_; // set extra params SetAttributes(deconv_op, deconv_layer); @@ -121,6 +124,7 @@ int DeconvolutionTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { } activation_layer->getOutput(0)->setName((op_name_ + "_output").c_str()); this->AddInnerOutTensors(ITensorHelper{activation_layer->getOutput(0), Format::NCHW, false}); + tensor_name_map_[activation_layer->getOutput(0)->getName()] = op_name_; return RET_OK; } diff --git a/mindspore/lite/src/delegate/tensorrt/op/matmul_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/matmul_tensorrt.cc index 0be2bdd49e5..25be18f2d39 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/matmul_tensorrt.cc +++ b/mindspore/lite/src/delegate/tensorrt/op/matmul_tensorrt.cc @@ -68,6 +68,7 @@ int MatMulTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { network->addMatrixMultiply(*matmul_a.trt_tensor_, transpose_a_, *matmul_b.trt_tensor_, transpose_b_); matmul_layer->setName(op_name_.c_str()); nvinfer1::ITensor *out_tensor = matmul_layer->getOutput(0); + tensor_name_map_[matmul_layer->getOutput(0)->getName()] = op_name_; if (in_tensors_.size() == BIAS_INDEX + 1) { nvinfer1::ITensor *bias = nullptr; diff --git a/mindspore/lite/src/delegate/tensorrt/op/pool_tensorrt.cc b/mindspore/lite/src/delegate/tensorrt/op/pool_tensorrt.cc index 5437fb6eb83..1488ec3c9ba 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/pool_tensorrt.cc +++ b/mindspore/lite/src/delegate/tensorrt/op/pool_tensorrt.cc @@ -63,6 +63,7 @@ int PoolTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { return RET_ERROR; } transpose_layer_in->setName((op_name_ + "_transpose2NCHW").c_str()); + tensor_name_map_[transpose_layer_in->getOutput(0)->getName()] = op_name_; pool_input = transpose_layer_in->getOutput(0); } @@ -79,6 +80,7 @@ int PoolTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { } AddParams(pooling_layer); pooling_layer->setName(op_name_.c_str()); + tensor_name_map_[pooling_layer->getOutput(0)->getName()] = op_name_; // add activation nvinfer1::ILayer *activation_layer = nullptr; @@ -96,6 +98,7 @@ int PoolTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) { nvinfer1::ITensor *out_trt_tensor = activation_layer->getOutput(0); out_trt_tensor->setName((op_name_ + "_output").c_str()); this->AddInnerOutTensors(ITensorHelper{out_trt_tensor, Format::NCHW, false}); + tensor_name_map_[activation_layer->getOutput(0)->getName()] = op_name_; MS_LOG(DEBUG) << "output " << GetTensorFormat(tensorrt_out_tensors_[0]); return RET_OK; } diff --git a/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.cc b/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.cc index bbe7e1cd832..843b5bcc76c 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.cc +++ b/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.cc @@ -16,6 +16,7 @@ #include "src/delegate/tensorrt/op/tensorrt_op.h" #include "src/delegate/tensorrt/tensorrt_runtime.h" +#include namespace mindspore::lite { const schema::Primitive *TensorRTOp::GetPrimitive() { return this->op_primitive_; } @@ -46,6 +47,8 @@ const std::vector &TensorRTOp::out_ops() const { return this->out_ void TensorRTOp::SetRuntime(TensorRTRuntime *runtime) { this->runtime_ = runtime; } +std::unordered_map TensorRTOp::GetTensorNameMap() { return this->tensor_name_map_; } + bool TensorRTOp::IsShapeKnown() { if (this->in_tensors_.size() == 1 && this->in_tensors_[0].Shape().size() == 0) { return false; diff --git a/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.h b/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.h index 44bee82ba10..216964cf7a8 100644 --- a/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.h +++ b/mindspore/lite/src/delegate/tensorrt/op/tensorrt_op.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "include/api/kernel.h" #include "src/common/log_adapter.h" #include "include/errorcode.h" @@ -104,6 +105,8 @@ class TensorRTOp { DynamicShapeParams GetDynamicShapeParams() const; + std::unordered_map GetTensorNameMap(); + protected: bool IsShapeKnown(); @@ -130,6 +133,8 @@ class TensorRTOp { TensorRTRuntime *runtime_{nullptr}; DynamicShapeParams dynamic_shape_params_; + + std::unordered_map tensor_name_map_; }; template diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.cc b/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.cc index 50d1267653a..4964d30e7a0 100644 --- a/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.cc +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include "src/delegate/delegate_utils.h" @@ -84,6 +85,51 @@ int TensorRTSubGraph::Init(cudaStream_t stream) { input_hw_index_ = -1; } } + if (GetInt8DynamicRange() != RET_OK) { + MS_LOG(WARNING) << "get tensorrt dynamic range failed."; + } + if (SetDeviceConfig(stream) != RET_OK) { + MS_LOG(WARNING) << "set tensorrt config failed."; + } + return RET_OK; +} + +int TensorRTSubGraph::GetInt8DynamicRange() { + if (!IsInt8Mode() || !runtime_->GetBuilder()->platformHasFastInt8()) { + MS_LOG(WARNING) << "no int8 mode, not need dynamic range."; + } + // input tensor + for (size_t i = 0; i < inputs_.size(); i++) { + auto quant_params = inputs_[i].QuantParams(); + auto tensor_name = inputs_[i].Name(); + for (auto param : quant_params) { + dynamic_range_map_[tensor_name] = param.max; + } + } + // output tensor + for (size_t i = 0; i < outputs_.size(); i++) { + auto quant_params = outputs_[i].QuantParams(); + auto tensor_name = outputs_[i].Name(); + for (auto param : quant_params) { + dynamic_range_map_[tensor_name] = param.max; + } + } + // op tensor + for (auto cur_op : all_ops_) { + for (auto in_tensor : cur_op->inputs()) { + auto tensor_name = in_tensor.Name(); + for (auto param : in_tensor.QuantParams()) { + dynamic_range_map_[tensor_name] = param.max; + } + } + + for (auto out_tensor : cur_op->outputs()) { + auto tensor_name = out_tensor.Name(); + for (auto param : out_tensor.QuantParams()) { + dynamic_range_map_[tensor_name] = param.max; + } + } + } return RET_OK; } @@ -122,6 +168,16 @@ int TensorRTSubGraph::SetDeviceConfig(cudaStream_t stream) { config_->setFlag(nvinfer1::BuilderFlag::kFP16); } + // set int8 + if (IsInt8Mode() && runtime_->GetBuilder()->platformHasFastInt8()) { + MS_LOG(INFO) << "set int8 flag successfully for tensorrt."; + config_->setFlag(nvinfer1::BuilderFlag::kINT8); + // Mark calibrator as null + config_->setInt8Calibrator(nullptr); + input_hw_index_ = -1; + } else { + MS_LOG(INFO) << "inputs no quant params or platform not support int8."; + } config_->setProfileStream(stream); // config setMaxWorkspaceSize to 1152 MB for max limit @@ -129,6 +185,130 @@ int TensorRTSubGraph::SetDeviceConfig(cudaStream_t stream) { return RET_OK; } +bool TensorRTSubGraph::SupportFP16() { + int deviceCnt = 0; + + cudaError ret = cudaGetDeviceCount(&deviceCnt); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cudaGetDeviceCount failed."; + return false; + } + std::vector supportFP16_versions{"5.3", "6.0", "6.2", "7.0", "7.2", "7.5", "8.0", "8.6"}; + cudaDeviceProp prop; + std::string version; + for (int dev = 0; dev < deviceCnt; dev++) { + ret = cudaGetDeviceProperties(&prop, dev); + if (ret != cudaSuccess) { + MS_LOG(ERROR) << "cuDeviceGetAttribute failed."; + return false; + } + version = std::to_string(prop.major) + "." + std::to_string(prop.minor); + if (std::find(supportFP16_versions.begin(), supportFP16_versions.end(), version) != supportFP16_versions.end()) { + MS_LOG(INFO) << "cuda device version is: " << version << ", support FP16, set enable FP16 tag successful"; + return true; + } + } + MS_LOG(WARNING) << "cuda device version is: " << version << ", don't support FP16, set enable FP16 tag failed"; + return false; +} + +bool TensorRTSubGraph::IsInt8Mode() { + bool isInt8Mode = false; + for (auto cur_op : all_ops_) { + for (auto in_tensor : cur_op->inputs()) { + if (cur_op->inputs().front().QuantParams().empty()) { + continue; + } + auto quant_param = cur_op->inputs().front().QuantParams().front(); + if (quant_param.max > 0) { + isInt8Mode = true; + break; + } + } + + for (auto out_tensor : cur_op->outputs()) { + if (cur_op->outputs().front().QuantParams().empty()) { + continue; + } + auto quant_param = cur_op->outputs().front().QuantParams().front(); + if (quant_param.max > 0) { + isInt8Mode = true; + break; + } + } + } + return isInt8Mode; +} + +void TensorRTSubGraph::SetInt8LayerPresion() { + if (!IsInt8Mode() || !runtime_->GetBuilder()->platformHasFastInt8()) { + MS_LOG(WARNING) << "no int8 mode, not need layer presion."; + return; + } + + for (int i = 0; i < this->network_->getNbLayers(); ++i) { + auto layer = this->network_->getLayer(i); + if (layer->getType() != nvinfer1::LayerType::kCONSTANT && layer->getType() != nvinfer1::LayerType::kCONCATENATION && + layer->getType() != nvinfer1::LayerType::kSHAPE) { + layer->setPrecision(nvinfer1::DataType::kINT8); + } + + for (int j = 0; j < layer->getNbOutputs(); ++j) { + std::string tensorName = layer->getOutput(j)->getName(); + MS_LOG(DEBUG) << "Tensor: " << tensorName << ". OutputType: INT8"; + if (layer->getOutput(j)->isExecutionTensor()) { + layer->setOutputType(j, nvinfer1::DataType::kINT8); + } + } + } +} + +int TensorRTSubGraph::SetInt8DynamicRange() { + if (!IsInt8Mode() || !runtime_->GetBuilder()->platformHasFastInt8()) { + MS_LOG(DEBUG) << "no int8 mode, not need dynamic range."; + return RET_OK; + } + + // set dynamic range for network input tensor + for (int i = 0; i < this->network_->getNbInputs(); ++i) { + std::string tensorName = this->network_->getInput(i)->getName(); + if (dynamic_range_map_.find(tensorName) != dynamic_range_map_.end()) { + if (!this->network_->getInput(i)->setDynamicRange(-dynamic_range_map_.at(tensorName), + dynamic_range_map_.at(tensorName))) { + return RET_ERROR; + } + MS_LOG(INFO) << "set dynamic range for input tensor: " << tensorName + << " value: " << dynamic_range_map_.at(tensorName); + } else { + MS_LOG(INFO) << "missing network dynamic range for input tensor: " << tensorName; + } + } + + // set dynamic range for network output tensor + for (int i = 0; i < this->network_->getNbLayers(); ++i) { + auto lyr = this->network_->getLayer(i); + for (int j = 0, e = lyr->getNbOutputs(); j < e; ++j) { + std::string tensorName = lyr->getOutput(j)->getName(); + if (tensor_name_map_.find(tensorName) == tensor_name_map_.end()) { + MS_LOG(INFO) << "missing network dynamic range for output tensor: " << tensorName; + continue; + } + std::string mappingName = tensor_name_map_.at(tensorName); + if (dynamic_range_map_.find(mappingName) != dynamic_range_map_.end()) { + if (!lyr->getOutput(j)->setDynamicRange(-dynamic_range_map_.at(mappingName), + dynamic_range_map_.at(mappingName))) { + return RET_ERROR; + } + MS_LOG(INFO) << "set dynamic range for output tensor: " << tensorName + << " value: " << dynamic_range_map_.at(mappingName); + } else { + MS_LOG(INFO) << "missing network dynamic range for output tensor: " << tensorName; + } + } + } + return RET_OK; +} + nvinfer1::ITensor *TensorRTSubGraph::SetTensorRTNetworkInput(const mindspore::MSTensor &in_tensor) { for (int i = 0; i < this->network_->getNbInputs(); i++) { if (in_tensor.Name().compare(this->network_->getInput(i)->getName()) == 0) { @@ -267,16 +447,29 @@ int TensorRTSubGraph::BuildTensorRTGraph() { MS_LOG(ERROR) << "Add op failed in TensorRT network"; return RET_ERROR; } + ret = GetTensorName(cur_op); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GetTensorName failed in TensorRT network"; + return RET_ERROR; + } } ret = MarkOutputs(); if (ret != RET_OK) { MS_LOG(ERROR) << "MarkOutputs failed in TensorRT network"; return ret; } + + ret = SetInt8DynamicRange(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SetInt8DynamicRange failed in TensorRT network"; + return ret; + } + std::string network_name = "network_" + std::string(network_->getInput(0)->getName()) + "_" + std::string(network_->getOutput(0)->getName()); network_->setName(network_name.c_str()); this->name_ = network_name; + SetInt8LayerPresion(); ret = BuildEngine(); if (ret != RET_OK) { MS_LOG(ERROR) << "Create engine failed in TensorRT network"; @@ -285,6 +478,18 @@ int TensorRTSubGraph::BuildTensorRTGraph() { return RET_OK; } +int TensorRTSubGraph::GetTensorName(TensorRTOp *cur_op) { + auto op_tensor_map = cur_op->GetTensorNameMap(); + std::unordered_map::iterator iter; + for (iter = op_tensor_map.begin(); iter != op_tensor_map.end(); ++iter) { + if (tensor_name_map_.find(iter->first) == tensor_name_map_.end()) { + tensor_name_map_[iter->first] = iter->second; + MS_LOG(DEBUG) << "UpdateTensorName mapping: " << iter->first << " value: " << iter->second; + } + } + return RET_OK; +} + int TensorRTSubGraph::MarkOutputs() { // Mark NetWork Output Tensor. for (auto out_tensor : outputs_) { diff --git a/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.h b/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.h index b82b2ce656f..242837097d1 100644 --- a/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.h +++ b/mindspore/lite/src/delegate/tensorrt/tensorrt_subgraph.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "include/api/kernel.h" #include "src/delegate/tensorrt/tensorrt_runtime.h" @@ -85,6 +86,18 @@ class TensorRTSubGraph : public kernel::Kernel { int SetDeviceConfig(cudaStream_t stream); + bool IsInt8Mode(); + + void SetInt8LayerPresion(); + + int GetInt8DynamicRange(); + + int SetInt8DynamicRange(); + + int GetTensorName(TensorRTOp *cur_op); + + bool SupportFP16(); + nvinfer1::ITensor *SetTensorRTNetworkInput(const mindspore::MSTensor &in_tensor); ITensorHelper FindTensorRTInputs(TensorRTOp *cur_op, const mindspore::MSTensor &in_tensor); @@ -124,6 +137,9 @@ class TensorRTSubGraph : public kernel::Kernel { std::vector cache_const_inputs_; std::map network_cache_tensor_info_; + std::unordered_map dynamic_range_map_; + std::unordered_map tensor_name_map_; + nvinfer1::INetworkDefinition *network_{nullptr}; nvinfer1::IBuilderConfig *config_{nullptr}; nvinfer1::ICudaEngine *engine_{nullptr}; diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index d597bcbf875..f6ac6d5b2dc 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -158,6 +158,8 @@ void LiteSession::ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lit quant_arg.roundType = quant_param->roundType(); quant_arg.multiplier = quant_param->multiplier(); quant_arg.dstDtype = quant_param->dstDtype(); + quant_arg.min = quant_param->min(); + quant_arg.max = quant_param->max(); } dst_tensor->AddQuantParam(quant_arg); } diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h index 5b255912319..1d3339be782 100644 --- a/mindspore/lite/src/tensor.h +++ b/mindspore/lite/src/tensor.h @@ -48,6 +48,9 @@ struct LiteQuantParam { int roundType{1}; int multiplier{1}; int dstDtype{32}; + // dynamic range + double min{-255.0}; + double max{255.0}; }; class Tensor : public mindspore::tensor::MSTensor { diff --git a/mindspore/lite/test/config/models_gpu_fp32.cfg b/mindspore/lite/test/config/models_gpu_fp32.cfg index b9ea3394d7c..9a09a2c4a03 100644 --- a/mindspore/lite/test/config/models_gpu_fp32.cfg +++ b/mindspore/lite/test/config/models_gpu_fp32.cfg @@ -301,7 +301,7 @@ porseg_tmp.onnx;2:img,prev_mask hiai_nlu_onnx_model_v1_0.onnx;3:input_ids,segment_ids,position_ids ml_video_edit_makeup_mobilenetv203.onnx;1:input.1 Q888_CV_face_recognition_self.onnx;1:input -ml_video_edit_hair_dyeing_migrate_v2_fix.onnx;4 3 +#ml_video_edit_hair_dyeing_migrate_v2_fix.onnx;4 3 ml_motion_capture_spin_mobile_mv3_v3_57mm_sim.onnx;5:input,bbox,init_pose,init_shape,init_cam ml_video_edit_dimming_tech_model_345000_color.onnx;2:input.18,1 Ireland_gaze_corrector.onnx;3:image,target_angle,strength 1 diff --git a/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc b/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc index 7a375474668..51f535303a2 100644 --- a/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc +++ b/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc @@ -168,8 +168,11 @@ int QuantParamParser::ParseTargetDevice(const std::string &target_device_str, qu if (target_device_str == "KIRIN") { (*target_device) = quant::KIRIN; return RET_OK; + } else if (target_device_str == "NVGPU") { + (*target_device) = quant::NVGPU; + return RET_OK; } else { - MS_LOG(ERROR) << "INPUT ILLEGAL: target_device must be KIRIN."; + MS_LOG(ERROR) << "INPUT ILLEGAL: target_device must be KIRIN or NVGPU."; return RET_INPUT_PARAM_INVALID; } } diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 934cc77e881..65406ce6141 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -149,7 +149,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { { Optimizer forming_model_optimizer; forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); - forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); + forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass(ctx)); forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass()); forming_model_optimizer.AddPass(new (std::nothrow) ConvertFP32ToFP16Pass(ctx.saveFP16)); status = forming_model_optimizer.Run(graph_defT_); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.cc index ff63bcee79d..3bc1675464e 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.cc @@ -23,8 +23,10 @@ STATUS SetUnusedQuantParamToDefaultPass::Run(schema::MetaGraphT *graph) { for (auto &tensor : graph->allTensors) { bool has_quant_param = false; for (auto &quant_param : tensor->quantParams) { - quant_param->min = 0.0; - quant_param->max = 0.0; + if (ctx_.fullQuantParam.target_device != quant::NVGPU) { + quant_param->min = 0.0; + quant_param->max = 0.0; + } quant_param->narrowRange = true; if (quant_param->inited) { has_quant_param = true; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h index fbefcf19613..ec397421aee 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h @@ -17,16 +17,21 @@ #define LITE_UNUSED_QUANT_PARAM_DATA_REMOVE_PASS_H #include #include "tools/converter/optimizer.h" +#include "tools/converter/converter_flags.h" #include "tools/common/graph_util.h" namespace mindspore { namespace lite { class SetUnusedQuantParamToDefaultPass : public GraphPass { public: SetUnusedQuantParamToDefaultPass() {} + explicit SetUnusedQuantParamToDefaultPass(const converter::Flags &ctx) : ctx_(ctx) {} ~SetUnusedQuantParamToDefaultPass() override = default; STATUS Run(schema::MetaGraphT *graph) override; + + private: + converter::Flags ctx_; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc index 5114f554f5a..00fa8fca0c2 100644 --- a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.cc @@ -249,6 +249,10 @@ int FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) { } } } else if (input_node->isa()) { + if (weight_data_type_ == kTypeUnknown) { + MS_LOG(INFO) << "weight not need parameter quant."; + continue; + } ret = DoParameterNodeQuant(cnode, input_node->cast(), i); if (ret == RET_NO_CHANGE) { continue; @@ -260,6 +264,10 @@ int FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) { weight_quant_params_bak[input_node->fullname_with_scope()] = primitive_quant_holder->get_input_quant_params()[i - 1]; } else if (input_node->isa()) { + if (weight_data_type_ == kTypeUnknown) { + MS_LOG(INFO) << "weight not need parameter quant."; + continue; + } ret = DoValueNodeQuant(cnode, input_node->cast(), i); if (ret == RET_NO_CHANGE) { continue; @@ -410,6 +418,17 @@ void FullQuantQuantizer::InitKirinConfig() { per_channel_ops_ = {prim::kPrimConv2DFusion}; } +void FullQuantQuantizer::InitNvGpuConfig() { + // `kTypeUnknown` represents the original data type + activation_target_data_type_ = kTypeUnknown; + activation_symmetry_ = true; + weight_data_type_ = kTypeUnknown; + weight_symmetry_ = true; + support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimFullConnection, prim::kPrimMatMul, + prim::kPrimConv2dTransposeFusion, prim::kPrimConv2dTransposeFusion}; + per_channel_ops_ = {prim::kPrimConv2DFusion, prim::kPrimMatMul, prim::kPrimFullConnection}; +} + void FullQuantQuantizer::InitQMinMax() { MS_ASSERT(activation_quant_data_type_ == kNumberTypeInt8 || activation_quant_data_type_ == kNumberTypeUInt8); if (activation_quant_data_type_ == kNumberTypeInt8) { @@ -464,6 +483,9 @@ int FullQuantQuantizer::PreProcess(const FuncGraphPtr &func_graph) { case KIRIN: InitKirinConfig(); break; + case NVGPU: + InitNvGpuConfig(); + break; default: MS_LOG(ERROR) << " Unsupported device " << flags_.fullQuantParam.target_device; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h index fea462d97ae..25d0e4bee1f 100644 --- a/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/full_quant_quantizer.h @@ -66,6 +66,7 @@ class FullQuantQuantizer : public Quantizer { int DoValueNodeQuant(const CNodePtr &cnode, const ValueNodePtr &input_node, size_t input_index); int IsSupportWeightQuant(const CNodePtr &cnode, const AnfNodePtr &input_node, size_t input_index); void InitQMinMax(); + void InitNvGpuConfig(); void InitCpuConfig(); void InitKirinConfig(); int MarkQuantNode(const FuncGraphPtr &func_graph);