forked from mindspore-Ecosystem/mindspore
add utst for mem scheduler
This commit is contained in:
parent
ac4852b8af
commit
6582744bf6
|
@ -1456,7 +1456,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
|
|||
mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
|
||||
MS_EXCEPTION_IF_NULL(mem_scheduler);
|
||||
mem_scheduler->SetMemHandler(mem_manager_);
|
||||
mem_scheduler->RecordMemUsage();
|
||||
mem_scheduler->Reset();
|
||||
InitGraphInputTensors(mem_scheduler, graph);
|
||||
}
|
||||
const auto &kernels = graph.execution_order();
|
||||
|
@ -1513,9 +1513,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
|
|||
}
|
||||
LaunchKernelEvent(kernel_post_run_events, i);
|
||||
}
|
||||
if (mem_scheduler != nullptr) {
|
||||
mem_scheduler->OptMemUsage();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1527,17 +1525,22 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
|
|||
auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
|
||||
if (mem_scheduler->need_record_event()) {
|
||||
(void)LaunchKernelMod(graph, true);
|
||||
mem_scheduler->set_need_record_event(false);
|
||||
}
|
||||
float mem_used_factor = kMaxMemReuseFactor;
|
||||
while (!mem_scheduler->optimized() && mem_used_factor >= kMinMemReuseFactor) {
|
||||
mem_scheduler->SetMemUsedFactor(mem_used_factor);
|
||||
mem_scheduler->OptMemUsage();
|
||||
bool ret = LaunchKernelMod(graph, true);
|
||||
if (ret) {
|
||||
mem_scheduler->SetOptimized(true);
|
||||
mem_scheduler->set_optimized(true);
|
||||
} else {
|
||||
mem_used_factor -= kRetryFactor;
|
||||
}
|
||||
}
|
||||
if (!mem_scheduler->optimized()) {
|
||||
MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -69,6 +69,9 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
if (!has_compute_mem_events_) {
|
||||
return nullptr;
|
||||
}
|
||||
auto iter = mem_result_.find(key);
|
||||
if (iter != mem_result_.end()) {
|
||||
auto ptr = iter->second;
|
||||
|
@ -80,7 +83,7 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
|
|||
}
|
||||
|
||||
bool MemScheduler::PreCompute(void *stream) {
|
||||
if (need_record_event_) {
|
||||
if (!has_compute_mem_events_) {
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(mem_handler_);
|
||||
|
@ -100,9 +103,6 @@ bool MemScheduler::PreCompute(void *stream) {
|
|||
if (priority != kMemPriorityLow && iter != high_priority_device_ptr_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
mem_result_[event->key] = iter->second;
|
||||
if (priority == kMemPriorityMedium) {
|
||||
mem_handler_->SwapIn(host_ptr, iter->second, event->mem_size, stream);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
|
||||
|
@ -136,7 +136,7 @@ bool MemScheduler::PreCompute(void *stream) {
|
|||
mem_result_[event->key] = device_ptr;
|
||||
if (!from_init) {
|
||||
mem_handler_->FreeHost(host_ptr);
|
||||
swap_host_ptr_.erase(event->key);
|
||||
(void)swap_host_ptr_.erase(event->key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -144,7 +144,7 @@ bool MemScheduler::PreCompute(void *stream) {
|
|||
}
|
||||
|
||||
bool MemScheduler::PostCompute(void *stream) {
|
||||
if (need_record_event_) {
|
||||
if (!has_compute_mem_events_) {
|
||||
++compute_index_;
|
||||
return true;
|
||||
}
|
||||
|
@ -177,7 +177,7 @@ bool MemScheduler::PostCompute(void *stream) {
|
|||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
|
||||
mem_handler_->FreeDevice(device_ptr);
|
||||
mem_result_.erase(device_ptr);
|
||||
(void)mem_result_.erase(device_ptr);
|
||||
}
|
||||
}
|
||||
++compute_index_;
|
||||
|
@ -185,7 +185,6 @@ bool MemScheduler::PostCompute(void *stream) {
|
|||
}
|
||||
|
||||
void MemScheduler::OptMemUsage() {
|
||||
need_record_event_ = false;
|
||||
if (optimized_) {
|
||||
return;
|
||||
}
|
||||
|
@ -195,7 +194,8 @@ void MemScheduler::OptMemUsage() {
|
|||
GenEventSpan();
|
||||
GenNoSwapEventSet();
|
||||
}
|
||||
GenEvents();
|
||||
GenComputeMemEvents();
|
||||
has_compute_mem_events_ = true;
|
||||
}
|
||||
|
||||
void MemScheduler::CountMemUsage() {
|
||||
|
@ -210,11 +210,10 @@ void MemScheduler::CountMemUsage() {
|
|||
continue;
|
||||
}
|
||||
auto first_event = mem_events[0];
|
||||
MS_EXCEPTION_IF_NULL(first_event);
|
||||
size_t i = 0;
|
||||
if (first_event->type == kInit && mem_events.size() > 1) {
|
||||
size_t cur_index = 0;
|
||||
if (first_event != nullptr && first_event->type == kInit && mem_events.size() > 1) {
|
||||
first_event = mem_events[1];
|
||||
i = 1;
|
||||
cur_index = 1;
|
||||
}
|
||||
auto last_event = mem_events[mem_events.size() - 1];
|
||||
for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
|
||||
|
@ -224,8 +223,8 @@ void MemScheduler::CountMemUsage() {
|
|||
MS_LOG(ERROR) << "Error mem event index " << start_index;
|
||||
}
|
||||
}
|
||||
for (; i < mem_events.size(); ++i) {
|
||||
auto &event = mem_events[i];
|
||||
for (; cur_index < mem_events.size(); ++cur_index) {
|
||||
auto &event = mem_events[cur_index];
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
if (event->index < compute_index_) {
|
||||
min_mem_used_[event->index] += first_event->mem_size;
|
||||
|
@ -248,8 +247,8 @@ void MemScheduler::CheckMemSize() {
|
|||
if (mem_used_without_swap_ > available_mem_size) {
|
||||
need_swap_ = true;
|
||||
}
|
||||
MS_LOG(INFO) << "Available mem size: " << available_mem_size << ", graph needs mem size:" << mem_used_without_swap_
|
||||
<< "without swap, and needs at least " << min_mem_needed_ << " with swap.";
|
||||
MS_LOG(INFO) << "Available mem size: " << available_mem_size << ", graph needs mem size: " << mem_used_without_swap_
|
||||
<< " without swap, and needs at least " << min_mem_needed_ << " with swap.";
|
||||
}
|
||||
|
||||
void MemScheduler::GenEventSpan() {
|
||||
|
@ -257,20 +256,19 @@ void MemScheduler::GenEventSpan() {
|
|||
return;
|
||||
}
|
||||
for (auto &item : mem_events_) {
|
||||
auto &mem_events = item.second;
|
||||
if (mem_events.empty()) {
|
||||
auto &tensor_events = item.second;
|
||||
if (tensor_events.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto first_event = mem_events[0];
|
||||
MS_EXCEPTION_IF_NULL(first_event);
|
||||
size_t i = 0;
|
||||
if (first_event->type == kInit && mem_events.size() > 1) {
|
||||
first_event = mem_events[1];
|
||||
i = 1;
|
||||
auto first_event = tensor_events[0];
|
||||
size_t cur_index = 0;
|
||||
if (first_event != nullptr && first_event->type == kInit && tensor_events.size() > 1) {
|
||||
first_event = tensor_events[1];
|
||||
cur_index = 1;
|
||||
}
|
||||
size_t last_index = first_event->index;
|
||||
for (; i < mem_events.size(); ++i) {
|
||||
auto &event = mem_events[i];
|
||||
for (; cur_index < tensor_events.size(); ++cur_index) {
|
||||
auto &event = tensor_events[cur_index];
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
auto span = event->index - last_index;
|
||||
if (span > 1) {
|
||||
|
@ -303,12 +301,14 @@ void MemScheduler::GenNoSwapEventSet() {
|
|||
cur_mem_used[i] -= event->mem_size;
|
||||
}
|
||||
} else {
|
||||
no_swap_events_.emplace(event);
|
||||
(void)no_swap_events_.emplace(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MemScheduler::GenEvents() {
|
||||
void MemScheduler::GenComputeMemEvents() {
|
||||
pre_compute_events_.clear();
|
||||
post_compute_events_.clear();
|
||||
pre_compute_events_.resize(compute_index_);
|
||||
post_compute_events_.resize(compute_index_);
|
||||
for (auto &item : mem_events_) {
|
||||
|
@ -351,10 +351,10 @@ void MemScheduler::GenEvents() {
|
|||
auto swap_in_event = std::make_shared<Event>(kSwapIn, event->index);
|
||||
swap_in_event->key = item.first;
|
||||
swap_in_event->mem_size = first_event->mem_size;
|
||||
pre_compute_events_[event->index].emplace_back(swap_in_event);
|
||||
(void)pre_compute_events_[event->index].emplace_back(swap_in_event);
|
||||
}
|
||||
if (event->index < pre_compute_events_.size()) {
|
||||
pre_compute_events_[event->index].emplace_back(event);
|
||||
(void)pre_compute_events_[event->index].emplace_back(event);
|
||||
}
|
||||
pre_index = event->index;
|
||||
}
|
||||
|
@ -366,7 +366,7 @@ void MemScheduler::GenEvents() {
|
|||
auto free_event = std::make_shared<Event>(kFree, last_event->index);
|
||||
free_event->key = item.first;
|
||||
if (last_event->index < post_compute_events_.size()) {
|
||||
post_compute_events_[last_event->index].emplace_back(free_event);
|
||||
(void)post_compute_events_[last_event->index].emplace_back(free_event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ class MemHandler {
|
|||
virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0;
|
||||
};
|
||||
|
||||
enum MemPriority { kMemPriorityLow, kMemPriorityMedium, kMemPriorityHigh };
|
||||
enum MemPriority { kMemPriorityLow, kMemPriorityHigh };
|
||||
|
||||
class MemScheduler {
|
||||
enum EventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut };
|
||||
|
@ -58,9 +58,11 @@ class MemScheduler {
|
|||
|
||||
bool need_record_event() const { return need_record_event_; }
|
||||
|
||||
void set_need_record_event(bool flag) { need_record_event_ = flag; }
|
||||
|
||||
bool optimized() const { return optimized_; }
|
||||
|
||||
void SetOptimized(bool flag) { optimized_ = flag; }
|
||||
void set_optimized(bool flag) { optimized_ = flag; }
|
||||
|
||||
void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
|
||||
|
||||
|
@ -68,7 +70,7 @@ class MemScheduler {
|
|||
|
||||
void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow);
|
||||
|
||||
void RecordMemUsage() { compute_index_ = 0; }
|
||||
void Reset() { compute_index_ = 0; }
|
||||
|
||||
bool PreCompute(void *stream);
|
||||
|
||||
|
@ -88,7 +90,7 @@ class MemScheduler {
|
|||
|
||||
private:
|
||||
void Record(const void *key, const EventType &event_type, size_t mem_size = 0);
|
||||
void GenEvents();
|
||||
void GenComputeMemEvents();
|
||||
void CheckMemSize();
|
||||
void CountMemUsage();
|
||||
void GenEventSpan();
|
||||
|
@ -104,6 +106,7 @@ class MemScheduler {
|
|||
size_t compute_index_{0};
|
||||
bool need_record_event_{true};
|
||||
bool optimized_{false};
|
||||
bool has_compute_mem_events_{false};
|
||||
std::shared_ptr<MemHandler> mem_handler_{nullptr};
|
||||
bool need_swap_{false};
|
||||
std::multimap<size_t, std::shared_ptr<Event>> event_span_;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.ops import operations as P
|
||||
from tests.st.tbe_networks.resnet import resnet50
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class LeNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(LeNet, self).__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.batch_size = 32
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
|
||||
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.reshape = P.Reshape()
|
||||
self.fc1 = nn.Dense(400, 120)
|
||||
self.fc2 = nn.Dense(120, 84)
|
||||
self.fc3 = nn.Dense(84, 10)
|
||||
|
||||
|
||||
def construct(self, input_x):
|
||||
output = self.conv1(input_x)
|
||||
output = self.relu(output)
|
||||
output = self.pool(output)
|
||||
output = self.conv2(output)
|
||||
output = self.relu(output)
|
||||
output = self.pool(output)
|
||||
output = self.reshape(output, (self.batch_size, -1))
|
||||
output = self.fc1(output)
|
||||
output = self.relu(output)
|
||||
output = self.fc2(output)
|
||||
output = self.relu(output)
|
||||
output = self.fc3(output)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_resnet():
|
||||
'''
|
||||
Feature: MemScheduler
|
||||
Description: Test MemScheduler
|
||||
Expectation: Run resnet success
|
||||
'''
|
||||
os.environ['ENABLE_MEM_SCHEDULER'] = '1'
|
||||
num_classes = 10
|
||||
epoch = 8
|
||||
batch_size = 1
|
||||
net = resnet50(batch_size, num_classes)
|
||||
lr = 0.1
|
||||
momentum = 0.9
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad,
|
||||
net.get_parameters()), lr, momentum)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = TrainOneStepCell(
|
||||
net_with_criterion, optimizer) # optimizer
|
||||
train_network.set_train()
|
||||
losses = []
|
||||
for _ in range(0, epoch):
|
||||
data = Tensor(np.ones([batch_size, 3, 224, 224]
|
||||
).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
||||
loss = train_network(data, label)
|
||||
losses.append(loss)
|
||||
assert losses[-1].asnumpy() < 1
|
||||
os.environ['ENABLE_MEM_SCHEDULER'] = '0'
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lenet():
|
||||
'''
|
||||
Feature: MemScheduler
|
||||
Description: Test MemScheduler
|
||||
Expectation: Run lenet success
|
||||
'''
|
||||
os.environ['ENABLE_MEM_SCHEDULER'] = '1'
|
||||
data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.ones([32]).astype(np.int32))
|
||||
net = LeNet()
|
||||
learning_rate = 0.01
|
||||
momentum = 0.9
|
||||
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
|
||||
train_network.set_train()
|
||||
res = train_network(data, label)
|
||||
diff = res.asnumpy()[0] - 2.3025851
|
||||
assert np.all(diff < 1.e-6)
|
||||
os.environ['ENABLE_MEM_SCHEDULER'] = '0'
|
|
@ -113,6 +113,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/runtime/hccl_adapter/all_to_all_v_calc_param.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/kernel_runtime.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/memory_manager.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/memory_scheduler.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/kernel_info.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/bucket.cc"
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "common/common_test.h"
|
||||
#include "runtime/device/memory_scheduler.h"
|
||||
namespace mindspore::device {
|
||||
constexpr size_t kDeviceMemSize = 1 * 1024 * 1024 * 1024;
|
||||
constexpr size_t kMaxVirtualCount = 1 * 1024 * 1024;
|
||||
class MemHandlerImpl : public MemHandler {
|
||||
public:
|
||||
MemHandlerImpl() {
|
||||
device_mem_.resize(kMaxVirtualCount, 0);
|
||||
host_mem_.resize(kMaxVirtualCount, 1);
|
||||
}
|
||||
size_t GetAvailableMemSize() override { return kDeviceMemSize; }
|
||||
void *MallocDevice(size_t mem_size) override {
|
||||
auto ret = device_mem_.data() + device_virtual_count_;
|
||||
++device_virtual_count_;
|
||||
device_mem_size_.emplace(ret, mem_size);
|
||||
return ret;
|
||||
}
|
||||
void FreeDevice(void *ptr) override {
|
||||
auto iter = device_mem_size_.find(ptr);
|
||||
if (iter != device_mem_size_.end()) {
|
||||
device_mem_size_.erase(iter);
|
||||
}
|
||||
}
|
||||
void *MallocHost(size_t mem_size) override {
|
||||
auto ret = host_mem_.data() + host_virtual_count_;
|
||||
++host_virtual_count_;
|
||||
host_mem_size_.emplace(ret, mem_size);
|
||||
return ret;
|
||||
}
|
||||
void FreeHost(void *ptr) override {
|
||||
auto iter = host_mem_size_.find(ptr);
|
||||
if (iter != host_mem_size_.end()) {
|
||||
host_mem_size_.erase(iter);
|
||||
}
|
||||
}
|
||||
void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {}
|
||||
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {}
|
||||
|
||||
private:
|
||||
std::vector<uint8_t> device_mem_;
|
||||
std::vector<uint8_t> host_mem_;
|
||||
size_t device_virtual_count_;
|
||||
size_t host_virtual_count_;
|
||||
std::map<void *, size_t> device_mem_size_;
|
||||
std::map<void *, size_t> host_mem_size_;
|
||||
};
|
||||
|
||||
class TestMemScheduler : public UT::Common {
|
||||
public:
|
||||
TestMemScheduler() {}
|
||||
};
|
||||
|
||||
/// Feature: MemSchedulerManager
|
||||
/// Description: Test MemSchedulerManager GetOrCreateMemScheduler interface
|
||||
/// Expectation: Create MemScheduler
|
||||
TEST_F(TestMemScheduler, test_mem_scheduler_manager) {
|
||||
MemSchedulerManager mem_scheduler_manager;
|
||||
auto ret = mem_scheduler_manager.GetMemScheduler(0);
|
||||
ASSERT_EQ(ret, nullptr);
|
||||
ret = mem_scheduler_manager.GetOrCreateMemScheduler(0);
|
||||
ASSERT_NE(ret, nullptr);
|
||||
ret = mem_scheduler_manager.GetMemScheduler(0);
|
||||
ASSERT_NE(ret, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: MemScheduler
|
||||
/// Description: Test MemScheduler interface
|
||||
/// Expectation: MemScheduler GetOrMalloc return valid ptr
|
||||
TEST_F(TestMemScheduler, test_mem_scheduler) {
|
||||
MemSchedulerManager mem_scheduler_manager;
|
||||
auto scheduler = mem_scheduler_manager.GetOrCreateMemScheduler(0);
|
||||
ASSERT_NE(scheduler, nullptr);
|
||||
auto need_record = scheduler->need_record_event();
|
||||
ASSERT_EQ(need_record, true);
|
||||
auto optimized = scheduler->optimized();
|
||||
ASSERT_EQ(optimized, false);
|
||||
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>();
|
||||
ASSERT_NE(mem_handler, nullptr);
|
||||
scheduler->SetMemHandler(mem_handler);
|
||||
|
||||
constexpr size_t kUsedTensors = 10;
|
||||
constexpr size_t kTimeSlice = 7;
|
||||
std::vector<uint8_t> tensor_keys(kUsedTensors, 0);
|
||||
std::vector<uint8_t> tensor_datas(kUsedTensors, 0);
|
||||
std::vector<size_t> init_tensors = {0, 2, 4};
|
||||
std::vector<std::vector<size_t>> step_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6}, {6, 7}, {2, 7, 8}, {4, 8, 9}};
|
||||
void *stream = nullptr;
|
||||
// record
|
||||
for (auto index : init_tensors) {
|
||||
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
|
||||
}
|
||||
for (size_t i = 0; i < kTimeSlice; ++i) {
|
||||
auto &tensors = step_tensors[i];
|
||||
for (auto j : tensors) {
|
||||
scheduler->GetOrMalloc(tensor_keys.data() + j, 1);
|
||||
}
|
||||
scheduler->PostCompute(stream);
|
||||
}
|
||||
scheduler->set_need_record_event(false);
|
||||
|
||||
// optimize
|
||||
scheduler->OptMemUsage();
|
||||
scheduler->set_optimized(true);
|
||||
|
||||
// run
|
||||
scheduler->Reset();
|
||||
for (auto index : init_tensors) {
|
||||
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
|
||||
}
|
||||
for (size_t i = 0; i < kTimeSlice; ++i) {
|
||||
scheduler->PreCompute(stream);
|
||||
auto &tensors = step_tensors[i];
|
||||
for (auto j : tensors) {
|
||||
auto addr = scheduler->GetOrMalloc(tensor_keys.data() + j, 1);
|
||||
ASSERT_NE(addr, nullptr);
|
||||
}
|
||||
scheduler->PostCompute(stream);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::device
|
Loading…
Reference in New Issue