forked from mindspore-Ecosystem/mindspore
!3357 modify device id
Merge pull request !3357 from changzherui/mod_device_id
This commit is contained in:
commit
0e3a39c223
|
@ -82,6 +82,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
if (IsCloudTransDeviceId()) {
|
||||
device_id_ = 0;
|
||||
}
|
||||
MS_LOG(INFO) << "context logic id: " << device_id_ << "context physics id: " << physics_id_;
|
||||
|
||||
backend_policy_ = policy_map_[policy];
|
||||
device_target_ = target;
|
||||
|
@ -172,7 +173,7 @@ bool MsContext::set_device_id(uint32_t device_id) {
|
|||
if (IsCloudTransDeviceId()) {
|
||||
device_id_ = 0;
|
||||
}
|
||||
MS_LOG(INFO) << "ms set context logic id:" << device_id;
|
||||
MS_LOG(INFO) << "ms set context logic id:" << device_id_;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -334,12 +334,33 @@ Backend::Backend(const std::string &name) : name_(name) {
|
|||
simu_flag_ = false;
|
||||
}
|
||||
|
||||
bool IsCloudTransSessDeviceId() {
|
||||
auto deploy_mode = common::GetEnv("DEPLOY_MODE");
|
||||
if (deploy_mode.empty() || deploy_mode != "1") {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto rank_size = common::GetEnv("RANK_SIZE");
|
||||
if (rank_size.empty() || rank_size != "1") {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
|
||||
convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
|
||||
target_sess_ = session::SessionFactory::Get().Create(target);
|
||||
if (target_sess_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Before trans, device id: " << device_id;
|
||||
if (IsCloudTransSessDeviceId()) {
|
||||
device_id = 0;
|
||||
}
|
||||
MS_LOG(INFO) << "After trans, device id: " << device_id;
|
||||
|
||||
target_sess_->Init(device_id);
|
||||
target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
|
||||
target_device_ = target;
|
||||
|
|
Loading…
Reference in New Issue