From 16ca5375054bd8c925d155637ba38c43b4a74438 Mon Sep 17 00:00:00 2001 From: tanghuikang Date: Mon, 22 Nov 2021 16:20:41 +0800 Subject: [PATCH] Adjust swap strategy --- .../ccsrc/runtime/device/kernel_runtime.cc | 67 +++++--- .../ccsrc/runtime/device/kernel_runtime.h | 5 +- .../runtime/device/memory_offload_strategy.cc | 159 ++++++++++-------- .../runtime/device/memory_offload_strategy.h | 9 +- .../ccsrc/runtime/device/memory_scheduler.cc | 57 +++++-- .../ccsrc/runtime/device/memory_scheduler.h | 4 +- 6 files changed, 191 insertions(+), 110 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 69de9d34d05..c3615c45fbb 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -1294,9 +1294,6 @@ void KernelRuntime::GetOrMallocAddress(const std::shared_ptr &mem_ kernel_addr->addr = device_address->ptr_; } else { kernel_addr->addr = mem_scheduler->GetOrMalloc(device_address, device_address->size_); - if (mem_scheduler->IsHighPriorityMem(device_address)) { - device_address->ptr_ = kernel_addr->addr; - } } } @@ -1343,37 +1340,29 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr &mem } void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr &mem_scheduler, - const session::KernelGraph &graph, const AnfNodePtr &kernel, bool mock) { + const session::KernelGraph &graph, const AnfNodePtr &kernel) { MS_EXCEPTION_IF_NULL(mem_scheduler); MS_EXCEPTION_IF_NULL(kernel); auto kernel_mod = AnfAlgo::GetKernelMod(kernel); MS_EXCEPTION_IF_NULL(kernel_mod); for (size_t input_idx = 0; input_idx < kernel_mod->GetInputSizeList().size(); ++input_idx) { const auto input_node_index = AnfAlgo::GetPrevNodeOutput(kernel, input_idx, true); - if (input_node_index.first == nullptr || !input_node_index.first->isa()) { - continue; + if (input_node_index.first != nullptr && input_node_index.first->isa()) { + SyncNodeOutputTensor(mem_scheduler, input_node_index, graph); } - SyncNodeOutputTensor(mem_scheduler, input_node_index, graph, mock); } for (size_t output_idx = 0; output_idx < kernel_mod->GetOutputSizeList().size(); ++output_idx) { - SyncNodeOutputTensor(mem_scheduler, std::make_pair(kernel, output_idx), graph, mock); + SyncNodeOutputTensor(mem_scheduler, std::make_pair(kernel, output_idx), graph); } } void KernelRuntime::SyncNodeOutputTensor(const std::shared_ptr &mem_scheduler, - const KernelWithIndex &node_output_index, const session::KernelGraph &graph, - bool mock) { + const KernelWithIndex &node_output_index, const session::KernelGraph &graph) { MS_EXCEPTION_IF_NULL(mem_scheduler); if (node_output_index.first == nullptr) { return; } auto device_address = AnfAlgo::GetMutableOutputAddr(node_output_index, true); - if (mock) { - if (graph.IsInternalOutput(node_output_index.first, node_output_index.second) && device_address != nullptr) { - mem_scheduler->SetMemPriority(device_address.get(), kMemPriorityHigh); - } - return; - } auto tensor = graph.GetNodeOutputTensor(node_output_index); if (tensor == nullptr) { return; @@ -1407,22 +1396,20 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr &m MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size(); } for (size_t i = 0; i < input_tensors.size(); ++i) { - auto tensor = input_tensors[i]; - MS_EXCEPTION_IF_NULL(tensor); auto input_node = input_nodes[i]; if (!input_node->isa() || !AnfAlgo::OutputAddrExist(input_node, 0)) { continue; } auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); + auto tensor = input_tensors[i]; MS_EXCEPTION_IF_NULL(tensor); - MemPriority priority = kMemPriorityLow; auto tensor_address = tensor->device_address(); if (!tensor->NeedSyncHostToDevice() && tensor_address != nullptr && tensor_address != device_address) { tensor->data_sync(false); } - if (AnfAlgo::IsParameterWeight(input_node->cast()) || + MemPriority priority = kMemPriorityLow; + if (AnfAlgo::IsParameterWeight(input_node->cast()) && graph.IsUpdatedParameter(input_node->cast())) { - tensor->set_device_address(device_address); priority = kMemPriorityHigh; } auto tensor_size = LongToSize(tensor->data().nbytes()); @@ -1477,7 +1464,9 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod } } if (mem_scheduler != nullptr) { - SyncNodeOutputTensors(mem_scheduler, graph, kernel, mock); + if (!mock) { + SyncNodeOutputTensors(mem_scheduler, graph, kernel); + } ret = mem_scheduler->PostCompute(stream); if (!ret) { return ret; @@ -1553,9 +1542,43 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock } LaunchKernelEvent(kernel_post_run_events, kernels[i]); } + if (UseMemScheduler() && !mock) { + SyncUpdatedParameter(graph, mem_scheduler); + } return true; } +void KernelRuntime::SyncUpdatedParameter(const session::KernelGraph &graph, + const std::shared_ptr &mem_scheduler) { + MS_EXCEPTION_IF_NULL(mem_scheduler); + auto &input_nodes = graph.input_nodes(); + auto &input_tensors = graph.input_tensors(); + if (input_tensors.size() != input_nodes.size()) { + MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size(); + } + + for (size_t i = 0; i < input_tensors.size(); ++i) { + auto input_node = input_nodes[i]; + if (!input_node->isa() || !AnfAlgo::OutputAddrExist(input_node, 0)) { + continue; + } + auto parameter = input_node->cast(); + MS_EXCEPTION_IF_NULL(parameter); + if (!graph.IsUpdatedParameter(parameter)) { + continue; + } + auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); + auto tensor = input_tensors[i]; + MS_EXCEPTION_IF_NULL(tensor); + auto device_ptr = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size(), kMemPriorityHigh); + if (device_ptr != nullptr) { + device_address->set_ptr(device_ptr); + tensor->set_device_address(device_address); + tensor->set_sync_status(kNeedSyncDeviceToHost); + } + } +} + void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index 957ef05cd5d..9e5130b6da3 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -95,6 +95,7 @@ class KernelRuntime { void set_device_id(uint32_t device_id) { device_id_ = device_id; } uint32_t device_id() { return device_id_; } static bool UseMemScheduler(); + void SyncUpdatedParameter(const session::KernelGraph &graph, const std::shared_ptr &mem_scheduler); #ifdef ENABLE_DEBUGGER // set debugger @@ -155,9 +156,9 @@ class KernelRuntime { const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr); void InitGraphInputTensors(const std::shared_ptr &mem_scheduler, const session::KernelGraph &graph); void SyncNodeOutputTensors(const std::shared_ptr &mem_scheduler, const session::KernelGraph &graph, - const AnfNodePtr &kernel, bool mock); + const AnfNodePtr &kernel); void SyncNodeOutputTensor(const std::shared_ptr &mem_scheduler, const KernelWithIndex &output, - const session::KernelGraph &graph, bool mock); + const session::KernelGraph &graph); void AssignCommunicationMem(const session::KernelGraph &graph); bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false); diff --git a/mindspore/ccsrc/runtime/device/memory_offload_strategy.cc b/mindspore/ccsrc/runtime/device/memory_offload_strategy.cc index de0079da4d9..0f0610a7360 100644 --- a/mindspore/ccsrc/runtime/device/memory_offload_strategy.cc +++ b/mindspore/ccsrc/runtime/device/memory_offload_strategy.cc @@ -43,7 +43,7 @@ void MemOffloadStrategy::Execute() { CheckMemSize(); if (need_swap_) { GenEventSpan(); - GenNoSwapEventSet(); + GenSwapEventSet(); } GenComputeMemEvents(); } @@ -57,37 +57,41 @@ void MemOffloadStrategy::CountMemUsage() { } min_mem_used_.resize(total_step_, 0); std::vector total_mem_used(total_step_, 0); + size_t high_priority_mem_size = 0; for (auto &item : mem_events_) { auto &mem_events = item.second; if (mem_events.empty()) { continue; } auto first_event = mem_events[0]; - size_t cur_index = 0; - if (first_event != nullptr && first_event->type == kInit && mem_events.size() > 1) { - first_event = mem_events[1]; - cur_index = 1; - } - auto last_event = mem_events[mem_events.size() - 1]; - for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) { - if (start_index < total_step_) { + const bool is_high_priority = IsHighPriorityMem(first_event->key); + if (is_high_priority) { + high_priority_mem_size += first_event->mem_size; + } else { + auto last_event = mem_events[mem_events.size() - 1]; + for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) { total_mem_used[start_index] += first_event->mem_size; - } else { - MS_LOG(ERROR) << "Error mem event index " << start_index; } } - for (; cur_index < mem_events.size(); ++cur_index) { - auto &event = mem_events[cur_index]; + // Calculate the minimum memory size for kernel execution. + for (const auto &event : mem_events) { MS_EXCEPTION_IF_NULL(event); - if (event->index < total_step_) { - min_mem_used_[event->index] += first_event->mem_size; - } else { - MS_LOG(ERROR) << "Error mem event index " << event->index; + if (event->type != kGet) { + continue; } + min_mem_used_[event->index] += first_event->mem_size; } } min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end())); - mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())); + mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())) + high_priority_mem_size; +} + +bool MemOffloadStrategy::IsHighPriorityMem(const void *key) { + auto iter = mem_priority_.find(key); + if (iter != mem_priority_.end()) { + return iter->second == kMemPriorityHigh; + } + return false; } void MemOffloadStrategy::CheckMemSize() { @@ -110,48 +114,60 @@ void MemOffloadStrategy::GenEventSpan() { } for (auto &item : mem_events_) { auto &tensor_events = item.second; - if (tensor_events.empty()) { + if (tensor_events.size() <= 1) { continue; } - auto first_event = tensor_events[0]; - size_t cur_index = 0; - if (first_event != nullptr && first_event->type == kInit && tensor_events.size() > 1) { - first_event = tensor_events[1]; - cur_index = 1; - } - size_t last_index = first_event->index; - for (; cur_index < tensor_events.size(); ++cur_index) { - auto &event = tensor_events[cur_index]; + const bool is_high_priority = IsHighPriorityMem(tensor_events[0]->key); + for (size_t event_index = 1; event_index < tensor_events.size(); ++event_index) { + auto &event = tensor_events[event_index]; MS_EXCEPTION_IF_NULL(event); - auto span = event->index - last_index; - if (span > 1) { - (void)event_span_.emplace(span, event); + if (event->type != kGet) { + MS_LOG(EXCEPTION) << "Event should be Get except fist event."; + } + size_t span = 0; + if (event_index == 1 && is_high_priority) { + const auto &last_event = tensor_events[tensor_events.size() - 1]; + span = event->index + total_step_ - last_event->index; + } else { + span = event->index - tensor_events[event_index - 1]->index; + } + if (span > 1) { + const size_t span_mul_size = (span - 1) * event->mem_size; + (void)event_span_.emplace(std::make_pair(span_mul_size, std::make_pair(event, span))); } - last_index = event->index; } } } -void MemOffloadStrategy::GenNoSwapEventSet() { - no_swap_events_.clear(); +void MemOffloadStrategy::GenSwapEventSet() { + swap_events_.clear(); std::vector cur_mem_used(min_mem_used_.begin(), min_mem_used_.end()); - for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) { - auto span = iter->first; - auto &event = iter->second; - auto start_index = event->index - span + 1; + for (const auto &iter : event_span_) { + auto span = iter.second.second; + auto &event = iter.second.first; + auto start_index = ((total_step_ + event->index - span) % total_step_) + 1; bool revert = false; - for (size_t i = start_index; i < event->index; ++i) { - cur_mem_used[i] += event->mem_size; - if (cur_mem_used[i] > mem_size_) { + size_t cur_index = start_index; + while (cur_index != event->index) { + cur_mem_used[cur_index] += event->mem_size; + if (cur_mem_used[cur_index] > mem_size_) { revert = true; } + cur_index += 1; + if (cur_index >= total_step_) { + cur_index = 0; + } } if (revert) { - for (size_t i = start_index; i < event->index; ++i) { - cur_mem_used[i] -= event->mem_size; + cur_index = start_index; + while (cur_index != event->index) { + cur_mem_used[cur_index] -= event->mem_size; + cur_index += 1; + if (cur_index >= total_step_) { + cur_index = 0; + } } - } else { - (void)no_swap_events_.emplace(event); + (void)swap_events_.emplace(event); } } } @@ -166,34 +182,31 @@ void MemOffloadStrategy::GenComputeMemEvents() { if (mem_events.empty()) { continue; } + // No need to generate events for memory that has only one event, which means it is never used by any kernel. + if (mem_events.size() <= 1) { + continue; + } + + const bool is_high_priority = IsHighPriorityMem(item.first); auto first_event = mem_events[0]; MS_EXCEPTION_IF_NULL(first_event); - if (first_event->type == kInit) { - if (mem_events.size() > 1) { - auto &second_event = mem_events[1]; - MS_EXCEPTION_IF_NULL(second_event); - first_event->index = second_event->index; - } else { - continue; - } + const auto &second_event = mem_events[1]; + MS_EXCEPTION_IF_NULL(second_event); + if (is_high_priority && swap_events_.find(second_event) != swap_events_.end()) { + first_event->index = second_event->index; } - if ((first_event->type == kInit || first_event->type == kMalloc) && - first_event->index < pre_compute_events_.size()) { + if ((first_event->type == kInit || first_event->type == kMalloc) && first_event->index < total_step_) { pre_compute_events_[first_event->index].emplace_back(first_event); } else { MS_LOG_EXCEPTION << "First event should be init or malloc!"; } - MemPriority priority = kMemPriorityLow; - auto iter = mem_priority_.find(first_event->key); - if (iter != mem_priority_.end()) { - priority = iter->second; - } - size_t pre_index = first_event->index; + + const auto &last_event = mem_events[mem_events.size() - 1]; + size_t pre_index = is_high_priority ? last_event->index : first_event->index; for (size_t i = 1; i < mem_events.size(); ++i) { auto &event = mem_events[i]; MS_EXCEPTION_IF_NULL(event); - if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow && - no_swap_events_.find(event) == no_swap_events_.end()) { + if (need_swap_ && swap_events_.find(event) != swap_events_.end()) { auto swap_out_event = std::make_shared(kSwapOut, pre_index); swap_out_event->key = item.first; swap_out_event->mem_size = first_event->mem_size; @@ -208,17 +221,19 @@ void MemOffloadStrategy::GenComputeMemEvents() { } pre_index = event->index; } - if (priority != kMemPriorityLow) { - continue; - } - auto &last_event = mem_events[mem_events.size() - 1]; - MS_EXCEPTION_IF_NULL(last_event); - auto free_event = std::make_shared(kFree, last_event->index); - free_event->key = item.first; - if (last_event->index < post_compute_events_.size()) { - (void)post_compute_events_[last_event->index].emplace_back(free_event); + if (!is_high_priority) { + GenFreeEvent(last_event); } } } + +void MemOffloadStrategy::GenFreeEvent(const std::shared_ptr &last_event) { + MS_EXCEPTION_IF_NULL(last_event); + auto free_event = std::make_shared(kFree, last_event->index); + free_event->key = last_event->key; + if (last_event->index < post_compute_events_.size()) { + (void)post_compute_events_[last_event->index].emplace_back(free_event); + } +} } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/memory_offload_strategy.h b/mindspore/ccsrc/runtime/device/memory_offload_strategy.h index 3f8f2ef7c6a..306d422edaf 100644 --- a/mindspore/ccsrc/runtime/device/memory_offload_strategy.h +++ b/mindspore/ccsrc/runtime/device/memory_offload_strategy.h @@ -58,12 +58,15 @@ class MemOffloadStrategy { bool need_swap() const { return need_swap_; } + bool IsHighPriorityMem(const void *key); + private: void CountMemUsage(); void CheckMemSize(); void GenEventSpan(); - void GenNoSwapEventSet(); + void GenSwapEventSet(); void GenComputeMemEvents(); + void GenFreeEvent(const std::shared_ptr &last_event); const std::map &mem_priority_; const std::map>> &mem_events_; @@ -74,8 +77,8 @@ class MemOffloadStrategy { size_t mem_size_{0}; std::vector compute_time_; bool need_swap_{false}; - std::multimap> event_span_; - std::set> no_swap_events_; + std::multimap, size_t>> event_span_; + std::set> swap_events_; std::vector min_mem_used_; size_t mem_used_without_swap_{0}; size_t min_mem_needed_{0}; diff --git a/mindspore/ccsrc/runtime/device/memory_scheduler.cc b/mindspore/ccsrc/runtime/device/memory_scheduler.cc index 8cdda6a2279..33d54db2dae 100644 --- a/mindspore/ccsrc/runtime/device/memory_scheduler.cc +++ b/mindspore/ccsrc/runtime/device/memory_scheduler.cc @@ -26,7 +26,7 @@ namespace mindspore { namespace device { namespace { -constexpr float kMaxMemReuseFactor = 0.9; +constexpr float kMaxMemReuseFactor = 1.0; constexpr float kMinMemReuseFactor = 0.5; constexpr float kRetryFactor = 0.1; @@ -51,12 +51,25 @@ void MemScheduler::Clear() { high_priority_device_ptr_.clear(); } -bool MemScheduler::IsHighPriorityMem(const void *key) { - auto iter = mem_priority_.find(key); - if (iter != mem_priority_.end()) { - return iter->second == kMemPriorityHigh; +void MemScheduler::ClearTempMem() { + if (mem_handler_ == nullptr) { + return; } - return false; + for (auto &item : mem_result_) { + const auto device_ptr = item.second; + if (device_ptr == nullptr) { + mem_handler_->FreeDevice(device_ptr); + } + } + mem_result_.clear(); + high_priority_device_ptr_.clear(); + for (const auto &item : swap_host_ptr_) { + const auto host_ptr = item.second; + if (host_ptr != nullptr) { + mem_handler_->FreeHost(host_ptr); + } + } + swap_host_ptr_.clear(); } void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; } @@ -88,9 +101,8 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr if (mem_priority_.find(key) == mem_priority_.end()) { mem_priority_[key] = priority; Record(key, kMalloc, mem_size); - } else { - Record(key, kGet, mem_size); } + Record(key, kGet, mem_size); return nullptr; } if (strategy_ == nullptr) { @@ -101,9 +113,8 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr auto ptr = iter->second; MS_EXCEPTION_IF_NULL(ptr); return ptr; - } else { - MS_LOG_EXCEPTION << "Mem extender get nullptr result!"; } + return nullptr; } bool MemScheduler::PreCompute(void *stream) { @@ -151,6 +162,9 @@ bool MemScheduler::PreCompute(void *stream) { MS_EXCEPTION_IF_NULL(host_ptr); mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream); mem_result_[event->key] = device_ptr; + if (mem_priority_[event->key] == kMemPriorityHigh) { + high_priority_device_ptr_[event->key] = device_ptr; + } if (!from_init) { mem_handler_->FreeHost(host_ptr); (void)swap_host_ptr_.erase(event->key); @@ -199,6 +213,9 @@ bool MemScheduler::PostCompute(void *stream) { mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream); mem_handler_->FreeDevice(device_ptr); (void)mem_result_.erase(event->key); + if (mem_priority_[event->key] == kMemPriorityHigh) { + high_priority_device_ptr_.erase(event->key); + } } } ++current_step_; @@ -221,6 +238,7 @@ void MemScheduler::OptMemUsage(float mem_used_factor) { } void MemScheduler::Optimize() { + AdjustFirstEventIndex(); float mem_used_factor = kMaxMemReuseFactor; while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) { OptMemUsage(mem_used_factor); @@ -247,11 +265,30 @@ void MemScheduler::Optimize() { if (ret) { optimized_ = true; } else { + ClearTempMem(); mem_used_factor -= kRetryFactor; } } } +void MemScheduler::AdjustFirstEventIndex() { + for (const auto &item : mem_events_) { + const auto &mem_events = item.second; + if (mem_events.empty()) { + continue; + } + auto &first_event = mem_events[0]; + MS_EXCEPTION_IF_NULL(first_event); + const auto &priority_iter = mem_priority_.find(item.first); + const bool is_high_priority = (priority_iter != mem_priority_.end() && priority_iter->second == kMemPriorityHigh); + if (first_event->type == kInit && !is_high_priority && mem_events.size() > 1) { + const auto &second_event = mem_events[1]; + MS_EXCEPTION_IF_NULL(second_event); + first_event->index = second_event->index; + } + } +} + void MemScheduler::Update() { if (!optimized_) { return; diff --git a/mindspore/ccsrc/runtime/device/memory_scheduler.h b/mindspore/ccsrc/runtime/device/memory_scheduler.h index 8c7106f45fd..e4794bcc5bd 100644 --- a/mindspore/ccsrc/runtime/device/memory_scheduler.h +++ b/mindspore/ccsrc/runtime/device/memory_scheduler.h @@ -70,7 +70,7 @@ class MemScheduler { void Clear(); - bool IsHighPriorityMem(const void *key); + void ClearTempMem(); void SetMemPriority(const void *key, MemPriority priority); @@ -79,6 +79,8 @@ class MemScheduler { void OptMemUsage(float mem_used_factor = 1.0f); + void AdjustFirstEventIndex(); + std::map mem_priority_; std::map>> mem_events_; std::vector>> step_events_;