!33234 Don't set device when device has been set
Merge pull request !33234 from laiyongqiang/resetdev
This commit is contained in:
commit
f17109c8aa
|
@ -184,6 +184,7 @@ void AscendKernelRuntime::ClearGraphModelMap() {
|
|||
MS_LOG(INFO) << "Ge UnloadModel " << iter.first;
|
||||
ModelRunner::Instance().UnloadModel(iter.first);
|
||||
}
|
||||
graph_model_map_.clear();
|
||||
}
|
||||
|
||||
void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
|
||||
|
@ -268,9 +269,6 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
|
|||
}
|
||||
}
|
||||
#endif
|
||||
if (!initialized_) {
|
||||
return;
|
||||
}
|
||||
SetCurrentContext();
|
||||
|
||||
// release ge runtime
|
||||
|
@ -1164,17 +1162,74 @@ void AscendKernelRuntime::CreateContext() {
|
|||
SetCurrentContext();
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::InitDevice() {
|
||||
bool AscendKernelRuntime::SetRtDevice(uint32_t device_id) {
|
||||
MS_LOG(INFO) << "Enter SetRtDevice, current initialize device number:" << initialized_device_set_.size();
|
||||
if (initialized_device_set_.count(device_id)) {
|
||||
MS_LOG(INFO) << "Device " << device_id << " has been set";
|
||||
return true;
|
||||
}
|
||||
|
||||
int device_count = 0;
|
||||
auto ret = rtGetDeviceCount(&device_count);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "Call rtGetDeviceCount, ret[" << static_cast<int>(ret) << "]";
|
||||
}
|
||||
|
||||
ret = rtSetDevice(UintToInt(device_id_));
|
||||
ret = rtSetDevice(UintToInt(device_id));
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast<int>(ret) << "]";
|
||||
}
|
||||
initialized_device_set_.insert(device_id);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::CreateDefaultStream(uint32_t device_id) {
|
||||
auto iter = device_stream_id_map_.find(device_id);
|
||||
if (iter != device_stream_id_map_.end()) {
|
||||
auto stream_map = (*iter->second);
|
||||
stream_id_map_[kDefaultStreamIndex] = stream_map[kDefaultStreamIndex];
|
||||
stream_ = stream_map[kDefaultStreamIndex];
|
||||
stream_id_map_[kIndependentStreamIndex] = stream_map[kIndependentStreamIndex];
|
||||
independent_stream_ = stream_map[kIndependentStreamIndex];
|
||||
stream_id_map_[kWorldGroupStreamIndex] = stream_map[kWorldGroupStreamIndex];
|
||||
communication_stream_ = stream_map[kWorldGroupStreamIndex];
|
||||
return true;
|
||||
}
|
||||
|
||||
auto stream_id_map = std::make_shared<std::map<uint32_t, void *>>();
|
||||
stream_id_map->clear();
|
||||
device_stream_id_map_[device_id] = stream_id_map;
|
||||
|
||||
auto ret = rtStreamCreateWithFlags(&stream_, 0, RT_STREAM_HUGE);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
|
||||
}
|
||||
stream_id_map_[kDefaultStreamIndex] = stream_;
|
||||
stream_id_map->insert(std::make_pair(kDefaultStreamIndex, stream_));
|
||||
|
||||
ret = rtStreamCreateWithFlags(&independent_stream_, 0, RT_STREAM_HUGE);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
|
||||
}
|
||||
stream_id_map_[kIndependentStreamIndex] = independent_stream_;
|
||||
stream_id_map->insert(std::make_pair(kIndependentStreamIndex, independent_stream_));
|
||||
|
||||
ret = rtStreamCreate(&communication_stream_, 0);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
|
||||
}
|
||||
stream_id_map_[kWorldGroupStreamIndex] = communication_stream_;
|
||||
stream_id_map->insert(std::make_pair(kWorldGroupStreamIndex, communication_stream_));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::InitDevice() {
|
||||
auto ret = SetRtDevice(device_id_);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Set runtime device failed";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
|
@ -1196,39 +1251,37 @@ bool AscendKernelRuntime::InitDevice() {
|
|||
return false;
|
||||
}
|
||||
|
||||
ret = rtStreamCreateWithFlags(&stream_, 0, RT_STREAM_HUGE);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
|
||||
ret = CreateDefaultStream(device_id_);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "Create default stream failed";
|
||||
return false;
|
||||
}
|
||||
ret = rtStreamCreateWithFlags(&independent_stream_, 0, RT_STREAM_HUGE);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
|
||||
}
|
||||
ret = rtStreamCreate(&communication_stream_, 0);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
|
||||
}
|
||||
|
||||
stream_id_map_[kDefaultStreamIndex] = stream_;
|
||||
stream_id_map_[kIndependentStreamIndex] = independent_stream_;
|
||||
stream_id_map_[kWorldGroupStreamIndex] = communication_stream_;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
|
||||
SetCurrentContext();
|
||||
int32_t ret;
|
||||
for (auto &iter : stream_id_map_) {
|
||||
ret = rtStreamDestroy(iter.second);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]";
|
||||
if (device_stream_id_map_.count(device_id)) {
|
||||
auto stream_id_map = *(device_stream_id_map_[device_id]);
|
||||
for (auto &iter : stream_id_map) {
|
||||
ret = rtStreamDestroy(iter.second);
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]";
|
||||
}
|
||||
iter.second = nullptr;
|
||||
}
|
||||
iter.second = nullptr;
|
||||
device_stream_id_map_.erase(device_id);
|
||||
}
|
||||
ret = rtDeviceReset(UintToInt(device_id));
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]";
|
||||
|
||||
if (initialized_device_set_.count(device_id)) {
|
||||
ret = rtDeviceReset(UintToInt(device_id));
|
||||
if (ret != RT_ERROR_NONE) {
|
||||
MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]";
|
||||
}
|
||||
initialized_device_set_.erase(device_id);
|
||||
}
|
||||
|
||||
// set to nullptr as its not created, only bounded to existing context
|
||||
rt_context_ = nullptr;
|
||||
return true;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
@ -92,6 +93,7 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
|
||||
private:
|
||||
bool InitDevice();
|
||||
bool SetRtDevice(uint32_t device_id);
|
||||
bool ResetDevice(uint32_t device_id);
|
||||
static bool HcclInit();
|
||||
static bool NeedDestroyHccl();
|
||||
|
@ -125,7 +127,10 @@ class AscendKernelRuntime : public KernelRuntime {
|
|||
std::map<std::pair<uint32_t, uint32_t>, std::string> stream_id_task_id_op_name_map_;
|
||||
static std::map<std::string, uint32_t> overflow_tasks_;
|
||||
static std::vector<rtExceptionInfo> task_fail_infoes_;
|
||||
std::map<uint32_t, std::shared_ptr<std::map<uint32_t, void *>>> device_stream_id_map_;
|
||||
std::map<uint32_t, void *> stream_id_map_;
|
||||
std::set<uint32_t> initialized_device_set_{};
|
||||
bool CreateDefaultStream(uint32_t device_id);
|
||||
};
|
||||
|
||||
MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime);
|
||||
|
|
|
@ -112,7 +112,7 @@ bool AscendMemAdapter::Initialize() {
|
|||
|
||||
bool AscendMemAdapter::DeInitialize() {
|
||||
if (!initialized_) {
|
||||
MS_LOG(ERROR) << "DeInitialize Ascend Memory Adapter when it is not initialize";
|
||||
MS_LOG(INFO) << "DeInitialize Ascend Memory Adapter when it is not initialize";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -358,6 +358,8 @@ void AicpuOpKernelLoad::FreeDeviceMemory() {
|
|||
}
|
||||
}
|
||||
}
|
||||
allocated_mem_list_.clear();
|
||||
|
||||
for (auto stream : stream_list_) {
|
||||
if (stream != nullptr) {
|
||||
auto rt_error = rtStreamDestroy(stream);
|
||||
|
@ -366,6 +368,8 @@ void AicpuOpKernelLoad::FreeDeviceMemory() {
|
|||
}
|
||||
}
|
||||
}
|
||||
stream_list_.clear();
|
||||
|
||||
so_name_and_realpath_map_.clear();
|
||||
cust_aicpu_so_.clear();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue