diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_copy_manager.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_copy_manager.h index 1952415515f..66d37a74984 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_copy_manager.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_copy_manager.h @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -40,29 +41,58 @@ struct TensorInfo { struct KernelExecutionInfo { size_t topo_order_{0}; float execution_perform_{0.0}; - bool trigger_swap_{false}; - bool need_swap_{false}; - // output index to topo orders of node users + bool trigger_swap_out_{false}; + bool trigger_swap_in_{false}; + size_t swap_in_task_num_{0}; + // Key: output index, value: topo orders of node users std::map> node_users_map_; - // kernel output idx to host addr - std::map host_addrs_; + // Key: output idx, value: (host addr, dirty or not) + std::map> host_addrs_; - KernelExecutionInfo() : KernelExecutionInfo(0, 0.0, false, false) {} - explicit KernelExecutionInfo(size_t topo_order) - : topo_order_(topo_order), execution_perform_(0.0), trigger_swap_(false), need_swap_(false) {} - KernelExecutionInfo(size_t topo_order, float execution_perform, bool trigger_swap, bool need_swap) + KernelExecutionInfo() {} + explicit KernelExecutionInfo(size_t topo_order) : KernelExecutionInfo(topo_order, 0.0, false, false, 0) {} + KernelExecutionInfo(size_t topo_order, float execution_perform, bool trigger_swap_out, bool trigger_swap_in, + size_t swap_in_task_num) : topo_order_(topo_order), execution_perform_(execution_perform), - trigger_swap_(trigger_swap), - need_swap_(need_swap) {} + trigger_swap_out_(trigger_swap_out), + trigger_swap_in_(trigger_swap_in), + swap_in_task_num_(swap_in_task_num) {} }; -// trigger swap struct MemSwapInfo { SwapKind swap_kind_; - // kernel need to be swapped - AnfNodePtr kernel_{nullptr}; + // Topo order of kernel need be swapped + size_t topo_order_; size_t output_idx_{0}; + // Record the swapping out position of swapping in tensor + size_t swap_out_pos_; +}; + +struct SwapInfoComp { + bool operator()(const MemSwapInfo &a, const MemSwapInfo &b) { + int swap_kind_a = static_cast(a.swap_kind_); + int swap_kind_b = static_cast(b.swap_kind_); + if (swap_kind_a < swap_kind_b) { + return true; + } else if (swap_kind_a > swap_kind_b) { + return false; + } + + if (a.swap_out_pos_ < b.swap_out_pos_) { + return true; + } else if (a.swap_out_pos_ > b.swap_out_pos_) { + return false; + } + + if (a.topo_order_ < b.topo_order_) { + return true; + } else if (a.topo_order_ > b.topo_order_) { + return false; + } + + return a.output_idx_ < b.output_idx_; + } }; class MemCopyManager { @@ -90,6 +120,7 @@ class MemCopyManager { virtual void ClearSwapQueue() {} }; using MemCopyManagerPtr = std::shared_ptr; +using MemSwapInfoSet = std::set; } // namespace memswap } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc index 41bf5460c3c..5f0569a41d1 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc @@ -22,22 +22,17 @@ namespace mindspore { namespace device { namespace memswap { -void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { +bool MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph, size_t swap_mem_size) { MS_EXCEPTION_IF_NULL(kernel_graph); graph_manager_ = kernel_graph->manager(); MS_EXCEPTION_IF_NULL(graph_manager_); - auto &kernels = kernel_graph->execution_order(); - for (const auto &kernel : kernels) { - if (AnfAlgo::IsRealCNodeKernel(kernel) && (!opt::IsNopNode(kernel))) { - execution_order_.push_back(kernel); - } - } + execution_order_ = kernel_graph->execution_order(); size_t kernel_index = 0; for (const auto &kernel : execution_order_) { - // parse topo order of kernel + // Parse topo order of kernel (void)kernel_execution_info_.emplace(kernel.get(), kernel_index++); - // parse tensor info + // Parse tensor info auto kernel_mod = AnfAlgo::GetKernelMod(kernel); MS_EXCEPTION_IF_NULL(kernel_mod); auto output_sizes = kernel_mod->GetOutputSizeList(); @@ -48,7 +43,7 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { } } - // parse topo order of user kernel + // Parse topo order of user kernel SaveUserKernelTopoOrder(); sort(ordered_tensors_.begin(), ordered_tensors_.end(), @@ -61,17 +56,103 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) { tensor_size_num_++; } } - tensor_size_threshold_ = ordered_tensors_.front().tensor_size_; - tensor_size_threshold_idx_ = 0; - - distance_threshold_ = kernel_index / kDistanceInitFactor; + if (!InitSwapThreshold(0)) { + return false; + } mem_swap_initialized_ = true; MS_EXCEPTION_IF_NULL(mem_copy_manager_); mem_copy_manager_->Init(); + return true; +} + +bool MemSwapManager::InitSwapThreshold(size_t swap_mem_size) { + distance_threshold_ = execution_order_.size() / kDistanceInitFactor; + distance_decay_step_ = execution_order_.size() / kDistanceInitFactor / tensor_size_num_; + if (distance_decay_step_ <= 1) { + distance_decay_step_ = 1; + } + tensor_size_threshold_ = ordered_tensors_.front().tensor_size_; + tensor_size_threshold_idx_ = 0; + + size_t accumulation = 0; + while (accumulation < swap_mem_size) { + accumulation = 0; + for (const auto &tensor_info : ordered_tensors_) { + size_t tensor_size = tensor_info.tensor_size_; + if (tensor_size < tensor_size_threshold_) { + break; + } + if (!CheckDistanceBetweenKernels(tensor_info)) { + continue; + } + + accumulation += tensor_info.tensor_size_; + if (accumulation >= swap_mem_size) { + return true; + } + } + RetreatSwapThreshold(); + if (tensor_size_threshold_idx_ == ordered_tensors_.size() - 1 && distance_threshold_ < kDistanceLowerBound) { + MS_LOG(ERROR) << "Init swap threshold info failed"; + return false; + } + } + return true; +} + +void MemSwapManager::RetreatSwapThreshold() { + if (distance_threshold_ >= kDistanceLowerBound) { + bool update_one_decay_step = (distance_threshold_ > distance_decay_step_) && + (distance_threshold_ - distance_decay_step_ >= kDistanceLowerBound); + if (update_one_decay_step) { + distance_threshold_ -= distance_decay_step_; + } else if (distance_threshold_ >= kDistanceLowerBound) { + size_t new_distance_decay_step = (distance_threshold_ - kDistanceLowerBound) / 4; + if (new_distance_decay_step < 1) { + new_distance_decay_step = 1; + } + distance_threshold_ -= new_distance_decay_step; + } + } + + while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) { + ++tensor_size_threshold_idx_; + if (tensor_size_threshold_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) { + tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_; + break; + } + } +} + +bool MemSwapManager::CheckDistanceBetweenKernels(const TensorInfo &tensor_info) const { + const AnfNodePtr &kernel = tensor_info.kernel_; + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + auto &node_users_map = kernel_exec_info.node_users_map_; + + auto iter = node_users_map.find(tensor_info.output_idx_); + if (iter == node_users_map.end()) { + return false; + } + + auto &node_users = iter->second; + if (node_users.front() - kernel_exec_info.topo_order_ > distance_threshold_) { + return true; + } + + for (size_t i = 1; i < node_users.size(); ++i) { + if (node_users[i] - node_users[i - 1] > distance_threshold_) { + return true; + } + } + return false; } bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { MS_EXCEPTION_IF_NULL(kernel); + if (AnfAlgo::IsCommunicationOp(kernel)) { + return true; + } + NodeUsersMap &user_map = graph_manager_->node_users(); auto iter = user_map.find(kernel); bool adjacent_with_communication_op = false; @@ -81,7 +162,7 @@ bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const { node_set.begin(), node_set.end(), [](const std::pair &node_pair) { return AnfAlgo::IsCommunicationOp(node_pair.first); }); } - return (AnfAlgo::IsCommunicationOp(kernel)) || adjacent_with_communication_op; + return adjacent_with_communication_op; } void MemSwapManager::SaveUserKernelTopoOrder() { @@ -95,7 +176,7 @@ void MemSwapManager::SaveUserKernelTopoOrder() { auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); for (auto &node_pair : node_set) { auto user_kernel = node_pair.first; - if (!AnfAlgo::IsRealCNodeKernel(user_kernel) || opt::IsNopNode(user_kernel)) { + if (!AnfAlgo::IsRealCNodeKernel(user_kernel)) { continue; } @@ -138,21 +219,18 @@ void MemSwapManager::AddSwapInfo() { if (!need_swap) { continue; } - AddKernelNeedSwap(kernel, true); HostAddress host_addr; host_addr.size = tensor_size; auto ret = AllocHostPinnedMem(tensor_size, reinterpret_cast(&host_addr.addr)); if (!ret) { MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << tensor_size << "] failed."; } - kernel_exec_info.host_addrs_[output_idx] = host_addr; - MemSwapInfo mem_swap_out_info = {SwapKind::kDeviceToHost, kernel, output_idx}; + kernel_exec_info.host_addrs_[output_idx] = std::make_pair(host_addr, true); + MemSwapInfo mem_swap_out_info = {SwapKind::kDeviceToHost, kernel_exec_info.topo_order_, output_idx, 0}; if (node_users.size() > 1) { AddKernelMemSwapInfo(execution_order_[node_users[0]], mem_swap_out_info); - AddKernelTriggerSwap(execution_order_[node_users[0]], true); } else { AddKernelMemSwapInfo(kernel, mem_swap_out_info); - AddKernelTriggerSwap(kernel, true); } size_t swap_in_order = node_users.size() == 1 ? node_users[0] - 1 : node_users[1] - 1; @@ -160,9 +238,8 @@ void MemSwapManager::AddSwapInfo() { MS_LOG(EXCEPTION) << "Select swap in point failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]"; } auto swap_in_kernel = execution_order_[swap_in_order]; - MemSwapInfo mem_swap_in_info = {SwapKind::kHostToDevice, kernel, output_idx}; + MemSwapInfo mem_swap_in_info = {SwapKind::kHostToDevice, kernel_exec_info.topo_order_, output_idx, 0}; AddKernelMemSwapInfo(swap_in_kernel, mem_swap_in_info); - AddKernelTriggerSwap(swap_in_kernel, true); host_addrs_list_.push_back(host_addr); } @@ -189,7 +266,7 @@ DeviceAddressPtr MemSwapManager::UpdateSwapQueue(SwapKind swap_kind) const { } } -// retreat to find a workable swap scheme +// Retreat to find a workable swap scheme bool MemSwapManager::RetreatSwapInfo() { if (!trigger_swap_) { trigger_swap_ = true; @@ -220,6 +297,114 @@ bool MemSwapManager::RetreatSwapInfo() { return true; } +void MemSwapManager::AdjustSwapInPos(const AnfNodePtr &kernel, size_t index) { + if (kernel_first_move_cache_map_.find(kernel.get()) == kernel_first_move_cache_map_.end()) { + CacheCurSwapInfoSet(kernel); + } + + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + size_t kernel_pos = kernel_exec_info.topo_order_; + auto &mem_swap_info = mem_swap_info_cache_list_[index]; + + if (QueryFirstTimeMovePos(kernel, index)) { + best_and_cur_pos_cache_.first = BestSwapInPerformPos(kernel, mem_swap_info); + best_and_cur_pos_cache_.second = best_and_cur_pos_cache_.first; + size_t best_pos = best_and_cur_pos_cache_.first; + if (best_pos != kernel_pos) { + MoveSwapInfoPos(best_pos, kernel_pos, mem_swap_info); + } + AddFirstTimeMovePos(kernel, index, false); + return; + } + + auto &cur_pos = best_and_cur_pos_cache_.second; + if (cur_pos < kernel_pos) { + MoveSwapInfoPos(cur_pos + 1, cur_pos, mem_swap_info); + cur_pos++; + } +} + +void MemSwapManager::CacheCurSwapInfoSet(const AnfNodePtr &kernel) { + if (!kernel_first_move_cache_map_.empty()) { + kernel_first_move_cache_map_.clear(); + } + if (!mem_swap_info_cache_list_.empty()) { + mem_swap_info_cache_list_.clear(); + } + + auto mem_swap_info_set = QueryKernelMemSwapInfo(kernel); + size_t swap_in_task_cnt = 0; + for (auto &mem_swap_info : mem_swap_info_set) { + if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) { + (void)mem_swap_info_cache_list_.push_back(mem_swap_info); + kernel_first_move_cache_map_[kernel.get()].push_back(true); + swap_in_task_cnt++; + } + } + size_t swap_in_task_num = QueryKernelTriggerSwapInTaskNum(kernel); + if (swap_in_task_cnt != swap_in_task_num) { + MS_LOG(EXCEPTION) << "Swap_in_task_cnt :" << swap_in_task_cnt + << "must equal Swap_in_task_num: " << swap_in_task_num; + } +} + +void MemSwapManager::AddFirstTimeMovePos(const AnfNodePtr &kernel, size_t index, bool first_time) { + auto iter = kernel_first_move_cache_map_.find(kernel.get()); + if (iter == kernel_first_move_cache_map_.end()) { + MS_LOG(EXCEPTION) << "Can not find first time move pos info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + auto &first_move_list = iter->second; + if (index >= first_move_list.size()) { + MS_LOG(EXCEPTION) << "Index [" << index << "] out of range"; + } + first_move_list[index] = first_time; +} + +bool MemSwapManager::QueryFirstTimeMovePos(const AnfNodePtr &kernel, size_t index) const { + auto iter = kernel_first_move_cache_map_.find(kernel.get()); + if (iter == kernel_first_move_cache_map_.end()) { + MS_LOG(EXCEPTION) << "Can not find first time move pos info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + const auto &first_move_list = iter->second; + if (index >= first_move_list.size()) { + MS_LOG(EXCEPTION) << "Index [" << index << "] out of range"; + } + return first_move_list[index]; +} + +size_t MemSwapManager::BestSwapInPerformPos(const AnfNodePtr &trigger_kernel, const MemSwapInfo &mem_swap_info) const { + auto need_swap_kernel = QueryKerneByTopoOrder(mem_swap_info.topo_order_); + const PerformPair &perform_pair = QueryKernelSwapPerform(need_swap_kernel, mem_swap_info.output_idx_); + float swap_in_cost_time = perform_pair.second; + size_t swap_out_pos = mem_swap_info.swap_out_pos_; + auto &kernel_exec_info = SearchKernelExecutionInfo(trigger_kernel); + size_t trigger_kernel_pos = kernel_exec_info.topo_order_; + float kernel_execution_time = 0; + + size_t pos = trigger_kernel_pos; + for (; pos > swap_out_pos + 1; pos--) { + auto kernel = QueryKerneByTopoOrder(pos - 1); + if (QueryKernelTriggerSwapIn(kernel)) { + return pos; + } + kernel_execution_time += QueryKernelExecutionPerform(QueryKerneByTopoOrder(pos)); + if (kernel_execution_time >= swap_in_cost_time) { + return pos - 1; + } + } + return pos; +} + +void MemSwapManager::MoveSwapInfoPos(size_t des_pos, size_t src_pos, const MemSwapInfo &mem_swap_info) { + if (des_pos == src_pos) { + MS_LOG(EXCEPTION) << "destination pos can not equal source pos"; + } + auto des_kernel = QueryKerneByTopoOrder(des_pos); + auto src_kernel = QueryKerneByTopoOrder(src_pos); + AddKernelMemSwapInfo(des_kernel, mem_swap_info); + RemoveKernelMemSwapInfo(src_kernel, mem_swap_info); +} + KernelExecutionInfo &MemSwapManager::SearchKernelExecutionInfo(const AnfNodePtr &kernel) const { MS_EXCEPTION_IF_NULL(kernel); auto iter = kernel_execution_info_.find(kernel.get()); @@ -234,16 +419,6 @@ void MemSwapManager::AddKernelExecutionPerform(const AnfNodePtr &kernel, float p kernel_exec_info.execution_perform_ = perform; } -void MemSwapManager::AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap) { - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - kernel_exec_info.trigger_swap_ = trigger_swap; -} - -void MemSwapManager::AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap) { - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - kernel_exec_info.need_swap_ = need_swap; -} - void MemSwapManager::AddKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx, const std::pair &perform) { MS_EXCEPTION_IF_NULL(kernel); @@ -252,7 +427,42 @@ void MemSwapManager::AddKernelSwapPerform(const AnfNodePtr &kernel, size_t outpu void MemSwapManager::AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) { MS_EXCEPTION_IF_NULL(kernel); - mem_swap_info_[kernel.get()].push_back(mem_swap_info); + (void)mem_swap_info_map_[kernel.get()].insert(mem_swap_info); + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { + kernel_exec_info.trigger_swap_out_ = true; + } else if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) { + kernel_exec_info.swap_in_task_num_++; + kernel_exec_info.trigger_swap_in_ = true; + } +} + +void MemSwapManager::RemoveKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info) { + MS_EXCEPTION_IF_NULL(kernel); + if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) { + auto map_iter = mem_swap_info_map_.find(kernel.get()); + if (map_iter == mem_swap_info_map_.end()) { + MS_LOG(EXCEPTION) << "Can not find memory swap information of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + MemSwapInfoSet &mem_swap_info_set = map_iter->second; + + auto set_iter = mem_swap_info_set.find(mem_swap_info); + if (set_iter == mem_swap_info_set.end()) { + MS_LOG(EXCEPTION) << "Can not find memory swap information in mem swap info set"; + } + mem_swap_info_set.erase(set_iter); + + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + if (kernel_exec_info.swap_in_task_num_ > 0) { + kernel_exec_info.swap_in_task_num_--; + } + if (kernel_exec_info.swap_in_task_num_ == 0) { + kernel_exec_info.trigger_swap_in_ = false; + } + if (mem_swap_info_set.empty()) { + (void)mem_swap_info_map_.erase(kernel.get()); + } + } } float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) const { @@ -262,12 +472,24 @@ float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) cons bool MemSwapManager::QueryKernelTriggerSwap(const AnfNodePtr &kernel) const { const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - return kernel_exec_info.trigger_swap_; + return kernel_exec_info.trigger_swap_out_ || kernel_exec_info.trigger_swap_in_; } -bool MemSwapManager::QueryKernelNeedSwap(const AnfNodePtr &kernel) const { +bool MemSwapManager::QueryKernelTriggerSwapIn(const AnfNodePtr &kernel) const { const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - return kernel_exec_info.need_swap_; + return kernel_exec_info.trigger_swap_in_; +} + +size_t MemSwapManager::QueryKernelTriggerSwapInTaskNum(const AnfNodePtr &kernel) const { + const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + return kernel_exec_info.swap_in_task_num_; +} + +const AnfNodePtr MemSwapManager::QueryKerneByTopoOrder(size_t index) const { + if (index >= execution_order_.size()) { + MS_LOG(EXCEPTION) << "Index [" << index << "] out of range"; + } + return execution_order_[index]; } const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kernel, size_t output_idx) const { @@ -286,15 +508,70 @@ const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kern return iter_output->second; } -const std::vector &MemSwapManager::QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const { +const MemSwapInfoSet &MemSwapManager::QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const { MS_EXCEPTION_IF_NULL(kernel); - auto iter = mem_swap_info_.find(kernel.get()); - if (iter == mem_swap_info_.end()) { - MS_LOG(EXCEPTION) << "Can not find memory swap information data of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + auto iter = mem_swap_info_map_.find(kernel.get()); + if (iter == mem_swap_info_map_.end()) { + MS_LOG(EXCEPTION) << "Can not find memory swap information of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; } return iter->second; } +void MemSwapManager::AssignHostMemory() { + for (auto &kernel_exec_info_pair : kernel_execution_info_) { + auto &kernel_exec_info = kernel_exec_info_pair.second; + auto &host_addrs_map = kernel_exec_info.host_addrs_; + for (auto &host_addr_pair : host_addrs_map) { + auto &host_addr = host_addr_pair.second.first; + auto ret = AllocHostPinnedMem(host_addr.size, reinterpret_cast(&host_addr.addr)); + if (!ret) { + MS_LOG(EXCEPTION) << "Alloc host pinned memory[" << host_addr.size << "] failed."; + } + host_addrs_list_.push_back(host_addr); + } + } +} + +const HostAddress &MemSwapManager::QueryKernelHostAddr(const AnfNodePtr &kernel, size_t output_idx) const { + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + auto &host_addrs = kernel_exec_info.host_addrs_; + auto iter = host_addrs.find(output_idx); + if (iter == host_addrs.end()) { + MS_LOG(EXCEPTION) << "Can not find host address of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + return (iter->second).first; +} + +void MemSwapManager::AddKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx, bool dirty) { + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + auto &host_addrs = kernel_exec_info.host_addrs_; + auto iter = host_addrs.find(output_idx); + if (iter == host_addrs.end()) { + MS_LOG(EXCEPTION) << "Can not find host memory dirty info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + (iter->second).second = dirty; +} + +bool MemSwapManager::QueryKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx) const { + auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); + auto &host_addrs = kernel_exec_info.host_addrs_; + auto iter = host_addrs.find(output_idx); + if (iter == host_addrs.end()) { + MS_LOG(EXCEPTION) << "Can not find host memory dirty info of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; + } + return (iter->second).second; +} + +void MemSwapManager::ResetHostAddrIsDirty() { + for (auto &kernel_exec_info_pair : kernel_execution_info_) { + auto &kernel_exec_info = kernel_exec_info_pair.second; + auto &host_addrs = kernel_exec_info.host_addrs_; + for (auto &host_addr : host_addrs) { + host_addr.second.second = true; + } + } +} + void MemSwapManager::InsertSwapInBlackList(const void *device_ptr) { swap_in_blacklist_.insert(device_ptr); } bool MemSwapManager::FindInSwapInBlackList(const void *device_ptr) const { @@ -302,16 +579,6 @@ bool MemSwapManager::FindInSwapInBlackList(const void *device_ptr) const { return iter != swap_in_blacklist_.end(); } -const HostAddress &MemSwapManager::kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const { - auto &kernel_exec_info = SearchKernelExecutionInfo(kernel); - auto &host_addrs = kernel_exec_info.host_addrs_; - auto iter = host_addrs.find(output_idx); - if (iter == host_addrs.end()) { - MS_LOG(EXCEPTION) << "Can not find host address of op[" << AnfAlgo::GetCNodeName(kernel) << "]"; - } - return iter->second; -} - bool MemSwapManager::AllocHostPinnedMem(size_t size, void **addr) const { return mem_copy_manager_->AllocHostPinnedMem(size, addr); } @@ -331,13 +598,14 @@ void MemSwapManager::ResetSwapInfo() { ClearSwapQueue(); for (auto &kernel_exec_info_pair : kernel_execution_info_) { auto &kernel_exec_info = kernel_exec_info_pair.second; - kernel_exec_info.trigger_swap_ = false; - kernel_exec_info.need_swap_ = false; + kernel_exec_info.trigger_swap_out_ = false; + kernel_exec_info.trigger_swap_in_ = false; + kernel_exec_info.swap_in_task_num_ = 0; kernel_exec_info.host_addrs_.clear(); } ReleaseHostPinnedMem(); swap_in_blacklist_.clear(); - mem_swap_info_.clear(); + mem_swap_info_map_.clear(); } } // namespace memswap } // namespace device diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h index d8620c85162..0f7f2678839 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.h @@ -32,7 +32,11 @@ namespace memswap { class MemSwapManager { public: explicit MemSwapManager(const MemCopyManagerPtr &mem_copy_manager) - : tensor_size_threshold_(0), tensor_size_threshold_idx_(0), tensor_size_num_(1), distance_threshold_(1) { + : tensor_size_threshold_(0), + tensor_size_threshold_idx_(0), + tensor_size_num_(1), + distance_threshold_(1), + distance_decay_step_(1) { mem_copy_manager_ = mem_copy_manager; } @@ -42,7 +46,7 @@ class MemSwapManager { ~MemSwapManager() = default; - void Init(const mindspore::session::KernelGraph *kernel_graph); + bool Init(const mindspore::session::KernelGraph *kernel_graph, size_t swap_mem_size = 0); void AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address, const HostAddress &host_address) const; @@ -51,9 +55,10 @@ class MemSwapManager { DeviceAddressPtr UpdateSwapQueue(SwapKind swap_kind) const; - // retreat to find a workable swap scheme bool RetreatSwapInfo(); + void AdjustSwapInPos(const AnfNodePtr &kernel, size_t index); + bool trigger_swap() const { return trigger_swap_; } bool mem_swap_init() const { return mem_swap_initialized_; } @@ -70,16 +75,28 @@ class MemSwapManager { bool QueryKernelTriggerSwap(const AnfNodePtr &kernel) const; - bool QueryKernelNeedSwap(const AnfNodePtr &kernel) const; + bool QueryKernelTriggerSwapIn(const AnfNodePtr &kernel) const; - const std::vector &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const; + size_t QueryKernelTriggerSwapInTaskNum(const AnfNodePtr &kernel) const; + + const AnfNodePtr QueryKerneByTopoOrder(size_t index) const; + + const MemSwapInfoSet &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const; + + void AssignHostMemory(); + + const HostAddress &QueryKernelHostAddr(const AnfNodePtr &kernel, size_t output_idx) const; + + void AddKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx, bool dirty); + + bool QueryKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx) const; + + void ResetHostAddrIsDirty(); void InsertSwapInBlackList(const void *device_ptr); bool FindInSwapInBlackList(const void *device_ptr) const; - const HostAddress &kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const; - bool AllocHostPinnedMem(size_t size, void **addr) const; void ReleaseHostPinnedMem(); @@ -93,27 +110,47 @@ class MemSwapManager { void SaveUserKernelTopoOrder(); - void AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap); + bool InitSwapThreshold(size_t swap_mem_size); - void AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap); + void RetreatSwapThreshold(); + + void CacheCurSwapInfoSet(const AnfNodePtr &kernel); + + void AddFirstTimeMovePos(const AnfNodePtr &kernel, size_t index, bool first_time); + + bool QueryFirstTimeMovePos(const AnfNodePtr &kernel, size_t index) const; + + size_t BestSwapInPerformPos(const AnfNodePtr &trigger_kernel, const MemSwapInfo &mem_swap_info) const; + + void MoveSwapInfoPos(size_t des_pos, size_t src_pos, const MemSwapInfo &mem_swap_info); void AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info); + void RemoveKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info); + + bool CheckDistanceBetweenKernels(const TensorInfo &tensor_info) const; + bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const; std::vector execution_order_; std::vector ordered_tensors_; std::unordered_map kernel_execution_info_; std::unordered_map> kernel_swap_perform_; - // trigger swap kernel key : MemSwapInfo of kernel need to be swapped - std::unordered_map> mem_swap_info_; + // Key: trigger swap kernel, value: MemSwapInfoSet of kernel need to be swapped + std::unordered_map mem_swap_info_map_; std::vector host_addrs_list_; std::unordered_set swap_in_blacklist_; + // Key: cache kernel address, value: lists of first time move pos or not + std::map> kernel_first_move_cache_map_; + std::vector mem_swap_info_cache_list_; + std::pair best_and_cur_pos_cache_; + size_t tensor_size_threshold_; size_t tensor_size_threshold_idx_; size_t tensor_size_num_; size_t distance_threshold_; + size_t distance_decay_step_; MemCopyManagerPtr mem_copy_manager_{nullptr}; FuncGraphManagerPtr graph_manager_{nullptr}; diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 38c040e6b15..15842335865 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -707,6 +707,18 @@ DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, siz return addr; } +// get workspace device mutable addr of anf_node +DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_info = dynamic_cast(node->kernel_info()); + MS_EXCEPTION_IF_NULL(kernel_info); + auto addr = kernel_info->GetMutableWorkspaceAddr(index); + if (addr == nullptr) { + MS_LOG(EXCEPTION) << "Index " << index << " of node " << node->DebugString() << "] workspace addr is not exist"; + } + return addr; +} + // set infer shapes and types of anf node void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector &types, const std::vector> &shapes, AnfNode *node) { diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 4fa3150e367..c7791d8ecba 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -149,6 +149,8 @@ class AnfRuntimeAlgorithm { static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); // get workspace device addr of anf_node static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx); + // get workspace device mutable addr of anf_node + static DeviceAddressPtr GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index); // set infer shapes and types of anf node static void SetOutputInferTypeAndShape(const std::vector &types, const std::vector> &shapes, AnfNode *node); diff --git a/mindspore/ccsrc/runtime/device/gpu/cuda_driver.cc b/mindspore/ccsrc/runtime/device/gpu/cuda_driver.cc index 1f5e5e3c22a..9eef3d6f108 100644 --- a/mindspore/ccsrc/runtime/device/gpu/cuda_driver.cc +++ b/mindspore/ccsrc/runtime/device/gpu/cuda_driver.cc @@ -209,6 +209,16 @@ bool CudaDriver::QueryEvent(const DeviceEvent &event) { } } +bool CudaDriver::ElapsedTime(float *cost_time, const DeviceEvent &start, const DeviceEvent &end) { + auto ret = cudaEventElapsedTime(cost_time, (cudaEvent_t)start, (cudaEvent_t)end); + if (ret == cudaSuccess) { + return true; + } else { + MS_LOG(ERROR) << "cudaEventElapsedTime failed, ret[" << static_cast(ret) << "], " << cudaGetErrorString(ret); + return false; + } +} + int CudaDriver::device_count() { int dev_count; auto ret = cudaGetDeviceCount(&dev_count); diff --git a/mindspore/ccsrc/runtime/device/gpu/cuda_driver.h b/mindspore/ccsrc/runtime/device/gpu/cuda_driver.h index fb5d60f6cfb..f626468545d 100644 --- a/mindspore/ccsrc/runtime/device/gpu/cuda_driver.h +++ b/mindspore/ccsrc/runtime/device/gpu/cuda_driver.h @@ -57,6 +57,7 @@ class CudaDriver { static bool RecordEvent(DeviceEvent event, DeviceStream stream = 0); static bool SyncEvent(const DeviceEvent &event); static bool QueryEvent(const DeviceEvent &event); + static bool ElapsedTime(float *cost_time, const DeviceEvent &start, const DeviceEvent &end); // Encapsulate the cuda APIs associated with device management. static int device_count(); diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index 185df37e4df..2b9f4379771 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -33,6 +33,7 @@ namespace mindspore { namespace device { namespace gpu { +using mindspore::device::memswap::MemSwapInfoSet; using mindspore::device::memswap::MemSwapManager; using mindspore::device::memswap::SwapKind; bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); } @@ -139,6 +140,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { InitKernelRefCount(graph); InitMemorySwapInfo(graph); InitKernelOutputAddress(graph); + InitKernelWorkspaceAddress(graph); } else { AssignDynamicMemory(graph); } @@ -183,6 +185,56 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph) { return ret; } +bool GPUKernelRuntime::SearchMemSwapScheme(const session::KernelGraph *graph) { + bool ret = false; + ClearKernelOldOutputAndWorkspace(graph); + if (!mem_swap_manager_->mem_swap_init()) { + if (!mem_swap_manager_->Init(graph)) { + return false; + } + } + + while (!ret) { + if (!mem_swap_manager_->RetreatSwapInfo()) { + return false; + } + ret = LaunchKernelDynamic(graph, true, false); + if (!ret) { + ClearKernelOldOutputAndWorkspace(graph); + } + } + mem_swap_manager_->AssignHostMemory(); + + // Time profiling + ret = LaunchKernelDynamic(graph, false, true); + if (!ret) { + return ret; + } + return RefineMemSwapScheme(graph); +} + +bool GPUKernelRuntime::RefineMemSwapScheme(const session::KernelGraph *graph) { + auto &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + if (!mem_swap_manager_->QueryKernelTriggerSwapIn(kernel)) { + continue; + } + + size_t swap_in_task_num = mem_swap_manager_->QueryKernelTriggerSwapInTaskNum(kernel); + for (size_t swap_in_task_idx = 0; swap_in_task_idx < swap_in_task_num; swap_in_task_idx++) { + bool ret = false; + while (!ret) { + mem_swap_manager_->AdjustSwapInPos(kernel, swap_in_task_idx); + ret = LaunchKernelDynamic(graph, true, false); + if (!ret) { + ClearKernelOldOutputAndWorkspace(graph); + } + } + } + } + return true; +} + void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared(); @@ -209,6 +261,7 @@ void GPUKernelRuntime::InitMemorySwapInfo(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(mem_swap_manager); auto graph_id = graph->graph_id(); mem_swap_map_[graph_id] = mem_swap_manager; + is_first_step_map_[graph_id] = true; } void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) { @@ -230,6 +283,25 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph } } +void GPUKernelRuntime::InitKernelWorkspaceAddress(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); + for (size_t i = 0; i < workspace_sizes.size(); ++i) { + auto device_address = CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown); + AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get()); + } + } +} + +void GPUKernelRuntime::ClearKernelOldOutputAndWorkspace(const session::KernelGraph *graph) { + ClearKernelOutputAddress(graph); + ClearKernelWorkspaceAddress(graph); +} + void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); auto &kernels = graph->execution_order(); @@ -242,6 +314,7 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap continue; } auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); + MS_EXCEPTION_IF_NULL(device_address); if (device_address->ptr_) { mem_manager_->FreeMemFromMemPool(device_address); } @@ -250,7 +323,24 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap } } -bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { +void GPUKernelRuntime::ClearKernelWorkspaceAddress(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + auto &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); + for (size_t i = 0; i < workspace_sizes.size(); ++i) { + auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel, i); + MS_EXCEPTION_IF_NULL(device_address); + if (device_address->ptr_) { + mem_manager_->FreeMemFromMemPool(device_address); + } + } + } +} + +bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, bool mock, bool profiling) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(mem_reuse_util_); // Reset the reference count. @@ -271,7 +361,7 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) { MS_LOG(EXCEPTION) << "Launch kernel failed."; } - FreeKernelDynamicRes(kernel, kernel_workspaces); + FreeKernelDynamicRes(kernel); UpdateMemorySwapTask(kernel); } CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed."); @@ -279,13 +369,39 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { return true; } +void GPUKernelRuntime::LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, + const AddressPtrList &workspace, const AddressPtrList &outputs) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + float cost_time = 0; + DeviceEvent start = nullptr; + DeviceEvent end = nullptr; + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&start), "Failed to create event."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateEvent(&end), "Failed to create event."); + + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(start, stream_), "Failed to record event to stream."); + CHECK_OP_RET_WITH_EXCEPT(kernel_mod->Launch(inputs, workspace, outputs, stream_), "Launch kernel failed."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::RecordEvent(end, stream_), "Failed to record event to stream."); + + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SyncEvent(start), "Failed to sync event."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::SyncEvent(end), "Failed to sync event."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::ElapsedTime(&cost_time, start, end), "Failed to record elapsed time."); + + mem_swap_manager_->AddKernelExecutionPerform(kernel, cost_time); + + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(start), "Failed to destroy event."); + CHECK_OP_RET_WITH_EXCEPT(CudaDriver::DestroyEvent(end), "Failed to destroy event."); +} + bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) { MS_EXCEPTION_IF_NULL(mem_swap_manager_); - auto &mem_swap_info_list = mem_swap_manager_->QueryKernelMemSwapInfo(kernel); - for (auto &mem_swap_info : mem_swap_info_list) { - auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_); - const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_]; - auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_, false); + const MemSwapInfoSet &mem_swap_info_set = mem_swap_manager_->QueryKernelMemSwapInfo(kernel); + for (auto &mem_swap_info : mem_swap_info_set) { + auto need_swap_kernel = mem_swap_manager_->QueryKerneByTopoOrder(mem_swap_info.topo_order_); + MS_EXCEPTION_IF_NULL(need_swap_kernel); + const HostAddress &host_address = + mem_swap_manager_->QueryKernelHostAddr(need_swap_kernel, mem_swap_info.output_idx_); + auto device_address = AnfAlgo::GetMutableOutputAddr(need_swap_kernel, mem_swap_info.output_idx_, false); if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address); @@ -309,9 +425,11 @@ bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel) { bool GPUKernelRuntime::UpdateMemorySwapInfo(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(mem_swap_manager_); - ClearKernelOutputAddress(graph); + ClearKernelOldOutputAndWorkspace(graph); if (!mem_swap_manager_->mem_swap_init()) { - mem_swap_manager_->Init(graph); + if (!mem_swap_manager_->Init(graph)) { + return false; + } } return mem_swap_manager_->RetreatSwapInfo(); } @@ -408,29 +526,6 @@ bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, return true; } -void *GPUKernelRuntime::AttemptMallocMem(size_t size) { - MS_EXCEPTION_IF_NULL(mem_manager_); - MS_EXCEPTION_IF_NULL(mem_swap_manager_); - auto device_ptr = mem_manager_->MallocMemFromMemPool(size); - if (!device_ptr) { - if (!mem_swap_manager_->trigger_swap()) { - return nullptr; - } - mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost); - while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) { - if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) { - device_address_swap_out->set_status(DeviceAddressStatus::kInHost); - mem_manager_->FreeMemFromMemPool(device_address_swap_out); - } - } - device_ptr = mem_manager_->MallocMemFromMemPool(size); - if (!device_ptr) { - return nullptr; - } - } - return device_ptr; -} - bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs) { @@ -504,13 +599,13 @@ bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::K kernel_workspaces->emplace_back(nullptr); continue; } - auto device_ptr = AttemptMallocMem(workspace_sizes[i]); - if (!device_ptr) { + auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel, i); + if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, workspace_sizes[i])) { return false; } kernel::AddressPtr workspace = std::make_shared(); MS_EXCEPTION_IF_NULL(workspace); - workspace->addr = device_ptr; + workspace->addr = device_address->ptr_; workspace->size = workspace_sizes[i]; kernel_workspaces->emplace_back(workspace); } @@ -606,8 +701,7 @@ void GPUKernelRuntime::AllocCommunicationOpMemory(bool is_need_alloc_memory, boo } } -void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, - const AddressPtrList &kernel_workspaces) { +void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_reuse_util_); @@ -652,12 +746,13 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, } } // Free the workspace of kernel. - for (size_t i = 0; i < kernel_workspaces.size(); ++i) { - auto workspace = kernel_workspaces[i]; - if (workspace != nullptr) { - MS_EXCEPTION_IF_NULL(workspace->addr); - mem_manager_->FreeMemFromMemPool(workspace->addr); - workspace->addr = nullptr; + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { + auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel, i); + MS_EXCEPTION_IF_NULL(device_address); + if (device_address->ptr_) { + mem_manager_->FreeMemFromMemPool(device_address); } } } diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h index e1ba3458661..2b041b86f54 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h @@ -53,11 +53,17 @@ class GPUKernelRuntime : public KernelRuntime { // The related functions and members for using dynamic memory pool. void InitKernelRefCount(const session::KernelGraph *graph); void InitKernelOutputAddress(const session::KernelGraph *graph); + void InitKernelWorkspaceAddress(const session::KernelGraph *graph); void InitMemorySwapInfo(const session::KernelGraph *graph); void ClearKernelOutputAddress(const session::KernelGraph *graph); - bool LaunchKernelDynamic(const session::KernelGraph *graph); + void ClearKernelWorkspaceAddress(const session::KernelGraph *graph); + void ClearKernelOldOutputAndWorkspace(const session::KernelGraph *graph); + bool SearchMemSwapScheme(const session::KernelGraph *graph); + bool RefineMemSwapScheme(const session::KernelGraph *graph); + bool LaunchKernelDynamic(const session::KernelGraph *graph, bool mock = false, bool profiling = false); + void LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, + const AddressPtrList &workspace, const AddressPtrList &outputs); bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size); - void *AttemptMallocMem(size_t size); bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); @@ -72,7 +78,7 @@ class GPUKernelRuntime : public KernelRuntime { void AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory, const DeviceAddressPtrList addr_list, size_t total_size, std::vector size_list); - void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces); + void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel); bool AddMemorySwapTask(const AnfNodePtr &kernel); bool UpdateMemorySwapInfo(const session::KernelGraph *graph); bool UpdateMemorySwapTask(const AnfNodePtr &kernel); @@ -81,6 +87,7 @@ class GPUKernelRuntime : public KernelRuntime { void ClearSwapQueue(); std::unordered_map mem_reuse_util_map_; std::unordered_map mem_swap_map_; + std::unordered_map is_first_step_map_; MemReuseUtilPtr mem_reuse_util_{nullptr}; MemSwapManagerPtr mem_swap_manager_{nullptr}; }; diff --git a/mindspore/ccsrc/runtime/device/kernel_info.cc b/mindspore/ccsrc/runtime/device/kernel_info.cc index 692532e70b3..a7a500ff95c 100644 --- a/mindspore/ccsrc/runtime/device/kernel_info.cc +++ b/mindspore/ccsrc/runtime/device/kernel_info.cc @@ -73,6 +73,14 @@ DeviceAddress *KernelInfo::GetWorkspaceAddr(size_t index) const { return workspace_address_list_[index].get(); } +DeviceAddressPtr KernelInfo::GetMutableWorkspaceAddr(size_t index) const { + if (index >= workspace_address_list_.size()) { + MS_LOG(ERROR) << "Index [" << index << "] out of range"; + return nullptr; + } + return workspace_address_list_[index]; +} + bool KernelInfo::SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index) { if (workspace_address_list_.empty()) { // parameter and valuenode diff --git a/mindspore/ccsrc/runtime/device/kernel_info.h b/mindspore/ccsrc/runtime/device/kernel_info.h index baded9d9a3a..e9d997cb5e7 100644 --- a/mindspore/ccsrc/runtime/device/kernel_info.h +++ b/mindspore/ccsrc/runtime/device/kernel_info.h @@ -54,6 +54,7 @@ class KernelInfo : public KernelInfoDevice { bool OutputAddrExist(size_t index) const; bool SetOutputAddr(const DeviceAddressPtr &output_address, size_t index); DeviceAddress *GetWorkspaceAddr(size_t index) const; + DeviceAddressPtr GetMutableWorkspaceAddr(size_t index) const; bool SetWorkspaceAddr(const DeviceAddressPtr &output_address, size_t index); void set_kernel_mod(const kernel::KernelModPtr &kernel_mod); kernel::KernelMod *MutableKernelMod() const;