Memoffload reuse continuous memory

This commit is contained in:
tanghuikang 2022-06-09 19:56:47 +08:00
parent 508f3d7b5f
commit 7522bb2c9b
10 changed files with 696 additions and 169 deletions

View File

@ -48,11 +48,6 @@ uint64_t AscendMemoryManager::GetMsMaxMemSize() { return AscendMemAdapter::GetIn
uint64_t AscendMemoryManager::GetMsUsedHbmSize() { return AscendMemAdapter::GetInstance().GetMsUsedHbmSize(); } uint64_t AscendMemoryManager::GetMsUsedHbmSize() { return AscendMemAdapter::GetInstance().GetMsUsedHbmSize(); }
void *AscendMemoryManager::MallocDevice(size_t size) {
auto align_size = GetCommonAlignSize(size);
return AscendMemoryPool::GetInstance().AllocTensorMem(align_size);
}
void *AscendMemoryManager::MallocMemFromMemPool(size_t size, bool from_persistent_mem) { void *AscendMemoryManager::MallocMemFromMemPool(size_t size, bool from_persistent_mem) {
auto align_size = GetCommonAlignSize(size); auto align_size = GetCommonAlignSize(size);
const auto device_addr = AscendMemoryPool::GetInstance().AllocTensorMem(align_size, from_persistent_mem); const auto device_addr = AscendMemoryPool::GetInstance().AllocTensorMem(align_size, from_persistent_mem);

View File

@ -33,7 +33,6 @@ class AscendMemoryManager : public MemoryManager {
void ResetDynamicMemory() override; void ResetDynamicMemory() override;
void ClearGlobalIdleMem() override; void ClearGlobalIdleMem() override;
void *MallocMemFromMemPool(size_t size, bool from_persistent_mem) override; void *MallocMemFromMemPool(size_t size, bool from_persistent_mem) override;
void *MallocDevice(size_t size) override;
void FreeMemFromMemPool(void *device_ptr) override; void FreeMemFromMemPool(void *device_ptr) override;
uint64_t GetMsMaxMemSize(); uint64_t GetMsMaxMemSize();
void MallocSomasDynamicMem(const session::KernelGraph &graph) override; void MallocSomasDynamicMem(const session::KernelGraph &graph) override;

View File

@ -15,6 +15,7 @@
*/ */
#include "runtime/device/kernel_runtime.h" #include "runtime/device/kernel_runtime.h"
#include <algorithm>
#include <functional> #include <functional>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -110,7 +111,7 @@ void KernelRuntime::AssignMemory(const session::KernelGraph &graph) {
if (UseMemScheduler()) { if (UseMemScheduler()) {
AssignStaticMemoryValueNode(graph); AssignStaticMemoryValueNode(graph);
ResetNodeAddress(graph); ResetNodeAddress(graph);
AssignCommunicationMem(graph); AddCommunicationMemInfo(graph);
} else { } else {
MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_manager_);
mem_manager_->ResetDynamicMemory(); mem_manager_->ResetDynamicMemory();
@ -1088,6 +1089,53 @@ DeviceAddressPtr KernelRuntime::CreateDeviceAddressForStringValue(const ValuePtr
return address; return address;
} }
bool KernelRuntime::MemSchedulerPreCompute(const AnfNodePtr &kernel, const std::shared_ptr<MemScheduler> &mem_scheduler,
void *stream, bool mock, KernelLaunchInfo *kernel_launch_info) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(mem_scheduler);
MS_EXCEPTION_IF_NULL(stream);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
if (!mock && common::AnfAlgo::IsCommunicationOp(kernel) && !SyncStream()) {
MS_LOG(ERROR) << "SyncStream failed";
return false;
}
bool ret = mem_scheduler->PreCompute(stream);
if (!ret) {
return ret;
}
AssignKernelAddress(mem_scheduler, kernel, kernel_launch_info);
auto cnode = kernel->cast<CNodePtr>();
if (mock && common::AnfAlgo::HasNodeAttr(kAttrOffload, cnode) &&
common::AnfAlgo::GetNodeAttr<bool>(cnode, kAttrOffload)) {
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
auto device_address = AnfAlgo::GetOutputAddr(kernel, i, true);
mem_scheduler->SetOffload(device_address);
}
}
return true;
}
bool KernelRuntime::MemSchedulerPostCompute(const session::KernelGraph &graph, const AnfNodePtr &kernel,
const std::shared_ptr<MemScheduler> &mem_scheduler, void *stream,
bool mock) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(mem_scheduler);
MS_EXCEPTION_IF_NULL(stream);
if (!mock) {
SyncNodeOutputTensors(mem_scheduler, graph, kernel);
}
bool ret = mem_scheduler->PostCompute(stream);
if (!ret) {
return ret;
}
if (!mock && common::AnfAlgo::IsCommunicationOp(kernel) && !SyncStream()) {
MS_LOG(ERROR) << "SyncStream failed";
return false;
}
return true;
}
void KernelRuntime::AssignDynamicMemory(const session::KernelGraph &graph) { void KernelRuntime::AssignDynamicMemory(const session::KernelGraph &graph) {
MS_EXCEPTION_IF_NULL(mem_manager_); MS_EXCEPTION_IF_NULL(mem_manager_);
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
@ -1524,13 +1572,37 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m
} }
} }
void KernelRuntime::AssignCommunicationMem(const session::KernelGraph &graph) { void KernelRuntime::AddCommunicationMemInfo(const session::KernelGraph &graph) {
for (const auto &kernel : graph.execution_order()) { const auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
for (size_t compute_index = 0; compute_index < graph.execution_order().size(); ++compute_index) {
const auto &kernel = graph.execution_order()[compute_index];
MS_EXCEPTION_IF_NULL(kernel);
if (!common::AnfAlgo::IsCommunicationOp(kernel)) { if (!common::AnfAlgo::IsCommunicationOp(kernel)) {
continue; continue;
} }
AssignCommunicationInputFromMemoryPool(kernel); auto device_address_to_key = [](const DeviceAddressPtr &device_address) -> void * { return device_address.get(); };
AssignCommunicationOutputFromMemoryPool(kernel); size_t input_total_size = 0;
DeviceAddressPtrList input_address_list;
std::vector<size_t> input_align_size_list;
GetCommunicationInputInfo(kernel, &input_total_size, &input_address_list, &input_align_size_list);
if (input_address_list.size() > 1) {
std::vector<const void *> input_address_key_list;
std::transform(input_address_list.begin(), input_address_list.end(), std::back_inserter(input_address_key_list),
device_address_to_key);
mem_scheduler->AddContinuousMemInfo(true, compute_index, input_total_size, input_align_size_list,
input_address_key_list);
}
size_t output_total_size = 0;
DeviceAddressPtrList output_address_list;
std::vector<size_t> output_align_size_list;
GetCommunicationOutputInfo(kernel, &output_total_size, &output_address_list, &output_align_size_list);
if (output_address_list.size() > 1) {
std::vector<const void *> output_address_key_list;
std::transform(output_address_list.begin(), output_address_list.end(),
std::back_inserter(output_address_key_list), device_address_to_key);
mem_scheduler->AddContinuousMemInfo(false, compute_index, output_total_size, output_align_size_list,
output_address_key_list);
}
} }
} }
@ -1550,19 +1622,10 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
} }
bool ret = true; bool ret = true;
if (mem_scheduler != nullptr) { if (mem_scheduler != nullptr) {
ret = mem_scheduler->PreCompute(stream); ret = MemSchedulerPreCompute(kernel, mem_scheduler, stream, mock, &kernel_launch_info);
if (!ret) { if (!ret) {
return ret; return ret;
} }
AssignKernelAddress(mem_scheduler, kernel, &kernel_launch_info);
auto cnode = kernel->cast<CNodePtr>();
if (mock && common::AnfAlgo::HasNodeAttr(kAttrOffload, cnode) &&
common::AnfAlgo::GetNodeAttr<bool>(cnode, kAttrOffload)) {
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
auto device_address = AnfAlgo::GetOutputAddr(kernel, i, true);
mem_scheduler->SetOffload(device_address);
}
}
} else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) { } else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) {
kernel_launch_info.inputs_ = kernel_mod->GetInputsAddr(); kernel_launch_info.inputs_ = kernel_mod->GetInputsAddr();
kernel_launch_info.outputs_ = kernel_mod->GetOutputsAddr(); kernel_launch_info.outputs_ = kernel_mod->GetOutputsAddr();
@ -1581,10 +1644,7 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
} }
} }
if (mem_scheduler != nullptr) { if (mem_scheduler != nullptr) {
if (!mock) { ret = MemSchedulerPostCompute(graph, kernel, mem_scheduler, stream, mock);
SyncNodeOutputTensors(mem_scheduler, graph, kernel);
}
ret = mem_scheduler->PostCompute(stream);
} }
return ret; return ret;
} }
@ -1739,7 +1799,7 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
if (mem_scheduler->optimized()) { if (mem_scheduler->optimized()) {
return; return;
} }
mem_scheduler->SetMemHandler(mem_manager_); mem_scheduler->SetMemHandler(std::make_shared<MemHandler>(mem_manager_));
mem_scheduler->SetTotalStep(graph.execution_order().size()); mem_scheduler->SetTotalStep(graph.execution_order().size());
if (mem_scheduler->need_record_event()) { if (mem_scheduler->need_record_event()) {

View File

@ -178,7 +178,7 @@ class KernelRuntime {
void SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler, const KernelWithIndex &output, void SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler, const KernelWithIndex &output,
const session::KernelGraph &graph); const session::KernelGraph &graph);
void AssignCommunicationMem(const session::KernelGraph &graph); void AddCommunicationMemInfo(const session::KernelGraph &graph);
bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false); bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false);
void LaunchKernelEvent(const std::map<AnfNodePtr, std::vector<std::function<void()>>> &run_events, void LaunchKernelEvent(const std::map<AnfNodePtr, std::vector<std::function<void()>>> &run_events,
const AnfNodePtr &node) const; const AnfNodePtr &node) const;
@ -204,6 +204,10 @@ class KernelRuntime {
void GetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size, DeviceAddressPtrList *address_list, void GetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size, DeviceAddressPtrList *address_list,
std::vector<size_t> *align_size_list) const; std::vector<size_t> *align_size_list) const;
DeviceAddressPtr CreateDeviceAddressForStringValue(const ValuePtr &value, bool use_mem_pool, uint32_t graph_id); DeviceAddressPtr CreateDeviceAddressForStringValue(const ValuePtr &value, bool use_mem_pool, uint32_t graph_id);
bool MemSchedulerPreCompute(const AnfNodePtr &kernel, const std::shared_ptr<MemScheduler> &mem_scheduler,
void *stream, bool mock, KernelLaunchInfo *kernel_launch_info);
bool MemSchedulerPostCompute(const session::KernelGraph &graph, const AnfNodePtr &kernel,
const std::shared_ptr<MemScheduler> &mem_scheduler, void *stream, bool mock);
protected: protected:
uint32_t device_id_{0}; uint32_t device_id_{0};

