!33234 Don't set device when device has been set

Merge pull request !33234 from laiyongqiang/resetdev
This commit is contained in:
i-robot 2022-04-25 07:08:58 +00:00 committed by Gitee
commit f17109c8aa
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 91 additions and 29 deletions

View File

@ -184,6 +184,7 @@ void AscendKernelRuntime::ClearGraphModelMap() {
MS_LOG(INFO) << "Ge UnloadModel " << iter.first;
ModelRunner::Instance().UnloadModel(iter.first);
}
graph_model_map_.clear();
}
void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
@ -268,9 +269,6 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
}
}
#endif
if (!initialized_) {
return;
}
SetCurrentContext();
// release ge runtime
@ -1164,17 +1162,74 @@ void AscendKernelRuntime::CreateContext() {
SetCurrentContext();
}
bool AscendKernelRuntime::InitDevice() {
bool AscendKernelRuntime::SetRtDevice(uint32_t device_id) {
MS_LOG(INFO) << "Enter SetRtDevice, current initialize device number:" << initialized_device_set_.size();
if (initialized_device_set_.count(device_id)) {
MS_LOG(INFO) << "Device " << device_id << " has been set";
return true;
}
int device_count = 0;
auto ret = rtGetDeviceCount(&device_count);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtGetDeviceCount, ret[" << static_cast<int>(ret) << "]";
}
ret = rtSetDevice(UintToInt(device_id_));
ret = rtSetDevice(UintToInt(device_id));
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast<int>(ret) << "]";
}
initialized_device_set_.insert(device_id);
return true;
}
bool AscendKernelRuntime::CreateDefaultStream(uint32_t device_id) {
auto iter = device_stream_id_map_.find(device_id);
if (iter != device_stream_id_map_.end()) {
auto stream_map = (*iter->second);
stream_id_map_[kDefaultStreamIndex] = stream_map[kDefaultStreamIndex];
stream_ = stream_map[kDefaultStreamIndex];
stream_id_map_[kIndependentStreamIndex] = stream_map[kIndependentStreamIndex];
independent_stream_ = stream_map[kIndependentStreamIndex];
stream_id_map_[kWorldGroupStreamIndex] = stream_map[kWorldGroupStreamIndex];
communication_stream_ = stream_map[kWorldGroupStreamIndex];
return true;
}
auto stream_id_map = std::make_shared<std::map<uint32_t, void *>>();
stream_id_map->clear();
device_stream_id_map_[device_id] = stream_id_map;
auto ret = rtStreamCreateWithFlags(&stream_, 0, RT_STREAM_HUGE);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
}
stream_id_map_[kDefaultStreamIndex] = stream_;
stream_id_map->insert(std::make_pair(kDefaultStreamIndex, stream_));
ret = rtStreamCreateWithFlags(&independent_stream_, 0, RT_STREAM_HUGE);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
}
stream_id_map_[kIndependentStreamIndex] = independent_stream_;
stream_id_map->insert(std::make_pair(kIndependentStreamIndex, independent_stream_));
ret = rtStreamCreate(&communication_stream_, 0);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
}
stream_id_map_[kWorldGroupStreamIndex] = communication_stream_;
stream_id_map->insert(std::make_pair(kWorldGroupStreamIndex, communication_stream_));
return true;
}
bool AscendKernelRuntime::InitDevice() {
auto ret = SetRtDevice(device_id_);
if (!ret) {
MS_LOG(ERROR) << "Set runtime device failed";
return false;
}
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
@ -1196,39 +1251,37 @@ bool AscendKernelRuntime::InitDevice() {
return false;
}
ret = rtStreamCreateWithFlags(&stream_, 0, RT_STREAM_HUGE);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
ret = CreateDefaultStream(device_id_);
if (!ret) {
MS_LOG(ERROR) << "Create default stream failed";
return false;
}
ret = rtStreamCreateWithFlags(&independent_stream_, 0, RT_STREAM_HUGE);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
}
ret = rtStreamCreate(&communication_stream_, 0);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
}
stream_id_map_[kDefaultStreamIndex] = stream_;
stream_id_map_[kIndependentStreamIndex] = independent_stream_;
stream_id_map_[kWorldGroupStreamIndex] = communication_stream_;
return true;
}
bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
SetCurrentContext();
int32_t ret;
for (auto &iter : stream_id_map_) {
ret = rtStreamDestroy(iter.second);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]";
if (device_stream_id_map_.count(device_id)) {
auto stream_id_map = *(device_stream_id_map_[device_id]);
for (auto &iter : stream_id_map) {
ret = rtStreamDestroy(iter.second);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]";
}
iter.second = nullptr;
}
iter.second = nullptr;
device_stream_id_map_.erase(device_id);
}
ret = rtDeviceReset(UintToInt(device_id));
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]";
if (initialized_device_set_.count(device_id)) {
ret = rtDeviceReset(UintToInt(device_id));
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]";
}
initialized_device_set_.erase(device_id);
}
// set to nullptr as its not created, only bounded to existing context
rt_context_ = nullptr;
return true;

View File

@ -20,6 +20,7 @@
#include <vector>
#include <string>
#include <map>
#include <set>
#include <utility>
#include <unordered_map>
#include <unordered_set>
@ -92,6 +93,7 @@ class AscendKernelRuntime : public KernelRuntime {
private:
bool InitDevice();
bool SetRtDevice(uint32_t device_id);
bool ResetDevice(uint32_t device_id);
static bool HcclInit();
static bool NeedDestroyHccl();
@ -125,7 +127,10 @@ class AscendKernelRuntime : public KernelRuntime {
std::map<std::pair<uint32_t, uint32_t>, std::string> stream_id_task_id_op_name_map_;
static std::map<std::string, uint32_t> overflow_tasks_;
static std::vector<rtExceptionInfo> task_fail_infoes_;
std::map<uint32_t, std::shared_ptr<std::map<uint32_t, void *>>> device_stream_id_map_;
std::map<uint32_t, void *> stream_id_map_;
std::set<uint32_t> initialized_device_set_{};
bool CreateDefaultStream(uint32_t device_id);
};
MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime);

View File

@ -112,7 +112,7 @@ bool AscendMemAdapter::Initialize() {
bool AscendMemAdapter::DeInitialize() {
if (!initialized_) {
MS_LOG(ERROR) << "DeInitialize Ascend Memory Adapter when it is not initialize";
MS_LOG(INFO) << "DeInitialize Ascend Memory Adapter when it is not initialize";
return false;
}

View File

@ -358,6 +358,8 @@ void AicpuOpKernelLoad::FreeDeviceMemory() {
}
}
}
allocated_mem_list_.clear();
for (auto stream : stream_list_) {
if (stream != nullptr) {
auto rt_error = rtStreamDestroy(stream);
@ -366,6 +368,8 @@ void AicpuOpKernelLoad::FreeDeviceMemory() {
}
}
}
stream_list_.clear();
so_name_and_realpath_map_.clear();
cust_aicpu_so_.clear();
}