support nvgpu quant

This commit is contained in:
albert.liyan@huawei.com 2022-01-08 15:24:11 +08:00
parent f6a20f1e62
commit 7b03487a68
20 changed files with 289 additions and 5 deletions

View File

@ -58,6 +58,8 @@ struct QuantParam {
int bit_num;
double scale;
int32_t zero_point;
double min;
double max;
};
class Allocator;

View File

@ -218,6 +218,8 @@ class MSTensor::Impl {
param.bit_num = lite_quant_params[i].bitNum;
param.scale = lite_quant_params[i].scale;
param.zero_point = lite_quant_params[i].zeroPoint;
param.min = lite_quant_params[i].min;
param.max = lite_quant_params[i].max;
quant_params.push_back(param);
}
return quant_params;

View File

@ -100,6 +100,7 @@ int ActivationTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
out_tensor->setName((op_name_ + "_output").c_str());
this->AddInnerOutTensors(
ITensorHelper{out_tensor, tensorrt_in_tensors_[0].format_, tensorrt_in_tensors_[0].same_format_});
tensor_name_map_[out_tensor->getName()] = op_name_;
return RET_OK;
}
nvinfer1::IActivationLayer *ActivationTensorRT::AddActivation(nvinfer1::INetworkDefinition *network,

View File

@ -63,6 +63,7 @@ int ConvolutionTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
return RET_ERROR;
}
transpose_layer_in->setName((op_name_ + "_transpose2NCHW").c_str());
tensor_name_map_[transpose_layer_in->getOutput(0)->getName()] = op_name_;
conv_input = transpose_layer_in->getOutput(0);
}
@ -105,6 +106,7 @@ int ConvolutionTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
return RET_ERROR;
}
conv_layer->setName((op_name_ + "_conv").c_str());
tensor_name_map_[conv_layer->getOutput(0)->getName()] = op_name_;
// add params
SetAttributes(conv_op, conv_layer);
@ -124,6 +126,8 @@ int ConvolutionTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
}
activation_layer->getOutput(0)->setName((op_name_ + "_output").c_str());
this->AddInnerOutTensors(ITensorHelper{activation_layer->getOutput(0), Format::NCHW, false});
// add tensor name mapping
tensor_name_map_[activation_layer->getOutput(0)->getName()] = op_name_;
return RET_OK;
}

View File

@ -60,6 +60,7 @@ int DeconvolutionTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
return RET_ERROR;
}
transpose_layer_in->setName((op_name_ + "_transpose2NCHW").c_str());
tensor_name_map_[transpose_layer_in->getOutput(0)->getName()] = op_name_;
deconv_input = transpose_layer_in->getOutput(0);
}
@ -103,6 +104,8 @@ int DeconvolutionTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
}
deconv_layer->setName((op_name_ + "_deconv").c_str());
// add tensor name mapping
tensor_name_map_[deconv_layer->getOutput(0)->getName()] = op_name_;
// set extra params
SetAttributes(deconv_op, deconv_layer);
@ -121,6 +124,7 @@ int DeconvolutionTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
}
activation_layer->getOutput(0)->setName((op_name_ + "_output").c_str());
this->AddInnerOutTensors(ITensorHelper{activation_layer->getOutput(0), Format::NCHW, false});
tensor_name_map_[activation_layer->getOutput(0)->getName()] = op_name_;
return RET_OK;
}

View File

