forked from mindspore-Ecosystem/mindspore
gpu inference mixed precision
This commit is contained in:
parent
6801ef61e0
commit
ba9bbfadf8
|
@ -116,8 +116,20 @@ class MS_API NvidiaGPUDeviceInfo : public DeviceInfoContext {
|
||||||
|
|
||||||
void SetGpuTrtInferMode(bool gpu_trt_infer_mode);
|
void SetGpuTrtInferMode(bool gpu_trt_infer_mode);
|
||||||
bool GetGpuTrtInferMode() const;
|
bool GetGpuTrtInferMode() const;
|
||||||
|
|
||||||
|
inline void SetPrecisionMode(const std::string &precison_mode);
|
||||||
|
inline std::string GetPrecisionMode() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
void SetPrecisionMode(const std::vector<char> &precision_mode);
|
||||||
|
std::vector<char> GetPrecisionModeChar() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void NvidiaGPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
|
||||||
|
SetPrecisionMode(StringToChar(precision_mode));
|
||||||
|
}
|
||||||
|
std::string NvidiaGPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
|
||||||
|
|
||||||
class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
|
class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
|
||||||
public:
|
public:
|
||||||
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
|
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
|
||||||
|
|
|
@ -198,6 +198,15 @@ bool TrtConverterContext::Serialize(std::string *model) {
|
||||||
MS_EXCEPTION_IF_NULL(model);
|
MS_EXCEPTION_IF_NULL(model);
|
||||||
builder_->setMaxBatchSize(batch_size_);
|
builder_->setMaxBatchSize(batch_size_);
|
||||||
config_->setMaxWorkspaceSize(workspace_size_);
|
config_->setMaxWorkspaceSize(workspace_size_);
|
||||||
|
|
||||||
|
// Set precision mode
|
||||||
|
const auto &context = MsContext::GetInstance();
|
||||||
|
const auto &precision_mode = context->get_param<std::string>(MS_CTX_INFER_PRECISION_MODE);
|
||||||
|
if (precision_mode == "fp16") {
|
||||||
|
MS_LOG(WARNING) << "Inference with mixed precision mode. It will take few minutes for operators selection.";
|
||||||
|
config_->setFlag(nvinfer1::BuilderFlag::kFP16);
|
||||||
|
}
|
||||||
|
|
||||||
engine_ = TrtPtr(builder_->buildEngineWithConfig(*network_, *config_));
|
engine_ = TrtPtr(builder_->buildEngineWithConfig(*network_, *config_));
|
||||||
MS_EXCEPTION_IF_NULL(engine_);
|
MS_EXCEPTION_IF_NULL(engine_);
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequ
|
||||||
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
|
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
|
||||||
constexpr auto kModelOptionNvidiaGpuDeviceID = kModelOptionDeviceID;
|
constexpr auto kModelOptionNvidiaGpuDeviceID = kModelOptionDeviceID;
|
||||||
constexpr auto kModelOptionNvidiaGpuTrtInferMode = "mindspore.option.nvidia_gpu.trt_infer_mode";
|
constexpr auto kModelOptionNvidiaGpuTrtInferMode = "mindspore.option.nvidia_gpu.trt_infer_mode";
|
||||||
|
constexpr auto kModelOptionNvidiaGpuPrecisionMode = "mindspore.option.nvidia_gpu.precision_mode";
|
||||||
constexpr auto kModelOptionAscend910DeviceID = kModelOptionDeviceID;
|
constexpr auto kModelOptionAscend910DeviceID = kModelOptionDeviceID;
|
||||||
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
|
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
|
||||||
constexpr auto kModelOptionAscend310DumpCfgPath = "mindspore.option.ascend310.dump_config_file_path";
|
constexpr auto kModelOptionAscend310DumpCfgPath = "mindspore.option.ascend310.dump_config_file_path";
|
||||||
|
@ -153,6 +154,16 @@ bool NvidiaGPUDeviceInfo::GetGpuTrtInferMode() const {
|
||||||
return GetValue<bool>(data_, kModelOptionNvidiaGpuTrtInferMode);
|
return GetValue<bool>(data_, kModelOptionNvidiaGpuTrtInferMode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void NvidiaGPUDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
data_->params[kModelOptionNvidiaGpuPrecisionMode] = CharToString(precision_mode);
|
||||||
|
}
|
||||||
|
std::vector<char> NvidiaGPUDeviceInfo::GetPrecisionModeChar() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
|
const std::string &ref = GetValue<std::string>(data_, kModelOptionNvidiaGpuPrecisionMode);
|
||||||
|
return StringToChar(ref);
|
||||||
|
}
|
||||||
|
|
||||||
void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) {
|
void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) {
|
||||||
MS_EXCEPTION_IF_NULL(data_);
|
MS_EXCEPTION_IF_NULL(data_);
|
||||||
data_->params[kModelOptionAscend910DeviceID] = device_id;
|
data_->params[kModelOptionAscend910DeviceID] = device_id;
|
||||||
|
|
|
@ -63,8 +63,8 @@ Status GPUGraphImpl::InitEnv() {
|
||||||
if (gpu_info == nullptr) {
|
if (gpu_info == nullptr) {
|
||||||
return kMCDeviceError;
|
return kMCDeviceError;
|
||||||
}
|
}
|
||||||
auto enable_trt = gpu_info->GetGpuTrtInferMode();
|
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, gpu_info->GetGpuTrtInferMode());
|
||||||
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, enable_trt);
|
ms_context->set_param<std::string>(MS_CTX_INFER_PRECISION_MODE, gpu_info->GetPrecisionMode());
|
||||||
|
|
||||||
session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice);
|
session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice);
|
||||||
if (session_impl_ == nullptr) {
|
if (session_impl_ == nullptr) {
|
||||||
|
|
|
@ -31,7 +31,7 @@ class TrtLoader {
|
||||||
|
|
||||||
std::shared_ptr<nvinfer1::IBuilder> CreateInferBuilder(nvinfer1::ILogger *logger);
|
std::shared_ptr<nvinfer1::IBuilder> CreateInferBuilder(nvinfer1::ILogger *logger);
|
||||||
std::shared_ptr<nvinfer1::IRuntime> CreateInferRuntime(nvinfer1::ILogger *logger);
|
std::shared_ptr<nvinfer1::IRuntime> CreateInferRuntime(nvinfer1::ILogger *logger);
|
||||||
bool nvinfer_loaded() { return nvinfer_loaded_; }
|
bool nvinfer_loaded() const { return nvinfer_loaded_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool nvinfer_loaded_;
|
bool nvinfer_loaded_;
|
||||||
|
|
|
@ -118,6 +118,7 @@ enum MsCtxParam : unsigned {
|
||||||
MS_CTX_ENV_CONFIG_PATH,
|
MS_CTX_ENV_CONFIG_PATH,
|
||||||
MS_CTX_TUNE_MODE,
|
MS_CTX_TUNE_MODE,
|
||||||
MS_CTX_GRAPH_KERNEL_FLAGS,
|
MS_CTX_GRAPH_KERNEL_FLAGS,
|
||||||
|
MS_CTX_INFER_PRECISION_MODE, // GPU inference precision mode configured by Serving or Unify API.
|
||||||
MS_CTX_TYPE_STRING_END,
|
MS_CTX_TYPE_STRING_END,
|
||||||
|
|
||||||
// parameter numbers of each type
|
// parameter numbers of each type
|
||||||
|
|
Loading…
Reference in New Issue