forked from mindspore-Ecosystem/mindspore
Set gpu device id for multiple threads
This commit is contained in:
parent
7b4ff1323b
commit
5059d8c3f9
|
@ -53,6 +53,8 @@
|
|||
#include "runtime/device/gpu/gpu_stream_assign.h"
|
||||
#include "runtime/device/gpu/kernel_info_setter.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "runtime/device/gpu/cuda_driver.h"
|
||||
#include "runtime/device/gpu/distribution/collective_init.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -64,6 +66,25 @@ namespace mindspore {
|
|||
namespace session {
|
||||
namespace gpu {
|
||||
using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
|
||||
using CollectiveInitializer = device::gpu::CollectiveInitializer;
|
||||
using GetLocalRankId = device::gpu::GetLocalRankId;
|
||||
|
||||
void GPUSession::Init(uint32_t device_id) {
|
||||
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
|
||||
bool collective_inited = CollectiveInitializer::instance().collective_inited();
|
||||
if (collective_inited && collective_handle_ != nullptr) {
|
||||
auto get_local_rank_funcptr =
|
||||
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
|
||||
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
|
||||
device_id = IntToUint((*get_local_rank_funcptr)());
|
||||
}
|
||||
bool ret = device::gpu::CudaDriver::set_current_device(UintToInt(device_id));
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "GPUSession failed to set current device id.";
|
||||
}
|
||||
MS_LOG(INFO) << "Set device id " << device_id << " for gpu session.";
|
||||
InitDevice(kGPUDevice, device_id);
|
||||
}
|
||||
|
||||
void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
|
|
@ -31,7 +31,7 @@ class GPUSession : public SessionBasic {
|
|||
public:
|
||||
GPUSession() = default;
|
||||
~GPUSession() override = default;
|
||||
void Init(uint32_t device_id) override { InitDevice(kGPUDevice, device_id); }
|
||||
void Init(uint32_t device_id) override;
|
||||
|
||||
protected:
|
||||
GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
|
||||
|
|
|
@ -187,14 +187,6 @@ bool GPUKernelRuntime::InitDevice() {
|
|||
MS_LOG(ERROR) << "No GPU device found.";
|
||||
return false;
|
||||
}
|
||||
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
|
||||
bool collective_inited = CollectiveInitializer::instance().collective_inited();
|
||||
if (collective_inited && collective_handle_ != nullptr) {
|
||||
auto get_local_rank_funcptr =
|
||||
reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
|
||||
MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
|
||||
device_id_ = IntToUint((*get_local_rank_funcptr)());
|
||||
}
|
||||
if (!GPUDeviceManager::GetInstance().is_device_id_init()) {
|
||||
if (!GPUDeviceManager::GetInstance().set_cur_device_id(device_id_)) {
|
||||
MS_LOG(ERROR) << "Failed to set current device to " << SizeToInt(device_id_);
|
||||
|
|
Loading…
Reference in New Issue