@ -68,6 +68,7 @@ int MatMulTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
network->addMatrixMultiply(*matmul_a.trt_tensor_, transpose_a_, *matmul_b.trt_tensor_, transpose_b_);
matmul_layer->setName(op_name_.c_str());
nvinfer1::ITensor *out_tensor = matmul_layer->getOutput(0);
tensor_name_map_[matmul_layer->getOutput(0)->getName()] = op_name_;
if (in_tensors_.size() == BIAS_INDEX + 1) {
nvinfer1::ITensor *bias = nullptr;

View File

@ -63,6 +63,7 @@ int PoolTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
return RET_ERROR;
}
transpose_layer_in->setName((op_name_ + "_transpose2NCHW").c_str());
tensor_name_map_[transpose_layer_in->getOutput(0)->getName()] = op_name_;
pool_input = transpose_layer_in->getOutput(0);
}
@ -79,6 +80,7 @@ int PoolTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
}
AddParams(pooling_layer);
pooling_layer->setName(op_name_.c_str());
tensor_name_map_[pooling_layer->getOutput(0)->getName()] = op_name_;
// add activation
nvinfer1::ILayer *activation_layer = nullptr;
@ -96,6 +98,7 @@ int PoolTensorRT::AddInnerOp(nvinfer1::INetworkDefinition *network) {
nvinfer1::ITensor *out_trt_tensor = activation_layer->getOutput(0);
out_trt_tensor->setName((op_name_ + "_output").c_str());
this->AddInnerOutTensors(ITensorHelper{out_trt_tensor, Format::NCHW, false});
tensor_name_map_[activation_layer->getOutput(0)->getName()] = op_name_;
MS_LOG(DEBUG) << "output " << GetTensorFormat(tensorrt_out_tensors_[0]);
return RET_OK;
}

View File

