!12464 fix 910 cpp inference multi device id
From: @zhoufeng54 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
c82881cd8d
|
@ -276,7 +276,7 @@ Status AscendGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MST
|
|||
}
|
||||
|
||||
AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) {
|
||||
MS_LOG(INFO) << "Start to init env.";
|
||||
MS_LOG(INFO) << "Start to init device " << device_id;
|
||||
device_id_ = device_id;
|
||||
RegAllOp();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
|
@ -294,49 +294,54 @@ AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) {
|
|||
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "InitEnv success.";
|
||||
MS_LOG(INFO) << "Device " << device_id << " init env success.";
|
||||
errno_ = kSuccess;
|
||||
}
|
||||
|
||||
AscendGraphImpl::MsEnvGuard::~MsEnvGuard() {
|
||||
MS_LOG(INFO) << "Start finalize env";
|
||||
MS_LOG(INFO) << "Start finalize device " << device_id_;
|
||||
session::ExecutorManager::Instance().Clear();
|
||||
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
|
||||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
if (ms_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Get Context failed!";
|
||||
errno_ = kMCFailed;
|
||||
return;
|
||||
}
|
||||
|
||||
auto ret = rtDeviceReset(device_id_);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
|
||||
MS_LOG(ERROR) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
|
||||
return;
|
||||
}
|
||||
|
||||
errno_ = kSuccess;
|
||||
MS_LOG(INFO) << "End finalize env";
|
||||
MS_LOG(INFO) << "End finalize device " << device_id_;
|
||||
}
|
||||
|
||||
std::shared_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::GetEnv(uint32_t device_id) {
|
||||
std::shared_ptr<MsEnvGuard> acl_env;
|
||||
std::lock_guard<std::mutex> lock(global_ms_env_mutex_);
|
||||
acl_env = global_ms_env_.lock();
|
||||
auto iter = global_ms_env_.find(device_id);
|
||||
if (iter != global_ms_env_.end()) {
|
||||
acl_env = iter->second.lock();
|
||||
}
|
||||
|
||||
if (acl_env != nullptr) {
|
||||
MS_LOG(INFO) << "Env has been initialized, skip.";
|
||||
} else {
|
||||
acl_env = std::make_shared<MsEnvGuard>(device_id);
|
||||
if (acl_env->GetErrno() != kSuccess) {
|
||||
MS_LOG(ERROR) << "Execute aclInit Failed";
|
||||
return nullptr;
|
||||
}
|
||||
global_ms_env_ = acl_env;
|
||||
MS_LOG(INFO) << "Env init success";
|
||||
return acl_env;
|
||||
}
|
||||
|
||||
acl_env = std::make_shared<MsEnvGuard>(device_id);
|
||||
if (acl_env->GetErrno() != kSuccess) {
|
||||
MS_LOG(ERROR) << "Init ascend env Failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
global_ms_env_.emplace(device_id, acl_env);
|
||||
MS_LOG(INFO) << "Env init success";
|
||||
return acl_env;
|
||||
}
|
||||
|
||||
std::weak_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::global_ms_env_;
|
||||
std::map<uint32_t, std::weak_ptr<AscendGraphImpl::MsEnvGuard>> AscendGraphImpl::MsEnvGuard::global_ms_env_;
|
||||
std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_;
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -73,7 +73,7 @@ class AscendGraphImpl::MsEnvGuard {
|
|||
static std::shared_ptr<MsEnvGuard> GetEnv(uint32_t device_id);
|
||||
|
||||
private:
|
||||
static std::weak_ptr<MsEnvGuard> global_ms_env_;
|
||||
static std::map<uint32_t, std::weak_ptr<MsEnvGuard>> global_ms_env_;
|
||||
static std::mutex global_ms_env_mutex_;
|
||||
|
||||
Status errno_;
|
||||
|
|
|
@ -85,20 +85,20 @@ Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelTy
|
|||
}
|
||||
|
||||
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
|
||||
Buffer data = ReadFile(file);
|
||||
if (data.Data() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Read file " << file << " failed.";
|
||||
}
|
||||
if (model_type == kMindIR) {
|
||||
FuncGraphPtr anf_graph = nullptr;
|
||||
try {
|
||||
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(data.Data()), data.DataSize());
|
||||
anf_graph = LoadMindIR(file);
|
||||
} catch (const std::exception &) {
|
||||
MS_LOG(EXCEPTION) << "Load MindIR failed.";
|
||||
}
|
||||
|
||||
return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
|
||||
} else if (model_type == kOM) {
|
||||
Buffer data = ReadFile(file);
|
||||
if (data.Data() == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Read file " << file << " failed.";
|
||||
}
|
||||
return Graph(std::make_shared<Graph::GraphData>(data, kOM));
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;
|
||||
|
|
Loading…
Reference in New Issue