forked from mindspore-Ecosystem/mindspore
!14730 trt set device id
From: @wilfchen Reviewed-by: @limingqi107,@cristoval Signed-off-by:
This commit is contained in:
commit
76345b4a46
|
@ -24,11 +24,13 @@
|
|||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include "runtime/device/gpu/trt_loader.h"
|
||||
#include "runtime/device/gpu/cuda_driver.h"
|
||||
#include "backend/optimizer/trt_pass/trt_op_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/trt/trt_utils.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/singleton.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
|
@ -121,6 +123,15 @@ void GetRealInputs(const AnfNodePtr &node, std::vector<session::KernelWithIndex>
|
|||
} // namespace
|
||||
|
||||
bool TrtConverterContext::Init() {
|
||||
// Set device id before invoke trt api as cudaSetDevice is thread level config.
|
||||
const auto &context = MsContext::GetInstance();
|
||||
const auto &device_id = context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
bool ret = device::gpu::CudaDriver::SetDevice(UintToInt(device_id));
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Failed to set device id:" << device_id;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance();
|
||||
builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance());
|
||||
MS_EXCEPTION_IF_NULL(builder_);
|
||||
|
|
Loading…
Reference in New Issue