set offload node

This commit is contained in:
kswang 2021-12-02 20:55:22 +08:00
parent 6924866a87
commit 391a06aad1
8 changed files with 226 additions and 71 deletions

View File

@ -1449,6 +1449,13 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
return ret; return ret;
} }
AssignKernelAddress(mem_scheduler, kernel, &kernel_launch_info); AssignKernelAddress(mem_scheduler, kernel, &kernel_launch_info);
auto cnode = kernel->cast<CNodePtr>();
if (mock && AnfAlgo::HasNodeAttr(kAttrOffload, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrOffload)) {
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
auto device_address = AnfAlgo::GetOutputAddr(kernel, i, true);
mem_scheduler->SetOffload(device_address);
}
}
} else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) { } else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) {
kernel_launch_info.inputs_ = kernel_mod->GetInputsAddr(); kernel_launch_info.inputs_ = kernel_mod->GetInputsAddr();
kernel_launch_info.outputs_ = kernel_mod->GetOutputsAddr(); kernel_launch_info.outputs_ = kernel_mod->GetOutputsAddr();
@ -1462,15 +1469,15 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
} else { } else {
ret = kernel_mod->Launch(kernel_launch_info, stream); ret = kernel_mod->Launch(kernel_launch_info, stream);
} }
if (!ret) {
return ret;
}
} }
if (mem_scheduler != nullptr) { if (mem_scheduler != nullptr) {
if (!mock) { if (!mock) {
SyncNodeOutputTensors(mem_scheduler, graph, kernel); SyncNodeOutputTensors(mem_scheduler, graph, kernel);
} }
ret = mem_scheduler->PostCompute(stream); ret = mem_scheduler->PostCompute(stream);
if (!ret) {
return ret;
}
} }
return ret; return ret;
} }
@ -1483,7 +1490,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
if (UseMemScheduler()) { if (UseMemScheduler()) {
mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id()); mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
MS_EXCEPTION_IF_NULL(mem_scheduler); MS_EXCEPTION_IF_NULL(mem_scheduler);
mem_scheduler->ResetCurrentStep(); mem_scheduler->Reset();
mem_scheduler->Update(); mem_scheduler->Update();
InitGraphInputTensors(mem_scheduler, graph); InitGraphInputTensors(mem_scheduler, graph);
} }
@ -1594,8 +1601,8 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
(void)LaunchKernelMod(graph, true); (void)LaunchKernelMod(graph, true);
mem_scheduler->set_need_record_event(false); mem_scheduler->set_need_record_event(false);
} }
mem_scheduler->Optimize(); auto ret = mem_scheduler->Optimize();
if (!mem_scheduler->optimized()) { if (!ret) {
MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit."; MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit.";
} }
} }

View File

