add utst for mem scheduler

This commit is contained in:
kswang 2021-10-09 12:20:31 +08:00
parent ac4852b8af
commit 6582744bf6
8 changed files with 313 additions and 42 deletions

View File

@ -1456,7 +1456,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
MS_EXCEPTION_IF_NULL(mem_scheduler);
mem_scheduler->SetMemHandler(mem_manager_);
mem_scheduler->RecordMemUsage();
mem_scheduler->Reset();
InitGraphInputTensors(mem_scheduler, graph);
}
const auto &kernels = graph.execution_order();
@ -1513,9 +1513,7 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock
}
LaunchKernelEvent(kernel_post_run_events, i);
}
if (mem_scheduler != nullptr) {
mem_scheduler->OptMemUsage();
}
return true;
}
@ -1527,17 +1525,22 @@ void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
if (mem_scheduler->need_record_event()) {
(void)LaunchKernelMod(graph, true);
mem_scheduler->set_need_record_event(false);
}
float mem_used_factor = kMaxMemReuseFactor;
while (!mem_scheduler->optimized() && mem_used_factor >= kMinMemReuseFactor) {
mem_scheduler->SetMemUsedFactor(mem_used_factor);
mem_scheduler->OptMemUsage();
bool ret = LaunchKernelMod(graph, true);
if (ret) {
mem_scheduler->SetOptimized(true);
mem_scheduler->set_optimized(true);
} else {
mem_used_factor -= kRetryFactor;
}
}
if (!mem_scheduler->optimized()) {
MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit.";
}
}
}

View File

