forked from mindspore-Ecosystem/mindspore
add mem offload strategy
This commit is contained in:
parent
9522ee9686
commit
7c312bd38c
|
@ -1,7 +1,7 @@
|
|||
file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc"
|
||||
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc"
|
||||
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc"
|
||||
"bucket.cc" "launch_kernel.cc" "launch_mul.cc"
|
||||
"memory_offload_strategy.cc" "bucket.cc" "launch_kernel.cc" "launch_mul.cc"
|
||||
)
|
||||
|
||||
if("${ENABLE_HIDDEN}" STREQUAL "OFF")
|
||||
|
|
|
@ -42,9 +42,6 @@ using mindspore::kernel::AddressPtr;
|
|||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
constexpr float kMaxMemReuseFactor = 0.8;
|
||||
constexpr float kMinMemReuseFactor = 0.5;
|
||||
constexpr float kRetryFactor = 0.1;
|
||||
constexpr size_t kAtomicCleanInputSize = 2;
|
||||
namespace {
|
||||
std::vector<AnfNodePtr> GetGraphInputs(const session::KernelGraph &graph) {
|
||||
|
@ -1494,13 +1491,15 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
|
|||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
std::shared_ptr<MemScheduler> mem_scheduler = nullptr;
|
||||
|
||||
if (UseMemScheduler()) {
|
||||
mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
|
||||
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||
mem_scheduler->SetMemHandler(mem_manager_);
|
||||
mem_scheduler->Reset();
|
||||
mem_scheduler->ResetCurrentStep();
|
||||
mem_scheduler->Update();
|
||||
InitGraphInputTensors(mem_scheduler, graph);
|
||||
}
|
||||
|
||||
const auto &kernels = graph.execution_order();
|
||||
std::vector<DynamicKernelPtr> dynamic_kernel_list;
|
||||
auto iter = graph_dynamic_kernel_map_.find(graph.graph_id());
|
||||
|
@ -1555,7 +1554,6 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
|
|||
}
|
||||
LaunchKernelEvent(kernel_post_run_events, i);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1566,21 +1564,15 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
|
|||
return;
|
||||
}
|
||||
auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
|
||||
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||
mem_scheduler->SetMemHandler(mem_manager_);
|
||||
mem_scheduler->SetTotalStep(graph.execution_order().size());
|
||||
|
||||
if (mem_scheduler->need_record_event()) {
|
||||
(void)LaunchKernelMod(graph, true);
|
||||
mem_scheduler->set_need_record_event(false);
|
||||
}
|
||||
float mem_used_factor = kMaxMemReuseFactor;
|
||||
while (!mem_scheduler->optimized() && mem_used_factor >= kMinMemReuseFactor) {
|
||||
mem_scheduler->SetMemUsedFactor(mem_used_factor);
|
||||
mem_scheduler->OptMemUsage();
|
||||
bool ret = LaunchKernelMod(graph, true);
|
||||
if (ret) {
|
||||
mem_scheduler->set_optimized(true);
|
||||
} else {
|
||||
mem_used_factor -= kRetryFactor;
|
||||
}
|
||||
}
|
||||
mem_scheduler->Optimize();
|
||||
if (!mem_scheduler->optimized()) {
|
||||
MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit.";
|
||||
}
|
||||
|
|
|
@ -0,0 +1,224 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "runtime/device/memory_offload_strategy.h"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
std::vector<std::shared_ptr<MemEvent>> &MemOffloadStrategy::GetPreComputeEvents(size_t step) {
|
||||
if (pre_compute_events_.size() <= step) {
|
||||
MS_LOG_EXCEPTION << "Index out of pre event range, index:" << step << ", event size:" << pre_compute_events_.size();
|
||||
}
|
||||
return pre_compute_events_[step];
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<MemEvent>> &MemOffloadStrategy::GetPostComputeEvents(size_t step) {
|
||||
if (post_compute_events_.size() <= step) {
|
||||
MS_LOG_EXCEPTION << "Index out of post event range, index:" << step
|
||||
<< ", event size:" << post_compute_events_.size();
|
||||
}
|
||||
return post_compute_events_[step];
|
||||
}
|
||||
|
||||
void MemOffloadStrategy::Execute() {
|
||||
CountMemUsage();
|
||||
CheckMemSize();
|
||||
if (need_swap_) {
|
||||
GenEventSpan();
|
||||
GenNoSwapEventSet();
|
||||
}
|
||||
GenComputeMemEvents();
|
||||
}
|
||||
|
||||
void MemOffloadStrategy::CountMemUsage() {
|
||||
if (!min_mem_used_.empty()) {
|
||||
return;
|
||||
}
|
||||
if (mem_events_.empty() || total_step_ == 0) {
|
||||
return;
|
||||
}
|
||||
min_mem_used_.resize(total_step_, 0);
|
||||
std::vector<size_t> total_mem_used(total_step_, 0);
|
||||
for (auto &item : mem_events_) {
|
||||
auto &mem_events = item.second;
|
||||
if (mem_events.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto first_event = mem_events[0];
|
||||
size_t cur_index = 0;
|
||||
if (first_event != nullptr && first_event->type == kInit && mem_events.size() > 1) {
|
||||
first_event = mem_events[1];
|
||||
cur_index = 1;
|
||||
}
|
||||
auto last_event = mem_events[mem_events.size() - 1];
|
||||
for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
|
||||
if (start_index < total_step_) {
|
||||
total_mem_used[start_index] += first_event->mem_size;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Error mem event index " << start_index;
|
||||
}
|
||||
}
|
||||
for (; cur_index < mem_events.size(); ++cur_index) {
|
||||
auto &event = mem_events[cur_index];
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
if (event->index < total_step_) {
|
||||
min_mem_used_[event->index] += first_event->mem_size;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Error mem event index " << event->index;
|
||||
}
|
||||
}
|
||||
}
|
||||
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
|
||||
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end()));
|
||||
}
|
||||
|
||||
void MemOffloadStrategy::CheckMemSize() {
|
||||
if (mem_size_ < min_mem_needed_) {
|
||||
MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << mem_size_ << " while graph needs at least "
|
||||
<< min_mem_needed_;
|
||||
}
|
||||
|
||||
if (mem_size_ < mem_used_without_swap_) {
|
||||
need_swap_ = true;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Available mem size: " << mem_size_ << ", graph needs mem size: " << mem_used_without_swap_
|
||||
<< " without swap, and needs at least " << min_mem_needed_ << " with swap.";
|
||||
}
|
||||
|
||||
void MemOffloadStrategy::GenEventSpan() {
|
||||
if (!event_span_.empty()) {
|
||||
return;
|
||||
}
|
||||
for (auto &item : mem_events_) {
|
||||
auto &tensor_events = item.second;
|
||||
if (tensor_events.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto first_event = tensor_events[0];
|
||||
size_t cur_index = 0;
|
||||
if (first_event != nullptr && first_event->type == kInit && tensor_events.size() > 1) {
|
||||
first_event = tensor_events[1];
|
||||
cur_index = 1;
|
||||
}
|
||||
size_t last_index = first_event->index;
|
||||
for (; cur_index < tensor_events.size(); ++cur_index) {
|
||||
auto &event = tensor_events[cur_index];
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
auto span = event->index - last_index;
|
||||
if (span > 1) {
|
||||
(void)event_span_.emplace(span, event);
|
||||
}
|
||||
last_index = event->index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MemOffloadStrategy::GenNoSwapEventSet() {
|
||||
no_swap_events_.clear();
|
||||
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
|
||||
for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) {
|
||||
auto span = iter->first;
|
||||
auto &event = iter->second;
|
||||
auto start_index = event->index - span + 1;
|
||||
bool revert = false;
|
||||
for (size_t i = start_index; i < event->index; ++i) {
|
||||
cur_mem_used[i] += event->mem_size;
|
||||
if (cur_mem_used[i] > mem_size_) {
|
||||
revert = true;
|
||||
}
|
||||
}
|
||||
if (revert) {
|
||||
for (size_t i = start_index; i < event->index; ++i) {
|
||||
cur_mem_used[i] -= event->mem_size;
|
||||
}
|
||||
} else {
|
||||
(void)no_swap_events_.emplace(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MemOffloadStrategy::GenComputeMemEvents() {
|
||||
pre_compute_events_.clear();
|
||||
post_compute_events_.clear();
|
||||
pre_compute_events_.resize(total_step_);
|
||||
post_compute_events_.resize(total_step_);
|
||||
for (auto &item : mem_events_) {
|
||||
auto &mem_events = item.second;
|
||||
if (mem_events.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto first_event = mem_events[0];
|
||||
MS_EXCEPTION_IF_NULL(first_event);
|
||||
if (first_event->type == kInit) {
|
||||
if (mem_events.size() > 1) {
|
||||
auto &second_event = mem_events[1];
|
||||
MS_EXCEPTION_IF_NULL(second_event);
|
||||
first_event->index = second_event->index;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if ((first_event->type == kInit || first_event->type == kMalloc) &&
|
||||
first_event->index < pre_compute_events_.size()) {
|
||||
pre_compute_events_[first_event->index].emplace_back(first_event);
|
||||
} else {
|
||||
MS_LOG_EXCEPTION << "First event should be init or malloc!";
|
||||
}
|
||||
MemPriority priority = kMemPriorityLow;
|
||||
auto iter = mem_priority_.find(first_event->key);
|
||||
if (iter != mem_priority_.end()) {
|
||||
priority = iter->second;
|
||||
}
|
||||
size_t pre_index = first_event->index;
|
||||
for (size_t i = 1; i < mem_events.size(); ++i) {
|
||||
auto &event = mem_events[i];
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow &&
|
||||
no_swap_events_.find(event) == no_swap_events_.end()) {
|
||||
auto swap_out_event = std::make_shared<MemEvent>(kSwapOut, pre_index);
|
||||
swap_out_event->key = item.first;
|
||||
swap_out_event->mem_size = first_event->mem_size;
|
||||
post_compute_events_[pre_index].emplace_back(swap_out_event);
|
||||
auto swap_in_event = std::make_shared<MemEvent>(kSwapIn, event->index);
|
||||
swap_in_event->key = item.first;
|
||||
swap_in_event->mem_size = first_event->mem_size;
|
||||
(void)pre_compute_events_[event->index].emplace_back(swap_in_event);
|
||||
}
|
||||
if (event->index < pre_compute_events_.size()) {
|
||||
(void)pre_compute_events_[event->index].emplace_back(event);
|
||||
}
|
||||
pre_index = event->index;
|
||||
}
|
||||
if (priority != kMemPriorityLow) {
|
||||
continue;
|
||||
}
|
||||
auto &last_event = mem_events[mem_events.size() - 1];
|
||||
MS_EXCEPTION_IF_NULL(last_event);
|
||||
auto free_event = std::make_shared<MemEvent>(kFree, last_event->index);
|
||||
free_event->key = item.first;
|
||||
if (last_event->index < post_compute_events_.size()) {
|
||||
(void)post_compute_events_[last_event->index].emplace_back(free_event);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
enum MemPriority { kMemPriorityLow, kMemPriorityHigh };
|
||||
|
||||
enum MemEventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut };
|
||||
|
||||
struct MemEvent {
|
||||
MemEvent(const MemEventType &in_type, size_t in_index) : type(in_type), index(in_index) {}
|
||||
|
||||
MemEventType type;
|
||||
size_t index{0};
|
||||
size_t mem_size{0};
|
||||
const void *key{nullptr};
|
||||
};
|
||||
|
||||
class MemOffloadStrategy {
|
||||
public:
|
||||
MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority,
|
||||
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events,
|
||||
size_t total_step)
|
||||
: mem_priority_(mem_priority), mem_events_(mem_events), total_step_(total_step) {}
|
||||
|
||||
virtual ~MemOffloadStrategy() = default;
|
||||
|
||||
virtual void Execute();
|
||||
|
||||
void SetComputeTime(const std::vector<double> &compute_time) { compute_time_ = compute_time; }
|
||||
|
||||
std::vector<std::shared_ptr<MemEvent>> &GetPreComputeEvents(size_t step);
|
||||
|
||||
std::vector<std::shared_ptr<MemEvent>> &GetPostComputeEvents(size_t step);
|
||||
|
||||
void set_mem_size(size_t mem_size) { mem_size_ = mem_size; }
|
||||
|
||||
bool need_swap() const { return need_swap_; }
|
||||
|
||||
private:
|
||||
void CountMemUsage();
|
||||
void CheckMemSize();
|
||||
void GenEventSpan();
|
||||
void GenNoSwapEventSet();
|
||||
void GenComputeMemEvents();
|
||||
|
||||
const std::map<const void *, MemPriority> &mem_priority_;
|
||||
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_;
|
||||
const size_t total_step_;
|
||||
std::vector<std::vector<std::shared_ptr<MemEvent>>> pre_compute_events_;
|
||||
std::vector<std::vector<std::shared_ptr<MemEvent>>> post_compute_events_;
|
||||
|
||||
size_t mem_size_{0};
|
||||
std::vector<double> compute_time_;
|
||||
bool need_swap_{false};
|
||||
std::multimap<size_t, std::shared_ptr<MemEvent>> event_span_;
|
||||
std::set<std::shared_ptr<MemEvent>> no_swap_events_;
|
||||
std::vector<size_t> min_mem_used_;
|
||||
size_t mem_used_without_swap_{0};
|
||||
size_t min_mem_needed_{0};
|
||||
};
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_
|
|
@ -17,9 +17,30 @@
|
|||
#include "runtime/device/memory_scheduler.h"
|
||||
#include <algorithm>
|
||||
#include "utils/log_adapter.h"
|
||||
#ifdef _MSC_VER
|
||||
#include <time.h>
|
||||
#else
|
||||
#include <sys/time.h>
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace {
|
||||
constexpr float kMaxMemReuseFactor = 0.9;
|
||||
constexpr float kMinMemReuseFactor = 0.5;
|
||||
constexpr float kRetryFactor = 0.1;
|
||||
|
||||
double GetCurrentTime() {
|
||||
#ifdef _MSC_VER
|
||||
return time(NULL) * 1.0e6;
|
||||
#else
|
||||
struct timeval tv;
|
||||
(void)gettimeofday(&tv, nullptr);
|
||||
return tv.tv_sec * 1.0e6 + tv.tv_usec;
|
||||
#endif
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void MemScheduler::Clear() {
|
||||
if (mem_handler_ == nullptr) {
|
||||
return;
|
||||
|
@ -40,23 +61,26 @@ bool MemScheduler::IsHighPriorityMem(const void *key) {
|
|||
|
||||
void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
|
||||
|
||||
void MemScheduler::Record(const void *key, const EventType &event_type, size_t mem_size) {
|
||||
void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
|
||||
if (key == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto event = std::make_shared<Event>(event_type, compute_index_);
|
||||
auto event = std::make_shared<MemEvent>(event_type, current_step_);
|
||||
event->mem_size = mem_size;
|
||||
event->key = key;
|
||||
mem_events_[key].emplace_back(event);
|
||||
if (step_events_.size() < current_step_ + 1) {
|
||||
step_events_.resize(current_step_ + 1);
|
||||
}
|
||||
step_events_[current_step_].emplace_back(event);
|
||||
}
|
||||
|
||||
void MemScheduler::Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority) {
|
||||
if (need_record_event_) {
|
||||
mem_priority_[key] = priority;
|
||||
Record(key, kInit, mem_size);
|
||||
} else {
|
||||
init_host_ptr_[key] = host_ptr;
|
||||
}
|
||||
init_host_ptr_[key] = host_ptr;
|
||||
}
|
||||
|
||||
void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority priority) {
|
||||
|
@ -69,7 +93,7 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
if (!has_compute_mem_events_) {
|
||||
if (strategy_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto iter = mem_result_.find(key);
|
||||
|
@ -83,18 +107,14 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
|
|||
}
|
||||
|
||||
bool MemScheduler::PreCompute(void *stream) {
|
||||
if (!has_compute_mem_events_) {
|
||||
if (strategy_ == nullptr) {
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
||||
if (pre_compute_events_.size() <= compute_index_) {
|
||||
MS_LOG_EXCEPTION << "Index out of pre event range, index:" << compute_index_
|
||||
<< ", event size:" << pre_compute_events_.size();
|
||||
}
|
||||
auto &events = pre_compute_events_[compute_index_];
|
||||
auto &events = strategy_->GetPreComputeEvents(current_step_);
|
||||
for (auto &event : events) {
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
MS_LOG(DEBUG) << "Pre compute " << compute_index_ << ": " << event->key << " v " << event->type;
|
||||
MS_LOG(DEBUG) << "Pre compute " << current_step_ << ": " << event->key << " v " << event->type;
|
||||
if (event->type == kInit || event->type == kMalloc) {
|
||||
auto priority = mem_priority_[event->key];
|
||||
auto iter = high_priority_device_ptr_.find(event->key);
|
||||
|
@ -137,23 +157,27 @@ bool MemScheduler::PreCompute(void *stream) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (record_compute_time_ && !updated_) {
|
||||
compute_start_time_ = GetCurrentTime();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MemScheduler::PostCompute(void *stream) {
|
||||
if (!has_compute_mem_events_) {
|
||||
++compute_index_;
|
||||
if (strategy_ == nullptr) {
|
||||
++current_step_;
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
||||
if (post_compute_events_.size() <= compute_index_) {
|
||||
MS_LOG_EXCEPTION << "Index out of post event range, index:" << compute_index_
|
||||
<< ", event size:" << post_compute_events_.size();
|
||||
|
||||
if (record_compute_time_ && !updated_) {
|
||||
compute_time_[current_step_] = GetCurrentTime() - compute_start_time_;
|
||||
}
|
||||
auto &events = post_compute_events_[compute_index_];
|
||||
|
||||
auto &events = strategy_->GetPostComputeEvents(current_step_);
|
||||
for (auto &event : events) {
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
MS_LOG(DEBUG) << "Post compute " << compute_index_ << ": " << event->key << " v " << event->type;
|
||||
MS_LOG(DEBUG) << "Post compute " << current_step_ << ": " << event->key << " v " << event->type;
|
||||
if (event->type == kFree) {
|
||||
auto ptr = mem_result_[event->key];
|
||||
if (ptr == nullptr) {
|
||||
|
@ -174,201 +198,81 @@ bool MemScheduler::PostCompute(void *stream) {
|
|||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
|
||||
mem_handler_->FreeDevice(device_ptr);
|
||||
(void)mem_result_.erase(device_ptr);
|
||||
(void)mem_result_.erase(event->key);
|
||||
}
|
||||
}
|
||||
++compute_index_;
|
||||
++current_step_;
|
||||
return true;
|
||||
}
|
||||
|
||||
void MemScheduler::OptMemUsage() {
|
||||
if (optimized_) {
|
||||
return;
|
||||
}
|
||||
CountMemUsage();
|
||||
CheckMemSize();
|
||||
if (need_swap_) {
|
||||
GenEventSpan();
|
||||
GenNoSwapEventSet();
|
||||
}
|
||||
GenComputeMemEvents();
|
||||
has_compute_mem_events_ = true;
|
||||
}
|
||||
|
||||
void MemScheduler::CountMemUsage() {
|
||||
if (!min_mem_used_.empty()) {
|
||||
return;
|
||||
}
|
||||
if (mem_events_.empty() || compute_index_ == 0) {
|
||||
return;
|
||||
}
|
||||
min_mem_used_.resize(compute_index_, 0);
|
||||
std::vector<size_t> total_mem_used(compute_index_, 0);
|
||||
for (auto &item : mem_events_) {
|
||||
auto &mem_events = item.second;
|
||||
if (mem_events.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto first_event = mem_events[0];
|
||||
size_t cur_index = 0;
|
||||
if (first_event != nullptr && first_event->type == kInit && mem_events.size() > 1) {
|
||||
first_event = mem_events[1];
|
||||
cur_index = 1;
|
||||
}
|
||||
auto last_event = mem_events[mem_events.size() - 1];
|
||||
for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
|
||||
if (start_index < compute_index_) {
|
||||
total_mem_used[start_index] += first_event->mem_size;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Error mem event index " << start_index;
|
||||
}
|
||||
}
|
||||
for (; cur_index < mem_events.size(); ++cur_index) {
|
||||
auto &event = mem_events[cur_index];
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
if (event->index < compute_index_) {
|
||||
min_mem_used_[event->index] += first_event->mem_size;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Error mem event index " << event->index;
|
||||
}
|
||||
}
|
||||
}
|
||||
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
|
||||
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end()));
|
||||
}
|
||||
|
||||
void MemScheduler::CheckMemSize() {
|
||||
void MemScheduler::OptMemUsage(float mem_used_factor) {
|
||||
mem_used_factor_ = mem_used_factor;
|
||||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
||||
|
||||
if (strategy_ == nullptr) {
|
||||
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, total_step_);
|
||||
compute_time_.resize(total_step_);
|
||||
}
|
||||
|
||||
auto available_mem_size = mem_handler_->GetAvailableMemSize();
|
||||
if (available_mem_size < min_mem_needed_) {
|
||||
MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << available_mem_size
|
||||
<< " while graph needs at least " << min_mem_needed_;
|
||||
}
|
||||
if (mem_used_without_swap_ > available_mem_size) {
|
||||
need_swap_ = true;
|
||||
}
|
||||
MS_LOG(INFO) << "Available mem size: " << available_mem_size << ", graph needs mem size: " << mem_used_without_swap_
|
||||
<< " without swap, and needs at least " << min_mem_needed_ << " with swap.";
|
||||
available_mem_size = available_mem_size * mem_used_factor_;
|
||||
strategy_->set_mem_size(available_mem_size);
|
||||
strategy_->Execute();
|
||||
}
|
||||
|
||||
void MemScheduler::GenEventSpan() {
|
||||
if (!event_span_.empty()) {
|
||||
void MemScheduler::Optimize() {
|
||||
float mem_used_factor = kMaxMemReuseFactor;
|
||||
while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) {
|
||||
OptMemUsage(mem_used_factor);
|
||||
current_step_ = 0;
|
||||
bool ret = true;
|
||||
for (size_t step = 0; step < total_step_; ++step) {
|
||||
ret = PreCompute(nullptr);
|
||||
auto &step_events = step_events_[step];
|
||||
for (auto &event : step_events) {
|
||||
if (event->type != kGet) {
|
||||
continue;
|
||||
}
|
||||
auto ptr = GetOrMalloc(event->key, event->mem_size);
|
||||
if (ptr == nullptr) {
|
||||
ret = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!ret) {
|
||||
break;
|
||||
}
|
||||
PostCompute(nullptr);
|
||||
}
|
||||
if (ret) {
|
||||
optimized_ = true;
|
||||
} else {
|
||||
mem_used_factor -= kRetryFactor;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MemScheduler::Update() {
|
||||
if (!optimized_) {
|
||||
return;
|
||||
}
|
||||
for (auto &item : mem_events_) {
|
||||
auto &tensor_events = item.second;
|
||||
if (tensor_events.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto first_event = tensor_events[0];
|
||||
size_t cur_index = 0;
|
||||
if (first_event != nullptr && first_event->type == kInit && tensor_events.size() > 1) {
|
||||
first_event = tensor_events[1];
|
||||
cur_index = 1;
|
||||
}
|
||||
size_t last_index = first_event->index;
|
||||
for (; cur_index < tensor_events.size(); ++cur_index) {
|
||||
auto &event = tensor_events[cur_index];
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
auto span = event->index - last_index;
|
||||
if (span > 1) {
|
||||
(void)event_span_.emplace(std::pair<size_t, std::shared_ptr<Event>>(span, event));
|
||||
}
|
||||
last_index = event->index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MemScheduler::GenNoSwapEventSet() {
|
||||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
||||
auto available_mem_size = mem_handler_->GetAvailableMemSize();
|
||||
auto threshold = available_mem_size * mem_used_factor_;
|
||||
no_swap_events_.clear();
|
||||
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
|
||||
for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) {
|
||||
auto span = iter->first;
|
||||
auto &event = iter->second;
|
||||
auto start_index = event->index - span + 1;
|
||||
bool revert = false;
|
||||
for (size_t i = start_index; i < event->index; ++i) {
|
||||
cur_mem_used[i] += event->mem_size;
|
||||
if (cur_mem_used[i] > threshold) {
|
||||
revert = true;
|
||||
}
|
||||
}
|
||||
if (revert) {
|
||||
for (size_t i = start_index; i < event->index; ++i) {
|
||||
cur_mem_used[i] -= event->mem_size;
|
||||
}
|
||||
} else {
|
||||
(void)no_swap_events_.emplace(event);
|
||||
}
|
||||
if (strategy_ == nullptr || !strategy_->need_swap()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void MemScheduler::GenComputeMemEvents() {
|
||||
pre_compute_events_.clear();
|
||||
post_compute_events_.clear();
|
||||
pre_compute_events_.resize(compute_index_);
|
||||
post_compute_events_.resize(compute_index_);
|
||||
for (auto &item : mem_events_) {
|
||||
auto &mem_events = item.second;
|
||||
if (mem_events.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto first_event = mem_events[0];
|
||||
MS_EXCEPTION_IF_NULL(first_event);
|
||||
if (first_event->type == kInit) {
|
||||
if (mem_events.size() > 1) {
|
||||
auto &second_event = mem_events[1];
|
||||
MS_EXCEPTION_IF_NULL(second_event);
|
||||
first_event->index = second_event->index;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if ((first_event->type == kInit || first_event->type == kMalloc) &&
|
||||
first_event->index < pre_compute_events_.size()) {
|
||||
pre_compute_events_[first_event->index].emplace_back(first_event);
|
||||
} else {
|
||||
MS_LOG_EXCEPTION << "First event should be init or malloc!";
|
||||
}
|
||||
MemPriority priority = kMemPriorityLow;
|
||||
auto iter = mem_priority_.find(first_event->key);
|
||||
if (iter != mem_priority_.end()) {
|
||||
priority = iter->second;
|
||||
}
|
||||
size_t pre_index = first_event->index;
|
||||
for (size_t i = 1; i < mem_events.size(); ++i) {
|
||||
auto &event = mem_events[i];
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow &&
|
||||
no_swap_events_.find(event) == no_swap_events_.end()) {
|
||||
auto swap_out_event = std::make_shared<Event>(kSwapOut, pre_index);
|
||||
swap_out_event->key = item.first;
|
||||
swap_out_event->mem_size = first_event->mem_size;
|
||||
post_compute_events_[pre_index].emplace_back(swap_out_event);
|
||||
auto swap_in_event = std::make_shared<Event>(kSwapIn, event->index);
|
||||
swap_in_event->key = item.first;
|
||||
swap_in_event->mem_size = first_event->mem_size;
|
||||
(void)pre_compute_events_[event->index].emplace_back(swap_in_event);
|
||||
}
|
||||
if (event->index < pre_compute_events_.size()) {
|
||||
(void)pre_compute_events_[event->index].emplace_back(event);
|
||||
}
|
||||
pre_index = event->index;
|
||||
}
|
||||
if (priority != kMemPriorityLow) {
|
||||
continue;
|
||||
}
|
||||
auto &last_event = mem_events[mem_events.size() - 1];
|
||||
MS_EXCEPTION_IF_NULL(last_event);
|
||||
auto free_event = std::make_shared<Event>(kFree, last_event->index);
|
||||
free_event->key = item.first;
|
||||
if (last_event->index < post_compute_events_.size()) {
|
||||
(void)post_compute_events_[last_event->index].emplace_back(free_event);
|
||||
}
|
||||
if (updated_) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!record_compute_time_) {
|
||||
record_compute_time_ = true;
|
||||
return;
|
||||
}
|
||||
|
||||
strategy_->SetComputeTime(compute_time_);
|
||||
strategy_->Execute();
|
||||
updated_ = true;
|
||||
}
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <set>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "runtime/device/memory_offload_strategy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -35,23 +36,7 @@ class MemHandler {
|
|||
virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0;
|
||||
};
|
||||
|
||||
enum MemPriority { kMemPriorityLow, kMemPriorityHigh };
|
||||
|
||||
class MemScheduler {
|
||||
enum EventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut };
|
||||
|
||||
struct Event {
|
||||
Event(const EventType &in_type, size_t in_index) {
|
||||
type = in_type;
|
||||
index = in_index;
|
||||
}
|
||||
|
||||
EventType type;
|
||||
size_t index{0};
|
||||
size_t mem_size{0};
|
||||
const void *key{nullptr};
|
||||
};
|
||||
|
||||
public:
|
||||
MemScheduler() = default;
|
||||
~MemScheduler() = default;
|
||||
|
@ -62,7 +47,7 @@ class MemScheduler {
|
|||
|
||||
bool optimized() const { return optimized_; }
|
||||
|
||||
void set_optimized(bool flag) { optimized_ = flag; }
|
||||
void Update();
|
||||
|
||||
void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
|
||||
|
||||
|
@ -70,13 +55,18 @@ class MemScheduler {
|
|||
|
||||
void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow);
|
||||
|
||||
void Reset() { compute_index_ = 0; }
|
||||
void SetTotalStep(size_t step) {
|
||||
total_step_ = step;
|
||||
step_events_.resize(total_step_);
|
||||
}
|
||||
|
||||
void ResetCurrentStep() { current_step_ = 0; }
|
||||
|
||||
bool PreCompute(void *stream);
|
||||
|
||||
bool PostCompute(void *stream);
|
||||
|
||||
void OptMemUsage();
|
||||
void Optimize();
|
||||
|
||||
void Clear();
|
||||
|
||||
|
@ -84,37 +74,29 @@ class MemScheduler {
|
|||
|
||||
void SetMemPriority(const void *key, MemPriority priority);
|
||||
|
||||
void SetMemUsedFactor(float factor) { mem_used_factor_ = factor; }
|
||||
|
||||
void SetNeedSwap(bool flag) { need_swap_ = flag; }
|
||||
|
||||
private:
|
||||
void Record(const void *key, const EventType &event_type, size_t mem_size = 0);
|
||||
void GenComputeMemEvents();
|
||||
void CheckMemSize();
|
||||
void CountMemUsage();
|
||||
void GenEventSpan();
|
||||
void GenNoSwapEventSet();
|
||||
void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
|
||||
|
||||
void OptMemUsage(float mem_used_factor = 1.0f);
|
||||
|
||||
std::map<const void *, MemPriority> mem_priority_;
|
||||
std::map<const void *, std::vector<std::shared_ptr<Event>>> mem_events_;
|
||||
std::vector<std::vector<std::shared_ptr<Event>>> pre_compute_events_;
|
||||
std::vector<std::vector<std::shared_ptr<Event>>> post_compute_events_;
|
||||
std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_;
|
||||
std::vector<std::vector<std::shared_ptr<MemEvent>>> step_events_;
|
||||
std::map<const void *, void *> mem_result_;
|
||||
std::map<const void *, void *> init_host_ptr_;
|
||||
std::map<const void *, void *> swap_host_ptr_;
|
||||
std::map<const void *, void *> high_priority_device_ptr_;
|
||||
size_t compute_index_{0};
|
||||
size_t total_step_{0};
|
||||
size_t current_step_{0};
|
||||
bool need_record_event_{true};
|
||||
bool optimized_{false};
|
||||
bool has_compute_mem_events_{false};
|
||||
std::shared_ptr<MemHandler> mem_handler_{nullptr};
|
||||
bool need_swap_{false};
|
||||
std::multimap<size_t, std::shared_ptr<Event>> event_span_;
|
||||
std::set<std::shared_ptr<Event>> no_swap_events_;
|
||||
std::vector<size_t> min_mem_used_;
|
||||
size_t mem_used_without_swap_{0};
|
||||
size_t min_mem_needed_{0};
|
||||
float mem_used_factor_{0.9};
|
||||
double compute_start_time_{0};
|
||||
std::vector<double> compute_time_;
|
||||
bool record_compute_time_{false};
|
||||
bool updated_{false};
|
||||
std::shared_ptr<MemHandler> mem_handler_{nullptr};
|
||||
std::shared_ptr<MemOffloadStrategy> strategy_{nullptr};
|
||||
};
|
||||
|
||||
class MemSchedulerManager {
|
||||
|
|
|
@ -114,6 +114,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/runtime/device/kernel_runtime.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/memory_manager.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/memory_scheduler.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/memory_offload_strategy.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/kernel_info.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/bucket.cc"
|
||||
|
|
|
@ -104,6 +104,7 @@ TEST_F(TestMemScheduler, test_mem_scheduler) {
|
|||
std::vector<size_t> init_tensors = {0, 2, 4};
|
||||
std::vector<std::vector<size_t>> step_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6}, {6, 7}, {2, 7, 8}, {4, 8, 9}};
|
||||
void *stream = nullptr;
|
||||
scheduler->SetTotalStep(kTimeSlice);
|
||||
// record
|
||||
for (auto index : init_tensors) {
|
||||
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
|
||||
|
@ -118,11 +119,10 @@ TEST_F(TestMemScheduler, test_mem_scheduler) {
|
|||
scheduler->set_need_record_event(false);
|
||||
|
||||
// optimize
|
||||
scheduler->OptMemUsage();
|
||||
scheduler->set_optimized(true);
|
||||
scheduler->Optimize();
|
||||
|
||||
// run
|
||||
scheduler->Reset();
|
||||
scheduler->ResetCurrentStep();
|
||||
for (auto index : init_tensors) {
|
||||
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue