Adjust swap strategy

This commit is contained in:
tanghuikang 2021-11-22 16:20:41 +08:00
parent d996ad5e1e
commit 16ca537505
6 changed files with 191 additions and 110 deletions

View File

@ -1294,9 +1294,6 @@ void KernelRuntime::GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_
kernel_addr->addr = device_address->ptr_; kernel_addr->addr = device_address->ptr_;
} else { } else {
kernel_addr->addr = mem_scheduler->GetOrMalloc(device_address, device_address->size_); kernel_addr->addr = mem_scheduler->GetOrMalloc(device_address, device_address->size_);
if (mem_scheduler->IsHighPriorityMem(device_address)) {
device_address->ptr_ = kernel_addr->addr;
}
} }
} }
@ -1343,37 +1340,29 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem
} }
void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
const session::KernelGraph &graph, const AnfNodePtr &kernel, bool mock) { const session::KernelGraph &graph, const AnfNodePtr &kernel) {
MS_EXCEPTION_IF_NULL(mem_scheduler); MS_EXCEPTION_IF_NULL(mem_scheduler);
MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel); auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod); MS_EXCEPTION_IF_NULL(kernel_mod);
for (size_t input_idx = 0; input_idx < kernel_mod->GetInputSizeList().size(); ++input_idx) { for (size_t input_idx = 0; input_idx < kernel_mod->GetInputSizeList().size(); ++input_idx) {
const auto input_node_index = AnfAlgo::GetPrevNodeOutput(kernel, input_idx, true); const auto input_node_index = AnfAlgo::GetPrevNodeOutput(kernel, input_idx, true);
if (input_node_index.first == nullptr || !input_node_index.first->isa<Parameter>()) { if (input_node_index.first != nullptr && input_node_index.first->isa<Parameter>()) {
continue; SyncNodeOutputTensor(mem_scheduler, input_node_index, graph);
} }
SyncNodeOutputTensor(mem_scheduler, input_node_index, graph, mock);
} }
for (size_t output_idx = 0; output_idx < kernel_mod->GetOutputSizeList().size(); ++output_idx) { for (size_t output_idx = 0; output_idx < kernel_mod->GetOutputSizeList().size(); ++output_idx) {
SyncNodeOutputTensor(mem_scheduler, std::make_pair(kernel, output_idx), graph, mock); SyncNodeOutputTensor(mem_scheduler, std::make_pair(kernel, output_idx), graph);
} }
} }
void KernelRuntime::SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler, void KernelRuntime::SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler,
const KernelWithIndex &node_output_index, const session::KernelGraph &graph, const KernelWithIndex &node_output_index, const session::KernelGraph &graph) {
bool mock) {
MS_EXCEPTION_IF_NULL(mem_scheduler); MS_EXCEPTION_IF_NULL(mem_scheduler);
if (node_output_index.first == nullptr) { if (node_output_index.first == nullptr) {
return; return;
} }
auto device_address = AnfAlgo::GetMutableOutputAddr(node_output_index, true); auto device_address = AnfAlgo::GetMutableOutputAddr(node_output_index, true);
if (mock) {
if (graph.IsInternalOutput(node_output_index.first, node_output_index.second) && device_address != nullptr) {
mem_scheduler->SetMemPriority(device_address.get(), kMemPriorityHigh);
}
return;
}
auto tensor = graph.GetNodeOutputTensor(node_output_index); auto tensor = graph.GetNodeOutputTensor(node_output_index);
if (tensor == nullptr) { if (tensor == nullptr) {
return; return;
@ -1407,22 +1396,20 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m
MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size(); MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
} }
for (size_t i = 0; i < input_tensors.size(); ++i) { for (size_t i = 0; i < input_tensors.size(); ++i) {
auto tensor = input_tensors[i];
MS_EXCEPTION_IF_NULL(tensor);
auto input_node = input_nodes[i]; auto input_node = input_nodes[i];
if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) { if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) {
continue; continue;
} }
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
auto tensor = input_tensors[i];
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
MemPriority priority = kMemPriorityLow;
auto tensor_address = tensor->device_address(); auto tensor_address = tensor->device_address();
if (!tensor->NeedSyncHostToDevice() && tensor_address != nullptr && tensor_address != device_address) { if (!tensor->NeedSyncHostToDevice() && tensor_address != nullptr && tensor_address != device_address) {
tensor->data_sync(false); tensor->data_sync(false);
} }
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) || MemPriority priority = kMemPriorityLow;
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) &&
graph.IsUpdatedParameter(input_node->cast<ParameterPtr>())) { graph.IsUpdatedParameter(input_node->cast<ParameterPtr>())) {
tensor->set_device_address(device_address);
priority = kMemPriorityHigh; priority = kMemPriorityHigh;
} }
auto tensor_size = LongToSize(tensor->data().nbytes()); auto tensor_size = LongToSize(tensor->data().nbytes());
@ -1477,7 +1464,9 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
} }
} }
if (mem_scheduler != nullptr) { if (mem_scheduler != nullptr) {
SyncNodeOutputTensors(mem_scheduler, graph, kernel, mock); if (!mock) {
SyncNodeOutputTensors(mem_scheduler, graph, kernel);
}
ret = mem_scheduler->PostCompute(stream); ret = mem_scheduler->PostCompute(stream);
if (!ret) { if (!ret) {
return ret; return ret;
@ -1553,9 +1542,43 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
} }
LaunchKernelEvent(kernel_post_run_events, kernels[i]); LaunchKernelEvent(kernel_post_run_events, kernels[i]);
} }
if (UseMemScheduler() && !mock) {
SyncUpdatedParameter(graph, mem_scheduler);
}
return true; return true;
} }
void KernelRuntime::SyncUpdatedParameter(const session::KernelGraph &graph,
const std::shared_ptr<MemScheduler> &mem_scheduler) {
MS_EXCEPTION_IF_NULL(mem_scheduler);
auto &input_nodes = graph.input_nodes();
auto &input_tensors = graph.input_tensors();
if (input_tensors.size() != input_nodes.size()) {
MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
}
for (size_t i = 0; i < input_tensors.size(); ++i) {
auto input_node = input_nodes[i];
if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) {
continue;
}
auto parameter = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
if (!graph.IsUpdatedParameter(parameter)) {
continue;
}
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
auto tensor = input_tensors[i];
MS_EXCEPTION_IF_NULL(tensor);
auto device_ptr = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size(), kMemPriorityHigh);
if (device_ptr != nullptr) {
device_address->set_ptr(device_ptr);
tensor->set_device_address(device_address);
tensor->set_sync_status(kNeedSyncDeviceToHost);
}
}
}
void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) { void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);