View File

@ -23,7 +23,6 @@
#include <queue> #include <queue>
#include "common/mem_reuse/mem_reuse.h" #include "common/mem_reuse/mem_reuse.h"
#include "backend/common/somas/somas.h" #include "backend/common/somas/somas.h"
#include "runtime/device/memory_scheduler.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
enum MemType { kStaticMem, kDynamicMem, kSomasReuseDynamicMem }; enum MemType { kStaticMem, kDynamicMem, kSomasReuseDynamicMem };
@ -32,7 +31,7 @@ constexpr uint64_t kMemAlignSize = 512;
constexpr uint64_t kTwiceMemAlignSize = kMemAlignSize << 1; constexpr uint64_t kTwiceMemAlignSize = kMemAlignSize << 1;
using SomasPtr = mindspore::somas::SomasPtr; using SomasPtr = mindspore::somas::SomasPtr;
class MemoryManager : public MemHandler { class MemoryManager {
public: public:
MemoryManager() = default; MemoryManager() = default;
virtual ~MemoryManager() = default; virtual ~MemoryManager() = default;
@ -50,7 +49,6 @@ class MemoryManager : public MemHandler {
virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) {
return MallocMem(type, size, address, kInvalidGraphId); return MallocMem(type, size, address, kInvalidGraphId);
} }
// param address is the address type of each device // param address is the address type of each device
// param from_persistent_mem shows whether the tensor is a parameter in Pynative mode // param from_persistent_mem shows whether the tensor is a parameter in Pynative mode
virtual bool MallocMemFromMemPool(const DeviceAddressPtr &address, size_t size); virtual bool MallocMemFromMemPool(const DeviceAddressPtr &address, size_t size);
@ -65,45 +63,13 @@ class MemoryManager : public MemHandler {
static size_t GetCommonAlignSize(size_t input_size); static size_t GetCommonAlignSize(size_t input_size);
static size_t GetCommunicationAlignSize(size_t input_size); static size_t GetCommunicationAlignSize(size_t input_size);
// swap manager interface virtual void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) {
void *MallocDevice(size_t mem_size) override { return MallocMemFromMemPool(mem_size, false); }
void FreeDevice(void *ptr) override {
MS_EXCEPTION_IF_NULL(ptr);
FreeMemFromMemPool(ptr);
}
void *MallocHost(size_t mem_size) override {
auto &mem_que = cached_host_mem_[mem_size];
if (!mem_que.empty()) {
auto ret = mem_que.front();
mem_que.pop();
return ret;
}
auto block = std::make_shared<std::vector<uint8_t>>();
try {
block->resize(mem_size, 0);
auto ptr = block->data();
host_mem_block_map_[ptr] = block;
return ptr;
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "Malloc memory failed: size " << mem_size;
}
}
void FreeHost(void *ptr) override {
MS_EXCEPTION_IF_NULL(ptr);
auto iter = host_mem_block_map_.find(ptr);
if (iter == host_mem_block_map_.end()) {
MS_LOG(ERROR) << "Free ptr not be created from manager!";
}
auto mem_size = iter->second->size();
cached_host_mem_[mem_size].emplace(iter->first);
}
void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {
MS_LOG(INFO) << "Call default swap in " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream; MS_LOG(INFO) << "Call default swap in " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream;
} }
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override { virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) {
MS_LOG(INFO) << "Call default swap out " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream; MS_LOG(INFO) << "Call default swap out " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream;
} }
size_t GetAvailableMemSize() override { virtual size_t GetAvailableMemSize() {
MS_LOG(ERROR) << "Return default 0 mem size!"; MS_LOG(ERROR) << "Return default 0 mem size!";
return 0; return 0;
} }
@ -115,8 +81,6 @@ class MemoryManager : public MemHandler {
} }
virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem); virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem);
SomasPtr somas_reuse_util_ptr_{nullptr}; SomasPtr somas_reuse_util_ptr_{nullptr};
std::map<size_t, std::queue<void *>> cached_host_mem_;
std::map<void *, std::shared_ptr<std::vector<uint8_t>>> host_mem_block_map_;
}; };
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore

View File

