forked from OSSInnovation/mindspore
!3340 modify device id
Merge pull request !3340 from changzherui/mod_device_id
This commit is contained in:
commit
3f916bddd3
|
@ -45,6 +45,20 @@ std::map<std::string, MsBackendPolicy> MsContext::policy_map_ = {{"ge", kMsBacke
|
|||
{"ge_only", kMsBackendGeOnly},
|
||||
{"vm_prior", kMsBackendVmPrior}};
|
||||
|
||||
bool IsCloudTransDeviceId() {
|
||||
auto deploy_mode = common::GetEnv("DEPLOY_MODE");
|
||||
if (deploy_mode.empty() || deploy_mode != "1") {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto rank_size = common::GetEnv("RANK_SIZE");
|
||||
if (rank_size.empty() || rank_size != "1") {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
MsContext::MsContext(const std::string &policy, const std::string &target) {
|
||||
save_graphs_flag_ = false;
|
||||
save_graphs_path_ = ".";
|
||||
|
@ -63,6 +77,12 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
} else {
|
||||
device_id_ = 0;
|
||||
}
|
||||
|
||||
physics_id_ = device_id_;
|
||||
if (IsCloudTransDeviceId()) {
|
||||
device_id_ = 0;
|
||||
}
|
||||
|
||||
backend_policy_ = policy_map_[policy];
|
||||
device_target_ = target;
|
||||
execution_mode_ = kPynativeMode;
|
||||
|
@ -147,6 +167,13 @@ bool MsContext::set_device_target(const std::string &target) {
|
|||
bool MsContext::set_device_id(uint32_t device_id) {
|
||||
device_id_ = device_id;
|
||||
MS_LOG(INFO) << "ms set context device id:" << device_id;
|
||||
|
||||
physics_id_ = device_id_;
|
||||
if (IsCloudTransDeviceId()) {
|
||||
device_id_ = 0;
|
||||
}
|
||||
MS_LOG(INFO) << "ms set context logic id:" << device_id;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -166,7 +193,8 @@ bool MsContext::OpenTsd() {
|
|||
unsigned int device_id;
|
||||
unsigned int rank_size = 1;
|
||||
|
||||
device_id = device_id_;
|
||||
device_id = physics_id_;
|
||||
MS_LOG(INFO) << "Open and init tsd, device = " << device_id << ".";
|
||||
|
||||
auto rank_size_env = common::GetEnv("RANK_SIZE");
|
||||
if (rank_size_env.empty()) {
|
||||
|
|
|
@ -172,6 +172,7 @@ class MsContext {
|
|||
MsBackendPolicy backend_policy_;
|
||||
std::string device_target_;
|
||||
uint32_t device_id_;
|
||||
uint32_t physics_id_;
|
||||
int execution_mode_;
|
||||
bool enable_pynative_infer_;
|
||||
bool enable_pynative_hook_;
|
||||
|
|
Loading…
Reference in New Issue