forked from mindspore-Ecosystem/mindspore
!8689 [GraphKernel] Split shape ops for more fusion opportunity.
From: @tronzhang Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou
This commit is contained in:
commit
9969c83f75
|
@ -87,7 +87,8 @@ class GraphSplitByPattern:
|
|||
_redirect_relation(a.out_relations)
|
||||
for a, _ in area.out_relations.items():
|
||||
_redirect_relation(a.in_relations)
|
||||
self.mode = self.MODE_COMPOSITE
|
||||
if self.pattern > PrimLib.RESHAPE:
|
||||
self.mode = self.MODE_COMPOSITE
|
||||
|
||||
def check_circle(self, to):
|
||||
"""Check circle. It returns false if circle exists"""
|
||||
|
@ -148,7 +149,7 @@ class GraphSplitByPattern:
|
|||
def split(self):
|
||||
"""Split graph by pattern"""
|
||||
def _elemwise_depth(dom):
|
||||
if dom.pattern > PrimLib.BROADCAST or len(dom.in_relations) != 1:
|
||||
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1:
|
||||
return None
|
||||
a, r = list(dom.in_relations.items())[0]
|
||||
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 and r != PrimLib.ELEMWISE:
|
||||
|
@ -156,7 +157,7 @@ class GraphSplitByPattern:
|
|||
return [a]
|
||||
|
||||
def _elemwise_width(dom):
|
||||
if dom.pattern > PrimLib.BROADCAST:
|
||||
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
|
||||
return None
|
||||
fused = []
|
||||
for a, r in dom.in_relations.items():
|
||||
|
@ -165,7 +166,7 @@ class GraphSplitByPattern:
|
|||
return fused
|
||||
|
||||
def _broadcast_depth(dom):
|
||||
if dom.pattern > PrimLib.BROADCAST or len(dom.in_relations) != 1:
|
||||
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1:
|
||||
return None
|
||||
a, r = list(dom.in_relations.items())[0]
|
||||
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \
|
||||
|
@ -174,7 +175,7 @@ class GraphSplitByPattern:
|
|||
return [a]
|
||||
|
||||
def _broadcast_width(dom):
|
||||
if dom.pattern > PrimLib.BROADCAST:
|
||||
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
|
||||
return None
|
||||
fused = []
|
||||
for a, r in dom.in_relations.items():
|
||||
|
|
|
@ -65,11 +65,12 @@ class PrimLib:
|
|||
"""Prim lib"""
|
||||
|
||||
UNKNOWN = 0
|
||||
ELEMWISE = 1
|
||||
BROADCAST = 2
|
||||
REDUCE = 3
|
||||
TRANSFORM = 4
|
||||
CONTROL = 5
|
||||
RESHAPE = 1
|
||||
ELEMWISE = 2
|
||||
BROADCAST = 3
|
||||
REDUCE = 4
|
||||
TRANSFORM = 5
|
||||
CONTROL = 6
|
||||
|
||||
class Prim:
|
||||
"""Prim"""
|
||||
|
@ -81,6 +82,11 @@ class PrimLib:
|
|||
if relation_func is None:
|
||||
self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
|
||||
|
||||
def default_reshape_relation(self, op, input_idx):
|
||||
axis_relation, elem_relation = self.unknown_relation(op, input_idx)
|
||||
elem_relation = [PrimLib.RESHAPE] * len(elem_relation)
|
||||
return axis_relation, elem_relation
|
||||
|
||||
def default_elemwise_broadcast_relation(self, op, input_idx):
|
||||
"""Process elemwise and broadcast relation"""
|
||||
out_shape = op.output.shape
|
||||
|
@ -116,6 +122,7 @@ class PrimLib:
|
|||
|
||||
default_relation_func = [
|
||||
unknown_relation,
|
||||
default_reshape_relation,
|
||||
default_elemwise_broadcast_relation,
|
||||
default_elemwise_broadcast_relation,
|
||||
default_reduce_relation,
|
||||
|
@ -154,11 +161,15 @@ class PrimLib:
|
|||
'ControlDepend': Prim(CONTROL),
|
||||
'Assign': Prim(ELEMWISE),
|
||||
'Tanh': Prim(ELEMWISE),
|
||||
'ExpandDims': Prim(ELEMWISE),
|
||||
'ExpandDims': Prim(RESHAPE),
|
||||
'InplaceAssign': Prim(ELEMWISE),
|
||||
'@ReduceInit': Prim(ELEMWISE),
|
||||
'Reshape': Prim(ELEMWISE),
|
||||
'Reshape': Prim(RESHAPE),
|
||||
'Squeeze': Prim(RESHAPE),
|
||||
'Flatten': Prim(RESHAPE),
|
||||
'FlattenGrad': Prim(RESHAPE),
|
||||
'Transpose': Prim(TRANSFORM),
|
||||
'Tile': Prim(BROADCAST),
|
||||
}
|
||||
|
||||
default_primtive = Prim(UNKNOWN)
|
||||
|
|
|
@ -50,6 +50,7 @@ class OpInfer:
|
|||
return shape
|
||||
|
||||
default_infer_shape_func = [
|
||||
None,
|
||||
None,
|
||||
default_elementwise_infer.__func__,
|
||||
lambda inputs, attrs: max([t.shape for t in inputs]),
|
||||
|
@ -70,7 +71,8 @@ class OpInfer:
|
|||
|
||||
infer_shape_func = {
|
||||
# add special infer func here
|
||||
'InplaceAssign': lambda inputs, attrs: inputs[2].shape
|
||||
'InplaceAssign': lambda inputs, attrs: inputs[2].shape,
|
||||
'Reshape': lambda inputs, attrs: attrs["shape"],
|
||||
}
|
||||
infer_dtype_func = {
|
||||
# add special infer func here
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <queue>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool IsMultiUserShapeOps(AnfNodePtr node, const FuncGraphManagerPtr &mng) {
|
||||
std::vector<PrimitivePtr> shape_ops = {prim::kPrimReshape};
|
||||
auto &users = mng->node_users();
|
||||
return std::any_of(shape_ops.begin(), shape_ops.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }) &&
|
||||
users[node].size() > 1;
|
||||
}
|
||||
|
||||
AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
|
||||
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(anf_node->func_graph());
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
CNodePtr node = kernel_graph->NewCNode(cnode->inputs());
|
||||
node->set_abstract(cnode->abstract());
|
||||
node->set_forward(cnode->forward().first, cnode->forward().second);
|
||||
node->set_inputs_value(cnode->inputs_value());
|
||||
ScopePtr scope = (anf_node->scope() != kDefaultScope) ? anf_node->scope() : kDefaultScope;
|
||||
node->set_scope(scope);
|
||||
node->set_kernel_info(cnode->kernel_info_ptr());
|
||||
return node;
|
||||
}
|
||||
|
||||
void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
|
||||
auto &users = mng->node_users();
|
||||
AnfNodePtrList splitted_nodes;
|
||||
for (size_t i = 0; i < users[node].size(); ++i) {
|
||||
splitted_nodes.push_back(CloneCNode(node));
|
||||
}
|
||||
|
||||
const auto &index_set = users[node];
|
||||
int i = 0;
|
||||
for (auto [user, index] : index_set) {
|
||||
auto user_node = user->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(user_node);
|
||||
user_node->set_input(index, splitted_nodes[i]);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
bool changed = false;
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
for (const auto &anf_node : todos) {
|
||||
auto node = anf_node->cast<CNodePtr>();
|
||||
if (node != nullptr && IsMultiUserShapeOps(node, mng)) {
|
||||
SplitNode(node, mng);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_
|
||||
#include <memory>
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ShapeOpsSplitter : public Pass {
|
||||
public:
|
||||
ShapeOpsSplitter() : Pass("shape_ops_splitter") {}
|
||||
~ShapeOpsSplitter() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph);
|
||||
};
|
||||
using ShapeOpsSplitterPtr = std::shared_ptr<ShapeOpsSplitter>;
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_
|
|
@ -42,6 +42,7 @@
|
|||
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_cse.h"
|
||||
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
|
||||
#include "backend/optimizer/graph_kernel/value_graph_binder.h"
|
||||
#include "backend/optimizer/pass/communication_op_fusion.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
|
@ -164,6 +165,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
|
|||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm");
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
|
||||
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>());
|
||||
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
|
||||
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
|
||||
|
|
Loading…
Reference in New Issue