!26402 add mem offload strategy

Merge pull request !26402 from kisnwang/add-mem-offload-strategy
This commit is contained in:
i-robot 2021-11-19 03:42:52 +00:00 committed by Gitee
commit 55463892e4
8 changed files with 450 additions and 262 deletions

View File

@ -1,7 +1,7 @@
file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc"
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc"
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc"
"bucket.cc" "launch_kernel.cc" "launch_mul.cc"
"memory_offload_strategy.cc" "bucket.cc" "launch_kernel.cc" "launch_mul.cc"
)
if("${ENABLE_HIDDEN}" STREQUAL "OFF")

View File

@ -42,9 +42,6 @@ using mindspore::kernel::AddressPtr;
namespace mindspore {
namespace device {
constexpr float kMaxMemReuseFactor = 0.8;
constexpr float kMinMemReuseFactor = 0.5;
constexpr float kRetryFactor = 0.1;
constexpr size_t kAtomicCleanInputSize = 2;
namespace {
std::vector<AnfNodePtr> GetGraphInputs(const session::KernelGraph &graph) {
@ -1493,13 +1490,15 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::shared_ptr<MemScheduler> mem_scheduler = nullptr;
if (UseMemScheduler()) {
mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
MS_EXCEPTION_IF_NULL(mem_scheduler);
mem_scheduler->SetMemHandler(mem_manager_);
mem_scheduler->Reset();
mem_scheduler->ResetCurrentStep();
mem_scheduler->Update();
InitGraphInputTensors(mem_scheduler, graph);
}
const auto &kernels = graph.execution_order();
std::vector<DynamicKernelPtr> dynamic_kernel_list;
auto iter = graph_dynamic_kernel_map_.find(graph.graph_id());
@ -1554,7 +1553,6 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
}
LaunchKernelEvent(kernel_post_run_events, kernels[i]);
}
return true;
}
@ -1565,21 +1563,15 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
return;
}
auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
MS_EXCEPTION_IF_NULL(mem_scheduler);
mem_scheduler->SetMemHandler(mem_manager_);
mem_scheduler->SetTotalStep(graph.execution_order().size());
if (mem_scheduler->need_record_event()) {
(void)LaunchKernelMod(graph, true);
mem_scheduler->set_need_record_event(false);
}
float mem_used_factor = kMaxMemReuseFactor;
while (!mem_scheduler->optimized() && mem_used_factor >= kMinMemReuseFactor) {
mem_scheduler->SetMemUsedFactor(mem_used_factor);
mem_scheduler->OptMemUsage();
bool ret = LaunchKernelMod(graph, true);
if (ret) {
mem_scheduler->set_optimized(true);
} else {
mem_used_factor -= kRetryFactor;
}
}
mem_scheduler->Optimize();
if (!mem_scheduler->optimized()) {
MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit.";
}

View File

