gpu inference mixed precision

This commit is contained in:
wilfChen 2021-04-25 09:53:47 +08:00
parent 6801ef61e0
commit ba9bbfadf8
6 changed files with 36 additions and 3 deletions

View File

@ -116,8 +116,20 @@ class MS_API NvidiaGPUDeviceInfo : public DeviceInfoContext {
void SetGpuTrtInferMode(bool gpu_trt_infer_mode);
bool GetGpuTrtInferMode() const;
inline void SetPrecisionMode(const std::string &precison_mode);
inline std::string GetPrecisionMode() const;
private:
void SetPrecisionMode(const std::vector<char> &precision_mode);
std::vector<char> GetPrecisionModeChar() const;
};
void NvidiaGPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
SetPrecisionMode(StringToChar(precision_mode));
}
std::string NvidiaGPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
public:
enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };

View File

@ -198,6 +198,15 @@ bool TrtConverterContext::Serialize(std::string *model) {
MS_EXCEPTION_IF_NULL(model);
builder_->setMaxBatchSize(batch_size_);
config_->setMaxWorkspaceSize(workspace_size_);
// Set precision mode
const auto &context = MsContext::GetInstance();
const auto &precision_mode = context->get_param<std::string>(MS_CTX_INFER_PRECISION_MODE);
if (precision_mode == "fp16") {
MS_LOG(WARNING) << "Inference with mixed precision mode. It will take few minutes for operators selection.";
config_->setFlag(nvinfer1::BuilderFlag::kFP16);
}
engine_ = TrtPtr(builder_->buildEngineWithConfig(*network_, *config_));
MS_EXCEPTION_IF_NULL(engine_);

View File

@ -27,6 +27,7 @@ constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequ
constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
constexpr auto kModelOptionNvidiaGpuDeviceID = kModelOptionDeviceID;
constexpr auto kModelOptionNvidiaGpuTrtInferMode = "mindspore.option.nvidia_gpu.trt_infer_mode";
constexpr auto kModelOptionNvidiaGpuPrecisionMode = "mindspore.option.nvidia_gpu.precision_mode";
constexpr auto kModelOptionAscend910DeviceID = kModelOptionDeviceID;
constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
constexpr auto kModelOptionAscend310DumpCfgPath = "mindspore.option.ascend310.dump_config_file_path";
@ -153,6 +154,16 @@ bool NvidiaGPUDeviceInfo::GetGpuTrtInferMode() const {
return GetValue<bool>(data_, kModelOptionNvidiaGpuTrtInferMode);
}
void NvidiaGPUDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionNvidiaGpuPrecisionMode] = CharToString(precision_mode);
}
std::vector<char> NvidiaGPUDeviceInfo::GetPrecisionModeChar() const {
MS_EXCEPTION_IF_NULL(data_);
const std::string &ref = GetValue<std::string>(data_, kModelOptionNvidiaGpuPrecisionMode);
return StringToChar(ref);
}
void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) {
MS_EXCEPTION_IF_NULL(data_);
data_->params[kModelOptionAscend910DeviceID] = device_id;

View File

@ -63,8 +63,8 @@ Status GPUGraphImpl::InitEnv() {
if (gpu_info == nullptr) {
return kMCDeviceError;
}
auto enable_trt = gpu_info->GetGpuTrtInferMode();
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, enable_trt);
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, gpu_info->GetGpuTrtInferMode());
ms_context->set_param<std::string>(MS_CTX_INFER_PRECISION_MODE, gpu_info->GetPrecisionMode());
session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice);
if (session_impl_ == nullptr) {

View File

@ -31,7 +31,7 @@ class TrtLoader {
std::shared_ptr<nvinfer1::IBuilder> CreateInferBuilder(nvinfer1::ILogger *logger);
std::shared_ptr<nvinfer1::IRuntime> CreateInferRuntime(nvinfer1::ILogger *logger);
bool nvinfer_loaded() { return nvinfer_loaded_; }
bool nvinfer_loaded() const { return nvinfer_loaded_; }
private:
bool nvinfer_loaded_;

View File

@ -118,6 +118,7 @@ enum MsCtxParam : unsigned {
MS_CTX_ENV_CONFIG_PATH,
MS_CTX_TUNE_MODE,
MS_CTX_GRAPH_KERNEL_FLAGS,
MS_CTX_INFER_PRECISION_MODE, // GPU inference precision mode configured by Serving or Unify API.
MS_CTX_TYPE_STRING_END,
// parameter numbers of each type