@ -69,6 +69,9 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
}
return nullptr;
}
if (!has_compute_mem_events_) {
return nullptr;
}
auto iter = mem_result_.find(key);
if (iter != mem_result_.end()) {
auto ptr = iter->second;
@ -80,7 +83,7 @@ void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority pr
}
bool MemScheduler::PreCompute(void *stream) {
if (need_record_event_) {
if (!has_compute_mem_events_) {
return true;
}
MS_EXCEPTION_IF_NULL(mem_handler_);
@ -100,9 +103,6 @@ bool MemScheduler::PreCompute(void *stream) {
if (priority != kMemPriorityLow && iter != high_priority_device_ptr_.end()) {
MS_EXCEPTION_IF_NULL(iter->second);
mem_result_[event->key] = iter->second;
if (priority == kMemPriorityMedium) {
mem_handler_->SwapIn(host_ptr, iter->second, event->mem_size, stream);
}
continue;
}
auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
@ -136,7 +136,7 @@ bool MemScheduler::PreCompute(void *stream) {
mem_result_[event->key] = device_ptr;
if (!from_init) {
mem_handler_->FreeHost(host_ptr);
swap_host_ptr_.erase(event->key);
(void)swap_host_ptr_.erase(event->key);
}
}
}
@ -144,7 +144,7 @@ bool MemScheduler::PreCompute(void *stream) {
}
bool MemScheduler::PostCompute(void *stream) {
if (need_record_event_) {
if (!has_compute_mem_events_) {
++compute_index_;
return true;
}
@ -177,7 +177,7 @@ bool MemScheduler::PostCompute(void *stream) {
MS_EXCEPTION_IF_NULL(host_ptr);
mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
mem_handler_->FreeDevice(device_ptr);
mem_result_.erase(device_ptr);
(void)mem_result_.erase(device_ptr);
}
}
++compute_index_;
@ -185,7 +185,6 @@ bool MemScheduler::PostCompute(void *stream) {
}
void MemScheduler::OptMemUsage() {
need_record_event_ = false;
if (optimized_) {
return;
}
@ -195,7 +194,8 @@ void MemScheduler::OptMemUsage() {
GenEventSpan();
GenNoSwapEventSet();
}
GenEvents();
GenComputeMemEvents();
has_compute_mem_events_ = true;
}
void MemScheduler::CountMemUsage() {
@ -210,11 +210,10 @@ void MemScheduler::CountMemUsage() {
continue;
}
auto first_event = mem_events[0];
MS_EXCEPTION_IF_NULL(first_event);
size_t i = 0;
if (first_event->type == kInit && mem_events.size() > 1) {
size_t cur_index = 0;
if (first_event != nullptr && first_event->type == kInit && mem_events.size() > 1) {
first_event = mem_events[1];
i = 1;
cur_index = 1;
}
auto last_event = mem_events[mem_events.size() - 1];
for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
@ -224,8 +223,8 @@ void MemScheduler::CountMemUsage() {
MS_LOG(ERROR) << "Error mem event index " << start_index;
}
}
for (; i < mem_events.size(); ++i) {
auto &event = mem_events[i];
for (; cur_index < mem_events.size(); ++cur_index) {
auto &event = mem_events[cur_index];
MS_EXCEPTION_IF_NULL(event);
if (event->index < compute_index_) {
min_mem_used_[event->index] += first_event->mem_size;
@ -248,8 +247,8 @@ void MemScheduler::CheckMemSize() {
if (mem_used_without_swap_ > available_mem_size) {
need_swap_ = true;
}
MS_LOG(INFO) << "Available mem size: " << available_mem_size << ", graph needs mem size:" << mem_used_without_swap_
<< "without swap, and needs at least " << min_mem_needed_ << " with swap.";
MS_LOG(INFO) << "Available mem size: " << available_mem_size << ", graph needs mem size: " << mem_used_without_swap_
<< " without swap, and needs at least " << min_mem_needed_ << " with swap.";
}
void MemScheduler::GenEventSpan() {
@ -257,20 +256,19 @@ void MemScheduler::GenEventSpan() {
return;
}
for (auto &item : mem_events_) {
auto &mem_events = item.second;
if (mem_events.empty()) {
auto &tensor_events = item.second;
if (tensor_events.empty()) {
continue;
}
auto first_event = mem_events[0];
MS_EXCEPTION_IF_NULL(first_event);
size_t i = 0;
if (first_event->type == kInit && mem_events.size() > 1) {
first_event = mem_events[1];
i = 1;
auto first_event = tensor_events[0];
size_t cur_index = 0;
if (first_event != nullptr && first_event->type == kInit && tensor_events.size() > 1) {
first_event = tensor_events[1];
cur_index = 1;
}
size_t last_index = first_event->index;
for (; i < mem_events.size(); ++i) {
auto &event = mem_events[i];
for (; cur_index < tensor_events.size(); ++cur_index) {
auto &event = tensor_events[cur_index];
MS_EXCEPTION_IF_NULL(event);
auto span = event->index - last_index;
if (span > 1) {
@ -303,12 +301,14 @@ void MemScheduler::GenNoSwapEventSet() {
cur_mem_used[i] -= event->mem_size;
}
} else {
no_swap_events_.emplace(event);
(void)no_swap_events_.emplace(event);
}
}
}
void MemScheduler::GenEvents() {
void MemScheduler::GenComputeMemEvents() {
pre_compute_events_.clear();
post_compute_events_.clear();
pre_compute_events_.resize(compute_index_);
post_compute_events_.resize(compute_index_);
for (auto &item : mem_events_) {
@ -351,10 +351,10 @@ void MemScheduler::GenEvents() {
auto swap_in_event = std::make_shared<Event>(kSwapIn, event->index);
swap_in_event->key = item.first;
swap_in_event->mem_size = first_event->mem_size;
pre_compute_events_[event->index].emplace_back(swap_in_event);
(void)pre_compute_events_[event->index].emplace_back(swap_in_event);
}
if (event->index < pre_compute_events_.size()) {
pre_compute_events_[event->index].emplace_back(event);
(void)pre_compute_events_[event->index].emplace_back(event);
}
pre_index = event->index;
}
@ -366,7 +366,7 @@ void MemScheduler::GenEvents() {
auto free_event = std::make_shared<Event>(kFree, last_event->index);
free_event->key = item.first;
if (last_event->index < post_compute_events_.size()) {
post_compute_events_[last_event->index].emplace_back(free_event);
(void)post_compute_events_[last_event->index].emplace_back(free_event);
}
}
}

View File

@ -35,7 +35,7 @@ class MemHandler {
virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0;
};
enum MemPriority { kMemPriorityLow, kMemPriorityMedium, kMemPriorityHigh };
enum MemPriority { kMemPriorityLow, kMemPriorityHigh };
class MemScheduler {
enum EventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut };
@ -58,9 +58,11 @@ class MemScheduler {
bool need_record_event() const { return need_record_event_; }
void set_need_record_event(bool flag) { need_record_event_ = flag; }
bool optimized() const { return optimized_; }
void SetOptimized(bool flag) { optimized_ = flag; }
void set_optimized(bool flag) { optimized_ = flag; }
void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
@ -68,7 +70,7 @@ class MemScheduler {
void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow);
void RecordMemUsage() { compute_index_ = 0; }
void Reset() { compute_index_ = 0; }
bool PreCompute(void *stream);
@ -88,7 +90,7 @@ class MemScheduler {
private:
void Record(const void *key, const EventType &event_type, size_t mem_size = 0);
void GenEvents();
void GenComputeMemEvents();
void CheckMemSize();
void CountMemUsage();
void GenEventSpan();
@ -104,6 +106,7 @@ class MemScheduler {
size_t compute_index_{0};
bool need_record_event_{true};
bool optimized_{false};
bool has_compute_mem_events_{false};
std::shared_ptr<MemHandler> mem_handler_{nullptr};
bool need_swap_{false};
std::multimap<size_t, std::shared_ptr<Event>> event_span_;

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -0,0 +1,125 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
from tests.st.tbe_networks.resnet import resnet50
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class LeNet(nn.Cell):
def __init__(self):
super(LeNet, self).__init__()
self.relu = P.ReLU()
self.batch_size = 32
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.fc1 = nn.Dense(400, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
def construct(self, input_x):
output = self.conv1(input_x)
output = self.relu(output)
output = self.pool(output)
output = self.conv2(output)
output = self.relu(output)
output = self.pool(output)
output = self.reshape(output, (self.batch_size, -1))
output = self.fc1(output)
output = self.relu(output)
output = self.fc2(output)
output = self.relu(output)
output = self.fc3(output)
return output
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_resnet():
'''
Feature: MemScheduler
Description: Test MemScheduler
Expectation: Run resnet success
'''
os.environ['ENABLE_MEM_SCHEDULER'] = '1'
num_classes = 10
epoch = 8
batch_size = 1
net = resnet50(batch_size, num_classes)
lr = 0.1
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad,
net.get_parameters()), lr, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(
net_with_criterion, optimizer) # optimizer
train_network.set_train()
losses = []
for _ in range(0, epoch):
data = Tensor(np.ones([batch_size, 3, 224, 224]
).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32))
loss = train_network(data, label)
losses.append(loss)
assert losses[-1].asnumpy() < 1
os.environ['ENABLE_MEM_SCHEDULER'] = '0'
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_lenet():
'''
Feature: MemScheduler
Description: Test MemScheduler
Expectation: Run lenet success
'''
os.environ['ENABLE_MEM_SCHEDULER'] = '1'
data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([32]).astype(np.int32))
net = LeNet()
learning_rate = 0.01
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()
res = train_network(data, label)
diff = res.asnumpy()[0] - 2.3025851
assert np.all(diff < 1.e-6)
os.environ['ENABLE_MEM_SCHEDULER'] = '0'

View File

@ -113,6 +113,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/runtime/hccl_adapter/all_to_all_v_calc_param.cc"
"../../../mindspore/ccsrc/runtime/device/kernel_runtime.cc"
"../../../mindspore/ccsrc/runtime/device/memory_manager.cc"
"../../../mindspore/ccsrc/runtime/device/memory_scheduler.cc"
"../../../mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc"
"../../../mindspore/ccsrc/runtime/device/kernel_info.cc"
"../../../mindspore/ccsrc/runtime/device/bucket.cc"

View File

@ -0,0 +1,139 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <map>
#include "common/common_test.h"
#include "runtime/device/memory_scheduler.h"
namespace mindspore::device {
constexpr size_t kDeviceMemSize = 1 * 1024 * 1024 * 1024;
constexpr size_t kMaxVirtualCount = 1 * 1024 * 1024;
class MemHandlerImpl : public MemHandler {
public:
MemHandlerImpl() {
device_mem_.resize(kMaxVirtualCount, 0);
host_mem_.resize(kMaxVirtualCount, 1);
}
size_t GetAvailableMemSize() override { return kDeviceMemSize; }
void *MallocDevice(size_t mem_size) override {
auto ret = device_mem_.data() + device_virtual_count_;
++device_virtual_count_;
device_mem_size_.emplace(ret, mem_size);
return ret;
}
void FreeDevice(void *ptr) override {
auto iter = device_mem_size_.find(ptr);
if (iter != device_mem_size_.end()) {
device_mem_size_.erase(iter);
}
}
void *MallocHost(size_t mem_size) override {
auto ret = host_mem_.data() + host_virtual_count_;
++host_virtual_count_;
host_mem_size_.emplace(ret, mem_size);
return ret;
}
void FreeHost(void *ptr) override {
auto iter = host_mem_size_.find(ptr);
if (iter != host_mem_size_.end()) {
host_mem_size_.erase(iter);
}
}
void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {}
void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {}
private:
std::vector<uint8_t> device_mem_;
std::vector<uint8_t> host_mem_;
size_t device_virtual_count_;
size_t host_virtual_count_;
std::map<void *, size_t> device_mem_size_;
std::map<void *, size_t> host_mem_size_;
};
class TestMemScheduler : public UT::Common {
public:
TestMemScheduler() {}
};
/// Feature: MemSchedulerManager
/// Description: Test MemSchedulerManager GetOrCreateMemScheduler interface
/// Expectation: Create MemScheduler
TEST_F(TestMemScheduler, test_mem_scheduler_manager) {
MemSchedulerManager mem_scheduler_manager;
auto ret = mem_scheduler_manager.GetMemScheduler(0);
ASSERT_EQ(ret, nullptr);
ret = mem_scheduler_manager.GetOrCreateMemScheduler(0);
ASSERT_NE(ret, nullptr);
ret = mem_scheduler_manager.GetMemScheduler(0);
ASSERT_NE(ret, nullptr);
}
/// Feature: MemScheduler
/// Description: Test MemScheduler interface
/// Expectation: MemScheduler GetOrMalloc return valid ptr
TEST_F(TestMemScheduler, test_mem_scheduler) {
MemSchedulerManager mem_scheduler_manager;
auto scheduler = mem_scheduler_manager.GetOrCreateMemScheduler(0);
ASSERT_NE(scheduler, nullptr);
auto need_record = scheduler->need_record_event();
ASSERT_EQ(need_record, true);
auto optimized = scheduler->optimized();
ASSERT_EQ(optimized, false);
std::shared_ptr<MemHandler> mem_handler = std::make_shared<MemHandlerImpl>();
ASSERT_NE(mem_handler, nullptr);
scheduler->SetMemHandler(mem_handler);
constexpr size_t kUsedTensors = 10;
constexpr size_t kTimeSlice = 7;
std::vector<uint8_t> tensor_keys(kUsedTensors, 0);
std::vector<uint8_t> tensor_datas(kUsedTensors, 0);
std::vector<size_t> init_tensors = {0, 2, 4};
std::vector<std::vector<size_t>> step_tensors = {{0, 1}, {1, 2, 3}, {3, 4, 5}, {5, 6}, {6, 7}, {2, 7, 8}, {4, 8, 9}};
void *stream = nullptr;
// record
for (auto index : init_tensors) {
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
}
for (size_t i = 0; i < kTimeSlice; ++i) {
auto &tensors = step_tensors[i];
for (auto j : tensors) {
scheduler->GetOrMalloc(tensor_keys.data() + j, 1);
}
scheduler->PostCompute(stream);
}
scheduler->set_need_record_event(false);
// optimize
scheduler->OptMemUsage();
scheduler->set_optimized(true);
// run
scheduler->Reset();
for (auto index : init_tensors) {
scheduler->Init(tensor_keys.data() + index, tensor_datas.data() + index, 1, kMemPriorityHigh);
}
for (size_t i = 0; i < kTimeSlice; ++i) {
scheduler->PreCompute(stream);
auto &tensors = step_tensors[i];
for (auto j : tensors) {
auto addr = scheduler->GetOrMalloc(tensor_keys.data() + j, 1);
ASSERT_NE(addr, nullptr);
}
scheduler->PostCompute(stream);
}
}
} // namespace mindspore::device