@ -0,0 +1,224 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/device/memory_offload_strategy.h"
#include <vector>
#include <map>
#include <set>
#include <memory>
#include <utility>
#include "utils/log_adapter.h"
namespace mindspore {
namespace device {
std::vector<std::shared_ptr<MemEvent>> &MemOffloadStrategy::GetPreComputeEvents(size_t step) {
if (pre_compute_events_.size() <= step) {
MS_LOG_EXCEPTION << "Index out of pre event range, index:" << step << ", event size:" << pre_compute_events_.size();
}
return pre_compute_events_[step];
}
std::vector<std::shared_ptr<MemEvent>> &MemOffloadStrategy::GetPostComputeEvents(size_t step) {
if (post_compute_events_.size() <= step) {
MS_LOG_EXCEPTION << "Index out of post event range, index:" << step
<< ", event size:" << post_compute_events_.size();
}
return post_compute_events_[step];
}
void MemOffloadStrategy::Execute() {
CountMemUsage();
CheckMemSize();
if (need_swap_) {
GenEventSpan();
GenNoSwapEventSet();
}
GenComputeMemEvents();
}
void MemOffloadStrategy::CountMemUsage() {
if (!min_mem_used_.empty()) {
return;
}
if (mem_events_.empty() || total_step_ == 0) {
return;
}
min_mem_used_.resize(total_step_, 0);
std::vector<size_t> total_mem_used(total_step_, 0);
for (auto &item : mem_events_) {
auto &mem_events = item.second;
if (mem_events.empty()) {
continue;
}
auto first_event = mem_events[0];
size_t cur_index = 0;
if (first_event != nullptr && first_event->type == kInit && mem_events.size() > 1) {
first_event = mem_events[1];
cur_index = 1;
}
auto last_event = mem_events[mem_events.size() - 1];
for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
if (start_index < total_step_) {
total_mem_used[start_index] += first_event->mem_size;
} else {
MS_LOG(ERROR) << "Error mem event index " << start_index;
}
}
for (; cur_index < mem_events.size(); ++cur_index) {
auto &event = mem_events[cur_index];
MS_EXCEPTION_IF_NULL(event);
if (event->index < total_step_) {
min_mem_used_[event->index] += first_event->mem_size;
} else {
MS_LOG(ERROR) << "Error mem event index " << event->index;
}
}
}
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end()));
}
void MemOffloadStrategy::CheckMemSize() {
if (mem_size_ < min_mem_needed_) {
MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << mem_size_ << " while graph needs at least "
<< min_mem_needed_;
}
if (mem_size_ < mem_used_without_swap_) {
need_swap_ = true;
}
MS_LOG(INFO) << "Available mem size: " << mem_size_ << ", graph needs mem size: " << mem_used_without_swap_
<< " without swap, and needs at least " << min_mem_needed_ << " with swap.";
}
void MemOffloadStrategy::GenEventSpan() {
if (!event_span_.empty()) {
return;
}
for (auto &item : mem_events_) {
auto &tensor_events = item.second;
if (tensor_events.empty()) {
continue;
}
auto first_event = tensor_events[0];
size_t cur_index = 0;
if (first_event != nullptr && first_event->type == kInit && tensor_events.size() > 1) {
first_event = tensor_events[1];
cur_index = 1;
}
size_t last_index = first_event->index;
for (; cur_index < tensor_events.size(); ++cur_index) {
auto &event = tensor_events[cur_index];
MS_EXCEPTION_IF_NULL(event);
auto span = event->index - last_index;
if (span > 1) {
(void)event_span_.emplace(span, event);
}
last_index = event->index;
}
}
}
void MemOffloadStrategy::GenNoSwapEventSet() {
no_swap_events_.clear();
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) {
auto span = iter->first;
auto &event = iter->second;
auto start_index = event->index - span + 1;
bool revert = false;
for (size_t i = start_index; i < event->index; ++i) {
cur_mem_used[i] += event->mem_size;
if (cur_mem_used[i] > mem_size_) {
revert = true;
}
}
if (revert) {
for (size_t i = start_index; i < event->index; ++i) {
cur_mem_used[i] -= event->mem_size;
}
} else {
(void)no_swap_events_.emplace(event);
}
}
}
void MemOffloadStrategy::GenComputeMemEvents() {
pre_compute_events_.clear();
post_compute_events_.clear();
pre_compute_events_.resize(total_step_);
post_compute_events_.resize(total_step_);
for (auto &item : mem_events_) {
auto &mem_events = item.second;
if (mem_events.empty()) {
continue;
}
auto first_event = mem_events[0];
MS_EXCEPTION_IF_NULL(first_event);
if (first_event->type == kInit) {
if (mem_events.size() > 1) {
auto &second_event = mem_events[1];
MS_EXCEPTION_IF_NULL(second_event);
first_event->index = second_event->index;
} else {
continue;
}
}
if ((first_event->type == kInit || first_event->type == kMalloc) &&
first_event->index < pre_compute_events_.size()) {
pre_compute_events_[first_event->index].emplace_back(first_event);
} else {
MS_LOG_EXCEPTION << "First event should be init or malloc!";
}
MemPriority priority = kMemPriorityLow;
auto iter = mem_priority_.find(first_event->key);
if (iter != mem_priority_.end()) {
priority = iter->second;
}
size_t pre_index = first_event->index;
for (size_t i = 1; i < mem_events.size(); ++i) {
auto &event = mem_events[i];
MS_EXCEPTION_IF_NULL(event);
if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow &&
no_swap_events_.find(event) == no_swap_events_.end()) {
auto swap_out_event = std::make_shared<MemEvent>(kSwapOut, pre_index);
swap_out_event->key = item.first;
swap_out_event->mem_size = first_event->mem_size;
post_compute_events_[pre_index].emplace_back(swap_out_event);
auto swap_in_event = std::make_shared<MemEvent>(kSwapIn, event->index);
swap_in_event->key = item.first;
swap_in_event->mem_size = first_event->mem_size;
(void)pre_compute_events_[event->index].emplace_back(swap_in_event);
}
if (event->index < pre_compute_events_.size()) {
(void)pre_compute_events_[event->index].emplace_back(event);
}
pre_index = event->index;
}
if (priority != kMemPriorityLow) {
continue;
}
auto &last_event = mem_events[mem_events.size() - 1];
MS_EXCEPTION_IF_NULL(last_event);
auto free_event = std::make_shared<MemEvent>(kFree, last_event->index);
free_event->key = item.first;
if (last_event->index < post_compute_events_.size()) {
(void)post_compute_events_[last_event->index].emplace_back(free_event);
}
}
}
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,85 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_
#include <vector>
#include <map>
#include <set>
#include <memory>
#include <utility>
namespace mindspore {
namespace device {
enum MemPriority { kMemPriorityLow, kMemPriorityHigh };
enum MemEventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut };
struct MemEvent {
MemEvent(const MemEventType &in_type, size_t in_index) : type(in_type), index(in_index) {}
MemEventType type;
size_t index{0};
size_t mem_size{0};
const void *key{nullptr};
};
class MemOffloadStrategy {
public:
MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority,
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events,
size_t total_step)
: mem_priority_(mem_priority), mem_events_(mem_events), total_step_(total_step) {}
virtual ~MemOffloadStrategy() = default;
virtual void Execute();
void SetComputeTime(const std::vector<double> &compute_time) { compute_time_ = compute_time; }
std::vector<std::shared_ptr<MemEvent>> &GetPreComputeEvents(size_t step);
std::vector<std::shared_ptr<MemEvent>> &GetPostComputeEvents(size_t step);
void set_mem_size(size_t mem_size) { mem_size_ = mem_size; }
bool need_swap() const { return need_swap_; }
private:
void CountMemUsage();
void CheckMemSize();
void GenEventSpan();
void GenNoSwapEventSet();
void GenComputeMemEvents();
const std::map<const void *, MemPriority> &mem_priority_;
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_;
const size_t total_step_;
std::vector<std::vector<std::shared_ptr<MemEvent>>> pre_compute_events_;
std::vector<std::vector<std::shared_ptr<MemEvent>>> post_compute_events_;
size_t mem_size_{0};
std::vector<double> compute_time_;
bool need_swap_{false};
std::multimap<size_t, std::shared_ptr<MemEvent>> event_span_;
std::set<std::shared_ptr<MemEvent>> no_swap_events_;
std::vector<size_t> min_mem_used_;
size_t mem_used_without_swap_{0};
size_t min_mem_needed_{0};
};
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_

