Set gpu device id for multiple threads

This commit is contained in:
ZPaC 2020-10-28 14:17:02 +08:00
parent 7b4ff1323b
commit 5059d8c3f9
3 changed files with 22 additions and 9 deletions

View File

@ -53,6 +53,8 @@
#include "runtime/device/gpu/gpu_stream_assign.h"
#include "runtime/device/gpu/kernel_info_setter.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/device/gpu/cuda_driver.h"
#include "runtime/device/gpu/distribution/collective_init.h"
#include "utils/ms_utils.h"
#include "utils/config_manager.h"
#include "utils/ms_context.h"
@ -64,6 +66,25 @@ namespace mindspore {
namespace session {
namespace gpu {
using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
using CollectiveInitializer = device::gpu::CollectiveInitializer;
using GetLocalRankId = device::gpu::GetLocalRankId;
void GPUSession::Init(uint32_t device_id) {
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
bool collective_inited = CollectiveInitializer::instance().collective_inited();
if (collective_inited && collective_handle_ != nullptr) {
auto get_local_rank_funcptr =
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
device_id = IntToUint((*get_local_rank_funcptr)());
}
bool ret = device::gpu::CudaDriver::set_current_device(UintToInt(device_id));
if (!ret) {
MS_LOG(EXCEPTION) << "GPUSession failed to set current device id.";
}
MS_LOG(INFO) << "Set device id " << device_id << " for gpu session.";
InitDevice(kGPUDevice, device_id);
}
void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_EXCEPTION_IF_NULL(kernel_graph);

View File

@ -31,7 +31,7 @@ class GPUSession : public SessionBasic {
public:
GPUSession() = default;
~GPUSession() override = default;
void Init(uint32_t device_id) override { InitDevice(kGPUDevice, device_id); }
void Init(uint32_t device_id) override;
protected:
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;

View File

@ -187,14 +187,6 @@ bool GPUKernelRuntime::InitDevice() {
MS_LOG(ERROR) << "No GPU device found.";
return false;
}
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
bool collective_inited = CollectiveInitializer::instance().collective_inited();
if (collective_inited && collective_handle_ != nullptr) {
auto get_local_rank_funcptr =
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
device_id_ = IntToUint((*get_local_rank_funcptr)());
}
if (!GPUDeviceManager::GetInstance().is_device_id_init()) {
if (!GPUDeviceManager::GetInstance().set_cur_device_id(device_id_)) {
MS_LOG(ERROR) << "Failed to set current device to " << SizeToInt(device_id_);