forked from OSSInnovation/mindspore
modify device id
This commit is contained in:
parent
f689648872
commit
bdc67ee2ca
|
@ -82,6 +82,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
||||||
if (IsCloudTransDeviceId()) {
|
if (IsCloudTransDeviceId()) {
|
||||||
device_id_ = 0;
|
device_id_ = 0;
|
||||||
}
|
}
|
||||||
|
MS_LOG(INFO) << "context logic id: " << device_id_ << "context physics id: " << physics_id_;
|
||||||
|
|
||||||
backend_policy_ = policy_map_[policy];
|
backend_policy_ = policy_map_[policy];
|
||||||
device_target_ = target;
|
device_target_ = target;
|
||||||
|
@ -172,7 +173,7 @@ bool MsContext::set_device_id(uint32_t device_id) {
|
||||||
if (IsCloudTransDeviceId()) {
|
if (IsCloudTransDeviceId()) {
|
||||||
device_id_ = 0;
|
device_id_ = 0;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "ms set context logic id:" << device_id;
|
MS_LOG(INFO) << "ms set context logic id:" << device_id_;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -334,12 +334,33 @@ Backend::Backend(const std::string &name) : name_(name) {
|
||||||
simu_flag_ = false;
|
simu_flag_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsCloudTransSessDeviceId() {
|
||||||
|
auto deploy_mode = common::GetEnv("DEPLOY_MODE");
|
||||||
|
if (deploy_mode.empty() || deploy_mode != "1") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rank_size = common::GetEnv("RANK_SIZE");
|
||||||
|
if (rank_size.empty() || rank_size != "1") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
|
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
|
||||||
convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
|
convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
|
||||||
target_sess_ = session::SessionFactory::Get().Create(target);
|
target_sess_ = session::SessionFactory::Get().Create(target);
|
||||||
if (target_sess_ == nullptr) {
|
if (target_sess_ == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
|
MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Before trans, device id: " << device_id;
|
||||||
|
if (IsCloudTransSessDeviceId()) {
|
||||||
|
device_id = 0;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "After trans, device id: " << device_id;
|
||||||
|
|
||||||
target_sess_->Init(device_id);
|
target_sess_->Init(device_id);
|
||||||
target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
|
target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
|
||||||
target_device_ = target;
|
target_device_ = target;
|
||||||
|
|
Loading…
Reference in New Issue