View File

@ -95,6 +95,7 @@ class KernelRuntime {
void set_device_id(uint32_t device_id) { device_id_ = device_id; } void set_device_id(uint32_t device_id) { device_id_ = device_id; }
uint32_t device_id() { return device_id_; } uint32_t device_id() { return device_id_; }
static bool UseMemScheduler(); static bool UseMemScheduler();
void SyncUpdatedParameter(const session::KernelGraph &graph, const std::shared_ptr<MemScheduler> &mem_scheduler);
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
// set debugger // set debugger
@ -155,9 +156,9 @@ class KernelRuntime {
const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr); const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr);
void InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph); void InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph);
void SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph, void SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph,
const AnfNodePtr &kernel, bool mock); const AnfNodePtr &kernel);
void SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler, const KernelWithIndex &output, void SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler, const KernelWithIndex &output,
const session::KernelGraph &graph, bool mock); const session::KernelGraph &graph);
void AssignCommunicationMem(const session::KernelGraph &graph); void AssignCommunicationMem(const session::KernelGraph &graph);
bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false); bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false);

View File

@ -43,7 +43,7 @@ void MemOffloadStrategy::Execute() {
CheckMemSize(); CheckMemSize();
if (need_swap_) { if (need_swap_) {
GenEventSpan(); GenEventSpan();
GenNoSwapEventSet(); GenSwapEventSet();
} }
GenComputeMemEvents(); GenComputeMemEvents();
} }
@ -57,37 +57,41 @@ void MemOffloadStrategy::CountMemUsage() {
} }
min_mem_used_.resize(total_step_, 0); min_mem_used_.resize(total_step_, 0);
std::vector<size_t> total_mem_used(total_step_, 0); std::vector<size_t> total_mem_used(total_step_, 0);
size_t high_priority_mem_size = 0;
for (auto &item : mem_events_) { for (auto &item : mem_events_) {
auto &mem_events = item.second; auto &mem_events = item.second;
if (mem_events.empty()) { if (mem_events.empty()) {
continue; continue;
} }
auto first_event = mem_events[0]; auto first_event = mem_events[0];
size_t cur_index = 0; const bool is_high_priority = IsHighPriorityMem(first_event->key);
if (first_event != nullptr && first_event->type == kInit && mem_events.size() > 1) { if (is_high_priority) {
first_event = mem_events[1]; high_priority_mem_size += first_event->mem_size;
cur_index = 1; } else {
} auto last_event = mem_events[mem_events.size() - 1];
auto last_event = mem_events[mem_events.size() - 1]; for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
if (start_index < total_step_) {
total_mem_used[start_index] += first_event->mem_size; total_mem_used[start_index] += first_event->mem_size;
} else {
MS_LOG(ERROR) << "Error mem event index " << start_index;
} }
} }
for (; cur_index < mem_events.size(); ++cur_index) { // Calculate the minimum memory size for kernel execution.
auto &event = mem_events[cur_index]; for (const auto &event : mem_events) {
MS_EXCEPTION_IF_NULL(event); MS_EXCEPTION_IF_NULL(event);
if (event->index < total_step_) { if (event->type != kGet) {
min_mem_used_[event->index] += first_event->mem_size; continue;
} else {
MS_LOG(ERROR) << "Error mem event index " << event->index;
} }
min_mem_used_[event->index] += first_event->mem_size;
} }
} }
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end())); min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())); mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())) + high_priority_mem_size;
}
bool MemOffloadStrategy::IsHighPriorityMem(const void *key) {
auto iter = mem_priority_.find(key);
if (iter != mem_priority_.end()) {
return iter->second == kMemPriorityHigh;
}
return false;
} }
void MemOffloadStrategy::CheckMemSize() { void MemOffloadStrategy::CheckMemSize() {
@ -110,48 +114,60 @@ void MemOffloadStrategy::GenEventSpan() {
} }
for (auto &item : mem_events_) { for (auto &item : mem_events_) {
auto &tensor_events = item.second; auto &tensor_events = item.second;
if (tensor_events.empty()) { if (tensor_events.size() <= 1) {
continue; continue;
} }
auto first_event = tensor_events[0]; const bool is_high_priority = IsHighPriorityMem(tensor_events[0]->key);
size_t cur_index = 0; for (size_t event_index = 1; event_index < tensor_events.size(); ++event_index) {
if (first_event != nullptr && first_event->type == kInit && tensor_events.size() > 1) { auto &event = tensor_events[event_index];
first_event = tensor_events[1];
cur_index = 1;
}
size_t last_index = first_event->index;
for (; cur_index < tensor_events.size(); ++cur_index) {
auto &event = tensor_events[cur_index];
MS_EXCEPTION_IF_NULL(event); MS_EXCEPTION_IF_NULL(event);
auto span = event->index - last_index; if (event->type != kGet) {
if (span > 1) { MS_LOG(EXCEPTION) << "Event should be Get except fist event.";
(void)event_span_.emplace(span, event); }
size_t span = 0;
if (event_index == 1 && is_high_priority) {
const auto &last_event = tensor_events[tensor_events.size() - 1];
span = event->index + total_step_ - last_event->index;
} else {
span = event->index - tensor_events[event_index - 1]->index;
}
if (span > 1) {
const size_t span_mul_size = (span - 1) * event->mem_size;
(void)event_span_.emplace(std::make_pair(span_mul_size, std::make_pair(event, span)));
} }
last_index = event->index;
} }
} }
} }
void MemOffloadStrategy::GenNoSwapEventSet() { void MemOffloadStrategy::GenSwapEventSet() {
no_swap_events_.clear(); swap_events_.clear();
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end()); std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) { for (const auto &iter : event_span_) {
auto span = iter->first; auto span = iter.second.second;
auto &event = iter->second; auto &event = iter.second.first;
auto start_index = event->index - span + 1; auto start_index = ((total_step_ + event->index - span) % total_step_) + 1;
bool revert = false; bool revert = false;
for (size_t i = start_index; i < event->index; ++i) { size_t cur_index = start_index;
cur_mem_used[i] += event->mem_size; while (cur_index != event->index) {
if (cur_mem_used[i] > mem_size_) { cur_mem_used[cur_index] += event->mem_size;
if (cur_mem_used[cur_index] > mem_size_) {
revert = true; revert = true;
} }
cur_index += 1;
if (cur_index >= total_step_) {
cur_index = 0;
}
} }
if (revert) { if (revert) {
for (size_t i = start_index; i < event->index; ++i) { cur_index = start_index;
cur_mem_used[i] -= event->mem_size; while (cur_index != event->index) {
cur_mem_used[cur_index] -= event->mem_size;
cur_index += 1;
if (cur_index >= total_step_) {
cur_index = 0;
}
} }
} else { (void)swap_events_.emplace(event);
(void)no_swap_events_.emplace(event);
} }
} }
} }
@ -166,34 +182,31 @@ void MemOffloadStrategy::GenComputeMemEvents() {
if (mem_events.empty()) { if (mem_events.empty()) {
continue; continue;
} }
// No need to generate events for memory that has only one event, which means it is never used by any kernel.
if (mem_events.size() <= 1) {
continue;
}
const bool is_high_priority = IsHighPriorityMem(item.first);
auto first_event = mem_events[0]; auto first_event = mem_events[0];
MS_EXCEPTION_IF_NULL(first_event); MS_EXCEPTION_IF_NULL(first_event);
if (first_event->type == kInit) { const auto &second_event = mem_events[1];
if (mem_events.size() > 1) { MS_EXCEPTION_IF_NULL(second_event);
auto &second_event = mem_events[1]; if (is_high_priority && swap_events_.find(second_event) != swap_events_.end()) {
MS_EXCEPTION_IF_NULL(second_event); first_event->index = second_event->index;
first_event->index = second_event->index;
} else {
continue;
}
} }
if ((first_event->type == kInit || first_event->type == kMalloc) && if ((first_event->type == kInit || first_event->type == kMalloc) && first_event->index < total_step_) {
first_event->index < pre_compute_events_.size()) {
pre_compute_events_[first_event->index].emplace_back(first_event); pre_compute_events_[first_event->index].emplace_back(first_event);
} else { } else {
MS_LOG_EXCEPTION << "First event should be init or malloc!"; MS_LOG_EXCEPTION << "First event should be init or malloc!";
} }
MemPriority priority = kMemPriorityLow;
auto iter = mem_priority_.find(first_event->key); const auto &last_event = mem_events[mem_events.size() - 1];
if (iter != mem_priority_.end()) { size_t pre_index = is_high_priority ? last_event->index : first_event->index;
priority = iter->second;
}
size_t pre_index = first_event->index;
for (size_t i = 1; i < mem_events.size(); ++i) { for (size_t i = 1; i < mem_events.size(); ++i) {
auto &event = mem_events[i]; auto &event = mem_events[i];
MS_EXCEPTION_IF_NULL(event); MS_EXCEPTION_IF_NULL(event);
if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow && if (need_swap_ && swap_events_.find(event) != swap_events_.end()) {
no_swap_events_.find(event) == no_swap_events_.end()) {
auto swap_out_event = std::make_shared<MemEvent>(kSwapOut, pre_index); auto swap_out_event = std::make_shared<MemEvent>(kSwapOut, pre_index);
swap_out_event->key = item.first; swap_out_event->key = item.first;
swap_out_event->mem_size = first_event->mem_size; swap_out_event->mem_size = first_event->mem_size;
@ -208,17 +221,19 @@ void MemOffloadStrategy::GenComputeMemEvents() {
} }
pre_index = event->index; pre_index = event->index;
} }
if (priority != kMemPriorityLow) { if (!is_high_priority) {
continue; GenFreeEvent(last_event);
}
auto &last_event = mem_events[mem_events.size() - 1];
MS_EXCEPTION_IF_NULL(last_event);
auto free_event = std::make_shared<MemEvent>(kFree, last_event->index);
free_event->key = item.first;
if (last_event->index < post_compute_events_.size()) {
(void)post_compute_events_[last_event->index].emplace_back(free_event);
} }
} }
} }
void MemOffloadStrategy::GenFreeEvent(const std::shared_ptr<MemEvent> &last_event) {
MS_EXCEPTION_IF_NULL(last_event);
auto free_event = std::make_shared<MemEvent>(kFree, last_event->index);
free_event->key = last_event->key;
if (last_event->index < post_compute_events_.size()) {
(void)post_compute_events_[last_event->index].emplace_back(free_event);
}
}
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore

