set offload node

This commit is contained in:
kswang 2021-12-02 20:55:22 +08:00
parent 6924866a87
commit 391a06aad1
8 changed files with 226 additions and 71 deletions

View File

@ -1449,6 +1449,13 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
return ret;
}
AssignKernelAddress(mem_scheduler, kernel, &kernel_launch_info);
auto cnode = kernel->cast<CNodePtr>();
if (mock && AnfAlgo::HasNodeAttr(kAttrOffload, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrOffload)) {
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
auto device_address = AnfAlgo::GetOutputAddr(kernel, i, true);
mem_scheduler->SetOffload(device_address);
}
}
} else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) {
kernel_launch_info.inputs_ = kernel_mod->GetInputsAddr();
kernel_launch_info.outputs_ = kernel_mod->GetOutputsAddr();
@ -1462,15 +1469,15 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
} else {
ret = kernel_mod->Launch(kernel_launch_info, stream);
}
if (!ret) {
return ret;
}
}
if (mem_scheduler != nullptr) {
if (!mock) {
SyncNodeOutputTensors(mem_scheduler, graph, kernel);
}
ret = mem_scheduler->PostCompute(stream);
if (!ret) {
return ret;
}
}
return ret;
}
@ -1483,7 +1490,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
if (UseMemScheduler()) {
mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
MS_EXCEPTION_IF_NULL(mem_scheduler);
mem_scheduler->ResetCurrentStep();
mem_scheduler->Reset();
mem_scheduler->Update();
InitGraphInputTensors(mem_scheduler, graph);
}
@ -1594,8 +1601,8 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
(void)LaunchKernelMod(graph, true);
mem_scheduler->set_need_record_event(false);
}
mem_scheduler->Optimize();
if (!mem_scheduler->optimized()) {
auto ret = mem_scheduler->Optimize();
if (!ret) {
MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit.";
}
}

View File

