forked from mindspore-Ecosystem/mindspore
set offload node
This commit is contained in:
parent
6924866a87
commit
391a06aad1
|
@ -1449,6 +1449,13 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
|
|||
return ret;
|
||||
}
|
||||
AssignKernelAddress(mem_scheduler, kernel, &kernel_launch_info);
|
||||
auto cnode = kernel->cast<CNodePtr>();
|
||||
if (mock && AnfAlgo::HasNodeAttr(kAttrOffload, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrOffload)) {
|
||||
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
|
||||
auto device_address = AnfAlgo::GetOutputAddr(kernel, i, true);
|
||||
mem_scheduler->SetOffload(device_address);
|
||||
}
|
||||
}
|
||||
} else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) {
|
||||
kernel_launch_info.inputs_ = kernel_mod->GetInputsAddr();
|
||||
kernel_launch_info.outputs_ = kernel_mod->GetOutputsAddr();
|
||||
|
@ -1462,15 +1469,15 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
|
|||
} else {
|
||||
ret = kernel_mod->Launch(kernel_launch_info, stream);
|
||||
}
|
||||
if (!ret) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
if (mem_scheduler != nullptr) {
|
||||
if (!mock) {
|
||||
SyncNodeOutputTensors(mem_scheduler, graph, kernel);
|
||||
}
|
||||
ret = mem_scheduler->PostCompute(stream);
|
||||
if (!ret) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
@ -1483,7 +1490,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
|
|||
if (UseMemScheduler()) {
|
||||
mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
|
||||
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||
mem_scheduler->ResetCurrentStep();
|
||||
mem_scheduler->Reset();
|
||||
mem_scheduler->Update();
|
||||
InitGraphInputTensors(mem_scheduler, graph);
|
||||
}
|
||||
|
@ -1594,8 +1601,8 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
|
|||
(void)LaunchKernelMod(graph, true);
|
||||
mem_scheduler->set_need_record_event(false);
|
||||
}
|
||||
mem_scheduler->Optimize();
|
||||
if (!mem_scheduler->optimized()) {
|
||||
auto ret = mem_scheduler->Optimize();
|
||||
if (!ret) {
|
||||
MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit.";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
#include "runtime/device/memory_offload_strategy.h"
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "utils/log_adapter.h"
|
||||
|
@ -100,7 +99,7 @@ void MemOffloadStrategy::CheckMemSize() {
|
|||
<< min_mem_needed_;
|
||||
}
|
||||
|
||||
if (mem_size_ < mem_used_without_swap_) {
|
||||
if (mem_size_ < mem_used_without_swap_ || !manual_offload_keys_.empty()) {
|
||||
need_swap_ = true;
|
||||
}
|
||||
|
||||
|
@ -141,6 +140,18 @@ void MemOffloadStrategy::GenEventSpan() {
|
|||
|
||||
void MemOffloadStrategy::GenSwapEventSet() {
|
||||
swap_events_.clear();
|
||||
// manual offload strategy
|
||||
if (!manual_offload_keys_.empty()) {
|
||||
for (const auto &iter : event_span_) {
|
||||
auto &event = iter.second.first;
|
||||
if (manual_offload_keys_.find(event->key) != manual_offload_keys_.end()) {
|
||||
(void)swap_events_.emplace(event);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// greedy span filter
|
||||
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
|
||||
for (const auto &iter : event_span_) {
|
||||
auto span = iter.second.second;
|
||||
|
@ -179,9 +190,6 @@ void MemOffloadStrategy::GenComputeMemEvents() {
|
|||
post_compute_events_.resize(total_step_);
|
||||
for (auto &item : mem_events_) {
|
||||
auto &mem_events = item.second;
|
||||
if (mem_events.empty()) {
|
||||
continue;
|
||||
}
|
||||
// No need to generate events for memory that has only one event, which means it is never used by any kernel.
|
||||
if (mem_events.size() <= 1) {
|
||||
continue;
|
||||
|
@ -211,10 +219,13 @@ void MemOffloadStrategy::GenComputeMemEvents() {
|
|||
swap_out_event->key = item.first;
|
||||
swap_out_event->mem_size = first_event->mem_size;
|
||||
post_compute_events_[pre_index].emplace_back(swap_out_event);
|
||||
auto swap_in_event = std::make_shared<MemEvent>(kSwapIn, event->index);
|
||||
swap_in_event->key = item.first;
|
||||
swap_in_event->mem_size = first_event->mem_size;
|
||||
(void)pre_compute_events_[event->index].emplace_back(swap_in_event);
|
||||
// avoid swap-in-event follow init-event
|
||||
if (first_event->type != kInit || i != 1) {
|
||||
auto swap_in_event = std::make_shared<MemEvent>(kSwapIn, event->index);
|
||||
swap_in_event->key = item.first;
|
||||
swap_in_event->mem_size = first_event->mem_size;
|
||||
(void)pre_compute_events_[event->index].emplace_back(swap_in_event);
|
||||
}
|
||||
}
|
||||
if (event->index < pre_compute_events_.size()) {
|
||||
(void)pre_compute_events_[event->index].emplace_back(event);
|
||||
|
|
|
@ -41,8 +41,11 @@ class MemOffloadStrategy {
|
|||
public:
|
||||
MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority,
|
||||
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events,
|
||||
size_t total_step)
|
||||
: mem_priority_(mem_priority), mem_events_(mem_events), total_step_(total_step) {}
|
||||
const std::set<const void *> &manual_offload_keys, size_t total_step)
|
||||
: mem_priority_(mem_priority),
|
||||
mem_events_(mem_events),
|
||||
manual_offload_keys_(manual_offload_keys),
|
||||
total_step_(total_step) {}
|
||||
|
||||
virtual ~MemOffloadStrategy() = default;
|
||||
|
||||
|
@ -58,18 +61,24 @@ class MemOffloadStrategy {
|
|||
|
||||
bool need_swap() const { return need_swap_; }
|
||||
|
||||
private:
|
||||
bool IsHighPriorityMem(const void *key);
|
||||
|
||||
private:
|
||||
void CountMemUsage();
|
||||
|
||||
void CheckMemSize();
|
||||
|
||||
void GenEventSpan();
|
||||
|
||||
void GenSwapEventSet();
|
||||
|
||||
void GenComputeMemEvents();
|
||||
|
||||
void GenFreeEvent(const std::shared_ptr<MemEvent> &last_event);
|
||||
|
||||
const std::map<const void *, MemPriority> &mem_priority_;
|
||||
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_;
|
||||
const std::set<const void *> &manual_offload_keys_;
|
||||
const size_t total_step_;
|
||||
std::vector<std::vector<std::shared_ptr<MemEvent>>> pre_compute_events_;
|
||||
std::vector<std::vector<std::shared_ptr<MemEvent>>> post_compute_events_;
|
||||
|
|
|
@ -16,12 +16,12 @@
|
|||
|
||||
#include "runtime/device/memory_scheduler.h"
|
||||
#include <algorithm>
|
||||
#include "utils/log_adapter.h"
|
||||
#ifdef _MSC_VER
|
||||
#include <time.h>
|
||||
#else
|
||||
#include <sys/time.h>
|
||||
#endif
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -51,7 +51,7 @@ void MemScheduler::Clear() {
|
|||
high_priority_device_ptr_.clear();
|
||||
}
|
||||
|
||||
void MemScheduler::ClearTempMem() {
|
||||
void MemScheduler::ClearAllocatedMem() {
|
||||
if (mem_handler_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
@ -72,8 +72,6 @@ void MemScheduler::ClearTempMem() {
|
|||
swap_host_ptr_.clear();
|
||||
}
|
||||
|
||||
void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
|
||||
|
||||
void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
|
||||
if (key == nullptr) {
|
||||
return;
|
||||
|
@ -184,7 +182,7 @@ bool MemScheduler::PostCompute(void *stream) {
|
|||
return true;
|
||||
}
|
||||
|
||||
if (record_compute_time_ && !updated_) {
|
||||
if (record_compute_time_ && !updated_ && current_step_ < compute_time_.size()) {
|
||||
compute_time_[current_step_] = GetCurrentTime() - compute_start_time_;
|
||||
}
|
||||
|
||||
|
@ -227,8 +225,12 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
|
|||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
||||
|
||||
if (strategy_ == nullptr) {
|
||||
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, total_step_);
|
||||
compute_time_.resize(total_step_);
|
||||
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_, total_step_);
|
||||
if (manual_offload_keys_.empty()) {
|
||||
compute_time_.resize(total_step_);
|
||||
} else {
|
||||
updated_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
auto available_mem_size = mem_handler_->GetAvailableMemSize();
|
||||
|
@ -237,7 +239,7 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
|
|||
strategy_->Execute();
|
||||
}
|
||||
|
||||
void MemScheduler::Optimize() {
|
||||
bool MemScheduler::Optimize() {
|
||||
AdjustFirstEventIndex();
|
||||
float mem_used_factor = kMaxMemReuseFactor;
|
||||
while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) {
|
||||
|
@ -265,10 +267,11 @@ void MemScheduler::Optimize() {
|
|||
if (ret) {
|
||||
optimized_ = true;
|
||||
} else {
|
||||
ClearTempMem();
|
||||
ClearAllocatedMem();
|
||||
mem_used_factor -= kRetryFactor;
|
||||
}
|
||||
}
|
||||
return optimized_;
|
||||
}
|
||||
|
||||
void MemScheduler::AdjustFirstEventIndex() {
|
||||
|
|
|
@ -45,8 +45,6 @@ class MemScheduler {
|
|||
|
||||
void set_need_record_event(bool flag) { need_record_event_ = flag; }
|
||||
|
||||
bool optimized() const { return optimized_; }
|
||||
|
||||
void Update();
|
||||
|
||||
void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
|
||||
|
@ -60,19 +58,19 @@ class MemScheduler {
|
|||
step_events_.resize(total_step_);
|
||||
}
|
||||
|
||||
void ResetCurrentStep() { current_step_ = 0; }
|
||||
void Reset() { current_step_ = 0; }
|
||||
|
||||
bool PreCompute(void *stream);
|
||||
|
||||
bool PostCompute(void *stream);
|
||||
|
||||
void Optimize();
|
||||
bool Optimize();
|
||||
|
||||
void Clear();
|
||||
|
||||
void ClearTempMem();
|
||||
void ClearAllocatedMem();
|
||||
|
||||
void SetMemPriority(const void *key, MemPriority priority);
|
||||
void SetOffload(const void *key) { (void)manual_offload_keys_.insert(key); }
|
||||
|
||||
private:
|
||||
void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
|
||||
|
@ -83,6 +81,7 @@ class MemScheduler {
|
|||
|
||||
std::map<const void *, MemPriority> mem_priority_;
|
||||
std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_;
|
||||
std::set<const void *> manual_offload_keys_;
|
||||
std::vector<std::vector<std::shared_ptr<MemEvent>>> step_events_;
|
||||
std::map<const void *, void *> mem_result_;
|
||||
std::map<const void *, void *> init_host_ptr_;
|
||||
|
|
|
@ -349,6 +349,7 @@ constexpr auto kAttrIsAICPUKernel = "is_AICPU_kernel";
|
|||
constexpr auto kIsBackendCast = "is_backed_cast";
|
||||
constexpr auto kAttrOutputNames = "output_names";
|
||||
constexpr auto kAttrAsync = "async";
|
||||
constexpr auto kAttrOffload = "offload";
|
||||
constexpr auto kAttrVisited = "visited";
|
||||
constexpr auto kAttrShape = "shape";
|
||||
constexpr auto kAttrMomentum = "momentum";
|
||||
|
|
|
@ -123,3 +123,32 @@ def test_lenet():
|
|||
diff = res.asnumpy()[0] - 2.3025851
|
||||
assert np.all(diff < 1.e-6)
|
||||
os.environ['ENABLE_MEM_SCHEDULER'] = '0'
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lenet_manual_offload():
|
||||
'''
|
||||
Feature: MemScheduler
|
||||
Description: Test set offload strategy
|
||||
Expectation: Run lenet success
|
||||
'''
|
||||
os.environ['ENABLE_MEM_SCHEDULER'] = '1'
|
||||
data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.ones([32]).astype(np.int32))
|
||||
net = LeNet()
|
||||
net.relu.add_prim_attr("Offload", True)
|
||||
learning_rate = 0.01
|
||||
momentum = 0.9
|
||||
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
|
||||
train_network.set_train()
|
||||
res = train_network(data, label)
|
||||
diff = res.asnumpy()[0] - 2.3025851
|
||||
assert np.all(diff < 1.e-6)
|
||||
os.environ['ENABLE_MEM_SCHEDULER'] = '0'
|
||||
|
|
|
@ -19,47 +19,58 @@
|
|||
#include "common/common_test.h"
|
||||
#include "runtime/device/memory_scheduler.h"
|
||||
namespace mindspore::device {
|
||||
constexpr size_t kDeviceMemSize = 1 * 1024 * 1024 * 1024;
|
||||
constexpr size_t kMaxVirtualCount = 1 * 1024 * 1024;
|
||||
constexpr size_t kDeviceMemSize = 5;
|
||||
constexpr size_t kMaxVirtualCount = 1024;
|
||||
class MemHandlerImpl : public MemHandler {
|
||||
public:
|
||||
MemHandlerImpl() {
|
||||
device_mem_.resize(kMaxVirtualCount, 0);
|
||||
host_mem_.resize(kMaxVirtualCount, 1);
|
||||
}
|
||||
|
||||
size_t GetAvailableMemSize() override { return kDeviceMemSize; }
|
||||
|
||||
void *MallocDevice(size_t mem_size) override {
|
||||
if (device_virtual_count_ >= kDeviceMemSize) {
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = device_mem_.data() + device_virtual_count_;
|
||||
++device_virtual_count_;
|
||||
device_mem_size_.emplace(ret, mem_size);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void FreeDevice(void *ptr) override {
|
||||
--device_virtual_count_;
|
||||
auto iter = device_mem_size_.find(ptr);
|
||||
if (iter != device_mem_size_.end()) {
|
||||
device_mem_size_.erase(iter);
|
||||
}
|
||||
}
|
||||
|
||||
void *MallocHost(size_t mem_size) override {
|
||||
auto ret = host_mem_.data() + host_virtual_count_;
|
||||
++host_virtual_count_;
|
||||
host_mem_size_.emplace(ret, mem_size);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void FreeHost(void *ptr) override {
|
||||
auto iter = host_mem_size_.find(ptr);
|
||||
if (iter != host_mem_size_.end()) {
|
||||
host_mem_size_.erase(iter);
|
||||
}
|
||||
}
|
||||
|
||||
void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {}
|
||||
|
||||
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {}
|
||||
|
||||
private:
|
||||
std::vector<uint8_t> device_mem_;
|
||||
std::vector<uint8_t> host_mem_;
|
||||
size_t device_virtual_count_;
|
||||
size_t host_virtual_count_;
|
||||
size_t device_virtual_count_{0};
|
||||
size_t host_virtual_count_{0};
|
||||
std::map<void *, size_t> device_mem_size_;
|
||||
std::map<void *, size_t> host_mem_size_;
|
||||
};
|
||||
|
@ -67,6 +78,47 @@ class MemHandlerImpl : public MemHandler {
|
|||
class TestMemScheduler : public UT::Common {
|
||||
public:
|
||||
TestMemScheduler() {}
|
||||
|
||||
protected:
|
||||
size_t used_tensor_num_{1};
|
||||
size_t total_step_{1};
|
||||
std::vector<uint8_t> tensor_keys_;
|
||||
std::vector<uint8_t> tensor_datas_;
|
||||
std::vector<size_t> init_tensors_;
|
||||
std::vector<std::vector<size_t>> step_used_tensors_;
|
||||
|
||||
void Record(const std::shared_ptr<MemScheduler> &scheduler) {
|
||||
void *stream = nullptr;
|
||||
for (auto index : init_tensors_) {
|
||||
scheduler->Init(tensor_keys_.data() + index, tensor_datas_.data() + index, 1, kMemPriorityHigh);
|
||||
}
|
||||
for (size_t i = 0; i < total_step_; ++i) {
|
||||
auto &tensors = step_used_tensors_[i];
|
||||
for (auto j : tensors) {
|
||||
scheduler->GetOrMalloc(tensor_keys_.data() + j, 1);
|
||||
}
|
||||
scheduler->PostCompute(stream);
|
||||
}
|
||||
scheduler->set_need_record_event(false);
|
||||
}
|
||||
|
||||
void Run(const std::shared_ptr<MemScheduler> &scheduler) {
|
||||
void *stream = nullptr;
|
||||
scheduler->Reset();
|
||||
scheduler->Update();
|
||||
for (auto index : init_tensors_) {
|
||||
scheduler->Init(tensor_keys_.data() + index, tensor_datas_.data() + index, 1, kMemPriorityHigh);
|
||||
}
|
||||
for (size_t i = 0; i < total_step_; ++i) {
|
||||
scheduler->PreCompute(stream);
|
||||
auto &tensors = step_used_tensors_[i];
|
||||
for (auto j : tensors) {
|
||||
auto addr = scheduler->GetOrMalloc(tensor_keys_.data() + j, 1);
|
||||
ASSERT_NE(addr, nullptr);
|
||||
}
|
||||
scheduler->PostCompute(stream);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Feature: MemSchedulerManager
|
||||
|
@ -91,49 +143,93 @@ TEST_F(TestMemScheduler, test_mem_scheduler) {
|
|||
ASSERT_NE(scheduler, nullptr);
|
||||
auto need_record = scheduler->need_record_event();
|
||||
ASSERT_EQ(need_record, true);
|
||||
auto optimized = scheduler->optimized();
|
||||
ASSERT_EQ(optimized, false);
|
||||
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>();
|
||||
ASSERT_NE(mem_handler, nullptr);
|
||||
scheduler->SetMemHandler(mem_handler);
|
||||
|
||||
constexpr size_t kUsedTensors = 10;
|
||||
constexpr size_t kTimeSlice = 7;
|
||||
std::vector<uint8_t> tensor_keys(kUsedTensors, 0);
|
||||
std::vector<uint8_t> tensor_datas(kUsedTensors, 0);
|
||||
// input data
|
||||
used_tensor_num_ = 10;
|
||||
total_step_ = 8;
|
||||
std::vector<uint8_t> tensor_keys(used_tensor_num_, 0);
|
||||
std::vector<uint8_t> tensor_datas(used_tensor_num_, 0);
|
||||
std::vector<size_t> init_tensors = {0, 2, 4};
|
||||
std::vector<std::vector<size_t>> step_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6}, {6, 7}, {2, 7, 8}, {4, 8, 9}};
|
||||
void *stream = nullptr;
|
||||
scheduler->SetTotalStep(kTimeSlice);
|
||||
// record
|
||||
for (auto index : init_tensors) {
|
||||
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
|
||||
}
|
||||
for (size_t i = 0; i < kTimeSlice; ++i) {
|
||||
auto &tensors = step_tensors[i];
|
||||
for (auto j : tensors) {
|
||||
scheduler->GetOrMalloc(tensor_keys.data() + j, 1);
|
||||
}
|
||||
scheduler->PostCompute(stream);
|
||||
}
|
||||
scheduler->set_need_record_event(false);
|
||||
// 8 step tensor usage
|
||||
//
|
||||
// 0
|
||||
// 1 1-----------------1
|
||||
// 2--------------2
|
||||
// 3 3--------3
|
||||
// 4-----4
|
||||
// 5 5
|
||||
// 6 6
|
||||
// 7 7
|
||||
// 8 8
|
||||
// 9 9
|
||||
std::vector<std::vector<size_t>> step_used_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6},
|
||||
{4, 6, 7}, {3, 7, 8}, {2, 8, 9}, {1, 9}};
|
||||
tensor_keys_.swap(tensor_keys);
|
||||
tensor_datas_.swap(tensor_datas);
|
||||
init_tensors_.swap(init_tensors);
|
||||
step_used_tensors_.swap(step_used_tensors);
|
||||
scheduler->SetTotalStep(total_step_);
|
||||
|
||||
// record
|
||||
Record(scheduler);
|
||||
// optimize
|
||||
scheduler->Optimize();
|
||||
|
||||
// run
|
||||
scheduler->ResetCurrentStep();
|
||||
for (auto index : init_tensors) {
|
||||
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
|
||||
}
|
||||
for (size_t i = 0; i < kTimeSlice; ++i) {
|
||||
scheduler->PreCompute(stream);
|
||||
auto &tensors = step_tensors[i];
|
||||
for (auto j : tensors) {
|
||||
auto addr = scheduler->GetOrMalloc(tensor_keys.data() + j, 1);
|
||||
ASSERT_NE(addr, nullptr);
|
||||
}
|
||||
scheduler->PostCompute(stream);
|
||||
Run(scheduler);
|
||||
}
|
||||
|
||||
/// Feature: MemScheduler
|
||||
/// Description: Test MemScheduler interface
|
||||
/// Expectation: MemScheduler GetOrMalloc return valid ptr
|
||||
TEST_F(TestMemScheduler, test_manual_mem_scheduler) {
|
||||
MemSchedulerManager mem_scheduler_manager;
|
||||
auto scheduler = mem_scheduler_manager.GetOrCreateMemScheduler(0);
|
||||
ASSERT_NE(scheduler, nullptr);
|
||||
auto need_record = scheduler->need_record_event();
|
||||
ASSERT_EQ(need_record, true);
|
||||
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>();
|
||||
ASSERT_NE(mem_handler, nullptr);
|
||||
scheduler->SetMemHandler(mem_handler);
|
||||
|
||||
// input data
|
||||
used_tensor_num_ = 10;
|
||||
total_step_ = 8;
|
||||
std::vector<uint8_t> tensor_keys(used_tensor_num_, 0);
|
||||
std::vector<uint8_t> tensor_datas(used_tensor_num_, 0);
|
||||
std::vector<size_t> init_tensors = {0, 2, 4};
|
||||
std::vector<size_t> offload_tensor = {1, 2, 3};
|
||||
// 8 step tensor usage
|
||||
//
|
||||
// 0
|
||||
// 1 1-----------------1
|
||||
// 2--------------2
|
||||
// 3 3--------3
|
||||
// 4-----4
|
||||
// 5 5
|
||||
// 6 6
|
||||
// 7 7
|
||||
// 8 8
|
||||
// 9 9
|
||||
std::vector<std::vector<size_t>> step_used_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6},
|
||||
{4, 6, 7}, {3, 7, 8}, {2, 8, 9}, {1, 9}};
|
||||
tensor_keys_.swap(tensor_keys);
|
||||
tensor_datas_.swap(tensor_datas);
|
||||
init_tensors_.swap(init_tensors);
|
||||
step_used_tensors_.swap(step_used_tensors);
|
||||
scheduler->SetTotalStep(total_step_);
|
||||
|
||||
// set offload key
|
||||
for (auto index : offload_tensor) {
|
||||
scheduler->SetOffload(tensor_keys_.data() + index);
|
||||
}
|
||||
// record
|
||||
Record(scheduler);
|
||||
// optimize
|
||||
scheduler->Optimize();
|
||||
// run
|
||||
Run(scheduler);
|
||||
}
|
||||
} // namespace mindspore::device
|
Loading…
Reference in New Issue