forked from mindspore-Ecosystem/mindspore
change gpu kernel runtime to support memory swap
This commit is contained in:
parent
8e36a4451e
commit
23a57476da
|
@ -28,10 +28,13 @@
|
|||
#include "common/utils.h"
|
||||
#include "device/gpu/gpu_memory_manager.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "device/gpu/gpu_memory_copy_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
using mindspore::device::memswap::MemSwapManager;
|
||||
using mindspore::device::memswap::SwapKind;
|
||||
bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); }
|
||||
|
||||
bool GPUKernelRuntime::Init() {
|
||||
|
@ -101,6 +104,12 @@ void GPUKernelRuntime::ReleaseDeviceRes() {
|
|||
}
|
||||
CHECK_OP_RET_WITH_EXCEPT(GpuBufferMgr::GetInstance().Destroy(), "Could not destroy gpu data queue.");
|
||||
}
|
||||
// destroy remaining memory swap events and free host memory
|
||||
if (mem_swap_manager_->trigger_swap()) {
|
||||
mem_swap_manager_->ClearSwapQueue();
|
||||
mem_swap_manager_->ReleaseHostPinnedMem();
|
||||
}
|
||||
|
||||
GPUDeviceManager::GetInstance().ReleaseDevice();
|
||||
if (mem_manager_ != nullptr) {
|
||||
mem_manager_->FreeDeviceMemory();
|
||||
|
@ -126,15 +135,29 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
|
|||
}
|
||||
|
||||
bool GPUKernelRuntime::Run(session::KernelGraph *graph) {
|
||||
bool ret;
|
||||
bool ret = true;
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool();
|
||||
bool is_enable_pynative_infer = context_ptr->enable_pynative_infer();
|
||||
auto iter = mem_swap_map_.find(graph);
|
||||
if (iter == mem_swap_map_.end()) {
|
||||
GPUMemCopyManagerPtr gpu_mem_copy_manager = std::make_shared<GPUMemCopyManager>();
|
||||
iter = mem_swap_map_.emplace(graph, std::make_shared<MemSwapManager>(gpu_mem_copy_manager)).first;
|
||||
}
|
||||
mem_swap_manager_ = iter->second;
|
||||
struct timeval start_time, end_time;
|
||||
(void)gettimeofday(&start_time, nullptr);
|
||||
if (is_enable_dynamic_mem && !is_enable_pynative_infer) {
|
||||
ret = LaunchKernelDynamic(graph);
|
||||
while (!LaunchKernelDynamic(graph)) {
|
||||
ClearKernelOutputAddress(graph);
|
||||
if (!mem_swap_manager_->mem_swap_init()) {
|
||||
mem_swap_manager_->Init(graph);
|
||||
}
|
||||
if (!mem_swap_manager_->RetreatSwapInfo()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ret = LaunchKernel(graph);
|
||||
}
|
||||
|
@ -181,6 +204,27 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph
|
|||
}
|
||||
}
|
||||
|
||||
void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto &kernels = graph->execution_order();
|
||||
for (const auto &kernel : kernels) {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
auto output_sizes = kernel_mod->GetOutputSizeList();
|
||||
for (size_t i = 0; i < output_sizes.size(); ++i) {
|
||||
if (!AnfAlgo::OutputAddrExist(kernel, i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i);
|
||||
if (device_address->ptr_) {
|
||||
mem_manager_->FreeMemFromMemPool(device_address);
|
||||
}
|
||||
device_address->set_status(DeviceAddressStatus::kInDevice);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto graph_id = graph->graph_id();
|
||||
|
@ -198,32 +242,157 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) {
|
|||
AddressPtrList kernel_inputs;
|
||||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
AllocKernelDynamicRes(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) {
|
||||
MS_LOG(ERROR) << "Launch kernel failed.";
|
||||
auto ret = AllocKernelDynamicRes(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) {
|
||||
MS_LOG(EXCEPTION) << "Launch kernel failed.";
|
||||
}
|
||||
FreeKernelDynamicRes(kernel, kernel_workspaces, graph_id);
|
||||
|
||||
if (mem_swap_manager_->trigger_swap() && mem_swap_manager_->QueryKernelTriggerSwap(kernel)) {
|
||||
CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed.");
|
||||
if (!AddMemSwapTask(kernel)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (mem_swap_manager_->trigger_swap()) {
|
||||
mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost);
|
||||
}
|
||||
}
|
||||
|
||||
if (!SyncStream()) {
|
||||
MS_LOG(ERROR) << "SyncStream failed.";
|
||||
CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed.");
|
||||
if (mem_swap_manager_->trigger_swap()) {
|
||||
mem_swap_manager_->ClearSwapQueue();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GPUKernelRuntime::AddMemSwapTask(const AnfNodePtr &kernel) {
|
||||
auto &mem_swap_info_list = mem_swap_manager_->QueryKernelMemSwapInfo(kernel);
|
||||
for (auto &mem_swap_info : mem_swap_info_list) {
|
||||
auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_);
|
||||
const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_];
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_);
|
||||
|
||||
if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) {
|
||||
mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address);
|
||||
} else if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) {
|
||||
auto status = device_address->status();
|
||||
if (status == DeviceAddressStatus::kInDeviceToHost) {
|
||||
mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_);
|
||||
device_address->set_status(DeviceAddressStatus::kInDevice);
|
||||
} else if (status == DeviceAddressStatus::kInHost) {
|
||||
if (!device_address->ptr_ && !AttemptMallocMem(device_address, device_address->size_)) {
|
||||
return false;
|
||||
}
|
||||
if (!mem_swap_manager_->FindInSwapInBlackList(device_address->ptr_)) {
|
||||
mem_swap_manager_->AddMemSwapTask(SwapKind::kHostToDevice, device_address, host_address);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size) {
|
||||
auto ret = mem_manager_->MallocMemFromMemPool(device_address, size);
|
||||
if (!ret) {
|
||||
if (!mem_swap_manager_->trigger_swap()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost);
|
||||
while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) {
|
||||
if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) {
|
||||
device_address_swap_out->set_status(DeviceAddressStatus::kInHost);
|
||||
mem_manager_->FreeMemFromMemPool(device_address_swap_out);
|
||||
}
|
||||
}
|
||||
|
||||
ret = mem_manager_->MallocMemFromMemPool(device_address, size);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void *GPUKernelRuntime::AttemptMallocMem(size_t size) {
|
||||
auto device_ptr = mem_manager_->MallocMemFromMemPool(size);
|
||||
if (!device_ptr) {
|
||||
if (!mem_swap_manager_->trigger_swap()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost);
|
||||
while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) {
|
||||
if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) {
|
||||
device_address_swap_out->set_status(DeviceAddressStatus::kInHost);
|
||||
mem_manager_->FreeMemFromMemPool(device_address_swap_out);
|
||||
}
|
||||
}
|
||||
|
||||
device_ptr = mem_manager_->MallocMemFromMemPool(size);
|
||||
if (!device_ptr) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return device_ptr;
|
||||
}
|
||||
|
||||
bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
|
||||
const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs,
|
||||
AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs) {
|
||||
if (!AllocKernelInputDynamicRes(kernel, kernel_inputs)) {
|
||||
return false;
|
||||
}
|
||||
if (!AllocKernelOutputDynamicRes(kernel_mod, kernel, kernel_outputs)) {
|
||||
return false;
|
||||
}
|
||||
if (!AllocKernelWorkspaceDynamicRes(kernel_mod, kernel, kernel_workspaces)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
|
||||
const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs,
|
||||
AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs) {
|
||||
bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_inputs);
|
||||
MS_EXCEPTION_IF_NULL(kernel_workspaces);
|
||||
MS_EXCEPTION_IF_NULL(kernel_outputs);
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
|
||||
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, i);
|
||||
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (mem_swap_manager_->trigger_swap()) {
|
||||
while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) {
|
||||
device_address_swap_in->set_status(DeviceAddressStatus::kInDevice);
|
||||
}
|
||||
|
||||
auto status = device_address->status();
|
||||
switch (status) {
|
||||
case DeviceAddressStatus::kInDevice:
|
||||
break;
|
||||
case DeviceAddressStatus::kInHost:
|
||||
break;
|
||||
case DeviceAddressStatus::kInDeviceToHost: {
|
||||
mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_);
|
||||
device_address->set_status(DeviceAddressStatus::kInDevice);
|
||||
break;
|
||||
}
|
||||
case DeviceAddressStatus::kInHostToDevice: {
|
||||
while (device_address->status() != DeviceAddressStatus::kInDevice) {
|
||||
while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) {
|
||||
device_address_swap_in->set_status(DeviceAddressStatus::kInDevice);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Invaild device address status";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(device_address->ptr_);
|
||||
kernel::AddressPtr input = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
|
@ -231,15 +400,29 @@ void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod
|
|||
input->size = device_address->size_;
|
||||
kernel_inputs->emplace_back(input);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
|
||||
const mindspore::AnfNodePtr &kernel,
|
||||
AddressPtrList *kernel_outputs) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_outputs);
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
if (mem_swap_manager_->trigger_swap()) {
|
||||
while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) {
|
||||
if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) {
|
||||
device_address_swap_out->set_status(DeviceAddressStatus::kInHost);
|
||||
mem_manager_->FreeMemFromMemPool(device_address_swap_out);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto output_sizes = kernel_mod.GetOutputSizeList();
|
||||
for (size_t i = 0; i < output_sizes.size(); ++i) {
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (device_address->ptr_ == nullptr) {
|
||||
auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Malloc device memory failed.";
|
||||
}
|
||||
if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i])) {
|
||||
return false;
|
||||
}
|
||||
kernel::AddressPtr output = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
|
@ -247,15 +430,24 @@ void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod
|
|||
output->size = output_sizes[i];
|
||||
kernel_outputs->emplace_back(output);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
|
||||
const mindspore::AnfNodePtr &kernel,
|
||||
AddressPtrList *kernel_workspaces) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_workspaces);
|
||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||
auto workspace_sizes = kernel_mod.GetWorkspaceSizeList();
|
||||
for (size_t i = 0; i < workspace_sizes.size(); ++i) {
|
||||
if (workspace_sizes[i] == 0) {
|
||||
kernel_workspaces->emplace_back(nullptr);
|
||||
continue;
|
||||
}
|
||||
auto device_ptr = mem_manager_->MallocMemFromMemPool(workspace_sizes[i]);
|
||||
auto device_ptr = AttemptMallocMem(workspace_sizes[i]);
|
||||
if (!device_ptr) {
|
||||
MS_LOG(EXCEPTION) << "Malloc device memory failed.";
|
||||
return false;
|
||||
}
|
||||
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
|
||||
MS_EXCEPTION_IF_NULL(workspace);
|
||||
|
@ -263,6 +455,7 @@ void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod
|
|||
workspace->size = workspace_sizes[i];
|
||||
kernel_workspaces->emplace_back(workspace);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph *graph) {
|
||||
|
@ -371,6 +564,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
|
|||
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) {
|
||||
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
|
||||
mem_manager_->FreeMemFromMemPool(device_address);
|
||||
device_address->set_status(DeviceAddressStatus::kInDevice);
|
||||
}
|
||||
}
|
||||
// Free the output of kernel, if output has no reference.
|
||||
|
@ -382,6 +576,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
|
|||
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) {
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i);
|
||||
mem_manager_->FreeMemFromMemPool(device_address);
|
||||
device_address->set_status(DeviceAddressStatus::kInDevice);
|
||||
}
|
||||
}
|
||||
// Free the workspace of kernel.
|
||||
|
|
|
@ -24,10 +24,12 @@
|
|||
#include <unordered_map>
|
||||
#include "device/kernel_runtime.h"
|
||||
#include "device/kernel_runtime_manager.h"
|
||||
#include "pre_activate/mem_reuse/mem_swap_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace gpu {
|
||||
using mindspore::device::memswap::MemSwapManagerPtr;
|
||||
class GPUKernelRuntime : public KernelRuntime {
|
||||
public:
|
||||
GPUKernelRuntime() = default;
|
||||
|
@ -51,10 +53,19 @@ class GPUKernelRuntime : public KernelRuntime {
|
|||
// The related functions and members for using dynamic memory pool.
|
||||
void InitKernelRefCount(const session::KernelGraph *graph);
|
||||
void InitKernelOutputAddress(const session::KernelGraph *graph);
|
||||
void ClearKernelOutputAddress(const session::KernelGraph *graph);
|
||||
bool LaunchKernelDynamic(const session::KernelGraph *graph);
|
||||
void AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
|
||||
bool AddMemSwapTask(const AnfNodePtr &kernel);
|
||||
bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size);
|
||||
void *AttemptMallocMem(size_t size);
|
||||
bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
|
||||
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
|
||||
AddressPtrList *kernel_outputs);
|
||||
bool AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs);
|
||||
bool AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
|
||||
AddressPtrList *kernel_outputs);
|
||||
bool AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
|
||||
const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_workspaces);
|
||||
void AllocCommunicationOpDynamicRes(const session::KernelGraph *graph);
|
||||
void AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel);
|
||||
void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel);
|
||||
|
@ -64,6 +75,8 @@ class GPUKernelRuntime : public KernelRuntime {
|
|||
void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces,
|
||||
uint32_t graph_id);
|
||||
std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_;
|
||||
std::unordered_map<void *, MemSwapManagerPtr> mem_swap_map_;
|
||||
MemSwapManagerPtr mem_swap_manager_{nullptr};
|
||||
};
|
||||
MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime);
|
||||
} // namespace gpu
|
||||
|
|
|
@ -25,10 +25,7 @@ namespace memswap {
|
|||
void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
execution_order_ = kernel_graph->execution_order();
|
||||
FuncGraphManagerPtr manager = kernel_graph->manager();
|
||||
NodeUsersMap user_map = manager->node_users();
|
||||
size_t kernel_index = 0;
|
||||
|
||||
for (const auto &kernel : execution_order_) {
|
||||
// parse topo order of kernel
|
||||
kernel_execution_info_.emplace(kernel.get(), kernel_index++);
|
||||
|
@ -44,6 +41,31 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
|
|||
}
|
||||
|
||||
// parse topo order of user kernel
|
||||
SaveUserKernelTopoOrder(kernel_graph);
|
||||
|
||||
sort(ordered_tensors_.begin(), ordered_tensors_.end(),
|
||||
[](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; });
|
||||
|
||||
auto cur_tensor_size = ordered_tensors_.front().tensor_size_;
|
||||
for (auto &tensor_info : ordered_tensors_) {
|
||||
if (cur_tensor_size != tensor_info.tensor_size_) {
|
||||
cur_tensor_size = tensor_info.tensor_size_;
|
||||
tensor_size_num_++;
|
||||
}
|
||||
}
|
||||
tensor_size_threshold_ = ordered_tensors_.front().tensor_size_;
|
||||
tensor_size_threshold_idx_ = 0;
|
||||
|
||||
distance_threshold_ = kernel_index / kDistanceInitFactor;
|
||||
mem_swap_initialized_ = true;
|
||||
MS_EXCEPTION_IF_NULL(mem_copy_manager_);
|
||||
mem_copy_manager_->Init();
|
||||
}
|
||||
|
||||
void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
FuncGraphManagerPtr manager = kernel_graph->manager();
|
||||
NodeUsersMap user_map = manager->node_users();
|
||||
for (const auto &kernel : execution_order_) {
|
||||
auto iter = user_map.find(kernel);
|
||||
if (iter == user_map.end()) {
|
||||
|
@ -66,24 +88,6 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
|
|||
sort(node_user_pair.second.begin(), node_user_pair.second.end());
|
||||
}
|
||||
}
|
||||
|
||||
sort(ordered_tensors_.begin(), ordered_tensors_.end(),
|
||||
[](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; });
|
||||
|
||||
auto cur_tensor_size = ordered_tensors_.front().tensor_size_;
|
||||
for (auto &tensor_info : ordered_tensors_) {
|
||||
if (cur_tensor_size != tensor_info.tensor_size_) {
|
||||
cur_tensor_size = tensor_info.tensor_size_;
|
||||
tensor_size_num_++;
|
||||
}
|
||||
}
|
||||
tensor_size_threshold_ = ordered_tensors_.front().tensor_size_;
|
||||
tensor_size_threshold_idx_ = 0;
|
||||
|
||||
distance_threshold_ = kernel_index / kDistanceInitFactor;
|
||||
mem_swap_initialized_ = true;
|
||||
MS_EXCEPTION_IF_NULL(mem_copy_manager_);
|
||||
mem_copy_manager_->Init();
|
||||
}
|
||||
|
||||
void MemSwapManager::AddSwapInfo() {
|
||||
|
@ -228,12 +232,12 @@ float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) cons
|
|||
return kernel_exec_info.execution_perform_;
|
||||
}
|
||||
|
||||
bool MemSwapManager::QueryKerneTriggerSwap(const AnfNodePtr &kernel) const {
|
||||
bool MemSwapManager::QueryKernelTriggerSwap(const AnfNodePtr &kernel) const {
|
||||
const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
||||
return kernel_exec_info.trigger_swap_;
|
||||
}
|
||||
|
||||
bool MemSwapManager::QueryKerneNeedSwap(const AnfNodePtr &kernel) const {
|
||||
bool MemSwapManager::QueryKernelNeedSwap(const AnfNodePtr &kernel) const {
|
||||
const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
||||
return kernel_exec_info.need_swap_;
|
||||
}
|
||||
|
@ -254,7 +258,7 @@ const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kern
|
|||
return iter_output->second;
|
||||
}
|
||||
|
||||
const std::vector<MemSwapInfo> &MemSwapManager::QueryKerneMemSwapInfo(const AnfNodePtr &kernel) const {
|
||||
const std::vector<MemSwapInfo> &MemSwapManager::QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
auto iter = mem_swap_info_.find(kernel.get());
|
||||
if (iter == mem_swap_info_.end()) {
|
||||
|
|
|
@ -63,11 +63,11 @@ class MemSwapManager {
|
|||
|
||||
const PerformPair &QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const;
|
||||
|
||||
bool QueryKerneTriggerSwap(const AnfNodePtr &kernel) const;
|
||||
bool QueryKernelTriggerSwap(const AnfNodePtr &kernel) const;
|
||||
|
||||
bool QueryKerneNeedSwap(const AnfNodePtr &kernel) const;
|
||||
bool QueryKernelNeedSwap(const AnfNodePtr &kernel) const;
|
||||
|
||||
const std::vector<MemSwapInfo> &QueryKerneMemSwapInfo(const AnfNodePtr &kernel) const;
|
||||
const std::vector<MemSwapInfo> &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const;
|
||||
|
||||
void InsertSwapInBlackList(const void *device_ptr);
|
||||
|
||||
|
@ -90,6 +90,8 @@ class MemSwapManager {
|
|||
|
||||
void ResetSwapInfo();
|
||||
|
||||
void SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph);
|
||||
|
||||
void AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap);
|
||||
|
||||
void AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap);
|
||||
|
|
Loading…
Reference in New Issue