@ -18,6 +18,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <algorithm>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {
@ -25,14 +26,14 @@ namespace device {
constexpr size_t kFirstGetMemEventIndex = 1; constexpr size_t kFirstGetMemEventIndex = 1;
constexpr size_t kInitOrMallocMemEventIndex = 0; constexpr size_t kInitOrMallocMemEventIndex = 0;
std::vector<std::shared_ptr<MemEvent>> &MemOffloadStrategy::GetPreComputeEvents(size_t step) { MemEventPtrList &MemOffloadStrategy::GetPreComputeEvents(size_t step) {
if (pre_compute_events_.size() <= step) { if (pre_compute_events_.size() <= step) {
MS_LOG_EXCEPTION << "Index out of pre event range, index:" << step << ", event size:" << pre_compute_events_.size(); MS_LOG_EXCEPTION << "Index out of pre event range, index:" << step << ", event size:" << pre_compute_events_.size();
} }
return pre_compute_events_[step]; return pre_compute_events_[step];
} }
std::vector<std::shared_ptr<MemEvent>> &MemOffloadStrategy::GetPostComputeEvents(size_t step) { MemEventPtrList &MemOffloadStrategy::GetPostComputeEvents(size_t step) {
if (post_compute_events_.size() <= step) { if (post_compute_events_.size() <= step) {
MS_LOG_EXCEPTION << "Index out of post event range, index:" << step MS_LOG_EXCEPTION << "Index out of post event range, index:" << step
<< ", event size:" << post_compute_events_.size(); << ", event size:" << post_compute_events_.size();
@ -46,6 +47,8 @@ void MemOffloadStrategy::Execute() {
if (need_swap_) { if (need_swap_) {
GenEventSpan(); GenEventSpan();
GenSwapEventSet(); GenSwapEventSet();
} else {
GenContinuousMemAllocSteps();
} }
GenComputeMemEvents(); GenComputeMemEvents();
} }
@ -60,14 +63,17 @@ void MemOffloadStrategy::CountMemUsage() {
min_mem_used_.resize(total_step_, 0); min_mem_used_.resize(total_step_, 0);
std::vector<size_t> total_mem_used(total_step_, 0); std::vector<size_t> total_mem_used(total_step_, 0);
size_t high_priority_mem_size = 0; size_t high_priority_mem_size = 0;
MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
for (auto &item : mem_events_) { for (auto &item : mem_events_) {
auto &mem_events = item.second; auto &mem_events = item.second;
if (mem_events.empty()) { if (mem_events.empty()) {
continue; continue;
} }
auto first_event = mem_events[kInitOrMallocMemEventIndex]; auto first_event = mem_events[kInitOrMallocMemEventIndex];
const bool is_high_priority = IsHighPriorityMem(first_event->key); const bool is_high_priority = IsHighPriorityMem(item.first);
if (is_high_priority) { if (continuous_mem_info_helper_->IsContinuousInputMem(item.first)) {
continue;
} else if (is_high_priority) {
high_priority_mem_size += first_event->mem_size; high_priority_mem_size += first_event->mem_size;
} else { } else {
auto last_event = mem_events[mem_events.size() - 1]; auto last_event = mem_events[mem_events.size() - 1];
@ -75,6 +81,7 @@ void MemOffloadStrategy::CountMemUsage() {
total_mem_used[start_index] += first_event->mem_size; total_mem_used[start_index] += first_event->mem_size;
} }
} }
// Calculate the minimum memory size for kernel execution. // Calculate the minimum memory size for kernel execution.
for (const auto &event : mem_events) { for (const auto &event : mem_events) {
MS_EXCEPTION_IF_NULL(event); MS_EXCEPTION_IF_NULL(event);
@ -84,6 +91,7 @@ void MemOffloadStrategy::CountMemUsage() {
min_mem_used_[event->index] += first_event->mem_size; min_mem_used_[event->index] += first_event->mem_size;
} }
} }
CountContinuousMemUsage(&total_mem_used);
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end())); min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())) + high_priority_mem_size; mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())) + high_priority_mem_size;
if (mem_size_ < min_mem_needed_) { if (mem_size_ < min_mem_needed_) {
@ -118,24 +126,25 @@ void MemOffloadStrategy::GenEventSpan() {
if (tensor_events.size() <= 1) { if (tensor_events.size() <= 1) {
continue; continue;
} }
const bool is_high_priority = IsHighPriorityMem(tensor_events[kInitOrMallocMemEventIndex]->key); const bool is_high_priority = IsHighPriorityMem(item.first);
for (size_t i = kFirstGetMemEventIndex; i < tensor_events.size(); ++i) { for (size_t i = kFirstGetMemEventIndex; i < tensor_events.size(); ++i) {
auto &event = tensor_events[i]; auto &event = tensor_events[i];
MS_EXCEPTION_IF_NULL(event); MS_EXCEPTION_IF_NULL(event);
if (event->type != kGet) { if (event->type != kGet) {
MS_LOG(EXCEPTION) << "Event should be Get except fist event."; MS_LOG(EXCEPTION) << "Event should be Get except fist event.";
} }
auto latest_event = tensor_events[i - 1]; auto latest_get_event = tensor_events[i - 1];
if (i == kFirstGetMemEventIndex && is_high_priority) { if (i == kFirstGetMemEventIndex && is_high_priority) {
latest_event = tensor_events[tensor_events.size() - 1]; latest_get_event = tensor_events[tensor_events.size() - 1];
} }
auto span = GetSpanBetweenMemEvents(latest_event->index, event->index); auto span = GetSpanBetweenMemEvents(latest_get_event->index, event->index);
if (is_high_priority && span == 0 && latest_event == event) { // High priority memory that is only used once in a total step
if (is_high_priority && span == 0 && latest_get_event == event) {
span = total_step_; span = total_step_;
} }
if (span > 1) { if (span > 1) {
const size_t span_mul_size = (span - 1) * event->mem_size; const size_t span_mul_size = (span - 1) * event->mem_size;
(void)event_span_.emplace(std::make_pair(span_mul_size, std::make_pair(event, span))); (void)event_span_.emplace(span_mul_size, std::make_pair(event, span));
} }
} }
} }
@ -153,36 +162,159 @@ void MemOffloadStrategy::GenSwapEventSet() {
} }
return; return;
} }
// greedy span filter // greedy span filter
continuous_mem_info_helper_->ClearContinuousMallocIndex();
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end()); std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
auto compare_total_size = [](ContinuousMemInfoPtr l, ContinuousMemInfoPtr r) -> bool {
return l->total_size_ < r->total_size_;
};
auto all_continuous_mem_info = continuous_mem_info_helper_->GetAllContinuousMemInfo();
std::sort(all_continuous_mem_info.begin(), all_continuous_mem_info.end(), compare_total_size);
std::set<std::shared_ptr<MemEvent>> events_no_need_swap;
for (const auto &continuous_mem_info : all_continuous_mem_info) {
GenContinuousMemSwapEvent(continuous_mem_info, &cur_mem_used, &events_no_need_swap);
}
for (const auto &iter : event_span_) { for (const auto &iter : event_span_) {
const auto &event = iter.second.first;
if (events_no_need_swap.count(event) > 0) {
continue;
}
auto span = iter.second.second; auto span = iter.second.second;
auto &event = iter.second.first; AddToSwapEventSetIfOutOfMem(event, span, &cur_mem_used);
auto start_index = ((event->index + total_step_ - span + 1) % total_step_); }
bool revert = false; }
size_t cur_index = start_index;
void MemOffloadStrategy::AddToSwapEventSetIfOutOfMem(const std::shared_ptr<MemEvent> &event, size_t span,
std::vector<size_t> *mem_used) {
const auto start_index = (GetPreMemEventIndex(event->index, span) + 1) % total_step_;
bool revert = false;
size_t cur_index = start_index;
while (cur_index != event->index) {
(*mem_used)[cur_index] += event->mem_size;
if (mem_used->at(cur_index) > mem_size_) {
revert = true;
}
cur_index += 1;
if (cur_index >= total_step_) {
cur_index = 0;
}
}
if (revert) {
cur_index = start_index;
while (cur_index != event->index) { while (cur_index != event->index) {
cur_mem_used[cur_index] += event->mem_size; (*mem_used)[cur_index] -= event->mem_size;
if (cur_mem_used[cur_index] > mem_size_) {
revert = true;
}
cur_index += 1; cur_index += 1;
if (cur_index >= total_step_) { if (cur_index >= total_step_) {
cur_index = 0; cur_index = 0;
} }
} }
if (revert) { (void)swap_events_.emplace(event);
cur_index = start_index; }
while (cur_index != event->index) { }
cur_mem_used[cur_index] -= event->mem_size;
cur_index += 1; void MemOffloadStrategy::GenContinuousMemSwapEvent(const ContinuousMemInfoPtr &continuous_mem_info,
if (cur_index >= total_step_) { std::vector<size_t> *mem_used,
cur_index = 0; std::set<std::shared_ptr<MemEvent>> *events_no_need_swap) {
} MS_EXCEPTION_IF_NULL(continuous_mem_info);
} if (continuous_mem_info->key_index_map_.empty()) {
(void)swap_events_.emplace(event); return;
}
const size_t continuous_mem_used_index = continuous_mem_info->compute_index_;
if (!continuous_mem_info->is_input_) {
continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, continuous_mem_info->compute_index_);
return;
}
const auto max_span_mem_in_device = GetMaxSpanForContinuousMem(continuous_mem_info, *mem_used);
size_t first_malloc_span = 0;
size_t first_malloc_size_dup = 0;
for (const auto &key_index : continuous_mem_info->key_index_map_) {
const auto &events_iter = mem_events_.find(key_index.first);
if (events_iter == mem_events_.end() || events_iter->second.empty()) {
MS_LOG(EXCEPTION) << "Can not find events for continuous input memory, device address key: " << key_index.first;
} }
size_t swap_in_event_index = kFirstGetMemEventIndex;
size_t swap_in_span = 0;
const bool is_high_priority = IsHighPriorityMem(key_index.first);
for (size_t i = kFirstGetMemEventIndex; i < events_iter->second.size(); ++i) {
const auto &mem_event = events_iter->second[i];
if (!is_high_priority && mem_event->index > continuous_mem_used_index) {
continue;
}
const size_t span = GetSpanBetweenMemEvents(mem_event->index, continuous_mem_used_index);
// Find the max span than less than or equal to max_span_mem_in_device.
if (span <= max_span_mem_in_device) {
if (span >= swap_in_span) {
swap_in_span = span;
swap_in_event_index = i;
}
events_no_need_swap->insert(mem_event);
}
}
if (swap_in_event_index != kFirstGetMemEventIndex || is_high_priority) {
swap_events_.insert(events_iter->second[swap_in_event_index]);
}
// Find the earliest index that continuous memory should be allocated
if (swap_in_span > first_malloc_span) {
first_malloc_span = swap_in_span;
first_malloc_size_dup = events_iter->second[swap_in_event_index]->mem_size;
} else if (swap_in_span == first_malloc_span) {
// Accumulate the memory size that already added to mem_used.
first_malloc_size_dup += events_iter->second[swap_in_event_index]->mem_size;
}
}
for (size_t span = 1; span <= first_malloc_span; ++span) {
size_t index = GetPreMemEventIndex(continuous_mem_used_index, span);
(*mem_used)[index] += continuous_mem_info->total_size_;
}
size_t index = GetPreMemEventIndex(continuous_mem_used_index, first_malloc_span);
(*mem_used)[index] -= first_malloc_size_dup;
continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, index);
}
size_t MemOffloadStrategy::GetMaxSpanForContinuousMem(const ContinuousMemInfoPtr &continuous_mem_info,
const std::vector<size_t> &mem_used) {
const size_t continuous_mem_used_index = continuous_mem_info->compute_index_;
size_t earliest_malloc_index = GetFirstMallocIndex(continuous_mem_info);
size_t max_span_mem_in_device = GetSpanBetweenMemEvents(earliest_malloc_index, continuous_mem_used_index);
for (size_t span = 1; span <= max_span_mem_in_device; ++span) {
size_t cur_index = GetPreMemEventIndex(continuous_mem_used_index, span);
if (mem_used[cur_index] + continuous_mem_info->total_size_ > mem_size_) {
max_span_mem_in_device = span - 1;
break;
}
}
return max_span_mem_in_device;
}
size_t MemOffloadStrategy::GetFirstMallocIndex(const ContinuousMemInfoPtr &continuous_mem_info) {
size_t earliest_malloc_index = continuous_mem_info->compute_index_;
for (const auto &key_index : continuous_mem_info->key_index_map_) {
const auto &events_iter = mem_events_.find(key_index.first);
if (events_iter == mem_events_.end() || events_iter->second.empty()) {
MS_LOG(EXCEPTION) << "Can not find events for continuous input memory, device address key: " << key_index.first;
}
const auto &first_event = events_iter->second[kInitOrMallocMemEventIndex];
if (first_event->index < earliest_malloc_index) {
earliest_malloc_index = first_event->index;
}
}
return earliest_malloc_index;
}
void MemOffloadStrategy::GenContinuousMemAllocSteps() {
for (const auto &continuous_mem_info : continuous_mem_info_helper_->GetAllContinuousMemInfo()) {
GenContinuousMemAllocStep(continuous_mem_info);
}
}
void MemOffloadStrategy::GenContinuousMemAllocStep(const ContinuousMemInfoPtr &continuous_mem_info) {
if (!continuous_mem_info->is_input_) {
continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, continuous_mem_info->compute_index_);
} else {
const size_t earliest_malloc_index = GetFirstMallocIndex(continuous_mem_info);
continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, earliest_malloc_index);
} }
} }
@ -288,5 +420,87 @@ std::set<size_t> MemOffloadStrategy::GetSwapOutEventIndex(const void *key,
} }
return swap_out_event_index; return swap_out_event_index;
} }
std::shared_ptr<ContinuousMemInfo> ContinuousMemInfoHelper::GetContinuousMemInfo(const void *address_key) {
const auto &key_compute_index_iter = key_continuous_info_map_.find(address_key);
if (key_compute_index_iter == key_continuous_info_map_.end()) {
return nullptr;
}
return key_compute_index_iter->second;
}
std::vector<ContinuousMemInfoPtr> ContinuousMemInfoHelper::GetAllContinuousMemInfo() {
std::vector<ContinuousMemInfoPtr> all_continuous_mem_info(input_continuous_mem_info_.size() +
output_continuous_mem_info_.size());
std::copy(input_continuous_mem_info_.begin(), input_continuous_mem_info_.end(), all_continuous_mem_info.begin());
std::copy_backward(output_continuous_mem_info_.begin(), output_continuous_mem_info_.end(),
all_continuous_mem_info.end());
return all_continuous_mem_info;
}
bool ContinuousMemInfoHelper::IsContinuousMem(const void *address_key) {
const auto continuous_mem_info = GetContinuousMemInfo(address_key);
return (continuous_mem_info != nullptr);
}
bool ContinuousMemInfoHelper::IsContinuousInputMem(const void *address_key) {
const auto continuous_mem_info = GetContinuousMemInfo(address_key);
return (continuous_mem_info != nullptr && continuous_mem_info->is_input_);
}
void ContinuousMemInfoHelper::AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size,
const std::vector<size_t> &align_size_list,
const std::vector<const void *> &address_key_list) {
if (align_size_list.size() != address_key_list.size()) {
MS_LOG(EXCEPTION) << "Number of align size[" << align_size_list.size()
<< "] is supposed to be equal to number of address[" << address_key_list.size() << "]";
}
ContinuousMemInfoPtr continuous_mem_info =
std::make_shared<ContinuousMemInfo>(is_input, total_size, compute_index, align_size_list);
for (size_t i = 0; i < address_key_list.size(); i += 1) {
auto key = address_key_list[i];
MS_EXCEPTION_IF_NULL(key);
(void)continuous_mem_info->key_index_map_.emplace(key, i);
key_continuous_info_map_.emplace(key, continuous_mem_info);
}
if (is_input) {
input_continuous_mem_info_.insert(continuous_mem_info);
} else {
output_continuous_mem_info_.insert(continuous_mem_info);
}
index_continuous_info_map_[compute_index].emplace_back(continuous_mem_info);
}
void MemOffloadStrategy::CountContinuousMemUsage(std::vector<size_t> *total_mem_used) {
const auto &input_continuous_mem_info_ = continuous_mem_info_helper_->GetAllContinuousMemInfo();
for (const auto &continuous_mem_info : input_continuous_mem_info_) {
if (!continuous_mem_info->is_input_ || continuous_mem_info->key_index_map_.empty()) {
continue;
}
const auto &compute_index = continuous_mem_info->compute_index_;
size_t earliest_malloc_index = SIZE_MAX;
for (const auto &key_index : continuous_mem_info->key_index_map_) {
const auto &key = key_index.first;
const auto &events_iter = mem_events_.find(key);
if (events_iter == mem_events_.end() || events_iter->second.empty()) {
MS_LOG(EXCEPTION) << "Can not find memory events of continuous input memory, device address key: " << key;
}
const auto &mem_events = events_iter->second;
const auto &first_event = mem_events[kInitOrMallocMemEventIndex];
if (first_event->index < earliest_malloc_index) {
earliest_malloc_index = first_event->index;
}
const auto &last_events = mem_events[mem_events.size() - 1];
const auto end_step = IsHighPriorityMem(key) ? total_step_ - 1 : last_events->index;
const auto mem_size = last_events->mem_size;
for (size_t start_index = compute_index + 1; start_index <= end_step; start_index += 1) {
(*total_mem_used)[start_index] += mem_size;
}
}
for (size_t start_index = earliest_malloc_index; start_index <= compute_index; ++start_index) {
(*total_mem_used)[start_index] += continuous_mem_info->total_size_;
}
}
}
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore

