forked from mindspore-Ecosystem/mindspore
refine gpu memory swap performance
This commit is contained in:
parent
bbfcbbe26d
commit
3ace75509b
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <set>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
@ -40,29 +41,58 @@ struct TensorInfo {
|
||||||
struct KernelExecutionInfo {
|
struct KernelExecutionInfo {
|
||||||
size_t topo_order_{0};
|
size_t topo_order_{0};
|
||||||
float execution_perform_{0.0};
|
float execution_perform_{0.0};
|
||||||
bool trigger_swap_{false};
|
bool trigger_swap_out_{false};
|
||||||
bool need_swap_{false};
|
bool trigger_swap_in_{false};
|
||||||
// output index to topo orders of node users
|
size_t swap_in_task_num_{0};
|
||||||
|
// Key: output index, value: topo orders of node users
|
||||||
std::map<size_t, std::vector<size_t>> node_users_map_;
|
std::map<size_t, std::vector<size_t>> node_users_map_;
|
||||||
// kernel output idx to host addr
|
// Key: output idx, value: (host addr, dirty or not)
|
||||||
std::map<size_t, HostAddress> host_addrs_;
|
std::map<size_t, std::pair<HostAddress, bool>> host_addrs_;
|
||||||
|
|
||||||
KernelExecutionInfo() : KernelExecutionInfo(0, 0.0, false, false) {}
|
KernelExecutionInfo() {}
|
||||||
explicit KernelExecutionInfo(size_t topo_order)
|
explicit KernelExecutionInfo(size_t topo_order) : KernelExecutionInfo(topo_order, 0.0, false, false, 0) {}
|
||||||
: topo_order_(topo_order), execution_perform_(0.0), trigger_swap_(false), need_swap_(false) {}
|
KernelExecutionInfo(size_t topo_order, float execution_perform, bool trigger_swap_out, bool trigger_swap_in,
|
||||||
KernelExecutionInfo(size_t topo_order, float execution_perform, bool trigger_swap, bool need_swap)
|
size_t swap_in_task_num)
|
||||||
: topo_order_(topo_order),
|
: topo_order_(topo_order),
|
||||||
execution_perform_(execution_perform),
|
execution_perform_(execution_perform),
|
||||||
trigger_swap_(trigger_swap),
|
trigger_swap_out_(trigger_swap_out),
|
||||||
need_swap_(need_swap) {}
|
trigger_swap_in_(trigger_swap_in),
|
||||||
|
swap_in_task_num_(swap_in_task_num) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
// trigger swap
|
|
||||||
struct MemSwapInfo {
|
struct MemSwapInfo {
|
||||||
SwapKind swap_kind_;
|
SwapKind swap_kind_;
|
||||||
// kernel need to be swapped
|
// Topo order of kernel need be swapped
|
||||||
AnfNodePtr kernel_{nullptr};
|
size_t topo_order_;
|
||||||
size_t output_idx_{0};
|
size_t output_idx_{0};
|
||||||
|
// Record the swapping out position of swapping in tensor
|
||||||
|
size_t swap_out_pos_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct SwapInfoComp {
|
||||||
|
bool operator()(const MemSwapInfo &a, const MemSwapInfo &b) {
|
||||||
|
int swap_kind_a = static_cast<int>(a.swap_kind_);
|
||||||
|
int swap_kind_b = static_cast<int>(b.swap_kind_);
|
||||||
|
if (swap_kind_a < swap_kind_b) {
|
||||||
|
return true;
|
||||||
|
} else if (swap_kind_a > swap_kind_b) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (a.swap_out_pos_ < b.swap_out_pos_) {
|
||||||
|
return true;
|
||||||
|
} else if (a.swap_out_pos_ > b.swap_out_pos_) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (a.topo_order_ < b.topo_order_) {
|
||||||
|
return true;
|
||||||
|
} else if (a.topo_order_ > b.topo_order_) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.output_idx_ < b.output_idx_;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class MemCopyManager {
|
class MemCopyManager {
|
||||||
|
@ -90,6 +120,7 @@ class MemCopyManager {
|
||||||
virtual void ClearSwapQueue() {}
|
virtual void ClearSwapQueue() {}
|
||||||
};
|
};
|
||||||
using MemCopyManagerPtr = std::shared_ptr<MemCopyManager>;
|
using MemCopyManagerPtr = std::shared_ptr<MemCopyManager>;
|
||||||
|
using MemSwapInfoSet = std::set<MemSwapInfo, SwapInfoComp>;
|
||||||
} // namespace memswap
|
} // namespace memswap
|
||||||
} // namespace device
|
} // namespace device
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -22,22 +22,17 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace memswap {
|
namespace memswap {
|
||||||
void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
|
bool MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph, size_t swap_mem_size) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
graph_manager_ = kernel_graph->manager();
|
graph_manager_ = kernel_graph->manager();
|
||||||
MS_EXCEPTION_IF_NULL(graph_manager_);
|
MS_EXCEPTION_IF_NULL(graph_manager_);
|
||||||
auto &kernels = kernel_graph->execution_order();
|
execution_order_ = kernel_graph->execution_order();
|
||||||
for (const auto &kernel : kernels) {
|
|
||||||
if (AnfAlgo::IsRealCNodeKernel(kernel) && (!opt::IsNopNode(kernel))) {
|
|
||||||
execution_order_.push_back(kernel);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t kernel_index = 0;
|
size_t kernel_index = 0;
|
||||||
for (const auto &kernel : execution_order_) {
|
for (const auto &kernel : execution_order_) {
|
||||||
// parse topo order of kernel
|
// Parse topo order of kernel
|
||||||
(void)kernel_execution_info_.emplace(kernel.get(), kernel_index++);
|
(void)kernel_execution_info_.emplace(kernel.get(), kernel_index++);
|
||||||
// parse tensor info
|
// Parse tensor info
|
||||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
auto output_sizes = kernel_mod->GetOutputSizeList();
|
auto output_sizes = kernel_mod->GetOutputSizeList();
|
||||||
|
@ -48,7 +43,7 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse topo order of user kernel
|
// Parse topo order of user kernel
|
||||||
SaveUserKernelTopoOrder();
|
SaveUserKernelTopoOrder();
|
||||||
|
|
||||||
sort(ordered_tensors_.begin(), ordered_tensors_.end(),
|
sort(ordered_tensors_.begin(), ordered_tensors_.end(),
|
||||||
|
@ -61,17 +56,103 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
|
||||||
tensor_size_num_++;
|
tensor_size_num_++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tensor_size_threshold_ = ordered_tensors_.front().tensor_size_;
|
if (!InitSwapThreshold(0)) {
|
||||||
tensor_size_threshold_idx_ = 0;
|
return false;
|
||||||
|
}
|
||||||
distance_threshold_ = kernel_index / kDistanceInitFactor;
|
|
||||||
mem_swap_initialized_ = true;
|
mem_swap_initialized_ = true;
|
||||||
MS_EXCEPTION_IF_NULL(mem_copy_manager_);
|
MS_EXCEPTION_IF_NULL(mem_copy_manager_);
|
||||||
mem_copy_manager_->Init();
|
mem_copy_manager_->Init();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MemSwapManager::InitSwapThreshold(size_t swap_mem_size) {
|
||||||
|
distance_threshold_ = execution_order_.size() / kDistanceInitFactor;
|
||||||
|
distance_decay_step_ = execution_order_.size() / kDistanceInitFactor / tensor_size_num_;
|
||||||
|
if (distance_decay_step_ <= 1) {
|
||||||
|
distance_decay_step_ = 1;
|
||||||
|
}
|
||||||
|
tensor_size_threshold_ = ordered_tensors_.front().tensor_size_;
|
||||||
|
tensor_size_threshold_idx_ = 0;
|
||||||
|
|
||||||
|
size_t accumulation = 0;
|
||||||
|
while (accumulation < swap_mem_size) {
|
||||||
|
accumulation = 0;
|
||||||
|
for (const auto &tensor_info : ordered_tensors_) {
|
||||||
|
size_t tensor_size = tensor_info.tensor_size_;
|
||||||
|
if (tensor_size < tensor_size_threshold_) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (!CheckDistanceBetweenKernels(tensor_info)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
accumulation += tensor_info.tensor_size_;
|
||||||
|
if (accumulation >= swap_mem_size) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RetreatSwapThreshold();
|
||||||
|
if (tensor_size_threshold_idx_ == ordered_tensors_.size() - 1 && distance_threshold_ < kDistanceLowerBound) {
|
||||||
|
MS_LOG(ERROR) << "Init swap threshold info failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemSwapManager::RetreatSwapThreshold() {
|
||||||
|
if (distance_threshold_ >= kDistanceLowerBound) {
|
||||||
|
bool update_one_decay_step = (distance_threshold_ > distance_decay_step_) &&
|
||||||
|
(distance_threshold_ - distance_decay_step_ >= kDistanceLowerBound);
|
||||||
|
if (update_one_decay_step) {
|
||||||
|
distance_threshold_ -= distance_decay_step_;
|
||||||
|
} else if (distance_threshold_ >= kDistanceLowerBound) {
|
||||||
|
size_t new_distance_decay_step = (distance_threshold_ - kDistanceLowerBound) / 4;
|
||||||
|
if (new_distance_decay_step < 1) {
|
||||||
|
new_distance_decay_step = 1;
|
||||||
|
}
|
||||||
|
distance_threshold_ -= new_distance_decay_step;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) {
|
||||||
|
++tensor_size_threshold_idx_;
|
||||||
|
if (tensor_size_threshold_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) {
|
||||||
|
tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MemSwapManager::CheckDistanceBetweenKernels(const TensorInfo &tensor_info) const {
|
||||||
|
const AnfNodePtr &kernel = tensor_info.kernel_;
|
||||||
|
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
||||||
|
auto &node_users_map = kernel_exec_info.node_users_map_;
|
||||||
|
|
||||||
|
auto iter = node_users_map.find(tensor_info.output_idx_);
|
||||||
|
if (iter == node_users_map.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &node_users = iter->second;
|
||||||
|
if (node_users.front() - kernel_exec_info.topo_order_ > distance_threshold_) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 1; i < node_users.size(); ++i) {
|
||||||
|
if (node_users[i] - node_users[i - 1] > distance_threshold_) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const {
|
bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const {
|
||||||
MS_EXCEPTION_IF_NULL(kernel);
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
|
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
NodeUsersMap &user_map = graph_manager_->node_users();
|
NodeUsersMap &user_map = graph_manager_->node_users();
|
||||||
auto iter = user_map.find(kernel);
|
auto iter = user_map.find(kernel);
|
||||||
bool adjacent_with_communication_op = false;
|
bool adjacent_with_communication_op = false;
|
||||||
|
@ -81,7 +162,7 @@ bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const {
|
||||||
node_set.begin(), node_set.end(),
|
node_set.begin(), node_set.end(),
|
||||||
[](const std::pair<AnfNodePtr, int> &node_pair) { return AnfAlgo::IsCommunicationOp(node_pair.first); });
|
[](const std::pair<AnfNodePtr, int> &node_pair) { return AnfAlgo::IsCommunicationOp(node_pair.first); });
|
||||||
}
|
}
|
||||||
return (AnfAlgo::IsCommunicationOp(kernel)) || adjacent_with_communication_op;
|
return adjacent_with_communication_op;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemSwapManager::SaveUserKernelTopoOrder() {
|
void MemSwapManager::SaveUserKernelTopoOrder() {
|
||||||
|
@ -95,7 +176,7 @@ void MemSwapManager::SaveUserKernelTopoOrder() {
|
||||||
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
||||||
for (auto &node_pair : node_set) {
|
for (auto &node_pair : node_set) {
|
||||||
auto user_kernel = node_pair.first;
|
auto user_kernel = node_pair.first;
|
||||||
if (!AnfAlgo::IsRealCNodeKernel(user_kernel) || opt::IsNopNode(user_kernel)) {
|
if (!AnfAlgo::IsRealCNodeKernel(user_kernel)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,21 +219,18 @@ void MemSwapManager::AddSwapInfo() {
|
||||||
if (!need_swap) {
|
if (!need_swap) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
AddKernelNeedSwap(kernel, true);
|
|
||||||
HostAddress host_addr;
|
HostAddress host_addr;
|
||||||
host_addr.size = tensor_size;
|
host_addr.size = tensor_size;
|
||||||
auto ret = AllocHostPinnedMem(tensor_size, reinterpret_cast<void **>(&host_addr.addr));
|
auto ret = AllocHostPinnedMem(tensor_size, reinterpret_cast<void **>(&host_addr.addr));
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << tensor_size << "] failed.";
|
MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << tensor_size << "] failed.";
|
||||||
}
|
}
|
||||||
kernel_exec_info.host_addrs_[output_idx] = host_addr;
|
kernel_exec_info.host_addrs_[output_idx] = std::make_pair(host_addr, true);
|
||||||
MemSwapInfo mem_swap_out_info = {SwapKind::kDeviceToHost, kernel, output_idx};
|
MemSwapInfo mem_swap_out_info = {SwapKind::kDeviceToHost, kernel_exec_info.topo_order_, output_idx, 0};
|
||||||
if (node_users.size() > 1) {
|
if (node_users.size() > 1) {
|
||||||
AddKernelMemSwapInfo(execution_order_[node_users[0]], mem_swap_out_info);
|
AddKernelMemSwapInfo(execution_order_[node_users[0]], mem_swap_out_info);
|
||||||
AddKernelTriggerSwap(execution_order_[node_users[0]], true);
|
|
||||||
} else {
|
} else {
|
||||||
AddKernelMemSwapInfo(kernel, mem_swap_out_info);
|
AddKernelMemSwapInfo(kernel, mem_swap_out_info);
|
||||||
AddKernelTriggerSwap(kernel, true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t swap_in_order = node_users.size() == 1 ? node_users[0] - 1 : node_users[1] - 1;
|
size_t swap_in_order = node_users.size() == 1 ? node_users[0] - 1 : node_users[1] - 1;
|
||||||
|
@ -160,9 +238,8 @@ void MemSwapManager::AddSwapInfo() {
|
||||||
MS_LOG(EXCEPTION) << "Select swap in point failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]";
|
MS_LOG(EXCEPTION) << "Select swap in point failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]";
|
||||||
}
|
}
|
||||||
auto swap_in_kernel = execution_order_[swap_in_order];
|
auto swap_in_kernel = execution_order_[swap_in_order];
|
||||||
MemSwapInfo mem_swap_in_info = {SwapKind::kHostToDevice, kernel, output_idx};
|
MemSwapInfo mem_swap_in_info = {SwapKind::kHostToDevice, kernel_exec_info.topo_order_, output_idx, 0};
|
||||||
AddKernelMemSwapInfo(swap_in_kernel, mem_swap_in_info);
|
AddKernelMemSwapInfo(swap_in_kernel, mem_swap_in_info);
|
||||||
AddKernelTriggerSwap(swap_in_kernel, true);
|
|
||||||
|
|
||||||
host_addrs_list_.push_back(host_addr);
|
host_addrs_list_.push_back(host_addr);
|
||||||
}
|
}
|
||||||
|
@ -189,7 +266,7 @@ DeviceAddressPtr MemSwapManager::UpdateSwapQueue(SwapKind swap_kind) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// retreat to find a workable swap scheme
|
// Retreat to find a workable swap scheme
|
||||||
bool MemSwapManager::RetreatSwapInfo() {
|
bool MemSwapManager::RetreatSwapInfo() {
|
||||||
if (!trigger_swap_) {
|
if (!trigger_swap_) {
|
||||||
trigger_swap_ = true;
|
trigger_swap_ = true;
|
||||||
|
@ -220,6 +297,114 @@ bool MemSwapManager::RetreatSwapInfo() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MemSwapManager::AdjustSwapInPos(const AnfNodePtr &kernel, size_t index) {
|
||||||
|
if (kernel_first_move_cache_map_.find(kernel.get()) == kernel_first_move_cache_map_.end()) {
|
||||||
|
CacheCurSwapInfoSet(kernel);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
||||||
|
size_t kernel_pos = kernel_exec_info.topo_order_;
|
||||||
|
auto &mem_swap_info = mem_swap_info_cache_list_[index];
|
||||||
|
|
||||||
|
if (QueryFirstTimeMovePos(kernel, index)) {
|
||||||
|
best_and_cur_pos_cache_.first = BestSwapInPerformPos(kernel, mem_swap_info);
|
||||||
|
best_and_cur_pos_cache_.second = best_and_cur_pos_cache_.first;
|
||||||
|
size_t best_pos = best_and_cur_pos_cache_.first;
|
||||||
|
if (best_pos != kernel_pos) {
|
||||||
|
MoveSwapInfoPos(best_pos, kernel_pos, mem_swap_info);
|
||||||
|
}
|
||||||
|
AddFirstTimeMovePos(kernel, index, false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &cur_pos = best_and_cur_pos_cache_.second;
|
||||||
|
if (cur_pos < kernel_pos) {
|
||||||
|
MoveSwapInfoPos(cur_pos + 1, cur_pos, mem_swap_info);
|
||||||
|
cur_pos++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemSwapManager::CacheCurSwapInfoSet(const AnfNodePtr &kernel) {
|
||||||
|
if (!kernel_first_move_cache_map_.empty()) {
|
||||||
|
kernel_first_move_cache_map_.clear();
|
||||||
|
}
|
||||||
|
if (!mem_swap_info_cache_list_.empty()) {
|
||||||
|
mem_swap_info_cache_list_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto mem_swap_info_set = QueryKernelMemSwapInfo(kernel);
|
||||||
|
size_t swap_in_task_cnt = 0;
|
||||||
|
for (auto &mem_swap_info : mem_swap_info_set) {
|
||||||
|
if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) {
|
||||||
|
(void)mem_swap_info_cache_list_.push_back(mem_swap_info);
|
||||||
|
kernel_first_move_cache_map_[kernel.get()].push_back(true);
|
||||||
|
swap_in_task_cnt++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
size_t swap_in_task_num = QueryKernelTriggerSwapInTaskNum(kernel);
|
||||||
|
if (swap_in_task_cnt != swap_in_task_num) {
|
||||||
|
MS_LOG(EXCEPTION) << "Swap_in_task_cnt :" << swap_in_task_cnt
|
||||||
|
<< "must equal Swap_in_task_num: " << swap_in_task_num;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemSwapManager::AddFirstTimeMovePos(const AnfNodePtr &kernel, size_t index, bool first_time) {
|
||||||
|
auto iter = kernel_first_move_cache_map_.find(kernel.get());
|
||||||
|
if (iter == kernel_first_move_cache_map_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can not find first time move pos info of op[" << AnfAlgo::GetCNodeName(kernel) << "]";
|
||||||
|
}
|
||||||
|
auto &first_move_list = iter->second;
|
||||||
|
if (index >= first_move_list.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Index [" << index << "] out of range";
|
||||||
|
}
|
||||||
|
first_move_list[index] = first_time;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MemSwapManager::QueryFirstTimeMovePos(const AnfNodePtr &kernel, size_t index) const {
|
||||||
|
auto iter = kernel_first_move_cache_map_.find(kernel.get());
|
||||||
|
if (iter == kernel_first_move_cache_map_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can not find first time move pos info of op[" << AnfAlgo::GetCNodeName(kernel) << "]";
|
||||||
|
}
|
||||||
|
const auto &first_move_list = iter->second;
|
||||||
|
if (index >= first_move_list.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Index [" << index << "] out of range";
|
||||||
|
}
|
||||||
|
return first_move_list[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t MemSwapManager::BestSwapInPerformPos(const AnfNodePtr &trigger_kernel, const MemSwapInfo &mem_swap_info) const {
|
||||||
|
auto need_swap_kernel = QueryKerneByTopoOrder(mem_swap_info.topo_order_);
|
||||||
|
const PerformPair &perform_pair = QueryKernelSwapPerform(need_swap_kernel, mem_swap_info.output_idx_);
|
||||||
|
float swap_in_cost_time = perform_pair.second;
|
||||||
|
size_t swap_out_pos = mem_swap_info.swap_out_pos_;
|
||||||
|
auto &kernel_exec_info = SearchKernelExecutionInfo(trigger_kernel);
|
||||||
|
size_t trigger_kernel_pos = kernel_exec_info.topo_order_;
|
||||||
|
float kernel_execution_time = 0;
|
||||||
|
|
||||||
|
size_t pos = trigger_kernel_pos;
|
||||||
|
for (; pos > swap_out_pos + 1; pos--) {
|
||||||
|
auto kernel = QueryKerneByTopoOrder(pos - 1);
|
||||||
|
if (QueryKernelTriggerSwapIn(kernel)) {
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
kernel_execution_time += QueryKernelExecutionPerform(QueryKerneByTopoOrder(pos));
|
||||||
|
if (kernel_execution_time >= swap_in_cost_time) {
|
||||||
|
return pos - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemSwapManager::MoveSwapInfoPos(size_t des_pos, size_t src_pos, const MemSwapInfo &mem_swap_info) {
|
||||||
|
if (des_pos == src_pos) {
|
||||||
|
MS_LOG(EXCEPTION) << "destination pos can not equal source pos";
|
||||||
|
}
|
||||||
|
auto des_kernel = QueryKerneByTopoOrder(des_pos);
|
||||||
|
auto src_kernel = QueryKerneByTopoOrder(src_pos);
|
||||||
|
AddKernelMemSwapInfo(des_kernel, mem_swap_info);
|
||||||
|
RemoveKernelMemSwapInfo(src_kernel, mem_swap_info);
|
||||||
|
}
|
||||||
|
|
||||||
KernelExecutionInfo &MemSwapManager::SearchKernelExecutionInfo(const AnfNodePtr &kernel) const {
|
KernelExecutionInfo &MemSwapManager::SearchKernelExecutionInfo(const AnfNodePtr &kernel) const {
|
||||||
MS_EXCEPTION_IF_NULL(kernel);
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
auto iter = kernel_execution_info_.find(kernel.get());
|
auto iter = kernel_execution_info_.find(kernel.get());
|
||||||
|
@ -234,16 +419,6 @@ void MemSwapManager::AddKernelExecutionPerform(const AnfNodePtr &kernel, float p
|
||||||
kernel_exec_info.execution_perform_ = perform;
|
kernel_exec_info.execution_perform_ = perform;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemSwapManager::AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap) {
|
|
||||||
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
|
||||||
kernel_exec_info.trigger_swap_ = trigger_swap;
|
|
||||||
}
|
|
||||||
|
|
||||||
void MemSwapManager::AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap) {
|
|
||||||
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
|
||||||
kernel_exec_info.need_swap_ = need_swap;
|
|
||||||
}
|
|
||||||
|
|
||||||
void MemSwapManager::AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx,
|
void MemSwapManager::AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx,
|
||||||
const std::pair<float, float> &perform) {
|
const std::pair<float, float> &perform) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel);
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
|
@ -252,7 +427,42 @@ void MemSwapManager::AddKernelSwapPerform(const AnfNodePtr &kernel, size_t outpu
|
||||||
|
|
||||||
void MemSwapManager::AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) {
|
void MemSwapManager::AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel);
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
mem_swap_info_[kernel.get()].push_back(mem_swap_info);
|
(void)mem_swap_info_map_[kernel.get()].insert(mem_swap_info);
|
||||||
|
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
||||||
|
if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) {
|
||||||
|
kernel_exec_info.trigger_swap_out_ = true;
|
||||||
|
} else if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) {
|
||||||
|
kernel_exec_info.swap_in_task_num_++;
|
||||||
|
kernel_exec_info.trigger_swap_in_ = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemSwapManager::RemoveKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
|
if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) {
|
||||||
|
auto map_iter = mem_swap_info_map_.find(kernel.get());
|
||||||
|
if (map_iter == mem_swap_info_map_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can not find memory swap information of op[" << AnfAlgo::GetCNodeName(kernel) << "]";
|
||||||
|
}
|
||||||
|
MemSwapInfoSet &mem_swap_info_set = map_iter->second;
|
||||||
|
|
||||||
|
auto set_iter = mem_swap_info_set.find(mem_swap_info);
|
||||||
|
if (set_iter == mem_swap_info_set.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can not find memory swap information in mem swap info set";
|
||||||
|
}
|
||||||
|
mem_swap_info_set.erase(set_iter);
|
||||||
|
|
||||||
|
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
||||||
|
if (kernel_exec_info.swap_in_task_num_ > 0) {
|
||||||
|
kernel_exec_info.swap_in_task_num_--;
|
||||||
|
}
|
||||||
|
if (kernel_exec_info.swap_in_task_num_ == 0) {
|
||||||
|
kernel_exec_info.trigger_swap_in_ = false;
|
||||||
|
}
|
||||||
|
if (mem_swap_info_set.empty()) {
|
||||||
|
(void)mem_swap_info_map_.erase(kernel.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) const {
|
float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) const {
|
||||||
|
@ -262,12 +472,24 @@ float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) cons
|
||||||
|
|
||||||
bool MemSwapManager::QueryKernelTriggerSwap(const AnfNodePtr &kernel) const {
|
bool MemSwapManager::QueryKernelTriggerSwap(const AnfNodePtr &kernel) const {
|
||||||
const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
||||||
return kernel_exec_info.trigger_swap_;
|
return kernel_exec_info.trigger_swap_out_ || kernel_exec_info.trigger_swap_in_;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MemSwapManager::QueryKernelNeedSwap(const AnfNodePtr &kernel) const {
|
bool MemSwapManager::QueryKernelTriggerSwapIn(const AnfNodePtr &kernel) const {
|
||||||
const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
||||||
return kernel_exec_info.need_swap_;
|
return kernel_exec_info.trigger_swap_in_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t MemSwapManager::QueryKernelTriggerSwapInTaskNum(const AnfNodePtr &kernel) const {
|
||||||
|
const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
||||||
|
return kernel_exec_info.swap_in_task_num_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr MemSwapManager::QueryKerneByTopoOrder(size_t index) const {
|
||||||
|
if (index >= execution_order_.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Index [" << index << "] out of range";
|
||||||
|
}
|
||||||
|
return execution_order_[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const {
|
const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const {
|
||||||
|
@ -286,15 +508,70 @@ const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kern
|
||||||
return iter_output->second;
|
return iter_output->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<MemSwapInfo> &MemSwapManager::QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const {
|
const MemSwapInfoSet &MemSwapManager::QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const {
|
||||||
MS_EXCEPTION_IF_NULL(kernel);
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
auto iter = mem_swap_info_.find(kernel.get());
|
auto iter = mem_swap_info_map_.find(kernel.get());
|
||||||
if (iter == mem_swap_info_.end()) {
|
if (iter == mem_swap_info_map_.end()) {
|
||||||
MS_LOG(EXCEPTION) << "Can not find memory swap information data of op[" << AnfAlgo::GetCNodeName(kernel) << "]";
|
MS_LOG(EXCEPTION) << "Can not find memory swap information of op[" << AnfAlgo::GetCNodeName(kernel) << "]";
|
||||||
}
|
}
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MemSwapManager::AssignHostMemory() {
|
||||||
|
for (auto &kernel_exec_info_pair : kernel_execution_info_) {
|
||||||
|
auto &kernel_exec_info = kernel_exec_info_pair.second;
|
||||||
|
auto &host_addrs_map = kernel_exec_info.host_addrs_;
|
||||||
|
for (auto &host_addr_pair : host_addrs_map) {
|
||||||
|
auto &host_addr = host_addr_pair.second.first;
|
||||||
|
auto ret = AllocHostPinnedMem(host_addr.size, reinterpret_cast<void **>(&host_addr.addr));
|
||||||
|
if (!ret) {
|
||||||
|
MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << host_addr.size << "] failed.";
|
||||||
|
}
|
||||||
|
host_addrs_list_.push_back(host_addr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const HostAddress &MemSwapManager::QueryKernelHostAddr(const AnfNodePtr &kernel, size_t output_idx) const {
|
||||||
|
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
||||||
|
auto &host_addrs = kernel_exec_info.host_addrs_;
|
||||||
|
auto iter = host_addrs.find(output_idx);
|
||||||
|
if (iter == host_addrs.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can not find host address of op[" << AnfAlgo::GetCNodeName(kernel) << "]";
|
||||||
|
}
|
||||||
|
return (iter->second).first;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemSwapManager::AddKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx, bool dirty) {
|
||||||
|
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
||||||
|
auto &host_addrs = kernel_exec_info.host_addrs_;
|
||||||
|
auto iter = host_addrs.find(output_idx);
|
||||||
|
if (iter == host_addrs.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can not find host memory dirty info of op[" << AnfAlgo::GetCNodeName(kernel) << "]";
|
||||||
|
}
|
||||||
|
(iter->second).second = dirty;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MemSwapManager::QueryKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx) const {
|
||||||
|
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
||||||
|
auto &host_addrs = kernel_exec_info.host_addrs_;
|
||||||
|
auto iter = host_addrs.find(output_idx);
|
||||||
|
if (iter == host_addrs.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can not find host memory dirty info of op[" << AnfAlgo::GetCNodeName(kernel) << "]";
|
||||||
|
}
|
||||||
|
return (iter->second).second;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemSwapManager::ResetHostAddrIsDirty() {
|
||||||
|
for (auto &kernel_exec_info_pair : kernel_execution_info_) {
|
||||||
|
auto &kernel_exec_info = kernel_exec_info_pair.second;
|
||||||
|
auto &host_addrs = kernel_exec_info.host_addrs_;
|
||||||
|
for (auto &host_addr : host_addrs) {
|
||||||
|
host_addr.second.second = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void MemSwapManager::InsertSwapInBlackList(const void *device_ptr) { swap_in_blacklist_.insert(device_ptr); }
|
void MemSwapManager::InsertSwapInBlackList(const void *device_ptr) { swap_in_blacklist_.insert(device_ptr); }
|
||||||
|
|
||||||
bool MemSwapManager::FindInSwapInBlackList(const void *device_ptr) const {
|
bool MemSwapManager::FindInSwapInBlackList(const void *device_ptr) const {
|
||||||
|
@ -302,16 +579,6 @@ bool MemSwapManager::FindInSwapInBlackList(const void *device_ptr) const {
|
||||||
return iter != swap_in_blacklist_.end();
|
return iter != swap_in_blacklist_.end();
|
||||||
}
|
}
|
||||||
|
|
||||||
const HostAddress &MemSwapManager::kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const {
|
|
||||||
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
|
||||||
auto &host_addrs = kernel_exec_info.host_addrs_;
|
|
||||||
auto iter = host_addrs.find(output_idx);
|
|
||||||
if (iter == host_addrs.end()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Can not find host address of op[" << AnfAlgo::GetCNodeName(kernel) << "]";
|
|
||||||
}
|
|
||||||
return iter->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool MemSwapManager::AllocHostPinnedMem(size_t size, void **addr) const {
|
bool MemSwapManager::AllocHostPinnedMem(size_t size, void **addr) const {
|
||||||
return mem_copy_manager_->AllocHostPinnedMem(size, addr);
|
return mem_copy_manager_->AllocHostPinnedMem(size, addr);
|
||||||
}
|
}
|
||||||
|
@ -331,13 +598,14 @@ void MemSwapManager::ResetSwapInfo() {
|
||||||
ClearSwapQueue();
|
ClearSwapQueue();
|
||||||
for (auto &kernel_exec_info_pair : kernel_execution_info_) {
|
for (auto &kernel_exec_info_pair : kernel_execution_info_) {
|
||||||
auto &kernel_exec_info = kernel_exec_info_pair.second;
|
auto &kernel_exec_info = kernel_exec_info_pair.second;
|
||||||
kernel_exec_info.trigger_swap_ = false;
|
kernel_exec_info.trigger_swap_out_ = false;
|
||||||
kernel_exec_info.need_swap_ = false;
|
kernel_exec_info.trigger_swap_in_ = false;
|
||||||
|
kernel_exec_info.swap_in_task_num_ = 0;
|
||||||
kernel_exec_info.host_addrs_.clear();
|
kernel_exec_info.host_addrs_.clear();
|
||||||
}
|
}
|
||||||
ReleaseHostPinnedMem();
|
ReleaseHostPinnedMem();
|
||||||
swap_in_blacklist_.clear();
|
swap_in_blacklist_.clear();
|
||||||
mem_swap_info_.clear();
|
mem_swap_info_map_.clear();
|
||||||
}
|
}
|
||||||
} // namespace memswap
|
} // namespace memswap
|
||||||
} // namespace device
|
} // namespace device
|
||||||
|
|
|
@ -32,7 +32,11 @@ namespace memswap {
|
||||||
class MemSwapManager {
|
class MemSwapManager {
|
||||||
public:
|
public:
|
||||||
explicit MemSwapManager(const MemCopyManagerPtr &mem_copy_manager)
|
explicit MemSwapManager(const MemCopyManagerPtr &mem_copy_manager)
|
||||||
: tensor_size_threshold_(0), tensor_size_threshold_idx_(0), tensor_size_num_(1), distance_threshold_(1) {
|
: tensor_size_threshold_(0),
|
||||||
|
tensor_size_threshold_idx_(0),
|
||||||
|
tensor_size_num_(1),
|
||||||
|
distance_threshold_(1),
|
||||||
|
distance_decay_step_(1) {
|
||||||
mem_copy_manager_ = mem_copy_manager;
|
mem_copy_manager_ = mem_copy_manager;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,7 +46,7 @@ class MemSwapManager {
|
||||||
|
|
||||||
~MemSwapManager() = default;
|
~MemSwapManager() = default;
|
||||||
|
|
||||||
void Init(const mindspore::session::KernelGraph *kernel_graph);
|
bool Init(const mindspore::session::KernelGraph *kernel_graph, size_t swap_mem_size = 0);
|
||||||
|
|
||||||
void AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address,
|
void AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address,
|
||||||
const HostAddress &host_address) const;
|
const HostAddress &host_address) const;
|
||||||
|
@ -51,9 +55,10 @@ class MemSwapManager {
|
||||||
|
|
||||||
DeviceAddressPtr UpdateSwapQueue(SwapKind swap_kind) const;
|
DeviceAddressPtr UpdateSwapQueue(SwapKind swap_kind) const;
|
||||||
|
|
||||||
// retreat to find a workable swap scheme
|
|
||||||
bool RetreatSwapInfo();
|
bool RetreatSwapInfo();
|
||||||
|
|
||||||
|
void AdjustSwapInPos(const AnfNodePtr &kernel, size_t index);
|
||||||
|
|
||||||
bool trigger_swap() const { return trigger_swap_; }
|
bool trigger_swap() const { return trigger_swap_; }
|
||||||
|
|
||||||
bool mem_swap_init() const { return mem_swap_initialized_; }
|
bool mem_swap_init() const { return mem_swap_initialized_; }
|
||||||
|
@ -70,16 +75,28 @@ class MemSwapManager {
|
||||||
|
|
||||||
bool QueryKernelTriggerSwap(const AnfNodePtr &kernel) const;
|
bool QueryKernelTriggerSwap(const AnfNodePtr &kernel) const;
|
||||||
|
|
||||||
bool QueryKernelNeedSwap(const AnfNodePtr &kernel) const;
|
bool QueryKernelTriggerSwapIn(const AnfNodePtr &kernel) const;
|
||||||
|
|
||||||
const std::vector<MemSwapInfo> &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const;
|
size_t QueryKernelTriggerSwapInTaskNum(const AnfNodePtr &kernel) const;
|
||||||
|
|
||||||
|
const AnfNodePtr QueryKerneByTopoOrder(size_t index) const;
|
||||||
|
|
||||||
|
const MemSwapInfoSet &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const;
|
||||||
|
|
||||||
|
void AssignHostMemory();
|
||||||
|
|
||||||
|
const HostAddress &QueryKernelHostAddr(const AnfNodePtr &kernel, size_t output_idx) const;
|
||||||
|
|
||||||
|
void AddKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx, bool dirty);
|
||||||
|
|
||||||
|
bool QueryKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx) const;
|
||||||
|
|
||||||
|
void ResetHostAddrIsDirty();
|
||||||
|
|
||||||
void InsertSwapInBlackList(const void *device_ptr);
|
void InsertSwapInBlackList(const void *device_ptr);
|
||||||
|
|
||||||
bool FindInSwapInBlackList(const void *device_ptr) const;
|
bool FindInSwapInBlackList(const void *device_ptr) const;
|
||||||
|
|
||||||
const HostAddress &kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const;
|
|
||||||
|
|
||||||
bool AllocHostPinnedMem(size_t size, void **addr) const;
|
bool AllocHostPinnedMem(size_t size, void **addr) const;
|
||||||
|
|
||||||
void ReleaseHostPinnedMem();
|
void ReleaseHostPinnedMem();
|
||||||
|
@ -93,27 +110,47 @@ class MemSwapManager {
|
||||||
|
|
||||||
void SaveUserKernelTopoOrder();
|
void SaveUserKernelTopoOrder();
|
||||||
|
|
||||||
void AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap);
|
bool InitSwapThreshold(size_t swap_mem_size);
|
||||||
|
|
||||||
void AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap);
|
void RetreatSwapThreshold();
|
||||||
|
|
||||||
|
void CacheCurSwapInfoSet(const AnfNodePtr &kernel);
|
||||||
|
|
||||||
|
void AddFirstTimeMovePos(const AnfNodePtr &kernel, size_t index, bool first_time);
|
||||||
|
|
||||||
|
bool QueryFirstTimeMovePos(const AnfNodePtr &kernel, size_t index) const;
|
||||||
|
|
||||||
|
size_t BestSwapInPerformPos(const AnfNodePtr &trigger_kernel, const MemSwapInfo &mem_swap_info) const;
|
||||||
|
|
||||||
|
void MoveSwapInfoPos(size_t des_pos, size_t src_pos, const MemSwapInfo &mem_swap_info);
|
||||||
|
|
||||||
void AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info);
|
void AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info);
|
||||||
|
|
||||||
|
void RemoveKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info);
|
||||||
|
|
||||||
|
bool CheckDistanceBetweenKernels(const TensorInfo &tensor_info) const;
|
||||||
|
|
||||||
bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const;
|
bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const;
|
||||||
|
|
||||||
std::vector<CNodePtr> execution_order_;
|
std::vector<CNodePtr> execution_order_;
|
||||||
std::vector<TensorInfo> ordered_tensors_;
|
std::vector<TensorInfo> ordered_tensors_;
|
||||||
std::unordered_map<void *, KernelExecutionInfo> kernel_execution_info_;
|
std::unordered_map<void *, KernelExecutionInfo> kernel_execution_info_;
|
||||||
std::unordered_map<void *, std::map<size_t, PerformPair>> kernel_swap_perform_;
|
std::unordered_map<void *, std::map<size_t, PerformPair>> kernel_swap_perform_;
|
||||||
// trigger swap kernel key : MemSwapInfo of kernel need to be swapped
|
// Key: trigger swap kernel, value: MemSwapInfoSet of kernel need to be swapped
|
||||||
std::unordered_map<void *, std::vector<MemSwapInfo>> mem_swap_info_;
|
std::unordered_map<void *, MemSwapInfoSet> mem_swap_info_map_;
|
||||||
std::vector<HostAddress> host_addrs_list_;
|
std::vector<HostAddress> host_addrs_list_;
|
||||||
std::unordered_set<const void *> swap_in_blacklist_;
|
std::unordered_set<const void *> swap_in_blacklist_;
|
||||||
|
|
||||||
|
// Key: cache kernel address, value: lists of first time move pos or not
|
||||||
|
std::map<void *, std::vector<bool>> kernel_first_move_cache_map_;
|
||||||
|
std::vector<MemSwapInfo> mem_swap_info_cache_list_;
|
||||||
|
std::pair<size_t, size_t> best_and_cur_pos_cache_;
|
||||||
|
|
||||||
size_t tensor_size_threshold_;
|
size_t tensor_size_threshold_;
|
||||||
size_t tensor_size_threshold_idx_;
|
size_t tensor_size_threshold_idx_;
|
||||||
size_t tensor_size_num_;
|
size_t tensor_size_num_;
|
||||||
size_t distance_threshold_;
|
size_t distance_threshold_;
|
||||||
|
size_t distance_decay_step_;
|
||||||
|
|
||||||
MemCopyManagerPtr mem_copy_manager_{nullptr};
|
MemCopyManagerPtr mem_copy_manager_{nullptr};
|
||||||
FuncGraphManagerPtr graph_manager_{nullptr};
|
FuncGraphManagerPtr graph_manager_{nullptr};
|
||||||
|
|
|
@ -707,6 +707,18 @@ DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, siz
|
||||||
return addr;
|
return addr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get workspace device mutable addr of anf_node
|
||||||
|
DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||||
|
auto addr = kernel_info->GetMutableWorkspaceAddr(index);
|
||||||
|
if (addr == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Index " << index << " of node " << node->DebugString() << "] workspace addr is not exist";
|
||||||
|
}
|
||||||
|
return addr;
|
||||||
|
}
|
||||||
|
|
||||||
// set infer shapes and types of anf node
|
// set infer shapes and types of anf node
|
||||||
void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
|
void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
|
||||||
const std::vector<std::vector<size_t>> &shapes, AnfNode *node) {
|
const std::vector<std::vector<size_t>> &shapes, AnfNode *node) {
|
||||||
|
|
|
@ -149,6 +149,8 @@ class AnfRuntimeAlgorithm {
|
||||||
static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
|
static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
|
||||||
// get workspace device addr of anf_node
|
// get workspace device addr of anf_node
|
||||||
static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx);
|
static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx);
|
||||||
|
// get workspace device mutable addr of anf_node
|
||||||
|
static DeviceAddressPtr GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index);
|
||||||
// set infer shapes and types of anf node
|
// set infer shapes and types of anf node
|
||||||
static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
|
static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
|
||||||
const std::vector<std::vector<size_t>> &shapes, AnfNode *node);
|
const std::vector<std::vector<size_t>> &shapes, AnfNode *node);
|
||||||
|
|
|
@ -209,6 +209,16 @@ bool CudaDriver::QueryEvent(const DeviceEvent &event) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool CudaDriver::ElapsedTime(float *cost_time, const DeviceEvent &start, const DeviceEvent &end) {
|
||||||
|
auto ret = cudaEventElapsedTime(cost_time, (cudaEvent_t)start, (cudaEvent_t)end);
|
||||||
|
if (ret == cudaSuccess) {
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "cudaEventElapsedTime failed, ret[" << static_cast<int>(ret) << "], " << cudaGetErrorString(ret);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int CudaDriver::device_count() {
|
int CudaDriver::device_count() {
|
||||||
int dev_count;
|
int dev_count;
|
||||||
auto ret = cudaGetDeviceCount(&dev_count);
|
auto ret = cudaGetDeviceCount(&dev_count);
|
||||||
|
|
|
@ -57,6 +57,7 @@ class CudaDriver {
|
||||||
static bool RecordEvent(DeviceEvent event, DeviceStream stream = 0);
|
static bool RecordEvent(DeviceEvent event, DeviceStream stream = 0);
|
||||||
static bool SyncEvent(const DeviceEvent &event);
|
static bool SyncEvent(const DeviceEvent &event);
|
||||||
static bool QueryEvent(const DeviceEvent &event);
|
static bool QueryEvent(const DeviceEvent &event);
|
||||||
|
static bool ElapsedTime(float *cost_time, const DeviceEvent &start, const DeviceEvent &end);
|
||||||
|
|
||||||
// Encapsulate the cuda APIs associated with device management.
|
// Encapsulate the cuda APIs associated with device management.
|
||||||
static int device_count();
|
static int device_count();
|
||||||
|
|
|
@ -33,6 +33,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
using mindspore::device::memswap::MemSwapInfoSet;
|
||||||
using mindspore::device::memswap::MemSwapManager;
|
using mindspore::device::memswap::MemSwapManager;
|
||||||
using mindspore::device::memswap::SwapKind;
|
using mindspore::device::memswap::SwapKind;
|
||||||
bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); }
|
bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); }
|
||||||
|
@ -139,6 +140,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
|
||||||
InitKernelRefCount(graph);
|
InitKernelRefCount(graph);
|
||||||
InitMemorySwapInfo(graph);
|
InitMemorySwapInfo(graph);
|
||||||
InitKernelOutputAddress(graph);
|
InitKernelOutputAddress(graph);
|
||||||
|
InitKernelWorkspaceAddress(graph);
|
||||||
} else {
|
} else {
|
||||||
AssignDynamicMemory(graph);
|
AssignDynamicMemory(graph);
|
||||||
}
|
}
|
||||||
|
@ -183,6 +185,56 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool GPUKernelRuntime::SearchMemSwapScheme(const session::KernelGraph *graph) {
|
||||||
|
bool ret = false;
|
||||||
|
ClearKernelOldOutputAndWorkspace(graph);
|
||||||
|
if (!mem_swap_manager_->mem_swap_init()) {
|
||||||
|
if (!mem_swap_manager_->Init(graph)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
while (!ret) {
|
||||||
|
if (!mem_swap_manager_->RetreatSwapInfo()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ret = LaunchKernelDynamic(graph, true, false);
|
||||||
|
if (!ret) {
|
||||||
|
ClearKernelOldOutputAndWorkspace(graph);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mem_swap_manager_->AssignHostMemory();
|
||||||
|
|
||||||
|
// Time profiling
|
||||||
|
ret = LaunchKernelDynamic(graph, false, true);
|
||||||
|
if (!ret) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
return RefineMemSwapScheme(graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GPUKernelRuntime::RefineMemSwapScheme(const session::KernelGraph *graph) {
|
||||||
|
auto &kernels = graph->execution_order();
|
||||||
|
for (const auto &kernel : kernels) {
|
||||||
|
if (!mem_swap_manager_->QueryKernelTriggerSwapIn(kernel)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t swap_in_task_num = mem_swap_manager_->QueryKernelTriggerSwapInTaskNum(kernel);
|
||||||
|
for (size_t swap_in_task_idx = 0; swap_in_task_idx < swap_in_task_num; swap_in_task_idx++) {
|
||||||
|
bool ret = false;
|
||||||
|
while (!ret) {
|
||||||
|
mem_swap_manager_->AdjustSwapInPos(kernel, swap_in_task_idx);
|
||||||
|
ret = LaunchKernelDynamic(graph, true, false);
|
||||||
|
if (!ret) {
|
||||||
|
ClearKernelOldOutputAndWorkspace(graph);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) {
|
void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared<memreuse::MemReuseUtil>();
|
MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared<memreuse::MemReuseUtil>();
|
||||||
|
@ -209,6 +261,7 @@ void GPUKernelRuntime::InitMemorySwapInfo(const session::KernelGraph *graph) {
|
||||||
MS_EXCEPTION_IF_NULL(mem_swap_manager);
|
MS_EXCEPTION_IF_NULL(mem_swap_manager);
|
||||||
auto graph_id = graph->graph_id();
|
auto graph_id = graph->graph_id();
|
||||||
mem_swap_map_[graph_id] = mem_swap_manager;
|
mem_swap_map_[graph_id] = mem_swap_manager;
|
||||||
|
is_first_step_map_[graph_id] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) {
|
void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) {
|
||||||
|
@ -230,6 +283,25 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GPUKernelRuntime::InitKernelWorkspaceAddress(const session::KernelGraph *graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
auto &kernels = graph->execution_order();
|
||||||
|
for (const auto &kernel : kernels) {
|
||||||
|
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
|
auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
|
||||||
|
for (size_t i = 0; i < workspace_sizes.size(); ++i) {
|
||||||
|
auto device_address = CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown);
|
||||||
|
AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GPUKernelRuntime::ClearKernelOldOutputAndWorkspace(const session::KernelGraph *graph) {
|
||||||
|
ClearKernelOutputAddress(graph);
|
||||||
|
ClearKernelWorkspaceAddress(graph);
|
||||||
|
}
|
||||||
|
|
||||||
void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *graph) {
|
void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *graph) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
auto &kernels = graph->execution_order();
|
auto &kernels = graph->execution_order();
|
||||||
|
@ -242,6 +314,7 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
|
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
|
||||||
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
if (device_address->ptr_) {
|
if (device_address->ptr_) {
|
||||||
mem_manager_->FreeMemFromMemPool(device_address);
|
mem_manager_->FreeMemFromMemPool(device_address);
|
||||||
}
|
}
|
||||||
|
@ -250,7 +323,24 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) {
|
void GPUKernelRuntime::ClearKernelWorkspaceAddress(const session::KernelGraph *graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
auto &kernels = graph->execution_order();
|
||||||
|
for (const auto &kernel : kernels) {
|
||||||
|
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
|
auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
|
||||||
|
for (size_t i = 0; i < workspace_sizes.size(); ++i) {
|
||||||
|
auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel, i);
|
||||||
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
|
if (device_address->ptr_) {
|
||||||
|
mem_manager_->FreeMemFromMemPool(device_address);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, bool mock, bool profiling) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(mem_reuse_util_);
|
MS_EXCEPTION_IF_NULL(mem_reuse_util_);
|
||||||
// Reset the reference count.
|
// Reset the reference count.
|
||||||
|
@ -271,7 +361,7 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) {
|
||||||
if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) {
|
if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) {
|
||||||
MS_LOG(EXCEPTION) << "Launch kernel failed.";
|
MS_LOG(EXCEPTION) << "Launch kernel failed.";
|
||||||
}
|
}
|
||||||
FreeKernelDynamicRes(kernel, kernel_workspaces);
|
FreeKernelDynamicRes(kernel);
|
||||||
UpdateMemorySwapTask(kernel);
|
UpdateMemorySwapTask(kernel);
|
||||||
}
|
}
|
||||||
CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed.");
|
CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed.");
|
||||||
|
@ -279,13 +369,39 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GPUKernelRuntime::LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs,
|
||||||
|
const AddressPtrList &workspace, const AddressPtrList &outputs) {
|
||||||
|
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
|
float cost_time = 0;
|
||||||
|
DeviceEvent start = nullptr;
|
||||||
|
DeviceEvent end = nullptr;
|
||||||
|
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&start), "Failed to create event.");
|
||||||
|
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&end), "Failed to create event.");
|
||||||
|
|
||||||
|
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(start, stream_), "Failed to record event to stream.");
|
||||||
|
CHECK_OP_RET_WITH_EXCEPT(kernel_mod->Launch(inputs, workspace, outputs, stream_), "Launch kernel failed.");
|
||||||
|
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(end, stream_), "Failed to record event to stream.");
|
||||||
|
|
||||||
|
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SyncEvent(start), "Failed to sync event.");
|
||||||
|
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SyncEvent(end), "Failed to sync event.");
|
||||||
|
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::ElapsedTime(&cost_time, start, end), "Failed to record elapsed time.");
|
||||||
|
|
||||||
|
mem_swap_manager_->AddKernelExecutionPerform(kernel, cost_time);
|
||||||
|
|
||||||
|
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(start), "Failed to destroy event.");
|
||||||
|
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(end), "Failed to destroy event.");
|
||||||
|
}
|
||||||
|
|
||||||
bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) {
|
bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) {
|
||||||
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
|
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
|
||||||
auto &mem_swap_info_list = mem_swap_manager_->QueryKernelMemSwapInfo(kernel);
|
const MemSwapInfoSet &mem_swap_info_set = mem_swap_manager_->QueryKernelMemSwapInfo(kernel);
|
||||||
for (auto &mem_swap_info : mem_swap_info_list) {
|
for (auto &mem_swap_info : mem_swap_info_set) {
|
||||||
auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_);
|
auto need_swap_kernel = mem_swap_manager_->QueryKerneByTopoOrder(mem_swap_info.topo_order_);
|
||||||
const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_];
|
MS_EXCEPTION_IF_NULL(need_swap_kernel);
|
||||||
auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_, false);
|
const HostAddress &host_address =
|
||||||
|
mem_swap_manager_->QueryKernelHostAddr(need_swap_kernel, mem_swap_info.output_idx_);
|
||||||
|
auto device_address = AnfAlgo::GetMutableOutputAddr(need_swap_kernel, mem_swap_info.output_idx_, false);
|
||||||
|
|
||||||
if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) {
|
if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) {
|
||||||
mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address);
|
mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address);
|
||||||
|
@ -309,9 +425,11 @@ bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) {
|
||||||
|
|
||||||
bool GPUKernelRuntime::UpdateMemorySwapInfo(const session::KernelGraph *graph) {
|
bool GPUKernelRuntime::UpdateMemorySwapInfo(const session::KernelGraph *graph) {
|
||||||
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
|
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
|
||||||
ClearKernelOutputAddress(graph);
|
ClearKernelOldOutputAndWorkspace(graph);
|
||||||
if (!mem_swap_manager_->mem_swap_init()) {
|
if (!mem_swap_manager_->mem_swap_init()) {
|
||||||
mem_swap_manager_->Init(graph);
|
if (!mem_swap_manager_->Init(graph)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return mem_swap_manager_->RetreatSwapInfo();
|
return mem_swap_manager_->RetreatSwapInfo();
|
||||||
}
|
}
|
||||||
|
@ -408,29 +526,6 @@ bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void *GPUKernelRuntime::AttemptMallocMem(size_t size) {
|
|
||||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
|
||||||
MS_EXCEPTION_IF_NULL(mem_swap_manager_);
|
|
||||||
auto device_ptr = mem_manager_->MallocMemFromMemPool(size);
|
|
||||||
if (!device_ptr) {
|
|
||||||
if (!mem_swap_manager_->trigger_swap()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost);
|
|
||||||
while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) {
|
|
||||||
if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) {
|
|
||||||
device_address_swap_out->set_status(DeviceAddressStatus::kInHost);
|
|
||||||
mem_manager_->FreeMemFromMemPool(device_address_swap_out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
device_ptr = mem_manager_->MallocMemFromMemPool(size);
|
|
||||||
if (!device_ptr) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return device_ptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
|
bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
|
||||||
const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs,
|
const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs,
|
||||||
AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs) {
|
AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs) {
|
||||||
|
@ -504,13 +599,13 @@ bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::K
|
||||||
kernel_workspaces->emplace_back(nullptr);
|
kernel_workspaces->emplace_back(nullptr);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto device_ptr = AttemptMallocMem(workspace_sizes[i]);
|
auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel, i);
|
||||||
if (!device_ptr) {
|
if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, workspace_sizes[i])) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
|
kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
|
||||||
MS_EXCEPTION_IF_NULL(workspace);
|
MS_EXCEPTION_IF_NULL(workspace);
|
||||||
workspace->addr = device_ptr;
|
workspace->addr = device_address->ptr_;
|
||||||
workspace->size = workspace_sizes[i];
|
workspace->size = workspace_sizes[i];
|
||||||
kernel_workspaces->emplace_back(workspace);
|
kernel_workspaces->emplace_back(workspace);
|
||||||
}
|
}
|
||||||
|
@ -606,8 +701,7 @@ void GPUKernelRuntime::AllocCommunicationOpMemory(bool is_need_alloc_memory, boo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
|
void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel) {
|
||||||
const AddressPtrList &kernel_workspaces) {
|
|
||||||
MS_EXCEPTION_IF_NULL(kernel);
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||||
MS_EXCEPTION_IF_NULL(mem_reuse_util_);
|
MS_EXCEPTION_IF_NULL(mem_reuse_util_);
|
||||||
|
@ -652,12 +746,13 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Free the workspace of kernel.
|
// Free the workspace of kernel.
|
||||||
for (size_t i = 0; i < kernel_workspaces.size(); ++i) {
|
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||||
auto workspace = kernel_workspaces[i];
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
if (workspace != nullptr) {
|
for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
|
||||||
MS_EXCEPTION_IF_NULL(workspace->addr);
|
auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel, i);
|
||||||
mem_manager_->FreeMemFromMemPool(workspace->addr);
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
workspace->addr = nullptr;
|
if (device_address->ptr_) {
|
||||||
|
mem_manager_->FreeMemFromMemPool(device_address);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,11 +53,17 @@ class GPUKernelRuntime : public KernelRuntime {
|
||||||
// The related functions and members for using dynamic memory pool.
|
// The related functions and members for using dynamic memory pool.
|
||||||
void InitKernelRefCount(const session::KernelGraph *graph);
|
void InitKernelRefCount(const session::KernelGraph *graph);
|
||||||
void InitKernelOutputAddress(const session::KernelGraph *graph);
|
void InitKernelOutputAddress(const session::KernelGraph *graph);
|
||||||
|
void InitKernelWorkspaceAddress(const session::KernelGraph *graph);
|
||||||
void InitMemorySwapInfo(const session::KernelGraph *graph);
|
void InitMemorySwapInfo(const session::KernelGraph *graph);
|
||||||
void ClearKernelOutputAddress(const session::KernelGraph *graph);
|
void ClearKernelOutputAddress(const session::KernelGraph *graph);
|
||||||
bool LaunchKernelDynamic(const session::KernelGraph *graph);
|
void ClearKernelWorkspaceAddress(const session::KernelGraph *graph);
|
||||||
|
void ClearKernelOldOutputAndWorkspace(const session::KernelGraph *graph);
|
||||||
|
bool SearchMemSwapScheme(const session::KernelGraph *graph);
|
||||||
|
bool RefineMemSwapScheme(const session::KernelGraph *graph);
|
||||||
|
bool LaunchKernelDynamic(const session::KernelGraph *graph, bool mock = false, bool profiling = false);
|
||||||
|
void LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs,
|
||||||
|
const AddressPtrList &workspace, const AddressPtrList &outputs);
|
||||||
bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size);
|
bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size);
|
||||||
void *AttemptMallocMem(size_t size);
|
|
||||||
bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
|
bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
|
||||||
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
|
AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
|
||||||
AddressPtrList *kernel_outputs);
|
AddressPtrList *kernel_outputs);
|
||||||
|
@ -72,7 +78,7 @@ class GPUKernelRuntime : public KernelRuntime {
|
||||||
void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory,
|
void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory,
|
||||||
const DeviceAddressPtrList addr_list, size_t total_size,
|
const DeviceAddressPtrList addr_list, size_t total_size,
|
||||||
std::vector<size_t> size_list);
|
std::vector<size_t> size_list);
|
||||||
void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces);
|
void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel);
|
||||||
bool AddMemorySwapTask(const AnfNodePtr &kernel);
|
bool AddMemorySwapTask(const AnfNodePtr &kernel);
|
||||||
bool UpdateMemorySwapInfo(const session::KernelGraph *graph);
|
bool UpdateMemorySwapInfo(const session::KernelGraph *graph);
|
||||||
bool UpdateMemorySwapTask(const AnfNodePtr &kernel);
|
bool UpdateMemorySwapTask(const AnfNodePtr &kernel);
|
||||||
|
@ -81,6 +87,7 @@ class GPUKernelRuntime : public KernelRuntime {
|
||||||
void ClearSwapQueue();
|
void ClearSwapQueue();
|
||||||
std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_;
|
std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_;
|
||||||
std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_;
|
std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_;
|
||||||
|
std::unordered_map<uint32_t, bool> is_first_step_map_;
|
||||||
MemReuseUtilPtr mem_reuse_util_{nullptr};
|
MemReuseUtilPtr mem_reuse_util_{nullptr};
|
||||||
MemSwapManagerPtr mem_swap_manager_{nullptr};
|
MemSwapManagerPtr mem_swap_manager_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
|
@ -73,6 +73,14 @@ DeviceAddress *KernelInfo::GetWorkspaceAddr(size_t index) const {
|
||||||
return workspace_address_list_[index].get();
|
return workspace_address_list_[index].get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DeviceAddressPtr KernelInfo::GetMutableWorkspaceAddr(size_t index) const {
|
||||||
|
if (index >= workspace_address_list_.size()) {
|
||||||
|
MS_LOG(ERROR) << "Index [" << index << "] out of range";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return workspace_address_list_[index];
|
||||||
|
}
|
||||||
|
|
||||||
bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) {
|
bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) {
|
||||||
if (workspace_address_list_.empty()) {
|
if (workspace_address_list_.empty()) {
|
||||||
// parameter and valuenode
|
// parameter and valuenode
|
||||||
|
|
|
@ -54,6 +54,7 @@ class KernelInfo : public KernelInfoDevice {
|
||||||
bool OutputAddrExist(size_t index) const;
|
bool OutputAddrExist(size_t index) const;
|
||||||
bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index);
|
bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index);
|
||||||
DeviceAddress *GetWorkspaceAddr(size_t index) const;
|
DeviceAddress *GetWorkspaceAddr(size_t index) const;
|
||||||
|
DeviceAddressPtr GetMutableWorkspaceAddr(size_t index) const;
|
||||||
bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index);
|
bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index);
|
||||||
void set_kernel_mod(const kernel::KernelModPtr &kernel_mod);
|
void set_kernel_mod(const kernel::KernelModPtr &kernel_mod);
|
||||||
kernel::KernelMod *MutableKernelMod() const;
|
kernel::KernelMod *MutableKernelMod() const;
|
||||||
|
|
Loading…
Reference in New Issue