@ -16,6 +16,7 @@
#include "src/delegate/tensorrt/op/tensorrt_op.h"
#include "src/delegate/tensorrt/tensorrt_runtime.h"
#include <unordered_map>
namespace mindspore::lite {
const schema::Primitive *TensorRTOp::GetPrimitive() { return this->op_primitive_; }
@ -46,6 +47,8 @@ const std::vector<TensorRTOp *> &TensorRTOp::out_ops() const { return this->out_
void TensorRTOp::SetRuntime(TensorRTRuntime *runtime) { this->runtime_ = runtime; }
std::unordered_map<std::string, std::string> TensorRTOp::GetTensorNameMap() { return this->tensor_name_map_; }
bool TensorRTOp::IsShapeKnown() {
if (this->in_tensors_.size() == 1 && this->in_tensors_[0].Shape().size() == 0) {
return false;

View File

@ -20,6 +20,7 @@
#include <NvInfer.h>
#include <string>
#include <vector>
#include <unordered_map>
#include "include/api/kernel.h"
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
@ -104,6 +105,8 @@ class TensorRTOp {
DynamicShapeParams GetDynamicShapeParams() const;
std::unordered_map<std::string, std::string> GetTensorNameMap();
protected:
bool IsShapeKnown();
@ -130,6 +133,8 @@ class TensorRTOp {
TensorRTRuntime *runtime_{nullptr};
DynamicShapeParams dynamic_shape_params_;
std::unordered_map<std::string, std::string> tensor_name_map_;
};
template <class T>

View File

@ -18,6 +18,7 @@
#include <cuda_runtime_api.h>
#include <string>
#include <vector>
#include <unordered_map>
#include <set>
#include <queue>
#include "src/delegate/delegate_utils.h"
@ -84,6 +85,51 @@ int TensorRTSubGraph::Init(cudaStream_t stream) {
input_hw_index_ = -1;
}
}
if (GetInt8DynamicRange() != RET_OK) {
MS_LOG(WARNING) << "get tensorrt dynamic range failed.";
}
if (SetDeviceConfig(stream) != RET_OK) {
MS_LOG(WARNING) << "set tensorrt config failed.";
}
return RET_OK;
}
int TensorRTSubGraph::GetInt8DynamicRange() {
if (!IsInt8Mode() || !runtime_->GetBuilder()->platformHasFastInt8()) {
MS_LOG(WARNING) << "no int8 mode, not need dynamic range.";
}
// input tensor
for (size_t i = 0; i < inputs_.size(); i++) {
auto quant_params = inputs_[i].QuantParams();
auto tensor_name = inputs_[i].Name();
for (auto param : quant_params) {
dynamic_range_map_[tensor_name] = param.max;
}
}
// output tensor
for (size_t i = 0; i < outputs_.size(); i++) {
auto quant_params = outputs_[i].QuantParams();
auto tensor_name = outputs_[i].Name();
for (auto param : quant_params) {
dynamic_range_map_[tensor_name] = param.max;
}
}
// op tensor
for (auto cur_op : all_ops_) {
for (auto in_tensor : cur_op->inputs()) {
auto tensor_name = in_tensor.Name();
for (auto param : in_tensor.QuantParams()) {
dynamic_range_map_[tensor_name] = param.max;
}
}
for (auto out_tensor : cur_op->outputs()) {
auto tensor_name = out_tensor.Name();
for (auto param : out_tensor.QuantParams()) {
dynamic_range_map_[tensor_name] = param.max;
}
}
}
return RET_OK;
}
@ -122,6 +168,16 @@ int TensorRTSubGraph::SetDeviceConfig(cudaStream_t stream) {
config_->setFlag(nvinfer1::BuilderFlag::kFP16);
}
// set int8
if (IsInt8Mode() && runtime_->GetBuilder()->platformHasFastInt8()) {
MS_LOG(INFO) << "set int8 flag successfully for tensorrt.";
config_->setFlag(nvinfer1::BuilderFlag::kINT8);
// Mark calibrator as null
config_->setInt8Calibrator(nullptr);
input_hw_index_ = -1;
} else {
MS_LOG(INFO) << "inputs no quant params or platform not support int8.";
}
config_->setProfileStream(stream);
// config setMaxWorkspaceSize to 1152 MB for max limit
@ -129,6 +185,130 @@ int TensorRTSubGraph::SetDeviceConfig(cudaStream_t stream) {
return RET_OK;
}
bool TensorRTSubGraph::SupportFP16() {
int deviceCnt = 0;
cudaError ret = cudaGetDeviceCount(&deviceCnt);
if (ret != cudaSuccess) {
MS_LOG(ERROR) << "cudaGetDeviceCount failed.";
return false;
}
std::vector<std::string> supportFP16_versions{"5.3", "6.0", "6.2", "7.0", "7.2", "7.5", "8.0", "8.6"};
cudaDeviceProp prop;
std::string version;
for (int dev = 0; dev < deviceCnt; dev++) {
ret = cudaGetDeviceProperties(&prop, dev);
if (ret != cudaSuccess) {
MS_LOG(ERROR) << "cuDeviceGetAttribute failed.";
return false;
}
version = std::to_string(prop.major) + "." + std::to_string(prop.minor);
if (std::find(supportFP16_versions.begin(), supportFP16_versions.end(), version) != supportFP16_versions.end()) {
MS_LOG(INFO) << "cuda device version is: " << version << ", support FP16, set enable FP16 tag successful";
return true;
}
}
MS_LOG(WARNING) << "cuda device version is: " << version << ", don't support FP16, set enable FP16 tag failed";
return false;
}
bool TensorRTSubGraph::IsInt8Mode() {
bool isInt8Mode = false;
for (auto cur_op : all_ops_) {
for (auto in_tensor : cur_op->inputs()) {
if (cur_op->inputs().front().QuantParams().empty()) {
continue;
}
auto quant_param = cur_op->inputs().front().QuantParams().front();
if (quant_param.max > 0) {
isInt8Mode = true;
break;
}
}
for (auto out_tensor : cur_op->outputs()) {
if (cur_op->outputs().front().QuantParams().empty()) {
continue;
}
auto quant_param = cur_op->outputs().front().QuantParams().front();
if (quant_param.max > 0) {
isInt8Mode = true;
break;
}
}
}
return isInt8Mode;
}
void TensorRTSubGraph::SetInt8LayerPresion() {
if (!IsInt8Mode() || !runtime_->GetBuilder()->platformHasFastInt8()) {
MS_LOG(WARNING) << "no int8 mode, not need layer presion.";
return;
}
for (int i = 0; i < this->network_->getNbLayers(); ++i) {
auto layer = this->network_->getLayer(i);
if (layer->getType() != nvinfer1::LayerType::kCONSTANT && layer->getType() != nvinfer1::LayerType::kCONCATENATION &&
layer->getType() != nvinfer1::LayerType::kSHAPE) {
layer->setPrecision(nvinfer1::DataType::kINT8);
}
for (int j = 0; j < layer->getNbOutputs(); ++j) {
std::string tensorName = layer->getOutput(j)->getName();
MS_LOG(DEBUG) << "Tensor: " << tensorName << ". OutputType: INT8";
if (layer->getOutput(j)->isExecutionTensor()) {
layer->setOutputType(j, nvinfer1::DataType::kINT8);
}
}
}
}
int TensorRTSubGraph::SetInt8DynamicRange() {
if (!IsInt8Mode() || !runtime_->GetBuilder()->platformHasFastInt8()) {
MS_LOG(DEBUG) << "no int8 mode, not need dynamic range.";
return RET_OK;
}
// set dynamic range for network input tensor
for (int i = 0; i < this->network_->getNbInputs(); ++i) {
std::string tensorName = this->network_->getInput(i)->getName();
if (dynamic_range_map_.find(tensorName) != dynamic_range_map_.end()) {
if (!this->network_->getInput(i)->setDynamicRange(-dynamic_range_map_.at(tensorName),
dynamic_range_map_.at(tensorName))) {
return RET_ERROR;
}
MS_LOG(INFO) << "set dynamic range for input tensor: " << tensorName
<< " value: " << dynamic_range_map_.at(tensorName);
} else {
MS_LOG(INFO) << "missing network dynamic range for input tensor: " << tensorName;
}
}
// set dynamic range for network output tensor
for (int i = 0; i < this->network_->getNbLayers(); ++i) {
auto lyr = this->network_->getLayer(i);
for (int j = 0, e = lyr->getNbOutputs(); j < e; ++j) {
std::string tensorName = lyr->getOutput(j)->getName();
if (tensor_name_map_.find(tensorName) == tensor_name_map_.end()) {
MS_LOG(INFO) << "missing network dynamic range for output tensor: " << tensorName;
continue;
}
std::string mappingName = tensor_name_map_.at(tensorName);
if (dynamic_range_map_.find(mappingName) != dynamic_range_map_.end()) {
if (!lyr->getOutput(j)->setDynamicRange(-dynamic_range_map_.at(mappingName),
dynamic_range_map_.at(mappingName))) {
return RET_ERROR;
}
MS_LOG(INFO) << "set dynamic range for output tensor: " << tensorName
<< " value: " << dynamic_range_map_.at(mappingName);
} else {
MS_LOG(INFO) << "missing network dynamic range for output tensor: " << tensorName;
}
}
}
return RET_OK;
}
nvinfer1::ITensor *TensorRTSubGraph::SetTensorRTNetworkInput(const mindspore::MSTensor &in_tensor) {
for (int i = 0; i < this->network_->getNbInputs(); i++) {
if (in_tensor.Name().compare(this->network_->getInput(i)->getName()) == 0) {
@ -267,16 +447,29 @@ int TensorRTSubGraph::BuildTensorRTGraph() {
MS_LOG(ERROR) << "Add op failed in TensorRT network";
return RET_ERROR;
}
ret = GetTensorName(cur_op);
if (ret != RET_OK) {
MS_LOG(ERROR) << "GetTensorName failed in TensorRT network";
return RET_ERROR;
}
}
ret = MarkOutputs();
if (ret != RET_OK) {
MS_LOG(ERROR) << "MarkOutputs failed in TensorRT network";
return ret;
}
ret = SetInt8DynamicRange();
if (ret != RET_OK) {
MS_LOG(ERROR) << "SetInt8DynamicRange failed in TensorRT network";
return ret;
}
std::string network_name =
"network_" + std::string(network_->getInput(0)->getName()) + "_" + std::string(network_->getOutput(0)->getName());
network_->setName(network_name.c_str());
this->name_ = network_name;
SetInt8LayerPresion();
ret = BuildEngine();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Create engine failed in TensorRT network";
@ -285,6 +478,18 @@ int TensorRTSubGraph::BuildTensorRTGraph() {
return RET_OK;
}
int TensorRTSubGraph::GetTensorName(TensorRTOp *cur_op) {
auto op_tensor_map = cur_op->GetTensorNameMap();
std::unordered_map<std::string, std::string>::iterator iter;
for (iter = op_tensor_map.begin(); iter != op_tensor_map.end(); ++iter) {
if (tensor_name_map_.find(iter->first) == tensor_name_map_.end()) {
tensor_name_map_[iter->first] = iter->second;
MS_LOG(DEBUG) << "UpdateTensorName mapping: " << iter->first << " value: " << iter->second;
}
}
return RET_OK;
}
int TensorRTSubGraph::MarkOutputs() {
// Mark NetWork Output Tensor.
for (auto out_tensor : outputs_) {

View File

@ -20,6 +20,7 @@
#include <map>
#include <string>
#include <vector>
#include <unordered_map>
#include <memory>
#include "include/api/kernel.h"
#include "src/delegate/tensorrt/tensorrt_runtime.h"
@ -85,6 +86,18 @@ class TensorRTSubGraph : public kernel::Kernel {
int SetDeviceConfig(cudaStream_t stream);
bool IsInt8Mode();
void SetInt8LayerPresion();
int GetInt8DynamicRange();
int SetInt8DynamicRange();
int GetTensorName(TensorRTOp *cur_op);
bool SupportFP16();
nvinfer1::ITensor *SetTensorRTNetworkInput(const mindspore::MSTensor &in_tensor);
ITensorHelper FindTensorRTInputs(TensorRTOp *cur_op, const mindspore::MSTensor &in_tensor);
@ -124,6 +137,9 @@ class TensorRTSubGraph : public kernel::Kernel {
std::vector<mindspore::MSTensor> cache_const_inputs_;
std::map<std::string, CacheTensorInfo> network_cache_tensor_info_;
std::unordered_map<std::string, float> dynamic_range_map_;
std::unordered_map<std::string, std::string> tensor_name_map_;
nvinfer1::INetworkDefinition *network_{nullptr};
nvinfer1::IBuilderConfig *config_{nullptr};
nvinfer1::ICudaEngine *engine_{nullptr};

View File

@ -158,6 +158,8 @@ void LiteSession::ConvertTensorsQuantParam(const schema::Tensor *src_tensor, lit
quant_arg.roundType = quant_param->roundType();
quant_arg.multiplier = quant_param->multiplier();
quant_arg.dstDtype = quant_param->dstDtype();
quant_arg.min = quant_param->min();
quant_arg.max = quant_param->max();
}
dst_tensor->AddQuantParam(quant_arg);
}

View File

@ -48,6 +48,9 @@ struct LiteQuantParam {
int roundType{1};
int multiplier{1};
int dstDtype{32};
// dynamic range
double min{-255.0};
double max{255.0};
};
class Tensor : public mindspore::tensor::MSTensor {

View File

@ -301,7 +301,7 @@ porseg_tmp.onnx;2:img,prev_mask
hiai_nlu_onnx_model_v1_0.onnx;3:input_ids,segment_ids,position_ids
ml_video_edit_makeup_mobilenetv203.onnx;1:input.1
Q888_CV_face_recognition_self.onnx;1:input
ml_video_edit_hair_dyeing_migrate_v2_fix.onnx;4 3
#ml_video_edit_hair_dyeing_migrate_v2_fix.onnx;4 3
ml_motion_capture_spin_mobile_mv3_v3_57mm_sim.onnx;5:input,bbox,init_pose,init_shape,init_cam
ml_video_edit_dimming_tech_model_345000_color.onnx;2:input.18,1
Ireland_gaze_corrector.onnx;3:image,target_angle,strength 1

View File

@ -168,8 +168,11 @@ int QuantParamParser::ParseTargetDevice(const std::string &target_device_str, qu
if (target_device_str == "KIRIN") {
(*target_device) = quant::KIRIN;
return RET_OK;
} else if (target_device_str == "NVGPU") {
(*target_device) = quant::NVGPU;
return RET_OK;
} else {
MS_LOG(ERROR) << "INPUT ILLEGAL: target_device must be KIRIN.";
MS_LOG(ERROR) << "INPUT ILLEGAL: target_device must be KIRIN or NVGPU.";
return RET_INPUT_PARAM_INVALID;
}
}

View File

@ -149,7 +149,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
{
Optimizer forming_model_optimizer;
forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass(ctx));
forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass());
forming_model_optimizer.AddPass(new (std::nothrow) ConvertFP32ToFP16Pass(ctx.saveFP16));
status = forming_model_optimizer.Run(graph_defT_);

View File

@ -23,8 +23,10 @@ STATUS SetUnusedQuantParamToDefaultPass::Run(schema::MetaGraphT *graph) {
for (auto &tensor : graph->allTensors) {
bool has_quant_param = false;
for (auto &quant_param : tensor->quantParams) {
quant_param->min = 0.0;
quant_param->max = 0.0;
if (ctx_.fullQuantParam.target_device != quant::NVGPU) {
quant_param->min = 0.0;
quant_param->max = 0.0;
}
quant_param->narrowRange = true;
if (quant_param->inited) {
has_quant_param = true;

View File

@ -17,16 +17,21 @@
#define LITE_UNUSED_QUANT_PARAM_DATA_REMOVE_PASS_H
#include <memory>
#include "tools/converter/optimizer.h"
#include "tools/converter/converter_flags.h"
#include "tools/common/graph_util.h"
namespace mindspore {
namespace lite {
class SetUnusedQuantParamToDefaultPass : public GraphPass {
public:
SetUnusedQuantParamToDefaultPass() {}
explicit SetUnusedQuantParamToDefaultPass(const converter::Flags &ctx) : ctx_(ctx) {}
~SetUnusedQuantParamToDefaultPass() override = default;
STATUS Run(schema::MetaGraphT *graph) override;
private:
converter::Flags ctx_;
};
} // namespace lite
} // namespace mindspore

View File

@ -249,6 +249,10 @@ int FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
}
}
} else if (input_node->isa<mindspore::Parameter>()) {
if (weight_data_type_ == kTypeUnknown) {
MS_LOG(INFO) << "weight not need parameter quant.";
continue;
}
ret = DoParameterNodeQuant(cnode, input_node->cast<ParameterPtr>(), i);
if (ret == RET_NO_CHANGE) {
continue;
@ -260,6 +264,10 @@ int FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
weight_quant_params_bak[input_node->fullname_with_scope()] =
primitive_quant_holder->get_input_quant_params()[i - 1];
} else if (input_node->isa<mindspore::ValueNode>()) {
if (weight_data_type_ == kTypeUnknown) {
MS_LOG(INFO) << "weight not need parameter quant.";
continue;
}
ret = DoValueNodeQuant(cnode, input_node->cast<ValueNodePtr>(), i);
if (ret == RET_NO_CHANGE) {
continue;
@ -410,6 +418,17 @@ void FullQuantQuantizer::InitKirinConfig() {
per_channel_ops_ = {prim::kPrimConv2DFusion};
}
void FullQuantQuantizer::InitNvGpuConfig() {
// `kTypeUnknown` represents the original data type
activation_target_data_type_ = kTypeUnknown;
activation_symmetry_ = true;
weight_data_type_ = kTypeUnknown;
weight_symmetry_ = true;
support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimFullConnection, prim::kPrimMatMul,
prim::kPrimConv2dTransposeFusion, prim::kPrimConv2dTransposeFusion};
per_channel_ops_ = {prim::kPrimConv2DFusion, prim::kPrimMatMul, prim::kPrimFullConnection};
}
void FullQuantQuantizer::InitQMinMax() {
MS_ASSERT(activation_quant_data_type_ == kNumberTypeInt8 || activation_quant_data_type_ == kNumberTypeUInt8);
if (activation_quant_data_type_ == kNumberTypeInt8) {
@ -464,6 +483,9 @@ int FullQuantQuantizer::PreProcess(const FuncGraphPtr &func_graph) {
case KIRIN:
InitKirinConfig();
break;
case NVGPU:
InitNvGpuConfig();
break;
default:
MS_LOG(ERROR) << " Unsupported device " << flags_.fullQuantParam.target_device;
return RET_ERROR;

View File

@ -66,6 +66,7 @@ class FullQuantQuantizer : public Quantizer {
int DoValueNodeQuant(const CNodePtr &cnode, const ValueNodePtr &input_node, size_t input_index);
int IsSupportWeightQuant(const CNodePtr &cnode, const AnfNodePtr &input_node, size_t input_index);
void InitQMinMax();
void InitNvGpuConfig();
void InitCpuConfig();
void InitKirinConfig();
int MarkQuantNode(const FuncGraphPtr &func_graph);