@ -16,7 +16,6 @@
#include "runtime/device/memory_offload_strategy.h" #include "runtime/device/memory_offload_strategy.h"
#include <vector> #include <vector>
#include <map> #include <map>
#include <set>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
@ -100,7 +99,7 @@ void MemOffloadStrategy::CheckMemSize() {
<< min_mem_needed_; << min_mem_needed_;
} }
if (mem_size_ < mem_used_without_swap_) { if (mem_size_ < mem_used_without_swap_ || !manual_offload_keys_.empty()) {
need_swap_ = true; need_swap_ = true;
} }
@ -141,6 +140,18 @@ void MemOffloadStrategy::GenEventSpan() {
void MemOffloadStrategy::GenSwapEventSet() { void MemOffloadStrategy::GenSwapEventSet() {
swap_events_.clear(); swap_events_.clear();
// manual offload strategy
if (!manual_offload_keys_.empty()) {
for (const auto &iter : event_span_) {
auto &event = iter.second.first;
if (manual_offload_keys_.find(event->key) != manual_offload_keys_.end()) {
(void)swap_events_.emplace(event);
}
}
return;
}
// greedy span filter
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end()); std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
for (const auto &iter : event_span_) { for (const auto &iter : event_span_) {
auto span = iter.second.second; auto span = iter.second.second;
@ -179,9 +190,6 @@ void MemOffloadStrategy::GenComputeMemEvents() {
post_compute_events_.resize(total_step_); post_compute_events_.resize(total_step_);
for (auto &item : mem_events_) { for (auto &item : mem_events_) {
auto &mem_events = item.second; auto &mem_events = item.second;
if (mem_events.empty()) {
continue;
}
// No need to generate events for memory that has only one event, which means it is never used by any kernel. // No need to generate events for memory that has only one event, which means it is never used by any kernel.
if (mem_events.size() <= 1) { if (mem_events.size() <= 1) {
continue; continue;
@ -211,11 +219,14 @@ void MemOffloadStrategy::GenComputeMemEvents() {
swap_out_event->key = item.first; swap_out_event->key = item.first;
swap_out_event->mem_size = first_event->mem_size; swap_out_event->mem_size = first_event->mem_size;
post_compute_events_[pre_index].emplace_back(swap_out_event); post_compute_events_[pre_index].emplace_back(swap_out_event);
// avoid swap-in-event follow init-event
if (first_event->type != kInit || i != 1) {
auto swap_in_event = std::make_shared<MemEvent>(kSwapIn, event->index); auto swap_in_event = std::make_shared<MemEvent>(kSwapIn, event->index);
swap_in_event->key = item.first; swap_in_event->key = item.first;
swap_in_event->mem_size = first_event->mem_size; swap_in_event->mem_size = first_event->mem_size;
(void)pre_compute_events_[event->index].emplace_back(swap_in_event); (void)pre_compute_events_[event->index].emplace_back(swap_in_event);
} }
}
if (event->index < pre_compute_events_.size()) { if (event->index < pre_compute_events_.size()) {
(void)pre_compute_events_[event->index].emplace_back(event); (void)pre_compute_events_[event->index].emplace_back(event);
} }

View File

@ -41,8 +41,11 @@ class MemOffloadStrategy {
public: public:
MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority, MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority,
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events, const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events,
size_t total_step) const std::set<const void *> &manual_offload_keys, size_t total_step)
: mem_priority_(mem_priority), mem_events_(mem_events), total_step_(total_step) {} : mem_priority_(mem_priority),
mem_events_(mem_events),
manual_offload_keys_(manual_offload_keys),
total_step_(total_step) {}
virtual ~MemOffloadStrategy() = default; virtual ~MemOffloadStrategy() = default;
@ -58,18 +61,24 @@ class MemOffloadStrategy {
bool need_swap() const { return need_swap_; } bool need_swap() const { return need_swap_; }
private:
bool IsHighPriorityMem(const void *key); bool IsHighPriorityMem(const void *key);
private:
void CountMemUsage(); void CountMemUsage();
void CheckMemSize(); void CheckMemSize();
void GenEventSpan(); void GenEventSpan();
void GenSwapEventSet(); void GenSwapEventSet();
void GenComputeMemEvents(); void GenComputeMemEvents();
void GenFreeEvent(const std::shared_ptr<MemEvent> &last_event); void GenFreeEvent(const std::shared_ptr<MemEvent> &last_event);
const std::map<const void *, MemPriority> &mem_priority_; const std::map<const void *, MemPriority> &mem_priority_;
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_; const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_;
const std::set<const void *> &manual_offload_keys_;
const size_t total_step_; const size_t total_step_;
std::vector<std::vector<std::shared_ptr<MemEvent>>> pre_compute_events_; std::vector<std::vector<std::shared_ptr<MemEvent>>> pre_compute_events_;
std::vector<std::vector<std::shared_ptr<MemEvent>>> post_compute_events_; std::vector<std::vector<std::shared_ptr<MemEvent>>> post_compute_events_;

View File

@ -16,12 +16,12 @@
#include "runtime/device/memory_scheduler.h" #include "runtime/device/memory_scheduler.h"
#include <algorithm> #include <algorithm>
#include "utils/log_adapter.h"
#ifdef _MSC_VER #ifdef _MSC_VER
#include <time.h> #include <time.h>
#else #else
#include <sys/time.h> #include <sys/time.h>
#endif #endif
#include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
@ -51,7 +51,7 @@ void MemScheduler::Clear() {
high_priority_device_ptr_.clear(); high_priority_device_ptr_.clear();
} }
void MemScheduler::ClearTempMem() { void MemScheduler::ClearAllocatedMem() {
if (mem_handler_ == nullptr) { if (mem_handler_ == nullptr) {
return; return;
} }
@ -72,8 +72,6 @@ void MemScheduler::ClearTempMem() {
swap_host_ptr_.clear(); swap_host_ptr_.clear();
} }
void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) { void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
if (key == nullptr) { if (key == nullptr) {
return; return;
@ -184,7 +182,7 @@ bool MemScheduler::PostCompute(void *stream) {
return true; return true;
} }
if (record_compute_time_ && !updated_) { if (record_compute_time_ && !updated_ && current_step_ < compute_time_.size()) {
compute_time_[current_step_] = GetCurrentTime() - compute_start_time_; compute_time_[current_step_] = GetCurrentTime() - compute_start_time_;
} }
@ -227,8 +225,12 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
MS_EXCEPTION_IF_NULL(mem_handler_); MS_EXCEPTION_IF_NULL(mem_handler_);
if (strategy_ == nullptr) { if (strategy_ == nullptr) {
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, total_step_); strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_, total_step_);
if (manual_offload_keys_.empty()) {
compute_time_.resize(total_step_); compute_time_.resize(total_step_);
} else {
updated_ = true;
}
} }
auto available_mem_size = mem_handler_->GetAvailableMemSize(); auto available_mem_size = mem_handler_->GetAvailableMemSize();
@ -237,7 +239,7 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
strategy_->Execute(); strategy_->Execute();
} }
void MemScheduler::Optimize() { bool MemScheduler::Optimize() {
AdjustFirstEventIndex(); AdjustFirstEventIndex();
float mem_used_factor = kMaxMemReuseFactor; float mem_used_factor = kMaxMemReuseFactor;
while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) { while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) {
@ -265,10 +267,11 @@ void MemScheduler::Optimize() {
if (ret) { if (ret) {
optimized_ = true; optimized_ = true;
} else { } else {
ClearTempMem(); ClearAllocatedMem();
mem_used_factor -= kRetryFactor; mem_used_factor -= kRetryFactor;
} }
} }
return optimized_;
} }
void MemScheduler::AdjustFirstEventIndex() { void MemScheduler::AdjustFirstEventIndex() {

View File

@ -45,8 +45,6 @@ class MemScheduler {
void set_need_record_event(bool flag) { need_record_event_ = flag; } void set_need_record_event(bool flag) { need_record_event_ = flag; }
bool optimized() const { return optimized_; }
void Update(); void Update();
void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; } void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
@ -60,19 +58,19 @@ class MemScheduler {
step_events_.resize(total_step_); step_events_.resize(total_step_);
} }
void ResetCurrentStep() { current_step_ = 0; } void Reset() { current_step_ = 0; }
bool PreCompute(void *stream); bool PreCompute(void *stream);
bool PostCompute(void *stream); bool PostCompute(void *stream);
void Optimize(); bool Optimize();
void Clear(); void Clear();
void ClearTempMem(); void ClearAllocatedMem();
void SetMemPriority(const void *key, MemPriority priority); void SetOffload(const void *key) { (void)manual_offload_keys_.insert(key); }
private: private:
void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0); void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
@ -83,6 +81,7 @@ class MemScheduler {
std::map<const void *, MemPriority> mem_priority_; std::map<const void *, MemPriority> mem_priority_;
std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_; std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_;
std::set<const void *> manual_offload_keys_;
std::vector<std::vector<std::shared_ptr<MemEvent>>> step_events_; std::vector<std::vector<std::shared_ptr<MemEvent>>> step_events_;
std::map<const void *, void *> mem_result_; std::map<const void *, void *> mem_result_;
std::map<const void *, void *> init_host_ptr_; std::map<const void *, void *> init_host_ptr_;

View File

@ -349,6 +349,7 @@ constexpr auto kAttrIsAICPUKernel = "is_AICPU_kernel";
constexpr auto kIsBackendCast = "is_backed_cast"; constexpr auto kIsBackendCast = "is_backed_cast";
constexpr auto kAttrOutputNames = "output_names"; constexpr auto kAttrOutputNames = "output_names";
constexpr auto kAttrAsync = "async"; constexpr auto kAttrAsync = "async";
constexpr auto kAttrOffload = "offload";
constexpr auto kAttrVisited = "visited"; constexpr auto kAttrVisited = "visited";
constexpr auto kAttrShape = "shape"; constexpr auto kAttrShape = "shape";
constexpr auto kAttrMomentum = "momentum"; constexpr auto kAttrMomentum = "momentum";

View File

@ -123,3 +123,32 @@ def test_lenet():
diff = res.asnumpy()[0] - 2.3025851 diff = res.asnumpy()[0] - 2.3025851
assert np.all(diff < 1.e-6) assert np.all(diff < 1.e-6)
os.environ['ENABLE_MEM_SCHEDULER'] = '0' os.environ['ENABLE_MEM_SCHEDULER'] = '0'
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_lenet_manual_offload():
'''
Feature: MemScheduler
Description: Test set offload strategy
Expectation: Run lenet success
'''
os.environ['ENABLE_MEM_SCHEDULER'] = '1'
data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([32]).astype(np.int32))
net = LeNet()
net.relu.add_prim_attr("Offload", True)
learning_rate = 0.01
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()
res = train_network(data, label)
diff = res.asnumpy()[0] - 2.3025851
assert np.all(diff < 1.e-6)
os.environ['ENABLE_MEM_SCHEDULER'] = '0'

View File

@ -19,47 +19,58 @@
#include "common/common_test.h" #include "common/common_test.h"
#include "runtime/device/memory_scheduler.h" #include "runtime/device/memory_scheduler.h"
namespace mindspore::device { namespace mindspore::device {
constexpr size_t kDeviceMemSize = 1 * 1024 * 1024 * 1024; constexpr size_t kDeviceMemSize = 5;
constexpr size_t kMaxVirtualCount = 1 * 1024 * 1024; constexpr size_t kMaxVirtualCount = 1024;
class MemHandlerImpl : public MemHandler { class MemHandlerImpl : public MemHandler {
public: public:
MemHandlerImpl() { MemHandlerImpl() {
device_mem_.resize(kMaxVirtualCount, 0); device_mem_.resize(kMaxVirtualCount, 0);
host_mem_.resize(kMaxVirtualCount, 1); host_mem_.resize(kMaxVirtualCount, 1);
} }
size_t GetAvailableMemSize() override { return kDeviceMemSize; } size_t GetAvailableMemSize() override { return kDeviceMemSize; }
void *MallocDevice(size_t mem_size) override { void *MallocDevice(size_t mem_size) override {
if (device_virtual_count_ >= kDeviceMemSize) {
return nullptr;
}
auto ret = device_mem_.data() + device_virtual_count_; auto ret = device_mem_.data() + device_virtual_count_;
++device_virtual_count_; ++device_virtual_count_;
device_mem_size_.emplace(ret, mem_size); device_mem_size_.emplace(ret, mem_size);
return ret; return ret;
} }
void FreeDevice(void *ptr) override { void FreeDevice(void *ptr) override {
--device_virtual_count_;
auto iter = device_mem_size_.find(ptr); auto iter = device_mem_size_.find(ptr);
if (iter != device_mem_size_.end()) { if (iter != device_mem_size_.end()) {
device_mem_size_.erase(iter); device_mem_size_.erase(iter);
} }
} }
void *MallocHost(size_t mem_size) override { void *MallocHost(size_t mem_size) override {
auto ret = host_mem_.data() + host_virtual_count_; auto ret = host_mem_.data() + host_virtual_count_;
++host_virtual_count_; ++host_virtual_count_;
host_mem_size_.emplace(ret, mem_size); host_mem_size_.emplace(ret, mem_size);
return ret; return ret;
} }
void FreeHost(void *ptr) override { void FreeHost(void *ptr) override {
auto iter = host_mem_size_.find(ptr); auto iter = host_mem_size_.find(ptr);
if (iter != host_mem_size_.end()) { if (iter != host_mem_size_.end()) {
host_mem_size_.erase(iter); host_mem_size_.erase(iter);
} }
} }
void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {} void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {}
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {} void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {}
private: private:
std::vector<uint8_t> device_mem_; std::vector<uint8_t> device_mem_;
std::vector<uint8_t> host_mem_; std::vector<uint8_t> host_mem_;
size_t device_virtual_count_; size_t device_virtual_count_{0};
size_t host_virtual_count_; size_t host_virtual_count_{0};
std::map<void *, size_t> device_mem_size_; std::map<void *, size_t> device_mem_size_;
std::map<void *, size_t> host_mem_size_; std::map<void *, size_t> host_mem_size_;
}; };
@ -67,6 +78,47 @@ class MemHandlerImpl : public MemHandler {
class TestMemScheduler : public UT::Common { class TestMemScheduler : public UT::Common {
public: public:
TestMemScheduler() {} TestMemScheduler() {}
protected:
size_t used_tensor_num_{1};
size_t total_step_{1};
std::vector<uint8_t> tensor_keys_;
std::vector<uint8_t> tensor_datas_;
std::vector<size_t> init_tensors_;
std::vector<std::vector<size_t>> step_used_tensors_;
void Record(const std::shared_ptr<MemScheduler> &scheduler) {
void *stream = nullptr;
for (auto index : init_tensors_) {
scheduler->Init(tensor_keys_.data() + index, tensor_datas_.data() + index, 1, kMemPriorityHigh);
}
for (size_t i = 0; i < total_step_; ++i) {
auto &tensors = step_used_tensors_[i];
for (auto j : tensors) {
scheduler->GetOrMalloc(tensor_keys_.data() + j, 1);
}
scheduler->PostCompute(stream);
}
scheduler->set_need_record_event(false);
}
void Run(const std::shared_ptr<MemScheduler> &scheduler) {
void *stream = nullptr;
scheduler->Reset();
scheduler->Update();
for (auto index : init_tensors_) {
scheduler->Init(tensor_keys_.data() + index, tensor_datas_.data() + index, 1, kMemPriorityHigh);
}
for (size_t i = 0; i < total_step_; ++i) {
scheduler->PreCompute(stream);
auto &tensors = step_used_tensors_[i];
for (auto j : tensors) {
auto addr = scheduler->GetOrMalloc(tensor_keys_.data() + j, 1);
ASSERT_NE(addr, nullptr);
}
scheduler->PostCompute(stream);
}
}
}; };
/// Feature: MemSchedulerManager /// Feature: MemSchedulerManager
@ -91,49 +143,93 @@ TEST_F(TestMemScheduler, test_mem_scheduler) {
ASSERT_NE(scheduler, nullptr); ASSERT_NE(scheduler, nullptr);
auto need_record = scheduler->need_record_event(); auto need_record = scheduler->need_record_event();
ASSERT_EQ(need_record, true); ASSERT_EQ(need_record, true);
auto optimized = scheduler->optimized();
ASSERT_EQ(optimized, false);
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>(); std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>();
ASSERT_NE(mem_handler, nullptr); ASSERT_NE(mem_handler, nullptr);
scheduler->SetMemHandler(mem_handler); scheduler->SetMemHandler(mem_handler);
constexpr size_t kUsedTensors = 10; // input data
constexpr size_t kTimeSlice = 7; used_tensor_num_ = 10;
std::vector<uint8_t> tensor_keys(kUsedTensors, 0); total_step_ = 8;
std::vector<uint8_t> tensor_datas(kUsedTensors, 0); std::vector<uint8_t> tensor_keys(used_tensor_num_, 0);
std::vector<uint8_t> tensor_datas(used_tensor_num_, 0);
std::vector<size_t> init_tensors = {0, 2, 4}; std::vector<size_t> init_tensors = {0, 2, 4};
std::vector<std::vector<size_t>> step_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6}, {6, 7}, {2, 7, 8}, {4, 8, 9}}; // 8 step tensor usage
void *stream = nullptr; //
scheduler->SetTotalStep(kTimeSlice); // 0
// record // 1 1-----------------1
for (auto index : init_tensors) { // 2--------------2
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh); // 3 3--------3
} // 4-----4
for (size_t i = 0; i < kTimeSlice; ++i) { // 5 5
auto &tensors = step_tensors[i]; // 6 6
for (auto j : tensors) { // 7 7
scheduler->GetOrMalloc(tensor_keys.data() + j, 1); // 8 8
} // 9 9
scheduler->PostCompute(stream); std::vector<std::vector<size_t>> step_used_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6},
} {4, 6, 7}, {3, 7, 8}, {2, 8, 9}, {1, 9}};
scheduler->set_need_record_event(false); tensor_keys_.swap(tensor_keys);
tensor_datas_.swap(tensor_datas);
init_tensors_.swap(init_tensors);
step_used_tensors_.swap(step_used_tensors);
scheduler->SetTotalStep(total_step_);
// record
Record(scheduler);
// optimize // optimize
scheduler->Optimize(); scheduler->Optimize();
// run // run
scheduler->ResetCurrentStep(); Run(scheduler);
for (auto index : init_tensors) {
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
} }
for (size_t i = 0; i < kTimeSlice; ++i) {
scheduler->PreCompute(stream); /// Feature: MemScheduler
auto &tensors = step_tensors[i]; /// Description: Test MemScheduler interface
for (auto j : tensors) { /// Expectation: MemScheduler GetOrMalloc return valid ptr
auto addr = scheduler->GetOrMalloc(tensor_keys.data() + j, 1); TEST_F(TestMemScheduler, test_manual_mem_scheduler) {
ASSERT_NE(addr, nullptr); MemSchedulerManager mem_scheduler_manager;
} auto scheduler = mem_scheduler_manager.GetOrCreateMemScheduler(0);
scheduler->PostCompute(stream); ASSERT_NE(scheduler, nullptr);
auto need_record = scheduler->need_record_event();
ASSERT_EQ(need_record, true);
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>();
ASSERT_NE(mem_handler, nullptr);
scheduler->SetMemHandler(mem_handler);
// input data
used_tensor_num_ = 10;
total_step_ = 8;
std::vector<uint8_t> tensor_keys(used_tensor_num_, 0);
std::vector<uint8_t> tensor_datas(used_tensor_num_, 0);
std::vector<size_t> init_tensors = {0, 2, 4};
std::vector<size_t> offload_tensor = {1, 2, 3};
// 8 step tensor usage
//
// 0
// 1 1-----------------1
// 2--------------2
// 3 3--------3
// 4-----4
// 5 5
// 6 6
// 7 7
// 8 8
// 9 9
std::vector<std::vector<size_t>> step_used_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6},
{4, 6, 7}, {3, 7, 8}, {2, 8, 9}, {1, 9}};
tensor_keys_.swap(tensor_keys);
tensor_datas_.swap(tensor_datas);
init_tensors_.swap(init_tensors);
step_used_tensors_.swap(step_used_tensors);
scheduler->SetTotalStep(total_step_);
// set offload key
for (auto index : offload_tensor) {
scheduler->SetOffload(tensor_keys_.data() + index);
} }
// record
Record(scheduler);
// optimize
scheduler->Optimize();
// run
Run(scheduler);
} }
} // namespace mindspore::device } // namespace mindspore::device