forked from mindspore-Ecosystem/mindspore
!14730 trt set device id
From: @wilfchen Reviewed-by: @limingqi107,@cristoval Signed-off-by:
This commit is contained in:
commit
76345b4a46
|
@ -24,11 +24,13 @@
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "runtime/device/gpu/trt_loader.h"
|
#include "runtime/device/gpu/trt_loader.h"
|
||||||
|
#include "runtime/device/gpu/cuda_driver.h"
|
||||||
#include "backend/optimizer/trt_pass/trt_op_factory.h"
|
#include "backend/optimizer/trt_pass/trt_op_factory.h"
|
||||||
#include "backend/kernel_compiler/gpu/trt/trt_utils.h"
|
#include "backend/kernel_compiler/gpu/trt/trt_utils.h"
|
||||||
#include "utils/convert_utils.h"
|
#include "utils/convert_utils.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "utils/singleton.h"
|
#include "utils/singleton.h"
|
||||||
|
#include "utils/ms_context.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -121,6 +123,15 @@ void GetRealInputs(const AnfNodePtr &node, std::vector<session::KernelWithIndex>
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool TrtConverterContext::Init() {
|
bool TrtConverterContext::Init() {
|
||||||
|
// Set device id before invoke trt api as cudaSetDevice is thread level config.
|
||||||
|
const auto &context = MsContext::GetInstance();
|
||||||
|
const auto &device_id = context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||||
|
bool ret = device::gpu::CudaDriver::SetDevice(UintToInt(device_id));
|
||||||
|
if (!ret) {
|
||||||
|
MS_LOG(ERROR) << "Failed to set device id:" << device_id;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance();
|
auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance();
|
||||||
builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance());
|
builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance());
|
||||||
MS_EXCEPTION_IF_NULL(builder_);
|
MS_EXCEPTION_IF_NULL(builder_);
|
||||||
|
|
Loading…
Reference in New Issue