forked from mindspore-Ecosystem/mindspore
!15612 gpu mixed precision config
From: @wilfchen Reviewed-by: @cristoval,@limingqi107 Signed-off-by: @limingqi107
This commit is contained in:
commit
6b5f1bc93a
|
@ -116,8 +116,20 @@ class MS_API NvidiaGPUDeviceInfo : public DeviceInfoContext {
|
|||
|
||||
void SetGpuTrtInferMode(bool gpu_trt_infer_mode);
|
||||
bool GetGpuTrtInferMode() const;
|
||||
|
||||
inline void SetPrecisionMode(const std::string &precison_mode);
|
||||
inline std::string GetPrecisionMode() const;
|
||||
|
||||
private:
|
||||
void SetPrecisionMode(const std::vector<char> &precision_mode);
|
||||
std::vector<char> GetPrecisionModeChar() const;
|
||||
};
|
||||
|
||||
void NvidiaGPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
||||
SetPrecisionMode(StringToChar(precision_mode));
|
||||
}
|
||||
std::string NvidiaGPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
||||
|
||||
class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
|
||||
public:
|
||||
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
|
||||
|
|
|
@ -198,6 +198,15 @@ bool TrtConverterContext::Serialize(std::string *model) {
|
|||
MS_EXCEPTION_IF_NULL(model);
|
||||
builder_->setMaxBatchSize(batch_size_);
|
||||
config_->setMaxWorkspaceSize(workspace_size_);
|
||||
|
||||
// Set precision mode
|
||||
const auto &context = MsContext::GetInstance();
|
||||
const auto &precision_mode = context->get_param<std::string>(MS_CTX_INFER_PRECISION_MODE);
|
||||
if (precision_mode == "fp16") {
|
||||
MS_LOG(WARNING) << "Inference with mixed precision mode. It will take few minutes for operators selection.";
|
||||
config_->setFlag(nvinfer1::BuilderFlag::kFP16);
|
||||
}
|
||||
|
||||
engine_ = TrtPtr(builder_->buildEngineWithConfig(*network_, *config_));
|
||||
MS_EXCEPTION_IF_NULL(engine_);
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequ
|
|||
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
|
||||
constexpr auto kModelOptionNvidiaGpuDeviceID = kModelOptionDeviceID;
|
||||
constexpr auto kModelOptionNvidiaGpuTrtInferMode = "mindspore.option.nvidia_gpu.trt_infer_mode";
|
||||
constexpr auto kModelOptionNvidiaGpuPrecisionMode = "mindspore.option.nvidia_gpu.precision_mode";
|
||||
constexpr auto kModelOptionAscend910DeviceID = kModelOptionDeviceID;
|
||||
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
|
||||
constexpr auto kModelOptionAscend310DumpCfgPath = "mindspore.option.ascend310.dump_config_file_path";
|
||||
|
@ -153,6 +154,16 @@ bool NvidiaGPUDeviceInfo::GetGpuTrtInferMode() const {
|
|||
return GetValue<bool>(data_, kModelOptionNvidiaGpuTrtInferMode);
|
||||
}
|
||||
|
||||
void NvidiaGPUDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionNvidiaGpuPrecisionMode] = CharToString(precision_mode);
|
||||
}
|
||||
std::vector<char> NvidiaGPUDeviceInfo::GetPrecisionModeChar() const {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
const std::string &ref = GetValue<std::string>(data_, kModelOptionNvidiaGpuPrecisionMode);
|
||||
return StringToChar(ref);
|
||||
}
|
||||
|
||||
void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||
MS_EXCEPTION_IF_NULL(data_);
|
||||
data_->params[kModelOptionAscend910DeviceID] = device_id;
|
||||
|
|
|
@ -63,8 +63,8 @@ Status GPUGraphImpl::InitEnv() {
|
|||
if (gpu_info == nullptr) {
|
||||
return kMCDeviceError;
|
||||
}
|
||||
auto enable_trt = gpu_info->GetGpuTrtInferMode();
|
||||
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, enable_trt);
|
||||
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, gpu_info->GetGpuTrtInferMode());
|
||||
ms_context->set_param<std::string>(MS_CTX_INFER_PRECISION_MODE, gpu_info->GetPrecisionMode());
|
||||
|
||||
session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice);
|
||||
if (session_impl_ == nullptr) {
|
||||
|
|
|
@ -31,7 +31,7 @@ class TrtLoader {
|
|||
|
||||
std::shared_ptr<nvinfer1::IBuilder> CreateInferBuilder(nvinfer1::ILogger *logger);
|
||||
std::shared_ptr<nvinfer1::IRuntime> CreateInferRuntime(nvinfer1::ILogger *logger);
|
||||
bool nvinfer_loaded() { return nvinfer_loaded_; }
|
||||
bool nvinfer_loaded() const { return nvinfer_loaded_; }
|
||||
|
||||
private:
|
||||
bool nvinfer_loaded_;
|
||||
|
|
|
@ -118,6 +118,7 @@ enum MsCtxParam : unsigned {
|
|||
MS_CTX_ENV_CONFIG_PATH,
|
||||
MS_CTX_TUNE_MODE,
|
||||
MS_CTX_GRAPH_KERNEL_FLAGS,
|
||||
MS_CTX_INFER_PRECISION_MODE, // GPU inference precision mode configured by Serving or Unify API.
|
||||
MS_CTX_TYPE_STRING_END,
|
||||
|
||||
// parameter numbers of each type
|
||||
|
|
Loading…
Reference in New Issue