forked from mindspore-Ecosystem/mindspore
Memoffload reuse continuous memory
This commit is contained in:
parent
508f3d7b5f
commit
7522bb2c9b
|
@ -48,11 +48,6 @@ uint64_t AscendMemoryManager::GetMsMaxMemSize() { return AscendMemAdapter::GetIn
|
||||||
|
|
||||||
uint64_t AscendMemoryManager::GetMsUsedHbmSize() { return AscendMemAdapter::GetInstance().GetMsUsedHbmSize(); }
|
uint64_t AscendMemoryManager::GetMsUsedHbmSize() { return AscendMemAdapter::GetInstance().GetMsUsedHbmSize(); }
|
||||||
|
|
||||||
void *AscendMemoryManager::MallocDevice(size_t size) {
|
|
||||||
auto align_size = GetCommonAlignSize(size);
|
|
||||||
return AscendMemoryPool::GetInstance().AllocTensorMem(align_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
void *AscendMemoryManager::MallocMemFromMemPool(size_t size, bool from_persistent_mem) {
|
void *AscendMemoryManager::MallocMemFromMemPool(size_t size, bool from_persistent_mem) {
|
||||||
auto align_size = GetCommonAlignSize(size);
|
auto align_size = GetCommonAlignSize(size);
|
||||||
const auto device_addr = AscendMemoryPool::GetInstance().AllocTensorMem(align_size, from_persistent_mem);
|
const auto device_addr = AscendMemoryPool::GetInstance().AllocTensorMem(align_size, from_persistent_mem);
|
||||||
|
|
|
@ -33,7 +33,6 @@ class AscendMemoryManager : public MemoryManager {
|
||||||
void ResetDynamicMemory() override;
|
void ResetDynamicMemory() override;
|
||||||
void ClearGlobalIdleMem() override;
|
void ClearGlobalIdleMem() override;
|
||||||
void *MallocMemFromMemPool(size_t size, bool from_persistent_mem) override;
|
void *MallocMemFromMemPool(size_t size, bool from_persistent_mem) override;
|
||||||
void *MallocDevice(size_t size) override;
|
|
||||||
void FreeMemFromMemPool(void *device_ptr) override;
|
void FreeMemFromMemPool(void *device_ptr) override;
|
||||||
uint64_t GetMsMaxMemSize();
|
uint64_t GetMsMaxMemSize();
|
||||||
void MallocSomasDynamicMem(const session::KernelGraph &graph) override;
|
void MallocSomasDynamicMem(const session::KernelGraph &graph) override;
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "runtime/device/kernel_runtime.h"
|
#include "runtime/device/kernel_runtime.h"
|
||||||
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -110,7 +111,7 @@ void KernelRuntime::AssignMemory(const session::KernelGraph &graph) {
|
||||||
if (UseMemScheduler()) {
|
if (UseMemScheduler()) {
|
||||||
AssignStaticMemoryValueNode(graph);
|
AssignStaticMemoryValueNode(graph);
|
||||||
ResetNodeAddress(graph);
|
ResetNodeAddress(graph);
|
||||||
AssignCommunicationMem(graph);
|
AddCommunicationMemInfo(graph);
|
||||||
} else {
|
} else {
|
||||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||||
mem_manager_->ResetDynamicMemory();
|
mem_manager_->ResetDynamicMemory();
|
||||||
|
@ -1088,6 +1089,53 @@ DeviceAddressPtr KernelRuntime::CreateDeviceAddressForStringValue(const ValuePtr
|
||||||
return address;
|
return address;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool KernelRuntime::MemSchedulerPreCompute(const AnfNodePtr &kernel, const std::shared_ptr<MemScheduler> &mem_scheduler,
|
||||||
|
void *stream, bool mock, KernelLaunchInfo *kernel_launch_info) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
|
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||||
|
MS_EXCEPTION_IF_NULL(stream);
|
||||||
|
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
|
if (!mock && common::AnfAlgo::IsCommunicationOp(kernel) && !SyncStream()) {
|
||||||
|
MS_LOG(ERROR) << "SyncStream failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool ret = mem_scheduler->PreCompute(stream);
|
||||||
|
if (!ret) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
AssignKernelAddress(mem_scheduler, kernel, kernel_launch_info);
|
||||||
|
auto cnode = kernel->cast<CNodePtr>();
|
||||||
|
if (mock && common::AnfAlgo::HasNodeAttr(kAttrOffload, cnode) &&
|
||||||
|
common::AnfAlgo::GetNodeAttr<bool>(cnode, kAttrOffload)) {
|
||||||
|
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
|
||||||
|
auto device_address = AnfAlgo::GetOutputAddr(kernel, i, true);
|
||||||
|
mem_scheduler->SetOffload(device_address);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool KernelRuntime::MemSchedulerPostCompute(const session::KernelGraph &graph, const AnfNodePtr &kernel,
|
||||||
|
const std::shared_ptr<MemScheduler> &mem_scheduler, void *stream,
|
||||||
|
bool mock) {
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
|
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||||
|
MS_EXCEPTION_IF_NULL(stream);
|
||||||
|
if (!mock) {
|
||||||
|
SyncNodeOutputTensors(mem_scheduler, graph, kernel);
|
||||||
|
}
|
||||||
|
bool ret = mem_scheduler->PostCompute(stream);
|
||||||
|
if (!ret) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
if (!mock && common::AnfAlgo::IsCommunicationOp(kernel) && !SyncStream()) {
|
||||||
|
MS_LOG(ERROR) << "SyncStream failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void KernelRuntime::AssignDynamicMemory(const session::KernelGraph &graph) {
|
void KernelRuntime::AssignDynamicMemory(const session::KernelGraph &graph) {
|
||||||
MS_EXCEPTION_IF_NULL(mem_manager_);
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||||
auto context_ptr = MsContext::GetInstance();
|
auto context_ptr = MsContext::GetInstance();
|
||||||
|
@ -1524,13 +1572,37 @@ void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &m
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void KernelRuntime::AssignCommunicationMem(const session::KernelGraph &graph) {
|
void KernelRuntime::AddCommunicationMemInfo(const session::KernelGraph &graph) {
|
||||||
for (const auto &kernel : graph.execution_order()) {
|
const auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
|
||||||
|
for (size_t compute_index = 0; compute_index < graph.execution_order().size(); ++compute_index) {
|
||||||
|
const auto &kernel = graph.execution_order()[compute_index];
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
if (!common::AnfAlgo::IsCommunicationOp(kernel)) {
|
if (!common::AnfAlgo::IsCommunicationOp(kernel)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
AssignCommunicationInputFromMemoryPool(kernel);
|
auto device_address_to_key = [](const DeviceAddressPtr &device_address) -> void * { return device_address.get(); };
|
||||||
AssignCommunicationOutputFromMemoryPool(kernel);
|
size_t input_total_size = 0;
|
||||||
|
DeviceAddressPtrList input_address_list;
|
||||||
|
std::vector<size_t> input_align_size_list;
|
||||||
|
GetCommunicationInputInfo(kernel, &input_total_size, &input_address_list, &input_align_size_list);
|
||||||
|
if (input_address_list.size() > 1) {
|
||||||
|
std::vector<const void *> input_address_key_list;
|
||||||
|
std::transform(input_address_list.begin(), input_address_list.end(), std::back_inserter(input_address_key_list),
|
||||||
|
device_address_to_key);
|
||||||
|
mem_scheduler->AddContinuousMemInfo(true, compute_index, input_total_size, input_align_size_list,
|
||||||
|
input_address_key_list);
|
||||||
|
}
|
||||||
|
size_t output_total_size = 0;
|
||||||
|
DeviceAddressPtrList output_address_list;
|
||||||
|
std::vector<size_t> output_align_size_list;
|
||||||
|
GetCommunicationOutputInfo(kernel, &output_total_size, &output_address_list, &output_align_size_list);
|
||||||
|
if (output_address_list.size() > 1) {
|
||||||
|
std::vector<const void *> output_address_key_list;
|
||||||
|
std::transform(output_address_list.begin(), output_address_list.end(),
|
||||||
|
std::back_inserter(output_address_key_list), device_address_to_key);
|
||||||
|
mem_scheduler->AddContinuousMemInfo(false, compute_index, output_total_size, output_align_size_list,
|
||||||
|
output_address_key_list);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1550,19 +1622,10 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
|
||||||
}
|
}
|
||||||
bool ret = true;
|
bool ret = true;
|
||||||
if (mem_scheduler != nullptr) {
|
if (mem_scheduler != nullptr) {
|
||||||
ret = mem_scheduler->PreCompute(stream);
|
ret = MemSchedulerPreCompute(kernel, mem_scheduler, stream, mock, &kernel_launch_info);
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
AssignKernelAddress(mem_scheduler, kernel, &kernel_launch_info);
|
|
||||||
auto cnode = kernel->cast<CNodePtr>();
|
|
||||||
if (mock && common::AnfAlgo::HasNodeAttr(kAttrOffload, cnode) &&
|
|
||||||
common::AnfAlgo::GetNodeAttr<bool>(cnode, kAttrOffload)) {
|
|
||||||
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
|
|
||||||
auto device_address = AnfAlgo::GetOutputAddr(kernel, i, true);
|
|
||||||
mem_scheduler->SetOffload(device_address);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) {
|
} else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) {
|
||||||
kernel_launch_info.inputs_ = kernel_mod->GetInputsAddr();
|
kernel_launch_info.inputs_ = kernel_mod->GetInputsAddr();
|
||||||
kernel_launch_info.outputs_ = kernel_mod->GetOutputsAddr();
|
kernel_launch_info.outputs_ = kernel_mod->GetOutputsAddr();
|
||||||
|
@ -1581,10 +1644,7 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (mem_scheduler != nullptr) {
|
if (mem_scheduler != nullptr) {
|
||||||
if (!mock) {
|
ret = MemSchedulerPostCompute(graph, kernel, mem_scheduler, stream, mock);
|
||||||
SyncNodeOutputTensors(mem_scheduler, graph, kernel);
|
|
||||||
}
|
|
||||||
ret = mem_scheduler->PostCompute(stream);
|
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -1739,7 +1799,7 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
|
||||||
if (mem_scheduler->optimized()) {
|
if (mem_scheduler->optimized()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
mem_scheduler->SetMemHandler(mem_manager_);
|
mem_scheduler->SetMemHandler(std::make_shared<MemHandler>(mem_manager_));
|
||||||
mem_scheduler->SetTotalStep(graph.execution_order().size());
|
mem_scheduler->SetTotalStep(graph.execution_order().size());
|
||||||
|
|
||||||
if (mem_scheduler->need_record_event()) {
|
if (mem_scheduler->need_record_event()) {
|
||||||
|
|
|
@ -178,7 +178,7 @@ class KernelRuntime {
|
||||||
void SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler, const KernelWithIndex &output,
|
void SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler, const KernelWithIndex &output,
|
||||||
const session::KernelGraph &graph);
|
const session::KernelGraph &graph);
|
||||||
|
|
||||||
void AssignCommunicationMem(const session::KernelGraph &graph);
|
void AddCommunicationMemInfo(const session::KernelGraph &graph);
|
||||||
bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false);
|
bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false);
|
||||||
void LaunchKernelEvent(const std::map<AnfNodePtr, std::vector<std::function<void()>>> &run_events,
|
void LaunchKernelEvent(const std::map<AnfNodePtr, std::vector<std::function<void()>>> &run_events,
|
||||||
const AnfNodePtr &node) const;
|
const AnfNodePtr &node) const;
|
||||||
|
@ -204,6 +204,10 @@ class KernelRuntime {
|
||||||
void GetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size, DeviceAddressPtrList *address_list,
|
void GetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size, DeviceAddressPtrList *address_list,
|
||||||
std::vector<size_t> *align_size_list) const;
|
std::vector<size_t> *align_size_list) const;
|
||||||
DeviceAddressPtr CreateDeviceAddressForStringValue(const ValuePtr &value, bool use_mem_pool, uint32_t graph_id);
|
DeviceAddressPtr CreateDeviceAddressForStringValue(const ValuePtr &value, bool use_mem_pool, uint32_t graph_id);
|
||||||
|
bool MemSchedulerPreCompute(const AnfNodePtr &kernel, const std::shared_ptr<MemScheduler> &mem_scheduler,
|
||||||
|
void *stream, bool mock, KernelLaunchInfo *kernel_launch_info);
|
||||||
|
bool MemSchedulerPostCompute(const session::KernelGraph &graph, const AnfNodePtr &kernel,
|
||||||
|
const std::shared_ptr<MemScheduler> &mem_scheduler, void *stream, bool mock);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
uint32_t device_id_{0};
|
uint32_t device_id_{0};
|
||||||
|
|
|
@ -23,7 +23,6 @@
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include "common/mem_reuse/mem_reuse.h"
|
#include "common/mem_reuse/mem_reuse.h"
|
||||||
#include "backend/common/somas/somas.h"
|
#include "backend/common/somas/somas.h"
|
||||||
#include "runtime/device/memory_scheduler.h"
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
enum MemType { kStaticMem, kDynamicMem, kSomasReuseDynamicMem };
|
enum MemType { kStaticMem, kDynamicMem, kSomasReuseDynamicMem };
|
||||||
|
@ -32,7 +31,7 @@ constexpr uint64_t kMemAlignSize = 512;
|
||||||
constexpr uint64_t kTwiceMemAlignSize = kMemAlignSize << 1;
|
constexpr uint64_t kTwiceMemAlignSize = kMemAlignSize << 1;
|
||||||
using SomasPtr = mindspore::somas::SomasPtr;
|
using SomasPtr = mindspore::somas::SomasPtr;
|
||||||
|
|
||||||
class MemoryManager : public MemHandler {
|
class MemoryManager {
|
||||||
public:
|
public:
|
||||||
MemoryManager() = default;
|
MemoryManager() = default;
|
||||||
virtual ~MemoryManager() = default;
|
virtual ~MemoryManager() = default;
|
||||||
|
@ -50,7 +49,6 @@ class MemoryManager : public MemHandler {
|
||||||
virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) {
|
virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) {
|
||||||
return MallocMem(type, size, address, kInvalidGraphId);
|
return MallocMem(type, size, address, kInvalidGraphId);
|
||||||
}
|
}
|
||||||
|
|
||||||
// param address is the address type of each device
|
// param address is the address type of each device
|
||||||
// param from_persistent_mem shows whether the tensor is a parameter in Pynative mode
|
// param from_persistent_mem shows whether the tensor is a parameter in Pynative mode
|
||||||
virtual bool MallocMemFromMemPool(const DeviceAddressPtr &address, size_t size);
|
virtual bool MallocMemFromMemPool(const DeviceAddressPtr &address, size_t size);
|
||||||
|
@ -65,45 +63,13 @@ class MemoryManager : public MemHandler {
|
||||||
static size_t GetCommonAlignSize(size_t input_size);
|
static size_t GetCommonAlignSize(size_t input_size);
|
||||||
static size_t GetCommunicationAlignSize(size_t input_size);
|
static size_t GetCommunicationAlignSize(size_t input_size);
|
||||||
|
|
||||||
// swap manager interface
|
virtual void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) {
|
||||||
void *MallocDevice(size_t mem_size) override { return MallocMemFromMemPool(mem_size, false); }
|
|
||||||
void FreeDevice(void *ptr) override {
|
|
||||||
MS_EXCEPTION_IF_NULL(ptr);
|
|
||||||
FreeMemFromMemPool(ptr);
|
|
||||||
}
|
|
||||||
void *MallocHost(size_t mem_size) override {
|
|
||||||
auto &mem_que = cached_host_mem_[mem_size];
|
|
||||||
if (!mem_que.empty()) {
|
|
||||||
auto ret = mem_que.front();
|
|
||||||
mem_que.pop();
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
auto block = std::make_shared<std::vector<uint8_t>>();
|
|
||||||
try {
|
|
||||||
block->resize(mem_size, 0);
|
|
||||||
auto ptr = block->data();
|
|
||||||
host_mem_block_map_[ptr] = block;
|
|
||||||
return ptr;
|
|
||||||
} catch (const std::exception &e) {
|
|
||||||
MS_LOG(EXCEPTION) << "Malloc memory failed: size " << mem_size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void FreeHost(void *ptr) override {
|
|
||||||
MS_EXCEPTION_IF_NULL(ptr);
|
|
||||||
auto iter = host_mem_block_map_.find(ptr);
|
|
||||||
if (iter == host_mem_block_map_.end()) {
|
|
||||||
MS_LOG(ERROR) << "Free ptr not be created from manager!";
|
|
||||||
}
|
|
||||||
auto mem_size = iter->second->size();
|
|
||||||
cached_host_mem_[mem_size].emplace(iter->first);
|
|
||||||
}
|
|
||||||
void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {
|
|
||||||
MS_LOG(INFO) << "Call default swap in " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream;
|
MS_LOG(INFO) << "Call default swap in " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream;
|
||||||
}
|
}
|
||||||
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {
|
virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) {
|
||||||
MS_LOG(INFO) << "Call default swap out " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream;
|
MS_LOG(INFO) << "Call default swap out " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream;
|
||||||
}
|
}
|
||||||
size_t GetAvailableMemSize() override {
|
virtual size_t GetAvailableMemSize() {
|
||||||
MS_LOG(ERROR) << "Return default 0 mem size!";
|
MS_LOG(ERROR) << "Return default 0 mem size!";
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -115,8 +81,6 @@ class MemoryManager : public MemHandler {
|
||||||
}
|
}
|
||||||
virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem);
|
virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem);
|
||||||
SomasPtr somas_reuse_util_ptr_{nullptr};
|
SomasPtr somas_reuse_util_ptr_{nullptr};
|
||||||
std::map<size_t, std::queue<void *>> cached_host_mem_;
|
|
||||||
std::map<void *, std::shared_ptr<std::vector<uint8_t>>> host_mem_block_map_;
|
|
||||||
};
|
};
|
||||||
} // namespace device
|
} // namespace device
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <algorithm>
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -25,14 +26,14 @@ namespace device {
|
||||||
constexpr size_t kFirstGetMemEventIndex = 1;
|
constexpr size_t kFirstGetMemEventIndex = 1;
|
||||||
constexpr size_t kInitOrMallocMemEventIndex = 0;
|
constexpr size_t kInitOrMallocMemEventIndex = 0;
|
||||||
|
|
||||||
std::vector<std::shared_ptr<MemEvent>> &MemOffloadStrategy::GetPreComputeEvents(size_t step) {
|
MemEventPtrList &MemOffloadStrategy::GetPreComputeEvents(size_t step) {
|
||||||
if (pre_compute_events_.size() <= step) {
|
if (pre_compute_events_.size() <= step) {
|
||||||
MS_LOG_EXCEPTION << "Index out of pre event range, index:" << step << ", event size:" << pre_compute_events_.size();
|
MS_LOG_EXCEPTION << "Index out of pre event range, index:" << step << ", event size:" << pre_compute_events_.size();
|
||||||
}
|
}
|
||||||
return pre_compute_events_[step];
|
return pre_compute_events_[step];
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::shared_ptr<MemEvent>> &MemOffloadStrategy::GetPostComputeEvents(size_t step) {
|
MemEventPtrList &MemOffloadStrategy::GetPostComputeEvents(size_t step) {
|
||||||
if (post_compute_events_.size() <= step) {
|
if (post_compute_events_.size() <= step) {
|
||||||
MS_LOG_EXCEPTION << "Index out of post event range, index:" << step
|
MS_LOG_EXCEPTION << "Index out of post event range, index:" << step
|
||||||
<< ", event size:" << post_compute_events_.size();
|
<< ", event size:" << post_compute_events_.size();
|
||||||
|
@ -46,6 +47,8 @@ void MemOffloadStrategy::Execute() {
|
||||||
if (need_swap_) {
|
if (need_swap_) {
|
||||||
GenEventSpan();
|
GenEventSpan();
|
||||||
GenSwapEventSet();
|
GenSwapEventSet();
|
||||||
|
} else {
|
||||||
|
GenContinuousMemAllocSteps();
|
||||||
}
|
}
|
||||||
GenComputeMemEvents();
|
GenComputeMemEvents();
|
||||||
}
|
}
|
||||||
|
@ -60,14 +63,17 @@ void MemOffloadStrategy::CountMemUsage() {
|
||||||
min_mem_used_.resize(total_step_, 0);
|
min_mem_used_.resize(total_step_, 0);
|
||||||
std::vector<size_t> total_mem_used(total_step_, 0);
|
std::vector<size_t> total_mem_used(total_step_, 0);
|
||||||
size_t high_priority_mem_size = 0;
|
size_t high_priority_mem_size = 0;
|
||||||
|
MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
|
||||||
for (auto &item : mem_events_) {
|
for (auto &item : mem_events_) {
|
||||||
auto &mem_events = item.second;
|
auto &mem_events = item.second;
|
||||||
if (mem_events.empty()) {
|
if (mem_events.empty()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto first_event = mem_events[kInitOrMallocMemEventIndex];
|
auto first_event = mem_events[kInitOrMallocMemEventIndex];
|
||||||
const bool is_high_priority = IsHighPriorityMem(first_event->key);
|
const bool is_high_priority = IsHighPriorityMem(item.first);
|
||||||
if (is_high_priority) {
|
if (continuous_mem_info_helper_->IsContinuousInputMem(item.first)) {
|
||||||
|
continue;
|
||||||
|
} else if (is_high_priority) {
|
||||||
high_priority_mem_size += first_event->mem_size;
|
high_priority_mem_size += first_event->mem_size;
|
||||||
} else {
|
} else {
|
||||||
auto last_event = mem_events[mem_events.size() - 1];
|
auto last_event = mem_events[mem_events.size() - 1];
|
||||||
|
@ -75,6 +81,7 @@ void MemOffloadStrategy::CountMemUsage() {
|
||||||
total_mem_used[start_index] += first_event->mem_size;
|
total_mem_used[start_index] += first_event->mem_size;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate the minimum memory size for kernel execution.
|
// Calculate the minimum memory size for kernel execution.
|
||||||
for (const auto &event : mem_events) {
|
for (const auto &event : mem_events) {
|
||||||
MS_EXCEPTION_IF_NULL(event);
|
MS_EXCEPTION_IF_NULL(event);
|
||||||
|
@ -84,6 +91,7 @@ void MemOffloadStrategy::CountMemUsage() {
|
||||||
min_mem_used_[event->index] += first_event->mem_size;
|
min_mem_used_[event->index] += first_event->mem_size;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
CountContinuousMemUsage(&total_mem_used);
|
||||||
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
|
min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
|
||||||
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())) + high_priority_mem_size;
|
mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())) + high_priority_mem_size;
|
||||||
if (mem_size_ < min_mem_needed_) {
|
if (mem_size_ < min_mem_needed_) {
|
||||||
|
@ -118,24 +126,25 @@ void MemOffloadStrategy::GenEventSpan() {
|
||||||
if (tensor_events.size() <= 1) {
|
if (tensor_events.size() <= 1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const bool is_high_priority = IsHighPriorityMem(tensor_events[kInitOrMallocMemEventIndex]->key);
|
const bool is_high_priority = IsHighPriorityMem(item.first);
|
||||||
for (size_t i = kFirstGetMemEventIndex; i < tensor_events.size(); ++i) {
|
for (size_t i = kFirstGetMemEventIndex; i < tensor_events.size(); ++i) {
|
||||||
auto &event = tensor_events[i];
|
auto &event = tensor_events[i];
|
||||||
MS_EXCEPTION_IF_NULL(event);
|
MS_EXCEPTION_IF_NULL(event);
|
||||||
if (event->type != kGet) {
|
if (event->type != kGet) {
|
||||||
MS_LOG(EXCEPTION) << "Event should be Get except fist event.";
|
MS_LOG(EXCEPTION) << "Event should be Get except fist event.";
|
||||||
}
|
}
|
||||||
auto latest_event = tensor_events[i - 1];
|
auto latest_get_event = tensor_events[i - 1];
|
||||||
if (i == kFirstGetMemEventIndex && is_high_priority) {
|
if (i == kFirstGetMemEventIndex && is_high_priority) {
|
||||||
latest_event = tensor_events[tensor_events.size() - 1];
|
latest_get_event = tensor_events[tensor_events.size() - 1];
|
||||||
}
|
}
|
||||||
auto span = GetSpanBetweenMemEvents(latest_event->index, event->index);
|
auto span = GetSpanBetweenMemEvents(latest_get_event->index, event->index);
|
||||||
if (is_high_priority && span == 0 && latest_event == event) {
|
// High priority memory that is only used once in a total step
|
||||||
|
if (is_high_priority && span == 0 && latest_get_event == event) {
|
||||||
span = total_step_;
|
span = total_step_;
|
||||||
}
|
}
|
||||||
if (span > 1) {
|
if (span > 1) {
|
||||||
const size_t span_mul_size = (span - 1) * event->mem_size;
|
const size_t span_mul_size = (span - 1) * event->mem_size;
|
||||||
(void)event_span_.emplace(std::make_pair(span_mul_size, std::make_pair(event, span)));
|
(void)event_span_.emplace(span_mul_size, std::make_pair(event, span));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -153,36 +162,159 @@ void MemOffloadStrategy::GenSwapEventSet() {
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// greedy span filter
|
// greedy span filter
|
||||||
|
continuous_mem_info_helper_->ClearContinuousMallocIndex();
|
||||||
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
|
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
|
||||||
|
|
||||||
|
auto compare_total_size = [](ContinuousMemInfoPtr l, ContinuousMemInfoPtr r) -> bool {
|
||||||
|
return l->total_size_ < r->total_size_;
|
||||||
|
};
|
||||||
|
auto all_continuous_mem_info = continuous_mem_info_helper_->GetAllContinuousMemInfo();
|
||||||
|
std::sort(all_continuous_mem_info.begin(), all_continuous_mem_info.end(), compare_total_size);
|
||||||
|
std::set<std::shared_ptr<MemEvent>> events_no_need_swap;
|
||||||
|
for (const auto &continuous_mem_info : all_continuous_mem_info) {
|
||||||
|
GenContinuousMemSwapEvent(continuous_mem_info, &cur_mem_used, &events_no_need_swap);
|
||||||
|
}
|
||||||
for (const auto &iter : event_span_) {
|
for (const auto &iter : event_span_) {
|
||||||
|
const auto &event = iter.second.first;
|
||||||
|
if (events_no_need_swap.count(event) > 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
auto span = iter.second.second;
|
auto span = iter.second.second;
|
||||||
auto &event = iter.second.first;
|
AddToSwapEventSetIfOutOfMem(event, span, &cur_mem_used);
|
||||||
auto start_index = ((event->index + total_step_ - span + 1) % total_step_);
|
}
|
||||||
bool revert = false;
|
}
|
||||||
size_t cur_index = start_index;
|
|
||||||
|
void MemOffloadStrategy::AddToSwapEventSetIfOutOfMem(const std::shared_ptr<MemEvent> &event, size_t span,
|
||||||
|
std::vector<size_t> *mem_used) {
|
||||||
|
const auto start_index = (GetPreMemEventIndex(event->index, span) + 1) % total_step_;
|
||||||
|
bool revert = false;
|
||||||
|
size_t cur_index = start_index;
|
||||||
|
while (cur_index != event->index) {
|
||||||
|
(*mem_used)[cur_index] += event->mem_size;
|
||||||
|
if (mem_used->at(cur_index) > mem_size_) {
|
||||||
|
revert = true;
|
||||||
|
}
|
||||||
|
cur_index += 1;
|
||||||
|
if (cur_index >= total_step_) {
|
||||||
|
cur_index = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (revert) {
|
||||||
|
cur_index = start_index;
|
||||||
while (cur_index != event->index) {
|
while (cur_index != event->index) {
|
||||||
cur_mem_used[cur_index] += event->mem_size;
|
(*mem_used)[cur_index] -= event->mem_size;
|
||||||
if (cur_mem_used[cur_index] > mem_size_) {
|
|
||||||
revert = true;
|
|
||||||
}
|
|
||||||
cur_index += 1;
|
cur_index += 1;
|
||||||
if (cur_index >= total_step_) {
|
if (cur_index >= total_step_) {
|
||||||
cur_index = 0;
|
cur_index = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (revert) {
|
(void)swap_events_.emplace(event);
|
||||||
cur_index = start_index;
|
}
|
||||||
while (cur_index != event->index) {
|
}
|
||||||
cur_mem_used[cur_index] -= event->mem_size;
|
|
||||||
cur_index += 1;
|
void MemOffloadStrategy::GenContinuousMemSwapEvent(const ContinuousMemInfoPtr &continuous_mem_info,
|
||||||
if (cur_index >= total_step_) {
|
std::vector<size_t> *mem_used,
|
||||||
cur_index = 0;
|
std::set<std::shared_ptr<MemEvent>> *events_no_need_swap) {
|
||||||
}
|
MS_EXCEPTION_IF_NULL(continuous_mem_info);
|
||||||
}
|
if (continuous_mem_info->key_index_map_.empty()) {
|
||||||
(void)swap_events_.emplace(event);
|
return;
|
||||||
|
}
|
||||||
|
const size_t continuous_mem_used_index = continuous_mem_info->compute_index_;
|
||||||
|
if (!continuous_mem_info->is_input_) {
|
||||||
|
continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, continuous_mem_info->compute_index_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const auto max_span_mem_in_device = GetMaxSpanForContinuousMem(continuous_mem_info, *mem_used);
|
||||||
|
size_t first_malloc_span = 0;
|
||||||
|
size_t first_malloc_size_dup = 0;
|
||||||
|
for (const auto &key_index : continuous_mem_info->key_index_map_) {
|
||||||
|
const auto &events_iter = mem_events_.find(key_index.first);
|
||||||
|
if (events_iter == mem_events_.end() || events_iter->second.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can not find events for continuous input memory, device address key: " << key_index.first;
|
||||||
}
|
}
|
||||||
|
size_t swap_in_event_index = kFirstGetMemEventIndex;
|
||||||
|
size_t swap_in_span = 0;
|
||||||
|
const bool is_high_priority = IsHighPriorityMem(key_index.first);
|
||||||
|
for (size_t i = kFirstGetMemEventIndex; i < events_iter->second.size(); ++i) {
|
||||||
|
const auto &mem_event = events_iter->second[i];
|
||||||
|
if (!is_high_priority && mem_event->index > continuous_mem_used_index) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const size_t span = GetSpanBetweenMemEvents(mem_event->index, continuous_mem_used_index);
|
||||||
|
// Find the max span than less than or equal to max_span_mem_in_device.
|
||||||
|
if (span <= max_span_mem_in_device) {
|
||||||
|
if (span >= swap_in_span) {
|
||||||
|
swap_in_span = span;
|
||||||
|
swap_in_event_index = i;
|
||||||
|
}
|
||||||
|
events_no_need_swap->insert(mem_event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (swap_in_event_index != kFirstGetMemEventIndex || is_high_priority) {
|
||||||
|
swap_events_.insert(events_iter->second[swap_in_event_index]);
|
||||||
|
}
|
||||||
|
// Find the earliest index that continuous memory should be allocated
|
||||||
|
if (swap_in_span > first_malloc_span) {
|
||||||
|
first_malloc_span = swap_in_span;
|
||||||
|
first_malloc_size_dup = events_iter->second[swap_in_event_index]->mem_size;
|
||||||
|
} else if (swap_in_span == first_malloc_span) {
|
||||||
|
// Accumulate the memory size that already added to mem_used.
|
||||||
|
first_malloc_size_dup += events_iter->second[swap_in_event_index]->mem_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t span = 1; span <= first_malloc_span; ++span) {
|
||||||
|
size_t index = GetPreMemEventIndex(continuous_mem_used_index, span);
|
||||||
|
(*mem_used)[index] += continuous_mem_info->total_size_;
|
||||||
|
}
|
||||||
|
size_t index = GetPreMemEventIndex(continuous_mem_used_index, first_malloc_span);
|
||||||
|
(*mem_used)[index] -= first_malloc_size_dup;
|
||||||
|
continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, index);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t MemOffloadStrategy::GetMaxSpanForContinuousMem(const ContinuousMemInfoPtr &continuous_mem_info,
|
||||||
|
const std::vector<size_t> &mem_used) {
|
||||||
|
const size_t continuous_mem_used_index = continuous_mem_info->compute_index_;
|
||||||
|
size_t earliest_malloc_index = GetFirstMallocIndex(continuous_mem_info);
|
||||||
|
size_t max_span_mem_in_device = GetSpanBetweenMemEvents(earliest_malloc_index, continuous_mem_used_index);
|
||||||
|
|
||||||
|
for (size_t span = 1; span <= max_span_mem_in_device; ++span) {
|
||||||
|
size_t cur_index = GetPreMemEventIndex(continuous_mem_used_index, span);
|
||||||
|
if (mem_used[cur_index] + continuous_mem_info->total_size_ > mem_size_) {
|
||||||
|
max_span_mem_in_device = span - 1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max_span_mem_in_device;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t MemOffloadStrategy::GetFirstMallocIndex(const ContinuousMemInfoPtr &continuous_mem_info) {
|
||||||
|
size_t earliest_malloc_index = continuous_mem_info->compute_index_;
|
||||||
|
for (const auto &key_index : continuous_mem_info->key_index_map_) {
|
||||||
|
const auto &events_iter = mem_events_.find(key_index.first);
|
||||||
|
if (events_iter == mem_events_.end() || events_iter->second.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can not find events for continuous input memory, device address key: " << key_index.first;
|
||||||
|
}
|
||||||
|
const auto &first_event = events_iter->second[kInitOrMallocMemEventIndex];
|
||||||
|
if (first_event->index < earliest_malloc_index) {
|
||||||
|
earliest_malloc_index = first_event->index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return earliest_malloc_index;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemOffloadStrategy::GenContinuousMemAllocSteps() {
|
||||||
|
for (const auto &continuous_mem_info : continuous_mem_info_helper_->GetAllContinuousMemInfo()) {
|
||||||
|
GenContinuousMemAllocStep(continuous_mem_info);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemOffloadStrategy::GenContinuousMemAllocStep(const ContinuousMemInfoPtr &continuous_mem_info) {
|
||||||
|
if (!continuous_mem_info->is_input_) {
|
||||||
|
continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, continuous_mem_info->compute_index_);
|
||||||
|
} else {
|
||||||
|
const size_t earliest_malloc_index = GetFirstMallocIndex(continuous_mem_info);
|
||||||
|
continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, earliest_malloc_index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -288,5 +420,87 @@ std::set<size_t> MemOffloadStrategy::GetSwapOutEventIndex(const void *key,
|
||||||
}
|
}
|
||||||
return swap_out_event_index;
|
return swap_out_event_index;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ContinuousMemInfo> ContinuousMemInfoHelper::GetContinuousMemInfo(const void *address_key) {
|
||||||
|
const auto &key_compute_index_iter = key_continuous_info_map_.find(address_key);
|
||||||
|
if (key_compute_index_iter == key_continuous_info_map_.end()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return key_compute_index_iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ContinuousMemInfoPtr> ContinuousMemInfoHelper::GetAllContinuousMemInfo() {
|
||||||
|
std::vector<ContinuousMemInfoPtr> all_continuous_mem_info(input_continuous_mem_info_.size() +
|
||||||
|
output_continuous_mem_info_.size());
|
||||||
|
std::copy(input_continuous_mem_info_.begin(), input_continuous_mem_info_.end(), all_continuous_mem_info.begin());
|
||||||
|
std::copy_backward(output_continuous_mem_info_.begin(), output_continuous_mem_info_.end(),
|
||||||
|
all_continuous_mem_info.end());
|
||||||
|
return all_continuous_mem_info;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ContinuousMemInfoHelper::IsContinuousMem(const void *address_key) {
|
||||||
|
const auto continuous_mem_info = GetContinuousMemInfo(address_key);
|
||||||
|
return (continuous_mem_info != nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ContinuousMemInfoHelper::IsContinuousInputMem(const void *address_key) {
|
||||||
|
const auto continuous_mem_info = GetContinuousMemInfo(address_key);
|
||||||
|
return (continuous_mem_info != nullptr && continuous_mem_info->is_input_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ContinuousMemInfoHelper::AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size,
|
||||||
|
const std::vector<size_t> &align_size_list,
|
||||||
|
const std::vector<const void *> &address_key_list) {
|
||||||
|
if (align_size_list.size() != address_key_list.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Number of align size[" << align_size_list.size()
|
||||||
|
<< "] is supposed to be equal to number of address[" << address_key_list.size() << "]";
|
||||||
|
}
|
||||||
|
ContinuousMemInfoPtr continuous_mem_info =
|
||||||
|
std::make_shared<ContinuousMemInfo>(is_input, total_size, compute_index, align_size_list);
|
||||||
|
for (size_t i = 0; i < address_key_list.size(); i += 1) {
|
||||||
|
auto key = address_key_list[i];
|
||||||
|
MS_EXCEPTION_IF_NULL(key);
|
||||||
|
(void)continuous_mem_info->key_index_map_.emplace(key, i);
|
||||||
|
key_continuous_info_map_.emplace(key, continuous_mem_info);
|
||||||
|
}
|
||||||
|
if (is_input) {
|
||||||
|
input_continuous_mem_info_.insert(continuous_mem_info);
|
||||||
|
} else {
|
||||||
|
output_continuous_mem_info_.insert(continuous_mem_info);
|
||||||
|
}
|
||||||
|
index_continuous_info_map_[compute_index].emplace_back(continuous_mem_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemOffloadStrategy::CountContinuousMemUsage(std::vector<size_t> *total_mem_used) {
|
||||||
|
const auto &input_continuous_mem_info_ = continuous_mem_info_helper_->GetAllContinuousMemInfo();
|
||||||
|
for (const auto &continuous_mem_info : input_continuous_mem_info_) {
|
||||||
|
if (!continuous_mem_info->is_input_ || continuous_mem_info->key_index_map_.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto &compute_index = continuous_mem_info->compute_index_;
|
||||||
|
size_t earliest_malloc_index = SIZE_MAX;
|
||||||
|
for (const auto &key_index : continuous_mem_info->key_index_map_) {
|
||||||
|
const auto &key = key_index.first;
|
||||||
|
const auto &events_iter = mem_events_.find(key);
|
||||||
|
if (events_iter == mem_events_.end() || events_iter->second.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can not find memory events of continuous input memory, device address key: " << key;
|
||||||
|
}
|
||||||
|
const auto &mem_events = events_iter->second;
|
||||||
|
const auto &first_event = mem_events[kInitOrMallocMemEventIndex];
|
||||||
|
if (first_event->index < earliest_malloc_index) {
|
||||||
|
earliest_malloc_index = first_event->index;
|
||||||
|
}
|
||||||
|
const auto &last_events = mem_events[mem_events.size() - 1];
|
||||||
|
const auto end_step = IsHighPriorityMem(key) ? total_step_ - 1 : last_events->index;
|
||||||
|
const auto mem_size = last_events->mem_size;
|
||||||
|
for (size_t start_index = compute_index + 1; start_index <= end_step; start_index += 1) {
|
||||||
|
(*total_mem_used)[start_index] += mem_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t start_index = earliest_malloc_index; start_index <= compute_index; ++start_index) {
|
||||||
|
(*total_mem_used)[start_index] += continuous_mem_info->total_size_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace device
|
} // namespace device
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
|
@ -37,17 +38,70 @@ struct MemEvent {
|
||||||
const void *key{nullptr};
|
const void *key{nullptr};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using MemEventPtr = std::shared_ptr<MemEvent>;
|
||||||
|
using MemEventPtrList = std::vector<MemEventPtr>;
|
||||||
|
|
||||||
|
struct ContinuousMemInfo {
|
||||||
|
ContinuousMemInfo(bool is_input, size_t total_size, size_t compute_index, std::vector<size_t> align_size_list)
|
||||||
|
: is_input_(is_input),
|
||||||
|
total_size_(total_size),
|
||||||
|
compute_index_(compute_index),
|
||||||
|
align_size_list_(std::move(align_size_list)) {}
|
||||||
|
bool is_input_;
|
||||||
|
size_t total_size_;
|
||||||
|
size_t compute_index_;
|
||||||
|
const std::vector<size_t> align_size_list_;
|
||||||
|
std::map<const void *, size_t> key_index_map_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using ContinuousMemInfoPtr = std::shared_ptr<ContinuousMemInfo>;
|
||||||
|
|
||||||
|
class ContinuousMemInfoHelper {
|
||||||
|
public:
|
||||||
|
void AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size,
|
||||||
|
const std::vector<size_t> &align_size_list,
|
||||||
|
const std::vector<const void *> &address_key_list);
|
||||||
|
std::shared_ptr<ContinuousMemInfo> GetContinuousMemInfo(const void *address_key);
|
||||||
|
std::vector<ContinuousMemInfoPtr> GetAllContinuousMemInfo();
|
||||||
|
bool IsContinuousMem(const void *address_key);
|
||||||
|
bool IsContinuousInputMem(const void *address_key);
|
||||||
|
|
||||||
|
void AddContinuousMallocIndex(const ContinuousMemInfoPtr &mem_info, size_t index) {
|
||||||
|
first_malloc_index_.emplace(mem_info, index);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NeedMallocContinuousMem(const ContinuousMemInfoPtr &mem_info, size_t index) {
|
||||||
|
const auto &iter = first_malloc_index_.find(mem_info);
|
||||||
|
return iter != first_malloc_index_.end() && iter->second == index;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ClearContinuousMallocIndex() { first_malloc_index_.clear(); }
|
||||||
|
|
||||||
|
const std::vector<ContinuousMemInfoPtr> &GetIndexContinuousMemInfo(size_t step) {
|
||||||
|
return index_continuous_info_map_[step];
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::set<ContinuousMemInfoPtr> input_continuous_mem_info_;
|
||||||
|
std::set<ContinuousMemInfoPtr> output_continuous_mem_info_;
|
||||||
|
std::map<const void *, ContinuousMemInfoPtr> key_continuous_info_map_;
|
||||||
|
std::map<ContinuousMemInfoPtr, size_t> first_malloc_index_;
|
||||||
|
std::map<size_t, std::vector<ContinuousMemInfoPtr>> index_continuous_info_map_;
|
||||||
|
};
|
||||||
|
|
||||||
class MemOffloadStrategy {
|
class MemOffloadStrategy {
|
||||||
public:
|
public:
|
||||||
MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority,
|
MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority,
|
||||||
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events,
|
const std::map<const void *, MemEventPtrList> &mem_events,
|
||||||
const std::set<const void *> &manual_offload_keys,
|
const std::set<const void *> &manual_offload_keys,
|
||||||
const std::map<const void *, std::vector<size_t>> &high_priority_updated_step, size_t total_step)
|
const std::map<const void *, std::vector<size_t>> &high_priority_updated_step, size_t total_step,
|
||||||
|
std::shared_ptr<ContinuousMemInfoHelper> continuous_mem_info_manager)
|
||||||
: mem_priority_(mem_priority),
|
: mem_priority_(mem_priority),
|
||||||
mem_events_(mem_events),
|
mem_events_(mem_events),
|
||||||
manual_offload_keys_(manual_offload_keys),
|
manual_offload_keys_(manual_offload_keys),
|
||||||
high_priority_updated_step_(high_priority_updated_step),
|
high_priority_updated_step_(high_priority_updated_step),
|
||||||
total_step_(total_step) {}
|
total_step_(total_step),
|
||||||
|
continuous_mem_info_helper_(std::move(continuous_mem_info_manager)) {}
|
||||||
|
|
||||||
virtual ~MemOffloadStrategy() = default;
|
virtual ~MemOffloadStrategy() = default;
|
||||||
|
|
||||||
|
@ -55,9 +109,9 @@ class MemOffloadStrategy {
|
||||||
|
|
||||||
void SetComputeTime(const std::vector<double> &compute_time) { compute_time_ = compute_time; }
|
void SetComputeTime(const std::vector<double> &compute_time) { compute_time_ = compute_time; }
|
||||||
|
|
||||||
std::vector<std::shared_ptr<MemEvent>> &GetPreComputeEvents(size_t step);
|
MemEventPtrList &GetPreComputeEvents(size_t step);
|
||||||
|
|
||||||
std::vector<std::shared_ptr<MemEvent>> &GetPostComputeEvents(size_t step);
|
MemEventPtrList &GetPostComputeEvents(size_t step);
|
||||||
|
|
||||||
void set_mem_size(size_t mem_size) { mem_size_ = mem_size; }
|
void set_mem_size(size_t mem_size) { mem_size_ = mem_size; }
|
||||||
|
|
||||||
|
@ -76,29 +130,52 @@ class MemOffloadStrategy {
|
||||||
|
|
||||||
void GenComputeMemEvents();
|
void GenComputeMemEvents();
|
||||||
|
|
||||||
void GenFreeEvent(const std::shared_ptr<MemEvent> &last_event);
|
void GenFreeEvent(const MemEventPtr &last_event);
|
||||||
|
|
||||||
std::set<size_t> GetSwapOutEventIndex(const void *key, const std::vector<std::shared_ptr<MemEvent>> &mem_events);
|
std::set<size_t> GetSwapOutEventIndex(const void *key, const std::vector<std::shared_ptr<MemEvent>> &mem_events);
|
||||||
|
|
||||||
size_t GetSpanBetweenMemEvents(size_t pre_step, size_t post_step) const {
|
void AddToSwapEventSetIfOutOfMem(const MemEventPtr &mem_event, size_t span, std::vector<size_t> *mem_used);
|
||||||
return (post_step + total_step_ - pre_step) % total_step_;
|
|
||||||
|
void GenContinuousMemSwapEvent(const ContinuousMemInfoPtr &continuous_mem_info, std::vector<size_t> *mem_used,
|
||||||
|
std::set<MemEventPtr> *events_no_need_swap);
|
||||||
|
|
||||||
|
size_t GetMaxSpanForContinuousMem(const ContinuousMemInfoPtr &continuous_mem_info,
|
||||||
|
const std::vector<size_t> &mem_used);
|
||||||
|
|
||||||
|
size_t GetFirstMallocIndex(const ContinuousMemInfoPtr &continuous_mem_info);
|
||||||
|
|
||||||
|
void GenContinuousMemAllocSteps();
|
||||||
|
|
||||||
|
void GenContinuousMemAllocStep(const ContinuousMemInfoPtr &continuous_mem_info);
|
||||||
|
|
||||||
|
void CountContinuousMemUsage(std::vector<size_t> *total_mem_used);
|
||||||
|
|
||||||
|
size_t GetSpanBetweenMemEvents(size_t pre_index, size_t post_index) const {
|
||||||
|
return (post_index + total_step_ - pre_index) % total_step_;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t GetPreMemEventIndex(size_t cur_index, size_t span) const {
|
||||||
|
return (cur_index + total_step_ - span) % total_step_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::map<const void *, MemPriority> &mem_priority_;
|
const std::map<const void *, MemPriority> &mem_priority_;
|
||||||
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_;
|
const std::map<const void *, MemEventPtrList> &mem_events_;
|
||||||
const std::set<const void *> &manual_offload_keys_;
|
const std::set<const void *> &manual_offload_keys_;
|
||||||
std::map<const void *, std::vector<size_t>> high_priority_updated_step_;
|
std::map<const void *, std::vector<size_t>> high_priority_updated_step_;
|
||||||
const size_t total_step_;
|
const size_t total_step_;
|
||||||
std::vector<std::vector<std::shared_ptr<MemEvent>>> pre_compute_events_;
|
std::vector<MemEventPtrList> pre_compute_events_;
|
||||||
std::vector<std::vector<std::shared_ptr<MemEvent>>> post_compute_events_;
|
std::vector<MemEventPtrList> post_compute_events_;
|
||||||
|
|
||||||
size_t mem_size_{0};
|
size_t mem_size_{0};
|
||||||
std::vector<double> compute_time_;
|
std::vector<double> compute_time_;
|
||||||
bool need_swap_{false};
|
bool need_swap_{false};
|
||||||
std::multimap<size_t, std::pair<std::shared_ptr<MemEvent>, size_t>> event_span_;
|
std::multimap<size_t, std::pair<MemEventPtr, size_t>> event_span_;
|
||||||
std::set<std::shared_ptr<MemEvent>> swap_events_;
|
std::multimap<size_t, std::pair<MemEventPtr, size_t>> continuous_input_event_span_;
|
||||||
|
std::set<MemEventPtr> swap_events_;
|
||||||
std::vector<size_t> min_mem_used_;
|
std::vector<size_t> min_mem_used_;
|
||||||
size_t mem_used_without_swap_{0};
|
size_t mem_used_without_swap_{0};
|
||||||
size_t min_mem_needed_{0};
|
size_t min_mem_needed_{0};
|
||||||
|
std::shared_ptr<ContinuousMemInfoHelper> continuous_mem_info_helper_{nullptr};
|
||||||
};
|
};
|
||||||
} // namespace device
|
} // namespace device
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include "runtime/device/memory_scheduler.h"
|
#include "runtime/device/memory_scheduler.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
|
#include <set>
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
#include <time.h>
|
#include <time.h>
|
||||||
#else
|
#else
|
||||||
|
@ -44,6 +45,34 @@ double GetCurrentTime() {
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
void *MemHandler::MallocHost(size_t mem_size) {
|
||||||
|
auto &mem_que = cached_host_mem_[mem_size];
|
||||||
|
if (!mem_que.empty()) {
|
||||||
|
auto ret = mem_que.front();
|
||||||
|
mem_que.pop();
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
auto block = std::make_shared<std::vector<uint8_t>>();
|
||||||
|
try {
|
||||||
|
block->resize(mem_size, 0);
|
||||||
|
auto ptr = block->data();
|
||||||
|
host_mem_block_map_[ptr] = block;
|
||||||
|
return ptr;
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
MS_LOG(EXCEPTION) << "Malloc memory failed: size " << mem_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemHandler::FreeHost(void *ptr) {
|
||||||
|
MS_EXCEPTION_IF_NULL(ptr);
|
||||||
|
auto iter = host_mem_block_map_.find(ptr);
|
||||||
|
if (iter == host_mem_block_map_.end()) {
|
||||||
|
MS_LOG(ERROR) << "Free ptr not be created from manager!";
|
||||||
|
}
|
||||||
|
auto mem_size = iter->second->size();
|
||||||
|
cached_host_mem_[mem_size].emplace(iter->first);
|
||||||
|
}
|
||||||
|
|
||||||
void MemScheduler::Clear() {
|
void MemScheduler::Clear() {
|
||||||
if (mem_handler_ == nullptr) {
|
if (mem_handler_ == nullptr) {
|
||||||
return;
|
return;
|
||||||
|
@ -72,6 +101,15 @@ void MemScheduler::ClearAllocatedMem() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
swap_host_ptr_.clear();
|
swap_host_ptr_.clear();
|
||||||
|
continuous_mem_key_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void MemScheduler::AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size,
|
||||||
|
const std::vector<size_t> &align_size_list,
|
||||||
|
const std::vector<const void *> &address_key_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
|
||||||
|
continuous_mem_info_helper_->AddContinuousMemInfo(is_input, compute_index, total_size, align_size_list,
|
||||||
|
address_key_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
|
void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
|
||||||
|
@ -119,18 +157,51 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void *MemScheduler::MallocContinuousMem(const std::shared_ptr<MemEvent> &event, void *stream) {
|
||||||
|
const auto &continuous_mem_info = continuous_mem_info_helper_->GetContinuousMemInfo(event->key);
|
||||||
|
void *device_ptr = nullptr;
|
||||||
|
if (cur_step_allocated_continuous_mem_.count(continuous_mem_info) == 0 &&
|
||||||
|
continuous_mem_info_helper_->NeedMallocContinuousMem(continuous_mem_info, current_step_)) {
|
||||||
|
if (mem_result_.find(event->key) != mem_result_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Device memory is allocated before first continuous memory alloc event, event key: "
|
||||||
|
<< event->key << ", continuous memory used index: " << continuous_mem_info->compute_index_;
|
||||||
|
}
|
||||||
|
const auto &device_ptr_list =
|
||||||
|
MallocContinuousMem(continuous_mem_info->total_size_, continuous_mem_info->align_size_list_, stream);
|
||||||
|
if (device_ptr_list.empty()) {
|
||||||
|
MS_LOG(WARNING) << "MallocContinuousMemFromMemPool failed, size: " << continuous_mem_info->total_size_;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
for (const auto &key_index : continuous_mem_info->key_index_map_) {
|
||||||
|
MS_EXCEPTION_IF_NULL(device_ptr_list[key_index.second]);
|
||||||
|
mem_result_[key_index.first] = device_ptr_list[key_index.second];
|
||||||
|
continuous_mem_key_.insert(key_index.first);
|
||||||
|
}
|
||||||
|
device_ptr = mem_result_[event->key];
|
||||||
|
MS_EXCEPTION_IF_NULL(device_ptr);
|
||||||
|
cur_step_allocated_continuous_mem_.insert(continuous_mem_info);
|
||||||
|
} else {
|
||||||
|
device_ptr = MallocDevice(event->mem_size, stream);
|
||||||
|
}
|
||||||
|
return device_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
bool MemScheduler::PreComputeInit(const std::shared_ptr<MemEvent> &event, void *stream) {
|
bool MemScheduler::PreComputeInit(const std::shared_ptr<MemEvent> &event, void *stream) {
|
||||||
|
const bool is_continuous_mem = continuous_mem_info_helper_->IsContinuousMem(event->key);
|
||||||
const auto &iter = mem_result_.find(event->key);
|
const auto &iter = mem_result_.find(event->key);
|
||||||
const bool new_malloc = iter == mem_result_.end();
|
const bool new_malloc = iter == mem_result_.end();
|
||||||
void *device_ptr = nullptr;
|
void *device_ptr = nullptr;
|
||||||
if (new_malloc) {
|
if (!new_malloc) {
|
||||||
device_ptr = MallocDevice(event->mem_size, stream);
|
|
||||||
if (device_ptr == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
device_ptr = iter->second;
|
device_ptr = iter->second;
|
||||||
|
} else if (is_continuous_mem) {
|
||||||
|
device_ptr = MallocContinuousMem(event, stream);
|
||||||
|
} else {
|
||||||
|
device_ptr = MallocDevice(event->mem_size, stream);
|
||||||
}
|
}
|
||||||
|
if (device_ptr == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (new_malloc || high_priority_mem_need_init_.count(event->key) != 0) {
|
if (new_malloc || high_priority_mem_need_init_.count(event->key) != 0) {
|
||||||
MS_LOG(DEBUG) << "Init input data from host, key: " << event->key;
|
MS_LOG(DEBUG) << "Init input data from host, key: " << event->key;
|
||||||
auto host_ptr = init_host_ptr_[event->key];
|
auto host_ptr = init_host_ptr_[event->key];
|
||||||
|
@ -142,29 +213,33 @@ bool MemScheduler::PreComputeInit(const std::shared_ptr<MemEvent> &event, void *
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MemScheduler::PreComputeMalloc(const std::shared_ptr<MemEvent> &event, void *stream) {
|
bool MemScheduler::PreComputeMalloc(const std::shared_ptr<MemEvent> &event, void *stream) {
|
||||||
const auto &iter = mem_result_.find(event->key);
|
const bool is_continuous_mem = continuous_mem_info_helper_->IsContinuousMem(event->key);
|
||||||
const bool new_malloc = iter == mem_result_.end();
|
|
||||||
void *device_ptr = nullptr;
|
void *device_ptr = nullptr;
|
||||||
if (new_malloc) {
|
const auto &iter = mem_result_.find(event->key);
|
||||||
device_ptr = MallocDevice(event->mem_size, stream);
|
if (iter != mem_result_.end()) {
|
||||||
if (device_ptr == nullptr) {
|
return true;
|
||||||
return false;
|
} else if (is_continuous_mem) {
|
||||||
}
|
device_ptr = MallocContinuousMem(event, stream);
|
||||||
} else {
|
} else {
|
||||||
device_ptr = iter->second;
|
device_ptr = MallocDevice(event->mem_size, stream);
|
||||||
|
}
|
||||||
|
if (device_ptr == nullptr) {
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
mem_result_[event->key] = device_ptr;
|
mem_result_[event->key] = device_ptr;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool MemScheduler::PreComputeSwapIn(const std::shared_ptr<MemEvent> &event, void *stream) {
|
bool MemScheduler::PreComputeSwapIn(const std::shared_ptr<MemEvent> &event, void *stream) {
|
||||||
|
if (!PreComputeMalloc(event, stream)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
PreComputeMalloc(event, stream);
|
||||||
|
const auto device_ptr = mem_result_[event->key];
|
||||||
|
MS_EXCEPTION_IF_NULL(device_ptr);
|
||||||
bool from_init = true;
|
bool from_init = true;
|
||||||
void *host_ptr = nullptr;
|
void *host_ptr = nullptr;
|
||||||
GetHostPtr(event->key, &host_ptr, &from_init);
|
GetHostPtr(event->key, &host_ptr, &from_init);
|
||||||
auto device_ptr = MallocDevice(event->mem_size, stream);
|
|
||||||
if (device_ptr == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
MS_EXCEPTION_IF_NULL(host_ptr);
|
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||||
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
|
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
|
||||||
mem_result_[event->key] = device_ptr;
|
mem_result_[event->key] = device_ptr;
|
||||||
|
@ -196,7 +271,7 @@ bool MemScheduler::PreComputeGet(const std::shared_ptr<MemEvent> &event, void *s
|
||||||
auto device_ptr = MallocDevice(mem_size, stream);
|
auto device_ptr = MallocDevice(mem_size, stream);
|
||||||
mem_handler_->SwapIn(host_ptr, device_ptr, mem_size, stream);
|
mem_handler_->SwapIn(host_ptr, device_ptr, mem_size, stream);
|
||||||
if (!from_init) {
|
if (!from_init) {
|
||||||
(void)swap_host_ptr_.erase(host_ptr);
|
(void)swap_host_ptr_.erase(key);
|
||||||
mem_handler_->FreeHost(host_ptr);
|
mem_handler_->FreeHost(host_ptr);
|
||||||
}
|
}
|
||||||
mem_result_[key] = device_ptr;
|
mem_result_[key] = device_ptr;
|
||||||
|
@ -223,12 +298,14 @@ bool MemScheduler::PreCompute(void *stream) {
|
||||||
ret = PreComputeGet(event, stream);
|
ret = PreComputeGet(event, stream);
|
||||||
}
|
}
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
|
cur_step_allocated_continuous_mem_.clear();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (record_compute_time_ && !updated_) {
|
if (record_compute_time_ && !updated_) {
|
||||||
compute_start_time_ = GetCurrentTime();
|
compute_start_time_ = GetCurrentTime();
|
||||||
}
|
}
|
||||||
|
cur_step_allocated_continuous_mem_.clear();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -253,6 +330,7 @@ bool MemScheduler::PostCompute(void *stream) {
|
||||||
}
|
}
|
||||||
mem_handler_->FreeDevice(ptr);
|
mem_handler_->FreeDevice(ptr);
|
||||||
(void)mem_result_.erase(event->key);
|
(void)mem_result_.erase(event->key);
|
||||||
|
continuous_mem_key_.erase(event->key);
|
||||||
} else if (event->type == kSwapOut) {
|
} else if (event->type == kSwapOut) {
|
||||||
auto device_ptr = mem_result_[event->key];
|
auto device_ptr = mem_result_[event->key];
|
||||||
if (device_ptr == nullptr) {
|
if (device_ptr == nullptr) {
|
||||||
|
@ -261,6 +339,11 @@ bool MemScheduler::PostCompute(void *stream) {
|
||||||
SwapOutAndFreeDevice(event->key, device_ptr, event->mem_size, stream);
|
SwapOutAndFreeDevice(event->key, device_ptr, event->mem_size, stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for (const auto &info : continuous_mem_info_helper_->GetIndexContinuousMemInfo(current_step_)) {
|
||||||
|
for (const auto &key_index : info->key_index_map_) {
|
||||||
|
continuous_mem_key_.erase(key_index.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
++current_step_;
|
++current_step_;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -269,8 +352,9 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
|
||||||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
MS_EXCEPTION_IF_NULL(mem_handler_);
|
||||||
|
|
||||||
if (strategy_ == nullptr) {
|
if (strategy_ == nullptr) {
|
||||||
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_,
|
strategy_ =
|
||||||
high_priority_updated_step_, total_step_);
|
std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_,
|
||||||
|
high_priority_updated_step_, total_step_, continuous_mem_info_helper_);
|
||||||
if (manual_offload_keys_.empty()) {
|
if (manual_offload_keys_.empty()) {
|
||||||
compute_time_.resize(total_step_);
|
compute_time_.resize(total_step_);
|
||||||
} else {
|
} else {
|
||||||
|
@ -352,25 +436,78 @@ void *MemScheduler::MallocDevice(size_t mem_size, void *stream) {
|
||||||
if (device_ptr != nullptr || !optimized_) {
|
if (device_ptr != nullptr || !optimized_) {
|
||||||
return device_ptr;
|
return device_ptr;
|
||||||
}
|
}
|
||||||
|
// Find memory block big enough in mem_result_, except continuous mem and memory blocks used in this step.
|
||||||
auto iter = mem_result_.begin();
|
auto iter = mem_result_.begin();
|
||||||
using KeySizePair = std::pair<const void *, size_t>;
|
using KeySizePair = std::pair<const void *, size_t>;
|
||||||
auto less = [](const KeySizePair &a, const KeySizePair &b) -> bool { return a.second < b.second; };
|
auto less = [](const KeySizePair &a, const KeySizePair &b) -> bool { return a.second < b.second; };
|
||||||
std::priority_queue<KeySizePair, std::vector<KeySizePair>, decltype(less)> mem_can_swap(less);
|
std::priority_queue<KeySizePair, std::vector<KeySizePair>, decltype(less)> mem_can_swap(less);
|
||||||
while (iter != mem_result_.end()) {
|
while (iter != mem_result_.end()) {
|
||||||
const auto key = iter->first;
|
const auto key = iter->first;
|
||||||
if (no_reuse_key.count(key) != 0) {
|
if (no_reuse_key.count(key) != 0 || continuous_mem_key_.count(key) != 0) {
|
||||||
++iter;
|
++iter;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const auto device_mem_size = GetMemSize(key);
|
const auto device_mem_size = GetMemSize(key);
|
||||||
mem_can_swap.push({key, device_mem_size});
|
|
||||||
if (device_mem_size >= mem_size) {
|
if (device_mem_size >= mem_size) {
|
||||||
SwapOutAndFreeDevice(key, iter->second, device_mem_size, stream);
|
SwapOutAndFreeDevice(key, iter->second, device_mem_size, stream);
|
||||||
device_ptr = mem_handler_->MallocDevice(mem_size);
|
device_ptr = mem_handler_->MallocDevice(mem_size);
|
||||||
|
MS_EXCEPTION_IF_NULL(device_ptr);
|
||||||
return device_ptr;
|
return device_ptr;
|
||||||
}
|
}
|
||||||
|
mem_can_swap.push({key, device_mem_size});
|
||||||
++iter;
|
++iter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try swap out memory block from big to small
|
||||||
|
while (!mem_can_swap.empty()) {
|
||||||
|
const auto &max_mem_in_device = mem_can_swap.top();
|
||||||
|
const auto key = max_mem_in_device.first;
|
||||||
|
const auto swap_mem_size = max_mem_in_device.second;
|
||||||
|
auto swap_device_ptr = mem_result_[key];
|
||||||
|
MS_EXCEPTION_IF_NULL(swap_device_ptr);
|
||||||
|
mem_can_swap.pop();
|
||||||
|
SwapOutAndFreeDevice(key, swap_device_ptr, swap_mem_size, stream);
|
||||||
|
device_ptr = mem_handler_->MallocDevice(mem_size);
|
||||||
|
if (device_ptr != nullptr) {
|
||||||
|
return device_ptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<void *> MemScheduler::MallocContinuousMem(size_t total_size, const std::vector<size_t> &size_list,
|
||||||
|
void *stream) {
|
||||||
|
const auto &no_reuse_key = step_keys_[current_step_];
|
||||||
|
auto device_ptr_list = mem_handler_->MallocContinuousMemFromMemPool(size_list);
|
||||||
|
if (!device_ptr_list.empty() || !optimized_) {
|
||||||
|
return device_ptr_list;
|
||||||
|
}
|
||||||
|
// Find memory block big enough in mem_result_, except continuous mem and memory blocks used in this step.
|
||||||
|
auto iter = mem_result_.begin();
|
||||||
|
using KeySizePair = std::pair<const void *, size_t>;
|
||||||
|
auto less = [](const KeySizePair &a, const KeySizePair &b) -> bool { return a.second < b.second; };
|
||||||
|
std::priority_queue<KeySizePair, std::vector<KeySizePair>, decltype(less)> mem_can_swap(less);
|
||||||
|
while (iter != mem_result_.end()) {
|
||||||
|
const auto key = iter->first;
|
||||||
|
if (no_reuse_key.count(key) != 0 || continuous_mem_key_.count(key) != 0) {
|
||||||
|
++iter;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const auto device_mem_size = GetMemSize(key);
|
||||||
|
if (device_mem_size >= total_size) {
|
||||||
|
SwapOutAndFreeDevice(key, iter->second, device_mem_size, stream);
|
||||||
|
device_ptr_list = mem_handler_->MallocContinuousMemFromMemPool(size_list);
|
||||||
|
if (device_ptr_list.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "device_ptr_list empty";
|
||||||
|
}
|
||||||
|
return device_ptr_list;
|
||||||
|
}
|
||||||
|
mem_can_swap.push({key, device_mem_size});
|
||||||
|
++iter;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try swap out memory block from big to small
|
||||||
while (!mem_can_swap.empty()) {
|
while (!mem_can_swap.empty()) {
|
||||||
const auto &max_mem_in_device = mem_can_swap.top();
|
const auto &max_mem_in_device = mem_can_swap.top();
|
||||||
mem_can_swap.pop();
|
mem_can_swap.pop();
|
||||||
|
@ -379,12 +516,13 @@ void *MemScheduler::MallocDevice(size_t mem_size, void *stream) {
|
||||||
auto swap_device_ptr = mem_result_[key];
|
auto swap_device_ptr = mem_result_[key];
|
||||||
MS_EXCEPTION_IF_NULL(swap_device_ptr);
|
MS_EXCEPTION_IF_NULL(swap_device_ptr);
|
||||||
SwapOutAndFreeDevice(key, swap_device_ptr, swap_mem_size, stream);
|
SwapOutAndFreeDevice(key, swap_device_ptr, swap_mem_size, stream);
|
||||||
device_ptr = mem_handler_->MallocDevice(mem_size);
|
device_ptr_list = mem_handler_->MallocContinuousMemFromMemPool(size_list);
|
||||||
if (device_ptr != nullptr) {
|
if (!device_ptr_list.empty()) {
|
||||||
return device_ptr;
|
return device_ptr_list;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nullptr;
|
|
||||||
|
return device_ptr_list;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemScheduler::SwapOutAndFreeDevice(const void *key, void *device_ptr, size_t mem_size, void *stream) {
|
void MemScheduler::SwapOutAndFreeDevice(const void *key, void *device_ptr, size_t mem_size, void *stream) {
|
||||||
|
|
|
@ -20,22 +20,36 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <queue>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "runtime/device/memory_offload_strategy.h"
|
#include "runtime/device/memory_offload_strategy.h"
|
||||||
|
#include "runtime/device/memory_manager.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
class MemHandler {
|
class MemHandler {
|
||||||
public:
|
public:
|
||||||
MemHandler() = default;
|
explicit MemHandler(std::shared_ptr<MemoryManager> memory_manager) : memory_manager_(memory_manager) {}
|
||||||
virtual ~MemHandler() = default;
|
~MemHandler() = default;
|
||||||
virtual size_t GetAvailableMemSize() = 0;
|
size_t GetAvailableMemSize() { return memory_manager_->GetAvailableMemSize(); }
|
||||||
virtual void *MallocDevice(size_t mem_size) = 0;
|
void *MallocDevice(size_t mem_size) { return memory_manager_->MallocMemFromMemPool(mem_size, false); }
|
||||||
virtual void FreeDevice(void *ptr) = 0;
|
void FreeDevice(void *ptr) { memory_manager_->FreeMemFromMemPool(ptr); }
|
||||||
virtual void *MallocHost(size_t mem_size) = 0;
|
void *MallocHost(size_t mem_size);
|
||||||
virtual void FreeHost(void *ptr) = 0;
|
void FreeHost(void *ptr);
|
||||||
virtual void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) = 0;
|
void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) {
|
||||||
virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0;
|
memory_manager_->SwapIn(host_ptr, device_ptr, mem_size, stream);
|
||||||
|
}
|
||||||
|
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) {
|
||||||
|
memory_manager_->SwapOut(device_ptr, host_ptr, mem_size, stream);
|
||||||
|
}
|
||||||
|
std::vector<void *> MallocContinuousMemFromMemPool(const std::vector<size_t> &size_list) {
|
||||||
|
return memory_manager_->MallocContinuousMemFromMemPool(size_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<MemoryManager> memory_manager_;
|
||||||
|
std::map<size_t, std::queue<void *>> cached_host_mem_;
|
||||||
|
std::map<void *, std::shared_ptr<std::vector<uint8_t>>> host_mem_block_map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MemScheduler {
|
class MemScheduler {
|
||||||
|
@ -88,6 +102,10 @@ class MemScheduler {
|
||||||
|
|
||||||
void ClearMemNeedInit() { high_priority_mem_need_init_.clear(); }
|
void ClearMemNeedInit() { high_priority_mem_need_init_.clear(); }
|
||||||
|
|
||||||
|
void AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size,
|
||||||
|
const std::vector<size_t> &align_size_list,
|
||||||
|
const std::vector<const void *> &address_key_list);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
|
void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
|
||||||
|
|
||||||
|
@ -99,6 +117,8 @@ class MemScheduler {
|
||||||
|
|
||||||
void *MallocDevice(size_t mem_size, void *stream);
|
void *MallocDevice(size_t mem_size, void *stream);
|
||||||
|
|
||||||
|
std::vector<void *> MallocContinuousMem(size_t total_size, const std::vector<size_t> &size_list, void *stream);
|
||||||
|
|
||||||
void SwapOutAndFreeDevice(const void *key, void *device_ptr, size_t mem_size, void *stream);
|
void SwapOutAndFreeDevice(const void *key, void *device_ptr, size_t mem_size, void *stream);
|
||||||
|
|
||||||
size_t GetMemSize(const void *key);
|
size_t GetMemSize(const void *key);
|
||||||
|
@ -115,8 +135,10 @@ class MemScheduler {
|
||||||
|
|
||||||
bool PreComputeGet(const std::shared_ptr<MemEvent> &event, void *stream);
|
bool PreComputeGet(const std::shared_ptr<MemEvent> &event, void *stream);
|
||||||
|
|
||||||
|
void *MallocContinuousMem(const std::shared_ptr<MemEvent> &event, void *stream);
|
||||||
|
|
||||||
std::map<const void *, MemPriority> mem_priority_;
|
std::map<const void *, MemPriority> mem_priority_;
|
||||||
std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_;
|
std::map<const void *, MemEventPtrList> mem_events_;
|
||||||
std::set<const void *> manual_offload_keys_;
|
std::set<const void *> manual_offload_keys_;
|
||||||
std::vector<std::set<const void *>> step_keys_;
|
std::vector<std::set<const void *>> step_keys_;
|
||||||
std::map<const void *, void *> mem_result_;
|
std::map<const void *, void *> mem_result_;
|
||||||
|
@ -134,6 +156,9 @@ class MemScheduler {
|
||||||
bool updated_{false};
|
bool updated_{false};
|
||||||
std::shared_ptr<MemHandler> mem_handler_{nullptr};
|
std::shared_ptr<MemHandler> mem_handler_{nullptr};
|
||||||
std::shared_ptr<MemOffloadStrategy> strategy_{nullptr};
|
std::shared_ptr<MemOffloadStrategy> strategy_{nullptr};
|
||||||
|
std::shared_ptr<ContinuousMemInfoHelper> continuous_mem_info_helper_{std::make_shared<ContinuousMemInfoHelper>()};
|
||||||
|
std::set<std::shared_ptr<ContinuousMemInfo>> cur_step_allocated_continuous_mem_;
|
||||||
|
std::set<const void *> continuous_mem_key_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MemSchedulerManager {
|
class MemSchedulerManager {
|
||||||
|
|
|
@ -21,16 +21,17 @@
|
||||||
namespace mindspore::device {
|
namespace mindspore::device {
|
||||||
constexpr size_t kDeviceMemSize = 5;
|
constexpr size_t kDeviceMemSize = 5;
|
||||||
constexpr size_t kMaxVirtualCount = 1024;
|
constexpr size_t kMaxVirtualCount = 1024;
|
||||||
class MemHandlerImpl : public MemHandler {
|
class MemoryManagerStub : public MemoryManager {
|
||||||
public:
|
public:
|
||||||
MemHandlerImpl() {
|
MemoryManagerStub() {
|
||||||
device_mem_.resize(kMaxVirtualCount, 0);
|
device_mem_.resize(kMaxVirtualCount, 0);
|
||||||
host_mem_.resize(kMaxVirtualCount, 1);
|
|
||||||
}
|
}
|
||||||
|
void Initialize() override {}
|
||||||
|
void Finalize() override {}
|
||||||
|
|
||||||
size_t GetAvailableMemSize() override { return kDeviceMemSize; }
|
size_t GetAvailableMemSize() override { return kDeviceMemSize; }
|
||||||
|
|
||||||
void *MallocDevice(size_t mem_size) override {
|
void *MallocMemFromMemPool(size_t mem_size, bool useless = false) override {
|
||||||
if (device_virtual_count_ >= kDeviceMemSize) {
|
if (device_virtual_count_ >= kDeviceMemSize) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -40,7 +41,7 @@ class MemHandlerImpl : public MemHandler {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FreeDevice(void *ptr) override {
|
void FreeMemFromMemPool(void *ptr) override {
|
||||||
--device_virtual_count_;
|
--device_virtual_count_;
|
||||||
auto iter = device_mem_size_.find(ptr);
|
auto iter = device_mem_size_.find(ptr);
|
||||||
if (iter != device_mem_size_.end()) {
|
if (iter != device_mem_size_.end()) {
|
||||||
|
@ -48,31 +49,32 @@ class MemHandlerImpl : public MemHandler {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void *MallocHost(size_t mem_size) override {
|
std::vector<void *> MallocContinuousMemFromMemPool(const std::vector<size_t> &size_list) override {
|
||||||
auto ret = host_mem_.data() + host_virtual_count_;
|
const size_t total_size = std::accumulate(size_list.begin(), size_list.end(), 0);
|
||||||
++host_virtual_count_;
|
std::vector<void *> ret;
|
||||||
host_mem_size_.emplace(ret, mem_size);
|
if (device_virtual_count_ + total_size > kDeviceMemSize) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
|
||||||
|
|
||||||
void FreeHost(void *ptr) override {
|
|
||||||
auto iter = host_mem_size_.find(ptr);
|
|
||||||
if (iter != host_mem_size_.end()) {
|
|
||||||
host_mem_size_.erase(iter);
|
|
||||||
}
|
}
|
||||||
|
for (const auto &size : size_list) {
|
||||||
|
auto ptr = device_mem_.data() + device_virtual_count_;
|
||||||
|
device_mem_size_.emplace(ptr, size);
|
||||||
|
ret.emplace_back(ptr);
|
||||||
|
++device_virtual_count_;
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {}
|
void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {}
|
||||||
|
|
||||||
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {}
|
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) { return nullptr; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<uint8_t> device_mem_;
|
std::vector<uint8_t> device_mem_;
|
||||||
std::vector<uint8_t> host_mem_;
|
|
||||||
size_t device_virtual_count_{0};
|
size_t device_virtual_count_{0};
|
||||||
size_t host_virtual_count_{0};
|
|
||||||
std::map<void *, size_t> device_mem_size_;
|
std::map<void *, size_t> device_mem_size_;
|
||||||
std::map<void *, size_t> host_mem_size_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class TestMemScheduler : public UT::Common {
|
class TestMemScheduler : public UT::Common {
|
||||||
|
@ -136,14 +138,14 @@ TEST_F(TestMemScheduler, test_mem_scheduler_manager) {
|
||||||
|
|
||||||
/// Feature: MemScheduler
|
/// Feature: MemScheduler
|
||||||
/// Description: Test MemScheduler interface
|
/// Description: Test MemScheduler interface
|
||||||
/// Expectation: MemScheduler GetOrMalloc return valid ptr
|
/// Expectation: MemScheduler GetOrMalloc return valid ptr for continuous mem
|
||||||
TEST_F(TestMemScheduler, test_mem_scheduler) {
|
TEST_F(TestMemScheduler, test_mem_scheduler_with_continuous_mem) {
|
||||||
MemSchedulerManager mem_scheduler_manager;
|
MemSchedulerManager mem_scheduler_manager;
|
||||||
auto scheduler = mem_scheduler_manager.GetOrCreateMemScheduler(0);
|
auto scheduler = mem_scheduler_manager.GetOrCreateMemScheduler(0);
|
||||||
ASSERT_NE(scheduler, nullptr);
|
ASSERT_NE(scheduler, nullptr);
|
||||||
auto need_record = scheduler->need_record_event();
|
auto need_record = scheduler->need_record_event();
|
||||||
ASSERT_EQ(need_record, true);
|
ASSERT_EQ(need_record, true);
|
||||||
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>();
|
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandler>(std::make_shared<MemoryManagerStub>());
|
||||||
ASSERT_NE(mem_handler, nullptr);
|
ASSERT_NE(mem_handler, nullptr);
|
||||||
scheduler->SetMemHandler(mem_handler);
|
scheduler->SetMemHandler(mem_handler);
|
||||||
|
|
||||||
|
@ -190,7 +192,7 @@ TEST_F(TestMemScheduler, test_manual_mem_scheduler) {
|
||||||
ASSERT_NE(scheduler, nullptr);
|
ASSERT_NE(scheduler, nullptr);
|
||||||
auto need_record = scheduler->need_record_event();
|
auto need_record = scheduler->need_record_event();
|
||||||
ASSERT_EQ(need_record, true);
|
ASSERT_EQ(need_record, true);
|
||||||
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>();
|
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandler>(std::make_shared<MemoryManagerStub>());
|
||||||
ASSERT_NE(mem_handler, nullptr);
|
ASSERT_NE(mem_handler, nullptr);
|
||||||
scheduler->SetMemHandler(mem_handler);
|
scheduler->SetMemHandler(mem_handler);
|
||||||
|
|
||||||
|
@ -232,4 +234,53 @@ TEST_F(TestMemScheduler, test_manual_mem_scheduler) {
|
||||||
// run
|
// run
|
||||||
Run(scheduler);
|
Run(scheduler);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Feature: MemScheduler
|
||||||
|
/// Description: Test MemScheduler interface
|
||||||
|
/// Expectation: MemScheduler GetOrMalloc return valid ptr
|
||||||
|
TEST_F(TestMemScheduler, test_mem_scheduler) {
|
||||||
|
MemSchedulerManager mem_scheduler_manager;
|
||||||
|
auto scheduler = mem_scheduler_manager.GetOrCreateMemScheduler(0);
|
||||||
|
ASSERT_NE(scheduler, nullptr);
|
||||||
|
auto need_record = scheduler->need_record_event();
|
||||||
|
ASSERT_EQ(need_record, true);
|
||||||
|
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandler>(std::make_shared<MemoryManagerStub>());
|
||||||
|
ASSERT_NE(mem_handler, nullptr);
|
||||||
|
scheduler->SetMemHandler(mem_handler);
|
||||||
|
|
||||||
|
// input data
|
||||||
|
used_tensor_num_ = 8;
|
||||||
|
total_step_ = 8;
|
||||||
|
std::vector<uint8_t> tensor_keys(used_tensor_num_, 0);
|
||||||
|
std::vector<uint8_t> tensor_datas(used_tensor_num_, 0);
|
||||||
|
std::vector<size_t> init_tensors = {0, 2, 4};
|
||||||
|
// 8 step tensor usage
|
||||||
|
//
|
||||||
|
// 0-----0-----0
|
||||||
|
// 1--1--1
|
||||||
|
// 2-----2--------2
|
||||||
|
// 3-----------3
|
||||||
|
// 4--------------4
|
||||||
|
// 6--------------6
|
||||||
|
// 7--------7
|
||||||
|
//
|
||||||
|
std::vector<std::vector<size_t>> step_used_tensors = {{0, 2, 6}, {1, 7}, {0, 1, 2, 3, 4}, {1},
|
||||||
|
{0, 7}, {2, 6}, {3}, {4}};
|
||||||
|
tensor_keys_.swap(tensor_keys);
|
||||||
|
tensor_datas_.swap(tensor_datas);
|
||||||
|
init_tensors_.swap(init_tensors);
|
||||||
|
step_used_tensors_.swap(step_used_tensors);
|
||||||
|
scheduler->SetTotalStep(total_step_);
|
||||||
|
|
||||||
|
// record
|
||||||
|
Record(scheduler);
|
||||||
|
// Add continuous memory info
|
||||||
|
scheduler->AddContinuousMemInfo(true, 2, 3, {1, 1, 1},
|
||||||
|
{tensor_keys_.data(), tensor_keys_.data() + 1, tensor_keys_.data() + 2});
|
||||||
|
scheduler->AddContinuousMemInfo(false, 2, 2, {1, 1}, {tensor_keys_.data() + 3, tensor_keys_.data() + 4});
|
||||||
|
// optimize
|
||||||
|
scheduler->Optimize();
|
||||||
|
// run
|
||||||
|
Run(scheduler);
|
||||||
|
}
|
||||||
} // namespace mindspore::device
|
} // namespace mindspore::device
|
Loading…
Reference in New Issue