From 6eb3cf1aee8ff30b39ac0169774384cab8f04655 Mon Sep 17 00:00:00 2001 From: tanghuikang Date: Wed, 1 Dec 2021 16:10:11 +0800 Subject: [PATCH] Optimize swap strategy --- .../backend/session/anf_runtime_algorithm.cc | 12 +++ .../backend/session/anf_runtime_algorithm.h | 2 + .../ccsrc/backend/session/kernel_graph.cc | 8 +- .../ccsrc/backend/session/kernel_graph.h | 5 +- .../ccsrc/runtime/device/kernel_runtime.cc | 34 ++++++- .../runtime/device/memory_offload_strategy.cc | 96 +++++++++++++------ .../runtime/device/memory_offload_strategy.h | 10 +- .../ccsrc/runtime/device/memory_scheduler.cc | 41 ++++---- .../ccsrc/runtime/device/memory_scheduler.h | 15 ++- 9 files changed, 153 insertions(+), 70 deletions(-) diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 68d3b383a2f..8a754b7a892 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -1398,6 +1398,18 @@ bool AnfRuntimeAlgorithm::IsLabelIndexInNode(const AnfNodePtr &node, size_t labe return false; } +bool AnfRuntimeAlgorithm::IsUpdateParameterKernel(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto node_name = GetCNodeName(node); + if (HasNodeAttr(kAttrAsync, node) && GetNodeAttr(node, kAttrAsync)) { + return false; + } + if (kOptOperatorSet.find(node_name) == kOptOperatorSet.end() && node_name.find("Assign") == string::npos) { + return false; + } + return true; +} + void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); auto kernel_info = dynamic_cast(node->kernel_info()); diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index d1c40edee3f..a7ccc96ac0d 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -246,6 +246,8 @@ class AnfRuntimeAlgorithm { static bool IsParameterWeight(const ParameterPtr &node); // checkout whether the anf node is include the label_index. static bool IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index); + // Check whether the cnode update parameter + static bool IsUpdateParameterKernel(const CNodePtr &node); // set stream id of kernel,which will be set in stream assign and be used in stream generate static void SetStreamId(uint32_t stream_id, AnfNode *node); // get stream id diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 31b6c74fcfc..2ba9864eb19 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -1335,13 +1335,7 @@ void KernelGraph::SetOptimizerFlag() { has_optimizer_ = false; for (const auto &cnode : execution_order_) { MS_EXCEPTION_IF_NULL(cnode); - auto node_name = AnfAlgo::GetCNodeName(cnode); - if (AnfAlgo::HasNodeAttr(kAttrAsync, cnode) && AnfAlgo::GetNodeAttr(cnode, kAttrAsync)) { - continue; - } - if (kOptOperatorSet.find(node_name) != kOptOperatorSet.end()) { - has_optimizer_ = true; - } else if (node_name.find("Assign") == string::npos) { + if (!AnfAlgo::IsUpdateParameterKernel(cnode)) { continue; } for (auto &input : cnode->inputs()) { diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 39d8dd22fee..d369a636eae 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -303,10 +303,7 @@ class KernelGraph : public FuncGraph { bool has_optimizer() const { return has_optimizer_; } bool IsUpdatedParameter(const ParameterPtr ¶m) const { - if (updated_parameters_.find(param) != updated_parameters_.end()) { - return true; - } - return false; + return updated_parameters_.find(param) != updated_parameters_.end(); } // handle graph dependency void AddPreGraph(const std::shared_ptr &graph) { diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 607fbcc78d5..57bfe56e894 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -1312,6 +1312,7 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr &mem auto kernel_mod = AnfAlgo::GetKernelMod(kernel); MS_EXCEPTION_IF_NULL(kernel_mod); size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + const auto update_parameter = AnfAlgo::IsUpdateParameterKernel(cnode); for (size_t j = 0; j < input_num; ++j) { auto real_input = AnfAlgo::GetRealInputIndex(kernel, j); auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, real_input, true); @@ -1323,6 +1324,14 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr &mem GetOrMallocAddress(mem_scheduler, device_address, input); input->size = device_address->size_; kernel_launch_info->inputs_.emplace_back(input); + if (update_parameter && input_node->isa()) { + auto param = input_node->cast(); + auto abstract = param->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + if (abstract->isa()) { + mem_scheduler->UpdateHighPriorityMem(device_address); + } + } } for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) { @@ -1398,6 +1407,7 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr &m if (input_tensors.size() != input_nodes.size()) { MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size(); } + mem_scheduler->ClearMemNeedInit(); for (size_t i = 0; i < input_tensors.size(); ++i) { auto input_node = input_nodes[i]; if (!input_node->isa() || !AnfAlgo::OutputAddrExist(input_node, 0)) { @@ -1406,16 +1416,30 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr &m auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); auto tensor = input_tensors[i]; MS_EXCEPTION_IF_NULL(tensor); - auto tensor_address = tensor->device_address(); - if (!tensor->NeedSyncHostToDevice() && tensor_address != nullptr && tensor_address != device_address) { + auto tensor_address = std::dynamic_pointer_cast(tensor->device_address()); + const auto tensor_size = LongToSize(tensor->data().nbytes()); + if (tensor_address == device_address) { + if (tensor->NeedSyncHostToDevice()) { + tensor_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), tensor->data().nbytes(), + tensor->data_type(), tensor->data_c(), tensor->device_info().host_format_); + tensor->set_sync_status(kNoNeedSync); + } + if (mem_scheduler->HasDeviceMem(tensor_address.get())) { + tensor_address->set_ptr(nullptr); + } + continue; + } + if (tensor->NeedSyncHostToDevice()) { + mem_scheduler->AddMemNeedInit(device_address.get()); + } else if (tensor_address != nullptr) { tensor->data_sync(false); + mem_scheduler->AddMemNeedInit(device_address.get()); } MemPriority priority = kMemPriorityLow; - if (AnfAlgo::IsParameterWeight(input_node->cast()) && - graph.IsUpdatedParameter(input_node->cast())) { + const auto ¶meter = input_node->cast(); + if (AnfAlgo::IsParameterWeight(parameter) || graph.IsUpdatedParameter(parameter)) { priority = kMemPriorityHigh; } - auto tensor_size = LongToSize(tensor->data().nbytes()); mem_scheduler->Init(device_address.get(), tensor->data_c(), tensor_size, priority); tensor->set_sync_status(kNoNeedSync); } diff --git a/mindspore/ccsrc/runtime/device/memory_offload_strategy.cc b/mindspore/ccsrc/runtime/device/memory_offload_strategy.cc index 5f49ac0ed4f..f2b8100f7aa 100644 --- a/mindspore/ccsrc/runtime/device/memory_offload_strategy.cc +++ b/mindspore/ccsrc/runtime/device/memory_offload_strategy.cc @@ -22,6 +22,9 @@ namespace mindspore { namespace device { +constexpr size_t kFirstGetMemEventIndex = 1; +constexpr size_t kInitOrMallocMemEventIndex = 0; + std::vector> &MemOffloadStrategy::GetPreComputeEvents(size_t step) { if (pre_compute_events_.size() <= step) { MS_LOG_EXCEPTION << "Index out of pre event range, index:" << step << ", event size:" << pre_compute_events_.size(); @@ -62,7 +65,7 @@ void MemOffloadStrategy::CountMemUsage() { if (mem_events.empty()) { continue; } - auto first_event = mem_events[0]; + auto first_event = mem_events[kInitOrMallocMemEventIndex]; const bool is_high_priority = IsHighPriorityMem(first_event->key); if (is_high_priority) { high_priority_mem_size += first_event->mem_size; @@ -83,6 +86,10 @@ void MemOffloadStrategy::CountMemUsage() { } min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end())); mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())) + high_priority_mem_size; + if (mem_size_ < min_mem_needed_) { + MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << mem_size_ << " while graph needs at least " + << min_mem_needed_; + } } bool MemOffloadStrategy::IsHighPriorityMem(const void *key) { @@ -94,11 +101,6 @@ bool MemOffloadStrategy::IsHighPriorityMem(const void *key) { } void MemOffloadStrategy::CheckMemSize() { - if (mem_size_ < min_mem_needed_) { - MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << mem_size_ << " while graph needs at least " - << min_mem_needed_; - } - if (mem_size_ < mem_used_without_swap_ || !manual_offload_keys_.empty()) { need_swap_ = true; } @@ -116,19 +118,20 @@ void MemOffloadStrategy::GenEventSpan() { if (tensor_events.size() <= 1) { continue; } - const bool is_high_priority = IsHighPriorityMem(tensor_events[0]->key); - for (size_t event_index = 1; event_index < tensor_events.size(); ++event_index) { - auto &event = tensor_events[event_index]; + const bool is_high_priority = IsHighPriorityMem(tensor_events[kInitOrMallocMemEventIndex]->key); + for (size_t i = kFirstGetMemEventIndex; i < tensor_events.size(); ++i) { + auto &event = tensor_events[i]; MS_EXCEPTION_IF_NULL(event); if (event->type != kGet) { MS_LOG(EXCEPTION) << "Event should be Get except fist event."; } - size_t span = 0; - if (event_index == 1 && is_high_priority) { - const auto &last_event = tensor_events[tensor_events.size() - 1]; - span = event->index + total_step_ - last_event->index; - } else { - span = event->index - tensor_events[event_index - 1]->index; + auto latest_event = tensor_events[i - 1]; + if (i == kFirstGetMemEventIndex && is_high_priority) { + latest_event = tensor_events[tensor_events.size() - 1]; + } + auto span = GetSpanBetweenMemEvents(latest_event->index, event->index); + if (is_high_priority && span == 0 && latest_event == event) { + span = total_step_; } if (span > 1) { const size_t span_mul_size = (span - 1) * event->mem_size; @@ -156,7 +159,7 @@ void MemOffloadStrategy::GenSwapEventSet() { for (const auto &iter : event_span_) { auto span = iter.second.second; auto &event = iter.second.first; - auto start_index = ((total_step_ + event->index - span) % total_step_) + 1; + auto start_index = ((event->index + total_step_ - span + 1) % total_step_); bool revert = false; size_t cur_index = start_index; while (cur_index != event->index) { @@ -196,12 +199,12 @@ void MemOffloadStrategy::GenComputeMemEvents() { } const bool is_high_priority = IsHighPriorityMem(item.first); - auto first_event = mem_events[0]; + auto first_event = mem_events[kInitOrMallocMemEventIndex]; MS_EXCEPTION_IF_NULL(first_event); - const auto &second_event = mem_events[1]; - MS_EXCEPTION_IF_NULL(second_event); - if (is_high_priority && swap_events_.find(second_event) != swap_events_.end()) { - first_event->index = second_event->index; + const auto &first_get_event = mem_events[kFirstGetMemEventIndex]; + MS_EXCEPTION_IF_NULL(first_get_event); + if (is_high_priority && swap_events_.find(first_get_event) != swap_events_.end()) { + first_event->index = first_get_event->index; } if ((first_event->type == kInit || first_event->type == kMalloc) && first_event->index < total_step_) { pre_compute_events_[first_event->index].emplace_back(first_event); @@ -211,16 +214,21 @@ void MemOffloadStrategy::GenComputeMemEvents() { const auto &last_event = mem_events[mem_events.size() - 1]; size_t pre_index = is_high_priority ? last_event->index : first_event->index; - for (size_t i = 1; i < mem_events.size(); ++i) { + const auto &swap_out_event_index = GetSwapOutEventIndex(item.first, mem_events); + for (size_t i = kFirstGetMemEventIndex; i < mem_events.size(); ++i) { auto &event = mem_events[i]; MS_EXCEPTION_IF_NULL(event); if (need_swap_ && swap_events_.find(event) != swap_events_.end()) { - auto swap_out_event = std::make_shared(kSwapOut, pre_index); - swap_out_event->key = item.first; - swap_out_event->mem_size = first_event->mem_size; - post_compute_events_[pre_index].emplace_back(swap_out_event); + MemEventType event_type = kSwapOut; + if (is_high_priority && swap_out_event_index.count(i) == 0) { + event_type = kFree; + } + auto free_or_swap_out_event = std::make_shared(event_type, pre_index); + free_or_swap_out_event->key = item.first; + free_or_swap_out_event->mem_size = first_event->mem_size; + post_compute_events_[pre_index].emplace_back(free_or_swap_out_event); // avoid swap-in-event follow init-event - if (first_event->type != kInit || i != 1) { + if (i != kFirstGetMemEventIndex || first_event->type != kInit) { auto swap_in_event = std::make_shared(kSwapIn, event->index); swap_in_event->key = item.first; swap_in_event->mem_size = first_event->mem_size; @@ -246,5 +254,39 @@ void MemOffloadStrategy::GenFreeEvent(const std::shared_ptr &last_even (void)post_compute_events_[last_event->index].emplace_back(free_event); } } + +std::set MemOffloadStrategy::GetSwapOutEventIndex(const void *key, + const std::vector> &mem_events) { + const auto &update_step_iter = high_priority_updated_step_.find(key); + if (update_step_iter == high_priority_updated_step_.end() || update_step_iter->second.empty()) { + return std::set(); + } + const auto &update_steps = update_step_iter->second; + size_t update_steps_index = 0; + std::set swap_out_event_index; + size_t min_swap_index_before_update = SIZE_MAX; + size_t max_swap_out_step = 0; + for (size_t i = 0; i < mem_events.size(); ++i) { + const auto &mem_event = mem_events[i]; + if (swap_events_.count(mem_event) == 0) { + continue; + } + if (mem_event->index <= update_steps[update_steps_index]) { + if (i <= min_swap_index_before_update) { + min_swap_index_before_update = i; + } + } else { + swap_out_event_index.insert(i); + max_swap_out_step = mem_event->index; + while (update_steps_index < update_steps.size() && update_steps[update_steps_index] < mem_event->index) { + ++update_steps_index; + } + } + } + if (max_swap_out_step <= update_steps[update_steps.size() - 1]) { + swap_out_event_index.insert(min_swap_index_before_update); + } + return swap_out_event_index; +} } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/memory_offload_strategy.h b/mindspore/ccsrc/runtime/device/memory_offload_strategy.h index 981ce34b43f..a0cf726905a 100644 --- a/mindspore/ccsrc/runtime/device/memory_offload_strategy.h +++ b/mindspore/ccsrc/runtime/device/memory_offload_strategy.h @@ -41,10 +41,12 @@ class MemOffloadStrategy { public: MemOffloadStrategy(const std::map &mem_priority, const std::map>> &mem_events, - const std::set &manual_offload_keys, size_t total_step) + const std::set &manual_offload_keys, + const std::map> &high_priority_updated_step, size_t total_step) : mem_priority_(mem_priority), mem_events_(mem_events), manual_offload_keys_(manual_offload_keys), + high_priority_updated_step_(high_priority_updated_step), total_step_(total_step) {} virtual ~MemOffloadStrategy() = default; @@ -75,10 +77,16 @@ class MemOffloadStrategy { void GenComputeMemEvents(); void GenFreeEvent(const std::shared_ptr &last_event); + std::set GetSwapOutEventIndex(const void *key, const std::vector> &mem_events); + + size_t GetSpanBetweenMemEvents(size_t pre_step, size_t post_step) const { + return (post_step + total_step_ - pre_step) % total_step_; + } const std::map &mem_priority_; const std::map>> &mem_events_; const std::set &manual_offload_keys_; + std::map> high_priority_updated_step_; const size_t total_step_; std::vector>> pre_compute_events_; std::vector>> post_compute_events_; diff --git a/mindspore/ccsrc/runtime/device/memory_scheduler.cc b/mindspore/ccsrc/runtime/device/memory_scheduler.cc index 846418e37e9..97e0a6d72b6 100644 --- a/mindspore/ccsrc/runtime/device/memory_scheduler.cc +++ b/mindspore/ccsrc/runtime/device/memory_scheduler.cc @@ -45,10 +45,10 @@ void MemScheduler::Clear() { if (mem_handler_ == nullptr) { return; } - for (auto &item : high_priority_device_ptr_) { + for (auto &item : mem_result_) { mem_handler_->FreeDevice(item.second); } - high_priority_device_ptr_.clear(); + mem_result_.clear(); } void MemScheduler::ClearAllocatedMem() { @@ -57,12 +57,11 @@ void MemScheduler::ClearAllocatedMem() { } for (auto &item : mem_result_) { const auto device_ptr = item.second; - if (device_ptr == nullptr) { + if (device_ptr != nullptr) { mem_handler_->FreeDevice(device_ptr); } } mem_result_.clear(); - high_priority_device_ptr_.clear(); for (const auto &item : swap_host_ptr_) { const auto host_ptr = item.second; if (host_ptr != nullptr) { @@ -125,22 +124,19 @@ bool MemScheduler::PreCompute(void *stream) { MS_EXCEPTION_IF_NULL(event); MS_LOG(DEBUG) << "Pre compute " << current_step_ << ": " << event->key << " v " << event->type; if (event->type == kInit || event->type == kMalloc) { - auto priority = mem_priority_[event->key]; - auto iter = high_priority_device_ptr_.find(event->key); - if (priority != kMemPriorityLow && iter != high_priority_device_ptr_.end()) { - MS_EXCEPTION_IF_NULL(iter->second); - mem_result_[event->key] = iter->second; - continue; - } - auto device_ptr = mem_handler_->MallocDevice(event->mem_size); - if (device_ptr == nullptr) { - return false; - } - if (priority != kMemPriorityLow) { - high_priority_device_ptr_[event->key] = device_ptr; + const auto &iter = mem_result_.find(event->key); + const bool new_malloc = iter == mem_result_.end(); + void *device_ptr; + if (new_malloc) { + device_ptr = mem_handler_->MallocDevice(event->mem_size); + if (device_ptr == nullptr) { + return false; + } + } else { + device_ptr = iter->second; } - if (event->type == kInit) { + if (event->type == kInit && (new_malloc || high_priority_mem_need_init_.count(event->key) != 0)) { auto host_ptr = init_host_ptr_[event->key]; MS_EXCEPTION_IF_NULL(host_ptr); mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream); @@ -160,9 +156,6 @@ bool MemScheduler::PreCompute(void *stream) { MS_EXCEPTION_IF_NULL(host_ptr); mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream); mem_result_[event->key] = device_ptr; - if (mem_priority_[event->key] == kMemPriorityHigh) { - high_priority_device_ptr_[event->key] = device_ptr; - } if (!from_init) { mem_handler_->FreeHost(host_ptr); (void)swap_host_ptr_.erase(event->key); @@ -211,9 +204,6 @@ bool MemScheduler::PostCompute(void *stream) { mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream); mem_handler_->FreeDevice(device_ptr); (void)mem_result_.erase(event->key); - if (mem_priority_[event->key] == kMemPriorityHigh) { - high_priority_device_ptr_.erase(event->key); - } } } ++current_step_; @@ -225,7 +215,8 @@ void MemScheduler::OptMemUsage(float mem_used_factor) { MS_EXCEPTION_IF_NULL(mem_handler_); if (strategy_ == nullptr) { - strategy_ = std::make_shared(mem_priority_, mem_events_, manual_offload_keys_, total_step_); + strategy_ = std::make_shared(mem_priority_, mem_events_, manual_offload_keys_, + high_priority_updated_step_, total_step_); if (manual_offload_keys_.empty()) { compute_time_.resize(total_step_); } else { diff --git a/mindspore/ccsrc/runtime/device/memory_scheduler.h b/mindspore/ccsrc/runtime/device/memory_scheduler.h index d322775abac..e455f3021b4 100644 --- a/mindspore/ccsrc/runtime/device/memory_scheduler.h +++ b/mindspore/ccsrc/runtime/device/memory_scheduler.h @@ -53,6 +53,14 @@ class MemScheduler { void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow); + bool HasDeviceMem(const void *key) const { return mem_result_.find(key) != mem_result_.end(); } + + void UpdateHighPriorityMem(const void *key) { + if (need_record_event_) { + high_priority_updated_step_[key].emplace_back(current_step_); + } + } + void SetTotalStep(size_t step) { total_step_ = step; step_events_.resize(total_step_); @@ -72,6 +80,10 @@ class MemScheduler { void SetOffload(const void *key) { (void)manual_offload_keys_.insert(key); } + void AddMemNeedInit(const void *key) { high_priority_mem_need_init_.insert(key); } + + void ClearMemNeedInit() { high_priority_mem_need_init_.clear(); } + private: void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0); @@ -86,7 +98,8 @@ class MemScheduler { std::map mem_result_; std::map init_host_ptr_; std::map swap_host_ptr_; - std::map high_priority_device_ptr_; + std::map> high_priority_updated_step_; + std::set high_priority_mem_need_init_; size_t total_step_{0}; size_t current_step_{0}; bool need_record_event_{true};