!27390 Optimize swap strategy

Merge pull request !27390 from tanghuikang/swap_strategy_adjust
This commit is contained in:
i-robot 2021-12-13 07:03:55 +00:00 committed by Gitee
commit 5655c9f972
9 changed files with 153 additions and 70 deletions

View File

@ -1406,6 +1406,18 @@ bool AnfRuntimeAlgorithm::IsLabelIndexInNode(const AnfNodePtr &node, size_t labe
return false;
}
bool AnfRuntimeAlgorithm::IsUpdateParameterKernel(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto node_name = GetCNodeName(node);
if (HasNodeAttr(kAttrAsync, node) && GetNodeAttr<bool>(node, kAttrAsync)) {
return false;
}
if (kOptOperatorSet.find(node_name) == kOptOperatorSet.end() && node_name.find("Assign") == string::npos) {
return false;
}
return true;
}
void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());

View File

@ -246,6 +246,8 @@ class AnfRuntimeAlgorithm {
static bool IsParameterWeight(const ParameterPtr &node);
// checkout whether the anf node is include the label_index.
static bool IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index);
// Check whether the cnode update parameter
static bool IsUpdateParameterKernel(const CNodePtr &node);
// set stream id of kernel,which will be set in stream assign and be used in stream generate
static void SetStreamId(uint32_t stream_id, AnfNode *node);
// get stream id

View File