View File

@ -17,9 +17,30 @@
#include "runtime/device/memory_scheduler.h"
#include <algorithm>
#include "utils/log_adapter.h"
#ifdef _MSC_VER
#include <time.h>
#else
#include <sys/time.h>
#endif
namespace mindspore {
namespace device {
namespace {
constexpr float kMaxMemReuseFactor = 0.9;
constexpr float kMinMemReuseFactor = 0.5;
constexpr float kRetryFactor = 0.1;
double GetCurrentTime() {
#ifdef _MSC_VER
return time(NULL) * 1.0e6;
#else
struct timeval tv;
(void)gettimeofday(&tv, nullptr);
return tv.tv_sec * 1.0e6 + tv.tv_usec;
#endif
}
} // namespace
void MemScheduler::Clear() {
if (mem_handler_ == nullptr) {
return;
@ -40,23 +61,26 @@ bool MemScheduler::IsHighPriorityMem(const void *key) {
void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
void MemScheduler::Record(const void *key, const EventType &event_type, size_t mem_size) {
void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
if (key == nullptr) {
return;
}
auto event = std::make_shared<Event>(event_type, compute_index_);
auto event = std::make_shared<MemEvent>(event_type, current_step_);
event->mem_size = mem_size;
event->key = key;
mem_events_[key].emplace_back(event);
if (step_events_.size() < current_step_ + 1) {
step_events_.resize(current_step_ + 1);
}
step_events_[current_step_].emplace_back(event);
}
void MemScheduler::Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority) {
if (need_record_event_) {
mem_priority_[key] = priority;
Record(key, kInit, mem_size);
} else {
init_host_ptr_[key] = host_ptr;
}
init_host_ptr_[key] = host_ptr;
}
void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority priority) {
@ -69,7 +93,7 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
}
return nullptr;
}
if (!has_compute_mem_events_) {
if (strategy_ == nullptr) {
return nullptr;
}
auto iter = mem_result_.find(key);
@ -83,18 +107,14 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
}
bool MemScheduler::PreCompute(void *stream) {
if (!has_compute_mem_events_) {
if (strategy_ == nullptr) {
return true;
}
MS_EXCEPTION_IF_NULL(mem_handler_);
if (pre_compute_events_.size() <= compute_index_) {
MS_LOG_EXCEPTION << "Index out of pre event range, index:" << compute_index_
<< ", event size:" << pre_compute_events_.size();
}
auto &events = pre_compute_events_[compute_index_];
auto &events = strategy_->GetPreComputeEvents(current_step_);
for (auto &event : events) {
MS_EXCEPTION_IF_NULL(event);
MS_LOG(DEBUG) << "Pre compute " << compute_index_ << ": " << event->key << " v " << event->type;
MS_LOG(DEBUG) << "Pre compute " << current_step_ << ": " << event->key << " v " << event->type;
if (event->type == kInit || event->type == kMalloc) {
auto priority = mem_priority_[event->key];
auto iter = high_priority_device_ptr_.find(event->key);
@ -137,23 +157,27 @@ bool MemScheduler::PreCompute(void *stream) {
}
}
}
if (record_compute_time_ && !updated_) {
compute_start_time_ = GetCurrentTime();
}
return true;
}
bool MemScheduler::PostCompute(void *stream) {
if (!has_compute_mem_events_) {
++compute_index_;
if (strategy_ == nullptr) {
++current_step_;
return true;
}
MS_EXCEPTION_IF_NULL(mem_handler_);
if (post_compute_events_.size() <= compute_index_) {
MS_LOG_EXCEPTION << "Index out of post event range, index:" << compute_index_
<< ", event size:" << post_compute_events_.size();
if (record_compute_time_ && !updated_) {
compute_time_[current_step_] = GetCurrentTime() - compute_start_time_;
}
auto &events = post_compute_events_[compute_index_];
auto &events = strategy_->GetPostComputeEvents(current_step_);
for (auto &event : events) {
MS_EXCEPTION_IF_NULL(event);
MS_LOG(DEBUG) << "Post compute " << compute_index_ << ": " << event->key << " v " << event->type;
MS_LOG(DEBUG) << "Post compute " << current_step_ << ": " << event->key << " v " << event->type;
if (event->type == kFree) {
auto ptr = mem_result_[event->key];
if (ptr == nullptr) {
@ -174,201 +198,81 @@ bool MemScheduler::PostCompute(void *stream) {
MS_EXCEPTION_IF_NULL(host_ptr);
mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
mem_handler_->FreeDevice(device_ptr);
(void)mem_result_.erase(device_ptr);
(void)mem_result_.erase(event->key);
}
}
++compute_index_;
++current_step_;
return true;
}
void MemScheduler::OptMemUsage() {
if (optimized_) {
return;
}
CountMemUsage();
CheckMemSize();
if (need_swap_) {
GenEventSpan();
GenNoSwapEventSet();
}
GenComputeMemEvents();
has_compute_mem_events_ = true;
}
void MemScheduler::CountMemUsage() {
if (!min_mem_used_.empty()) {
return;
}
if (mem_events_.empty() || compute_index_ == 0) {
return;
}
min_mem_used_.resize(compute_index_, 0);
std::vector<size_t> total_mem_used(compute_index_, 0);
for (auto &item : mem_events_) {
auto &mem_events = item.second;
if (mem_events.empty()) {
continue;
}
auto first_event = mem_events[0];
size_t cur_index = 0;
if (first_event != nullptr && first_event->type == kInit && mem_events.size() > 1) {
first_event = mem_events[1];
cur_index = 1;
}
auto last_event = mem_events[mem_events.size() - 1];
for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
if (start_index < compute_index_) {
total_mem_used[start_index] += first_event->mem_size;
} else {
MS_LOG(ERROR) << "Error mem event index " << start_index;
}
}
for (; cur_index < mem_events.size(); ++cur_index) {
auto &event = mem_events[cur_index];
MS_EXCEPTION_IF_NULL(event);
if (event->index < compute_index_) {
min_mem_used_[event->index] += first_event->mem_size;
} else {
MS_LOG(ERROR) << "Error mem event index " << event->index;
}
}
}
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end()));
}
void MemScheduler::CheckMemSize() {
void MemScheduler::OptMemUsage(float mem_used_factor) {
mem_used_factor_ = mem_used_factor;
MS_EXCEPTION_IF_NULL(mem_handler_);
if (strategy_ == nullptr) {
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, total_step_);
compute_time_.resize(total_step_);
}
auto available_mem_size = mem_handler_->GetAvailableMemSize();
if (available_mem_size < min_mem_needed_) {
MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << available_mem_size
<< " while graph needs at least " << min_mem_needed_;
}
if (mem_used_without_swap_ > available_mem_size) {
need_swap_ = true;
}
MS_LOG(INFO) << "Available mem size: " << available_mem_size << ", graph needs mem size: " << mem_used_without_swap_
<< " without swap, and needs at least " << min_mem_needed_ << " with swap.";
available_mem_size = available_mem_size * mem_used_factor_;
strategy_->set_mem_size(available_mem_size);
strategy_->Execute();
}
void MemScheduler::GenEventSpan() {
if (!event_span_.empty()) {
void MemScheduler::Optimize() {
float mem_used_factor = kMaxMemReuseFactor;
while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) {
OptMemUsage(mem_used_factor);
current_step_ = 0;
bool ret = true;
for (size_t step = 0; step < total_step_; ++step) {
ret = PreCompute(nullptr);
auto &step_events = step_events_[step];
for (auto &event : step_events) {
if (event->type != kGet) {
continue;
}
auto ptr = GetOrMalloc(event->key, event->mem_size);
if (ptr == nullptr) {
ret = false;
break;
}
}
if (!ret) {
break;
}
PostCompute(nullptr);
}
if (ret) {
optimized_ = true;
} else {
mem_used_factor -= kRetryFactor;
}
}
}
void MemScheduler::Update() {
if (!optimized_) {
return;
}
for (auto &item : mem_events_) {
auto &tensor_events = item.second;
if (tensor_events.empty()) {
continue;
}
auto first_event = tensor_events[0];
size_t cur_index = 0;
if (first_event != nullptr && first_event->type == kInit && tensor_events.size() > 1) {
first_event = tensor_events[1];
cur_index = 1;
}
size_t last_index = first_event->index;
for (; cur_index < tensor_events.size(); ++cur_index) {
auto &event = tensor_events[cur_index];
MS_EXCEPTION_IF_NULL(event);
auto span = event->index - last_index;
if (span > 1) {
(void)event_span_.emplace(std::pair<size_t, std::shared_ptr<Event>>(span, event));
}
last_index = event->index;
}
}
}
void MemScheduler::GenNoSwapEventSet() {
MS_EXCEPTION_IF_NULL(mem_handler_);
auto available_mem_size = mem_handler_->GetAvailableMemSize();
auto threshold = available_mem_size * mem_used_factor_;
no_swap_events_.clear();
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) {
auto span = iter->first;
auto &event = iter->second;
auto start_index = event->index - span + 1;
bool revert = false;
for (size_t i = start_index; i < event->index; ++i) {
cur_mem_used[i] += event->mem_size;
if (cur_mem_used[i] > threshold) {
revert = true;
}
}
if (revert) {
for (size_t i = start_index; i < event->index; ++i) {
cur_mem_used[i] -= event->mem_size;
}
} else {
(void)no_swap_events_.emplace(event);
}
if (strategy_ == nullptr || !strategy_->need_swap()) {
return;
}
}
void MemScheduler::GenComputeMemEvents() {
pre_compute_events_.clear();
post_compute_events_.clear();
pre_compute_events_.resize(compute_index_);
post_compute_events_.resize(compute_index_);
for (auto &item : mem_events_) {
auto &mem_events = item.second;
if (mem_events.empty()) {
continue;
}
auto first_event = mem_events[0];
MS_EXCEPTION_IF_NULL(first_event);
if (first_event->type == kInit) {
if (mem_events.size() > 1) {
auto &second_event = mem_events[1];
MS_EXCEPTION_IF_NULL(second_event);
first_event->index = second_event->index;
} else {
continue;
}
}
if ((first_event->type == kInit || first_event->type == kMalloc) &&
first_event->index < pre_compute_events_.size()) {
pre_compute_events_[first_event->index].emplace_back(first_event);
} else {
MS_LOG_EXCEPTION << "First event should be init or malloc!";
}
MemPriority priority = kMemPriorityLow;
auto iter = mem_priority_.find(first_event->key);
if (iter != mem_priority_.end()) {
priority = iter->second;
}
size_t pre_index = first_event->index;
for (size_t i = 1; i < mem_events.size(); ++i) {
auto &event = mem_events[i];
MS_EXCEPTION_IF_NULL(event);
if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow &&
no_swap_events_.find(event) == no_swap_events_.end()) {
auto swap_out_event = std::make_shared<Event>(kSwapOut, pre_index);
swap_out_event->key = item.first;
swap_out_event->mem_size = first_event->mem_size;
post_compute_events_[pre_index].emplace_back(swap_out_event);
auto swap_in_event = std::make_shared<Event>(kSwapIn, event->index);
swap_in_event->key = item.first;
swap_in_event->mem_size = first_event->mem_size;
(void)pre_compute_events_[event->index].emplace_back(swap_in_event);
}
if (event->index < pre_compute_events_.size()) {
(void)pre_compute_events_[event->index].emplace_back(event);
}
pre_index = event->index;
}
if (priority != kMemPriorityLow) {
continue;
}
auto &last_event = mem_events[mem_events.size() - 1];
MS_EXCEPTION_IF_NULL(last_event);
auto free_event = std::make_shared<Event>(kFree, last_event->index);
free_event->key = item.first;
if (last_event->index < post_compute_events_.size()) {
(void)post_compute_events_[last_event->index].emplace_back(free_event);
}
if (updated_) {
return;
}
if (!record_compute_time_) {
record_compute_time_ = true;
return;
}
strategy_->SetComputeTime(compute_time_);
strategy_->Execute();
updated_ = true;
}
} // namespace device
} // namespace mindspore

View File

@ -21,6 +21,7 @@
#include <set>
#include <memory>
#include <utility>
#include "runtime/device/memory_offload_strategy.h"
namespace mindspore {
namespace device {
@ -35,23 +36,7 @@ class MemHandler {
virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0;
};
enum MemPriority { kMemPriorityLow, kMemPriorityHigh };
class MemScheduler {
enum EventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut };
struct Event {
Event(const EventType &in_type, size_t in_index) {
type = in_type;
index = in_index;
}
EventType type;
size_t index{0};
size_t mem_size{0};
const void *key{nullptr};
};
public:
MemScheduler() = default;
~MemScheduler() = default;
@ -62,7 +47,7 @@ class MemScheduler {
bool optimized() const { return optimized_; }
void set_optimized(bool flag) { optimized_ = flag; }
void Update();
void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
@ -70,13 +55,18 @@ class MemScheduler {
void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow);
void Reset() { compute_index_ = 0; }
void SetTotalStep(size_t step) {
total_step_ = step;
step_events_.resize(total_step_);
}
void ResetCurrentStep() { current_step_ = 0; }
bool PreCompute(void *stream);
bool PostCompute(void *stream);
void OptMemUsage();
void Optimize();
void Clear();
@ -84,37 +74,29 @@ class MemScheduler {
void SetMemPriority(const void *key, MemPriority priority);
void SetMemUsedFactor(float factor) { mem_used_factor_ = factor; }
void SetNeedSwap(bool flag) { need_swap_ = flag; }
private:
void Record(const void *key, const EventType &event_type, size_t mem_size = 0);
void GenComputeMemEvents();
void CheckMemSize();
void CountMemUsage();
void GenEventSpan();
void GenNoSwapEventSet();
void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
void OptMemUsage(float mem_used_factor = 1.0f);
std::map<const void *, MemPriority> mem_priority_;
std::map<const void *, std::vector<std::shared_ptr<Event>>> mem_events_;
std::vector<std::vector<std::shared_ptr<Event>>> pre_compute_events_;
std::vector<std::vector<std::shared_ptr<Event>>> post_compute_events_;
std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_;
std::vector<std::vector<std::shared_ptr<MemEvent>>> step_events_;
std::map<const void *, void *> mem_result_;
std::map<const void *, void *> init_host_ptr_;
std::map<const void *, void *> swap_host_ptr_;
std::map<const void *, void *> high_priority_device_ptr_;
size_t compute_index_{0};
size_t total_step_{0};
size_t current_step_{0};
bool need_record_event_{true};
bool optimized_{false};
bool has_compute_mem_events_{false};
std::shared_ptr<MemHandler> mem_handler_{nullptr};
bool need_swap_{false};
std::multimap<size_t, std::shared_ptr<Event>> event_span_;
std::set<std::shared_ptr<Event>> no_swap_events_;
std::vector<size_t> min_mem_used_;
size_t mem_used_without_swap_{0};
size_t min_mem_needed_{0};
float mem_used_factor_{0.9};
double compute_start_time_{0};
std::vector<double> compute_time_;
bool record_compute_time_{false};
bool updated_{false};
std::shared_ptr<MemHandler> mem_handler_{nullptr};
std::shared_ptr<MemOffloadStrategy> strategy_{nullptr};
};
class MemSchedulerManager {

View File

@ -114,6 +114,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/runtime/device/kernel_runtime.cc"
"../../../mindspore/ccsrc/runtime/device/memory_manager.cc"
"../../../mindspore/ccsrc/runtime/device/memory_scheduler.cc"
"../../../mindspore/ccsrc/runtime/device/memory_offload_strategy.cc"
"../../../mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc"
"../../../mindspore/ccsrc/runtime/device/kernel_info.cc"
"../../../mindspore/ccsrc/runtime/device/bucket.cc"

View File

@ -104,6 +104,7 @@ TEST_F(TestMemScheduler, test_mem_scheduler) {
std::vector<size_t> init_tensors = {0, 2, 4};
std::vector<std::vector<size_t>> step_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6}, {6, 7}, {2, 7, 8}, {4, 8, 9}};
void *stream = nullptr;
scheduler->SetTotalStep(kTimeSlice);
// record
for (auto index : init_tensors) {
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
@ -118,11 +119,10 @@ TEST_F(TestMemScheduler, test_mem_scheduler) {
scheduler->set_need_record_event(false);
// optimize
scheduler->OptMemUsage();
scheduler->set_optimized(true);
scheduler->Optimize();
// run
scheduler->Reset();
scheduler->ResetCurrentStep();
for (auto index : init_tensors) {
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
}