@ -16,7 +16,6 @@
#include "runtime/device/memory_offload_strategy.h"
#include <vector>
#include <map>
#include <set>
#include <memory>
#include <utility>
#include "utils/log_adapter.h"
@ -100,7 +99,7 @@ void MemOffloadStrategy::CheckMemSize() {
<< min_mem_needed_;
}
if (mem_size_ < mem_used_without_swap_) {
if (mem_size_ < mem_used_without_swap_ || !manual_offload_keys_.empty()) {
need_swap_ = true;
}
@ -141,6 +140,18 @@ void MemOffloadStrategy::GenEventSpan() {
void MemOffloadStrategy::GenSwapEventSet() {
swap_events_.clear();
// manual offload strategy
if (!manual_offload_keys_.empty()) {
for (const auto &iter : event_span_) {
auto &event = iter.second.first;
if (manual_offload_keys_.find(event->key) != manual_offload_keys_.end()) {
(void)swap_events_.emplace(event);
}
}
return;
}
// greedy span filter
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
for (const auto &iter : event_span_) {
auto span = iter.second.second;
@ -179,9 +190,6 @@ void MemOffloadStrategy::GenComputeMemEvents() {
post_compute_events_.resize(total_step_);
for (auto &item : mem_events_) {
auto &mem_events = item.second;
if (mem_events.empty()) {
continue;
}
// No need to generate events for memory that has only one event, which means it is never used by any kernel.
if (mem_events.size() <= 1) {
continue;
@ -211,10 +219,13 @@ void MemOffloadStrategy::GenComputeMemEvents() {
swap_out_event->key = item.first;
swap_out_event->mem_size = first_event->mem_size;
post_compute_events_[pre_index].emplace_back(swap_out_event);
auto swap_in_event = std::make_shared<MemEvent>(kSwapIn, event->index);
swap_in_event->key = item.first;
swap_in_event->mem_size = first_event->mem_size;
(void)pre_compute_events_[event->index].emplace_back(swap_in_event);
// avoid swap-in-event follow init-event
if (first_event->type != kInit || i != 1) {
auto swap_in_event = std::make_shared<MemEvent>(kSwapIn, event->index);
swap_in_event->key = item.first;
swap_in_event->mem_size = first_event->mem_size;
(void)pre_compute_events_[event->index].emplace_back(swap_in_event);
}
}
if (event->index < pre_compute_events_.size()) {
(void)pre_compute_events_[event->index].emplace_back(event);

View File

@ -41,8 +41,11 @@ class MemOffloadStrategy {
public:
MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority,
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events,
size_t total_step)
: mem_priority_(mem_priority), mem_events_(mem_events), total_step_(total_step) {}
const std::set<const void *> &manual_offload_keys, size_t total_step)
: mem_priority_(mem_priority),
mem_events_(mem_events),
manual_offload_keys_(manual_offload_keys),
total_step_(total_step) {}
virtual ~MemOffloadStrategy() = default;
@ -58,18 +61,24 @@ class MemOffloadStrategy {
bool need_swap() const { return need_swap_; }
private:
bool IsHighPriorityMem(const void *key);
private:
void CountMemUsage();
void CheckMemSize();
void GenEventSpan();
void GenSwapEventSet();
void GenComputeMemEvents();
void GenFreeEvent(const std::shared_ptr<MemEvent> &last_event);
const std::map<const void *, MemPriority> &mem_priority_;
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_;
const std::set<const void *> &manual_offload_keys_;
const size_t total_step_;
std::vector<std::vector<std::shared_ptr<MemEvent>>> pre_compute_events_;
std::vector<std::vector<std::shared_ptr<MemEvent>>> post_compute_events_;

View File

@ -16,12 +16,12 @@
#include "runtime/device/memory_scheduler.h"
#include <algorithm>
#include "utils/log_adapter.h"
#ifdef _MSC_VER
#include <time.h>
#else
#include <sys/time.h>
#endif
#include "utils/log_adapter.h"
namespace mindspore {
namespace device {
@ -51,7 +51,7 @@ void MemScheduler::Clear() {
high_priority_device_ptr_.clear();
}
void MemScheduler::ClearTempMem() {
void MemScheduler::ClearAllocatedMem() {
if (mem_handler_ == nullptr) {
return;
}
@ -72,8 +72,6 @@ void MemScheduler::ClearTempMem() {
swap_host_ptr_.clear();
}
void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
if (key == nullptr) {
return;
@ -184,7 +182,7 @@ bool MemScheduler::PostCompute(void *stream) {
return true;
}
if (record_compute_time_ && !updated_) {
if (record_compute_time_ && !updated_ && current_step_ < compute_time_.size()) {
compute_time_[current_step_] = GetCurrentTime() - compute_start_time_;
}
@ -227,8 +225,12 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
MS_EXCEPTION_IF_NULL(mem_handler_);
if (strategy_ == nullptr) {
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, total_step_);
compute_time_.resize(total_step_);
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_, total_step_);
if (manual_offload_keys_.empty()) {
compute_time_.resize(total_step_);
} else {
updated_ = true;
}
}
auto available_mem_size = mem_handler_->GetAvailableMemSize();
@ -237,7 +239,7 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
strategy_->Execute();
}
void MemScheduler::Optimize() {
bool MemScheduler::Optimize() {
AdjustFirstEventIndex();
float mem_used_factor = kMaxMemReuseFactor;
while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) {
@ -265,10 +267,11 @@ void MemScheduler::Optimize() {
if (ret) {
optimized_ = true;
} else {
ClearTempMem();
ClearAllocatedMem();
mem_used_factor -= kRetryFactor;
}
}
return optimized_;
}
void MemScheduler::AdjustFirstEventIndex() {

View File

@ -45,8 +45,6 @@ class MemScheduler {
void set_need_record_event(bool flag) { need_record_event_ = flag; }
bool optimized() const { return optimized_; }
void Update();
void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
@ -60,19 +58,19 @@ class MemScheduler {
step_events_.resize(total_step_);
}
void ResetCurrentStep() { current_step_ = 0; }
void Reset() { current_step_ = 0; }
bool PreCompute(void *stream);
bool PostCompute(void *stream);
void Optimize();
bool Optimize();
void Clear();
void ClearTempMem();
void ClearAllocatedMem();
void SetMemPriority(const void *key, MemPriority priority);
void SetOffload(const void *key) { (void)manual_offload_keys_.insert(key); }
private:
void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
@ -83,6 +81,7 @@ class MemScheduler {
std::map<const void *, MemPriority> mem_priority_;
std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_;
std::set<const void *> manual_offload_keys_;
std::vector<std::vector<std::shared_ptr<MemEvent>>> step_events_;
std::map<const void *, void *> mem_result_;
std::map<const void *, void *> init_host_ptr_;

View File

@ -349,6 +349,7 @@ constexpr auto kAttrIsAICPUKernel = "is_AICPU_kernel";
constexpr auto kIsBackendCast = "is_backed_cast";
constexpr auto kAttrOutputNames = "output_names";
constexpr auto kAttrAsync = "async";
constexpr auto kAttrOffload = "offload";
constexpr auto kAttrVisited = "visited";
constexpr auto kAttrShape = "shape";
constexpr auto kAttrMomentum = "momentum";

View File

@ -123,3 +123,32 @@ def test_lenet():
diff = res.asnumpy()[0] - 2.3025851
assert np.all(diff < 1.e-6)
os.environ['ENABLE_MEM_SCHEDULER'] = '0'
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_lenet_manual_offload():
'''
Feature: MemScheduler
Description: Test set offload strategy
Expectation: Run lenet success
'''
os.environ['ENABLE_MEM_SCHEDULER'] = '1'
data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([32]).astype(np.int32))
net = LeNet()
net.relu.add_prim_attr("Offload", True)
learning_rate = 0.01
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()
res = train_network(data, label)
diff = res.asnumpy()[0] - 2.3025851
assert np.all(diff < 1.e-6)
os.environ['ENABLE_MEM_SCHEDULER'] = '0'

View File

@ -19,47 +19,58 @@
#include "common/common_test.h"
#include "runtime/device/memory_scheduler.h"
namespace mindspore::device {
constexpr size_t kDeviceMemSize = 1 * 1024 * 1024 * 1024;
constexpr size_t kMaxVirtualCount = 1 * 1024 * 1024;
constexpr size_t kDeviceMemSize = 5;
constexpr size_t kMaxVirtualCount = 1024;
class MemHandlerImpl : public MemHandler {
public:
MemHandlerImpl() {
device_mem_.resize(kMaxVirtualCount, 0);
host_mem_.resize(kMaxVirtualCount, 1);
}
size_t GetAvailableMemSize() override { return kDeviceMemSize; }
void *MallocDevice(size_t mem_size) override {
if (device_virtual_count_ >= kDeviceMemSize) {
return nullptr;
}
auto ret = device_mem_.data() + device_virtual_count_;
++device_virtual_count_;
device_mem_size_.emplace(ret, mem_size);
return ret;
}
void FreeDevice(void *ptr) override {
--device_virtual_count_;
auto iter = device_mem_size_.find(ptr);
if (iter != device_mem_size_.end()) {
device_mem_size_.erase(iter);
}
}
void *MallocHost(size_t mem_size) override {
auto ret = host_mem_.data() + host_virtual_count_;
++host_virtual_count_;
host_mem_size_.emplace(ret, mem_size);
return ret;
}
void FreeHost(void *ptr) override {
auto iter = host_mem_size_.find(ptr);
if (iter != host_mem_size_.end()) {
host_mem_size_.erase(iter);
}
}
void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {}
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {}
private:
std::vector<uint8_t> device_mem_;
std::vector<uint8_t> host_mem_;
size_t device_virtual_count_;
size_t host_virtual_count_;
size_t device_virtual_count_{0};
size_t host_virtual_count_{0};
std::map<void *, size_t> device_mem_size_;
std::map<void *, size_t> host_mem_size_;
};
@ -67,6 +78,47 @@ class MemHandlerImpl : public MemHandler {
class TestMemScheduler : public UT::Common {
public:
TestMemScheduler() {}
protected:
size_t used_tensor_num_{1};
size_t total_step_{1};
std::vector<uint8_t> tensor_keys_;
std::vector<uint8_t> tensor_datas_;
std::vector<size_t> init_tensors_;
std::vector<std::vector<size_t>> step_used_tensors_;
void Record(const std::shared_ptr<MemScheduler> &scheduler) {
void *stream = nullptr;
for (auto index : init_tensors_) {
scheduler->Init(tensor_keys_.data() + index, tensor_datas_.data() + index, 1, kMemPriorityHigh);
}
for (size_t i = 0; i < total_step_; ++i) {
auto &tensors = step_used_tensors_[i];
for (auto j : tensors) {
scheduler->GetOrMalloc(tensor_keys_.data() + j, 1);
}
scheduler->PostCompute(stream);
}
scheduler->set_need_record_event(false);
}
void Run(const std::shared_ptr<MemScheduler> &scheduler) {
void *stream = nullptr;
scheduler->Reset();
scheduler->Update();
for (auto index : init_tensors_) {
scheduler->Init(tensor_keys_.data() + index, tensor_datas_.data() + index, 1, kMemPriorityHigh);
}
for (size_t i = 0; i < total_step_; ++i) {
scheduler->PreCompute(stream);
auto &tensors = step_used_tensors_[i];
for (auto j : tensors) {
auto addr = scheduler->GetOrMalloc(tensor_keys_.data() + j, 1);
ASSERT_NE(addr, nullptr);
}
scheduler->PostCompute(stream);
}
}
};
/// Feature: MemSchedulerManager
@ -91,49 +143,93 @@ TEST_F(TestMemScheduler, test_mem_scheduler) {
ASSERT_NE(scheduler, nullptr);
auto need_record = scheduler->need_record_event();
ASSERT_EQ(need_record, true);
auto optimized = scheduler->optimized();
ASSERT_EQ(optimized, false);
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>();
ASSERT_NE(mem_handler, nullptr);
scheduler->SetMemHandler(mem_handler);
constexpr size_t kUsedTensors = 10;
constexpr size_t kTimeSlice = 7;
std::vector<uint8_t> tensor_keys(kUsedTensors, 0);
std::vector<uint8_t> tensor_datas(kUsedTensors, 0);
// input data
used_tensor_num_ = 10;
total_step_ = 8;
std::vector<uint8_t> tensor_keys(used_tensor_num_, 0);
std::vector<uint8_t> tensor_datas(used_tensor_num_, 0);
std::vector<size_t> init_tensors = {0, 2, 4};
std::vector<std::vector<size_t>> step_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6}, {6, 7}, {2, 7, 8}, {4, 8, 9}};
void *stream = nullptr;
scheduler->SetTotalStep(kTimeSlice);
// record
for (auto index : init_tensors) {
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
}
for (size_t i = 0; i < kTimeSlice; ++i) {
auto &tensors = step_tensors[i];
for (auto j : tensors) {
scheduler->GetOrMalloc(tensor_keys.data() + j, 1);
}
scheduler->PostCompute(stream);
}
scheduler->set_need_record_event(false);
// 8 step tensor usage
//
// 0
// 1 1-----------------1
// 2--------------2
// 3 3--------3
// 4-----4
// 5 5
// 6 6
// 7 7
// 8 8
// 9 9
std::vector<std::vector<size_t>> step_used_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6},
{4, 6, 7}, {3, 7, 8}, {2, 8, 9}, {1, 9}};
tensor_keys_.swap(tensor_keys);
tensor_datas_.swap(tensor_datas);
init_tensors_.swap(init_tensors);
step_used_tensors_.swap(step_used_tensors);
scheduler->SetTotalStep(total_step_);
// record
Record(scheduler);
// optimize
scheduler->Optimize();
// run
scheduler->ResetCurrentStep();
for (auto index : init_tensors) {
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
}
for (size_t i = 0; i < kTimeSlice; ++i) {
scheduler->PreCompute(stream);
auto &tensors = step_tensors[i];
for (auto j : tensors) {
auto addr = scheduler->GetOrMalloc(tensor_keys.data() + j, 1);
ASSERT_NE(addr, nullptr);
}
scheduler->PostCompute(stream);
Run(scheduler);
}
/// Feature: MemScheduler
/// Description: Test MemScheduler interface
/// Expectation: MemScheduler GetOrMalloc return valid ptr
TEST_F(TestMemScheduler, test_manual_mem_scheduler) {
MemSchedulerManager mem_scheduler_manager;
auto scheduler = mem_scheduler_manager.GetOrCreateMemScheduler(0);
ASSERT_NE(scheduler, nullptr);
auto need_record = scheduler->need_record_event();
ASSERT_EQ(need_record, true);
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>();
ASSERT_NE(mem_handler, nullptr);
scheduler->SetMemHandler(mem_handler);
// input data
used_tensor_num_ = 10;
total_step_ = 8;
std::vector<uint8_t> tensor_keys(used_tensor_num_, 0);
std::vector<uint8_t> tensor_datas(used_tensor_num_, 0);
std::vector<size_t> init_tensors = {0, 2, 4};
std::vector<size_t> offload_tensor = {1, 2, 3};
// 8 step tensor usage
//
// 0
// 1 1-----------------1
// 2--------------2
// 3 3--------3
// 4-----4
// 5 5
// 6 6
// 7 7
// 8 8
// 9 9
std::vector<std::vector<size_t>> step_used_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6},
{4, 6, 7}, {3, 7, 8}, {2, 8, 9}, {1, 9}};
tensor_keys_.swap(tensor_keys);
tensor_datas_.swap(tensor_datas);
init_tensors_.swap(init_tensors);
step_used_tensors_.swap(step_used_tensors);
scheduler->SetTotalStep(total_step_);
// set offload key
for (auto index : offload_tensor) {
scheduler->SetOffload(tensor_keys_.data() + index);
}
// record
Record(scheduler);
// optimize
scheduler->Optimize();
// run
Run(scheduler);
}
} // namespace mindspore::device