From 7c312bd38cac65f0fc3c4847f432954101d5f29d Mon Sep 17 00:00:00 2001 From: kswang Date: Tue, 16 Nov 2021 11:39:37 +0800 Subject: [PATCH] add mem offload strategy --- mindspore/ccsrc/runtime/device/CMakeLists.txt | 2 +- .../ccsrc/runtime/device/kernel_runtime.cc | 26 +- .../runtime/device/memory_offload_strategy.cc | 224 +++++++++++++ .../runtime/device/memory_offload_strategy.h | 85 +++++ .../ccsrc/runtime/device/memory_scheduler.cc | 304 ++++++------------ .../ccsrc/runtime/device/memory_scheduler.h | 64 ++-- tests/ut/cpp/CMakeLists.txt | 1 + tests/ut/cpp/device/mem_scheduler_test.cc | 6 +- 8 files changed, 450 insertions(+), 262 deletions(-) create mode 100644 mindspore/ccsrc/runtime/device/memory_offload_strategy.cc create mode 100644 mindspore/ccsrc/runtime/device/memory_offload_strategy.h diff --git a/mindspore/ccsrc/runtime/device/CMakeLists.txt b/mindspore/ccsrc/runtime/device/CMakeLists.txt index d45b0ea50e7..81738e29d54 100644 --- a/mindspore/ccsrc/runtime/device/CMakeLists.txt +++ b/mindspore/ccsrc/runtime/device/CMakeLists.txt @@ -1,7 +1,7 @@ file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc" "kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc" - "bucket.cc" "launch_kernel.cc" "launch_mul.cc" + "memory_offload_strategy.cc" "bucket.cc" "launch_kernel.cc" "launch_mul.cc" ) if("${ENABLE_HIDDEN}" STREQUAL "OFF") diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index e48e28804c8..10e738d0aba 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -42,9 +42,6 @@ using mindspore::kernel::AddressPtr; namespace mindspore { namespace device { -constexpr float kMaxMemReuseFactor = 0.8; -constexpr float kMinMemReuseFactor = 0.5; -constexpr float kRetryFactor = 0.1; constexpr size_t kAtomicCleanInputSize = 2; namespace { std::vector GetGraphInputs(const session::KernelGraph &graph) { @@ -1494,13 +1491,15 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); std::shared_ptr mem_scheduler = nullptr; + if (UseMemScheduler()) { mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id()); MS_EXCEPTION_IF_NULL(mem_scheduler); - mem_scheduler->SetMemHandler(mem_manager_); - mem_scheduler->Reset(); + mem_scheduler->ResetCurrentStep(); + mem_scheduler->Update(); InitGraphInputTensors(mem_scheduler, graph); } + const auto &kernels = graph.execution_order(); std::vector dynamic_kernel_list; auto iter = graph_dynamic_kernel_map_.find(graph.graph_id()); @@ -1555,7 +1554,6 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock } LaunchKernelEvent(kernel_post_run_events, i); } - return true; } @@ -1566,21 +1564,15 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) { return; } auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id()); + MS_EXCEPTION_IF_NULL(mem_scheduler); + mem_scheduler->SetMemHandler(mem_manager_); + mem_scheduler->SetTotalStep(graph.execution_order().size()); + if (mem_scheduler->need_record_event()) { (void)LaunchKernelMod(graph, true); mem_scheduler->set_need_record_event(false); } - float mem_used_factor = kMaxMemReuseFactor; - while (!mem_scheduler->optimized() && mem_used_factor >= kMinMemReuseFactor) { - mem_scheduler->SetMemUsedFactor(mem_used_factor); - mem_scheduler->OptMemUsage(); - bool ret = LaunchKernelMod(graph, true); - if (ret) { - mem_scheduler->set_optimized(true); - } else { - mem_used_factor -= kRetryFactor; - } - } + mem_scheduler->Optimize(); if (!mem_scheduler->optimized()) { MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit."; } diff --git a/mindspore/ccsrc/runtime/device/memory_offload_strategy.cc b/mindspore/ccsrc/runtime/device/memory_offload_strategy.cc new file mode 100644 index 00000000000..de0079da4d9 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/memory_offload_strategy.cc @@ -0,0 +1,224 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "runtime/device/memory_offload_strategy.h" +#include +#include +#include +#include +#include +#include "utils/log_adapter.h" + +namespace mindspore { +namespace device { +std::vector> &MemOffloadStrategy::GetPreComputeEvents(size_t step) { + if (pre_compute_events_.size() <= step) { + MS_LOG_EXCEPTION << "Index out of pre event range, index:" << step << ", event size:" << pre_compute_events_.size(); + } + return pre_compute_events_[step]; +} + +std::vector> &MemOffloadStrategy::GetPostComputeEvents(size_t step) { + if (post_compute_events_.size() <= step) { + MS_LOG_EXCEPTION << "Index out of post event range, index:" << step + << ", event size:" << post_compute_events_.size(); + } + return post_compute_events_[step]; +} + +void MemOffloadStrategy::Execute() { + CountMemUsage(); + CheckMemSize(); + if (need_swap_) { + GenEventSpan(); + GenNoSwapEventSet(); + } + GenComputeMemEvents(); +} + +void MemOffloadStrategy::CountMemUsage() { + if (!min_mem_used_.empty()) { + return; + } + if (mem_events_.empty() || total_step_ == 0) { + return; + } + min_mem_used_.resize(total_step_, 0); + std::vector total_mem_used(total_step_, 0); + for (auto &item : mem_events_) { + auto &mem_events = item.second; + if (mem_events.empty()) { + continue; + } + auto first_event = mem_events[0]; + size_t cur_index = 0; + if (first_event != nullptr && first_event->type == kInit && mem_events.size() > 1) { + first_event = mem_events[1]; + cur_index = 1; + } + auto last_event = mem_events[mem_events.size() - 1]; + for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) { + if (start_index < total_step_) { + total_mem_used[start_index] += first_event->mem_size; + } else { + MS_LOG(ERROR) << "Error mem event index " << start_index; + } + } + for (; cur_index < mem_events.size(); ++cur_index) { + auto &event = mem_events[cur_index]; + MS_EXCEPTION_IF_NULL(event); + if (event->index < total_step_) { + min_mem_used_[event->index] += first_event->mem_size; + } else { + MS_LOG(ERROR) << "Error mem event index " << event->index; + } + } + } + min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end())); + mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())); +} + +void MemOffloadStrategy::CheckMemSize() { + if (mem_size_ < min_mem_needed_) { + MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << mem_size_ << " while graph needs at least " + << min_mem_needed_; + } + + if (mem_size_ < mem_used_without_swap_) { + need_swap_ = true; + } + + MS_LOG(INFO) << "Available mem size: " << mem_size_ << ", graph needs mem size: " << mem_used_without_swap_ + << " without swap, and needs at least " << min_mem_needed_ << " with swap."; +} + +void MemOffloadStrategy::GenEventSpan() { + if (!event_span_.empty()) { + return; + } + for (auto &item : mem_events_) { + auto &tensor_events = item.second; + if (tensor_events.empty()) { + continue; + } + auto first_event = tensor_events[0]; + size_t cur_index = 0; + if (first_event != nullptr && first_event->type == kInit && tensor_events.size() > 1) { + first_event = tensor_events[1]; + cur_index = 1; + } + size_t last_index = first_event->index; + for (; cur_index < tensor_events.size(); ++cur_index) { + auto &event = tensor_events[cur_index]; + MS_EXCEPTION_IF_NULL(event); + auto span = event->index - last_index; + if (span > 1) { + (void)event_span_.emplace(span, event); + } + last_index = event->index; + } + } +} + +void MemOffloadStrategy::GenNoSwapEventSet() { + no_swap_events_.clear(); + std::vector cur_mem_used(min_mem_used_.begin(), min_mem_used_.end()); + for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) { + auto span = iter->first; + auto &event = iter->second; + auto start_index = event->index - span + 1; + bool revert = false; + for (size_t i = start_index; i < event->index; ++i) { + cur_mem_used[i] += event->mem_size; + if (cur_mem_used[i] > mem_size_) { + revert = true; + } + } + if (revert) { + for (size_t i = start_index; i < event->index; ++i) { + cur_mem_used[i] -= event->mem_size; + } + } else { + (void)no_swap_events_.emplace(event); + } + } +} + +void MemOffloadStrategy::GenComputeMemEvents() { + pre_compute_events_.clear(); + post_compute_events_.clear(); + pre_compute_events_.resize(total_step_); + post_compute_events_.resize(total_step_); + for (auto &item : mem_events_) { + auto &mem_events = item.second; + if (mem_events.empty()) { + continue; + } + auto first_event = mem_events[0]; + MS_EXCEPTION_IF_NULL(first_event); + if (first_event->type == kInit) { + if (mem_events.size() > 1) { + auto &second_event = mem_events[1]; + MS_EXCEPTION_IF_NULL(second_event); + first_event->index = second_event->index; + } else { + continue; + } + } + if ((first_event->type == kInit || first_event->type == kMalloc) && + first_event->index < pre_compute_events_.size()) { + pre_compute_events_[first_event->index].emplace_back(first_event); + } else { + MS_LOG_EXCEPTION << "First event should be init or malloc!"; + } + MemPriority priority = kMemPriorityLow; + auto iter = mem_priority_.find(first_event->key); + if (iter != mem_priority_.end()) { + priority = iter->second; + } + size_t pre_index = first_event->index; + for (size_t i = 1; i < mem_events.size(); ++i) { + auto &event = mem_events[i]; + MS_EXCEPTION_IF_NULL(event); + if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow && + no_swap_events_.find(event) == no_swap_events_.end()) { + auto swap_out_event = std::make_shared(kSwapOut, pre_index); + swap_out_event->key = item.first; + swap_out_event->mem_size = first_event->mem_size; + post_compute_events_[pre_index].emplace_back(swap_out_event); + auto swap_in_event = std::make_shared(kSwapIn, event->index); + swap_in_event->key = item.first; + swap_in_event->mem_size = first_event->mem_size; + (void)pre_compute_events_[event->index].emplace_back(swap_in_event); + } + if (event->index < pre_compute_events_.size()) { + (void)pre_compute_events_[event->index].emplace_back(event); + } + pre_index = event->index; + } + if (priority != kMemPriorityLow) { + continue; + } + auto &last_event = mem_events[mem_events.size() - 1]; + MS_EXCEPTION_IF_NULL(last_event); + auto free_event = std::make_shared(kFree, last_event->index); + free_event->key = item.first; + if (last_event->index < post_compute_events_.size()) { + (void)post_compute_events_[last_event->index].emplace_back(free_event); + } + } +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/memory_offload_strategy.h b/mindspore/ccsrc/runtime/device/memory_offload_strategy.h new file mode 100644 index 00000000000..3f8f2ef7c6a --- /dev/null +++ b/mindspore/ccsrc/runtime/device/memory_offload_strategy.h @@ -0,0 +1,85 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_ +#define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_ +#include +#include +#include +#include +#include + +namespace mindspore { +namespace device { +enum MemPriority { kMemPriorityLow, kMemPriorityHigh }; + +enum MemEventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut }; + +struct MemEvent { + MemEvent(const MemEventType &in_type, size_t in_index) : type(in_type), index(in_index) {} + + MemEventType type; + size_t index{0}; + size_t mem_size{0}; + const void *key{nullptr}; +}; + +class MemOffloadStrategy { + public: + MemOffloadStrategy(const std::map &mem_priority, + const std::map>> &mem_events, + size_t total_step) + : mem_priority_(mem_priority), mem_events_(mem_events), total_step_(total_step) {} + + virtual ~MemOffloadStrategy() = default; + + virtual void Execute(); + + void SetComputeTime(const std::vector &compute_time) { compute_time_ = compute_time; } + + std::vector> &GetPreComputeEvents(size_t step); + + std::vector> &GetPostComputeEvents(size_t step); + + void set_mem_size(size_t mem_size) { mem_size_ = mem_size; } + + bool need_swap() const { return need_swap_; } + + private: + void CountMemUsage(); + void CheckMemSize(); + void GenEventSpan(); + void GenNoSwapEventSet(); + void GenComputeMemEvents(); + + const std::map &mem_priority_; + const std::map>> &mem_events_; + const size_t total_step_; + std::vector>> pre_compute_events_; + std::vector>> post_compute_events_; + + size_t mem_size_{0}; + std::vector compute_time_; + bool need_swap_{false}; + std::multimap> event_span_; + std::set> no_swap_events_; + std::vector min_mem_used_; + size_t mem_used_without_swap_{0}; + size_t min_mem_needed_{0}; +}; +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_ diff --git a/mindspore/ccsrc/runtime/device/memory_scheduler.cc b/mindspore/ccsrc/runtime/device/memory_scheduler.cc index e9fd7f54ffe..8cdda6a2279 100644 --- a/mindspore/ccsrc/runtime/device/memory_scheduler.cc +++ b/mindspore/ccsrc/runtime/device/memory_scheduler.cc @@ -17,9 +17,30 @@ #include "runtime/device/memory_scheduler.h" #include #include "utils/log_adapter.h" +#ifdef _MSC_VER +#include +#else +#include +#endif namespace mindspore { namespace device { +namespace { +constexpr float kMaxMemReuseFactor = 0.9; +constexpr float kMinMemReuseFactor = 0.5; +constexpr float kRetryFactor = 0.1; + +double GetCurrentTime() { +#ifdef _MSC_VER + return time(NULL) * 1.0e6; +#else + struct timeval tv; + (void)gettimeofday(&tv, nullptr); + return tv.tv_sec * 1.0e6 + tv.tv_usec; +#endif +} +} // namespace + void MemScheduler::Clear() { if (mem_handler_ == nullptr) { return; @@ -40,23 +61,26 @@ bool MemScheduler::IsHighPriorityMem(const void *key) { void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; } -void MemScheduler::Record(const void *key, const EventType &event_type, size_t mem_size) { +void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) { if (key == nullptr) { return; } - auto event = std::make_shared(event_type, compute_index_); + auto event = std::make_shared(event_type, current_step_); event->mem_size = mem_size; event->key = key; mem_events_[key].emplace_back(event); + if (step_events_.size() < current_step_ + 1) { + step_events_.resize(current_step_ + 1); + } + step_events_[current_step_].emplace_back(event); } void MemScheduler::Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority) { if (need_record_event_) { mem_priority_[key] = priority; Record(key, kInit, mem_size); - } else { - init_host_ptr_[key] = host_ptr; } + init_host_ptr_[key] = host_ptr; } void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority priority) { @@ -69,7 +93,7 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr } return nullptr; } - if (!has_compute_mem_events_) { + if (strategy_ == nullptr) { return nullptr; } auto iter = mem_result_.find(key); @@ -83,18 +107,14 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr } bool MemScheduler::PreCompute(void *stream) { - if (!has_compute_mem_events_) { + if (strategy_ == nullptr) { return true; } MS_EXCEPTION_IF_NULL(mem_handler_); - if (pre_compute_events_.size() <= compute_index_) { - MS_LOG_EXCEPTION << "Index out of pre event range, index:" << compute_index_ - << ", event size:" << pre_compute_events_.size(); - } - auto &events = pre_compute_events_[compute_index_]; + auto &events = strategy_->GetPreComputeEvents(current_step_); for (auto &event : events) { MS_EXCEPTION_IF_NULL(event); - MS_LOG(DEBUG) << "Pre compute " << compute_index_ << ": " << event->key << " v " << event->type; + MS_LOG(DEBUG) << "Pre compute " << current_step_ << ": " << event->key << " v " << event->type; if (event->type == kInit || event->type == kMalloc) { auto priority = mem_priority_[event->key]; auto iter = high_priority_device_ptr_.find(event->key); @@ -137,23 +157,27 @@ bool MemScheduler::PreCompute(void *stream) { } } } + + if (record_compute_time_ && !updated_) { + compute_start_time_ = GetCurrentTime(); + } return true; } bool MemScheduler::PostCompute(void *stream) { - if (!has_compute_mem_events_) { - ++compute_index_; + if (strategy_ == nullptr) { + ++current_step_; return true; } - MS_EXCEPTION_IF_NULL(mem_handler_); - if (post_compute_events_.size() <= compute_index_) { - MS_LOG_EXCEPTION << "Index out of post event range, index:" << compute_index_ - << ", event size:" << post_compute_events_.size(); + + if (record_compute_time_ && !updated_) { + compute_time_[current_step_] = GetCurrentTime() - compute_start_time_; } - auto &events = post_compute_events_[compute_index_]; + + auto &events = strategy_->GetPostComputeEvents(current_step_); for (auto &event : events) { MS_EXCEPTION_IF_NULL(event); - MS_LOG(DEBUG) << "Post compute " << compute_index_ << ": " << event->key << " v " << event->type; + MS_LOG(DEBUG) << "Post compute " << current_step_ << ": " << event->key << " v " << event->type; if (event->type == kFree) { auto ptr = mem_result_[event->key]; if (ptr == nullptr) { @@ -174,201 +198,81 @@ bool MemScheduler::PostCompute(void *stream) { MS_EXCEPTION_IF_NULL(host_ptr); mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream); mem_handler_->FreeDevice(device_ptr); - (void)mem_result_.erase(device_ptr); + (void)mem_result_.erase(event->key); } } - ++compute_index_; + ++current_step_; return true; } -void MemScheduler::OptMemUsage() { - if (optimized_) { - return; - } - CountMemUsage(); - CheckMemSize(); - if (need_swap_) { - GenEventSpan(); - GenNoSwapEventSet(); - } - GenComputeMemEvents(); - has_compute_mem_events_ = true; -} - -void MemScheduler::CountMemUsage() { - if (!min_mem_used_.empty()) { - return; - } - if (mem_events_.empty() || compute_index_ == 0) { - return; - } - min_mem_used_.resize(compute_index_, 0); - std::vector total_mem_used(compute_index_, 0); - for (auto &item : mem_events_) { - auto &mem_events = item.second; - if (mem_events.empty()) { - continue; - } - auto first_event = mem_events[0]; - size_t cur_index = 0; - if (first_event != nullptr && first_event->type == kInit && mem_events.size() > 1) { - first_event = mem_events[1]; - cur_index = 1; - } - auto last_event = mem_events[mem_events.size() - 1]; - for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) { - if (start_index < compute_index_) { - total_mem_used[start_index] += first_event->mem_size; - } else { - MS_LOG(ERROR) << "Error mem event index " << start_index; - } - } - for (; cur_index < mem_events.size(); ++cur_index) { - auto &event = mem_events[cur_index]; - MS_EXCEPTION_IF_NULL(event); - if (event->index < compute_index_) { - min_mem_used_[event->index] += first_event->mem_size; - } else { - MS_LOG(ERROR) << "Error mem event index " << event->index; - } - } - } - min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end())); - mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())); -} - -void MemScheduler::CheckMemSize() { +void MemScheduler::OptMemUsage(float mem_used_factor) { + mem_used_factor_ = mem_used_factor; MS_EXCEPTION_IF_NULL(mem_handler_); + + if (strategy_ == nullptr) { + strategy_ = std::make_shared(mem_priority_, mem_events_, total_step_); + compute_time_.resize(total_step_); + } + auto available_mem_size = mem_handler_->GetAvailableMemSize(); - if (available_mem_size < min_mem_needed_) { - MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << available_mem_size - << " while graph needs at least " << min_mem_needed_; - } - if (mem_used_without_swap_ > available_mem_size) { - need_swap_ = true; - } - MS_LOG(INFO) << "Available mem size: " << available_mem_size << ", graph needs mem size: " << mem_used_without_swap_ - << " without swap, and needs at least " << min_mem_needed_ << " with swap."; + available_mem_size = available_mem_size * mem_used_factor_; + strategy_->set_mem_size(available_mem_size); + strategy_->Execute(); } -void MemScheduler::GenEventSpan() { - if (!event_span_.empty()) { +void MemScheduler::Optimize() { + float mem_used_factor = kMaxMemReuseFactor; + while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) { + OptMemUsage(mem_used_factor); + current_step_ = 0; + bool ret = true; + for (size_t step = 0; step < total_step_; ++step) { + ret = PreCompute(nullptr); + auto &step_events = step_events_[step]; + for (auto &event : step_events) { + if (event->type != kGet) { + continue; + } + auto ptr = GetOrMalloc(event->key, event->mem_size); + if (ptr == nullptr) { + ret = false; + break; + } + } + if (!ret) { + break; + } + PostCompute(nullptr); + } + if (ret) { + optimized_ = true; + } else { + mem_used_factor -= kRetryFactor; + } + } +} + +void MemScheduler::Update() { + if (!optimized_) { return; } - for (auto &item : mem_events_) { - auto &tensor_events = item.second; - if (tensor_events.empty()) { - continue; - } - auto first_event = tensor_events[0]; - size_t cur_index = 0; - if (first_event != nullptr && first_event->type == kInit && tensor_events.size() > 1) { - first_event = tensor_events[1]; - cur_index = 1; - } - size_t last_index = first_event->index; - for (; cur_index < tensor_events.size(); ++cur_index) { - auto &event = tensor_events[cur_index]; - MS_EXCEPTION_IF_NULL(event); - auto span = event->index - last_index; - if (span > 1) { - (void)event_span_.emplace(std::pair>(span, event)); - } - last_index = event->index; - } - } -} -void MemScheduler::GenNoSwapEventSet() { - MS_EXCEPTION_IF_NULL(mem_handler_); - auto available_mem_size = mem_handler_->GetAvailableMemSize(); - auto threshold = available_mem_size * mem_used_factor_; - no_swap_events_.clear(); - std::vector cur_mem_used(min_mem_used_.begin(), min_mem_used_.end()); - for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) { - auto span = iter->first; - auto &event = iter->second; - auto start_index = event->index - span + 1; - bool revert = false; - for (size_t i = start_index; i < event->index; ++i) { - cur_mem_used[i] += event->mem_size; - if (cur_mem_used[i] > threshold) { - revert = true; - } - } - if (revert) { - for (size_t i = start_index; i < event->index; ++i) { - cur_mem_used[i] -= event->mem_size; - } - } else { - (void)no_swap_events_.emplace(event); - } + if (strategy_ == nullptr || !strategy_->need_swap()) { + return; } -} -void MemScheduler::GenComputeMemEvents() { - pre_compute_events_.clear(); - post_compute_events_.clear(); - pre_compute_events_.resize(compute_index_); - post_compute_events_.resize(compute_index_); - for (auto &item : mem_events_) { - auto &mem_events = item.second; - if (mem_events.empty()) { - continue; - } - auto first_event = mem_events[0]; - MS_EXCEPTION_IF_NULL(first_event); - if (first_event->type == kInit) { - if (mem_events.size() > 1) { - auto &second_event = mem_events[1]; - MS_EXCEPTION_IF_NULL(second_event); - first_event->index = second_event->index; - } else { - continue; - } - } - if ((first_event->type == kInit || first_event->type == kMalloc) && - first_event->index < pre_compute_events_.size()) { - pre_compute_events_[first_event->index].emplace_back(first_event); - } else { - MS_LOG_EXCEPTION << "First event should be init or malloc!"; - } - MemPriority priority = kMemPriorityLow; - auto iter = mem_priority_.find(first_event->key); - if (iter != mem_priority_.end()) { - priority = iter->second; - } - size_t pre_index = first_event->index; - for (size_t i = 1; i < mem_events.size(); ++i) { - auto &event = mem_events[i]; - MS_EXCEPTION_IF_NULL(event); - if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow && - no_swap_events_.find(event) == no_swap_events_.end()) { - auto swap_out_event = std::make_shared(kSwapOut, pre_index); - swap_out_event->key = item.first; - swap_out_event->mem_size = first_event->mem_size; - post_compute_events_[pre_index].emplace_back(swap_out_event); - auto swap_in_event = std::make_shared(kSwapIn, event->index); - swap_in_event->key = item.first; - swap_in_event->mem_size = first_event->mem_size; - (void)pre_compute_events_[event->index].emplace_back(swap_in_event); - } - if (event->index < pre_compute_events_.size()) { - (void)pre_compute_events_[event->index].emplace_back(event); - } - pre_index = event->index; - } - if (priority != kMemPriorityLow) { - continue; - } - auto &last_event = mem_events[mem_events.size() - 1]; - MS_EXCEPTION_IF_NULL(last_event); - auto free_event = std::make_shared(kFree, last_event->index); - free_event->key = item.first; - if (last_event->index < post_compute_events_.size()) { - (void)post_compute_events_[last_event->index].emplace_back(free_event); - } + if (updated_) { + return; } + + if (!record_compute_time_) { + record_compute_time_ = true; + return; + } + + strategy_->SetComputeTime(compute_time_); + strategy_->Execute(); + updated_ = true; } } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/memory_scheduler.h b/mindspore/ccsrc/runtime/device/memory_scheduler.h index e2783b5765b..8c7106f45fd 100644 --- a/mindspore/ccsrc/runtime/device/memory_scheduler.h +++ b/mindspore/ccsrc/runtime/device/memory_scheduler.h @@ -21,6 +21,7 @@ #include #include #include +#include "runtime/device/memory_offload_strategy.h" namespace mindspore { namespace device { @@ -35,23 +36,7 @@ class MemHandler { virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0; }; -enum MemPriority { kMemPriorityLow, kMemPriorityHigh }; - class MemScheduler { - enum EventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut }; - - struct Event { - Event(const EventType &in_type, size_t in_index) { - type = in_type; - index = in_index; - } - - EventType type; - size_t index{0}; - size_t mem_size{0}; - const void *key{nullptr}; - }; - public: MemScheduler() = default; ~MemScheduler() = default; @@ -62,7 +47,7 @@ class MemScheduler { bool optimized() const { return optimized_; } - void set_optimized(bool flag) { optimized_ = flag; } + void Update(); void SetMemHandler(const std::shared_ptr &handler) { mem_handler_ = handler; } @@ -70,13 +55,18 @@ class MemScheduler { void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow); - void Reset() { compute_index_ = 0; } + void SetTotalStep(size_t step) { + total_step_ = step; + step_events_.resize(total_step_); + } + + void ResetCurrentStep() { current_step_ = 0; } bool PreCompute(void *stream); bool PostCompute(void *stream); - void OptMemUsage(); + void Optimize(); void Clear(); @@ -84,37 +74,29 @@ class MemScheduler { void SetMemPriority(const void *key, MemPriority priority); - void SetMemUsedFactor(float factor) { mem_used_factor_ = factor; } - - void SetNeedSwap(bool flag) { need_swap_ = flag; } - private: - void Record(const void *key, const EventType &event_type, size_t mem_size = 0); - void GenComputeMemEvents(); - void CheckMemSize(); - void CountMemUsage(); - void GenEventSpan(); - void GenNoSwapEventSet(); + void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0); + + void OptMemUsage(float mem_used_factor = 1.0f); + std::map mem_priority_; - std::map>> mem_events_; - std::vector>> pre_compute_events_; - std::vector>> post_compute_events_; + std::map>> mem_events_; + std::vector>> step_events_; std::map mem_result_; std::map init_host_ptr_; std::map swap_host_ptr_; std::map high_priority_device_ptr_; - size_t compute_index_{0}; + size_t total_step_{0}; + size_t current_step_{0}; bool need_record_event_{true}; bool optimized_{false}; - bool has_compute_mem_events_{false}; - std::shared_ptr mem_handler_{nullptr}; - bool need_swap_{false}; - std::multimap> event_span_; - std::set> no_swap_events_; - std::vector min_mem_used_; - size_t mem_used_without_swap_{0}; - size_t min_mem_needed_{0}; float mem_used_factor_{0.9}; + double compute_start_time_{0}; + std::vector compute_time_; + bool record_compute_time_{false}; + bool updated_{false}; + std::shared_ptr mem_handler_{nullptr}; + std::shared_ptr strategy_{nullptr}; }; class MemSchedulerManager { diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 0290e791c16..bee31bd4321 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -114,6 +114,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/runtime/device/kernel_runtime.cc" "../../../mindspore/ccsrc/runtime/device/memory_manager.cc" "../../../mindspore/ccsrc/runtime/device/memory_scheduler.cc" + "../../../mindspore/ccsrc/runtime/device/memory_offload_strategy.cc" "../../../mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc" "../../../mindspore/ccsrc/runtime/device/kernel_info.cc" "../../../mindspore/ccsrc/runtime/device/bucket.cc" diff --git a/tests/ut/cpp/device/mem_scheduler_test.cc b/tests/ut/cpp/device/mem_scheduler_test.cc index a769c76e2fd..08e8711a9ed 100644 --- a/tests/ut/cpp/device/mem_scheduler_test.cc +++ b/tests/ut/cpp/device/mem_scheduler_test.cc @@ -104,6 +104,7 @@ TEST_F(TestMemScheduler, test_mem_scheduler) { std::vector init_tensors = {0, 2, 4}; std::vector> step_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6}, {6, 7}, {2, 7, 8}, {4, 8, 9}}; void *stream = nullptr; + scheduler->SetTotalStep(kTimeSlice); // record for (auto index : init_tensors) { scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh); @@ -118,11 +119,10 @@ TEST_F(TestMemScheduler, test_mem_scheduler) { scheduler->set_need_record_event(false); // optimize - scheduler->OptMemUsage(); - scheduler->set_optimized(true); + scheduler->Optimize(); // run - scheduler->Reset(); + scheduler->ResetCurrentStep(); for (auto index : init_tensors) { scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh); }