Adjust swap strategy

This commit is contained in:
tanghuikang 2021-11-22 16:20:41 +08:00
parent d996ad5e1e
commit 16ca537505
6 changed files with 191 additions and 110 deletions

View File

@ -1294,9 +1294,6 @@ void KernelRuntime::GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_
kernel_addr->addr = device_address->ptr_;
} else {
kernel_addr->addr = mem_scheduler->GetOrMalloc(device_address, device_address->size_);
if (mem_scheduler->IsHighPriorityMem(device_address)) {
device_address->ptr_ = kernel_addr->addr;
}
}
}
@ -1343,37 +1340,29 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem
}
void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
const session::KernelGraph &graph, const AnfNodePtr &kernel, bool mock) {
const session::KernelGraph &graph, const AnfNodePtr &kernel) {
MS_EXCEPTION_IF_NULL(mem_scheduler);
MS_EXCEPTION_IF_NULL(kernel);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
for (size_t input_idx = 0; input_idx < kernel_mod->GetInputSizeList().size(); ++input_idx) {
const auto input_node_index = AnfAlgo::GetPrevNodeOutput(kernel, input_idx, true);
if (input_node_index.first == nullptr || !input_node_index.first->isa<Parameter>()) {
continue;
if (input_node_index.first != nullptr && input_node_index.first->isa<Parameter>()) {
SyncNodeOutputTensor(mem_scheduler, input_node_index, graph);
}
SyncNodeOutputTensor(mem_scheduler, input_node_index, graph, mock);
}
for (size_t output_idx = 0; output_idx < kernel_mod->GetOutputSizeList().size(); ++output_idx) {
SyncNodeOutputTensor(mem_scheduler, std::make_pair(kernel, output_idx), graph, mock);
SyncNodeOutputTensor(mem_scheduler, std::make_pair(kernel, output_idx), graph);
}
}
void KernelRuntime::SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler,
const KernelWithIndex &node_output_index, const session::KernelGraph &graph,
bool mock) {
const KernelWithIndex &node_output_index, const session::KernelGraph &graph) {
MS_EXCEPTION_IF_NULL(mem_scheduler);
if (node_output_index.first == nullptr) {
return;
}
auto device_address = AnfAlgo::GetMutableOutputAddr(node_output_index, true);
if (mock) {
if (graph.IsInternalOutput(node_output_index.first, node_output_index.second) && device_address != nullptr) {
mem_scheduler->SetMemPriority(device_address.get(), kMemPriorityHigh);
}
return;
}
auto tensor = graph.GetNodeOutputTensor(node_output_index);
if (tensor == nullptr) {
return;
@ -1407,22 +1396,20 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m
MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
}
for (size_t i = 0; i < input_tensors.size(); ++i) {
auto tensor = input_tensors[i];
MS_EXCEPTION_IF_NULL(tensor);
auto input_node = input_nodes[i];
if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) {
continue;
}
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
auto tensor = input_tensors[i];
MS_EXCEPTION_IF_NULL(tensor);
MemPriority priority = kMemPriorityLow;
auto tensor_address = tensor->device_address();
if (!tensor->NeedSyncHostToDevice() && tensor_address != nullptr && tensor_address != device_address) {
tensor->data_sync(false);
}
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) ||
MemPriority priority = kMemPriorityLow;
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) &&
graph.IsUpdatedParameter(input_node->cast<ParameterPtr>())) {
tensor->set_device_address(device_address);
priority = kMemPriorityHigh;
}
auto tensor_size = LongToSize(tensor->data().nbytes());
@ -1477,7 +1464,9 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
}
}
if (mem_scheduler != nullptr) {
SyncNodeOutputTensors(mem_scheduler, graph, kernel, mock);
if (!mock) {
SyncNodeOutputTensors(mem_scheduler, graph, kernel);
}
ret = mem_scheduler->PostCompute(stream);
if (!ret) {
return ret;
@ -1553,9 +1542,43 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
}
LaunchKernelEvent(kernel_post_run_events, kernels[i]);
}
if (UseMemScheduler() && !mock) {
SyncUpdatedParameter(graph, mem_scheduler);
}
return true;
}
void KernelRuntime::SyncUpdatedParameter(const session::KernelGraph &graph,
const std::shared_ptr<MemScheduler> &mem_scheduler) {
MS_EXCEPTION_IF_NULL(mem_scheduler);
auto &input_nodes = graph.input_nodes();
auto &input_tensors = graph.input_tensors();
if (input_tensors.size() != input_nodes.size()) {
MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
}
for (size_t i = 0; i < input_tensors.size(); ++i) {
auto input_node = input_nodes[i];
if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) {
continue;
}
auto parameter = input_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
if (!graph.IsUpdatedParameter(parameter)) {
continue;
}
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
auto tensor = input_tensors[i];
MS_EXCEPTION_IF_NULL(tensor);
auto device_ptr = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size(), kMemPriorityHigh);
if (device_ptr != nullptr) {
device_address->set_ptr(device_ptr);
tensor->set_device_address(device_address);
tensor->set_sync_status(kNeedSyncDeviceToHost);
}
}
}
void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);

