!12464 fix 910 cpp inference multi device id

From: @zhoufeng54
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-22 21:02:53 +08:00 committed by Gitee
commit c82881cd8d
3 changed files with 28 additions and 23 deletions

View File

@ -276,7 +276,7 @@ Status AscendGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MST
}
AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) {
MS_LOG(INFO) << "Start to init env.";
MS_LOG(INFO) << "Start to init device " << device_id;
device_id_ = device_id;
RegAllOp();
auto ms_context = MsContext::GetInstance();
@ -294,49 +294,54 @@ AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) {
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
}
MS_LOG(INFO) << "InitEnv success.";
MS_LOG(INFO) << "Device " << device_id << " init env success.";
errno_ = kSuccess;
}
AscendGraphImpl::MsEnvGuard::~MsEnvGuard() {
MS_LOG(INFO) << "Start finalize env";
MS_LOG(INFO) << "Start finalize device " << device_id_;
session::ExecutorManager::Instance().Clear();
device::KernelRuntimeManager::Instance().ClearRuntimeResource();
auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!";
errno_ = kMCFailed;
return;
}
auto ret = rtDeviceReset(device_id_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
MS_LOG(ERROR) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
return;
}
errno_ = kSuccess;
MS_LOG(INFO) << "End finalize env";
MS_LOG(INFO) << "End finalize device " << device_id_;
}
std::shared_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::GetEnv(uint32_t device_id) {
std::shared_ptr<MsEnvGuard> acl_env;
std::lock_guard<std::mutex> lock(global_ms_env_mutex_);
acl_env = global_ms_env_.lock();
auto iter = global_ms_env_.find(device_id);
if (iter != global_ms_env_.end()) {
acl_env = iter->second.lock();
}
if (acl_env != nullptr) {
MS_LOG(INFO) << "Env has been initialized, skip.";
} else {
acl_env = std::make_shared<MsEnvGuard>(device_id);
if (acl_env->GetErrno() != kSuccess) {
MS_LOG(ERROR) << "Execute aclInit Failed";
return nullptr;
}
global_ms_env_ = acl_env;
MS_LOG(INFO) << "Env init success";
return acl_env;
}
acl_env = std::make_shared<MsEnvGuard>(device_id);
if (acl_env->GetErrno() != kSuccess) {
MS_LOG(ERROR) << "Init ascend env Failed";
return nullptr;
}
global_ms_env_.emplace(device_id, acl_env);
MS_LOG(INFO) << "Env init success";
return acl_env;
}
std::weak_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::global_ms_env_;
std::map<uint32_t, std::weak_ptr<AscendGraphImpl::MsEnvGuard>> AscendGraphImpl::MsEnvGuard::global_ms_env_;
std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_;
} // namespace mindspore

View File

@ -73,7 +73,7 @@ class AscendGraphImpl::MsEnvGuard {
static std::shared_ptr<MsEnvGuard> GetEnv(uint32_t device_id);
private:
static std::weak_ptr<MsEnvGuard> global_ms_env_;
static std::map<uint32_t, std::weak_ptr<MsEnvGuard>> global_ms_env_;
static std::mutex global_ms_env_mutex_;
Status errno_;

View File

@ -85,20 +85,20 @@ Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelTy
}
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
Buffer data = ReadFile(file);
if (data.Data() == nullptr) {
MS_LOG(EXCEPTION) << "Read file " << file << " failed.";
}
if (model_type == kMindIR) {
FuncGraphPtr anf_graph = nullptr;
try {
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(data.Data()), data.DataSize());
anf_graph = LoadMindIR(file);
} catch (const std::exception &) {
MS_LOG(EXCEPTION) << "Load MindIR failed.";
}
return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
} else if (model_type == kOM) {
Buffer data = ReadFile(file);
if (data.Data() == nullptr) {
MS_LOG(EXCEPTION) << "Read file " << file << " failed.";
}
return Graph(std::make_shared<Graph::GraphData>(data, kOM));
}
MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;