forked from mindspore-Ecosystem/mindspore
!27390 Optimize swap strategy
Merge pull request !27390 from tanghuikang/swap_strategy_adjust
This commit is contained in:
commit
5655c9f972
|
@ -1406,6 +1406,18 @@ bool AnfRuntimeAlgorithm::IsLabelIndexInNode(const AnfNodePtr &node, size_t labe
|
|||
return false;
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsUpdateParameterKernel(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto node_name = GetCNodeName(node);
|
||||
if (HasNodeAttr(kAttrAsync, node) && GetNodeAttr<bool>(node, kAttrAsync)) {
|
||||
return false;
|
||||
}
|
||||
if (kOptOperatorSet.find(node_name) == kOptOperatorSet.end() && node_name.find("Assign") == string::npos) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
||||
|
|
|
@ -246,6 +246,8 @@ class AnfRuntimeAlgorithm {
|
|||
static bool IsParameterWeight(const ParameterPtr &node);
|
||||
// checkout whether the anf node is include the label_index.
|
||||
static bool IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index);
|
||||
// Check whether the cnode update parameter
|
||||
static bool IsUpdateParameterKernel(const CNodePtr &node);
|
||||
// set stream id of kernel,which will be set in stream assign and be used in stream generate
|
||||
static void SetStreamId(uint32_t stream_id, AnfNode *node);
|
||||
// get stream id
|
||||
|
|
|
@ -1346,13 +1346,7 @@ void KernelGraph::SetOptimizerFlag() {
|
|||
has_optimizer_ = false;
|
||||
for (const auto &cnode : execution_order_) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto node_name = AnfAlgo::GetCNodeName(cnode);
|
||||
if (AnfAlgo::HasNodeAttr(kAttrAsync, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrAsync)) {
|
||||
continue;
|
||||
}
|
||||
if (kOptOperatorSet.find(node_name) != kOptOperatorSet.end()) {
|
||||
has_optimizer_ = true;
|
||||
} else if (node_name.find("Assign") == string::npos) {
|
||||
if (!AnfAlgo::IsUpdateParameterKernel(cnode)) {
|
||||
continue;
|
||||
}
|
||||
for (auto &input : cnode->inputs()) {
|
||||
|
|
|
@ -307,10 +307,7 @@ class KernelGraph : public FuncGraph {
|
|||
|
||||
bool has_optimizer() const { return has_optimizer_; }
|
||||
bool IsUpdatedParameter(const ParameterPtr ¶m) const {
|
||||
if (updated_parameters_.find(param) != updated_parameters_.end()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
return updated_parameters_.find(param) != updated_parameters_.end();
|
||||
}
|
||||
// handle graph dependency
|
||||
void AddPreGraph(const std::shared_ptr<session::KernelGraph> &graph) {
|
||||
|
|
|
@ -1324,6 +1324,7 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem
|
|||
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
|
||||
const auto update_parameter = AnfAlgo::IsUpdateParameterKernel(cnode);
|
||||
for (size_t j = 0; j < input_num; ++j) {
|
||||
auto real_input = AnfAlgo::GetRealInputIndex(kernel, j);
|
||||
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, real_input, true);
|
||||
|
@ -1335,6 +1336,14 @@ void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem
|
|||
GetOrMallocAddress(mem_scheduler, device_address, input);
|
||||
input->size = device_address->size_;
|
||||
kernel_launch_info->inputs_.emplace_back(input);
|
||||
if (update_parameter && input_node->isa<Parameter>()) {
|
||||
auto param = input_node->cast<ParameterPtr>();
|
||||
auto abstract = param->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (abstract->isa<abstract::AbstractRef>()) {
|
||||
mem_scheduler->UpdateHighPriorityMem(device_address);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) {
|
||||
|
@ -1410,6 +1419,7 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m
|
|||
if (input_tensors.size() != input_nodes.size()) {
|
||||
MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
|
||||
}
|
||||
mem_scheduler->ClearMemNeedInit();
|
||||
for (size_t i = 0; i < input_tensors.size(); ++i) {
|
||||
auto input_node = input_nodes[i];
|
||||
if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) {
|
||||
|
@ -1418,16 +1428,30 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m
|
|||
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
||||
auto tensor = input_tensors[i];
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
auto tensor_address = tensor->device_address();
|
||||
if (!tensor->NeedSyncHostToDevice() && tensor_address != nullptr && tensor_address != device_address) {
|
||||
auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
|
||||
const auto tensor_size = LongToSize(tensor->data().nbytes());
|
||||
if (tensor_address == device_address) {
|
||||
if (tensor->NeedSyncHostToDevice()) {
|
||||
tensor_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), tensor->data().nbytes(),
|
||||
tensor->data_type(), tensor->data_c(), tensor->device_info().host_format_);
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
}
|
||||
if (mem_scheduler->HasDeviceMem(tensor_address.get())) {
|
||||
tensor_address->set_ptr(nullptr);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (tensor->NeedSyncHostToDevice()) {
|
||||
mem_scheduler->AddMemNeedInit(device_address.get());
|
||||
} else if (tensor_address != nullptr) {
|
||||
tensor->data_sync(false);
|
||||
mem_scheduler->AddMemNeedInit(device_address.get());
|
||||
}
|
||||
MemPriority priority = kMemPriorityLow;
|
||||
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) &&
|
||||
graph.IsUpdatedParameter(input_node->cast<ParameterPtr>())) {
|
||||
const auto ¶meter = input_node->cast<ParameterPtr>();
|
||||
if (AnfAlgo::IsParameterWeight(parameter) || graph.IsUpdatedParameter(parameter)) {
|
||||
priority = kMemPriorityHigh;
|
||||
}
|
||||
auto tensor_size = LongToSize(tensor->data().nbytes());
|
||||
mem_scheduler->Init(device_address.get(), tensor->data_c(), tensor_size, priority);
|
||||
tensor->set_sync_status(kNoNeedSync);
|
||||
}
|
||||
|
|
|
@ -22,6 +22,9 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
constexpr size_t kFirstGetMemEventIndex = 1;
|
||||
constexpr size_t kInitOrMallocMemEventIndex = 0;
|
||||
|
||||
std::vector<std::shared_ptr<MemEvent>> &MemOffloadStrategy::GetPreComputeEvents(size_t step) {
|
||||
if (pre_compute_events_.size() <= step) {
|
||||
MS_LOG_EXCEPTION << "Index out of pre event range, index:" << step << ", event size:" << pre_compute_events_.size();
|
||||
|
@ -62,7 +65,7 @@ void MemOffloadStrategy::CountMemUsage() {
|
|||
if (mem_events.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto first_event = mem_events[0];
|
||||
auto first_event = mem_events[kInitOrMallocMemEventIndex];
|
||||
const bool is_high_priority = IsHighPriorityMem(first_event->key);
|
||||
if (is_high_priority) {
|
||||
high_priority_mem_size += first_event->mem_size;
|
||||
|
@ -83,6 +86,10 @@ void MemOffloadStrategy::CountMemUsage() {
|
|||
}
|
||||
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
|
||||
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())) + high_priority_mem_size;
|
||||
if (mem_size_ < min_mem_needed_) {
|
||||
MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << mem_size_ << " while graph needs at least "
|
||||
<< min_mem_needed_;
|
||||
}
|
||||
}
|
||||
|
||||
bool MemOffloadStrategy::IsHighPriorityMem(const void *key) {
|
||||
|
@ -94,11 +101,6 @@ bool MemOffloadStrategy::IsHighPriorityMem(const void *key) {
|
|||
}
|
||||
|
||||
void MemOffloadStrategy::CheckMemSize() {
|
||||
if (mem_size_ < min_mem_needed_) {
|
||||
MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << mem_size_ << " while graph needs at least "
|
||||
<< min_mem_needed_;
|
||||
}
|
||||
|
||||
if (mem_size_ < mem_used_without_swap_ || !manual_offload_keys_.empty()) {
|
||||
need_swap_ = true;
|
||||
}
|
||||
|
@ -116,19 +118,20 @@ void MemOffloadStrategy::GenEventSpan() {
|
|||
if (tensor_events.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
const bool is_high_priority = IsHighPriorityMem(tensor_events[0]->key);
|
||||
for (size_t event_index = 1; event_index < tensor_events.size(); ++event_index) {
|
||||
auto &event = tensor_events[event_index];
|
||||
const bool is_high_priority = IsHighPriorityMem(tensor_events[kInitOrMallocMemEventIndex]->key);
|
||||
for (size_t i = kFirstGetMemEventIndex; i < tensor_events.size(); ++i) {
|
||||
auto &event = tensor_events[i];
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
if (event->type != kGet) {
|
||||
MS_LOG(EXCEPTION) << "Event should be Get except fist event.";
|
||||
}
|
||||
size_t span = 0;
|
||||
if (event_index == 1 && is_high_priority) {
|
||||
const auto &last_event = tensor_events[tensor_events.size() - 1];
|
||||
span = event->index + total_step_ - last_event->index;
|
||||
} else {
|
||||
span = event->index - tensor_events[event_index - 1]->index;
|
||||
auto latest_event = tensor_events[i - 1];
|
||||
if (i == kFirstGetMemEventIndex && is_high_priority) {
|
||||
latest_event = tensor_events[tensor_events.size() - 1];
|
||||
}
|
||||
auto span = GetSpanBetweenMemEvents(latest_event->index, event->index);
|
||||
if (is_high_priority && span == 0 && latest_event == event) {
|
||||
span = total_step_;
|
||||
}
|
||||
if (span > 1) {
|
||||
const size_t span_mul_size = (span - 1) * event->mem_size;
|
||||
|
@ -156,7 +159,7 @@ void MemOffloadStrategy::GenSwapEventSet() {
|
|||
for (const auto &iter : event_span_) {
|
||||
auto span = iter.second.second;
|
||||
auto &event = iter.second.first;
|
||||
auto start_index = ((total_step_ + event->index - span) % total_step_) + 1;
|
||||
auto start_index = ((event->index + total_step_ - span + 1) % total_step_);
|
||||
bool revert = false;
|
||||
size_t cur_index = start_index;
|
||||
while (cur_index != event->index) {
|
||||
|
@ -196,12 +199,12 @@ void MemOffloadStrategy::GenComputeMemEvents() {
|
|||
}
|
||||
|
||||
const bool is_high_priority = IsHighPriorityMem(item.first);
|
||||
auto first_event = mem_events[0];
|
||||
auto first_event = mem_events[kInitOrMallocMemEventIndex];
|
||||
MS_EXCEPTION_IF_NULL(first_event);
|
||||
const auto &second_event = mem_events[1];
|
||||
MS_EXCEPTION_IF_NULL(second_event);
|
||||
if (is_high_priority && swap_events_.find(second_event) != swap_events_.end()) {
|
||||
first_event->index = second_event->index;
|
||||
const auto &first_get_event = mem_events[kFirstGetMemEventIndex];
|
||||
MS_EXCEPTION_IF_NULL(first_get_event);
|
||||
if (is_high_priority && swap_events_.find(first_get_event) != swap_events_.end()) {
|
||||
first_event->index = first_get_event->index;
|
||||
}
|
||||
if ((first_event->type == kInit || first_event->type == kMalloc) && first_event->index < total_step_) {
|
||||
pre_compute_events_[first_event->index].emplace_back(first_event);
|
||||
|
@ -211,16 +214,21 @@ void MemOffloadStrategy::GenComputeMemEvents() {
|
|||
|
||||
const auto &last_event = mem_events[mem_events.size() - 1];
|
||||
size_t pre_index = is_high_priority ? last_event->index : first_event->index;
|
||||
for (size_t i = 1; i < mem_events.size(); ++i) {
|
||||
const auto &swap_out_event_index = GetSwapOutEventIndex(item.first, mem_events);
|
||||
for (size_t i = kFirstGetMemEventIndex; i < mem_events.size(); ++i) {
|
||||
auto &event = mem_events[i];
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
if (need_swap_ && swap_events_.find(event) != swap_events_.end()) {
|
||||
auto swap_out_event = std::make_shared<MemEvent>(kSwapOut, pre_index);
|
||||
swap_out_event->key = item.first;
|
||||
swap_out_event->mem_size = first_event->mem_size;
|
||||
post_compute_events_[pre_index].emplace_back(swap_out_event);
|
||||
MemEventType event_type = kSwapOut;
|
||||
if (is_high_priority && swap_out_event_index.count(i) == 0) {
|
||||
event_type = kFree;
|
||||
}
|
||||
auto free_or_swap_out_event = std::make_shared<MemEvent>(event_type, pre_index);
|
||||
free_or_swap_out_event->key = item.first;
|
||||
free_or_swap_out_event->mem_size = first_event->mem_size;
|
||||
post_compute_events_[pre_index].emplace_back(free_or_swap_out_event);
|
||||
// avoid swap-in-event follow init-event
|
||||
if (first_event->type != kInit || i != 1) {
|
||||
if (i != kFirstGetMemEventIndex || first_event->type != kInit) {
|
||||
auto swap_in_event = std::make_shared<MemEvent>(kSwapIn, event->index);
|
||||
swap_in_event->key = item.first;
|
||||
swap_in_event->mem_size = first_event->mem_size;
|
||||
|
@ -246,5 +254,39 @@ void MemOffloadStrategy::GenFreeEvent(const std::shared_ptr<MemEvent> &last_even
|
|||
(void)post_compute_events_[last_event->index].emplace_back(free_event);
|
||||
}
|
||||
}
|
||||
|
||||
std::set<size_t> MemOffloadStrategy::GetSwapOutEventIndex(const void *key,
|
||||
const std::vector<std::shared_ptr<MemEvent>> &mem_events) {
|
||||
const auto &update_step_iter = high_priority_updated_step_.find(key);
|
||||
if (update_step_iter == high_priority_updated_step_.end() || update_step_iter->second.empty()) {
|
||||
return std::set<size_t>();
|
||||
}
|
||||
const auto &update_steps = update_step_iter->second;
|
||||
size_t update_steps_index = 0;
|
||||
std::set<size_t> swap_out_event_index;
|
||||
size_t min_swap_index_before_update = SIZE_MAX;
|
||||
size_t max_swap_out_step = 0;
|
||||
for (size_t i = 0; i < mem_events.size(); ++i) {
|
||||
const auto &mem_event = mem_events[i];
|
||||
if (swap_events_.count(mem_event) == 0) {
|
||||
continue;
|
||||
}
|
||||
if (mem_event->index <= update_steps[update_steps_index]) {
|
||||
if (i <= min_swap_index_before_update) {
|
||||
min_swap_index_before_update = i;
|
||||
}
|
||||
} else {
|
||||
swap_out_event_index.insert(i);
|
||||
max_swap_out_step = mem_event->index;
|
||||
while (update_steps_index < update_steps.size() && update_steps[update_steps_index] < mem_event->index) {
|
||||
++update_steps_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (max_swap_out_step <= update_steps[update_steps.size() - 1]) {
|
||||
swap_out_event_index.insert(min_swap_index_before_update);
|
||||
}
|
||||
return swap_out_event_index;
|
||||
}
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,10 +41,12 @@ class MemOffloadStrategy {
|
|||
public:
|
||||
MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority,
|
||||
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events,
|
||||
const std::set<const void *> &manual_offload_keys, size_t total_step)
|
||||
const std::set<const void *> &manual_offload_keys,
|
||||
const std::map<const void *, std::vector<size_t>> &high_priority_updated_step, size_t total_step)
|
||||
: mem_priority_(mem_priority),
|
||||
mem_events_(mem_events),
|
||||
manual_offload_keys_(manual_offload_keys),
|
||||
high_priority_updated_step_(high_priority_updated_step),
|
||||
total_step_(total_step) {}
|
||||
|
||||
virtual ~MemOffloadStrategy() = default;
|
||||
|
@ -75,10 +77,16 @@ class MemOffloadStrategy {
|
|||
void GenComputeMemEvents();
|
||||
|
||||
void GenFreeEvent(const std::shared_ptr<MemEvent> &last_event);
|
||||
std::set<size_t> GetSwapOutEventIndex(const void *key, const std::vector<std::shared_ptr<MemEvent>> &mem_events);
|
||||
|
||||
size_t GetSpanBetweenMemEvents(size_t pre_step, size_t post_step) const {
|
||||
return (post_step + total_step_ - pre_step) % total_step_;
|
||||
}
|
||||
|
||||
const std::map<const void *, MemPriority> &mem_priority_;
|
||||
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_;
|
||||
const std::set<const void *> &manual_offload_keys_;
|
||||
std::map<const void *, std::vector<size_t>> high_priority_updated_step_;
|
||||
const size_t total_step_;
|
||||
std::vector<std::vector<std::shared_ptr<MemEvent>>> pre_compute_events_;
|
||||
std::vector<std::vector<std::shared_ptr<MemEvent>>> post_compute_events_;
|
||||
|
|
|
@ -45,10 +45,10 @@ void MemScheduler::Clear() {
|
|||
if (mem_handler_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
for (auto &item : high_priority_device_ptr_) {
|
||||
for (auto &item : mem_result_) {
|
||||
mem_handler_->FreeDevice(item.second);
|
||||
}
|
||||
high_priority_device_ptr_.clear();
|
||||
mem_result_.clear();
|
||||
}
|
||||
|
||||
void MemScheduler::ClearAllocatedMem() {
|
||||
|
@ -57,12 +57,11 @@ void MemScheduler::ClearAllocatedMem() {
|
|||
}
|
||||
for (auto &item : mem_result_) {
|
||||
const auto device_ptr = item.second;
|
||||
if (device_ptr == nullptr) {
|
||||
if (device_ptr != nullptr) {
|
||||
mem_handler_->FreeDevice(device_ptr);
|
||||
}
|
||||
}
|
||||
mem_result_.clear();
|
||||
high_priority_device_ptr_.clear();
|
||||
for (const auto &item : swap_host_ptr_) {
|
||||
const auto host_ptr = item.second;
|
||||
if (host_ptr != nullptr) {
|
||||
|
@ -125,22 +124,19 @@ bool MemScheduler::PreCompute(void *stream) {
|
|||
MS_EXCEPTION_IF_NULL(event);
|
||||
MS_LOG(DEBUG) << "Pre compute " << current_step_ << ": " << event->key << " v " << event->type;
|
||||
if (event->type == kInit || event->type == kMalloc) {
|
||||
auto priority = mem_priority_[event->key];
|
||||
auto iter = high_priority_device_ptr_.find(event->key);
|
||||
if (priority != kMemPriorityLow && iter != high_priority_device_ptr_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
mem_result_[event->key] = iter->second;
|
||||
continue;
|
||||
}
|
||||
auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
|
||||
if (device_ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (priority != kMemPriorityLow) {
|
||||
high_priority_device_ptr_[event->key] = device_ptr;
|
||||
const auto &iter = mem_result_.find(event->key);
|
||||
const bool new_malloc = iter == mem_result_.end();
|
||||
void *device_ptr;
|
||||
if (new_malloc) {
|
||||
device_ptr = mem_handler_->MallocDevice(event->mem_size);
|
||||
if (device_ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
device_ptr = iter->second;
|
||||
}
|
||||
|
||||
if (event->type == kInit) {
|
||||
if (event->type == kInit && (new_malloc || high_priority_mem_need_init_.count(event->key) != 0)) {
|
||||
auto host_ptr = init_host_ptr_[event->key];
|
||||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
|
||||
|
@ -160,9 +156,6 @@ bool MemScheduler::PreCompute(void *stream) {
|
|||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
|
||||
mem_result_[event->key] = device_ptr;
|
||||
if (mem_priority_[event->key] == kMemPriorityHigh) {
|
||||
high_priority_device_ptr_[event->key] = device_ptr;
|
||||
}
|
||||
if (!from_init) {
|
||||
mem_handler_->FreeHost(host_ptr);
|
||||
(void)swap_host_ptr_.erase(event->key);
|
||||
|
@ -211,9 +204,6 @@ bool MemScheduler::PostCompute(void *stream) {
|
|||
mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
|
||||
mem_handler_->FreeDevice(device_ptr);
|
||||
(void)mem_result_.erase(event->key);
|
||||
if (mem_priority_[event->key] == kMemPriorityHigh) {
|
||||
high_priority_device_ptr_.erase(event->key);
|
||||
}
|
||||
}
|
||||
}
|
||||
++current_step_;
|
||||
|
@ -225,7 +215,8 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
|
|||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
||||
|
||||
if (strategy_ == nullptr) {
|
||||
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_, total_step_);
|
||||
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_,
|
||||
high_priority_updated_step_, total_step_);
|
||||
if (manual_offload_keys_.empty()) {
|
||||
compute_time_.resize(total_step_);
|
||||
} else {
|
||||
|
|
|
@ -53,6 +53,14 @@ class MemScheduler {
|
|||
|
||||
void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow);
|
||||
|
||||
bool HasDeviceMem(const void *key) const { return mem_result_.find(key) != mem_result_.end(); }
|
||||
|
||||
void UpdateHighPriorityMem(const void *key) {
|
||||
if (need_record_event_) {
|
||||
high_priority_updated_step_[key].emplace_back(current_step_);
|
||||
}
|
||||
}
|
||||
|
||||
void SetTotalStep(size_t step) {
|
||||
total_step_ = step;
|
||||
step_events_.resize(total_step_);
|
||||
|
@ -72,6 +80,10 @@ class MemScheduler {
|
|||
|
||||
void SetOffload(const void *key) { (void)manual_offload_keys_.insert(key); }
|
||||
|
||||
void AddMemNeedInit(const void *key) { high_priority_mem_need_init_.insert(key); }
|
||||
|
||||
void ClearMemNeedInit() { high_priority_mem_need_init_.clear(); }
|
||||
|
||||
private:
|
||||
void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
|
||||
|
||||
|
@ -86,7 +98,8 @@ class MemScheduler {
|
|||
std::map<const void *, void *> mem_result_;
|
||||
std::map<const void *, void *> init_host_ptr_;
|
||||
std::map<const void *, void *> swap_host_ptr_;
|
||||
std::map<const void *, void *> high_priority_device_ptr_;
|
||||
std::map<const void *, std::vector<size_t>> high_priority_updated_step_;
|
||||
std::set<const void *> high_priority_mem_need_init_;
|
||||
size_t total_step_{0};
|
||||
size_t current_step_{0};
|
||||
bool need_record_event_{true};
|
||||
|
|
Loading…
Reference in New Issue