View File

@ -95,6 +95,7 @@ class KernelRuntime {
void set_device_id(uint32_t device_id) { device_id_ = device_id; }
uint32_t device_id() { return device_id_; }
static bool UseMemScheduler();
void SyncUpdatedParameter(const session::KernelGraph &graph, const std::shared_ptr<MemScheduler> &mem_scheduler);
#ifdef ENABLE_DEBUGGER
// set debugger
@ -155,9 +156,9 @@ class KernelRuntime {
const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr);
void InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph);
void SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph,
const AnfNodePtr &kernel, bool mock);
const AnfNodePtr &kernel);
void SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler, const KernelWithIndex &output,
const session::KernelGraph &graph, bool mock);
const session::KernelGraph &graph);
void AssignCommunicationMem(const session::KernelGraph &graph);
bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false);

View File

@ -43,7 +43,7 @@ void MemOffloadStrategy::Execute() {
CheckMemSize();
if (need_swap_) {
GenEventSpan();
GenNoSwapEventSet();
GenSwapEventSet();
}
GenComputeMemEvents();
}
@ -57,37 +57,41 @@ void MemOffloadStrategy::CountMemUsage() {
}
min_mem_used_.resize(total_step_, 0);
std::vector<size_t> total_mem_used(total_step_, 0);
size_t high_priority_mem_size = 0;
for (auto &item : mem_events_) {
auto &mem_events = item.second;
if (mem_events.empty()) {
continue;
}
auto first_event = mem_events[0];
size_t cur_index = 0;
if (first_event != nullptr && first_event->type == kInit && mem_events.size() > 1) {
first_event = mem_events[1];
cur_index = 1;
}
auto last_event = mem_events[mem_events.size() - 1];
for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
if (start_index < total_step_) {
const bool is_high_priority = IsHighPriorityMem(first_event->key);
if (is_high_priority) {
high_priority_mem_size += first_event->mem_size;
} else {
auto last_event = mem_events[mem_events.size() - 1];
for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
total_mem_used[start_index] += first_event->mem_size;
} else {
MS_LOG(ERROR) << "Error mem event index " << start_index;
}
}
for (; cur_index < mem_events.size(); ++cur_index) {
auto &event = mem_events[cur_index];
// Calculate the minimum memory size for kernel execution.
for (const auto &event : mem_events) {
MS_EXCEPTION_IF_NULL(event);
if (event->index < total_step_) {
min_mem_used_[event->index] += first_event->mem_size;
} else {
MS_LOG(ERROR) << "Error mem event index " << event->index;
if (event->type != kGet) {
continue;
}
min_mem_used_[event->index] += first_event->mem_size;
}
}
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end()));
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())) + high_priority_mem_size;
}
bool MemOffloadStrategy::IsHighPriorityMem(const void *key) {
auto iter = mem_priority_.find(key);
if (iter != mem_priority_.end()) {
return iter->second == kMemPriorityHigh;
}
return false;
}
void MemOffloadStrategy::CheckMemSize() {
@ -110,48 +114,60 @@ void MemOffloadStrategy::GenEventSpan() {
}
for (auto &item : mem_events_) {
auto &tensor_events = item.second;
if (tensor_events.empty()) {
if (tensor_events.size() <= 1) {
continue;
}
auto first_event = tensor_events[0];
size_t cur_index = 0;
if (first_event != nullptr && first_event->type == kInit && tensor_events.size() > 1) {
first_event = tensor_events[1];
cur_index = 1;
}
size_t last_index = first_event->index;
for (; cur_index < tensor_events.size(); ++cur_index) {
auto &event = tensor_events[cur_index];
const bool is_high_priority = IsHighPriorityMem(tensor_events[0]->key);
for (size_t event_index = 1; event_index < tensor_events.size(); ++event_index) {
auto &event = tensor_events[event_index];
MS_EXCEPTION_IF_NULL(event);
auto span = event->index - last_index;
if (span > 1) {
(void)event_span_.emplace(span, event);
if (event->type != kGet) {
MS_LOG(EXCEPTION) << "Event should be Get except fist event.";
}
size_t span = 0;
if (event_index == 1 && is_high_priority) {
const auto &last_event = tensor_events[tensor_events.size() - 1];
span = event->index + total_step_ - last_event->index;
} else {
span = event->index - tensor_events[event_index - 1]->index;
}
if (span > 1) {
const size_t span_mul_size = (span - 1) * event->mem_size;
(void)event_span_.emplace(std::make_pair(span_mul_size, std::make_pair(event, span)));
}
last_index = event->index;
}
}
}
void MemOffloadStrategy::GenNoSwapEventSet() {
no_swap_events_.clear();
void MemOffloadStrategy::GenSwapEventSet() {
swap_events_.clear();
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) {
auto span = iter->first;
auto &event = iter->second;
auto start_index = event->index - span + 1;
for (const auto &iter : event_span_) {
auto span = iter.second.second;
auto &event = iter.second.first;
auto start_index = ((total_step_ + event->index - span) % total_step_) + 1;
bool revert = false;
for (size_t i = start_index; i < event->index; ++i) {
cur_mem_used[i] += event->mem_size;
if (cur_mem_used[i] > mem_size_) {
size_t cur_index = start_index;
while (cur_index != event->index) {
cur_mem_used[cur_index] += event->mem_size;
if (cur_mem_used[cur_index] > mem_size_) {
revert = true;
}
cur_index += 1;
if (cur_index >= total_step_) {
cur_index = 0;
}
}
if (revert) {
for (size_t i = start_index; i < event->index; ++i) {
cur_mem_used[i] -= event->mem_size;
cur_index = start_index;
while (cur_index != event->index) {
cur_mem_used[cur_index] -= event->mem_size;
cur_index += 1;
if (cur_index >= total_step_) {
cur_index = 0;
}
}
} else {
(void)no_swap_events_.emplace(event);
(void)swap_events_.emplace(event);
}
}
}
@ -166,34 +182,31 @@ void MemOffloadStrategy::GenComputeMemEvents() {
if (mem_events.empty()) {
continue;
}
// No need to generate events for memory that has only one event, which means it is never used by any kernel.
if (mem_events.size() <= 1) {
continue;
}
const bool is_high_priority = IsHighPriorityMem(item.first);
auto first_event = mem_events[0];
MS_EXCEPTION_IF_NULL(first_event);
if (first_event->type == kInit) {
if (mem_events.size() > 1) {
auto &second_event = mem_events[1];
MS_EXCEPTION_IF_NULL(second_event);
first_event->index = second_event->index;
} else {
continue;
}
const auto &second_event = mem_events[1];
MS_EXCEPTION_IF_NULL(second_event);
if (is_high_priority && swap_events_.find(second_event) != swap_events_.end()) {
first_event->index = second_event->index;
}
if ((first_event->type == kInit || first_event->type == kMalloc) &&
first_event->index < pre_compute_events_.size()) {
if ((first_event->type == kInit || first_event->type == kMalloc) && first_event->index < total_step_) {
pre_compute_events_[first_event->index].emplace_back(first_event);
} else {
MS_LOG_EXCEPTION << "First event should be init or malloc!";
}
MemPriority priority = kMemPriorityLow;
auto iter = mem_priority_.find(first_event->key);
if (iter != mem_priority_.end()) {
priority = iter->second;
}
size_t pre_index = first_event->index;
const auto &last_event = mem_events[mem_events.size() - 1];
size_t pre_index = is_high_priority ? last_event->index : first_event->index;
for (size_t i = 1; i < mem_events.size(); ++i) {
auto &event = mem_events[i];
MS_EXCEPTION_IF_NULL(event);
if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow &&
no_swap_events_.find(event) == no_swap_events_.end()) {
if (need_swap_ && swap_events_.find(event) != swap_events_.end()) {
auto swap_out_event = std::make_shared<MemEvent>(kSwapOut, pre_index);
swap_out_event->key = item.first;
swap_out_event->mem_size = first_event->mem_size;
@ -208,17 +221,19 @@ void MemOffloadStrategy::GenComputeMemEvents() {
}
pre_index = event->index;
}
if (priority != kMemPriorityLow) {
continue;
}
auto &last_event = mem_events[mem_events.size() - 1];
MS_EXCEPTION_IF_NULL(last_event);
auto free_event = std::make_shared<MemEvent>(kFree, last_event->index);
free_event->key = item.first;
if (last_event->index < post_compute_events_.size()) {
(void)post_compute_events_[last_event->index].emplace_back(free_event);
if (!is_high_priority) {
GenFreeEvent(last_event);
}
}
}
void MemOffloadStrategy::GenFreeEvent(const std::shared_ptr<MemEvent> &last_event) {
MS_EXCEPTION_IF_NULL(last_event);
auto free_event = std::make_shared<MemEvent>(kFree, last_event->index);
free_event->key = last_event->key;
if (last_event->index < post_compute_events_.size()) {
(void)post_compute_events_[last_event->index].emplace_back(free_event);
}
}
} // namespace device
} // namespace mindspore

View File

@ -58,12 +58,15 @@ class MemOffloadStrategy {
bool need_swap() const { return need_swap_; }
bool IsHighPriorityMem(const void *key);
private:
void CountMemUsage();
void CheckMemSize();
void GenEventSpan();
void GenNoSwapEventSet();
void GenSwapEventSet();
void GenComputeMemEvents();
void GenFreeEvent(const std::shared_ptr<MemEvent> &last_event);
const std::map<const void *, MemPriority> &mem_priority_;
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_;
@ -74,8 +77,8 @@ class MemOffloadStrategy {
size_t mem_size_{0};
std::vector<double> compute_time_;
bool need_swap_{false};
std::multimap<size_t, std::shared_ptr<MemEvent>> event_span_;
std::set<std::shared_ptr<MemEvent>> no_swap_events_;
std::multimap<size_t, std::pair<std::shared_ptr<MemEvent>, size_t>> event_span_;
std::set<std::shared_ptr<MemEvent>> swap_events_;
std::vector<size_t> min_mem_used_;
size_t mem_used_without_swap_{0};
size_t min_mem_needed_{0};

View File

@ -26,7 +26,7 @@
namespace mindspore {
namespace device {
namespace {
constexpr float kMaxMemReuseFactor = 0.9;
constexpr float kMaxMemReuseFactor = 1.0;
constexpr float kMinMemReuseFactor = 0.5;
constexpr float kRetryFactor = 0.1;
@ -51,12 +51,25 @@ void MemScheduler::Clear() {
high_priority_device_ptr_.clear();
}
bool MemScheduler::IsHighPriorityMem(const void *key) {
auto iter = mem_priority_.find(key);
if (iter != mem_priority_.end()) {
return iter->second == kMemPriorityHigh;
void MemScheduler::ClearTempMem() {
if (mem_handler_ == nullptr) {
return;
}
return false;
for (auto &item : mem_result_) {
const auto device_ptr = item.second;
if (device_ptr == nullptr) {
mem_handler_->FreeDevice(device_ptr);
}
}
mem_result_.clear();
high_priority_device_ptr_.clear();
for (const auto &item : swap_host_ptr_) {
const auto host_ptr = item.second;
if (host_ptr != nullptr) {
mem_handler_->FreeHost(host_ptr);
}
}
swap_host_ptr_.clear();
}
void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
@ -88,9 +101,8 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
if (mem_priority_.find(key) == mem_priority_.end()) {
mem_priority_[key] = priority;
Record(key, kMalloc, mem_size);
} else {
Record(key, kGet, mem_size);
}
Record(key, kGet, mem_size);
return nullptr;
}
if (strategy_ == nullptr) {
@ -101,9 +113,8 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
auto ptr = iter->second;
MS_EXCEPTION_IF_NULL(ptr);
return ptr;
} else {
MS_LOG_EXCEPTION << "Mem extender get nullptr result!";
}
return nullptr;
}
bool MemScheduler::PreCompute(void *stream) {
@ -151,6 +162,9 @@ bool MemScheduler::PreCompute(void *stream) {
MS_EXCEPTION_IF_NULL(host_ptr);
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
mem_result_[event->key] = device_ptr;
if (mem_priority_[event->key] == kMemPriorityHigh) {
high_priority_device_ptr_[event->key] = device_ptr;
}
if (!from_init) {
mem_handler_->FreeHost(host_ptr);
(void)swap_host_ptr_.erase(event->key);
@ -199,6 +213,9 @@ bool MemScheduler::PostCompute(void *stream) {
mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
mem_handler_->FreeDevice(device_ptr);
(void)mem_result_.erase(event->key);
if (mem_priority_[event->key] == kMemPriorityHigh) {
high_priority_device_ptr_.erase(event->key);
}
}
}
++current_step_;
@ -221,6 +238,7 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
}
void MemScheduler::Optimize() {
AdjustFirstEventIndex();
float mem_used_factor = kMaxMemReuseFactor;
while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) {
OptMemUsage(mem_used_factor);
@ -247,11 +265,30 @@ void MemScheduler::Optimize() {
if (ret) {
optimized_ = true;
} else {
ClearTempMem();
mem_used_factor -= kRetryFactor;
}
}
}
void MemScheduler::AdjustFirstEventIndex() {
for (const auto &item : mem_events_) {
const auto &mem_events = item.second;
if (mem_events.empty()) {
continue;
}
auto &first_event = mem_events[0];
MS_EXCEPTION_IF_NULL(first_event);
const auto &priority_iter = mem_priority_.find(item.first);
const bool is_high_priority = (priority_iter != mem_priority_.end() && priority_iter->second == kMemPriorityHigh);
if (first_event->type == kInit && !is_high_priority && mem_events.size() > 1) {
const auto &second_event = mem_events[1];
MS_EXCEPTION_IF_NULL(second_event);
first_event->index = second_event->index;
}
}
}
void MemScheduler::Update() {
if (!optimized_) {
return;

View File

@ -70,7 +70,7 @@ class MemScheduler {
void Clear();
bool IsHighPriorityMem(const void *key);
void ClearTempMem();
void SetMemPriority(const void *key, MemPriority priority);
@ -79,6 +79,8 @@ class MemScheduler {
void OptMemUsage(float mem_used_factor = 1.0f);
void AdjustFirstEventIndex();
std::map<const void *, MemPriority> mem_priority_;
std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_;
std::vector<std::vector<std::shared_ptr<MemEvent>>> step_events_;