!14730 trt set device id

From: @wilfchen
Reviewed-by: @limingqi107,@cristoval
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-04-09 10:18:51 +08:00 committed by Gitee
commit 76345b4a46
1 changed files with 11 additions and 0 deletions

View File

@ -24,11 +24,13 @@
#include <sstream>
#include <algorithm>
#include "runtime/device/gpu/trt_loader.h"
#include "runtime/device/gpu/cuda_driver.h"
#include "backend/optimizer/trt_pass/trt_op_factory.h"
#include "backend/kernel_compiler/gpu/trt/trt_utils.h"
#include "utils/convert_utils.h"
#include "utils/utils.h"
#include "utils/singleton.h"
#include "utils/ms_context.h"
namespace mindspore::opt {
namespace {
@ -121,6 +123,15 @@ void GetRealInputs(const AnfNodePtr &node, std::vector<session::KernelWithIndex>
} // namespace
bool TrtConverterContext::Init() {
// Set device id before invoke trt api as cudaSetDevice is thread level config.
const auto &context = MsContext::GetInstance();
const auto &device_id = context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
bool ret = device::gpu::CudaDriver::SetDevice(UintToInt(device_id));
if (!ret) {
MS_LOG(ERROR) << "Failed to set device id:" << device_id;
return false;
}
auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance();
builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance());
MS_EXCEPTION_IF_NULL(builder_);