View File

@ -58,12 +58,15 @@ class MemOffloadStrategy {
bool need_swap() const { return need_swap_; } bool need_swap() const { return need_swap_; }
bool IsHighPriorityMem(const void *key);
private: private:
void CountMemUsage(); void CountMemUsage();
void CheckMemSize(); void CheckMemSize();
void GenEventSpan(); void GenEventSpan();
void GenNoSwapEventSet(); void GenSwapEventSet();
void GenComputeMemEvents(); void GenComputeMemEvents();
void GenFreeEvent(const std::shared_ptr<MemEvent> &last_event);
const std::map<const void *, MemPriority> &mem_priority_; const std::map<const void *, MemPriority> &mem_priority_;
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_; const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_;
@ -74,8 +77,8 @@ class MemOffloadStrategy {
size_t mem_size_{0}; size_t mem_size_{0};
std::vector<double> compute_time_; std::vector<double> compute_time_;
bool need_swap_{false}; bool need_swap_{false};
std::multimap<size_t, std::shared_ptr<MemEvent>> event_span_; std::multimap<size_t, std::pair<std::shared_ptr<MemEvent>, size_t>> event_span_;
std::set<std::shared_ptr<MemEvent>> no_swap_events_; std::set<std::shared_ptr<MemEvent>> swap_events_;
std::vector<size_t> min_mem_used_; std::vector<size_t> min_mem_used_;
size_t mem_used_without_swap_{0}; size_t mem_used_without_swap_{0};
size_t min_mem_needed_{0}; size_t min_mem_needed_{0};

View File

@ -26,7 +26,7 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace { namespace {
constexpr float kMaxMemReuseFactor = 0.9; constexpr float kMaxMemReuseFactor = 1.0;
constexpr float kMinMemReuseFactor = 0.5; constexpr float kMinMemReuseFactor = 0.5;
constexpr float kRetryFactor = 0.1; constexpr float kRetryFactor = 0.1;
@ -51,12 +51,25 @@ void MemScheduler::Clear() {
high_priority_device_ptr_.clear(); high_priority_device_ptr_.clear();
} }
bool MemScheduler::IsHighPriorityMem(const void *key) { void MemScheduler::ClearTempMem() {
auto iter = mem_priority_.find(key); if (mem_handler_ == nullptr) {
if (iter != mem_priority_.end()) { return;
return iter->second == kMemPriorityHigh;
} }
return false; for (auto &item : mem_result_) {
const auto device_ptr = item.second;
if (device_ptr == nullptr) {
mem_handler_->FreeDevice(device_ptr);
}
}
mem_result_.clear();
high_priority_device_ptr_.clear();
for (const auto &item : swap_host_ptr_) {
const auto host_ptr = item.second;
if (host_ptr != nullptr) {
mem_handler_->FreeHost(host_ptr);
}
}
swap_host_ptr_.clear();
} }
void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; } void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
@ -88,9 +101,8 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
if (mem_priority_.find(key) == mem_priority_.end()) { if (mem_priority_.find(key) == mem_priority_.end()) {
mem_priority_[key] = priority; mem_priority_[key] = priority;
Record(key, kMalloc, mem_size); Record(key, kMalloc, mem_size);
} else {
Record(key, kGet, mem_size);
} }
Record(key, kGet, mem_size);
return nullptr; return nullptr;
} }
if (strategy_ == nullptr) { if (strategy_ == nullptr) {
@ -101,9 +113,8 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
auto ptr = iter->second; auto ptr = iter->second;
MS_EXCEPTION_IF_NULL(ptr); MS_EXCEPTION_IF_NULL(ptr);
return ptr; return ptr;
} else {
MS_LOG_EXCEPTION << "Mem extender get nullptr result!";
} }
return nullptr;
} }
bool MemScheduler::PreCompute(void *stream) { bool MemScheduler::PreCompute(void *stream) {
@ -151,6 +162,9 @@ bool MemScheduler::PreCompute(void *stream) {
MS_EXCEPTION_IF_NULL(host_ptr); MS_EXCEPTION_IF_NULL(host_ptr);
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream); mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
mem_result_[event->key] = device_ptr; mem_result_[event->key] = device_ptr;
if (mem_priority_[event->key] == kMemPriorityHigh) {
high_priority_device_ptr_[event->key] = device_ptr;
}
if (!from_init) { if (!from_init) {
mem_handler_->FreeHost(host_ptr); mem_handler_->FreeHost(host_ptr);
(void)swap_host_ptr_.erase(event->key); (void)swap_host_ptr_.erase(event->key);
@ -199,6 +213,9 @@ bool MemScheduler::PostCompute(void *stream) {
mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream); mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
mem_handler_->FreeDevice(device_ptr); mem_handler_->FreeDevice(device_ptr);
(void)mem_result_.erase(event->key); (void)mem_result_.erase(event->key);
if (mem_priority_[event->key] == kMemPriorityHigh) {
high_priority_device_ptr_.erase(event->key);
}
} }
} }
++current_step_; ++current_step_;
@ -221,6 +238,7 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
} }
void MemScheduler::Optimize() { void MemScheduler::Optimize() {
AdjustFirstEventIndex();
float mem_used_factor = kMaxMemReuseFactor; float mem_used_factor = kMaxMemReuseFactor;
while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) { while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) {
OptMemUsage(mem_used_factor); OptMemUsage(mem_used_factor);
@ -247,11 +265,30 @@ void MemScheduler::Optimize() {
if (ret) { if (ret) {
optimized_ = true; optimized_ = true;
} else { } else {
ClearTempMem();
mem_used_factor -= kRetryFactor; mem_used_factor -= kRetryFactor;
} }
} }
} }
void MemScheduler::AdjustFirstEventIndex() {
for (const auto &item : mem_events_) {
const auto &mem_events = item.second;
if (mem_events.empty()) {
continue;
}
auto &first_event = mem_events[0];
MS_EXCEPTION_IF_NULL(first_event);
const auto &priority_iter = mem_priority_.find(item.first);
const bool is_high_priority = (priority_iter != mem_priority_.end() && priority_iter->second == kMemPriorityHigh);
if (first_event->type == kInit && !is_high_priority && mem_events.size() > 1) {
const auto &second_event = mem_events[1];
MS_EXCEPTION_IF_NULL(second_event);
first_event->index = second_event->index;
}
}
}
void MemScheduler::Update() { void MemScheduler::Update() {
if (!optimized_) { if (!optimized_) {
return; return;

View File

@ -70,7 +70,7 @@ class MemScheduler {
void Clear(); void Clear();
bool IsHighPriorityMem(const void *key); void ClearTempMem();
void SetMemPriority(const void *key, MemPriority priority); void SetMemPriority(const void *key, MemPriority priority);
@ -79,6 +79,8 @@ class MemScheduler {
void OptMemUsage(float mem_used_factor = 1.0f); void OptMemUsage(float mem_used_factor = 1.0f);
void AdjustFirstEventIndex();
std::map<const void *, MemPriority> mem_priority_; std::map<const void *, MemPriority> mem_priority_;
std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_; std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_;
std::vector<std::vector<std::shared_ptr<MemEvent>>> step_events_; std::vector<std::vector<std::shared_ptr<MemEvent>>> step_events_;