View File

@ -21,6 +21,7 @@
#include <set> #include <set>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <algorithm>
namespace mindspore { namespace mindspore {
namespace device { namespace device {
@ -37,17 +38,70 @@ struct MemEvent {
const void *key{nullptr}; const void *key{nullptr};
}; };
using MemEventPtr = std::shared_ptr<MemEvent>;
using MemEventPtrList = std::vector<MemEventPtr>;
struct ContinuousMemInfo {
ContinuousMemInfo(bool is_input, size_t total_size, size_t compute_index, std::vector<size_t> align_size_list)
: is_input_(is_input),
total_size_(total_size),
compute_index_(compute_index),
align_size_list_(std::move(align_size_list)) {}
bool is_input_;
size_t total_size_;
size_t compute_index_;
const std::vector<size_t> align_size_list_;
std::map<const void *, size_t> key_index_map_;
};
using ContinuousMemInfoPtr = std::shared_ptr<ContinuousMemInfo>;
class ContinuousMemInfoHelper {
public:
void AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size,
const std::vector<size_t> &align_size_list,
const std::vector<const void *> &address_key_list);
std::shared_ptr<ContinuousMemInfo> GetContinuousMemInfo(const void *address_key);
std::vector<ContinuousMemInfoPtr> GetAllContinuousMemInfo();
bool IsContinuousMem(const void *address_key);
bool IsContinuousInputMem(const void *address_key);
void AddContinuousMallocIndex(const ContinuousMemInfoPtr &mem_info, size_t index) {
first_malloc_index_.emplace(mem_info, index);
}
bool NeedMallocContinuousMem(const ContinuousMemInfoPtr &mem_info, size_t index) {
const auto &iter = first_malloc_index_.find(mem_info);
return iter != first_malloc_index_.end() && iter->second == index;
}
void ClearContinuousMallocIndex() { first_malloc_index_.clear(); }
const std::vector<ContinuousMemInfoPtr> &GetIndexContinuousMemInfo(size_t step) {
return index_continuous_info_map_[step];
}
private:
std::set<ContinuousMemInfoPtr> input_continuous_mem_info_;
std::set<ContinuousMemInfoPtr> output_continuous_mem_info_;
std::map<const void *, ContinuousMemInfoPtr> key_continuous_info_map_;
std::map<ContinuousMemInfoPtr, size_t> first_malloc_index_;
std::map<size_t, std::vector<ContinuousMemInfoPtr>> index_continuous_info_map_;
};
class MemOffloadStrategy { class MemOffloadStrategy {
public: public:
MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority, MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority,
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events, const std::map<const void *, MemEventPtrList> &mem_events,
const std::set<const void *> &manual_offload_keys, const std::set<const void *> &manual_offload_keys,
const std::map<const void *, std::vector<size_t>> &high_priority_updated_step, size_t total_step) const std::map<const void *, std::vector<size_t>> &high_priority_updated_step, size_t total_step,
std::shared_ptr<ContinuousMemInfoHelper> continuous_mem_info_manager)
: mem_priority_(mem_priority), : mem_priority_(mem_priority),
mem_events_(mem_events), mem_events_(mem_events),
manual_offload_keys_(manual_offload_keys), manual_offload_keys_(manual_offload_keys),
high_priority_updated_step_(high_priority_updated_step), high_priority_updated_step_(high_priority_updated_step),
total_step_(total_step) {} total_step_(total_step),
continuous_mem_info_helper_(std::move(continuous_mem_info_manager)) {}
virtual ~MemOffloadStrategy() = default; virtual ~MemOffloadStrategy() = default;
@ -55,9 +109,9 @@ class MemOffloadStrategy {
void SetComputeTime(const std::vector<double> &compute_time) { compute_time_ = compute_time; } void SetComputeTime(const std::vector<double> &compute_time) { compute_time_ = compute_time; }
std::vector<std::shared_ptr<MemEvent>> &GetPreComputeEvents(size_t step); MemEventPtrList &GetPreComputeEvents(size_t step);
std::vector<std::shared_ptr<MemEvent>> &GetPostComputeEvents(size_t step); MemEventPtrList &GetPostComputeEvents(size_t step);
void set_mem_size(size_t mem_size) { mem_size_ = mem_size; } void set_mem_size(size_t mem_size) { mem_size_ = mem_size; }
@ -76,29 +130,52 @@ class MemOffloadStrategy {
void GenComputeMemEvents(); void GenComputeMemEvents();
void GenFreeEvent(const std::shared_ptr<MemEvent> &last_event); void GenFreeEvent(const MemEventPtr &last_event);
std::set<size_t> GetSwapOutEventIndex(const void *key, const std::vector<std::shared_ptr<MemEvent>> &mem_events); std::set<size_t> GetSwapOutEventIndex(const void *key, const std::vector<std::shared_ptr<MemEvent>> &mem_events);
size_t GetSpanBetweenMemEvents(size_t pre_step, size_t post_step) const { void AddToSwapEventSetIfOutOfMem(const MemEventPtr &mem_event, size_t span, std::vector<size_t> *mem_used);
return (post_step + total_step_ - pre_step) % total_step_;
void GenContinuousMemSwapEvent(const ContinuousMemInfoPtr &continuous_mem_info, std::vector<size_t> *mem_used,
std::set<MemEventPtr> *events_no_need_swap);
size_t GetMaxSpanForContinuousMem(const ContinuousMemInfoPtr &continuous_mem_info,
const std::vector<size_t> &mem_used);
size_t GetFirstMallocIndex(const ContinuousMemInfoPtr &continuous_mem_info);
void GenContinuousMemAllocSteps();
void GenContinuousMemAllocStep(const ContinuousMemInfoPtr &continuous_mem_info);
void CountContinuousMemUsage(std::vector<size_t> *total_mem_used);
size_t GetSpanBetweenMemEvents(size_t pre_index, size_t post_index) const {
return (post_index + total_step_ - pre_index) % total_step_;
}
size_t GetPreMemEventIndex(size_t cur_index, size_t span) const {
return (cur_index + total_step_ - span) % total_step_;
} }
const std::map<const void *, MemPriority> &mem_priority_; const std::map<const void *, MemPriority> &mem_priority_;
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_; const std::map<const void *, MemEventPtrList> &mem_events_;
const std::set<const void *> &manual_offload_keys_; const std::set<const void *> &manual_offload_keys_;
std::map<const void *, std::vector<size_t>> high_priority_updated_step_; std::map<const void *, std::vector<size_t>> high_priority_updated_step_;
const size_t total_step_; const size_t total_step_;
std::vector<std::vector<std::shared_ptr<MemEvent>>> pre_compute_events_; std::vector<MemEventPtrList> pre_compute_events_;
std::vector<std::vector<std::shared_ptr<MemEvent>>> post_compute_events_; std::vector<MemEventPtrList> post_compute_events_;
size_t mem_size_{0}; size_t mem_size_{0};
std::vector<double> compute_time_; std::vector<double> compute_time_;
bool need_swap_{false}; bool need_swap_{false};
std::multimap<size_t, std::pair<std::shared_ptr<MemEvent>, size_t>> event_span_; std::multimap<size_t, std::pair<MemEventPtr, size_t>> event_span_;
std::set<std::shared_ptr<MemEvent>> swap_events_; std::multimap<size_t, std::pair<MemEventPtr, size_t>> continuous_input_event_span_;
std::set<MemEventPtr> swap_events_;
std::vector<size_t> min_mem_used_; std::vector<size_t> min_mem_used_;
size_t mem_used_without_swap_{0}; size_t mem_used_without_swap_{0};
size_t min_mem_needed_{0}; size_t min_mem_needed_{0};
std::shared_ptr<ContinuousMemInfoHelper> continuous_mem_info_helper_{nullptr};
}; };
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore

View File

@ -17,6 +17,7 @@
#include "runtime/device/memory_scheduler.h" #include "runtime/device/memory_scheduler.h"
#include <algorithm> #include <algorithm>
#include <queue> #include <queue>
#include <set>
#ifdef _MSC_VER #ifdef _MSC_VER
#include <time.h> #include <time.h>
#else #else
@ -44,6 +45,34 @@ double GetCurrentTime() {
} }
} // namespace } // namespace
void *MemHandler::MallocHost(size_t mem_size) {
auto &mem_que = cached_host_mem_[mem_size];
if (!mem_que.empty()) {
auto ret = mem_que.front();
mem_que.pop();
return ret;
}
auto block = std::make_shared<std::vector<uint8_t>>();
try {
block->resize(mem_size, 0);
auto ptr = block->data();
host_mem_block_map_[ptr] = block;
return ptr;
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "Malloc memory failed: size " << mem_size;
}
}
void MemHandler::FreeHost(void *ptr) {
MS_EXCEPTION_IF_NULL(ptr);
auto iter = host_mem_block_map_.find(ptr);
if (iter == host_mem_block_map_.end()) {
MS_LOG(ERROR) << "Free ptr not be created from manager!";
}
auto mem_size = iter->second->size();
cached_host_mem_[mem_size].emplace(iter->first);
}
void MemScheduler::Clear() { void MemScheduler::Clear() {
if (mem_handler_ == nullptr) { if (mem_handler_ == nullptr) {
return; return;
@ -72,6 +101,15 @@ void MemScheduler::ClearAllocatedMem() {
} }
} }
swap_host_ptr_.clear(); swap_host_ptr_.clear();
continuous_mem_key_.clear();
}
void MemScheduler::AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size,
const std::vector<size_t> &align_size_list,
const std::vector<const void *> &address_key_list) {
MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
continuous_mem_info_helper_->AddContinuousMemInfo(is_input, compute_index, total_size, align_size_list,
address_key_list);
} }
void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) { void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
@ -119,18 +157,51 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
return nullptr; return nullptr;
} }
void *MemScheduler::MallocContinuousMem(const std::shared_ptr<MemEvent> &event, void *stream) {
const auto &continuous_mem_info = continuous_mem_info_helper_->GetContinuousMemInfo(event->key);
void *device_ptr = nullptr;
if (cur_step_allocated_continuous_mem_.count(continuous_mem_info) == 0 &&
continuous_mem_info_helper_->NeedMallocContinuousMem(continuous_mem_info, current_step_)) {
if (mem_result_.find(event->key) != mem_result_.end()) {
MS_LOG(EXCEPTION) << "Device memory is allocated before first continuous memory alloc event, event key: "
<< event->key << ", continuous memory used index: " << continuous_mem_info->compute_index_;
}
const auto &device_ptr_list =
MallocContinuousMem(continuous_mem_info->total_size_, continuous_mem_info->align_size_list_, stream);
if (device_ptr_list.empty()) {
MS_LOG(WARNING) << "MallocContinuousMemFromMemPool failed, size: " << continuous_mem_info->total_size_;
return nullptr;
}
for (const auto &key_index : continuous_mem_info->key_index_map_) {
MS_EXCEPTION_IF_NULL(device_ptr_list[key_index.second]);
mem_result_[key_index.first] = device_ptr_list[key_index.second];
continuous_mem_key_.insert(key_index.first);
}
device_ptr = mem_result_[event->key];
MS_EXCEPTION_IF_NULL(device_ptr);
cur_step_allocated_continuous_mem_.insert(continuous_mem_info);
} else {
device_ptr = MallocDevice(event->mem_size, stream);
}
return device_ptr;
}
bool MemScheduler::PreComputeInit(const std::shared_ptr<MemEvent> &event, void *stream) { bool MemScheduler::PreComputeInit(const std::shared_ptr<MemEvent> &event, void *stream) {
const bool is_continuous_mem = continuous_mem_info_helper_->IsContinuousMem(event->key);
const auto &iter = mem_result_.find(event->key); const auto &iter = mem_result_.find(event->key);
const bool new_malloc = iter == mem_result_.end(); const bool new_malloc = iter == mem_result_.end();
void *device_ptr = nullptr; void *device_ptr = nullptr;
if (new_malloc) { if (!new_malloc) {
device_ptr = MallocDevice(event->mem_size, stream);
if (device_ptr == nullptr) {
return false;
}
} else {
device_ptr = iter->second; device_ptr = iter->second;
} else if (is_continuous_mem) {
device_ptr = MallocContinuousMem(event, stream);
} else {
device_ptr = MallocDevice(event->mem_size, stream);
} }
if (device_ptr == nullptr) {
return false;
}
if (new_malloc || high_priority_mem_need_init_.count(event->key) != 0) { if (new_malloc || high_priority_mem_need_init_.count(event->key) != 0) {
MS_LOG(DEBUG) << "Init input data from host, key: " << event->key; MS_LOG(DEBUG) << "Init input data from host, key: " << event->key;
auto host_ptr = init_host_ptr_[event->key]; auto host_ptr = init_host_ptr_[event->key];
@ -142,29 +213,33 @@ bool MemScheduler::PreComputeInit(const std::shared_ptr<MemEvent> &event, void *
} }
bool MemScheduler::PreComputeMalloc(const std::shared_ptr<MemEvent> &event, void *stream) { bool MemScheduler::PreComputeMalloc(const std::shared_ptr<MemEvent> &event, void *stream) {
const auto &iter = mem_result_.find(event->key); const bool is_continuous_mem = continuous_mem_info_helper_->IsContinuousMem(event->key);
const bool new_malloc = iter == mem_result_.end();
void *device_ptr = nullptr; void *device_ptr = nullptr;
if (new_malloc) { const auto &iter = mem_result_.find(event->key);
device_ptr = MallocDevice(event->mem_size, stream); if (iter != mem_result_.end()) {
if (device_ptr == nullptr) { return true;
return false; } else if (is_continuous_mem) {
} device_ptr = MallocContinuousMem(event, stream);
} else { } else {
device_ptr = iter->second; device_ptr = MallocDevice(event->mem_size, stream);
}
if (device_ptr == nullptr) {
return false;
} }
mem_result_[event->key] = device_ptr; mem_result_[event->key] = device_ptr;
return true; return true;
} }
bool MemScheduler::PreComputeSwapIn(const std::shared_ptr<MemEvent> &event, void *stream) { bool MemScheduler::PreComputeSwapIn(const std::shared_ptr<MemEvent> &event, void *stream) {
if (!PreComputeMalloc(event, stream)) {
return false;
}
PreComputeMalloc(event, stream);
const auto device_ptr = mem_result_[event->key];
MS_EXCEPTION_IF_NULL(device_ptr);
bool from_init = true; bool from_init = true;
void *host_ptr = nullptr; void *host_ptr = nullptr;
GetHostPtr(event->key, &host_ptr, &from_init); GetHostPtr(event->key, &host_ptr, &from_init);
auto device_ptr = MallocDevice(event->mem_size, stream);
if (device_ptr == nullptr) {
return false;
}
MS_EXCEPTION_IF_NULL(host_ptr); MS_EXCEPTION_IF_NULL(host_ptr);
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream); mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
mem_result_[event->key] = device_ptr; mem_result_[event->key] = device_ptr;
@ -196,7 +271,7 @@ bool MemScheduler::PreComputeGet(const std::shared_ptr<MemEvent> &event, void *s
auto device_ptr = MallocDevice(mem_size, stream); auto device_ptr = MallocDevice(mem_size, stream);
mem_handler_->SwapIn(host_ptr, device_ptr, mem_size, stream); mem_handler_->SwapIn(host_ptr, device_ptr, mem_size, stream);
if (!from_init) { if (!from_init) {
(void)swap_host_ptr_.erase(host_ptr); (void)swap_host_ptr_.erase(key);
mem_handler_->FreeHost(host_ptr); mem_handler_->FreeHost(host_ptr);
} }
mem_result_[key] = device_ptr; mem_result_[key] = device_ptr;
@ -223,12 +298,14 @@ bool MemScheduler::PreCompute(void *stream) {
ret = PreComputeGet(event, stream); ret = PreComputeGet(event, stream);
} }
if (!ret) { if (!ret) {
cur_step_allocated_continuous_mem_.clear();
return false; return false;
} }
} }
if (record_compute_time_ && !updated_) { if (record_compute_time_ && !updated_) {
compute_start_time_ = GetCurrentTime(); compute_start_time_ = GetCurrentTime();
} }
cur_step_allocated_continuous_mem_.clear();
return true; return true;
} }
@ -253,6 +330,7 @@ bool MemScheduler::PostCompute(void *stream) {
} }
mem_handler_->FreeDevice(ptr); mem_handler_->FreeDevice(ptr);
(void)mem_result_.erase(event->key); (void)mem_result_.erase(event->key);
continuous_mem_key_.erase(event->key);
} else if (event->type == kSwapOut) { } else if (event->type == kSwapOut) {
auto device_ptr = mem_result_[event->key]; auto device_ptr = mem_result_[event->key];
if (device_ptr == nullptr) { if (device_ptr == nullptr) {
@ -261,6 +339,11 @@ bool MemScheduler::PostCompute(void *stream) {
SwapOutAndFreeDevice(event->key, device_ptr, event->mem_size, stream); SwapOutAndFreeDevice(event->key, device_ptr, event->mem_size, stream);
} }
} }
for (const auto &info : continuous_mem_info_helper_->GetIndexContinuousMemInfo(current_step_)) {
for (const auto &key_index : info->key_index_map_) {
continuous_mem_key_.erase(key_index.first);
}
}
++current_step_; ++current_step_;
return true; return true;
} }
@ -269,8 +352,9 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
MS_EXCEPTION_IF_NULL(mem_handler_); MS_EXCEPTION_IF_NULL(mem_handler_);
if (strategy_ == nullptr) { if (strategy_ == nullptr) {
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_, strategy_ =
high_priority_updated_step_, total_step_); std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_,
high_priority_updated_step_, total_step_, continuous_mem_info_helper_);
if (manual_offload_keys_.empty()) { if (manual_offload_keys_.empty()) {
compute_time_.resize(total_step_); compute_time_.resize(total_step_);
} else { } else {
@ -352,25 +436,78 @@ void *MemScheduler::MallocDevice(size_t mem_size, void *stream) {
if (device_ptr != nullptr || !optimized_) { if (device_ptr != nullptr || !optimized_) {
return device_ptr; return device_ptr;
} }
// Find memory block big enough in mem_result_, except continuous mem and memory blocks used in this step.
auto iter = mem_result_.begin(); auto iter = mem_result_.begin();
using KeySizePair = std::pair<const void *, size_t>; using KeySizePair = std::pair<const void *, size_t>;
auto less = [](const KeySizePair &a, const KeySizePair &b) -> bool { return a.second < b.second; }; auto less = [](const KeySizePair &a, const KeySizePair &b) -> bool { return a.second < b.second; };
std::priority_queue<KeySizePair, std::vector<KeySizePair>, decltype(less)> mem_can_swap(less); std::priority_queue<KeySizePair, std::vector<KeySizePair>, decltype(less)> mem_can_swap(less);
while (iter != mem_result_.end()) { while (iter != mem_result_.end()) {
const auto key = iter->first; const auto key = iter->first;
if (no_reuse_key.count(key) != 0) { if (no_reuse_key.count(key) != 0 || continuous_mem_key_.count(key) != 0) {
++iter; ++iter;
continue; continue;
} }
const auto device_mem_size = GetMemSize(key); const auto device_mem_size = GetMemSize(key);
mem_can_swap.push({key, device_mem_size});
if (device_mem_size >= mem_size) { if (device_mem_size >= mem_size) {
SwapOutAndFreeDevice(key, iter->second, device_mem_size, stream); SwapOutAndFreeDevice(key, iter->second, device_mem_size, stream);
device_ptr = mem_handler_->MallocDevice(mem_size); device_ptr = mem_handler_->MallocDevice(mem_size);
MS_EXCEPTION_IF_NULL(device_ptr);
return device_ptr; return device_ptr;
} }
mem_can_swap.push({key, device_mem_size});
++iter; ++iter;
} }
// Try swap out memory block from big to small
while (!mem_can_swap.empty()) {
const auto &max_mem_in_device = mem_can_swap.top();
const auto key = max_mem_in_device.first;
const auto swap_mem_size = max_mem_in_device.second;
auto swap_device_ptr = mem_result_[key];
MS_EXCEPTION_IF_NULL(swap_device_ptr);
mem_can_swap.pop();
SwapOutAndFreeDevice(key, swap_device_ptr, swap_mem_size, stream);
device_ptr = mem_handler_->MallocDevice(mem_size);
if (device_ptr != nullptr) {
return device_ptr;
}
}
return nullptr;
}
std::vector<void *> MemScheduler::MallocContinuousMem(size_t total_size, const std::vector<size_t> &size_list,
void *stream) {
const auto &no_reuse_key = step_keys_[current_step_];
auto device_ptr_list = mem_handler_->MallocContinuousMemFromMemPool(size_list);
if (!device_ptr_list.empty() || !optimized_) {
return device_ptr_list;
}
// Find memory block big enough in mem_result_, except continuous mem and memory blocks used in this step.
auto iter = mem_result_.begin();
using KeySizePair = std::pair<const void *, size_t>;
auto less = [](const KeySizePair &a, const KeySizePair &b) -> bool { return a.second < b.second; };
std::priority_queue<KeySizePair, std::vector<KeySizePair>, decltype(less)> mem_can_swap(less);
while (iter != mem_result_.end()) {
const auto key = iter->first;
if (no_reuse_key.count(key) != 0 || continuous_mem_key_.count(key) != 0) {
++iter;
continue;
}
const auto device_mem_size = GetMemSize(key);
if (device_mem_size >= total_size) {
SwapOutAndFreeDevice(key, iter->second, device_mem_size, stream);
device_ptr_list = mem_handler_->MallocContinuousMemFromMemPool(size_list);
if (device_ptr_list.empty()) {
MS_LOG(EXCEPTION) << "device_ptr_list empty";
}
return device_ptr_list;
}
mem_can_swap.push({key, device_mem_size});
++iter;
}
// Try swap out memory block from big to small
while (!mem_can_swap.empty()) { while (!mem_can_swap.empty()) {
const auto &max_mem_in_device = mem_can_swap.top(); const auto &max_mem_in_device = mem_can_swap.top();
mem_can_swap.pop(); mem_can_swap.pop();
@ -379,12 +516,13 @@ void *MemScheduler::MallocDevice(size_t mem_size, void *stream) {
auto swap_device_ptr = mem_result_[key]; auto swap_device_ptr = mem_result_[key];
MS_EXCEPTION_IF_NULL(swap_device_ptr); MS_EXCEPTION_IF_NULL(swap_device_ptr);
SwapOutAndFreeDevice(key, swap_device_ptr, swap_mem_size, stream); SwapOutAndFreeDevice(key, swap_device_ptr, swap_mem_size, stream);
device_ptr = mem_handler_->MallocDevice(mem_size); device_ptr_list = mem_handler_->MallocContinuousMemFromMemPool(size_list);
if (device_ptr != nullptr) { if (!device_ptr_list.empty()) {
return device_ptr; return device_ptr_list;
} }
} }
return nullptr;
return device_ptr_list;
} }
void MemScheduler::SwapOutAndFreeDevice(const void *key, void *device_ptr, size_t mem_size, void *stream) { void MemScheduler::SwapOutAndFreeDevice(const void *key, void *device_ptr, size_t mem_size, void *stream) {

View File

@ -20,22 +20,36 @@
#include <map> #include <map>
#include <set> #include <set>
#include <memory> #include <memory>
#include <queue>
#include <utility> #include <utility>
#include "runtime/device/memory_offload_strategy.h" #include "runtime/device/memory_offload_strategy.h"
#include "runtime/device/memory_manager.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
class MemHandler { class MemHandler {
public: public:
MemHandler() = default; explicit MemHandler(std::shared_ptr<MemoryManager> memory_manager) : memory_manager_(memory_manager) {}
virtual ~MemHandler() = default; ~MemHandler() = default;
virtual size_t GetAvailableMemSize() = 0; size_t GetAvailableMemSize() { return memory_manager_->GetAvailableMemSize(); }
virtual void *MallocDevice(size_t mem_size) = 0; void *MallocDevice(size_t mem_size) { return memory_manager_->MallocMemFromMemPool(mem_size, false); }
virtual void FreeDevice(void *ptr) = 0; void FreeDevice(void *ptr) { memory_manager_->FreeMemFromMemPool(ptr); }
virtual void *MallocHost(size_t mem_size) = 0; void *MallocHost(size_t mem_size);
virtual void FreeHost(void *ptr) = 0; void FreeHost(void *ptr);
virtual void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) = 0; void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) {
virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0; memory_manager_->SwapIn(host_ptr, device_ptr, mem_size, stream);
}
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) {
memory_manager_->SwapOut(device_ptr, host_ptr, mem_size, stream);
}
std::vector<void *> MallocContinuousMemFromMemPool(const std::vector<size_t> &size_list) {
return memory_manager_->MallocContinuousMemFromMemPool(size_list);
}
private:
std::shared_ptr<MemoryManager> memory_manager_;
std::map<size_t, std::queue<void *>> cached_host_mem_;
std::map<void *, std::shared_ptr<std::vector<uint8_t>>> host_mem_block_map_;
}; };
class MemScheduler { class MemScheduler {
@ -88,6 +102,10 @@ class MemScheduler {
void ClearMemNeedInit() { high_priority_mem_need_init_.clear(); } void ClearMemNeedInit() { high_priority_mem_need_init_.clear(); }
void AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size,
const std::vector<size_t> &align_size_list,
const std::vector<const void *> &address_key_list);
private: private:
void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0); void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
@ -99,6 +117,8 @@ class MemScheduler {
void *MallocDevice(size_t mem_size, void *stream); void *MallocDevice(size_t mem_size, void *stream);
std::vector<void *> MallocContinuousMem(size_t total_size, const std::vector<size_t> &size_list, void *stream);
void SwapOutAndFreeDevice(const void *key, void *device_ptr, size_t mem_size, void *stream); void SwapOutAndFreeDevice(const void *key, void *device_ptr, size_t mem_size, void *stream);
size_t GetMemSize(const void *key); size_t GetMemSize(const void *key);
@ -115,8 +135,10 @@ class MemScheduler {
bool PreComputeGet(const std::shared_ptr<MemEvent> &event, void *stream); bool PreComputeGet(const std::shared_ptr<MemEvent> &event, void *stream);
void *MallocContinuousMem(const std::shared_ptr<MemEvent> &event, void *stream);
std::map<const void *, MemPriority> mem_priority_; std::map<const void *, MemPriority> mem_priority_;
std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_; std::map<const void *, MemEventPtrList> mem_events_;
std::set<const void *> manual_offload_keys_; std::set<const void *> manual_offload_keys_;
std::vector<std::set<const void *>> step_keys_; std::vector<std::set<const void *>> step_keys_;
std::map<const void *, void *> mem_result_; std::map<const void *, void *> mem_result_;
@ -134,6 +156,9 @@ class MemScheduler {
bool updated_{false}; bool updated_{false};
std::shared_ptr<MemHandler> mem_handler_{nullptr}; std::shared_ptr<MemHandler> mem_handler_{nullptr};
std::shared_ptr<MemOffloadStrategy> strategy_{nullptr}; std::shared_ptr<MemOffloadStrategy> strategy_{nullptr};
std::shared_ptr<ContinuousMemInfoHelper> continuous_mem_info_helper_{std::make_shared<ContinuousMemInfoHelper>()};
std::set<std::shared_ptr<ContinuousMemInfo>> cur_step_allocated_continuous_mem_;
std::set<const void *> continuous_mem_key_;
}; };
class MemSchedulerManager { class MemSchedulerManager {

View File

@ -21,16 +21,17 @@
namespace mindspore::device { namespace mindspore::device {
constexpr size_t kDeviceMemSize = 5; constexpr size_t kDeviceMemSize = 5;
constexpr size_t kMaxVirtualCount = 1024; constexpr size_t kMaxVirtualCount = 1024;
class MemHandlerImpl : public MemHandler { class MemoryManagerStub : public MemoryManager {
public: public:
MemHandlerImpl() { MemoryManagerStub() {
device_mem_.resize(kMaxVirtualCount, 0); device_mem_.resize(kMaxVirtualCount, 0);
host_mem_.resize(kMaxVirtualCount, 1);
} }
void Initialize() override {}
void Finalize() override {}
size_t GetAvailableMemSize() override { return kDeviceMemSize; } size_t GetAvailableMemSize() override { return kDeviceMemSize; }
void *MallocDevice(size_t mem_size) override { void *MallocMemFromMemPool(size_t mem_size, bool useless = false) override {
if (device_virtual_count_ >= kDeviceMemSize) { if (device_virtual_count_ >= kDeviceMemSize) {
return nullptr; return nullptr;
} }
@ -40,7 +41,7 @@ class MemHandlerImpl : public MemHandler {
return ret; return ret;
} }
void FreeDevice(void *ptr) override { void FreeMemFromMemPool(void *ptr) override {
--device_virtual_count_; --device_virtual_count_;
auto iter = device_mem_size_.find(ptr); auto iter = device_mem_size_.find(ptr);
if (iter != device_mem_size_.end()) { if (iter != device_mem_size_.end()) {
@ -48,31 +49,32 @@ class MemHandlerImpl : public MemHandler {
} }
} }
void *MallocHost(size_t mem_size) override { std::vector<void *> MallocContinuousMemFromMemPool(const std::vector<size_t> &size_list) override {
auto ret = host_mem_.data() + host_virtual_count_; const size_t total_size = std::accumulate(size_list.begin(), size_list.end(), 0);
++host_virtual_count_; std::vector<void *> ret;
host_mem_size_.emplace(ret, mem_size); if (device_virtual_count_ + total_size > kDeviceMemSize) {
return ret; return ret;
}
void FreeHost(void *ptr) override {
auto iter = host_mem_size_.find(ptr);
if (iter != host_mem_size_.end()) {
host_mem_size_.erase(iter);
} }
for (const auto &size : size_list) {
auto ptr = device_mem_.data() + device_virtual_count_;
device_mem_size_.emplace(ptr, size);
ret.emplace_back(ptr);
++device_virtual_count_;
}
return ret;
} }
void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {} void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {}
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {} void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {}
protected:
uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) { return nullptr; }
private: private:
std::vector<uint8_t> device_mem_; std::vector<uint8_t> device_mem_;
std::vector<uint8_t> host_mem_;
size_t device_virtual_count_{0}; size_t device_virtual_count_{0};
size_t host_virtual_count_{0};
std::map<void *, size_t> device_mem_size_; std::map<void *, size_t> device_mem_size_;
std::map<void *, size_t> host_mem_size_;
}; };
class TestMemScheduler : public UT::Common { class TestMemScheduler : public UT::Common {
@ -136,14 +138,14 @@ TEST_F(TestMemScheduler, test_mem_scheduler_manager) {
/// Feature: MemScheduler /// Feature: MemScheduler
/// Description: Test MemScheduler interface /// Description: Test MemScheduler interface
/// Expectation: MemScheduler GetOrMalloc return valid ptr /// Expectation: MemScheduler GetOrMalloc return valid ptr for continuous mem
TEST_F(TestMemScheduler, test_mem_scheduler) { TEST_F(TestMemScheduler, test_mem_scheduler_with_continuous_mem) {
MemSchedulerManager mem_scheduler_manager; MemSchedulerManager mem_scheduler_manager;
auto scheduler = mem_scheduler_manager.GetOrCreateMemScheduler(0); auto scheduler = mem_scheduler_manager.GetOrCreateMemScheduler(0);
ASSERT_NE(scheduler, nullptr); ASSERT_NE(scheduler, nullptr);
auto need_record = scheduler->need_record_event(); auto need_record = scheduler->need_record_event();
ASSERT_EQ(need_record, true); ASSERT_EQ(need_record, true);
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>(); std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandler>(std::make_shared<MemoryManagerStub>());
ASSERT_NE(mem_handler, nullptr); ASSERT_NE(mem_handler, nullptr);
scheduler->SetMemHandler(mem_handler); scheduler->SetMemHandler(mem_handler);
@ -190,7 +192,7 @@ TEST_F(TestMemScheduler, test_manual_mem_scheduler) {
ASSERT_NE(scheduler, nullptr); ASSERT_NE(scheduler, nullptr);
auto need_record = scheduler->need_record_event(); auto need_record = scheduler->need_record_event();
ASSERT_EQ(need_record, true); ASSERT_EQ(need_record, true);
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>(); std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandler>(std::make_shared<MemoryManagerStub>());
ASSERT_NE(mem_handler, nullptr); ASSERT_NE(mem_handler, nullptr);
scheduler->SetMemHandler(mem_handler); scheduler->SetMemHandler(mem_handler);
@ -232,4 +234,53 @@ TEST_F(TestMemScheduler, test_manual_mem_scheduler) {
// run // run
Run(scheduler); Run(scheduler);
} }
/// Feature: MemScheduler
/// Description: Test MemScheduler interface
/// Expectation: MemScheduler GetOrMalloc return valid ptr
TEST_F(TestMemScheduler, test_mem_scheduler) {
MemSchedulerManager mem_scheduler_manager;
auto scheduler = mem_scheduler_manager.GetOrCreateMemScheduler(0);
ASSERT_NE(scheduler, nullptr);
auto need_record = scheduler->need_record_event();
ASSERT_EQ(need_record, true);
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandler>(std::make_shared<MemoryManagerStub>());
ASSERT_NE(mem_handler, nullptr);
scheduler->SetMemHandler(mem_handler);
// input data
used_tensor_num_ = 8;
total_step_ = 8;
std::vector<uint8_t> tensor_keys(used_tensor_num_, 0);
std::vector<uint8_t> tensor_datas(used_tensor_num_, 0);
std::vector<size_t> init_tensors = {0, 2, 4};
// 8 step tensor usage
//
// 0-----0-----0
// 1--1--1
// 2-----2--------2
// 3-----------3
// 4--------------4
// 6--------------6
// 7--------7
//
std::vector<std::vector<size_t>> step_used_tensors = {{0, 2, 6}, {1, 7}, {0, 1, 2, 3, 4}, {1},
{0, 7}, {2, 6}, {3}, {4}};
tensor_keys_.swap(tensor_keys);
tensor_datas_.swap(tensor_datas);
init_tensors_.swap(init_tensors);
step_used_tensors_.swap(step_used_tensors);
scheduler->SetTotalStep(total_step_);
// record
Record(scheduler);
// Add continuous memory info
scheduler->AddContinuousMemInfo(true, 2, 3, {1, 1, 1},
{tensor_keys_.data(), tensor_keys_.data() + 1, tensor_keys_.data() + 2});
scheduler->AddContinuousMemInfo(false, 2, 2, {1, 1}, {tensor_keys_.data() + 3, tensor_keys_.data() + 4});
// optimize
scheduler->Optimize();
// run
Run(scheduler);
}
} // namespace mindspore::device } // namespace mindspore::device