support parallel fusion
This commit is contained in:
parent
b2cd022c5f
commit
d078cbfa99
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit 20ecddee01cd07d0945240672597d7a36499e537
|
||||
Subproject commit c63b2e6f7e7704f18b217e42c8c5c0b95e04b9fb
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -15,3 +15,4 @@
|
|||
"""init"""
|
||||
from .splitter import split_with_json
|
||||
from .expander import get_op_expander
|
||||
from .parallel_estimate import estimate_calulation_amount, estimate_ops
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -16,3 +16,4 @@
|
|||
|
||||
from .graph_split import split
|
||||
from .model_builder import GraphBuilder, load_composite
|
||||
from .graph_parallel import parallel_estimate
|
||||
|
|
|
@ -0,0 +1,153 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""Cost model for parallel fusion"""
|
||||
from .model import PrimLib
|
||||
|
||||
|
||||
class ParalGain:
|
||||
def __init__(self, fusion_type, bottleneck, gain, block_assign):
|
||||
self.fusion_type = fusion_type
|
||||
self.bottleneck = bottleneck
|
||||
self.gain = gain
|
||||
self.block_assign = block_assign
|
||||
|
||||
|
||||
class ScheduleAnalyzer:
|
||||
"""schedule analyzer"""
|
||||
WRAP_SIZE = 32
|
||||
MAX_SM = 80 # Volta
|
||||
MAX_NUM_THREADS = 1024
|
||||
MAX_BLOCK = 256
|
||||
|
||||
def __init__(self, graph):
|
||||
self.graph = graph
|
||||
self.block_num = 0
|
||||
self.block_weight = 0
|
||||
_, outputs = graph.deduce_parameters()
|
||||
self.ops = graph.ops
|
||||
self.dom_op = [out.op for out in outputs]
|
||||
|
||||
def prod(self, shape):
|
||||
res = shape[0]
|
||||
for i in range(1, len(shape)):
|
||||
res = res * shape[i]
|
||||
return res
|
||||
|
||||
def _cal_weight(self, ops):
|
||||
weight = 0
|
||||
for op in ops:
|
||||
weight += self.prod(op.output.shape) * \
|
||||
PrimLib.dtype_bytes(op.output.dtype)
|
||||
return weight
|
||||
|
||||
def injective_analyze(self):
|
||||
"""analyze injective case"""
|
||||
const_size = max([self.prod(op.output.shape) for op in self.dom_op])
|
||||
const_size = (const_size + self.MAX_NUM_THREADS -
|
||||
1) // self.MAX_NUM_THREADS * self.MAX_NUM_THREADS
|
||||
|
||||
total_weight = self._cal_weight(self.ops)
|
||||
total_block = (const_size + self.MAX_NUM_THREADS -
|
||||
1) // self.MAX_NUM_THREADS
|
||||
need_block_split = const_size > self.MAX_BLOCK * self.MAX_NUM_THREADS
|
||||
if need_block_split:
|
||||
self.block_num = self.MAX_BLOCK
|
||||
waves = (total_block + self.MAX_BLOCK - 1) // self.MAX_BLOCK
|
||||
self.block_weight = total_weight // total_block * waves
|
||||
else:
|
||||
self.block_num = total_block
|
||||
self.block_weight = total_weight // self.block_num
|
||||
|
||||
def reduce_analyze(self):
|
||||
"""analyze reduce case"""
|
||||
thread_x, thread_y = 32, 32
|
||||
reduce_op = None
|
||||
for op in self.ops:
|
||||
if PrimLib.iter_type(op) == PrimLib.REDUCE:
|
||||
if reduce_op:
|
||||
raise RuntimeError(
|
||||
"Not support multiply reduce op in a graph now.")
|
||||
reduce_op = op
|
||||
if not reduce_op:
|
||||
raise RuntimeError("Wrong analyze for reduce!")
|
||||
shape = reduce_op.inputs[0].shape
|
||||
reduce_axis = reduce_op.attrs['reduce_axis']
|
||||
total_space = self.prod(shape)
|
||||
red_space = shape[reduce_axis[0]]
|
||||
for i in range(1, len(reduce_axis)):
|
||||
red_space *= shape[reduce_axis[i]]
|
||||
dtype_size = PrimLib.dtype_bytes(reduce_op.output.dtype)
|
||||
|
||||
weight = self._cal_weight(self.ops) # reduce + injective
|
||||
block_x = (total_space // red_space + thread_y - 1) // thread_y
|
||||
block_w = (weight + block_x - 1) // block_x
|
||||
waves = (block_x + self.MAX_BLOCK - 1) // self.MAX_BLOCK
|
||||
self.block_num = min(self.MAX_BLOCK, block_x)
|
||||
all_reduce = 10 # 1 reduce init + 3 sync + 5 bin + 1 write
|
||||
self.block_weight = (block_w + all_reduce *
|
||||
dtype_size * thread_x * thread_y) * waves
|
||||
|
||||
def default_analyze(self):
|
||||
"""analyze default case"""
|
||||
def _cal_default_space(op):
|
||||
space = self.prod(op.output.shape)
|
||||
for t in op.inputs:
|
||||
size = self.prod(t.shape)
|
||||
if size > space:
|
||||
space = size
|
||||
return space
|
||||
space = max([_cal_default_space(op) for op in self.dom_op])
|
||||
|
||||
# each sm least 4 wrap
|
||||
block = (space + (self.WRAP_SIZE * 4) - 1) // (self.WRAP_SIZE * 4)
|
||||
self.block_num = min(self.MAX_BLOCK, block)
|
||||
self.block_weight = self._cal_weight(self.ops) // self.block_num
|
||||
|
||||
def analyze(self):
|
||||
"""analyze ops"""
|
||||
def _ops_type(ops, dom_op):
|
||||
have_reduce = any(
|
||||
[PrimLib.iter_type(op) == PrimLib.REDUCE for op in ops])
|
||||
if have_reduce:
|
||||
return True
|
||||
return PrimLib.iter_type(dom_op[0])
|
||||
|
||||
dom_type = _ops_type(self.ops, self.dom_op)
|
||||
if dom_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
|
||||
self.injective_analyze()
|
||||
elif dom_type == PrimLib.REDUCE:
|
||||
self.reduce_analyze()
|
||||
else:
|
||||
self.default_analyze()
|
||||
|
||||
|
||||
def block_parallel_estimate(graphs):
|
||||
"""estimate block parallel gain"""
|
||||
sum_block, max_weight, sum_weight, blocks = 0, 0, 0, []
|
||||
for g in graphs:
|
||||
s = ScheduleAnalyzer(g)
|
||||
s.analyze()
|
||||
sum_block += s.block_num
|
||||
if s.block_weight > max_weight:
|
||||
max_weight = s.block_weight
|
||||
sum_weight += s.block_weight
|
||||
blocks.append(s.block_num)
|
||||
if sum_block > ScheduleAnalyzer.MAX_SM * 32:
|
||||
return ParalGain("none", sum_weight, 0, [])
|
||||
return ParalGain("block_fusion", max_weight, sum_weight - max_weight, blocks)
|
||||
|
||||
|
||||
def parallel_estimate(graphs):
|
||||
return block_parallel_estimate(graphs)
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""estimate parallel case"""
|
||||
import json
|
||||
import json.decoder as jd
|
||||
import traceback
|
||||
from mindspore import log as logger
|
||||
from . import model
|
||||
|
||||
def estimate_ops(json_str: str):
|
||||
"""Call costmodel to estimate ops."""
|
||||
try:
|
||||
json_obj = json.loads(json_str)
|
||||
graph_descs = json_obj["graph_desc"]
|
||||
graphs = []
|
||||
for gd in graph_descs:
|
||||
graphs.append(model.load_composite(gd).graph)
|
||||
estimation = model.parallel_estimate(graphs)
|
||||
if estimation.fusion_type == "block_fusion" and estimation.gain > 0:
|
||||
res = (estimation.block_assign, estimation.gain)
|
||||
else:
|
||||
res = ([0 for g in graphs], 0)
|
||||
return res
|
||||
except jd.JSONDecodeError:
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def estimate_calulation_amount(json_str: str):
|
||||
"""Call costmodel to estimate calculation amount of op."""
|
||||
try:
|
||||
graph_desc = json.loads(json_str)
|
||||
comp = model.load_composite(graph_desc)
|
||||
estimation = model.parallel_estimate([comp.graph])
|
||||
return estimation.bottleneck
|
||||
except jd.JSONDecodeError:
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -120,7 +120,7 @@ class OpInfoExtractor {
|
|||
}
|
||||
}
|
||||
if (op_attr->type().empty()) {
|
||||
MS_LOG(DEBUG) << "Unknow type, ignore attr " << name;
|
||||
MS_LOG(DEBUG) << "Unknown type, ignore attr " << name;
|
||||
continue;
|
||||
}
|
||||
op_info->add_attrs_ptr(op_attr);
|
||||
|
@ -174,7 +174,7 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con
|
|||
// for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
|
||||
auto inputs_ptr = op_info->inputs_ptr();
|
||||
if (inputs_ptr.empty()) {
|
||||
MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] regist info has no input info";
|
||||
MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] info has no input info";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -184,7 +184,7 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con
|
|||
for (size_t i = 0; i < inputs_ptr.size(); i++) {
|
||||
auto input_ptr = inputs_ptr[i];
|
||||
if (input_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] regist input[" << i << "] is nullptr";
|
||||
MS_LOG(ERROR) << "Kernel [" << anf_node->fullname_with_scope() << "] input[" << i << "] is nullptr";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -204,7 +204,8 @@ bool AkgKernelJsonGenerator::CreateInputDescJson(const AnfNodePtr &anf_node, con
|
|||
input_desc_json[kJsonKeyName] = input_ptr->name();
|
||||
input_desc_json[kJsonKeyTensorName] = "input_" + std::to_string(GetInputTensorIdxInc(anf_node, real_input_index));
|
||||
auto input_shape = this->GetInputShape(anf_node, real_input_index);
|
||||
if (AnfAlgo::IsNodeInGraphKernel(anf_node) && GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) {
|
||||
if (dump_option_.extract_opinfo_from_anfnode &&
|
||||
GetInputTensorValue(anf_node, real_input_index, &input_desc_json)) {
|
||||
MS_LOG(DEBUG) << "Take input[" << real_input_index << "] of [" << anf_node->DebugString(2)
|
||||
<< "] as const tensor, shape: [" << Vector2Str(input_shape)
|
||||
<< "], value: " << input_desc_json[kJsonKeyValue];
|
||||
|
@ -555,6 +556,30 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j
|
|||
return true;
|
||||
}
|
||||
|
||||
void AkgKernelJsonGenerator::SetParallelValueToJson(const std::string &processor,
|
||||
const std::map<size_t, size_t> &dim_infos,
|
||||
nlohmann::json *sub_fusion_json) {
|
||||
if (processor == kProcessorCuda) {
|
||||
std::vector<size_t> cnums;
|
||||
std::transform(dim_infos.cbegin(), dim_infos.cend(), std::back_insert_iterator(cnums),
|
||||
[](const std::pair<size_t, size_t> &dim) { return dim.second; });
|
||||
(*sub_fusion_json)[kJsonKeyCoreNum] = cnums;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Parallel fusion not support " << processor << " now.";
|
||||
}
|
||||
}
|
||||
|
||||
void AkgKernelJsonGenerator::AddParalleFusionJsonInfo(const std::string &processor, nlohmann::json *kernel_json) {
|
||||
nlohmann::json parallel_fusion_json;
|
||||
parallel_fusion_json[kJsonKeyFusionType] = "block_fusion";
|
||||
std::vector<std::vector<std::string>> sgraphs;
|
||||
std::transform(sub_graphs_.cbegin(), sub_graphs_.cend(), std::back_insert_iterator(sgraphs),
|
||||
[](const std::pair<int, std::vector<std::string>> &sg) { return sg.second; });
|
||||
parallel_fusion_json[kJsonKeySubGraph] = sgraphs;
|
||||
SetParallelValueToJson(processor, dim_infos_, ¶llel_fusion_json);
|
||||
(*kernel_json)[kJsonKeyParallelFusion] = parallel_fusion_json;
|
||||
}
|
||||
|
||||
bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf_nodes,
|
||||
const std::vector<AnfNodePtr> &input_list,
|
||||
const std::vector<AnfNodePtr> &output_list, nlohmann::json *kernel_json) {
|
||||
|
@ -581,6 +606,13 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
|
|||
(*kernel_json)[kJsonKeyOutputDesc] =
|
||||
CreateOutputsJson(anf_nodes, input_list, output_list, inputs_json, node_json_map);
|
||||
|
||||
auto processor = GetProcessorStr(anf_nodes[0]);
|
||||
|
||||
// Add parallel fusion information.
|
||||
if (!sub_graphs_.empty()) {
|
||||
AddParalleFusionJsonInfo(processor, kernel_json);
|
||||
}
|
||||
|
||||
size_t hash_id = std::hash<std::string>()(kernel_json->dump());
|
||||
kernel_name_ = "Fused_";
|
||||
auto fg = anf_nodes[0]->func_graph();
|
||||
|
@ -601,7 +633,7 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
|
|||
(*kernel_json)[kJsonKeyId] = GetOpCntInc();
|
||||
(*kernel_json)[kJsonKeyOp] = kernel_name_;
|
||||
(*kernel_json)[kJsonKeyPlatform] = "AKG";
|
||||
(*kernel_json)[kJsonKeyProcess] = GetProcessorStr(anf_nodes[0]);
|
||||
(*kernel_json)[kJsonKeyProcess] = processor;
|
||||
(*kernel_json)[kJsonKeyComposite] = true;
|
||||
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString() + "." + fg->debug_info()->get_id();
|
||||
|
||||
|
@ -724,6 +756,17 @@ nlohmann::json AkgKernelJsonGenerator::CreateOutputsJson(const std::vector<AnfNo
|
|||
output_shape.push_back(1);
|
||||
}
|
||||
output_desc_json[kJsonKeyShape] = output_shape;
|
||||
if (auto tcnode = tmp_output.first->cast<CNodePtr>();
|
||||
tcnode && AnfAlgo::HasNodeAttr(kAttrParallelDimInfo, tcnode)) {
|
||||
auto info = AnfAlgo::GetNodeAttr<std::vector<size_t>>(tcnode, kAttrParallelDimInfo);
|
||||
if (info.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Parallel dim info is invalid!";
|
||||
}
|
||||
sub_graphs_[info[0]].push_back(output_desc_json[kJsonKeyTensorName]);
|
||||
if (dim_infos_.find(info[0]) == dim_infos_.end()) {
|
||||
dim_infos_[info[0]] = info[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
outputs_json.emplace_back(output_desc_json);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -49,6 +49,11 @@ constexpr auto kJsonKeyPtrAddress = "ptr_address";
|
|||
constexpr auto kJsonKeyCompositeGraph = "composite_graph";
|
||||
constexpr auto kJsonKeyPlatform = "platform";
|
||||
constexpr auto kJsonKeyOpFullName = "op_full_name";
|
||||
constexpr auto kJsonKeyFusion = "fusion";
|
||||
constexpr auto kJsonKeyParallelFusion = "parallel_fusion";
|
||||
constexpr auto kJsonKeyFusionType = "fusion_type";
|
||||
constexpr auto kJsonKeySubGraph = "sub_graph";
|
||||
constexpr auto kJsonKeyCoreNum = "core_num";
|
||||
|
||||
constexpr auto kAttrInputNames = "input_names";
|
||||
|
||||
|
@ -81,6 +86,8 @@ class AkgKernelJsonGenerator {
|
|||
input_tensor_idx_.clear();
|
||||
address_node_map_.clear();
|
||||
output_tensor_idx_ = 0;
|
||||
sub_graphs_.clear();
|
||||
dim_infos_.clear();
|
||||
}
|
||||
void set_dump_option(DumpOption dump_option) { dump_option_ = dump_option; }
|
||||
std::map<std::string, AnfNodePtr> address_node_map() { return address_node_map_; }
|
||||
|
@ -115,6 +122,9 @@ class AkgKernelJsonGenerator {
|
|||
std::string GetOutputFormat(const AnfNodePtr &anf_node, size_t index);
|
||||
void SaveNodeAddress(const AnfNodePtr &anf_node, nlohmann::json *node_json);
|
||||
OpInfoPtr ExtractOpInfo(const AnfNodePtr &anf_node);
|
||||
void SetParallelValueToJson(const std::string &processor, const std::map<size_t, size_t> &dim_infos,
|
||||
nlohmann::json *sub_fusion_json);
|
||||
void AddParalleFusionJsonInfo(const std::string &processor, nlohmann::json *kernel_json);
|
||||
|
||||
DumpOption dump_option_;
|
||||
static int op_cnt_;
|
||||
|
@ -127,6 +137,8 @@ class AkgKernelJsonGenerator {
|
|||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::map<std::string, AnfNodePtr> address_node_map_;
|
||||
std::map<size_t, std::vector<std::string>> sub_graphs_;
|
||||
std::map<size_t, size_t> dim_infos_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -133,8 +133,10 @@ bool AtomicCleanInsertter::CanActivateAtomicAdd(const AnfNodePtr &anf_node) {
|
|||
if (reduce_cnt != 1) {
|
||||
return false;
|
||||
}
|
||||
real_output_num_ = inputs.size() - 1;
|
||||
} else if (IsPrimitiveCNode(real_return_node, prim::kPrimReduceSum)) {
|
||||
atomic_add_node_ = real_return_node->cast<CNodePtr>();
|
||||
real_output_num_ = 1;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
@ -200,7 +202,6 @@ void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGra
|
|||
auto retrun_node = sub_graph->get_return()->input(kFirstDataInputIndex);
|
||||
if (IsPrimitiveCNode(retrun_node, prim::kPrimMakeTuple)) {
|
||||
const auto &outs = retrun_node->cast<CNodePtr>()->inputs();
|
||||
real_output_num_ = outs.size() - 1;
|
||||
for (size_t i = 1; i < outs.size(); ++i) {
|
||||
if (i != reduce_real_output_index_ + 1) {
|
||||
out_node = outs[i];
|
||||
|
@ -209,7 +210,6 @@ void AtomicCleanInsertter::CreateInplaceAssignNodeAndCorrectReturn(const FuncGra
|
|||
}
|
||||
}
|
||||
} else {
|
||||
real_output_num_ = 1;
|
||||
out_node = atomic_add_node_; // Use result data itself, and set attr "fake_out" true.
|
||||
fake_out = true;
|
||||
}
|
||||
|
@ -456,7 +456,7 @@ std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUs
|
|||
}
|
||||
}
|
||||
for (auto &pair : getitem_user_nodes) {
|
||||
// dirctory to find real user.
|
||||
// Directory to find real user.
|
||||
auto real_users = mng->node_users()[pair.first];
|
||||
reduce_user_nodes.insert(reduce_user_nodes.end(), real_users.begin(), real_users.end());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/depend_formater.h"
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool RemoveRedundantDepend(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
|
||||
const auto &users = mng->node_users()[node];
|
||||
std::vector<std::pair<AnfNodePtr, int>> sons;
|
||||
for (const auto &[user, index] : users) {
|
||||
if (!IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) {
|
||||
sons.emplace_back(user, index);
|
||||
continue;
|
||||
}
|
||||
auto &[fake_first_grad_son, grad_index] = *((mng->node_users()[user]).begin());
|
||||
sons.emplace_back(fake_first_grad_son, grad_index);
|
||||
}
|
||||
|
||||
AnfNodePtrList latter_to_delete;
|
||||
for (const auto &[son, index] : sons) {
|
||||
if (!IsPrimitiveCNode(son, prim::kPrimDepend) || index != kDependAttachNodeIndex) {
|
||||
continue;
|
||||
}
|
||||
|
||||
latter_to_delete.push_back(son);
|
||||
}
|
||||
|
||||
if (latter_to_delete.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr>::iterator delete_begin = latter_to_delete.begin();
|
||||
if (latter_to_delete.size() == sons.size()) {
|
||||
// Left one Depend node relation and delete others!
|
||||
++delete_begin;
|
||||
}
|
||||
for (; delete_begin != latter_to_delete.end(); ++delete_begin) {
|
||||
auto depend_anfnode = *delete_begin;
|
||||
auto depend_cnode = depend_anfnode->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||
auto depend_prior_node = depend_cnode->input(kRealInputIndexInDepend);
|
||||
mng->Replace(depend_anfnode, depend_prior_node);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr FindPatronNode(const FuncGraphPtr &main_graph, const FuncGraphManagerPtr &mng) {
|
||||
AnfNodePtr patron_node;
|
||||
|
||||
auto return_cnode = main_graph->get_return()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(return_cnode);
|
||||
auto output_node = return_cnode->input(kFirstDataInputIndex);
|
||||
if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
|
||||
auto output_cnode = output_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_cnode);
|
||||
patron_node = output_cnode->input(kFirstDataInputIndex);
|
||||
} else {
|
||||
patron_node = output_node;
|
||||
}
|
||||
|
||||
return patron_node;
|
||||
}
|
||||
|
||||
void AddDepends(const AnfNodePtr &stable_node, const AnfNodePtrList &free_nodes, const FuncGraphPtr &main_graph,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
AnfNodePtr modified_node = stable_node;
|
||||
for (const auto &free_node : free_nodes) {
|
||||
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), modified_node, free_node};
|
||||
auto depend_cnode = main_graph->NewCNode(d_inputs);
|
||||
depend_cnode->set_abstract(modified_node->abstract());
|
||||
main_graph->AddNode(depend_cnode);
|
||||
modified_node = depend_cnode;
|
||||
}
|
||||
|
||||
if (!free_nodes.empty()) {
|
||||
mng->Replace(stable_node, modified_node);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool DependFormater::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
|
||||
// 1. Try to remove redundant depend.
|
||||
bool changed = false;
|
||||
auto nodes = TopoSort(func_graph->get_return());
|
||||
std::for_each(nodes.rbegin(), nodes.rend(), [&changed, &mng](const AnfNodePtr &node) {
|
||||
if (RemoveRedundantDepend(node, mng)) {
|
||||
changed = true;
|
||||
}
|
||||
});
|
||||
|
||||
// Should re-toposort for changed graph.
|
||||
if (changed) {
|
||||
nodes = TopoSort(func_graph->get_return());
|
||||
}
|
||||
|
||||
// 2. Move depend to tail of graph.
|
||||
AnfNodePtrList old_depends;
|
||||
AnfNodePtrList free_nodes;
|
||||
|
||||
// Find depend and its free nodes.
|
||||
for (const auto &node : nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
old_depends.push_back(node);
|
||||
free_nodes.push_back(node->cast<CNodePtr>()->input(kDependAttachNodeIndex));
|
||||
}
|
||||
|
||||
if (old_depends.empty()) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
// Delete old depend.
|
||||
for (const auto &depend_anfnode : old_depends) {
|
||||
auto depend_cnode = depend_anfnode->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||
auto depend_prior_node = depend_cnode->input(kControlDependPriorIndex);
|
||||
mng->Replace(depend_anfnode, depend_prior_node);
|
||||
}
|
||||
|
||||
// Add new depend node in tail.
|
||||
AnfNodePtr patron_node = FindPatronNode(func_graph, mng);
|
||||
AddDepends(patron_node, free_nodes, func_graph, mng);
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DependFormater : public Pass {
|
||||
public:
|
||||
DependFormater() : Pass("depend_formater") {}
|
||||
~DependFormater() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
};
|
||||
using DependFormaterPtr = std::shared_ptr<DependFormater>;
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_
|
|
@ -274,7 +274,7 @@ bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *i
|
|||
MS_EXCEPTION_IF_NULL(inputs_ptr);
|
||||
auto nodes = TopoSort(fg->get_return());
|
||||
|
||||
std::map<ValuePtr, AnfNodePtrList> vmap;
|
||||
OrderedMap<ValuePtr, AnfNodePtrList> vmap;
|
||||
for (const auto &node : nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
|
@ -590,7 +590,7 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n
|
|||
op_nodes = nodes;
|
||||
} else {
|
||||
// When there are basic and composite ops, the composite ops should be inline to the basic ones' graph,
|
||||
// so a new graph generation should be done (beacuse they may in the main graph!).
|
||||
// so a new graph generation should be done (because they may in the main graph!).
|
||||
// If address_node_map is wanted, we should map the new node in new graph to the old nodes. But... not support now.
|
||||
MS_LOG(EXCEPTION) << "No support mixed with basic and composite ops now!";
|
||||
}
|
||||
|
@ -1016,5 +1016,16 @@ CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &
|
|||
func_graph->AddNode(cnode);
|
||||
return cnode;
|
||||
}
|
||||
|
||||
void MakeCNodeSafeForAttr(const AnfNodePtr &node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
return;
|
||||
}
|
||||
AnfNodePtrList new_inputs = {NewValueNode(AnfAlgo::GetCNodePrimitive(cnode)->Clone())};
|
||||
auto inputs = cnode->inputs();
|
||||
new_inputs.insert(new_inputs.end(), inputs.begin() + 1, inputs.end());
|
||||
cnode->set_inputs(new_inputs);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -42,6 +42,8 @@ using kernel::DumpOption;
|
|||
constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
|
||||
constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
|
||||
constexpr auto kGraphKernelModule = "mindspore._extends.graph_kernel";
|
||||
constexpr auto kGraphKernelEstimateOps = "estimate_ops";
|
||||
constexpr auto kGraphKernelGetNodeCalAmount = "estimate_calulation_amount";
|
||||
constexpr auto kGraphKernelSplitFunc = "split_with_json";
|
||||
constexpr auto kGetGraphKernelOpExpander = "get_op_expander";
|
||||
constexpr auto kJsonKeyMultiGraph = "multi_graph";
|
||||
|
@ -88,6 +90,7 @@ ShapeVector GetShape(const AnfNodePtr &node);
|
|||
std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node);
|
||||
|
||||
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info);
|
||||
void MakeCNodeSafeForAttr(const AnfNodePtr &node);
|
||||
|
||||
template <typename T>
|
||||
ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) {
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/graph_kernel/parallel_cost_model.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
std::string CommonDimInfo::ToString() {
|
||||
std::ostringstream buffer;
|
||||
buffer << "Dim(" << dim_info_ << ")";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
int ParallelCostModel::GetNodeCalAmount(const AnfNodePtr &node) {
|
||||
nlohmann::json json_desc;
|
||||
AnfNodePtrList nodes = {node};
|
||||
DumpOption dump_option;
|
||||
if (!AnfToJsonDesc(nodes, dump_option, &json_desc)) {
|
||||
MS_LOG(EXCEPTION) << "Collect json desc failed.";
|
||||
}
|
||||
|
||||
auto json_desc_str = json_desc.dump();
|
||||
auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelGetNodeCalAmount, json_desc_str);
|
||||
if (py::isinstance<py::none>(ret)) {
|
||||
MS_LOG(EXCEPTION) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
|
||||
<< json_desc_str;
|
||||
}
|
||||
return py::cast<int>(ret);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<DimInfoPtr>, int> ParallelCostModel::CalFuseInfo(const AnfNodePtrList &nodes) {
|
||||
nlohmann::json json_desc;
|
||||
std::vector<AnfNodePtrList> graphs;
|
||||
std::transform(nodes.begin(), nodes.end(), std::back_inserter(graphs),
|
||||
[](const AnfNodePtr &node) -> AnfNodePtrList { return {node}; });
|
||||
DumpOption dump_option;
|
||||
if (!AnfToJsonDesc(graphs, dump_option, &json_desc)) {
|
||||
MS_LOG(EXCEPTION) << "Collect json desc failed.";
|
||||
}
|
||||
|
||||
auto json_desc_str = json_desc.dump();
|
||||
auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelEstimateOps, json_desc_str);
|
||||
if (py::isinstance<py::none>(ret)) {
|
||||
MS_LOG(EXCEPTION) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
|
||||
<< json_desc_str;
|
||||
}
|
||||
|
||||
py::tuple ret_tuple = py::cast<py::tuple>(ret);
|
||||
if (!py::isinstance<py::tuple>(ret_tuple) || ret_tuple.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Parallel cost model should return a tuple with two elements!";
|
||||
}
|
||||
|
||||
std::vector<DimInfoPtr> dim_infos;
|
||||
py::list dim_list = py::cast<py::list>(ret_tuple[0]);
|
||||
for (size_t i = 0; i < dim_list.size(); ++i) {
|
||||
dim_infos.push_back(std::make_shared<CommonDimInfo>(py::cast<int>(dim_list[i])));
|
||||
}
|
||||
int benefit = py::cast<int>(ret_tuple[1]);
|
||||
|
||||
return std::make_tuple(dim_infos, benefit);
|
||||
}
|
||||
|
||||
ParallelCostModelPtr ParellelCostModelWarehouse::GetParallelCostModel(const std::string &target) {
|
||||
if (target != kGPUDevice) {
|
||||
MS_LOG(EXCEPTION) << "Parallel cost model only support " << kGPUDevice << " now.";
|
||||
}
|
||||
return cost_model_;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,82 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "base/base.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/graph_kernel/parallel_cost_model.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DimInfo {
|
||||
public:
|
||||
DimInfo() = default;
|
||||
~DimInfo() {}
|
||||
virtual std::string ToString() = 0;
|
||||
};
|
||||
|
||||
class CommonDimInfo : public DimInfo {
|
||||
public:
|
||||
explicit CommonDimInfo(size_t dim) : dim_info_(dim) {}
|
||||
~CommonDimInfo() {}
|
||||
void set_dim_info(size_t d) { dim_info_ = d; }
|
||||
size_t dim_info() const { return dim_info_; }
|
||||
std::string ToString() override;
|
||||
|
||||
private:
|
||||
size_t dim_info_;
|
||||
};
|
||||
|
||||
using DimInfoPtr = std::shared_ptr<DimInfo>;
|
||||
using CommonDimInfoPtr = std::shared_ptr<CommonDimInfo>;
|
||||
|
||||
class ParallelCostModel {
|
||||
public:
|
||||
ParallelCostModel() {}
|
||||
~ParallelCostModel() {}
|
||||
int GetNodeCalAmount(const AnfNodePtr &node);
|
||||
std::tuple<std::vector<DimInfoPtr>, int> CalFuseInfo(const AnfNodePtrList &nodes);
|
||||
};
|
||||
|
||||
using ParallelCostModelPtr = std::shared_ptr<ParallelCostModel>;
|
||||
|
||||
class ParellelCostModelWarehouse {
|
||||
public:
|
||||
static ParellelCostModelWarehouse &Instance() {
|
||||
static ParellelCostModelWarehouse instance;
|
||||
return instance;
|
||||
}
|
||||
ParallelCostModelPtr GetParallelCostModel(const std::string &target);
|
||||
|
||||
private:
|
||||
ParellelCostModelWarehouse() { cost_model_ = std::make_shared<ParallelCostModel>(); }
|
||||
ParallelCostModelPtr cost_model_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_COST_MODEL_H_
|
|
@ -0,0 +1,876 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/graph_kernel/parallel_fusion.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "vm/segment_runner.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool IsOneOf(const AnfNodePtr &node, const std::vector<PrimitivePtr> &ops_prim) {
|
||||
return std::any_of(ops_prim.cbegin(), ops_prim.cend(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||
}
|
||||
|
||||
void ProcessThroughPassCNode(std::function<bool(const AnfNodePtr &)> pass_fn,
|
||||
OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
|
||||
std::set<AnfNodePtr> latter_to_be_erased;
|
||||
for (const auto &[node, node_rel] : (*node_rels)) {
|
||||
if (!pass_fn(node) || latter_to_be_erased.count(node) != 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto nexts = node_rel.nexts;
|
||||
std::vector<AnfNodePtr> pre_nodes;
|
||||
std::queue<AnfNodePtr> node_que;
|
||||
node_que.push(node);
|
||||
|
||||
// Find until all pre nodes get false from pass_fn, and collect all these predecessor nodes.
|
||||
while (!node_que.empty()) {
|
||||
auto cur_node = node_que.front();
|
||||
node_que.pop();
|
||||
|
||||
if (!pass_fn(cur_node)) {
|
||||
pre_nodes.push_back(cur_node);
|
||||
continue;
|
||||
}
|
||||
|
||||
latter_to_be_erased.insert(cur_node);
|
||||
auto predecessors = (*node_rels)[cur_node].pres;
|
||||
if (predecessors.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const auto &pre_node : predecessors) {
|
||||
(*node_rels)[cur_node].pres.erase(pre_node);
|
||||
(*node_rels)[pre_node].nexts.erase(cur_node);
|
||||
node_que.push(pre_node);
|
||||
}
|
||||
}
|
||||
|
||||
// Modify the relation: delete node <-> next_node, add pre node <-> next_node.
|
||||
for (const auto &next_node : nexts) {
|
||||
(*node_rels)[next_node].pres.erase(node);
|
||||
for (const auto &cur_node : pre_nodes) {
|
||||
(*node_rels)[next_node].pres.insert(cur_node);
|
||||
(*node_rels)[cur_node].nexts.insert(next_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &node : latter_to_be_erased) {
|
||||
node_rels->erase(node);
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
|
||||
for (auto &[node, node_rel] : (*node_rels)) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Make attached nodes deattach with node.
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
for (size_t id = kDependAttachNodeIndex; id < cnode->inputs().size(); ++id) {
|
||||
auto attach_node = cnode->input(id);
|
||||
if (auto iter = node_rels->find(attach_node); iter != node_rels->end()) {
|
||||
iter->second.nexts.erase(node);
|
||||
}
|
||||
if (auto &cnode_pres = node_rel.pres; cnode_pres.count(attach_node) != 0) {
|
||||
cnode_pres.erase(attach_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Eliminate depend node of node relations.
|
||||
ProcessThroughPassCNode([](const AnfNodePtr &node) { return IsOneOf(node, {prim::kPrimDepend}); }, node_rels);
|
||||
}
|
||||
|
||||
std::tuple<std::pair<AnfNodePtr, AnfNodePtr>, std::pair<AnfNodePtrList, AnfNodePtrList>> FindRelationOfControlDepend(
|
||||
const AnfNodePtr &node, OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto prior_node = cnode->input(kControlDependPriorIndex);
|
||||
auto behind_node = cnode->input(kControlDependBehindIndex);
|
||||
MS_EXCEPTION_IF_NULL(prior_node);
|
||||
MS_EXCEPTION_IF_NULL(behind_node);
|
||||
|
||||
OrderedSet<AnfNodePtr> prior_nodes;
|
||||
prior_nodes.insert(prior_node);
|
||||
OrderedSet<AnfNodePtr> behind_nodes;
|
||||
behind_nodes.insert(behind_node);
|
||||
|
||||
int64_t depend_mode = 0;
|
||||
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
|
||||
depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode);
|
||||
}
|
||||
if (prior_node->isa<Parameter>() && depend_mode == 1) {
|
||||
prior_nodes = (*node_rels)[prior_node].nexts;
|
||||
}
|
||||
if (behind_node->isa<Parameter>()) {
|
||||
behind_nodes = depend_mode == 1 ? (*node_rels)[behind_node].nexts : OrderedSet<AnfNodePtr>();
|
||||
}
|
||||
|
||||
// Get real nodes.
|
||||
AnfNodePtrList real_prior_nodes;
|
||||
std::set<AnfNodePtr> prior_visited;
|
||||
for (const auto &tmp : prior_nodes) {
|
||||
AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
|
||||
}
|
||||
AnfNodePtrList real_behind_nodes;
|
||||
std::set<AnfNodePtr> behind_visited;
|
||||
for (const auto &tmp : behind_nodes) {
|
||||
AnfAlgo::GetAllFatherRealNode(tmp, &real_behind_nodes, &behind_visited);
|
||||
}
|
||||
|
||||
return std::make_tuple(std::make_pair(prior_node, behind_node), std::make_pair(real_prior_nodes, real_behind_nodes));
|
||||
}
|
||||
|
||||
void ReLinkNodesOfControlDependByRelation(const std::unordered_map<AnfNodePtr, AnfNodePtrList> &control_depend_info,
|
||||
OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
|
||||
// Relink and its log.
|
||||
for (const auto &m : control_depend_info) {
|
||||
const auto &prior = m.second[0];
|
||||
const auto &behind = m.second[1];
|
||||
(*node_rels)[prior].nexts.insert(behind);
|
||||
(*node_rels)[behind].pres.insert(prior);
|
||||
MS_LOG(DEBUG) << "Relink relation of " << m.first->fullname_with_scope() << ": " << prior->fullname_with_scope()
|
||||
<< " -> " << behind->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessControlDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtrList> control_depend_info;
|
||||
AnfNodePtrList latter_to_be_erased;
|
||||
|
||||
// Collect ControlDepend node and its input and output nodes.
|
||||
for (auto &[node, node_rel] : (*node_rels)) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimControlDepend)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto [direct_relation, real_relations] = FindRelationOfControlDepend(node, node_rels);
|
||||
auto &[prior_node, behind_node] = direct_relation;
|
||||
auto &[real_prior_nodes, real_behind_nodes] = real_relations;
|
||||
|
||||
(*node_rels)[prior_node].nexts.erase(node);
|
||||
(*node_rels)[behind_node].nexts.erase(node);
|
||||
node_rel.pres.erase(prior_node);
|
||||
node_rel.pres.erase(behind_node);
|
||||
|
||||
for (auto &first_node : real_prior_nodes) {
|
||||
for (auto &second_node : real_behind_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(first_node);
|
||||
MS_EXCEPTION_IF_NULL(second_node);
|
||||
control_depend_info.insert({node, {first_node, second_node}});
|
||||
}
|
||||
}
|
||||
latter_to_be_erased.push_back(node);
|
||||
}
|
||||
|
||||
// Delete ControlDepend node before relink its relation.
|
||||
for (const auto &node : latter_to_be_erased) {
|
||||
node_rels->erase(node);
|
||||
}
|
||||
|
||||
// Rebuild relation between prior and behind node.
|
||||
ReLinkNodesOfControlDependByRelation(control_depend_info, node_rels);
|
||||
}
|
||||
|
||||
void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
|
||||
AnfNodePtrList latter_to_be_erased;
|
||||
for (auto &[node, node_rel] : (*node_rels)) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
AnfNodePtrList check_next_list;
|
||||
check_next_list.push_back(node);
|
||||
|
||||
bool disinterested = false;
|
||||
for (auto &successor : node_rel.nexts) {
|
||||
if (!IsPrimitiveCNode(successor, prim::kPrimTupleGetItem)) {
|
||||
disinterested = true;
|
||||
break;
|
||||
}
|
||||
check_next_list.push_back(successor);
|
||||
}
|
||||
if (disinterested) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!std::all_of(check_next_list.cbegin(), check_next_list.cend(),
|
||||
[&node_rels](const AnfNodePtr &n) -> bool { return (*node_rels)[n].nexts.empty(); })) {
|
||||
continue;
|
||||
}
|
||||
|
||||
latter_to_be_erased.push_back(node);
|
||||
}
|
||||
|
||||
// Delete Tail MakeTuple(including its getitem nodes).
|
||||
for (const auto &node : latter_to_be_erased) {
|
||||
for (auto &pre : (*node_rels)[node].pres) {
|
||||
(*node_rels)[pre].nexts.erase(node);
|
||||
}
|
||||
|
||||
// Tail MakeTuple is just be consumed by nothing or invalid getitem node.
|
||||
for (auto &getitem : (*node_rels)[node].nexts) {
|
||||
node_rels->erase(getitem);
|
||||
}
|
||||
|
||||
node_rels->erase(node);
|
||||
}
|
||||
}
|
||||
|
||||
bool IsSingleInputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
|
||||
if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 1) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsSingleOutputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
|
||||
if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 1) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsMultiInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
|
||||
if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() > 1) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsMultiOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
|
||||
if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() > 1) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsNoInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
|
||||
if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsNoOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
|
||||
if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 0) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ProcessLocalStructure(OrderedMap<AnfNodePtr, NodeRelation> *node_rels, std::set<AnfNodePtr> *virtual_noout_nodes,
|
||||
std::set<AnfNodePtr> *ignore_noin_nodes) {
|
||||
// 1. Local relation
|
||||
// Graph as following left part, relation D->B and D->E(D is a no input node)
|
||||
// will make B and E to be multiply inputs node.
|
||||
// But for parallel, this local relation can ignore for B and E, which make
|
||||
// them be able to be paralleled.
|
||||
//
|
||||
// ************************************
|
||||
// * *
|
||||
// * | | *
|
||||
// * A D A D *
|
||||
// * | /| | / \ *
|
||||
// * | C | | C F *
|
||||
// * |/ / | | | *
|
||||
// * B F ====> B x x *
|
||||
// * | / | *
|
||||
// * |/ | *
|
||||
// * E E *
|
||||
// * | | *
|
||||
// * *
|
||||
// ************************************
|
||||
AnfNodePtrList no_input_nodes;
|
||||
for (const auto &node_rel : *node_rels) {
|
||||
auto &node = node_rel.first;
|
||||
if (IsNoInputsNode(*node_rels, node)) {
|
||||
no_input_nodes.push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, AnfNodePtr>> latter_delete;
|
||||
|
||||
for (const auto &ninode : no_input_nodes) {
|
||||
AnfNodePtrList cnexts((*node_rels)[ninode].nexts.begin(), (*node_rels)[ninode].nexts.end());
|
||||
for (const auto &n : cnexts) {
|
||||
AnfNodePtr serial_tail = ninode;
|
||||
AnfNodePtr cur_node = n;
|
||||
while (IsSingleInputNode(*node_rels, cur_node) && IsSingleOutputNode(*node_rels, cur_node)) {
|
||||
serial_tail = cur_node;
|
||||
cur_node = *((*node_rels)[cur_node].nexts.begin());
|
||||
}
|
||||
latter_delete.emplace_back(serial_tail, cur_node);
|
||||
}
|
||||
}
|
||||
|
||||
// Delete relation.
|
||||
for (const auto &[serial_tail, cur_node] : latter_delete) {
|
||||
virtual_noout_nodes->insert(serial_tail);
|
||||
ignore_noin_nodes->insert(cur_node);
|
||||
(*node_rels)[serial_tail].nexts.erase(cur_node);
|
||||
(*node_rels)[cur_node].pres.erase(serial_tail);
|
||||
MS_LOG(INFO) << "Process local relation delete relation: " << serial_tail->fullname_with_scope() << " -> "
|
||||
<< cur_node->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<AnfNodePtrList, AnfNodePtrList, AnfNodePtrList, AnfNodePtrList> GetInterestNodeIds(
|
||||
const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const std::set<AnfNodePtr> &virtual_noout_nodes,
|
||||
const std::set<AnfNodePtr> &ignore_noin_nodes) {
|
||||
AnfNodePtrList multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes;
|
||||
std::list<std::function<void(const AnfNodePtr &)>> func_list = {
|
||||
[&node_rels, &multi_inputs_nodes](const AnfNodePtr &node) {
|
||||
if (IsMultiInputsNode(node_rels, node)) {
|
||||
multi_inputs_nodes.push_back(node);
|
||||
}
|
||||
},
|
||||
[&node_rels, &multi_outputs_nodes](const AnfNodePtr &node) {
|
||||
if (IsMultiOutputsNode(node_rels, node)) {
|
||||
multi_outputs_nodes.push_back(node);
|
||||
}
|
||||
},
|
||||
[&node_rels, &no_input_nodes, &ignore_noin_nodes](const AnfNodePtr &node) {
|
||||
if (IsNoInputsNode(node_rels, node) && ignore_noin_nodes.count(node) == 0) {
|
||||
no_input_nodes.push_back(node);
|
||||
}
|
||||
},
|
||||
[&node_rels, &no_output_nodes, &virtual_noout_nodes](const AnfNodePtr &node) {
|
||||
if (IsNoOutputsNode(node_rels, node) && virtual_noout_nodes.count(node) == 0) {
|
||||
no_output_nodes.push_back(node);
|
||||
}
|
||||
}};
|
||||
|
||||
for (const auto &node_rel : node_rels) {
|
||||
for (const auto &func : func_list) {
|
||||
func(node_rel.first);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes);
|
||||
}
|
||||
|
||||
bool WhiteOpsFilter(const AnfNodePtr &node) {
|
||||
std::vector<PrimitivePtr> whiteable_ops = {}; // Not special for now.
|
||||
return session::AnfRuntimeAlgorithm::IsGraphKernel(node) || IsOneOf(node, whiteable_ops);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtrList> SearchFromNodes(const AnfNodePtrList &nodes,
|
||||
const std::function<bool(const AnfNodePtr &)> &filter_func,
|
||||
const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
|
||||
std::set<AnfNodePtr> *seen) {
|
||||
// Start from multi-inputs node, stop on seen node or multi-inputs or multi-outputs nodes.
|
||||
// For backward search, the other multi-inputs node can be contained in.
|
||||
// For forward search, the other multi-outputs node can be contained in.
|
||||
auto get_contain_node_set = is_backward ? [](const NodeRelation &info) { return info.pres; }
|
||||
: [](const NodeRelation &info) { return info.nexts; };
|
||||
auto get_exclude_node_set = is_backward ? [](const NodeRelation &info) { return info.nexts; }
|
||||
: [](const NodeRelation &info) { return info.pres; };
|
||||
std::vector<AnfNodePtrList> group;
|
||||
for (const auto &node : nodes) {
|
||||
AnfNodePtrList stream;
|
||||
AnfNodePtr n = node;
|
||||
for (auto iter = node_rels.find(n);
|
||||
seen->count(n) == 0 && iter != node_rels.end() && get_exclude_node_set(iter->second).size() <= 1;
|
||||
iter = node_rels.find(n)) {
|
||||
if (filter_func(n)) {
|
||||
stream.push_back(n);
|
||||
seen->insert(n);
|
||||
}
|
||||
if (get_contain_node_set(iter->second).size() != 1) {
|
||||
break;
|
||||
}
|
||||
n = *(get_contain_node_set(iter->second).begin());
|
||||
}
|
||||
if (stream.size() > 0) {
|
||||
group.push_back(stream);
|
||||
}
|
||||
}
|
||||
|
||||
if (group.size() == 1) {
|
||||
for (const auto &drop : group[0]) {
|
||||
seen->erase(drop);
|
||||
}
|
||||
group.clear();
|
||||
}
|
||||
|
||||
return group;
|
||||
}
|
||||
|
||||
void SearchStreamFromMultiRelationNode(const AnfNodePtrList &multi_nodes,
|
||||
const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
|
||||
std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) {
|
||||
auto get_related_nodes = is_backward ? [](const NodeRelation &info) { return info.pres; }
|
||||
: [](const NodeRelation &info) { return info.nexts; };
|
||||
for (const auto &node : multi_nodes) {
|
||||
if (auto iter = node_rels.find(node); iter != node_rels.end()) {
|
||||
const auto &pre_nodes = get_related_nodes(iter->second);
|
||||
AnfNodePtrList related_nodes(pre_nodes.begin(), pre_nodes.end());
|
||||
groups->push_back(SearchFromNodes(related_nodes, WhiteOpsFilter, node_rels, is_backward, seen));
|
||||
}
|
||||
}
|
||||
|
||||
// Erase empty groups.
|
||||
for (auto iter = groups->begin(); iter != groups->end();) {
|
||||
if (iter->size() == 0) {
|
||||
iter = groups->erase(iter);
|
||||
} else {
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SearchStreamFromUnidirectionalNode(const AnfNodePtrList &ud_nodes,
|
||||
const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
|
||||
std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) {
|
||||
groups->push_back(SearchFromNodes(ud_nodes, WhiteOpsFilter, node_rels, is_backward, seen));
|
||||
|
||||
// Erase empty groups.
|
||||
for (auto iter = groups->begin(); iter != groups->end();) {
|
||||
if (iter->size() == 0) {
|
||||
iter = groups->erase(iter);
|
||||
} else {
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string DumpNode(const AnfNodePtr &node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::stringstream buf;
|
||||
buf << (AnfAlgo::IsGraphKernel(cnode) ? "[graph]" : "[primitive]") << cnode->fullname_with_scope() << "|"
|
||||
<< cnode->ToString();
|
||||
return buf.str();
|
||||
}
|
||||
|
||||
void DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> &groups) {
|
||||
MS_LOG(INFO) << "There are " << groups.size() << " parallel groups, their detail is: ";
|
||||
int i = 0;
|
||||
for (const auto group : groups) {
|
||||
std::stringstream buf;
|
||||
buf << "[" << i << " group] " << group.size() << ":\n";
|
||||
for (const auto nodes : group) {
|
||||
buf << " " << nodes.size() << ": [<";
|
||||
for (const auto node : nodes) {
|
||||
buf << "(" << DumpNode(node) << ") -> ";
|
||||
}
|
||||
buf << ">]\n";
|
||||
}
|
||||
i++;
|
||||
MS_LOG(INFO) << buf.str();
|
||||
}
|
||||
}
|
||||
|
||||
void DumpParallelFusionDetail(const AnfNodePtrList &source, const AnfNodePtr &target) {
|
||||
std::stringstream buf;
|
||||
buf << "Parallel fusion detail: ";
|
||||
for (const auto &node : source) {
|
||||
buf << "(" << DumpNode(node) << ") + ";
|
||||
}
|
||||
buf << "==>"
|
||||
<< "(" << DumpNode(target) << ")";
|
||||
MS_LOG(INFO) << buf.str();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const AnfNodePtrList &nodes) {
|
||||
// Based on anf node input information, build a simple graph for latter analyzation.
|
||||
OrderedMap<AnfNodePtr, NodeRelation> node_rels;
|
||||
auto get_info = [&node_rels](const AnfNodePtr &node) {
|
||||
if (node_rels.count(node) == 0) {
|
||||
node_rels.insert({node, NodeRelation()});
|
||||
}
|
||||
return &(node_rels[node]);
|
||||
};
|
||||
|
||||
for (const auto &node : nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto prior_node = get_info(node);
|
||||
for (const auto &input : (node->cast<CNodePtr>())->inputs()) {
|
||||
// Parameter for ControlDepend when depend mode is 1.
|
||||
if (!input->isa<CNode>() && !input->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
auto behind_node = get_info(input);
|
||||
prior_node->pres.insert(input);
|
||||
behind_node->nexts.insert(node);
|
||||
}
|
||||
}
|
||||
|
||||
ProcessDependCNode(&node_rels);
|
||||
ProcessControlDependCNode(&node_rels);
|
||||
ProcessThroughPassCNode(
|
||||
[](const AnfNodePtr &node) {
|
||||
return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem});
|
||||
},
|
||||
&node_rels);
|
||||
ProcessThroughPassCNode([](const AnfNodePtr &node) { return node->isa<Parameter>(); }, &node_rels);
|
||||
ProcessTailMakeTupleCNode(&node_rels);
|
||||
ProcessLocalStructure(&node_rels, &virtual_noout_nodes_, &ignore_noin_nodes_);
|
||||
|
||||
return node_rels;
|
||||
}
|
||||
|
||||
std::vector<std::vector<AnfNodePtrList>> ParallelOpFusion::SearchParallelGroups(
|
||||
const OrderedMap<AnfNodePtr, NodeRelation> &node_rels) {
|
||||
// Get interesting nodes: multi-inputs nodes, multi-outputs nodes, no input nodes and no output nodes.
|
||||
auto [mul_ins_nodes, mul_outs_nodes, no_in_nodes, no_out_nodes] =
|
||||
GetInterestNodeIds(node_rels, virtual_noout_nodes_, ignore_noin_nodes_);
|
||||
|
||||
// Get streams and group them
|
||||
std::set<AnfNodePtr> seen;
|
||||
std::vector<std::vector<AnfNodePtrList>> groups;
|
||||
|
||||
SearchStreamFromMultiRelationNode(mul_ins_nodes, node_rels, true, &groups, &seen);
|
||||
SearchStreamFromUnidirectionalNode(no_out_nodes, node_rels, true, &groups, &seen);
|
||||
SearchStreamFromMultiRelationNode(mul_outs_nodes, node_rels, false, &groups, &seen);
|
||||
SearchStreamFromUnidirectionalNode(no_in_nodes, node_rels, false, &groups, &seen);
|
||||
|
||||
DumpParallelGroups(groups);
|
||||
return groups;
|
||||
}
|
||||
|
||||
std::tuple<AnfNodePtrList, std::vector<int>> ParallelOpFusion::GetAvaliableNodesByOffset(
|
||||
int start, const std::vector<int> &offsets, const std::vector<bool> &used, const AnfNodePtrList &nodes,
|
||||
const std::set<int> &excludes) {
|
||||
// Get unused nodes by offset index, the result will contain the node with start index.
|
||||
int node_limit = nodes.size();
|
||||
if (start >= node_limit) {
|
||||
MS_LOG(EXCEPTION) << "Index offset is exceed the limit of given nodes.";
|
||||
}
|
||||
AnfNodePtrList target_nodes = {nodes[start]};
|
||||
std::vector<int> valid_indices;
|
||||
std::vector<int> unused;
|
||||
for (size_t i = start; i < used.size(); ++i) {
|
||||
if (!used[i] && excludes.count(i) == 0) {
|
||||
unused.push_back(i);
|
||||
}
|
||||
}
|
||||
int limit = unused.size();
|
||||
for (auto offset : offsets) {
|
||||
if (offset >= limit) {
|
||||
MS_LOG(EXCEPTION) << "Index offset is exceed the limit of unused nodes.";
|
||||
}
|
||||
if (unused[offset] >= node_limit) {
|
||||
MS_LOG(EXCEPTION) << "Index offset is exceed the limit of nodes.";
|
||||
}
|
||||
valid_indices.push_back(unused[offset]);
|
||||
target_nodes.push_back(nodes[unused[offset]]);
|
||||
}
|
||||
|
||||
return std::make_tuple(target_nodes, valid_indices);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSearchInSortedCandidates(
|
||||
size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices,
|
||||
std::map<AnfNodePtr, int> *sorted_indices) {
|
||||
auto get_index = [](std::map<AnfNodePtr, int> *indices, const AnfNodePtr &node) -> int {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (indices->find(node) == indices->end()) {
|
||||
MS_LOG(EXCEPTION) << "There is no index record for node " << node->ToString();
|
||||
}
|
||||
return (*indices)[node];
|
||||
};
|
||||
|
||||
std::vector<ParallelInfo> parallel_infos;
|
||||
std::vector<bool> origin_candidates_used(origin_size, false);
|
||||
std::vector<bool> sorted_candidates_used(candidates.size(), false);
|
||||
|
||||
for (size_t i = 0; i < candidates.size(); ++i) {
|
||||
if (sorted_candidates_used[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int max_benefit = 0;
|
||||
ParallelInfo best_parallel_info;
|
||||
std::set<int> bad_set;
|
||||
size_t unused_num = 0;
|
||||
for (size_t j = i + 1; j < sorted_candidates_used.size(); ++j) {
|
||||
unused_num += sorted_candidates_used[j] ? 0 : 1;
|
||||
}
|
||||
if (unused_num < 1) {
|
||||
break;
|
||||
}
|
||||
|
||||
unused_num = std::min(unused_num, config_.max_num_for_fuse() - 1);
|
||||
|
||||
size_t begin = 1, end = unused_num;
|
||||
while (begin <= end) {
|
||||
size_t mid = (begin + end) / 2;
|
||||
std::vector<int> tc(mid);
|
||||
std::iota(tc.begin(), tc.end(), 1);
|
||||
AnfNodePtrList other_candidates;
|
||||
std::tie(other_candidates, std::ignore) =
|
||||
GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set<int>());
|
||||
int benefit;
|
||||
std::tie(std::ignore, benefit) = cost_model_ptr_->CalFuseInfo(other_candidates);
|
||||
if (benefit > 0) {
|
||||
begin = mid + 1;
|
||||
} else {
|
||||
end = mid - 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (begin > 1) {
|
||||
std::vector<int> tc(begin - 1);
|
||||
std::iota(tc.begin(), tc.end(), 1);
|
||||
AnfNodePtrList other_candidates;
|
||||
std::tie(other_candidates, std::ignore) =
|
||||
GetAvaliableNodesByOffset(i, tc, sorted_candidates_used, candidates, std::set<int>());
|
||||
auto [dim_infos, benefit] = cost_model_ptr_->CalFuseInfo(other_candidates);
|
||||
if (benefit <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Internal error in candidate search!";
|
||||
}
|
||||
max_benefit = benefit;
|
||||
best_parallel_info = ParallelInfo(other_candidates, dim_infos);
|
||||
i += begin - 1;
|
||||
}
|
||||
|
||||
if (max_benefit > 0) {
|
||||
parallel_infos.push_back(best_parallel_info);
|
||||
for (const auto &node : best_parallel_info.nodes()) {
|
||||
sorted_candidates_used[get_index(sorted_indices, node)] = true;
|
||||
origin_candidates_used[get_index(origin_indices, node)] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Current nodes is not suitable to fuse, so pop first node to try other fusion possibility.
|
||||
if (parallel_infos.size() == 0) {
|
||||
origin_candidates_used[get_index(origin_indices, candidates[0])] = true;
|
||||
}
|
||||
|
||||
return std::make_tuple(origin_candidates_used, parallel_infos);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::SearchFuseNodesInCandidates(
|
||||
const AnfNodePtrList &cs) {
|
||||
std::map<AnfNodePtr, int> origin_indices;
|
||||
std::vector<size_t> indices;
|
||||
for (size_t i = 0; i < cs.size(); ++i) {
|
||||
if (cs[i]) {
|
||||
origin_indices.insert({cs[i], i});
|
||||
indices.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// A calculated heavy node can cover more lighter nodes' cost, so sort them first.
|
||||
std::map<size_t, int> cal_amounts;
|
||||
for (auto id : indices) {
|
||||
cal_amounts[id] = cost_model_ptr_->GetNodeCalAmount(cs[id]);
|
||||
}
|
||||
std::sort(indices.begin(), indices.end(),
|
||||
[&cal_amounts](size_t a, size_t b) { return cal_amounts[a] > cal_amounts[b]; });
|
||||
|
||||
AnfNodePtrList candidates;
|
||||
for (size_t i = 0; i < indices.size(); ++i) {
|
||||
candidates.push_back(cs[indices[i]]);
|
||||
}
|
||||
|
||||
std::map<AnfNodePtr, int> sorted_indices;
|
||||
for (size_t i = 0; i < candidates.size(); ++i) {
|
||||
sorted_indices.insert({candidates[i], i});
|
||||
}
|
||||
|
||||
return DoSearchInSortedCandidates(cs.size(), candidates, &origin_indices, &sorted_indices);
|
||||
}
|
||||
|
||||
void ParallelOpFusion::SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group,
|
||||
std::vector<ParallelInfo> *parallel_infos) {
|
||||
std::vector<AnfNodePtrList::const_iterator> tails;
|
||||
std::vector<AnfNodePtrList::const_iterator> ended;
|
||||
for (const auto &node_list : group) {
|
||||
tails.push_back(node_list.begin());
|
||||
ended.push_back(node_list.end());
|
||||
}
|
||||
auto get_candidates = [&tails, &ended]() {
|
||||
AnfNodePtrList candidates;
|
||||
for (size_t id = 0; id < tails.size(); ++id) {
|
||||
candidates.push_back(tails[id] != ended[id] ? *tails[id] : AnfNodePtr());
|
||||
}
|
||||
return candidates;
|
||||
};
|
||||
auto update_tails = [&tails](const std::vector<bool> &used) {
|
||||
if (used.size() != tails.size()) {
|
||||
MS_LOG(EXCEPTION) << "Judged nodes size is not equal to left ones!";
|
||||
}
|
||||
for (size_t id = 0; id < used.size(); ++id) {
|
||||
if (used[id]) {
|
||||
tails[id]++;
|
||||
}
|
||||
}
|
||||
};
|
||||
auto valid_candidate_num = [](const AnfNodePtrList &cs) {
|
||||
return std::count_if(cs.begin(), cs.end(), [](const AnfNodePtr &n) { return n != nullptr; });
|
||||
};
|
||||
|
||||
auto candidates = get_candidates();
|
||||
while (valid_candidate_num(candidates) > 1) {
|
||||
auto [used, fnds] = SearchFuseNodesInCandidates(candidates);
|
||||
std::transform(fnds.cbegin(), fnds.cend(), std::back_insert_iterator(*parallel_infos),
|
||||
[](const ParallelInfo &pi) { return pi; });
|
||||
update_tails(used);
|
||||
candidates = get_candidates();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ParallelInfo> ParallelOpFusion::SearchFusableParallelCNodes(
|
||||
const std::vector<std::vector<AnfNodePtrList>> &groups) {
|
||||
// Find core-fusable groups with cost model.
|
||||
std::vector<ParallelInfo> parallel_infos;
|
||||
for (const auto &group : groups) {
|
||||
SearchFuseNodesInParallelGroup(group, ¶llel_infos);
|
||||
}
|
||||
|
||||
return parallel_infos;
|
||||
}
|
||||
|
||||
void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info) {
|
||||
for (size_t i = 0; i < parallel_info.GetSize(); ++i) {
|
||||
const auto &fuse_nodes = parallel_info.nodes();
|
||||
std::vector<size_t> info = {i, std::dynamic_pointer_cast<CommonDimInfo>(parallel_info.dims()[i])->dim_info()};
|
||||
if (!AnfAlgo::IsGraphKernel(fuse_nodes[i])) {
|
||||
MakeCNodeSafeForAttr(fuse_nodes[i]);
|
||||
AnfAlgo::SetNodeAttr(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), fuse_nodes[i]);
|
||||
} else {
|
||||
auto node_g = GetValueNode<FuncGraphPtr>((fuse_nodes[i]->cast<CNodePtr>())->input(0));
|
||||
auto out_node = node_g->output();
|
||||
if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
|
||||
auto inputs = out_node->cast<CNodePtr>()->inputs();
|
||||
for (size_t j = 1; j < inputs.size(); ++j) {
|
||||
MakeCNodeSafeForAttr(inputs[j]);
|
||||
AnfAlgo::SetNodeAttr(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), inputs[j]);
|
||||
}
|
||||
} else {
|
||||
MakeCNodeSafeForAttr(out_node);
|
||||
AnfAlgo::SetNodeAttr(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), out_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PostProcessForNewSubGraphCNode(const AnfNodePtr &node, const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
auto mng = kernel_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(kernel_graph, true);
|
||||
kernel_graph->set_manager(mng);
|
||||
}
|
||||
|
||||
const auto &users = mng->node_users()[node];
|
||||
std::vector<std::pair<AnfNodePtr, int>> sons;
|
||||
for (const auto &[user, index] : users) {
|
||||
if (!IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) {
|
||||
sons.emplace_back(user, index);
|
||||
continue;
|
||||
}
|
||||
auto &[fake_first_grad_son, grad_index] = *((mng->node_users()[user]).begin());
|
||||
sons.emplace_back(fake_first_grad_son, grad_index);
|
||||
}
|
||||
|
||||
AnfNodePtrList latter_to_delete;
|
||||
for (const auto &[son, index] : sons) {
|
||||
if (!IsPrimitiveCNode(son, prim::kPrimDepend) || index != kDependAttachNodeIndex) {
|
||||
continue;
|
||||
}
|
||||
|
||||
latter_to_delete.push_back(son);
|
||||
}
|
||||
|
||||
if (latter_to_delete.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr>::iterator delete_begin = latter_to_delete.begin();
|
||||
if (latter_to_delete.size() == sons.size()) {
|
||||
// Left one Depend node relation and delete others!
|
||||
++delete_begin;
|
||||
}
|
||||
for (; delete_begin != latter_to_delete.end(); ++delete_begin) {
|
||||
auto depend_anfnode = *delete_begin;
|
||||
auto depend_cnode = depend_anfnode->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||
auto depend_prior_node = depend_cnode->input(kRealInputIndexInDepend);
|
||||
mng->Replace(depend_anfnode, depend_prior_node);
|
||||
}
|
||||
}
|
||||
|
||||
bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo> ¶llel_infos,
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
bool changed = false;
|
||||
|
||||
for (size_t i = 0; i < parallel_infos.size(); ++i) {
|
||||
const auto &fuse_nodes = parallel_infos[i].nodes();
|
||||
if (fuse_nodes.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
changed = true;
|
||||
SetFusedParallelOpAttrToReturnNode(parallel_infos[i]);
|
||||
AnfNodePtr sg_node;
|
||||
std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel");
|
||||
PostProcessForNewSubGraphCNode(sg_node, kernel_graph);
|
||||
DumpParallelFusionDetail(fuse_nodes, sg_node);
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool ParallelOpFusion::Run(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
||||
cost_model_ptr_ = ParellelCostModelWarehouse::Instance().GetParallelCostModel(target_);
|
||||
MS_EXCEPTION_IF_NULL(cost_model_ptr_);
|
||||
|
||||
auto nodes = TopoSort(kernel_graph->get_return());
|
||||
std::reverse(nodes.begin(), nodes.end());
|
||||
|
||||
auto node_rels = GenAnalysisGraph(nodes);
|
||||
auto groups = SearchParallelGroups(node_rels);
|
||||
auto parallel_infos = SearchFusableParallelCNodes(groups);
|
||||
|
||||
// Create core-fuse subgraph and change origin graph.
|
||||
return CreateParallelOpSubGraphs(parallel_infos, kernel_graph);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,122 @@
|
|||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "base/base.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/graph_kernel/parallel_cost_model.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ParallelInfo {
|
||||
public:
|
||||
ParallelInfo() = default;
|
||||
ParallelInfo(const AnfNodePtrList &nodes, const std::vector<DimInfoPtr> &dims) : nodes_(nodes), dims_(dims) {}
|
||||
ParallelInfo(const ParallelInfo &obj) {
|
||||
nodes_ = obj.nodes_;
|
||||
dims_ = obj.dims_;
|
||||
}
|
||||
~ParallelInfo() = default;
|
||||
|
||||
size_t GetSize() const {
|
||||
if (nodes_.size() != dims_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Internal error in parallel info!";
|
||||
}
|
||||
return nodes_.size();
|
||||
}
|
||||
const AnfNodePtrList &nodes() const { return nodes_; }
|
||||
const std::vector<DimInfoPtr> &dims() const { return dims_; }
|
||||
|
||||
private:
|
||||
AnfNodePtrList nodes_;
|
||||
std::vector<DimInfoPtr> dims_;
|
||||
};
|
||||
|
||||
class ParallelConfig {
|
||||
public:
|
||||
ParallelConfig() = default;
|
||||
explicit ParallelConfig(size_t max_n) : max_num_for_fuse_(max_n) {}
|
||||
explicit ParallelConfig(const ParallelConfig &obj) { max_num_for_fuse_ = obj.max_num_for_fuse_; }
|
||||
~ParallelConfig() = default;
|
||||
size_t max_num_for_fuse() { return max_num_for_fuse_; }
|
||||
|
||||
private:
|
||||
size_t max_num_for_fuse_{10}; // Too many nodes to fuse together may produce bad result.
|
||||
};
|
||||
|
||||
struct NodeRelation {
|
||||
public:
|
||||
NodeRelation() {}
|
||||
~NodeRelation() = default;
|
||||
OrderedSet<AnfNodePtr> pres;
|
||||
OrderedSet<AnfNodePtr> nexts;
|
||||
};
|
||||
|
||||
class ParallelOpFusion : public Pass {
|
||||
public:
|
||||
ParallelOpFusion(const std::string &target, const ParallelConfig &config)
|
||||
: Pass("parallel_fusion"), target_(target), config_(config) {}
|
||||
~ParallelOpFusion() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
std::tuple<AnfNodePtrList, std::vector<int>> GetAvaliableNodesByOffset(int start, const std::vector<int> &offsets,
|
||||
const std::vector<bool> &used,
|
||||
const AnfNodePtrList &nodes,
|
||||
const std::set<int> &excludes);
|
||||
|
||||
std::tuple<std::vector<bool>, std::vector<ParallelInfo>> DoSearchInSortedCandidates(
|
||||
size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices,
|
||||
std::map<AnfNodePtr, int> *sorted_indices);
|
||||
|
||||
std::tuple<std::vector<bool>, std::vector<ParallelInfo>> SearchFuseNodesInCandidates(const AnfNodePtrList &cs);
|
||||
|
||||
void SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group,
|
||||
std::vector<ParallelInfo> *parallel_infos);
|
||||
|
||||
std::vector<ParallelInfo> SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> &groups);
|
||||
|
||||
void SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info);
|
||||
|
||||
bool CreateParallelOpSubGraphs(const std::vector<ParallelInfo> ¶llel_infos,
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph);
|
||||
|
||||
OrderedMap<AnfNodePtr, NodeRelation> GenAnalysisGraph(const AnfNodePtrList &nodes);
|
||||
std::vector<std::vector<AnfNodePtrList>> SearchParallelGroups(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels);
|
||||
|
||||
std::string target_;
|
||||
ParallelConfig config_;
|
||||
ParallelCostModelPtr cost_model_ptr_;
|
||||
std::set<AnfNodePtr> virtual_noout_nodes_;
|
||||
std::set<AnfNodePtr> ignore_noin_nodes_;
|
||||
};
|
||||
using ParallelOpFusionPtr = std::shared_ptr<ParallelOpFusion>;
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_PARALLEL_FUSION_H_
|
|
@ -43,6 +43,7 @@
|
|||
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
|
||||
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/clean_all_in_once.h"
|
||||
#include "backend/optimizer/graph_kernel/depend_formater.h"
|
||||
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
|
||||
#include "backend/optimizer/graph_kernel/tensor_promotion.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
|
||||
|
@ -51,6 +52,7 @@
|
|||
#include "backend/optimizer/graph_kernel/graph_kernel_cse.h"
|
||||
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
|
||||
#include "backend/optimizer/graph_kernel/value_graph_binder.h"
|
||||
#include "backend/optimizer/graph_kernel/parallel_fusion.h"
|
||||
#include "backend/optimizer/pass/communication_op_fusion.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
#include "common/trans.h"
|
||||
|
@ -179,6 +181,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
|
|||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm");
|
||||
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast};
|
||||
pm->AddPass(std::make_shared<opt::DependFormater>()); // Make more fusion opportunity.
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
|
||||
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>(duplicated_ops));
|
||||
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
|
||||
|
@ -196,7 +199,8 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
|
|||
// will be exposed, use GetitemTuple Pass to delete them.
|
||||
pm->AddPass(std::make_shared<opt::GetitemTuple>());
|
||||
pm->AddPass(std::make_shared<opt::AtomicCleanInsertter>());
|
||||
pm->AddPass(std::make_shared<opt::CleanAllInOnce>());
|
||||
pm->AddPass(std::make_shared<opt::DependFormater>()); // Prevent fake loop in parallel fusion.
|
||||
pm->AddPass(std::make_shared<opt::ParallelOpFusion>(kGPUDevice, opt::ParallelConfig(7)));
|
||||
pm->AddPass(std::make_shared<opt::BindValueToGraph>());
|
||||
optimizer->AddPassManager(pm);
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -382,6 +382,7 @@ constexpr auto kAttrPadding = "padding";
|
|||
constexpr auto kAttrIsGrad = "is_grad";
|
||||
constexpr auto kAttrRecompute = "recompute";
|
||||
constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute";
|
||||
constexpr auto kAttrParallelDimInfo = "parallel_dim_info";
|
||||
|
||||
// attr value
|
||||
constexpr auto kValueTargetSwitch = "target_switch";
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==========================================================================
|
||||
"""test graph parallel case"""
|
||||
import model
|
||||
|
||||
def injective_graph(shape):
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope('injective') as _:
|
||||
a1 = gb.tensor(shape, 'float32')
|
||||
a2 = gb.emit('Abs', a1)
|
||||
a3 = gb.emit('Abs', a2)
|
||||
gb.emit('Abs', a3)
|
||||
return gb.get()[0]
|
||||
|
||||
def reduce_graph(shape, reduce_axis):
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope('reduce') as _:
|
||||
a1 = gb.tensor(shape, 'float32')
|
||||
a2 = gb.emit('Abs', a1)
|
||||
a3 = gb.emit('Abs', a2)
|
||||
gb.emit('ReduceSum', a3, 'C', attrs={'reduce_axis': reduce_axis})
|
||||
return gb.get()[0]
|
||||
|
||||
def control_graph(shape):
|
||||
gb = model.GraphBuilder()
|
||||
with gb.graph_scope('control') as _:
|
||||
a1 = gb.tensor(shape, 'float32')
|
||||
a2 = gb.emit('Abs', a1)
|
||||
gb.emit('ControlDepend', a2)
|
||||
return gb.get()[0]
|
||||
|
||||
def block_fusion(graphs):
|
||||
gain = model.parallel_estimate(graphs)
|
||||
print("fusion = {}, bottleneck = {}, gain = {}".format(gain.fusion_type, gain.bottleneck, gain.gain))
|
||||
return gain.fusion_type == "block_fusion" and gain.gain > 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert block_fusion([injective_graph([40, 1024]), injective_graph([40, 1024])])
|
||||
assert block_fusion([reduce_graph([1024, 1024], [1]), injective_graph([24, 1024])])
|
||||
assert not block_fusion([reduce_graph([1024, 1024], [1]), injective_graph([50, 1024])])
|
||||
assert not block_fusion([reduce_graph([1024, 1024], [0, 1]), injective_graph([1024, 1024])])
|
||||
assert block_fusion([control_graph([20, 128]), injective_graph([40, 1024])])
|
Loading…
Reference in New Issue