@ -1346,13 +1346,7 @@ void KernelGraph::SetOptimizerFlag() {
has_optimizer_ = false;
for (const auto &cnode : execution_order_) {
MS_EXCEPTION_IF_NULL(cnode);
auto node_name = AnfAlgo::GetCNodeName(cnode);
if (AnfAlgo::HasNodeAttr(kAttrAsync, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrAsync)) {
continue;
}
if (kOptOperatorSet.find(node_name) != kOptOperatorSet.end()) {
has_optimizer_ = true;
} else if (node_name.find("Assign") == string::npos) {
if (!AnfAlgo::IsUpdateParameterKernel(cnode)) {
continue;
}
for (auto &input : cnode->inputs()) {

View File

@ -307,10 +307,7 @@ class KernelGraph : public FuncGraph {
bool has_optimizer() const { return has_optimizer_; }
bool IsUpdatedParameter(const ParameterPtr &param) const {
if (updated_parameters_.find(param) != updated_parameters_.end()) {
return true;
}
return false;
return updated_parameters_.find(param) != updated_parameters_.end();
}
// handle graph dependency
void AddPreGraph(const std::shared_ptr<session::KernelGraph> &graph) {

View File

@ -1324,6 +1324,7 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
const auto update_parameter = AnfAlgo::IsUpdateParameterKernel(cnode);
for (size_t j = 0; j < input_num; ++j) {
auto real_input = AnfAlgo::GetRealInputIndex(kernel, j);
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, real_input, true);
@ -1335,6 +1336,14 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem
GetOrMallocAddress(mem_scheduler, device_address, input);
input->size = device_address->size_;
kernel_launch_info->inputs_.emplace_back(input);
if (update_parameter && input_node->isa<Parameter>()) {
auto param = input_node->cast<ParameterPtr>();
auto abstract = param->abstract();
MS_EXCEPTION_IF_NULL(abstract);
if (abstract->isa<abstract::AbstractRef>()) {
mem_scheduler->UpdateHighPriorityMem(device_address);
}
}
}
for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) {
@ -1410,6 +1419,7 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m
if (input_tensors.size() != input_nodes.size()) {
MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
}
mem_scheduler->ClearMemNeedInit();
for (size_t i = 0; i < input_tensors.size(); ++i) {
auto input_node = input_nodes[i];
if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) {
@ -1418,16 +1428,30 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
auto tensor = input_tensors[i];
MS_EXCEPTION_IF_NULL(tensor);
auto tensor_address = tensor->device_address();
if (!tensor->NeedSyncHostToDevice() && tensor_address != nullptr && tensor_address != device_address) {
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
const auto tensor_size = LongToSize(tensor->data().nbytes());
if (tensor_address == device_address) {
if (tensor->NeedSyncHostToDevice()) {
tensor_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), tensor->data().nbytes(),
tensor->data_type(), tensor->data_c(), tensor->device_info().host_format_);
tensor->set_sync_status(kNoNeedSync);
}
if (mem_scheduler->HasDeviceMem(tensor_address.get())) {
tensor_address->set_ptr(nullptr);
}
continue;
}
if (tensor->NeedSyncHostToDevice()) {
mem_scheduler->AddMemNeedInit(device_address.get());
} else if (tensor_address != nullptr) {
tensor->data_sync(false);
mem_scheduler->AddMemNeedInit(device_address.get());
}
MemPriority priority = kMemPriorityLow;
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) &&
graph.IsUpdatedParameter(input_node->cast<ParameterPtr>())) {
const auto &parameter = input_node->cast<ParameterPtr>();
if (AnfAlgo::IsParameterWeight(parameter) || graph.IsUpdatedParameter(parameter)) {
priority = kMemPriorityHigh;
}
auto tensor_size = LongToSize(tensor->data().nbytes());
mem_scheduler->Init(device_address.get(), tensor->data_c(), tensor_size, priority);
tensor->set_sync_status(kNoNeedSync);
}

View File

@ -22,6 +22,9 @@
namespace mindspore {
namespace device {
constexpr size_t kFirstGetMemEventIndex = 1;
constexpr size_t kInitOrMallocMemEventIndex = 0;
std::vector<std::shared_ptr<MemEvent>> &MemOffloadStrategy::GetPreComputeEvents(size_t step) {
if (pre_compute_events_.size() <= step) {
MS_LOG_EXCEPTION << "Index out of pre event range, index:" << step << ", event size:" << pre_compute_events_.size();
@ -62,7 +65,7 @@ void MemOffloadStrategy::CountMemUsage() {
if (mem_events.empty()) {
continue;
}
auto first_event = mem_events[0];
auto first_event = mem_events[kInitOrMallocMemEventIndex];
const bool is_high_priority = IsHighPriorityMem(first_event->key);
if (is_high_priority) {
high_priority_mem_size += first_event->mem_size;
@ -83,6 +86,10 @@ void MemOffloadStrategy::CountMemUsage() {
}
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())) + high_priority_mem_size;
if (mem_size_ < min_mem_needed_) {
MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << mem_size_ << " while graph needs at least "
<< min_mem_needed_;
}
}
bool MemOffloadStrategy::IsHighPriorityMem(const void *key) {
@ -94,11 +101,6 @@ bool MemOffloadStrategy::IsHighPriorityMem(const void *key) {
}
void MemOffloadStrategy::CheckMemSize() {
if (mem_size_ < min_mem_needed_) {
MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << mem_size_ << " while graph needs at least "
<< min_mem_needed_;
}
if (mem_size_ < mem_used_without_swap_ || !manual_offload_keys_.empty()) {
need_swap_ = true;
}
@ -116,19 +118,20 @@ void MemOffloadStrategy::GenEventSpan() {
if (tensor_events.size() <= 1) {
continue;
}
const bool is_high_priority = IsHighPriorityMem(tensor_events[0]->key);
for (size_t event_index = 1; event_index < tensor_events.size(); ++event_index) {
auto &event = tensor_events[event_index];
const bool is_high_priority = IsHighPriorityMem(tensor_events[kInitOrMallocMemEventIndex]->key);
for (size_t i = kFirstGetMemEventIndex; i < tensor_events.size(); ++i) {
auto &event = tensor_events[i];
MS_EXCEPTION_IF_NULL(event);
if (event->type != kGet) {
MS_LOG(EXCEPTION) << "Event should be Get except fist event.";
}
size_t span = 0;
if (event_index == 1 && is_high_priority) {
const auto &last_event = tensor_events[tensor_events.size() - 1];
span = event->index + total_step_ - last_event->index;
} else {
span = event->index - tensor_events[event_index - 1]->index;
auto latest_event = tensor_events[i - 1];
if (i == kFirstGetMemEventIndex && is_high_priority) {
latest_event = tensor_events[tensor_events.size() - 1];
}
auto span = GetSpanBetweenMemEvents(latest_event->index, event->index);
if (is_high_priority && span == 0 && latest_event == event) {
span = total_step_;
}
if (span > 1) {
const size_t span_mul_size = (span - 1) * event->mem_size;
@ -156,7 +159,7 @@ void MemOffloadStrategy::GenSwapEventSet() {
for (const auto &iter : event_span_) {
auto span = iter.second.second;
auto &event = iter.second.first;
auto start_index = ((total_step_ + event->index - span) % total_step_) + 1;
auto start_index = ((event->index + total_step_ - span + 1) % total_step_);
bool revert = false;
size_t cur_index = start_index;
while (cur_index != event->index) {
@ -196,12 +199,12 @@ void MemOffloadStrategy::GenComputeMemEvents() {
}
const bool is_high_priority = IsHighPriorityMem(item.first);
auto first_event = mem_events[0];
auto first_event = mem_events[kInitOrMallocMemEventIndex];
MS_EXCEPTION_IF_NULL(first_event);
const auto &second_event = mem_events[1];
MS_EXCEPTION_IF_NULL(second_event);
if (is_high_priority && swap_events_.find(second_event) != swap_events_.end()) {
first_event->index = second_event->index;
const auto &first_get_event = mem_events[kFirstGetMemEventIndex];
MS_EXCEPTION_IF_NULL(first_get_event);
if (is_high_priority && swap_events_.find(first_get_event) != swap_events_.end()) {
first_event->index = first_get_event->index;
}
if ((first_event->type == kInit || first_event->type == kMalloc) && first_event->index < total_step_) {
pre_compute_events_[first_event->index].emplace_back(first_event);
@ -211,16 +214,21 @@ void MemOffloadStrategy::GenComputeMemEvents() {
const auto &last_event = mem_events[mem_events.size() - 1];
size_t pre_index = is_high_priority ? last_event->index : first_event->index;
for (size_t i = 1; i < mem_events.size(); ++i) {
const auto &swap_out_event_index = GetSwapOutEventIndex(item.first, mem_events);
for (size_t i = kFirstGetMemEventIndex; i < mem_events.size(); ++i) {
auto &event = mem_events[i];
MS_EXCEPTION_IF_NULL(event);
if (need_swap_ && swap_events_.find(event) != swap_events_.end()) {
auto swap_out_event = std::make_shared<MemEvent>(kSwapOut, pre_index);
swap_out_event->key = item.first;
swap_out_event->mem_size = first_event->mem_size;
post_compute_events_[pre_index].emplace_back(swap_out_event);
MemEventType event_type = kSwapOut;
if (is_high_priority && swap_out_event_index.count(i) == 0) {
event_type = kFree;
}
auto free_or_swap_out_event = std::make_shared<MemEvent>(event_type, pre_index);
free_or_swap_out_event->key = item.first;
free_or_swap_out_event->mem_size = first_event->mem_size;
post_compute_events_[pre_index].emplace_back(free_or_swap_out_event);
// avoid swap-in-event follow init-event
if (first_event->type != kInit || i != 1) {
if (i != kFirstGetMemEventIndex || first_event->type != kInit) {
auto swap_in_event = std::make_shared<MemEvent>(kSwapIn, event->index);
swap_in_event->key = item.first;
swap_in_event->mem_size = first_event->mem_size;
@ -246,5 +254,39 @@ void MemOffloadStrategy::GenFreeEvent(const std::shared_ptr<MemEvent> &last_even
(void)post_compute_events_[last_event->index].emplace_back(free_event);
}
}
std::set<size_t> MemOffloadStrategy::GetSwapOutEventIndex(const void *key,
const std::vector<std::shared_ptr<MemEvent>> &mem_events) {
const auto &update_step_iter = high_priority_updated_step_.find(key);
if (update_step_iter == high_priority_updated_step_.end() || update_step_iter->second.empty()) {
return std::set<size_t>();
}
const auto &update_steps = update_step_iter->second;
size_t update_steps_index = 0;
std::set<size_t> swap_out_event_index;
size_t min_swap_index_before_update = SIZE_MAX;
size_t max_swap_out_step = 0;
for (size_t i = 0; i < mem_events.size(); ++i) {
const auto &mem_event = mem_events[i];
if (swap_events_.count(mem_event) == 0) {
continue;
}
if (mem_event->index <= update_steps[update_steps_index]) {
if (i <= min_swap_index_before_update) {
min_swap_index_before_update = i;
}
} else {
swap_out_event_index.insert(i);
max_swap_out_step = mem_event->index;
while (update_steps_index < update_steps.size() && update_steps[update_steps_index] < mem_event->index) {
++update_steps_index;
}
}
}
if (max_swap_out_step <= update_steps[update_steps.size() - 1]) {
swap_out_event_index.insert(min_swap_index_before_update);
}
return swap_out_event_index;
}
} // namespace device
} // namespace mindspore

View File

@ -41,10 +41,12 @@ class MemOffloadStrategy {
public:
MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority,
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events,
const std::set<const void *> &manual_offload_keys, size_t total_step)
const std::set<const void *> &manual_offload_keys,
const std::map<const void *, std::vector<size_t>> &high_priority_updated_step, size_t total_step)
: mem_priority_(mem_priority),
mem_events_(mem_events),
manual_offload_keys_(manual_offload_keys),
high_priority_updated_step_(high_priority_updated_step),
total_step_(total_step) {}
virtual ~MemOffloadStrategy() = default;
@ -75,10 +77,16 @@ class MemOffloadStrategy {
void GenComputeMemEvents();
void GenFreeEvent(const std::shared_ptr<MemEvent> &last_event);
std::set<size_t> GetSwapOutEventIndex(const void *key, const std::vector<std::shared_ptr<MemEvent>> &mem_events);
size_t GetSpanBetweenMemEvents(size_t pre_step, size_t post_step) const {
return (post_step + total_step_ - pre_step) % total_step_;
}
const std::map<const void *, MemPriority> &mem_priority_;
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_;
const std::set<const void *> &manual_offload_keys_;
std::map<const void *, std::vector<size_t>> high_priority_updated_step_;
const size_t total_step_;
std::vector<std::vector<std::shared_ptr<MemEvent>>> pre_compute_events_;
std::vector<std::vector<std::shared_ptr<MemEvent>>> post_compute_events_;

View File

@ -45,10 +45,10 @@ void MemScheduler::Clear() {
if (mem_handler_ == nullptr) {
return;
}
for (auto &item : high_priority_device_ptr_) {
for (auto &item : mem_result_) {
mem_handler_->FreeDevice(item.second);
}
high_priority_device_ptr_.clear();
mem_result_.clear();
}
void MemScheduler::ClearAllocatedMem() {
@ -57,12 +57,11 @@ void MemScheduler::ClearAllocatedMem() {
}
for (auto &item : mem_result_) {
const auto device_ptr = item.second;
if (device_ptr == nullptr) {
if (device_ptr != nullptr) {
mem_handler_->FreeDevice(device_ptr);
}
}
mem_result_.clear();
high_priority_device_ptr_.clear();
for (const auto &item : swap_host_ptr_) {
const auto host_ptr = item.second;
if (host_ptr != nullptr) {
@ -125,22 +124,19 @@ bool MemScheduler::PreCompute(void *stream) {
MS_EXCEPTION_IF_NULL(event);
MS_LOG(DEBUG) << "Pre compute " << current_step_ << ": " << event->key << " v " << event->type;
if (event->type == kInit || event->type == kMalloc) {
auto priority = mem_priority_[event->key];
auto iter = high_priority_device_ptr_.find(event->key);
if (priority != kMemPriorityLow && iter != high_priority_device_ptr_.end()) {
MS_EXCEPTION_IF_NULL(iter->second);
mem_result_[event->key] = iter->second;
continue;
}
auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
if (device_ptr == nullptr) {
return false;
}
if (priority != kMemPriorityLow) {
high_priority_device_ptr_[event->key] = device_ptr;
const auto &iter = mem_result_.find(event->key);
const bool new_malloc = iter == mem_result_.end();
void *device_ptr;
if (new_malloc) {
device_ptr = mem_handler_->MallocDevice(event->mem_size);
if (device_ptr == nullptr) {
return false;
}
} else {
device_ptr = iter->second;
}
if (event->type == kInit) {
if (event->type == kInit && (new_malloc || high_priority_mem_need_init_.count(event->key) != 0)) {
auto host_ptr = init_host_ptr_[event->key];
MS_EXCEPTION_IF_NULL(host_ptr);
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
@ -160,9 +156,6 @@ bool MemScheduler::PreCompute(void *stream) {
MS_EXCEPTION_IF_NULL(host_ptr);
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
mem_result_[event->key] = device_ptr;
if (mem_priority_[event->key] == kMemPriorityHigh) {
high_priority_device_ptr_[event->key] = device_ptr;
}
if (!from_init) {
mem_handler_->FreeHost(host_ptr);
(void)swap_host_ptr_.erase(event->key);
@ -211,9 +204,6 @@ bool MemScheduler::PostCompute(void *stream) {
mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
mem_handler_->FreeDevice(device_ptr);
(void)mem_result_.erase(event->key);
if (mem_priority_[event->key] == kMemPriorityHigh) {
high_priority_device_ptr_.erase(event->key);
}
}
}
++current_step_;
@ -225,7 +215,8 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
MS_EXCEPTION_IF_NULL(mem_handler_);
if (strategy_ == nullptr) {
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_, total_step_);
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_,
high_priority_updated_step_, total_step_);
if (manual_offload_keys_.empty()) {
compute_time_.resize(total_step_);
} else {

View File

@ -53,6 +53,14 @@ class MemScheduler {
void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow);
bool HasDeviceMem(const void *key) const { return mem_result_.find(key) != mem_result_.end(); }
void UpdateHighPriorityMem(const void *key) {
if (need_record_event_) {
high_priority_updated_step_[key].emplace_back(current_step_);
}
}
void SetTotalStep(size_t step) {
total_step_ = step;
step_events_.resize(total_step_);
@ -72,6 +80,10 @@ class MemScheduler {
void SetOffload(const void *key) { (void)manual_offload_keys_.insert(key); }
void AddMemNeedInit(const void *key) { high_priority_mem_need_init_.insert(key); }
void ClearMemNeedInit() { high_priority_mem_need_init_.clear(); }
private:
void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
@ -86,7 +98,8 @@ class MemScheduler {
std::map<const void *, void *> mem_result_;
std::map<const void *, void *> init_host_ptr_;
std::map<const void *, void *> swap_host_ptr_;
std::map<const void *, void *> high_priority_device_ptr_;
std::map<const void *, std::vector<size_t>> high_priority_updated_step_;
std::set<const void *> high_priority_mem_need_init_;
size_t total_step_{0};
size_t current_step_{0};
bool need_record_event_{true};