forked from mindspore-Ecosystem/mindspore
add swap strategy builder
This commit is contained in:
parent
f65bb0d92d
commit
556b5446fa
|
@ -18,6 +18,10 @@
|
|||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
auto constexpr kInplaceNodeTypeSkip = "skip";
|
||||
auto constexpr kInplaceNodeTypeAlgo = "inplace_algo";
|
||||
} // namespace
|
||||
namespace device {
|
||||
size_t MemUsageAnalyzer::AddTensorInfo(const AnfNodePtr &node, size_t index, bool is_workspace) {
|
||||
auto add_to_container = [this](const AnfNodePtr &node, size_t index,
|
||||
|
@ -149,8 +153,10 @@ void MemUsageAnalyzer::AddKernelAndTensorInfo(const KernelGraphPtr &graph) {
|
|||
auto real_kernel_num = exec_order.size();
|
||||
kernel_infos_.resize(real_kernel_num);
|
||||
|
||||
auto add_tensor_usage = [this](size_t tensor_id, size_t kernel_id, size_t *kernel_mem) {
|
||||
auto add_tensor_usage = [this](size_t tensor_id, size_t kernel_id, size_t *kernel_mem, bool inplace) {
|
||||
auto tensor_info = GetMemUsageTensorInfo(tensor_id);
|
||||
MS_EXCEPTION_IF_NULL(tensor_info);
|
||||
tensor_info->is_inplace_tensor_ = inplace;
|
||||
(void)tensor_info->used_by_kernels_.emplace_back(kernel_id);
|
||||
*kernel_mem += tensor_info->tensor_size_;
|
||||
};
|
||||
|
@ -162,6 +168,8 @@ void MemUsageAnalyzer::AddKernelAndTensorInfo(const KernelGraphPtr &graph) {
|
|||
auto kernel_info = std::make_shared<MemUsageKernelInfo>();
|
||||
kernel_info->is_comm_ = common::AnfAlgo::IsCommunicationOp(node);
|
||||
kernel_info->update_input_ = common::AnfAlgo::IsUpdateParameterKernel(node);
|
||||
bool inplace_node = common::AnfAlgo::IsInplaceNode(node, kInplaceNodeTypeSkip) ||
|
||||
common::AnfAlgo::IsInplaceNode(node, kInplaceNodeTypeAlgo);
|
||||
|
||||
// Memory used by this kernel
|
||||
size_t kernel_mem = 0;
|
||||
|
@ -169,18 +177,30 @@ void MemUsageAnalyzer::AddKernelAndTensorInfo(const KernelGraphPtr &graph) {
|
|||
// Add input tensors
|
||||
const auto input_num = kernel_mod->GetInputSizeList().size();
|
||||
for (size_t index = 0; index < input_num; ++index) {
|
||||
const auto &prev_node_output = common::AnfAlgo::GetPrevNodeOutput(node, index, true);
|
||||
auto prev_node_output = common::AnfAlgo::GetPrevNodeOutput(node, index, true);
|
||||
if (graph->IsInRefOutputMap(prev_node_output)) {
|
||||
prev_node_output = graph->GetRefCorrespondOutput(prev_node_output);
|
||||
}
|
||||
auto tensor_id = AddTensorInfo(prev_node_output.first, prev_node_output.second);
|
||||
(void)kernel_info->input_tensors_.emplace_back(tensor_id);
|
||||
add_tensor_usage(tensor_id, i, &kernel_mem);
|
||||
add_tensor_usage(tensor_id, i, &kernel_mem, false);
|
||||
}
|
||||
|
||||
// Add output tensors
|
||||
const auto output_num = kernel_mod->GetOutputSizeList().size();
|
||||
for (size_t index = 0; index < output_num; ++index) {
|
||||
auto tensor_id = AddTensorInfo(node, index);
|
||||
(void)kernel_info->output_tensors_.emplace_back(tensor_id);
|
||||
add_tensor_usage(tensor_id, i, &kernel_mem);
|
||||
if (graph->IsInRefOutputMap({node, index})) {
|
||||
auto real_node_pair = graph->GetRefCorrespondOutput({node, index});
|
||||
if (real_node_pair.first != node) {
|
||||
auto tensor_id = AddTensorInfo(real_node_pair.first, real_node_pair.second);
|
||||
(void)kernel_info->input_tensors_.emplace_back(tensor_id);
|
||||
add_tensor_usage(tensor_id, i, &kernel_mem, inplace_node);
|
||||
}
|
||||
} else {
|
||||
auto tensor_id = AddTensorInfo(node, index);
|
||||
(void)kernel_info->output_tensors_.emplace_back(tensor_id);
|
||||
add_tensor_usage(tensor_id, i, &kernel_mem, inplace_node);
|
||||
}
|
||||
}
|
||||
|
||||
// Add workspace tensors
|
||||
|
@ -188,7 +208,7 @@ void MemUsageAnalyzer::AddKernelAndTensorInfo(const KernelGraphPtr &graph) {
|
|||
for (size_t index = 0; index < workspace_num; ++index) {
|
||||
auto tensor_id = AddTensorInfo(node, index, true);
|
||||
(void)kernel_info->workspace_tensors_.emplace_back(tensor_id);
|
||||
add_tensor_usage(tensor_id, i, &kernel_mem);
|
||||
add_tensor_usage(tensor_id, i, &kernel_mem, false);
|
||||
}
|
||||
|
||||
if (kernel_mem > least_mem_) {
|
||||
|
|
|
@ -31,6 +31,7 @@ struct MemUsageTensorInfo {
|
|||
bool is_workspace_{false};
|
||||
bool is_graph_output_{false};
|
||||
bool is_graph_input_{false};
|
||||
bool is_inplace_tensor_{false};
|
||||
std::vector<size_t> used_by_kernels_;
|
||||
std::vector<size_t> fused_tensor_ids_;
|
||||
};
|
||||
|
@ -81,6 +82,17 @@ struct SwapStrategy {
|
|||
std::vector<std::shared_ptr<MemUsageTensorInfo>> tensor_infos_;
|
||||
std::vector<std::shared_ptr<MemUsageKernelInfo>> kernel_infos_;
|
||||
};
|
||||
|
||||
class SwapContext {
|
||||
public:
|
||||
size_t hbm_mem_size_{0};
|
||||
size_t ddr_mem_size_{0};
|
||||
size_t disk_mem_size_{0};
|
||||
bool offload_param_to_ddr_{false};
|
||||
bool offload_param_to_disk_{false};
|
||||
bool offload_checkpoint_to_ddr_{false};
|
||||
bool offload_checkpoint_to_disk_{false};
|
||||
};
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_SWAP_STRATEGY_H_
|
||||
|
|
|
@ -0,0 +1,403 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "runtime/device/gsm/swap_strategy_builder.h"
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <functional>
|
||||
#include "runtime/device/gsm/swap_strategy.h"
|
||||
#include "runtime/device/gsm/mem_usage_analyzer.h"
|
||||
#include "backend/common/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace {
|
||||
template <typename T>
|
||||
void CheckVectorIndex(const std::vector<T> &input, size_t index) {
|
||||
if (input.size() <= index) {
|
||||
MS_LOG_EXCEPTION << "Invalid vector index " << index << ", vector size is " << input.size();
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
const size_t kSwapVirtualNodeNum = 2; // Mark graph start and end node as virtual node
|
||||
void SwapStrategyBuilder::ResetState(const KernelGraphPtr &graph, const std::shared_ptr<SwapContext> &context) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
context_ = context;
|
||||
total_mem_level0_ = context->hbm_mem_size_;
|
||||
total_mem_level1_ = context->ddr_mem_size_;
|
||||
|
||||
kernel_num_ = graph->execution_order().size();
|
||||
|
||||
mem_used_level0_.clear();
|
||||
mem_used_level0_.resize(kernel_num_, 0);
|
||||
mem_used_level1_.clear();
|
||||
mem_used_level1_.resize(kernel_num_, 0);
|
||||
|
||||
span_level1_.clear();
|
||||
span_level2_.clear();
|
||||
auto tmp_queue = std::priority_queue<std::shared_ptr<Span>, std::vector<std::shared_ptr<Span>>, SpanCmp>();
|
||||
span_queue_.swap(tmp_queue);
|
||||
|
||||
analyzer_ = std::make_shared<MemUsageAnalyzer>();
|
||||
|
||||
kernel_actions_.clear();
|
||||
kernel_actions_.resize(kernel_num_ + kSwapVirtualNodeNum);
|
||||
}
|
||||
|
||||
void SwapStrategyBuilder::AnalyzeGraph(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(analyzer_);
|
||||
analyzer_->Analyze(graph);
|
||||
if (analyzer_->LeastMemNeeded() > total_mem_level0_) {
|
||||
MS_LOG(EXCEPTION) << "Need " << analyzer_->LeastMemNeeded() << " at least, but total mem is " << total_mem_level0_;
|
||||
}
|
||||
}
|
||||
|
||||
void SwapStrategyBuilder::RecordSpan(const std::shared_ptr<MemUsageTensorInfo> &info, size_t last_index,
|
||||
size_t current_index, bool output_span) {
|
||||
auto dist = current_index - last_index;
|
||||
if (dist <= 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(info);
|
||||
MS_EXCEPTION_IF_NULL(context_);
|
||||
|
||||
auto span = std::make_shared<Span>();
|
||||
span->tensor_id_ = info->tensor_id_;
|
||||
span->tensor_size_ = info->tensor_size_;
|
||||
span->last_index_ = last_index;
|
||||
span->current_index_ = current_index;
|
||||
span->weight_ = (dist - 1) * info->tensor_size_;
|
||||
span->output_span_ = output_span;
|
||||
|
||||
bool offload_param = context_->offload_param_to_ddr_ || context_->offload_param_to_disk_;
|
||||
bool offload_checkpoint = context_->offload_checkpoint_to_ddr_ || context_->offload_checkpoint_to_disk_;
|
||||
if (offload_param && info->node_ != nullptr && !info->node_->isa<CNode>()) {
|
||||
(void)offload_param_spans_.emplace_back(span);
|
||||
} else if (offload_checkpoint && info->node_ != nullptr && info->node_->isa<CNode>()) {
|
||||
auto cnode = info->node_->cast<CNodePtr>();
|
||||
if (cnode != nullptr && cnode->HasAttr("checkpoint")) {
|
||||
(void)offload_checkpoint_spans_.emplace_back(span);
|
||||
} else {
|
||||
span_queue_.emplace(span);
|
||||
}
|
||||
} else {
|
||||
span_queue_.emplace(span);
|
||||
}
|
||||
}
|
||||
|
||||
void SwapStrategyBuilder::BuildSpans() {
|
||||
MS_EXCEPTION_IF_NULL(analyzer_);
|
||||
auto &tensor_infos = analyzer_->GetMemUsageTensorInfos();
|
||||
for (auto info : tensor_infos) {
|
||||
MS_EXCEPTION_IF_NULL(info);
|
||||
if (info->tensor_id_ != info->real_tensor_id_) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto &used_by_kernels = info->used_by_kernels_;
|
||||
if (used_by_kernels.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (info->is_inplace_tensor_) {
|
||||
for (size_t i = 0; i < used_by_kernels.size(); ++i) {
|
||||
size_t current_index = used_by_kernels[i];
|
||||
mem_used_level0_[current_index] += info->tensor_size_;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t last_index = used_by_kernels[0];
|
||||
CheckVectorIndex(mem_used_level0_, last_index);
|
||||
mem_used_level0_[last_index] += info->tensor_size_;
|
||||
|
||||
for (size_t i = 1; i < used_by_kernels.size(); ++i) {
|
||||
size_t current_index = used_by_kernels[i];
|
||||
CheckVectorIndex(mem_used_level0_, current_index);
|
||||
mem_used_level0_[current_index] += info->tensor_size_;
|
||||
RecordSpan(info, last_index, current_index);
|
||||
last_index = current_index;
|
||||
}
|
||||
|
||||
if (info->is_graph_output_) {
|
||||
RecordSpan(info, last_index, kernel_num_, true);
|
||||
}
|
||||
|
||||
if (info->is_graph_input_) {
|
||||
RecordSpan(info, last_index, used_by_kernels[0] + kernel_num_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool SwapStrategyBuilder::EnoughSpaceForSpan(const std::shared_ptr<Span> &span, std::vector<size_t> *mem_used,
|
||||
size_t total_mem_size) {
|
||||
MS_EXCEPTION_IF_NULL(span);
|
||||
MS_EXCEPTION_IF_NULL(mem_used);
|
||||
CheckVectorIndex(*mem_used, kernel_num_ - 1);
|
||||
for (size_t index = span->last_index_ + 1; index < span->current_index_; ++index) {
|
||||
(*mem_used)[index % kernel_num_] += span->tensor_size_;
|
||||
if ((*mem_used)[index % kernel_num_] > total_mem_size) {
|
||||
for (size_t r_index = span->last_index_ + 1; r_index <= index; ++r_index) {
|
||||
(*mem_used)[r_index % kernel_num_] -= span->tensor_size_;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void SwapStrategyBuilder::ClassifyOffloadSpanLevel(const std::vector<std::shared_ptr<Span>> &spans,
|
||||
bool offload_to_ddr) {
|
||||
for (auto const &span : spans) {
|
||||
bool offload_to_mem_level1 = false;
|
||||
if (offload_to_ddr) {
|
||||
offload_to_mem_level1 = EnoughSpaceForSpan(span, &mem_used_level1_, total_mem_level1_);
|
||||
}
|
||||
|
||||
if (offload_to_mem_level1) {
|
||||
(void)span_level1_.emplace_back(span);
|
||||
} else {
|
||||
(void)span_level2_.emplace_back(span);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SwapStrategyBuilder::ClassifySpanLevel() {
|
||||
MS_EXCEPTION_IF_NULL(context_);
|
||||
ClassifyOffloadSpanLevel(offload_param_spans_, context_->offload_param_to_ddr_);
|
||||
offload_param_spans_.clear();
|
||||
ClassifyOffloadSpanLevel(offload_checkpoint_spans_, context_->offload_checkpoint_to_ddr_);
|
||||
offload_checkpoint_spans_.clear();
|
||||
|
||||
while (!span_queue_.empty()) {
|
||||
auto span = span_queue_.top();
|
||||
bool enough = EnoughSpaceForSpan(span, &mem_used_level0_, total_mem_level0_);
|
||||
if (!enough) {
|
||||
enough = EnoughSpaceForSpan(span, &mem_used_level1_, total_mem_level1_);
|
||||
if (enough) {
|
||||
(void)span_level1_.emplace_back(span);
|
||||
} else {
|
||||
(void)span_level2_.emplace_back(span);
|
||||
}
|
||||
}
|
||||
span_queue_.pop();
|
||||
}
|
||||
}
|
||||
|
||||
void SwapStrategyBuilder::AddTensorAction(SwapActionType action_type, size_t tensor_id, size_t kernel_id) {
|
||||
MS_EXCEPTION_IF_NULL(analyzer_);
|
||||
auto action = std::make_shared<TensorAction>();
|
||||
action->action_ = action_type;
|
||||
action->tensor_id_ = tensor_id;
|
||||
|
||||
if (kernel_id > 0 && (action_type == SwapActionType::kHBM2DDR || action_type == SwapActionType::kHBM2DISK)) {
|
||||
auto kernel_info = analyzer_->GetMemUsageKernelInfo(kernel_id - 1);
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
if (!kernel_info->update_input_) {
|
||||
action->avoid_copy_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
CheckVectorIndex(kernel_actions_, kernel_id);
|
||||
(void)kernel_actions_[kernel_id].emplace_back(action);
|
||||
}
|
||||
|
||||
void SwapStrategyBuilder::AddFusedTensorSpan(const std::shared_ptr<MemUsageTensorInfo> &info, size_t start_index,
|
||||
size_t current_kernel_id) {
|
||||
MS_EXCEPTION_IF_NULL(analyzer_);
|
||||
MS_EXCEPTION_IF_NULL(info);
|
||||
for (auto sub_tensor_id : info->fused_tensor_ids_) {
|
||||
auto sub_tensor_info = analyzer_->GetMemUsageTensorInfo(sub_tensor_id);
|
||||
std::vector<size_t> used_before;
|
||||
std::vector<size_t> used_after;
|
||||
for (auto kid : sub_tensor_info->used_by_kernels_) {
|
||||
if (kid < start_index) {
|
||||
(void)used_before.emplace_back(kid);
|
||||
}
|
||||
|
||||
if (kid >= current_kernel_id) {
|
||||
(void)used_after.emplace_back(kid);
|
||||
}
|
||||
}
|
||||
if (!used_before.empty()) {
|
||||
size_t last = used_before.size() - 1;
|
||||
CheckVectorIndex(mem_used_level0_, used_before[last]);
|
||||
mem_used_level0_[used_before[last]] += sub_tensor_info->tensor_size_;
|
||||
for (size_t i = 0; i < last; ++i) {
|
||||
CheckVectorIndex(mem_used_level0_, used_before[i]);
|
||||
mem_used_level0_[used_before[i]] += sub_tensor_info->tensor_size_;
|
||||
RecordSpan(sub_tensor_info, used_before[i], used_before[i + 1]);
|
||||
}
|
||||
|
||||
auto span = std::make_shared<Span>();
|
||||
span->tensor_id_ = sub_tensor_info->tensor_id_;
|
||||
span->tensor_size_ = sub_tensor_info->tensor_size_;
|
||||
span->last_index_ = used_before[last];
|
||||
span->current_index_ = start_index;
|
||||
bool enough_space = EnoughSpaceForSpan(span, &mem_used_level1_, total_mem_level1_);
|
||||
if (enough_space) {
|
||||
AddTensorAction(SwapActionType::kHBM2DDR, span->tensor_id_, span->last_index_ + 1);
|
||||
AddTensorAction(SwapActionType::kDDR2HBM, span->tensor_id_, span->current_index_);
|
||||
} else {
|
||||
AddTensorAction(SwapActionType::kHBM2DISK, span->tensor_id_, span->last_index_ + 1);
|
||||
AddTensorAction(SwapActionType::kDISK2HBM, span->tensor_id_, span->current_index_);
|
||||
}
|
||||
}
|
||||
|
||||
if (!used_after.empty()) {
|
||||
for (size_t i = 1; i < used_after.size(); ++i) {
|
||||
CheckVectorIndex(mem_used_level0_, used_after[i]);
|
||||
mem_used_level0_[used_after[i]] += sub_tensor_info->tensor_size_;
|
||||
RecordSpan(sub_tensor_info, used_after[i - 1], used_after[i]);
|
||||
}
|
||||
|
||||
if (sub_tensor_info->is_graph_output_) {
|
||||
RecordSpan(sub_tensor_info, used_after[used_after.size() - 1], kernel_num_, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SwapStrategyBuilder::HandleFusedTensor() {
|
||||
MS_EXCEPTION_IF_NULL(analyzer_);
|
||||
auto &tensor_infos = analyzer_->GetMemUsageTensorInfos();
|
||||
for (const auto &info : tensor_infos) {
|
||||
MS_EXCEPTION_IF_NULL(info);
|
||||
if (info->node_ != nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto &used_by_kernels = info->used_by_kernels_;
|
||||
if (used_by_kernels.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t current_kernel_id = used_by_kernels[0];
|
||||
std::set<size_t, std::greater<size_t>> reference_kernels_before;
|
||||
for (auto sub_tensor_id : info->fused_tensor_ids_) {
|
||||
auto sub_tensor_info = analyzer_->GetMemUsageTensorInfo(sub_tensor_id);
|
||||
for (auto kid : sub_tensor_info->used_by_kernels_) {
|
||||
if (kid >= current_kernel_id) {
|
||||
continue;
|
||||
}
|
||||
auto iter = reference_kernels_before.find(kid);
|
||||
if (iter == reference_kernels_before.end()) {
|
||||
(void)reference_kernels_before.insert(kid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t last_index = current_kernel_id;
|
||||
size_t start_index = current_kernel_id;
|
||||
for (const auto &kid : reference_kernels_before) {
|
||||
auto span = std::make_shared<Span>();
|
||||
span->tensor_id_ = info->tensor_id_;
|
||||
span->tensor_size_ = info->tensor_size_;
|
||||
span->last_index_ = kid - 1;
|
||||
span->current_index_ = last_index;
|
||||
bool enough_space = EnoughSpaceForSpan(span, &mem_used_level0_, total_mem_level0_);
|
||||
if (!enough_space) {
|
||||
start_index = last_index;
|
||||
break;
|
||||
}
|
||||
last_index = kid;
|
||||
start_index = last_index;
|
||||
}
|
||||
|
||||
AddTensorAction(SwapActionType::kAllocHBM, info->tensor_id_, start_index);
|
||||
|
||||
AddFusedTensorSpan(info, start_index, current_kernel_id);
|
||||
}
|
||||
}
|
||||
|
||||
void SwapStrategyBuilder::SpanToTensorAction() {
|
||||
for (auto span : span_level1_) {
|
||||
MS_EXCEPTION_IF_NULL(span);
|
||||
AddTensorAction(SwapActionType::kHBM2DDR, span->tensor_id_, span->last_index_ + 1);
|
||||
if (!span->output_span_) {
|
||||
AddTensorAction(SwapActionType::kDDR2HBM, span->tensor_id_, span->current_index_ % kernel_num_);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto span : span_level2_) {
|
||||
MS_EXCEPTION_IF_NULL(span);
|
||||
AddTensorAction(SwapActionType::kHBM2DISK, span->tensor_id_, span->last_index_ + 1);
|
||||
if (!span->output_span_) {
|
||||
AddTensorAction(SwapActionType::kDISK2HBM, span->tensor_id_, span->current_index_ % kernel_num_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<SwapStrategy> SwapStrategyBuilder::BuildStrategy(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(analyzer_);
|
||||
auto &exec_order = graph->execution_order();
|
||||
if (exec_order.size() != kernel_num_) {
|
||||
MS_LOG(EXCEPTION) << "Kernel num error !!!";
|
||||
}
|
||||
|
||||
auto strategy = std::make_shared<SwapStrategy>();
|
||||
strategy->kernel_num_ = kernel_num_;
|
||||
strategy->virtual_node_num_ = kSwapVirtualNodeNum;
|
||||
for (size_t i = 0; i < kernel_num_; ++i) {
|
||||
strategy->nodes_[i + 1] = exec_order[i];
|
||||
(void)strategy->links_.emplace_back(std::make_shared<SwapLink>(i, i + 1));
|
||||
}
|
||||
|
||||
size_t logic_kernel_num = kernel_actions_.size();
|
||||
size_t action_id = logic_kernel_num;
|
||||
for (size_t i = 0; i < logic_kernel_num; ++i) {
|
||||
auto &actions = kernel_actions_[i];
|
||||
if (actions.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto swap_action = std::make_shared<SwapAction>();
|
||||
swap_action->actions_ = actions;
|
||||
strategy->actions_[action_id] = swap_action;
|
||||
(void)strategy->links_.emplace_back(std::make_shared<SwapLink>(i, action_id));
|
||||
(void)strategy->links_.emplace_back(std::make_shared<SwapLink>(action_id, i + 1));
|
||||
++action_id;
|
||||
}
|
||||
|
||||
strategy->kernel_infos_ = analyzer_->GetMemUsageKernelInfos();
|
||||
strategy->tensor_infos_ = analyzer_->GetMemUsageTensorInfos();
|
||||
return strategy;
|
||||
}
|
||||
|
||||
std::shared_ptr<SwapStrategy> SwapStrategyBuilder::Build(const KernelGraphPtr &graph,
|
||||
const std::shared_ptr<SwapContext> &context) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
ResetState(graph, context);
|
||||
|
||||
AnalyzeGraph(graph);
|
||||
|
||||
BuildSpans();
|
||||
|
||||
HandleFusedTensor();
|
||||
|
||||
ClassifySpanLevel();
|
||||
|
||||
SpanToTensorAction();
|
||||
|
||||
return BuildStrategy(graph);
|
||||
}
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,81 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_SWAP_STRATEGY_BUILDER_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_SWAP_STRATEGY_BUILDER_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include "runtime/device/gsm/swap_strategy.h"
|
||||
#include "runtime/device/gsm/mem_usage_analyzer.h"
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
class SwapStrategyBuilder {
|
||||
public:
|
||||
SwapStrategyBuilder() = default;
|
||||
~SwapStrategyBuilder() = default;
|
||||
std::shared_ptr<SwapStrategy> Build(const KernelGraphPtr &graph, const std::shared_ptr<SwapContext> &context);
|
||||
|
||||
private:
|
||||
struct Span {
|
||||
size_t tensor_id_{0};
|
||||
size_t tensor_size_{0};
|
||||
size_t last_index_{0};
|
||||
size_t current_index_{0};
|
||||
size_t weight_{0};
|
||||
bool output_span_{false};
|
||||
};
|
||||
|
||||
struct SpanCmp {
|
||||
bool operator()(const std::shared_ptr<Span> &left, const std::shared_ptr<Span> &right) {
|
||||
if (left == nullptr || right == nullptr) {
|
||||
return true;
|
||||
}
|
||||
return left->weight_ > right->weight_;
|
||||
}
|
||||
};
|
||||
|
||||
std::shared_ptr<MemUsageAnalyzer> analyzer_{nullptr};
|
||||
std::shared_ptr<SwapContext> context_{nullptr};
|
||||
size_t kernel_num_{0};
|
||||
std::priority_queue<std::shared_ptr<Span>, std::vector<std::shared_ptr<Span>>, SpanCmp> span_queue_;
|
||||
std::vector<std::shared_ptr<Span>> offload_param_spans_;
|
||||
std::vector<std::shared_ptr<Span>> offload_checkpoint_spans_;
|
||||
std::vector<std::shared_ptr<Span>> span_level1_;
|
||||
std::vector<std::shared_ptr<Span>> span_level2_;
|
||||
std::vector<size_t> mem_used_level0_;
|
||||
std::vector<size_t> mem_used_level1_;
|
||||
size_t total_mem_level0_{0};
|
||||
size_t total_mem_level1_{0};
|
||||
std::vector<std::vector<std::shared_ptr<TensorAction>>> kernel_actions_;
|
||||
|
||||
void ResetState(const KernelGraphPtr &graph, const std::shared_ptr<SwapContext> &context);
|
||||
void AnalyzeGraph(const KernelGraphPtr &graph);
|
||||
void BuildSpans();
|
||||
void ClassifyOffloadSpanLevel(const std::vector<std::shared_ptr<Span>> &spans, bool offload_to_ddr);
|
||||
void ClassifySpanLevel();
|
||||
void AddFusedTensorSpan(const std::shared_ptr<MemUsageTensorInfo> &info, size_t start_index,
|
||||
size_t current_kernel_id);
|
||||
void HandleFusedTensor();
|
||||
void SpanToTensorAction();
|
||||
void RecordSpan(const std::shared_ptr<MemUsageTensorInfo> &info, size_t last_index, size_t current_index,
|
||||
bool output_span = false);
|
||||
bool EnoughSpaceForSpan(const std::shared_ptr<Span> &span, std::vector<size_t> *mem_used, size_t total_mem_size);
|
||||
void AddTensorAction(SwapActionType action_type, size_t tensor_id, size_t kernel_id);
|
||||
std::shared_ptr<SwapStrategy> BuildStrategy(const KernelGraphPtr &graph);
|
||||
};
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_SWAP_STRATEGY_BUILDER_H_
|
|
@ -97,6 +97,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)
|
|||
${CCSRC_DIR}/runtime/device/memory_manager.cc
|
||||
${CCSRC_DIR}/runtime/device/auto_mem_offload.cc
|
||||
${CCSRC_DIR}/runtime/device/gsm/mem_usage_analyzer.cc
|
||||
${CCSRC_DIR}/runtime/device/gsm/swap_strategy_builder.cc
|
||||
${CCSRC_DIR}/runtime/device/common_somas_allocator.cc
|
||||
${CCSRC_DIR}/runtime/pynative/op_runtime_info.cc
|
||||
${CCSRC_DIR}/runtime/hardware/device_type.cc
|
||||
|
|
|
@ -16,6 +16,7 @@ from mindspore.ops import operations as P
|
|||
|
||||
add = P.Add()
|
||||
mul = P.Mul()
|
||||
all_reduce = P.AllReduce().add_prim_attr("fusion", 1)
|
||||
|
||||
|
||||
def add_net(x1, x2, x3, x4, x5):
|
||||
|
@ -27,10 +28,6 @@ def add_net(x1, x2, x3, x4, x5):
|
|||
return ret
|
||||
|
||||
|
||||
all_reduce = P.AllReduce().add_prim_attr("fusion", 1)
|
||||
mul = P.Mul()
|
||||
|
||||
|
||||
def all_reduce_net(x1, x2, x3):
|
||||
product = mul(x1, x2)
|
||||
sum1 = add(x2, x3)
|
||||
|
@ -38,3 +35,15 @@ def all_reduce_net(x1, x2, x3):
|
|||
reduce2 = all_reduce(sum1)
|
||||
res = add(reduce1, reduce2)
|
||||
return res
|
||||
|
||||
|
||||
def add_with_all_reduce_net(x1, x2, x3, x4, x5):
|
||||
a1 = all_reduce(add(x1, x2))
|
||||
a2 = all_reduce(add(x2, x3))
|
||||
a3 = all_reduce(add(x3, x4))
|
||||
sum1 = add(a1, x3)
|
||||
sum2 = add(a2, x4)
|
||||
sum3 = add(a3, x5)
|
||||
sum4 = add(sum1, sum2)
|
||||
ret = mul(sum3, sum4)
|
||||
return ret
|
||||
|
|
|
@ -0,0 +1,138 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "common/common_test.h"
|
||||
#include "common/backend_common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "runtime/device/gsm/swap_strategy_builder.h"
|
||||
|
||||
namespace mindspore::device {
|
||||
class TestSwapStrategyBuilder : public BackendCommon {
|
||||
public:
|
||||
TestSwapStrategyBuilder() : get_py_func_("gtest_input.runtime.device.gsm.mem_usage_analyzer_test", true) {}
|
||||
|
||||
void SetUp() override {
|
||||
auto net = get_py_func_("add_net");
|
||||
EXPECT_NE(net, nullptr);
|
||||
std::vector<int64_t> shp_x{1, 2, 2, 2};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
||||
AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract};
|
||||
auto func_graph = GetFuncGraph(net, args_spec_list);
|
||||
kernel_graph_add_net_ = Compile(func_graph);
|
||||
|
||||
net = get_py_func_("add_with_all_reduce_net");
|
||||
EXPECT_NE(net, nullptr);
|
||||
func_graph = GetFuncGraph(net, args_spec_list);
|
||||
kernel_graph_add_with_all_reduce_net_ = Compile(func_graph);
|
||||
}
|
||||
|
||||
UT::PyFuncGraphFetcher get_py_func_;
|
||||
std::shared_ptr<session::KernelGraph> kernel_graph_add_net_;
|
||||
std::shared_ptr<session::KernelGraph> kernel_graph_add_with_all_reduce_net_;
|
||||
};
|
||||
|
||||
/// Feature: SwapStrategyBuilder
|
||||
/// Description: Test SwapStrategyBuilder with variable mem size
|
||||
/// Expectation: Pass all test cases
|
||||
TEST_F(TestSwapStrategyBuilder, test_swap_strategy_with_variable_mem_size) {
|
||||
auto builder = std::make_shared<SwapStrategyBuilder>();
|
||||
auto context = std::make_shared<SwapContext>();
|
||||
auto kernel_graph = kernel_graph_add_net_;
|
||||
EXPECT_NE(kernel_graph, nullptr);
|
||||
std::vector<std::vector<size_t>> inputs = {{10000, 10000}, {10000, 136}, {100, 136}};
|
||||
std::vector<std::vector<size_t>> expects = {{5, 2, 5, 0, 5, 0}, {5, 2, 5, 5, 15, 10}, {5, 2, 5, 5, 15, 10}};
|
||||
for (size_t i = 0; i < 3; ++i) {
|
||||
context->ddr_mem_size_ = inputs[i][0];
|
||||
context->hbm_mem_size_ = inputs[i][1];
|
||||
auto strategy = builder->Build(kernel_graph, context);
|
||||
EXPECT_NE(strategy, nullptr);
|
||||
EXPECT_EQ(strategy->kernel_num_, expects[i][0]);
|
||||
EXPECT_EQ(strategy->virtual_node_num_, expects[i][1]);
|
||||
EXPECT_EQ(strategy->nodes_.size(), expects[i][2]);
|
||||
EXPECT_EQ(strategy->actions_.size(), expects[i][3]);
|
||||
EXPECT_EQ(strategy->links_.size(), expects[i][4]);
|
||||
std::vector<std::shared_ptr<TensorAction>> all_actions;
|
||||
for (auto const &item : strategy->actions_) {
|
||||
for (auto const &action : item.second->actions_) {
|
||||
(void)all_actions.emplace_back(action);
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(all_actions.size(), expects[i][5]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Feature: SwapStrategyBuilder
|
||||
/// Description: Test SwapStrategyBuilder with offload param
|
||||
/// Expectation: Pass all test cases
|
||||
TEST_F(TestSwapStrategyBuilder, test_swap_strategy_with_offload_param) {
|
||||
auto builder = std::make_shared<SwapStrategyBuilder>();
|
||||
auto context = std::make_shared<SwapContext>();
|
||||
auto kernel_graph = kernel_graph_add_net_;
|
||||
EXPECT_NE(kernel_graph, nullptr);
|
||||
|
||||
context->ddr_mem_size_ = 10000;
|
||||
context->hbm_mem_size_ = 10000;
|
||||
std::vector<std::vector<size_t>> inputs = {{true, false}, {false, true}, {true, true}};
|
||||
std::vector<std::vector<size_t>> expects = {{5, 2, 5, 5, 15, 10}, {5, 2, 5, 5, 15, 10}, {5, 2, 5, 5, 15, 10}};
|
||||
for (size_t i = 0; i < 3; ++i) {
|
||||
context->offload_param_to_ddr_ = inputs[i][0];
|
||||
context->offload_param_to_disk_ = inputs[i][1];
|
||||
auto strategy = builder->Build(kernel_graph, context);
|
||||
EXPECT_NE(strategy, nullptr);
|
||||
EXPECT_EQ(strategy->kernel_num_, expects[i][0]);
|
||||
EXPECT_EQ(strategy->virtual_node_num_, expects[i][1]);
|
||||
EXPECT_EQ(strategy->nodes_.size(), expects[i][2]);
|
||||
EXPECT_EQ(strategy->actions_.size(), expects[i][3]);
|
||||
EXPECT_EQ(strategy->links_.size(), expects[i][4]);
|
||||
std::vector<std::shared_ptr<TensorAction>> all_actions;
|
||||
for (auto const &item : strategy->actions_) {
|
||||
for (auto const &action : item.second->actions_) {
|
||||
(void)all_actions.emplace_back(action);
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(all_actions.size(), expects[i][5]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Feature: SwapStrategyBuilder
|
||||
/// Description: Test SwapStrategyBuilder with all reduce nodes
|
||||
/// Expectation: Pass all test cases
|
||||
TEST_F(TestSwapStrategyBuilder, test_swap_strategy_with_all_reduce_nodes) {
|
||||
auto builder = std::make_shared<SwapStrategyBuilder>();
|
||||
auto context = std::make_shared<SwapContext>();
|
||||
auto kernel_graph = kernel_graph_add_with_all_reduce_net_;
|
||||
EXPECT_NE(kernel_graph, nullptr);
|
||||
|
||||
context->ddr_mem_size_ = 100;
|
||||
context->hbm_mem_size_ = 250;
|
||||
auto strategy = builder->Build(kernel_graph, context);
|
||||
EXPECT_NE(strategy, nullptr);
|
||||
EXPECT_EQ(strategy->kernel_num_, 9);
|
||||
EXPECT_EQ(strategy->virtual_node_num_, 2);
|
||||
EXPECT_EQ(strategy->nodes_.size(), 9);
|
||||
EXPECT_EQ(strategy->actions_.size(), 8);
|
||||
EXPECT_EQ(strategy->links_.size(), 25);
|
||||
std::vector<std::shared_ptr<TensorAction>> all_actions;
|
||||
for (auto const &item : strategy->actions_) {
|
||||
for (auto const &action : item.second->actions_) {
|
||||
(void)all_actions.emplace_back(action);
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(all_actions.size(), 12);
|
||||
}
|
||||
} // namespace mindspore::device
|
Loading…
Reference in New Issue