forked from mindspore-Ecosystem/mindspore
add mem offload strategy
This commit is contained in:
parent
9522ee9686
commit
7c312bd38c
|
@ -1,7 +1,7 @@
|
||||||
file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc"
|
file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc"
|
||||||
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc"
|
"kernel_info.cc" "executor/dynamic_kernel.cc" "executor/executor_callback.cc" "kernel_runtime.cc"
|
||||||
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc"
|
"memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" "memory_scheduler.cc"
|
||||||
"bucket.cc" "launch_kernel.cc" "launch_mul.cc"
|
"memory_offload_strategy.cc" "bucket.cc" "launch_kernel.cc" "launch_mul.cc"
|
||||||
)
|
)
|
||||||
|
|
||||||
if("${ENABLE_HIDDEN}" STREQUAL "OFF")
|
if("${ENABLE_HIDDEN}" STREQUAL "OFF")
|
||||||
|
|
|
@ -42,9 +42,6 @@ using mindspore::kernel::AddressPtr;
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
constexpr float kMaxMemReuseFactor = 0.8;
|
|
||||||
constexpr float kMinMemReuseFactor = 0.5;
|
|
||||||
constexpr float kRetryFactor = 0.1;
|
|
||||||
constexpr size_t kAtomicCleanInputSize = 2;
|
constexpr size_t kAtomicCleanInputSize = 2;
|
||||||
namespace {
|
namespace {
|
||||||
std::vector<AnfNodePtr> GetGraphInputs(const session::KernelGraph &graph) {
|
std::vector<AnfNodePtr> GetGraphInputs(const session::KernelGraph &graph) {
|
||||||
|
@ -1494,13 +1491,15 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
|
||||||
auto context_ptr = MsContext::GetInstance();
|
auto context_ptr = MsContext::GetInstance();
|
||||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||||
std::shared_ptr<MemScheduler> mem_scheduler = nullptr;
|
std::shared_ptr<MemScheduler> mem_scheduler = nullptr;
|
||||||
|
|
||||||
if (UseMemScheduler()) {
|
if (UseMemScheduler()) {
|
||||||
mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
|
mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
|
||||||
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||||
mem_scheduler->SetMemHandler(mem_manager_);
|
mem_scheduler->ResetCurrentStep();
|
||||||
mem_scheduler->Reset();
|
mem_scheduler->Update();
|
||||||
InitGraphInputTensors(mem_scheduler, graph);
|
InitGraphInputTensors(mem_scheduler, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto &kernels = graph.execution_order();
|
const auto &kernels = graph.execution_order();
|
||||||
std::vector<DynamicKernelPtr> dynamic_kernel_list;
|
std::vector<DynamicKernelPtr> dynamic_kernel_list;
|
||||||
auto iter = graph_dynamic_kernel_map_.find(graph.graph_id());
|
auto iter = graph_dynamic_kernel_map_.find(graph.graph_id());
|
||||||
|
@ -1555,7 +1554,6 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
|
||||||
}
|
}
|
||||||
LaunchKernelEvent(kernel_post_run_events, i);
|
LaunchKernelEvent(kernel_post_run_events, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1566,21 +1564,15 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
|
auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
|
||||||
|
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||||
|
mem_scheduler->SetMemHandler(mem_manager_);
|
||||||
|
mem_scheduler->SetTotalStep(graph.execution_order().size());
|
||||||
|
|
||||||
if (mem_scheduler->need_record_event()) {
|
if (mem_scheduler->need_record_event()) {
|
||||||
(void)LaunchKernelMod(graph, true);
|
(void)LaunchKernelMod(graph, true);
|
||||||
mem_scheduler->set_need_record_event(false);
|
mem_scheduler->set_need_record_event(false);
|
||||||
}
|
}
|
||||||
float mem_used_factor = kMaxMemReuseFactor;
|
mem_scheduler->Optimize();
|
||||||
while (!mem_scheduler->optimized() && mem_used_factor >= kMinMemReuseFactor) {
|
|
||||||
mem_scheduler->SetMemUsedFactor(mem_used_factor);
|
|
||||||
mem_scheduler->OptMemUsage();
|
|
||||||
bool ret = LaunchKernelMod(graph, true);
|
|
||||||
if (ret) {
|
|
||||||
mem_scheduler->set_optimized(true);
|
|
||||||
} else {
|
|
||||||
mem_used_factor -= kRetryFactor;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!mem_scheduler->optimized()) {
|
if (!mem_scheduler->optimized()) {
|
||||||
MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit.";
|
MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit.";
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,224 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "runtime/device/memory_offload_strategy.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
std::vector<std::shared_ptr<MemEvent>> &MemOffloadStrategy::GetPreComputeEvents(size_t step) {
|
||||||
|
if (pre_compute_events_.size() <= step) {
|
||||||
|
MS_LOG_EXCEPTION << "Index out of pre event range, index:" << step << ", event size:" << pre_compute_events_.size();
|
||||||
|
}
|
||||||
|
return pre_compute_events_[step];
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<MemEvent>> &MemOffloadStrategy::GetPostComputeEvents(size_t step) {
|
||||||
|
if (post_compute_events_.size() <= step) {
|
||||||
|
MS_LOG_EXCEPTION << "Index out of post event range, index:" << step
|
||||||
|
<< ", event size:" << post_compute_events_.size();
|
||||||
|
}
|
||||||
|
return post_compute_events_[step];
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemOffloadStrategy::Execute() {
|
||||||
|
CountMemUsage();
|
||||||
|
CheckMemSize();
|
||||||
|
if (need_swap_) {
|
||||||
|
GenEventSpan();
|
||||||
|
GenNoSwapEventSet();
|
||||||
|
}
|
||||||
|
GenComputeMemEvents();
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemOffloadStrategy::CountMemUsage() {
|
||||||
|
if (!min_mem_used_.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (mem_events_.empty() || total_step_ == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
min_mem_used_.resize(total_step_, 0);
|
||||||
|
std::vector<size_t> total_mem_used(total_step_, 0);
|
||||||
|
for (auto &item : mem_events_) {
|
||||||
|
auto &mem_events = item.second;
|
||||||
|
if (mem_events.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto first_event = mem_events[0];
|
||||||
|
size_t cur_index = 0;
|
||||||
|
if (first_event != nullptr && first_event->type == kInit && mem_events.size() > 1) {
|
||||||
|
first_event = mem_events[1];
|
||||||
|
cur_index = 1;
|
||||||
|
}
|
||||||
|
auto last_event = mem_events[mem_events.size() - 1];
|
||||||
|
for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
|
||||||
|
if (start_index < total_step_) {
|
||||||
|
total_mem_used[start_index] += first_event->mem_size;
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Error mem event index " << start_index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (; cur_index < mem_events.size(); ++cur_index) {
|
||||||
|
auto &event = mem_events[cur_index];
|
||||||
|
MS_EXCEPTION_IF_NULL(event);
|
||||||
|
if (event->index < total_step_) {
|
||||||
|
min_mem_used_[event->index] += first_event->mem_size;
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Error mem event index " << event->index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
|
||||||
|
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end()));
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemOffloadStrategy::CheckMemSize() {
|
||||||
|
if (mem_size_ < min_mem_needed_) {
|
||||||
|
MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << mem_size_ << " while graph needs at least "
|
||||||
|
<< min_mem_needed_;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mem_size_ < mem_used_without_swap_) {
|
||||||
|
need_swap_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Available mem size: " << mem_size_ << ", graph needs mem size: " << mem_used_without_swap_
|
||||||
|
<< " without swap, and needs at least " << min_mem_needed_ << " with swap.";
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemOffloadStrategy::GenEventSpan() {
|
||||||
|
if (!event_span_.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (auto &item : mem_events_) {
|
||||||
|
auto &tensor_events = item.second;
|
||||||
|
if (tensor_events.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto first_event = tensor_events[0];
|
||||||
|
size_t cur_index = 0;
|
||||||
|
if (first_event != nullptr && first_event->type == kInit && tensor_events.size() > 1) {
|
||||||
|
first_event = tensor_events[1];
|
||||||
|
cur_index = 1;
|
||||||
|
}
|
||||||
|
size_t last_index = first_event->index;
|
||||||
|
for (; cur_index < tensor_events.size(); ++cur_index) {
|
||||||
|
auto &event = tensor_events[cur_index];
|
||||||
|
MS_EXCEPTION_IF_NULL(event);
|
||||||
|
auto span = event->index - last_index;
|
||||||
|
if (span > 1) {
|
||||||
|
(void)event_span_.emplace(span, event);
|
||||||
|
}
|
||||||
|
last_index = event->index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemOffloadStrategy::GenNoSwapEventSet() {
|
||||||
|
no_swap_events_.clear();
|
||||||
|
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
|
||||||
|
for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) {
|
||||||
|
auto span = iter->first;
|
||||||
|
auto &event = iter->second;
|
||||||
|
auto start_index = event->index - span + 1;
|
||||||
|
bool revert = false;
|
||||||
|
for (size_t i = start_index; i < event->index; ++i) {
|
||||||
|
cur_mem_used[i] += event->mem_size;
|
||||||
|
if (cur_mem_used[i] > mem_size_) {
|
||||||
|
revert = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (revert) {
|
||||||
|
for (size_t i = start_index; i < event->index; ++i) {
|
||||||
|
cur_mem_used[i] -= event->mem_size;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(void)no_swap_events_.emplace(event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemOffloadStrategy::GenComputeMemEvents() {
|
||||||
|
pre_compute_events_.clear();
|
||||||
|
post_compute_events_.clear();
|
||||||
|
pre_compute_events_.resize(total_step_);
|
||||||
|
post_compute_events_.resize(total_step_);
|
||||||
|
for (auto &item : mem_events_) {
|
||||||
|
auto &mem_events = item.second;
|
||||||
|
if (mem_events.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto first_event = mem_events[0];
|
||||||
|
MS_EXCEPTION_IF_NULL(first_event);
|
||||||
|
if (first_event->type == kInit) {
|
||||||
|
if (mem_events.size() > 1) {
|
||||||
|
auto &second_event = mem_events[1];
|
||||||
|
MS_EXCEPTION_IF_NULL(second_event);
|
||||||
|
first_event->index = second_event->index;
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ((first_event->type == kInit || first_event->type == kMalloc) &&
|
||||||
|
first_event->index < pre_compute_events_.size()) {
|
||||||
|
pre_compute_events_[first_event->index].emplace_back(first_event);
|
||||||
|
} else {
|
||||||
|
MS_LOG_EXCEPTION << "First event should be init or malloc!";
|
||||||
|
}
|
||||||
|
MemPriority priority = kMemPriorityLow;
|
||||||
|
auto iter = mem_priority_.find(first_event->key);
|
||||||
|
if (iter != mem_priority_.end()) {
|
||||||
|
priority = iter->second;
|
||||||
|
}
|
||||||
|
size_t pre_index = first_event->index;
|
||||||
|
for (size_t i = 1; i < mem_events.size(); ++i) {
|
||||||
|
auto &event = mem_events[i];
|
||||||
|
MS_EXCEPTION_IF_NULL(event);
|
||||||
|
if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow &&
|
||||||
|
no_swap_events_.find(event) == no_swap_events_.end()) {
|
||||||
|
auto swap_out_event = std::make_shared<MemEvent>(kSwapOut, pre_index);
|
||||||
|
swap_out_event->key = item.first;
|
||||||
|
swap_out_event->mem_size = first_event->mem_size;
|
||||||
|
post_compute_events_[pre_index].emplace_back(swap_out_event);
|
||||||
|
auto swap_in_event = std::make_shared<MemEvent>(kSwapIn, event->index);
|
||||||
|
swap_in_event->key = item.first;
|
||||||
|
swap_in_event->mem_size = first_event->mem_size;
|
||||||
|
(void)pre_compute_events_[event->index].emplace_back(swap_in_event);
|
||||||
|
}
|
||||||
|
if (event->index < pre_compute_events_.size()) {
|
||||||
|
(void)pre_compute_events_[event->index].emplace_back(event);
|
||||||
|
}
|
||||||
|
pre_index = event->index;
|
||||||
|
}
|
||||||
|
if (priority != kMemPriorityLow) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto &last_event = mem_events[mem_events.size() - 1];
|
||||||
|
MS_EXCEPTION_IF_NULL(last_event);
|
||||||
|
auto free_event = std::make_shared<MemEvent>(kFree, last_event->index);
|
||||||
|
free_event->key = item.first;
|
||||||
|
if (last_event->index < post_compute_events_.size()) {
|
||||||
|
(void)post_compute_events_[last_event->index].emplace_back(free_event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,85 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_
|
||||||
|
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
enum MemPriority { kMemPriorityLow, kMemPriorityHigh };
|
||||||
|
|
||||||
|
enum MemEventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut };
|
||||||
|
|
||||||
|
struct MemEvent {
|
||||||
|
MemEvent(const MemEventType &in_type, size_t in_index) : type(in_type), index(in_index) {}
|
||||||
|
|
||||||
|
MemEventType type;
|
||||||
|
size_t index{0};
|
||||||
|
size_t mem_size{0};
|
||||||
|
const void *key{nullptr};
|
||||||
|
};
|
||||||
|
|
||||||
|
class MemOffloadStrategy {
|
||||||
|
public:
|
||||||
|
MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority,
|
||||||
|
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events,
|
||||||
|
size_t total_step)
|
||||||
|
: mem_priority_(mem_priority), mem_events_(mem_events), total_step_(total_step) {}
|
||||||
|
|
||||||
|
virtual ~MemOffloadStrategy() = default;
|
||||||
|
|
||||||
|
virtual void Execute();
|
||||||
|
|
||||||
|
void SetComputeTime(const std::vector<double> &compute_time) { compute_time_ = compute_time; }
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<MemEvent>> &GetPreComputeEvents(size_t step);
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<MemEvent>> &GetPostComputeEvents(size_t step);
|
||||||
|
|
||||||
|
void set_mem_size(size_t mem_size) { mem_size_ = mem_size; }
|
||||||
|
|
||||||
|
bool need_swap() const { return need_swap_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void CountMemUsage();
|
||||||
|
void CheckMemSize();
|
||||||
|
void GenEventSpan();
|
||||||
|
void GenNoSwapEventSet();
|
||||||
|
void GenComputeMemEvents();
|
||||||
|
|
||||||
|
const std::map<const void *, MemPriority> &mem_priority_;
|
||||||
|
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_;
|
||||||
|
const size_t total_step_;
|
||||||
|
std::vector<std::vector<std::shared_ptr<MemEvent>>> pre_compute_events_;
|
||||||
|
std::vector<std::vector<std::shared_ptr<MemEvent>>> post_compute_events_;
|
||||||
|
|
||||||
|
size_t mem_size_{0};
|
||||||
|
std::vector<double> compute_time_;
|
||||||
|
bool need_swap_{false};
|
||||||
|
std::multimap<size_t, std::shared_ptr<MemEvent>> event_span_;
|
||||||
|
std::set<std::shared_ptr<MemEvent>> no_swap_events_;
|
||||||
|
std::vector<size_t> min_mem_used_;
|
||||||
|
size_t mem_used_without_swap_{0};
|
||||||
|
size_t min_mem_needed_{0};
|
||||||
|
};
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_
|
|
@ -17,9 +17,30 @@
|
||||||
#include "runtime/device/memory_scheduler.h"
|
#include "runtime/device/memory_scheduler.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
#include <time.h>
|
||||||
|
#else
|
||||||
|
#include <sys/time.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
|
namespace {
|
||||||
|
constexpr float kMaxMemReuseFactor = 0.9;
|
||||||
|
constexpr float kMinMemReuseFactor = 0.5;
|
||||||
|
constexpr float kRetryFactor = 0.1;
|
||||||
|
|
||||||
|
double GetCurrentTime() {
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
return time(NULL) * 1.0e6;
|
||||||
|
#else
|
||||||
|
struct timeval tv;
|
||||||
|
(void)gettimeofday(&tv, nullptr);
|
||||||
|
return tv.tv_sec * 1.0e6 + tv.tv_usec;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void MemScheduler::Clear() {
|
void MemScheduler::Clear() {
|
||||||
if (mem_handler_ == nullptr) {
|
if (mem_handler_ == nullptr) {
|
||||||
return;
|
return;
|
||||||
|
@ -40,23 +61,26 @@ bool MemScheduler::IsHighPriorityMem(const void *key) {
|
||||||
|
|
||||||
void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
|
void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
|
||||||
|
|
||||||
void MemScheduler::Record(const void *key, const EventType &event_type, size_t mem_size) {
|
void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
|
||||||
if (key == nullptr) {
|
if (key == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto event = std::make_shared<Event>(event_type, compute_index_);
|
auto event = std::make_shared<MemEvent>(event_type, current_step_);
|
||||||
event->mem_size = mem_size;
|
event->mem_size = mem_size;
|
||||||
event->key = key;
|
event->key = key;
|
||||||
mem_events_[key].emplace_back(event);
|
mem_events_[key].emplace_back(event);
|
||||||
|
if (step_events_.size() < current_step_ + 1) {
|
||||||
|
step_events_.resize(current_step_ + 1);
|
||||||
|
}
|
||||||
|
step_events_[current_step_].emplace_back(event);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemScheduler::Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority) {
|
void MemScheduler::Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority) {
|
||||||
if (need_record_event_) {
|
if (need_record_event_) {
|
||||||
mem_priority_[key] = priority;
|
mem_priority_[key] = priority;
|
||||||
Record(key, kInit, mem_size);
|
Record(key, kInit, mem_size);
|
||||||
} else {
|
|
||||||
init_host_ptr_[key] = host_ptr;
|
|
||||||
}
|
}
|
||||||
|
init_host_ptr_[key] = host_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority priority) {
|
void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority priority) {
|
||||||
|
@ -69,7 +93,7 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (!has_compute_mem_events_) {
|
if (strategy_ == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto iter = mem_result_.find(key);
|
auto iter = mem_result_.find(key);
|
||||||
|
@ -83,18 +107,14 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MemScheduler::PreCompute(void *stream) {
|
bool MemScheduler::PreCompute(void *stream) {
|
||||||
if (!has_compute_mem_events_) {
|
if (strategy_ == nullptr) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
MS_EXCEPTION_IF_NULL(mem_handler_);
|
||||||
if (pre_compute_events_.size() <= compute_index_) {
|
auto &events = strategy_->GetPreComputeEvents(current_step_);
|
||||||
MS_LOG_EXCEPTION << "Index out of pre event range, index:" << compute_index_
|
|
||||||
<< ", event size:" << pre_compute_events_.size();
|
|
||||||
}
|
|
||||||
auto &events = pre_compute_events_[compute_index_];
|
|
||||||
for (auto &event : events) {
|
for (auto &event : events) {
|
||||||
MS_EXCEPTION_IF_NULL(event);
|
MS_EXCEPTION_IF_NULL(event);
|
||||||
MS_LOG(DEBUG) << "Pre compute " << compute_index_ << ": " << event->key << " v " << event->type;
|
MS_LOG(DEBUG) << "Pre compute " << current_step_ << ": " << event->key << " v " << event->type;
|
||||||
if (event->type == kInit || event->type == kMalloc) {
|
if (event->type == kInit || event->type == kMalloc) {
|
||||||
auto priority = mem_priority_[event->key];
|
auto priority = mem_priority_[event->key];
|
||||||
auto iter = high_priority_device_ptr_.find(event->key);
|
auto iter = high_priority_device_ptr_.find(event->key);
|
||||||
|
@ -137,23 +157,27 @@ bool MemScheduler::PreCompute(void *stream) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (record_compute_time_ && !updated_) {
|
||||||
|
compute_start_time_ = GetCurrentTime();
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MemScheduler::PostCompute(void *stream) {
|
bool MemScheduler::PostCompute(void *stream) {
|
||||||
if (!has_compute_mem_events_) {
|
if (strategy_ == nullptr) {
|
||||||
++compute_index_;
|
++current_step_;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
|
||||||
if (post_compute_events_.size() <= compute_index_) {
|
if (record_compute_time_ && !updated_) {
|
||||||
MS_LOG_EXCEPTION << "Index out of post event range, index:" << compute_index_
|
compute_time_[current_step_] = GetCurrentTime() - compute_start_time_;
|
||||||
<< ", event size:" << post_compute_events_.size();
|
|
||||||
}
|
}
|
||||||
auto &events = post_compute_events_[compute_index_];
|
|
||||||
|
auto &events = strategy_->GetPostComputeEvents(current_step_);
|
||||||
for (auto &event : events) {
|
for (auto &event : events) {
|
||||||
MS_EXCEPTION_IF_NULL(event);
|
MS_EXCEPTION_IF_NULL(event);
|
||||||
MS_LOG(DEBUG) << "Post compute " << compute_index_ << ": " << event->key << " v " << event->type;
|
MS_LOG(DEBUG) << "Post compute " << current_step_ << ": " << event->key << " v " << event->type;
|
||||||
if (event->type == kFree) {
|
if (event->type == kFree) {
|
||||||
auto ptr = mem_result_[event->key];
|
auto ptr = mem_result_[event->key];
|
||||||
if (ptr == nullptr) {
|
if (ptr == nullptr) {
|
||||||
|
@ -174,201 +198,81 @@ bool MemScheduler::PostCompute(void *stream) {
|
||||||
MS_EXCEPTION_IF_NULL(host_ptr);
|
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||||
mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
|
mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
|
||||||
mem_handler_->FreeDevice(device_ptr);
|
mem_handler_->FreeDevice(device_ptr);
|
||||||
(void)mem_result_.erase(device_ptr);
|
(void)mem_result_.erase(event->key);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
++compute_index_;
|
++current_step_;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemScheduler::OptMemUsage() {
|
void MemScheduler::OptMemUsage(float mem_used_factor) {
|
||||||
if (optimized_) {
|
mem_used_factor_ = mem_used_factor;
|
||||||
return;
|
|
||||||
}
|
|
||||||
CountMemUsage();
|
|
||||||
CheckMemSize();
|
|
||||||
if (need_swap_) {
|
|
||||||
GenEventSpan();
|
|
||||||
GenNoSwapEventSet();
|
|
||||||
}
|
|
||||||
GenComputeMemEvents();
|
|
||||||
has_compute_mem_events_ = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void MemScheduler::CountMemUsage() {
|
|
||||||
if (!min_mem_used_.empty()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (mem_events_.empty() || compute_index_ == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
min_mem_used_.resize(compute_index_, 0);
|
|
||||||
std::vector<size_t> total_mem_used(compute_index_, 0);
|
|
||||||
for (auto &item : mem_events_) {
|
|
||||||
auto &mem_events = item.second;
|
|
||||||
if (mem_events.empty()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto first_event = mem_events[0];
|
|
||||||
size_t cur_index = 0;
|
|
||||||
if (first_event != nullptr && first_event->type == kInit && mem_events.size() > 1) {
|
|
||||||
first_event = mem_events[1];
|
|
||||||
cur_index = 1;
|
|
||||||
}
|
|
||||||
auto last_event = mem_events[mem_events.size() - 1];
|
|
||||||
for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
|
|
||||||
if (start_index < compute_index_) {
|
|
||||||
total_mem_used[start_index] += first_event->mem_size;
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Error mem event index " << start_index;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (; cur_index < mem_events.size(); ++cur_index) {
|
|
||||||
auto &event = mem_events[cur_index];
|
|
||||||
MS_EXCEPTION_IF_NULL(event);
|
|
||||||
if (event->index < compute_index_) {
|
|
||||||
min_mem_used_[event->index] += first_event->mem_size;
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "Error mem event index " << event->index;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
|
|
||||||
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end()));
|
|
||||||
}
|
|
||||||
|
|
||||||
void MemScheduler::CheckMemSize() {
|
|
||||||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
MS_EXCEPTION_IF_NULL(mem_handler_);
|
||||||
|
|
||||||
|
if (strategy_ == nullptr) {
|
||||||
|
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, total_step_);
|
||||||
|
compute_time_.resize(total_step_);
|
||||||
|
}
|
||||||
|
|
||||||
auto available_mem_size = mem_handler_->GetAvailableMemSize();
|
auto available_mem_size = mem_handler_->GetAvailableMemSize();
|
||||||
if (available_mem_size < min_mem_needed_) {
|
available_mem_size = available_mem_size * mem_used_factor_;
|
||||||
MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << available_mem_size
|
strategy_->set_mem_size(available_mem_size);
|
||||||
<< " while graph needs at least " << min_mem_needed_;
|
strategy_->Execute();
|
||||||
}
|
|
||||||
if (mem_used_without_swap_ > available_mem_size) {
|
|
||||||
need_swap_ = true;
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "Available mem size: " << available_mem_size << ", graph needs mem size: " << mem_used_without_swap_
|
|
||||||
<< " without swap, and needs at least " << min_mem_needed_ << " with swap.";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemScheduler::GenEventSpan() {
|
void MemScheduler::Optimize() {
|
||||||
if (!event_span_.empty()) {
|
float mem_used_factor = kMaxMemReuseFactor;
|
||||||
|
while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) {
|
||||||
|
OptMemUsage(mem_used_factor);
|
||||||
|
current_step_ = 0;
|
||||||
|
bool ret = true;
|
||||||
|
for (size_t step = 0; step < total_step_; ++step) {
|
||||||
|
ret = PreCompute(nullptr);
|
||||||
|
auto &step_events = step_events_[step];
|
||||||
|
for (auto &event : step_events) {
|
||||||
|
if (event->type != kGet) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto ptr = GetOrMalloc(event->key, event->mem_size);
|
||||||
|
if (ptr == nullptr) {
|
||||||
|
ret = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!ret) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
PostCompute(nullptr);
|
||||||
|
}
|
||||||
|
if (ret) {
|
||||||
|
optimized_ = true;
|
||||||
|
} else {
|
||||||
|
mem_used_factor -= kRetryFactor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemScheduler::Update() {
|
||||||
|
if (!optimized_) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (auto &item : mem_events_) {
|
|
||||||
auto &tensor_events = item.second;
|
|
||||||
if (tensor_events.empty()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto first_event = tensor_events[0];
|
|
||||||
size_t cur_index = 0;
|
|
||||||
if (first_event != nullptr && first_event->type == kInit && tensor_events.size() > 1) {
|
|
||||||
first_event = tensor_events[1];
|
|
||||||
cur_index = 1;
|
|
||||||
}
|
|
||||||
size_t last_index = first_event->index;
|
|
||||||
for (; cur_index < tensor_events.size(); ++cur_index) {
|
|
||||||
auto &event = tensor_events[cur_index];
|
|
||||||
MS_EXCEPTION_IF_NULL(event);
|
|
||||||
auto span = event->index - last_index;
|
|
||||||
if (span > 1) {
|
|
||||||
(void)event_span_.emplace(std::pair<size_t, std::shared_ptr<Event>>(span, event));
|
|
||||||
}
|
|
||||||
last_index = event->index;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void MemScheduler::GenNoSwapEventSet() {
|
if (strategy_ == nullptr || !strategy_->need_swap()) {
|
||||||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
return;
|
||||||
auto available_mem_size = mem_handler_->GetAvailableMemSize();
|
|
||||||
auto threshold = available_mem_size * mem_used_factor_;
|
|
||||||
no_swap_events_.clear();
|
|
||||||
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
|
|
||||||
for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) {
|
|
||||||
auto span = iter->first;
|
|
||||||
auto &event = iter->second;
|
|
||||||
auto start_index = event->index - span + 1;
|
|
||||||
bool revert = false;
|
|
||||||
for (size_t i = start_index; i < event->index; ++i) {
|
|
||||||
cur_mem_used[i] += event->mem_size;
|
|
||||||
if (cur_mem_used[i] > threshold) {
|
|
||||||
revert = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (revert) {
|
|
||||||
for (size_t i = start_index; i < event->index; ++i) {
|
|
||||||
cur_mem_used[i] -= event->mem_size;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
(void)no_swap_events_.emplace(event);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
void MemScheduler::GenComputeMemEvents() {
|
if (updated_) {
|
||||||
pre_compute_events_.clear();
|
return;
|
||||||
post_compute_events_.clear();
|
|
||||||
pre_compute_events_.resize(compute_index_);
|
|
||||||
post_compute_events_.resize(compute_index_);
|
|
||||||
for (auto &item : mem_events_) {
|
|
||||||
auto &mem_events = item.second;
|
|
||||||
if (mem_events.empty()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto first_event = mem_events[0];
|
|
||||||
MS_EXCEPTION_IF_NULL(first_event);
|
|
||||||
if (first_event->type == kInit) {
|
|
||||||
if (mem_events.size() > 1) {
|
|
||||||
auto &second_event = mem_events[1];
|
|
||||||
MS_EXCEPTION_IF_NULL(second_event);
|
|
||||||
first_event->index = second_event->index;
|
|
||||||
} else {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if ((first_event->type == kInit || first_event->type == kMalloc) &&
|
|
||||||
first_event->index < pre_compute_events_.size()) {
|
|
||||||
pre_compute_events_[first_event->index].emplace_back(first_event);
|
|
||||||
} else {
|
|
||||||
MS_LOG_EXCEPTION << "First event should be init or malloc!";
|
|
||||||
}
|
|
||||||
MemPriority priority = kMemPriorityLow;
|
|
||||||
auto iter = mem_priority_.find(first_event->key);
|
|
||||||
if (iter != mem_priority_.end()) {
|
|
||||||
priority = iter->second;
|
|
||||||
}
|
|
||||||
size_t pre_index = first_event->index;
|
|
||||||
for (size_t i = 1; i < mem_events.size(); ++i) {
|
|
||||||
auto &event = mem_events[i];
|
|
||||||
MS_EXCEPTION_IF_NULL(event);
|
|
||||||
if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow &&
|
|
||||||
no_swap_events_.find(event) == no_swap_events_.end()) {
|
|
||||||
auto swap_out_event = std::make_shared<Event>(kSwapOut, pre_index);
|
|
||||||
swap_out_event->key = item.first;
|
|
||||||
swap_out_event->mem_size = first_event->mem_size;
|
|
||||||
post_compute_events_[pre_index].emplace_back(swap_out_event);
|
|
||||||
auto swap_in_event = std::make_shared<Event>(kSwapIn, event->index);
|
|
||||||
swap_in_event->key = item.first;
|
|
||||||
swap_in_event->mem_size = first_event->mem_size;
|
|
||||||
(void)pre_compute_events_[event->index].emplace_back(swap_in_event);
|
|
||||||
}
|
|
||||||
if (event->index < pre_compute_events_.size()) {
|
|
||||||
(void)pre_compute_events_[event->index].emplace_back(event);
|
|
||||||
}
|
|
||||||
pre_index = event->index;
|
|
||||||
}
|
|
||||||
if (priority != kMemPriorityLow) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto &last_event = mem_events[mem_events.size() - 1];
|
|
||||||
MS_EXCEPTION_IF_NULL(last_event);
|
|
||||||
auto free_event = std::make_shared<Event>(kFree, last_event->index);
|
|
||||||
free_event->key = item.first;
|
|
||||||
if (last_event->index < post_compute_events_.size()) {
|
|
||||||
(void)post_compute_events_[last_event->index].emplace_back(free_event);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!record_compute_time_) {
|
||||||
|
record_compute_time_ = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
strategy_->SetComputeTime(compute_time_);
|
||||||
|
strategy_->Execute();
|
||||||
|
updated_ = true;
|
||||||
}
|
}
|
||||||
} // namespace device
|
} // namespace device
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include "runtime/device/memory_offload_strategy.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
|
@ -35,23 +36,7 @@ class MemHandler {
|
||||||
virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0;
|
virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
enum MemPriority { kMemPriorityLow, kMemPriorityHigh };
|
|
||||||
|
|
||||||
class MemScheduler {
|
class MemScheduler {
|
||||||
enum EventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut };
|
|
||||||
|
|
||||||
struct Event {
|
|
||||||
Event(const EventType &in_type, size_t in_index) {
|
|
||||||
type = in_type;
|
|
||||||
index = in_index;
|
|
||||||
}
|
|
||||||
|
|
||||||
EventType type;
|
|
||||||
size_t index{0};
|
|
||||||
size_t mem_size{0};
|
|
||||||
const void *key{nullptr};
|
|
||||||
};
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
MemScheduler() = default;
|
MemScheduler() = default;
|
||||||
~MemScheduler() = default;
|
~MemScheduler() = default;
|
||||||
|
@ -62,7 +47,7 @@ class MemScheduler {
|
||||||
|
|
||||||
bool optimized() const { return optimized_; }
|
bool optimized() const { return optimized_; }
|
||||||
|
|
||||||
void set_optimized(bool flag) { optimized_ = flag; }
|
void Update();
|
||||||
|
|
||||||
void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
|
void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
|
||||||
|
|
||||||
|
@ -70,13 +55,18 @@ class MemScheduler {
|
||||||
|
|
||||||
void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow);
|
void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow);
|
||||||
|
|
||||||
void Reset() { compute_index_ = 0; }
|
void SetTotalStep(size_t step) {
|
||||||
|
total_step_ = step;
|
||||||
|
step_events_.resize(total_step_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ResetCurrentStep() { current_step_ = 0; }
|
||||||
|
|
||||||
bool PreCompute(void *stream);
|
bool PreCompute(void *stream);
|
||||||
|
|
||||||
bool PostCompute(void *stream);
|
bool PostCompute(void *stream);
|
||||||
|
|
||||||
void OptMemUsage();
|
void Optimize();
|
||||||
|
|
||||||
void Clear();
|
void Clear();
|
||||||
|
|
||||||
|
@ -84,37 +74,29 @@ class MemScheduler {
|
||||||
|
|
||||||
void SetMemPriority(const void *key, MemPriority priority);
|
void SetMemPriority(const void *key, MemPriority priority);
|
||||||
|
|
||||||
void SetMemUsedFactor(float factor) { mem_used_factor_ = factor; }
|
|
||||||
|
|
||||||
void SetNeedSwap(bool flag) { need_swap_ = flag; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Record(const void *key, const EventType &event_type, size_t mem_size = 0);
|
void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
|
||||||
void GenComputeMemEvents();
|
|
||||||
void CheckMemSize();
|
void OptMemUsage(float mem_used_factor = 1.0f);
|
||||||
void CountMemUsage();
|
|
||||||
void GenEventSpan();
|
|
||||||
void GenNoSwapEventSet();
|
|
||||||
std::map<const void *, MemPriority> mem_priority_;
|
std::map<const void *, MemPriority> mem_priority_;
|
||||||
std::map<const void *, std::vector<std::shared_ptr<Event>>> mem_events_;
|
std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_;
|
||||||
std::vector<std::vector<std::shared_ptr<Event>>> pre_compute_events_;
|
std::vector<std::vector<std::shared_ptr<MemEvent>>> step_events_;
|
||||||
std::vector<std::vector<std::shared_ptr<Event>>> post_compute_events_;
|
|
||||||
std::map<const void *, void *> mem_result_;
|
std::map<const void *, void *> mem_result_;
|
||||||
std::map<const void *, void *> init_host_ptr_;
|
std::map<const void *, void *> init_host_ptr_;
|
||||||
std::map<const void *, void *> swap_host_ptr_;
|
std::map<const void *, void *> swap_host_ptr_;
|
||||||
std::map<const void *, void *> high_priority_device_ptr_;
|
std::map<const void *, void *> high_priority_device_ptr_;
|
||||||
size_t compute_index_{0};
|
size_t total_step_{0};
|
||||||
|
size_t current_step_{0};
|
||||||
bool need_record_event_{true};
|
bool need_record_event_{true};
|
||||||
bool optimized_{false};
|
bool optimized_{false};
|
||||||
bool has_compute_mem_events_{false};
|
|
||||||
std::shared_ptr<MemHandler> mem_handler_{nullptr};
|
|
||||||
bool need_swap_{false};
|
|
||||||
std::multimap<size_t, std::shared_ptr<Event>> event_span_;
|
|
||||||
std::set<std::shared_ptr<Event>> no_swap_events_;
|
|
||||||
std::vector<size_t> min_mem_used_;
|
|
||||||
size_t mem_used_without_swap_{0};
|
|
||||||
size_t min_mem_needed_{0};
|
|
||||||
float mem_used_factor_{0.9};
|
float mem_used_factor_{0.9};
|
||||||
|
double compute_start_time_{0};
|
||||||
|
std::vector<double> compute_time_;
|
||||||
|
bool record_compute_time_{false};
|
||||||
|
bool updated_{false};
|
||||||
|
std::shared_ptr<MemHandler> mem_handler_{nullptr};
|
||||||
|
std::shared_ptr<MemOffloadStrategy> strategy_{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
class MemSchedulerManager {
|
class MemSchedulerManager {
|
||||||
|
|
|
@ -114,6 +114,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
"../../../mindspore/ccsrc/runtime/device/kernel_runtime.cc"
|
"../../../mindspore/ccsrc/runtime/device/kernel_runtime.cc"
|
||||||
"../../../mindspore/ccsrc/runtime/device/memory_manager.cc"
|
"../../../mindspore/ccsrc/runtime/device/memory_manager.cc"
|
||||||
"../../../mindspore/ccsrc/runtime/device/memory_scheduler.cc"
|
"../../../mindspore/ccsrc/runtime/device/memory_scheduler.cc"
|
||||||
|
"../../../mindspore/ccsrc/runtime/device/memory_offload_strategy.cc"
|
||||||
"../../../mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc"
|
"../../../mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc"
|
||||||
"../../../mindspore/ccsrc/runtime/device/kernel_info.cc"
|
"../../../mindspore/ccsrc/runtime/device/kernel_info.cc"
|
||||||
"../../../mindspore/ccsrc/runtime/device/bucket.cc"
|
"../../../mindspore/ccsrc/runtime/device/bucket.cc"
|
||||||
|
|
|
@ -104,6 +104,7 @@ TEST_F(TestMemScheduler, test_mem_scheduler) {
|
||||||
std::vector<size_t> init_tensors = {0, 2, 4};
|
std::vector<size_t> init_tensors = {0, 2, 4};
|
||||||
std::vector<std::vector<size_t>> step_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6}, {6, 7}, {2, 7, 8}, {4, 8, 9}};
|
std::vector<std::vector<size_t>> step_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6}, {6, 7}, {2, 7, 8}, {4, 8, 9}};
|
||||||
void *stream = nullptr;
|
void *stream = nullptr;
|
||||||
|
scheduler->SetTotalStep(kTimeSlice);
|
||||||
// record
|
// record
|
||||||
for (auto index : init_tensors) {
|
for (auto index : init_tensors) {
|
||||||
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
|
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
|
||||||
|
@ -118,11 +119,10 @@ TEST_F(TestMemScheduler, test_mem_scheduler) {
|
||||||
scheduler->set_need_record_event(false);
|
scheduler->set_need_record_event(false);
|
||||||
|
|
||||||
// optimize
|
// optimize
|
||||||
scheduler->OptMemUsage();
|
scheduler->Optimize();
|
||||||
scheduler->set_optimized(true);
|
|
||||||
|
|
||||||
// run
|
// run
|
||||||
scheduler->Reset();
|
scheduler->ResetCurrentStep();
|
||||||
for (auto index : init_tensors) {
|
for (auto index : init_tensors) {
|
||||||
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
|
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue