From 564414aa984abc73cbf06db45d9db5e6adc84b67 Mon Sep 17 00:00:00 2001 From: wilfChen Date: Wed, 7 Apr 2021 16:19:16 +0800 Subject: [PATCH] trt pass set device id --- .../optimizer/trt_pass/trt_converter_context.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.cc b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.cc index af38f5b64ea..1e9403ca1af 100644 --- a/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.cc +++ b/mindspore/ccsrc/backend/optimizer/trt_pass/trt_converter_context.cc @@ -24,11 +24,13 @@ #include #include #include "runtime/device/gpu/trt_loader.h" +#include "runtime/device/gpu/cuda_driver.h" #include "backend/optimizer/trt_pass/trt_op_factory.h" #include "backend/kernel_compiler/gpu/trt/trt_utils.h" #include "utils/convert_utils.h" #include "utils/utils.h" #include "utils/singleton.h" +#include "utils/ms_context.h" namespace mindspore::opt { namespace { @@ -121,6 +123,15 @@ void GetRealInputs(const AnfNodePtr &node, std::vector } // namespace bool TrtConverterContext::Init() { + // Set device id before invoke trt api as cudaSetDevice is thread level config. + const auto &context = MsContext::GetInstance(); + const auto &device_id = context->get_param(MS_CTX_DEVICE_ID); + bool ret = device::gpu::CudaDriver::SetDevice(UintToInt(device_id)); + if (!ret) { + MS_LOG(ERROR) << "Failed to set device id:" << device_id; + return false; + } + auto trt_loader = Singleton::Instance(); builder_ = trt_loader.CreateInferBuilder(&Singleton::Instance()); MS_EXCEPTION_IF_NULL(builder_);