forked from mindspore-Ecosystem/mindspore
set offload node
This commit is contained in:
parent
6924866a87
commit
391a06aad1
|
@ -1449,6 +1449,13 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
AssignKernelAddress(mem_scheduler, kernel, &kernel_launch_info);
|
AssignKernelAddress(mem_scheduler, kernel, &kernel_launch_info);
|
||||||
|
auto cnode = kernel->cast<CNodePtr>();
|
||||||
|
if (mock && AnfAlgo::HasNodeAttr(kAttrOffload, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrOffload)) {
|
||||||
|
for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
|
||||||
|
auto device_address = AnfAlgo::GetOutputAddr(kernel, i, true);
|
||||||
|
mem_scheduler->SetOffload(device_address);
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) {
|
} else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) {
|
||||||
kernel_launch_info.inputs_ = kernel_mod->GetInputsAddr();
|
kernel_launch_info.inputs_ = kernel_mod->GetInputsAddr();
|
||||||
kernel_launch_info.outputs_ = kernel_mod->GetOutputsAddr();
|
kernel_launch_info.outputs_ = kernel_mod->GetOutputsAddr();
|
||||||
|
@ -1462,15 +1469,15 @@ bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNod
|
||||||
} else {
|
} else {
|
||||||
ret = kernel_mod->Launch(kernel_launch_info, stream);
|
ret = kernel_mod->Launch(kernel_launch_info, stream);
|
||||||
}
|
}
|
||||||
|
if (!ret) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (mem_scheduler != nullptr) {
|
if (mem_scheduler != nullptr) {
|
||||||
if (!mock) {
|
if (!mock) {
|
||||||
SyncNodeOutputTensors(mem_scheduler, graph, kernel);
|
SyncNodeOutputTensors(mem_scheduler, graph, kernel);
|
||||||
}
|
}
|
||||||
ret = mem_scheduler->PostCompute(stream);
|
ret = mem_scheduler->PostCompute(stream);
|
||||||
if (!ret) {
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
@ -1483,7 +1490,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
|
||||||
if (UseMemScheduler()) {
|
if (UseMemScheduler()) {
|
||||||
mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
|
mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
|
||||||
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||||
mem_scheduler->ResetCurrentStep();
|
mem_scheduler->Reset();
|
||||||
mem_scheduler->Update();
|
mem_scheduler->Update();
|
||||||
InitGraphInputTensors(mem_scheduler, graph);
|
InitGraphInputTensors(mem_scheduler, graph);
|
||||||
}
|
}
|
||||||
|
@ -1594,8 +1601,8 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
|
||||||
(void)LaunchKernelMod(graph, true);
|
(void)LaunchKernelMod(graph, true);
|
||||||
mem_scheduler->set_need_record_event(false);
|
mem_scheduler->set_need_record_event(false);
|
||||||
}
|
}
|
||||||
mem_scheduler->Optimize();
|
auto ret = mem_scheduler->Optimize();
|
||||||
if (!mem_scheduler->optimized()) {
|
if (!ret) {
|
||||||
MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit.";
|
MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
#include "runtime/device/memory_offload_strategy.h"
|
#include "runtime/device/memory_offload_strategy.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <set>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
@ -100,7 +99,7 @@ void MemOffloadStrategy::CheckMemSize() {
|
||||||
<< min_mem_needed_;
|
<< min_mem_needed_;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mem_size_ < mem_used_without_swap_) {
|
if (mem_size_ < mem_used_without_swap_ || !manual_offload_keys_.empty()) {
|
||||||
need_swap_ = true;
|
need_swap_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,6 +140,18 @@ void MemOffloadStrategy::GenEventSpan() {
|
||||||
|
|
||||||
void MemOffloadStrategy::GenSwapEventSet() {
|
void MemOffloadStrategy::GenSwapEventSet() {
|
||||||
swap_events_.clear();
|
swap_events_.clear();
|
||||||
|
// manual offload strategy
|
||||||
|
if (!manual_offload_keys_.empty()) {
|
||||||
|
for (const auto &iter : event_span_) {
|
||||||
|
auto &event = iter.second.first;
|
||||||
|
if (manual_offload_keys_.find(event->key) != manual_offload_keys_.end()) {
|
||||||
|
(void)swap_events_.emplace(event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// greedy span filter
|
||||||
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
|
std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
|
||||||
for (const auto &iter : event_span_) {
|
for (const auto &iter : event_span_) {
|
||||||
auto span = iter.second.second;
|
auto span = iter.second.second;
|
||||||
|
@ -179,9 +190,6 @@ void MemOffloadStrategy::GenComputeMemEvents() {
|
||||||
post_compute_events_.resize(total_step_);
|
post_compute_events_.resize(total_step_);
|
||||||
for (auto &item : mem_events_) {
|
for (auto &item : mem_events_) {
|
||||||
auto &mem_events = item.second;
|
auto &mem_events = item.second;
|
||||||
if (mem_events.empty()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// No need to generate events for memory that has only one event, which means it is never used by any kernel.
|
// No need to generate events for memory that has only one event, which means it is never used by any kernel.
|
||||||
if (mem_events.size() <= 1) {
|
if (mem_events.size() <= 1) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -211,10 +219,13 @@ void MemOffloadStrategy::GenComputeMemEvents() {
|
||||||
swap_out_event->key = item.first;
|
swap_out_event->key = item.first;
|
||||||
swap_out_event->mem_size = first_event->mem_size;
|
swap_out_event->mem_size = first_event->mem_size;
|
||||||
post_compute_events_[pre_index].emplace_back(swap_out_event);
|
post_compute_events_[pre_index].emplace_back(swap_out_event);
|
||||||
auto swap_in_event = std::make_shared<MemEvent>(kSwapIn, event->index);
|
// avoid swap-in-event follow init-event
|
||||||
swap_in_event->key = item.first;
|
if (first_event->type != kInit || i != 1) {
|
||||||
swap_in_event->mem_size = first_event->mem_size;
|
auto swap_in_event = std::make_shared<MemEvent>(kSwapIn, event->index);
|
||||||
(void)pre_compute_events_[event->index].emplace_back(swap_in_event);
|
swap_in_event->key = item.first;
|
||||||
|
swap_in_event->mem_size = first_event->mem_size;
|
||||||
|
(void)pre_compute_events_[event->index].emplace_back(swap_in_event);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (event->index < pre_compute_events_.size()) {
|
if (event->index < pre_compute_events_.size()) {
|
||||||
(void)pre_compute_events_[event->index].emplace_back(event);
|
(void)pre_compute_events_[event->index].emplace_back(event);
|
||||||
|
|
|
@ -41,8 +41,11 @@ class MemOffloadStrategy {
|
||||||
public:
|
public:
|
||||||
MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority,
|
MemOffloadStrategy(const std::map<const void *, MemPriority> &mem_priority,
|
||||||
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events,
|
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events,
|
||||||
size_t total_step)
|
const std::set<const void *> &manual_offload_keys, size_t total_step)
|
||||||
: mem_priority_(mem_priority), mem_events_(mem_events), total_step_(total_step) {}
|
: mem_priority_(mem_priority),
|
||||||
|
mem_events_(mem_events),
|
||||||
|
manual_offload_keys_(manual_offload_keys),
|
||||||
|
total_step_(total_step) {}
|
||||||
|
|
||||||
virtual ~MemOffloadStrategy() = default;
|
virtual ~MemOffloadStrategy() = default;
|
||||||
|
|
||||||
|
@ -58,18 +61,24 @@ class MemOffloadStrategy {
|
||||||
|
|
||||||
bool need_swap() const { return need_swap_; }
|
bool need_swap() const { return need_swap_; }
|
||||||
|
|
||||||
|
private:
|
||||||
bool IsHighPriorityMem(const void *key);
|
bool IsHighPriorityMem(const void *key);
|
||||||
|
|
||||||
private:
|
|
||||||
void CountMemUsage();
|
void CountMemUsage();
|
||||||
|
|
||||||
void CheckMemSize();
|
void CheckMemSize();
|
||||||
|
|
||||||
void GenEventSpan();
|
void GenEventSpan();
|
||||||
|
|
||||||
void GenSwapEventSet();
|
void GenSwapEventSet();
|
||||||
|
|
||||||
void GenComputeMemEvents();
|
void GenComputeMemEvents();
|
||||||
|
|
||||||
void GenFreeEvent(const std::shared_ptr<MemEvent> &last_event);
|
void GenFreeEvent(const std::shared_ptr<MemEvent> &last_event);
|
||||||
|
|
||||||
const std::map<const void *, MemPriority> &mem_priority_;
|
const std::map<const void *, MemPriority> &mem_priority_;
|
||||||
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_;
|
const std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> &mem_events_;
|
||||||
|
const std::set<const void *> &manual_offload_keys_;
|
||||||
const size_t total_step_;
|
const size_t total_step_;
|
||||||
std::vector<std::vector<std::shared_ptr<MemEvent>>> pre_compute_events_;
|
std::vector<std::vector<std::shared_ptr<MemEvent>>> pre_compute_events_;
|
||||||
std::vector<std::vector<std::shared_ptr<MemEvent>>> post_compute_events_;
|
std::vector<std::vector<std::shared_ptr<MemEvent>>> post_compute_events_;
|
||||||
|
|
|
@ -16,12 +16,12 @@
|
||||||
|
|
||||||
#include "runtime/device/memory_scheduler.h"
|
#include "runtime/device/memory_scheduler.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "utils/log_adapter.h"
|
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
#include <time.h>
|
#include <time.h>
|
||||||
#else
|
#else
|
||||||
#include <sys/time.h>
|
#include <sys/time.h>
|
||||||
#endif
|
#endif
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
|
@ -51,7 +51,7 @@ void MemScheduler::Clear() {
|
||||||
high_priority_device_ptr_.clear();
|
high_priority_device_ptr_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemScheduler::ClearTempMem() {
|
void MemScheduler::ClearAllocatedMem() {
|
||||||
if (mem_handler_ == nullptr) {
|
if (mem_handler_ == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -72,8 +72,6 @@ void MemScheduler::ClearTempMem() {
|
||||||
swap_host_ptr_.clear();
|
swap_host_ptr_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
|
|
||||||
|
|
||||||
void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
|
void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
|
||||||
if (key == nullptr) {
|
if (key == nullptr) {
|
||||||
return;
|
return;
|
||||||
|
@ -184,7 +182,7 @@ bool MemScheduler::PostCompute(void *stream) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (record_compute_time_ && !updated_) {
|
if (record_compute_time_ && !updated_ && current_step_ < compute_time_.size()) {
|
||||||
compute_time_[current_step_] = GetCurrentTime() - compute_start_time_;
|
compute_time_[current_step_] = GetCurrentTime() - compute_start_time_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -227,8 +225,12 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
|
||||||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
MS_EXCEPTION_IF_NULL(mem_handler_);
|
||||||
|
|
||||||
if (strategy_ == nullptr) {
|
if (strategy_ == nullptr) {
|
||||||
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, total_step_);
|
strategy_ = std::make_shared<MemOffloadStrategy>(mem_priority_, mem_events_, manual_offload_keys_, total_step_);
|
||||||
compute_time_.resize(total_step_);
|
if (manual_offload_keys_.empty()) {
|
||||||
|
compute_time_.resize(total_step_);
|
||||||
|
} else {
|
||||||
|
updated_ = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto available_mem_size = mem_handler_->GetAvailableMemSize();
|
auto available_mem_size = mem_handler_->GetAvailableMemSize();
|
||||||
|
@ -237,7 +239,7 @@ void MemScheduler::OptMemUsage(float mem_used_factor) {
|
||||||
strategy_->Execute();
|
strategy_->Execute();
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemScheduler::Optimize() {
|
bool MemScheduler::Optimize() {
|
||||||
AdjustFirstEventIndex();
|
AdjustFirstEventIndex();
|
||||||
float mem_used_factor = kMaxMemReuseFactor;
|
float mem_used_factor = kMaxMemReuseFactor;
|
||||||
while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) {
|
while (!optimized_ && mem_used_factor >= kMinMemReuseFactor) {
|
||||||
|
@ -265,10 +267,11 @@ void MemScheduler::Optimize() {
|
||||||
if (ret) {
|
if (ret) {
|
||||||
optimized_ = true;
|
optimized_ = true;
|
||||||
} else {
|
} else {
|
||||||
ClearTempMem();
|
ClearAllocatedMem();
|
||||||
mem_used_factor -= kRetryFactor;
|
mem_used_factor -= kRetryFactor;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return optimized_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MemScheduler::AdjustFirstEventIndex() {
|
void MemScheduler::AdjustFirstEventIndex() {
|
||||||
|
|
|
@ -45,8 +45,6 @@ class MemScheduler {
|
||||||
|
|
||||||
void set_need_record_event(bool flag) { need_record_event_ = flag; }
|
void set_need_record_event(bool flag) { need_record_event_ = flag; }
|
||||||
|
|
||||||
bool optimized() const { return optimized_; }
|
|
||||||
|
|
||||||
void Update();
|
void Update();
|
||||||
|
|
||||||
void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
|
void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
|
||||||
|
@ -60,19 +58,19 @@ class MemScheduler {
|
||||||
step_events_.resize(total_step_);
|
step_events_.resize(total_step_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ResetCurrentStep() { current_step_ = 0; }
|
void Reset() { current_step_ = 0; }
|
||||||
|
|
||||||
bool PreCompute(void *stream);
|
bool PreCompute(void *stream);
|
||||||
|
|
||||||
bool PostCompute(void *stream);
|
bool PostCompute(void *stream);
|
||||||
|
|
||||||
void Optimize();
|
bool Optimize();
|
||||||
|
|
||||||
void Clear();
|
void Clear();
|
||||||
|
|
||||||
void ClearTempMem();
|
void ClearAllocatedMem();
|
||||||
|
|
||||||
void SetMemPriority(const void *key, MemPriority priority);
|
void SetOffload(const void *key) { (void)manual_offload_keys_.insert(key); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
|
void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
|
||||||
|
@ -83,6 +81,7 @@ class MemScheduler {
|
||||||
|
|
||||||
std::map<const void *, MemPriority> mem_priority_;
|
std::map<const void *, MemPriority> mem_priority_;
|
||||||
std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_;
|
std::map<const void *, std::vector<std::shared_ptr<MemEvent>>> mem_events_;
|
||||||
|
std::set<const void *> manual_offload_keys_;
|
||||||
std::vector<std::vector<std::shared_ptr<MemEvent>>> step_events_;
|
std::vector<std::vector<std::shared_ptr<MemEvent>>> step_events_;
|
||||||
std::map<const void *, void *> mem_result_;
|
std::map<const void *, void *> mem_result_;
|
||||||
std::map<const void *, void *> init_host_ptr_;
|
std::map<const void *, void *> init_host_ptr_;
|
||||||
|
|
|
@ -349,6 +349,7 @@ constexpr auto kAttrIsAICPUKernel = "is_AICPU_kernel";
|
||||||
constexpr auto kIsBackendCast = "is_backed_cast";
|
constexpr auto kIsBackendCast = "is_backed_cast";
|
||||||
constexpr auto kAttrOutputNames = "output_names";
|
constexpr auto kAttrOutputNames = "output_names";
|
||||||
constexpr auto kAttrAsync = "async";
|
constexpr auto kAttrAsync = "async";
|
||||||
|
constexpr auto kAttrOffload = "offload";
|
||||||
constexpr auto kAttrVisited = "visited";
|
constexpr auto kAttrVisited = "visited";
|
||||||
constexpr auto kAttrShape = "shape";
|
constexpr auto kAttrShape = "shape";
|
||||||
constexpr auto kAttrMomentum = "momentum";
|
constexpr auto kAttrMomentum = "momentum";
|
||||||
|
|
|
@ -123,3 +123,32 @@ def test_lenet():
|
||||||
diff = res.asnumpy()[0] - 2.3025851
|
diff = res.asnumpy()[0] - 2.3025851
|
||||||
assert np.all(diff < 1.e-6)
|
assert np.all(diff < 1.e-6)
|
||||||
os.environ['ENABLE_MEM_SCHEDULER'] = '0'
|
os.environ['ENABLE_MEM_SCHEDULER'] = '0'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_arm_ascend_training
|
||||||
|
@pytest.mark.platform_x86_ascend_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_lenet_manual_offload():
|
||||||
|
'''
|
||||||
|
Feature: MemScheduler
|
||||||
|
Description: Test set offload strategy
|
||||||
|
Expectation: Run lenet success
|
||||||
|
'''
|
||||||
|
os.environ['ENABLE_MEM_SCHEDULER'] = '1'
|
||||||
|
data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||||
|
label = Tensor(np.ones([32]).astype(np.int32))
|
||||||
|
net = LeNet()
|
||||||
|
net.relu.add_prim_attr("Offload", True)
|
||||||
|
learning_rate = 0.01
|
||||||
|
momentum = 0.9
|
||||||
|
|
||||||
|
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
|
||||||
|
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||||
|
net_with_criterion = WithLossCell(net, criterion)
|
||||||
|
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
|
||||||
|
train_network.set_train()
|
||||||
|
res = train_network(data, label)
|
||||||
|
diff = res.asnumpy()[0] - 2.3025851
|
||||||
|
assert np.all(diff < 1.e-6)
|
||||||
|
os.environ['ENABLE_MEM_SCHEDULER'] = '0'
|
||||||
|
|
|
@ -19,47 +19,58 @@
|
||||||
#include "common/common_test.h"
|
#include "common/common_test.h"
|
||||||
#include "runtime/device/memory_scheduler.h"
|
#include "runtime/device/memory_scheduler.h"
|
||||||
namespace mindspore::device {
|
namespace mindspore::device {
|
||||||
constexpr size_t kDeviceMemSize = 1 * 1024 * 1024 * 1024;
|
constexpr size_t kDeviceMemSize = 5;
|
||||||
constexpr size_t kMaxVirtualCount = 1 * 1024 * 1024;
|
constexpr size_t kMaxVirtualCount = 1024;
|
||||||
class MemHandlerImpl : public MemHandler {
|
class MemHandlerImpl : public MemHandler {
|
||||||
public:
|
public:
|
||||||
MemHandlerImpl() {
|
MemHandlerImpl() {
|
||||||
device_mem_.resize(kMaxVirtualCount, 0);
|
device_mem_.resize(kMaxVirtualCount, 0);
|
||||||
host_mem_.resize(kMaxVirtualCount, 1);
|
host_mem_.resize(kMaxVirtualCount, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t GetAvailableMemSize() override { return kDeviceMemSize; }
|
size_t GetAvailableMemSize() override { return kDeviceMemSize; }
|
||||||
|
|
||||||
void *MallocDevice(size_t mem_size) override {
|
void *MallocDevice(size_t mem_size) override {
|
||||||
|
if (device_virtual_count_ >= kDeviceMemSize) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
auto ret = device_mem_.data() + device_virtual_count_;
|
auto ret = device_mem_.data() + device_virtual_count_;
|
||||||
++device_virtual_count_;
|
++device_virtual_count_;
|
||||||
device_mem_size_.emplace(ret, mem_size);
|
device_mem_size_.emplace(ret, mem_size);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FreeDevice(void *ptr) override {
|
void FreeDevice(void *ptr) override {
|
||||||
|
--device_virtual_count_;
|
||||||
auto iter = device_mem_size_.find(ptr);
|
auto iter = device_mem_size_.find(ptr);
|
||||||
if (iter != device_mem_size_.end()) {
|
if (iter != device_mem_size_.end()) {
|
||||||
device_mem_size_.erase(iter);
|
device_mem_size_.erase(iter);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void *MallocHost(size_t mem_size) override {
|
void *MallocHost(size_t mem_size) override {
|
||||||
auto ret = host_mem_.data() + host_virtual_count_;
|
auto ret = host_mem_.data() + host_virtual_count_;
|
||||||
++host_virtual_count_;
|
++host_virtual_count_;
|
||||||
host_mem_size_.emplace(ret, mem_size);
|
host_mem_size_.emplace(ret, mem_size);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FreeHost(void *ptr) override {
|
void FreeHost(void *ptr) override {
|
||||||
auto iter = host_mem_size_.find(ptr);
|
auto iter = host_mem_size_.find(ptr);
|
||||||
if (iter != host_mem_size_.end()) {
|
if (iter != host_mem_size_.end()) {
|
||||||
host_mem_size_.erase(iter);
|
host_mem_size_.erase(iter);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {}
|
void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {}
|
||||||
|
|
||||||
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {}
|
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<uint8_t> device_mem_;
|
std::vector<uint8_t> device_mem_;
|
||||||
std::vector<uint8_t> host_mem_;
|
std::vector<uint8_t> host_mem_;
|
||||||
size_t device_virtual_count_;
|
size_t device_virtual_count_{0};
|
||||||
size_t host_virtual_count_;
|
size_t host_virtual_count_{0};
|
||||||
std::map<void *, size_t> device_mem_size_;
|
std::map<void *, size_t> device_mem_size_;
|
||||||
std::map<void *, size_t> host_mem_size_;
|
std::map<void *, size_t> host_mem_size_;
|
||||||
};
|
};
|
||||||
|
@ -67,6 +78,47 @@ class MemHandlerImpl : public MemHandler {
|
||||||
class TestMemScheduler : public UT::Common {
|
class TestMemScheduler : public UT::Common {
|
||||||
public:
|
public:
|
||||||
TestMemScheduler() {}
|
TestMemScheduler() {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
size_t used_tensor_num_{1};
|
||||||
|
size_t total_step_{1};
|
||||||
|
std::vector<uint8_t> tensor_keys_;
|
||||||
|
std::vector<uint8_t> tensor_datas_;
|
||||||
|
std::vector<size_t> init_tensors_;
|
||||||
|
std::vector<std::vector<size_t>> step_used_tensors_;
|
||||||
|
|
||||||
|
void Record(const std::shared_ptr<MemScheduler> &scheduler) {
|
||||||
|
void *stream = nullptr;
|
||||||
|
for (auto index : init_tensors_) {
|
||||||
|
scheduler->Init(tensor_keys_.data() + index, tensor_datas_.data() + index, 1, kMemPriorityHigh);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < total_step_; ++i) {
|
||||||
|
auto &tensors = step_used_tensors_[i];
|
||||||
|
for (auto j : tensors) {
|
||||||
|
scheduler->GetOrMalloc(tensor_keys_.data() + j, 1);
|
||||||
|
}
|
||||||
|
scheduler->PostCompute(stream);
|
||||||
|
}
|
||||||
|
scheduler->set_need_record_event(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Run(const std::shared_ptr<MemScheduler> &scheduler) {
|
||||||
|
void *stream = nullptr;
|
||||||
|
scheduler->Reset();
|
||||||
|
scheduler->Update();
|
||||||
|
for (auto index : init_tensors_) {
|
||||||
|
scheduler->Init(tensor_keys_.data() + index, tensor_datas_.data() + index, 1, kMemPriorityHigh);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < total_step_; ++i) {
|
||||||
|
scheduler->PreCompute(stream);
|
||||||
|
auto &tensors = step_used_tensors_[i];
|
||||||
|
for (auto j : tensors) {
|
||||||
|
auto addr = scheduler->GetOrMalloc(tensor_keys_.data() + j, 1);
|
||||||
|
ASSERT_NE(addr, nullptr);
|
||||||
|
}
|
||||||
|
scheduler->PostCompute(stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Feature: MemSchedulerManager
|
/// Feature: MemSchedulerManager
|
||||||
|
@ -91,49 +143,93 @@ TEST_F(TestMemScheduler, test_mem_scheduler) {
|
||||||
ASSERT_NE(scheduler, nullptr);
|
ASSERT_NE(scheduler, nullptr);
|
||||||
auto need_record = scheduler->need_record_event();
|
auto need_record = scheduler->need_record_event();
|
||||||
ASSERT_EQ(need_record, true);
|
ASSERT_EQ(need_record, true);
|
||||||
auto optimized = scheduler->optimized();
|
|
||||||
ASSERT_EQ(optimized, false);
|
|
||||||
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>();
|
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>();
|
||||||
ASSERT_NE(mem_handler, nullptr);
|
ASSERT_NE(mem_handler, nullptr);
|
||||||
scheduler->SetMemHandler(mem_handler);
|
scheduler->SetMemHandler(mem_handler);
|
||||||
|
|
||||||
constexpr size_t kUsedTensors = 10;
|
// input data
|
||||||
constexpr size_t kTimeSlice = 7;
|
used_tensor_num_ = 10;
|
||||||
std::vector<uint8_t> tensor_keys(kUsedTensors, 0);
|
total_step_ = 8;
|
||||||
std::vector<uint8_t> tensor_datas(kUsedTensors, 0);
|
std::vector<uint8_t> tensor_keys(used_tensor_num_, 0);
|
||||||
|
std::vector<uint8_t> tensor_datas(used_tensor_num_, 0);
|
||||||
std::vector<size_t> init_tensors = {0, 2, 4};
|
std::vector<size_t> init_tensors = {0, 2, 4};
|
||||||
std::vector<std::vector<size_t>> step_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6}, {6, 7}, {2, 7, 8}, {4, 8, 9}};
|
// 8 step tensor usage
|
||||||
void *stream = nullptr;
|
//
|
||||||
scheduler->SetTotalStep(kTimeSlice);
|
// 0
|
||||||
// record
|
// 1 1-----------------1
|
||||||
for (auto index : init_tensors) {
|
// 2--------------2
|
||||||
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
|
// 3 3--------3
|
||||||
}
|
// 4-----4
|
||||||
for (size_t i = 0; i < kTimeSlice; ++i) {
|
// 5 5
|
||||||
auto &tensors = step_tensors[i];
|
// 6 6
|
||||||
for (auto j : tensors) {
|
// 7 7
|
||||||
scheduler->GetOrMalloc(tensor_keys.data() + j, 1);
|
// 8 8
|
||||||
}
|
// 9 9
|
||||||
scheduler->PostCompute(stream);
|
std::vector<std::vector<size_t>> step_used_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6},
|
||||||
}
|
{4, 6, 7}, {3, 7, 8}, {2, 8, 9}, {1, 9}};
|
||||||
scheduler->set_need_record_event(false);
|
tensor_keys_.swap(tensor_keys);
|
||||||
|
tensor_datas_.swap(tensor_datas);
|
||||||
|
init_tensors_.swap(init_tensors);
|
||||||
|
step_used_tensors_.swap(step_used_tensors);
|
||||||
|
scheduler->SetTotalStep(total_step_);
|
||||||
|
|
||||||
|
// record
|
||||||
|
Record(scheduler);
|
||||||
// optimize
|
// optimize
|
||||||
scheduler->Optimize();
|
scheduler->Optimize();
|
||||||
|
|
||||||
// run
|
// run
|
||||||
scheduler->ResetCurrentStep();
|
Run(scheduler);
|
||||||
for (auto index : init_tensors) {
|
}
|
||||||
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
|
|
||||||
}
|
/// Feature: MemScheduler
|
||||||
for (size_t i = 0; i < kTimeSlice; ++i) {
|
/// Description: Test MemScheduler interface
|
||||||
scheduler->PreCompute(stream);
|
/// Expectation: MemScheduler GetOrMalloc return valid ptr
|
||||||
auto &tensors = step_tensors[i];
|
TEST_F(TestMemScheduler, test_manual_mem_scheduler) {
|
||||||
for (auto j : tensors) {
|
MemSchedulerManager mem_scheduler_manager;
|
||||||
auto addr = scheduler->GetOrMalloc(tensor_keys.data() + j, 1);
|
auto scheduler = mem_scheduler_manager.GetOrCreateMemScheduler(0);
|
||||||
ASSERT_NE(addr, nullptr);
|
ASSERT_NE(scheduler, nullptr);
|
||||||
}
|
auto need_record = scheduler->need_record_event();
|
||||||
scheduler->PostCompute(stream);
|
ASSERT_EQ(need_record, true);
|
||||||
|
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>();
|
||||||
|
ASSERT_NE(mem_handler, nullptr);
|
||||||
|
scheduler->SetMemHandler(mem_handler);
|
||||||
|
|
||||||
|
// input data
|
||||||
|
used_tensor_num_ = 10;
|
||||||
|
total_step_ = 8;
|
||||||
|
std::vector<uint8_t> tensor_keys(used_tensor_num_, 0);
|
||||||
|
std::vector<uint8_t> tensor_datas(used_tensor_num_, 0);
|
||||||
|
std::vector<size_t> init_tensors = {0, 2, 4};
|
||||||
|
std::vector<size_t> offload_tensor = {1, 2, 3};
|
||||||
|
// 8 step tensor usage
|
||||||
|
//
|
||||||
|
// 0
|
||||||
|
// 1 1-----------------1
|
||||||
|
// 2--------------2
|
||||||
|
// 3 3--------3
|
||||||
|
// 4-----4
|
||||||
|
// 5 5
|
||||||
|
// 6 6
|
||||||
|
// 7 7
|
||||||
|
// 8 8
|
||||||
|
// 9 9
|
||||||
|
std::vector<std::vector<size_t>> step_used_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6},
|
||||||
|
{4, 6, 7}, {3, 7, 8}, {2, 8, 9}, {1, 9}};
|
||||||
|
tensor_keys_.swap(tensor_keys);
|
||||||
|
tensor_datas_.swap(tensor_datas);
|
||||||
|
init_tensors_.swap(init_tensors);
|
||||||
|
step_used_tensors_.swap(step_used_tensors);
|
||||||
|
scheduler->SetTotalStep(total_step_);
|
||||||
|
|
||||||
|
// set offload key
|
||||||
|
for (auto index : offload_tensor) {
|
||||||
|
scheduler->SetOffload(tensor_keys_.data() + index);
|
||||||
}
|
}
|
||||||
|
// record
|
||||||
|
Record(scheduler);
|
||||||
|
// optimize
|
||||||
|
scheduler->Optimize();
|
||||||
|
// run
|
||||||
|
Run(scheduler);
|
||||||
}
|
}
|
||||||
} // namespace mindspore::device
|
} // namespace mindspore::device
|
Loading…
Reference in New Issue