forked from mindspore-Ecosystem/mindspore
Adjust swap strategy
This commit is contained in:
parent
d996ad5e1e
commit
16ca537505
|
@ -1294,9 +1294,6 @@ void KernelRuntime::GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_
|
|||
kernel_addr->addr = device_address->ptr_;
|
||||
} else {
|
||||
kernel_addr->addr = mem_scheduler->GetOrMalloc(device_address, device_address->size_);
|
||||
if (mem_scheduler->IsHighPriorityMem(device_address)) {
|
||||
device_address->ptr_ = kernel_addr->addr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1343,37 +1340,29 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem
|
|||
}
|
||||
|
||||
void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
|
||||
const session::KernelGraph &graph, const AnfNodePtr &kernel, bool mock) {
|
||||
const session::KernelGraph &graph, const AnfNodePtr &kernel) {
|
||||
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
for (size_t input_idx = 0; input_idx < kernel_mod->GetInputSizeList().size(); ++input_idx) {
|
||||
const auto input_node_index = AnfAlgo::GetPrevNodeOutput(kernel, input_idx, true);
|
||||
if (input_node_index.first == nullptr || !input_node_index.first->isa<Parameter>()) {
|
||||
continue;
|
||||
if (input_node_index.first != nullptr && input_node_index.first->isa<Parameter>()) {
|
||||
SyncNodeOutputTensor(mem_scheduler, input_node_index, graph);
|
||||
}
|
||||
SyncNodeOutputTensor(mem_scheduler, input_node_index, graph, mock);
|
||||
}
|
||||
for (size_t output_idx = 0; output_idx < kernel_mod->GetOutputSizeList().size(); ++output_idx) {
|
||||
SyncNodeOutputTensor(mem_scheduler, std::make_pair(kernel, output_idx), graph, mock);
|
||||
SyncNodeOutputTensor(mem_scheduler, std::make_pair(kernel, output_idx), graph);
|
||||
}
|
||||
}
|
||||
|
||||
void KernelRuntime::SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler,
|
||||
const KernelWithIndex &node_output_index, const session::KernelGraph &graph,
|
||||
bool mock) {
|
||||
const KernelWithIndex &node_output_index, const session::KernelGraph &graph) {
|
||||
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||
if (node_output_index.first == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(node_output_index, true);
|
||||
if (mock) {
|
||||
if (graph.IsInternalOutput(node_output_index.first, node_output_index.second) && device_address != nullptr) {
|
||||
mem_scheduler->SetMemPriority(device_address.get(), kMemPriorityHigh);
|
||||
}
|
||||
return;
|
||||
}
|
||||
auto tensor = graph.GetNodeOutputTensor(node_output_index);
|
||||
if (tensor == nullptr) {
|
||||
return;
|
||||
|
@ -1407,22 +1396,20 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m
|
|||
MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
|
||||
}
|
||||
for (size_t i = 0; i < input_tensors.size(); ++i) {
|
||||
auto tensor = input_tensors[i];
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto input_node = input_nodes[i];
|
||||
if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) {
|
||||
continue;
|
||||
}
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
||||
auto tensor = input_tensors[i];
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
MemPriority priority = kMemPriorityLow;
|
||||
auto tensor_address = tensor->device_address();
|
||||
if (!tensor->NeedSyncHostToDevice() && tensor_address != nullptr && tensor_address != device_address) {
|
||||
tensor->data_sync(false);
|
||||
}
|
||||
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) ||
|
||||
MemPriority priority = kMemPriorityLow;
|
||||
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) &&
|
||||
graph.IsUpdatedParameter(input_node->cast<ParameterPtr>())) {
|
||||
tensor->set_device_address(device_address);
|
||||
priority = kMemPriorityHigh;
|
||||
}
|
||||
auto tensor_size = LongToSize(tensor->data().nbytes());
|
||||
|
@ -1477,7 +1464,9 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
|
|||
}
|
||||
}
|
||||
if (mem_scheduler != nullptr) {
|
||||
SyncNodeOutputTensors(mem_scheduler, graph, kernel, mock);
|
||||
if (!mock) {
|
||||
SyncNodeOutputTensors(mem_scheduler, graph, kernel);
|
||||
}
|
||||
ret = mem_scheduler->PostCompute(stream);
|
||||
if (!ret) {
|
||||
return ret;
|
||||
|
@ -1553,9 +1542,43 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
|
|||
}
|
||||
LaunchKernelEvent(kernel_post_run_events, kernels[i]);
|
||||
}
|
||||
if (UseMemScheduler() && !mock) {
|
||||
SyncUpdatedParameter(graph, mem_scheduler);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void KernelRuntime::SyncUpdatedParameter(const session::KernelGraph &graph,
|
||||
const std::shared_ptr<MemScheduler> &mem_scheduler) {
|
||||
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||
auto &input_nodes = graph.input_nodes();
|
||||
auto &input_tensors = graph.input_tensors();
|
||||
if (input_tensors.size() != input_nodes.size()) {
|
||||
MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < input_tensors.size(); ++i) {
|
||||
auto input_node = input_nodes[i];
|
||||
if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) {
|
||||
continue;
|
||||
}
|
||||
auto parameter = input_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
if (!graph.IsUpdatedParameter(parameter)) {
|
||||
continue;
|
||||
}
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
||||
auto tensor = input_tensors[i];
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto device_ptr = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size(), kMemPriorityHigh);
|
||||
if (device_ptr != nullptr) {
|
||||
device_address->set_ptr(device_ptr);
|
||||
tensor->set_device_address(device_address);
|
||||
tensor->set_sync_status(kNeedSyncDeviceToHost);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
|
|
|
@ -95,6 +95,7 @@ class KernelRuntime {
|
|||
void set_device_id(uint32_t device_id) { device_id_ = device_id; }
|
||||
uint32_t device_id() { return device_id_; }
|
||||
static bool UseMemScheduler();
|
||||
void SyncUpdatedParameter(const session::KernelGraph &graph, const std::shared_ptr<MemScheduler> &mem_scheduler);
|
||||
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
// set debugger
|
||||
|
@ -155,9 +156,9 @@ class KernelRuntime {
|
|||
const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr);
|
||||
void InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph);
|
||||
void SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph,
|
||||
const AnfNodePtr &kernel, bool mock);
|
||||
const AnfNodePtr &kernel);
|
||||
void SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler, const KernelWithIndex &output,
|
||||
const session::KernelGraph &graph, bool mock);
|
||||
const session::KernelGraph &graph);
|
||||
|
||||
void AssignCommunicationMem(const session::KernelGraph &graph);
|
||||
bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false);
|
||||
|
|
|
@ -43,7 +43,7 @@ void MemOffloadStrategy::Execute() {
|
|||
CheckMemSize();
|
||||
if (need_swap_) {
|
||||
GenEventSpan();
|
||||
GenNoSwapEventSet();
|
||||
GenSwapEventSet();
|
||||
}
|
||||
GenComputeMemEvents();
|
||||
}
|
||||
|
@ -57,37 +57,41 @@ void MemOffloadStrategy::CountMemUsage() {
|
|||
}
|
||||
min_mem_used_.resize(total_step_, 0);
|
||||
std::vector<size_t> total_mem_used(total_step_, 0);
|
||||
size_t high_priority_mem_size = 0;
|
||||
for (auto &item : mem_events_) {
|
||||
auto &mem_events = item.second;
|
||||
if (mem_events.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto first_event = mem_events[0];
|
||||
size_t cur_index = 0;
|
||||
if (first_event != nullptr && first_event->type == kInit && mem_events.size() > 1) {
|
||||
first_event = mem_events[1];
|
||||
cur_index = 1;
|
||||
}
|
||||
auto last_event = mem_events[mem_events.size() - 1];
|
||||
for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
|
||||
if (start_index < total_step_) {
|
||||
const bool is_high_priority = IsHighPriorityMem(first_event->key);
|
||||
if (is_high_priority) {
|
||||
high_priority_mem_size += first_event->mem_size;
|
||||
} else {
|
||||
auto last_event = mem_events[mem_events.size() - 1];
|
||||
for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
|
||||
total_mem_used[start_index] += first_event->mem_size;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Error mem event index " << start_index;
|
||||
}
|
||||
}
|
||||
for (; cur_index < mem_events.size(); ++cur_index) {
|
||||
auto &event = mem_events[cur_index];
|
||||
// Calculate the minimum memory size for kernel execution.
|
||||
for (const auto &event : mem_events) {
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
if (event->index < total_step_) {
|
||||
min_mem_used_[event->index] += first_event->mem_size;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Error mem event index " << event->index;
|
||||
if (event->type != kGet) {
|
||||
continue;
|
||||
}
|
||||
min_mem_used_[event->index] += first_event->mem_size;
|
||||
}
|
||||
}
|
||||
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
|
||||
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end()));
|
||||
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())) + high_priority_mem_size;
|
||||
}
|
||||
|
||||
bool MemOffloadStrategy::IsHighPriorityMem(const void *key) {
|
||||
auto iter = mem_priority_.find(key);
|
||||
if (iter != mem_priority_.end()) {
|
||||
return iter->second == kMemPriorityHigh;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void MemOffloadStrategy::CheckMemSize() {
|
||||
|
@ -110,48 +114,60 @@ void MemOffloadStrategy::GenEventSpan() {
|
|||
}
|
||||
for (auto &item : mem_events_) {
|
||||
auto &tensor_events = item.second;
|
||||
if (tensor_events.empty()) {
|
||||
if (tensor_events.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
auto first_event = tensor_events[0];
|
||||
size_t cur_index = 0;
|
||||
if (first_event != nullptr && first_event->type == kInit && tensor_events.size() > 1) {
|
||||
first_event = tensor_events[1];
|
||||
cur_index = 1;
|
||||
}
|
||||
size_t last_index = first_event->index;
|
||||
for (; cur_index < tensor_events.size(); ++cur_index) {
|
||||
auto &event = tensor_events[cur_index];
|
||||
const bool is_high_priority = IsHighPriorityMem(tensor_events[0]->key);
|
||||
for (size_t event_index = 1; event_index < tensor_events.size(); ++event_index) {
|
||||
auto &event = tensor_events[event_index];
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
auto span = event->index - last_index;
|
||||
if (span > 1) {
|
||||
(void)event_span_.emplace(span, event);
|
||||
if (event->type != kGet) {
|
||||
MS_LOG(EXCEPTION) << "Event should be Get except fist event.";
|
||||
}
|
||||
size_t span = 0;
|
||||
if (event_index == 1 && is_high_priority) {
|
||||
const auto &last_event = tensor_events[tensor_events.size() - 1];
|
||||
span = event->index + total_step_ - last_event->index;
|
||||
} else {
|
||||
span = event->index - tensor_events[event_index - 1]->index;
|
||||
}
|
||||
if (span > 1) {
|
||||
const size_t span_mul_size = (span - 1) * event->mem_size;
|
||||
(void)event_span_.emplace(std::make_pair(span_mul_size, std::make_pair(event, span)));
|
||||
}
|
||||
last_index = event->index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MemOffloadStrategy::GenNoSwapEventSet() {
|
||||
no_swap_events_.clear();
|
||||
void MemOffloadStrategy::GenSwapEventSet() {
|
||||
swap_events_.clear();
|
||||
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
|
||||
for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) {
|
||||
auto span = iter->first;
|
||||
auto &event = iter->second;
|
||||
auto start_index = event->index - span + 1;
|
||||
for (const auto &iter : event_span_) {
|
||||
auto span = iter.second.second;
|
||||
auto &event = iter.second.first;
|
||||
auto start_index = ((total_step_ + event->index - span) % total_step_) + 1;
|
||||
bool revert = false;
|
||||
for (size_t i = start_index; i < event->index; ++i) {
|
||||
cur_mem_used[i] += event->mem_size;
|
||||
if (cur_mem_used[i] > mem_size_) {
|
||||
size_t cur_index = start_index;
|
||||
while (cur_index != event->index) {
|
||||
cur_mem_used[cur_index] += event->mem_size;
|
||||
if (cur_mem_used[cur_index] > mem_size_) {
|
||||
revert = true;
|
||||
}
|
||||
cur_index += 1;
|
||||
if (cur_index >= total_step_) {
|
||||
cur_index = 0;
|
||||
}
|
||||
}
|
||||
if (revert) {
|
||||
for (size_t i = start_index; i < event->index; ++i) {
|
||||
cur_mem_used[i] -= event->mem_size;
|
||||
cur_index = start_index;
|
||||
while (cur_index != event->index) {
|
||||
cur_mem_used[cur_index] -= event->mem_size;
|
||||
cur_index += 1;
|
||||
if (cur_index >= total_step_) {
|
||||
cur_index = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(void)no_swap_events_.emplace(event);
|
||||
(void)swap_events_.emplace(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -166,34 +182,31 @@ void MemOffloadStrategy::GenComputeMemEvents() {
|
|||
if (mem_events.empty()) {
|
||||
continue;
|
||||
}
|
||||
// No need to generate events for memory that has only one event, which means it is never used by any kernel.
|
||||
if (mem_events.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const bool is_high_priority = IsHighPriorityMem(item.first);
|
||||
auto first_event = mem_events[0];
|
||||
MS_EXCEPTION_IF_NULL(first_event);
|
||||
if (first_event->type == kInit) {
|
||||
if (mem_events.size() > 1) {
|
||||
auto &second_event = mem_events[1];
|
||||
MS_EXCEPTION_IF_NULL(second_event);
|
||||
first_event->index = second_event->index;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
const auto &second_event = mem_events[1];
|
||||
MS_EXCEPTION_IF_NULL(second_event);
|
||||
if (is_high_priority && swap_events_.find(second_event) != swap_events_.end()) {
|
||||
first_event->index = second_event->index;
|
||||
}
|
||||
if ((first_event->type == kInit || first_event->type == kMalloc) &&
|
||||
first_event->index < pre_compute_events_.size()) {
|
||||
if ((first_event->type == kInit || first_event->type == kMalloc) && first_event->index < total_step_) {
|
||||
pre_compute_events_[first_event->index].emplace_back(first_event);
|
||||
} else {
|
||||
MS_LOG_EXCEPTION << "First event should be init or malloc!";
|
||||
}
|
||||
MemPriority priority = kMemPriorityLow;
|
||||
auto iter = mem_priority_.find(first_event->key);
|
||||
if (iter != mem_priority_.end()) {
|
||||
priority = iter->second;
|
||||
}
|
||||
size_t pre_index = first_event->index;
|
||||
|
||||
const auto &last_event = mem_events[mem_events.size() - 1];
|
||||
size_t pre_index = is_high_priority ? last_event->index : first_event->index;
|
||||
for (size_t i = 1; i < mem_events.size(); ++i) {
|
||||
auto &event = mem_events[i];
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow &&
|
||||
no_swap_events_.find(event) == no_swap_events_.end()) {
|
||||
if (need_swap_ && swap_events_.find(event) != swap_events_.end()) {
|
||||
auto swap_out_event = std::make_shared<MemEvent>(kSwapOut, pre_index);
|
||||
swap_out_event->key = item.first;
|
||||
swap_out_event->mem_size = first_event->mem_size;
|
||||
|
@ -208,17 +221,19 @@ void MemOffloadStrategy::GenComputeMemEvents() {
|
|||
}
|
||||
pre_index = event->index;
|
||||
}
|
||||
if (priority != kMemPriorityLow) {
|
||||
continue;
|
||||
}
|
||||
auto &last_event = mem_events[mem_events.size() - 1];
|
||||
MS_EXCEPTION_IF_NULL(last_event);
|
||||
auto free_event = std::make_shared<MemEvent>(kFree, last_event->index);
|
||||
free_event->key = item.first;
|
||||
if (last_event->index < post_compute_events_.size()) {
|
||||
(void)post_compute_events_[last_event->index].emplace_back(free_event);
|
||||
if (!is_high_priority) {
|
||||
GenFreeEvent(last_event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MemOffloadStrategy::GenFreeEvent(const std::shared_ptr<MemEvent> &last_event) {
|
||||
MS_EXCEPTION_IF_NULL(last_event);
|
||||
auto free_event = std::make_shared<MemEvent>(kFree, last_event->index);
|
||||
free_event->key = last_event->key;
|
||||
if (last_event->index < post_compute_events_.size()) {
|
||||
(void)post_compute_events_[last_event->index].emplace_back(free_event);
|
||||
}
|
||||
}
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -58,12 +58,15 @@ class MemOffloadStrategy {
|
|||
|
||||
bool need_swap() const { return need_swap_; }
|
||||
|
||||
bool IsHighPriorityMem(const void *key);
|
||||
|
||||
private:
|
||||
void CountMemUsage();
|
||||
void CheckMemSize();
|
||||
void GenEventSpan();
|
||||
void GenNoSwapEventSet();
|
||||
void GenSwapEventSet();
|
||||
void GenComputeMemEvents();
|
||||
void GenFreeEvent(const std::shared_ptr<MemEvent> &last_event);
|
||||
|
||||
const std::map<const void *, MemPriority> &mem_priority_;
|
||||
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_;
|
||||
|
@ -74,8 +77,8 @@ class MemOffloadStrategy {
|
|||
size_t mem_size_{0};
|
||||
std::vector<double> compute_time_;
|
||||
bool need_swap_{false};
|
||||
std::multimap<size_t, std::shared_ptr<MemEvent>> event_span_;
|
||||
std::set<std::shared_ptr<MemEvent>> no_swap_events_;
|
||||
std::multimap<size_t, std::pair<std::shared_ptr<MemEvent>, size_t>> event_span_;
|
||||
std::set<std::shared_ptr<MemEvent>> swap_events_;
|
||||
std::vector<size_t> min_mem_used_;
|
||||
size_t mem_used_without_swap_{0};
|
||||
size_t min_mem_needed_{0};
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace {
|
||||
constexpr float kMaxMemReuseFactor = 0.9;
|
||||
constexpr float kMaxMemReuseFactor = 1.0;
|
||||
constexpr float kMinMemReuseFactor = 0.5;
|
||||
constexpr float kRetryFactor = 0.1;
|
||||
|
||||
|
@ -51,12 +51,25 @@ void MemScheduler::Clear() {
|
|||
high_priority_device_ptr_.clear();
|
||||
}
|
||||
|
||||
bool MemScheduler::IsHighPriorityMem(const void *key) {
|
||||
auto iter = mem_priority_.find(key);
|
||||
if (iter != mem_priority_.end()) {
|
||||
return iter->second == kMemPriorityHigh;
|
||||
void MemScheduler::ClearTempMem() {
|
||||
if (mem_handler_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
return false;
|
||||
for (auto &item : mem_result_) {
|
||||
const auto device_ptr = item.second;
|
||||
if (device_ptr == nullptr) {
|
||||
mem_handler_->FreeDevice(device_ptr);
|
||||
}
|
||||
}
|
||||
mem_result_.clear();
|
||||
high_priority_device_ptr_.clear();
|
||||
for (const auto &item : swap_host_ptr_) {
|
||||
const auto host_ptr = item.second;
|
||||
if (host_ptr != nullptr) {
|
||||
mem_handler_->FreeHost(host_ptr);
|
||||
}
|
||||
}
|
||||
swap_host_ptr_.clear();
|
||||
}
|
||||
|
||||
void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
|
||||
|
@ -88,9 +101,8 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
|
|||
if (mem_priority_.find(key) == mem_priority_.end()) {
|
||||
mem_priority_[key] = priority;
|
||||
Record(key, kMalloc, mem_size);
|
||||
} else {
|
||||
Record(key, kGet, mem_size);
|
||||
}
|
||||
Record(key, kGet, mem_size);
|
||||
return nullptr;
|
||||
}
|
||||
if (strategy_ == nullptr) {
|
||||
|
@ -101,9 +113,8 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
|
|||
auto ptr = iter->second;
|
||||
MS_EXCEPTION_IF_NULL(ptr);
|
||||
return ptr;
|
||||
} else {
|
||||
MS_LOG_EXCEPTION << "Mem extender get nullptr result!";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool MemScheduler::PreCompute(void *stream) {
|
||||
|
@ -151,6 +162,9 @@ bool MemScheduler::PreCompute(void *stream) {
|
|||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
|
||||
mem_result_[event->key] = device_ptr;
|
||||
if (mem_priority_[event->key] == kMemPriorityHigh) {
|
||||
high_priority_device_ptr_[event->key] = device_ptr;
|
||||
}
|
||||
if (!from_init) {
|
||||
mem_handler_->FreeHost(host_ptr);
|
||||
(void)swap_host_ptr_.erase(event->key);
|
||||
|
@ -199,6 +213,9 @@ bool MemScheduler::PostCompute(void *stream) {
|
|||
mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
|
||||
mem_handler_->FreeDevice(device_ptr);
|
||||
(void)mem_result_.erase(event->key);
|
||||
if (mem_priority_[event->key] == kMemPriorityHigh) {
|
||||
high_priority_device_ptr_.erase(event->key);
|
||||
}
|
||||
}
|
||||
}
|
||||
++current_step_;
|
||||
|
@ -221,6 +238,7 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
|
|||
}
|
||||
|
||||
void MemScheduler::Optimize() {
|
||||
AdjustFirstEventIndex();
|
||||
float mem_used_factor = kMaxMemReuseFactor;
|
||||
while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) {
|
||||
OptMemUsage(mem_used_factor);
|
||||
|
@ -247,11 +265,30 @@ void MemScheduler::Optimize() {
|
|||
if (ret) {
|
||||
optimized_ = true;
|
||||
} else {
|
||||
ClearTempMem();
|
||||
mem_used_factor -= kRetryFactor;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MemScheduler::AdjustFirstEventIndex() {
|
||||
for (const auto &item : mem_events_) {
|
||||
const auto &mem_events = item.second;
|
||||
if (mem_events.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto &first_event = mem_events[0];
|
||||
MS_EXCEPTION_IF_NULL(first_event);
|
||||
const auto &priority_iter = mem_priority_.find(item.first);
|
||||
const bool is_high_priority = (priority_iter != mem_priority_.end() && priority_iter->second == kMemPriorityHigh);
|
||||
if (first_event->type == kInit && !is_high_priority && mem_events.size() > 1) {
|
||||
const auto &second_event = mem_events[1];
|
||||
MS_EXCEPTION_IF_NULL(second_event);
|
||||
first_event->index = second_event->index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MemScheduler::Update() {
|
||||
if (!optimized_) {
|
||||
return;
|
||||
|
|
|
@ -70,7 +70,7 @@ class MemScheduler {
|
|||
|
||||
void Clear();
|
||||
|
||||
bool IsHighPriorityMem(const void *key);
|
||||
void ClearTempMem();
|
||||
|
||||
void SetMemPriority(const void *key, MemPriority priority);
|
||||
|
||||
|
@ -79,6 +79,8 @@ class MemScheduler {
|
|||
|
||||
void OptMemUsage(float mem_used_factor = 1.0f);
|
||||
|
||||
void AdjustFirstEventIndex();
|
||||
|
||||
std::map<const void *, MemPriority> mem_priority_;
|
||||
std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_;
|
||||
std::vector<std::vector<std::shared_ptr<MemEvent>>> step_events_;
|
||||
|
|
Loading…
Reference in New Issue