!8689 [GraphKernel] Split shape ops for more fusion opportunity.

From: @tronzhang
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2020-11-18 10:41:29 +08:00 committed by Gitee
commit 9969c83f75
6 changed files with 161 additions and 13 deletions

View File

@ -87,7 +87,8 @@ class GraphSplitByPattern:
_redirect_relation(a.out_relations)
for a, _ in area.out_relations.items():
_redirect_relation(a.in_relations)
self.mode = self.MODE_COMPOSITE
if self.pattern > PrimLib.RESHAPE:
self.mode = self.MODE_COMPOSITE
def check_circle(self, to):
"""Check circle. It returns false if circle exists"""
@ -148,7 +149,7 @@ class GraphSplitByPattern:
def split(self):
"""Split graph by pattern"""
def _elemwise_depth(dom):
if dom.pattern > PrimLib.BROADCAST or len(dom.in_relations) != 1:
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1:
return None
a, r = list(dom.in_relations.items())[0]
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 and r != PrimLib.ELEMWISE:
@ -156,7 +157,7 @@ class GraphSplitByPattern:
return [a]
def _elemwise_width(dom):
if dom.pattern > PrimLib.BROADCAST:
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
return None
fused = []
for a, r in dom.in_relations.items():
@ -165,7 +166,7 @@ class GraphSplitByPattern:
return fused
def _broadcast_depth(dom):
if dom.pattern > PrimLib.BROADCAST or len(dom.in_relations) != 1:
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1:
return None
a, r = list(dom.in_relations.items())[0]
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \
@ -174,7 +175,7 @@ class GraphSplitByPattern:
return [a]
def _broadcast_width(dom):
if dom.pattern > PrimLib.BROADCAST:
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
return None
fused = []
for a, r in dom.in_relations.items():

View File

@ -65,11 +65,12 @@ class PrimLib:
"""Prim lib"""
UNKNOWN = 0
ELEMWISE = 1
BROADCAST = 2
REDUCE = 3
TRANSFORM = 4
CONTROL = 5
RESHAPE = 1
ELEMWISE = 2
BROADCAST = 3
REDUCE = 4
TRANSFORM = 5
CONTROL = 6
class Prim:
"""Prim"""
@ -81,6 +82,11 @@ class PrimLib:
if relation_func is None:
self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
def default_reshape_relation(self, op, input_idx):
axis_relation, elem_relation = self.unknown_relation(op, input_idx)
elem_relation = [PrimLib.RESHAPE] * len(elem_relation)
return axis_relation, elem_relation
def default_elemwise_broadcast_relation(self, op, input_idx):
"""Process elemwise and broadcast relation"""
out_shape = op.output.shape
@ -116,6 +122,7 @@ class PrimLib:
default_relation_func = [
unknown_relation,
default_reshape_relation,
default_elemwise_broadcast_relation,
default_elemwise_broadcast_relation,
default_reduce_relation,
@ -154,11 +161,15 @@ class PrimLib:
'ControlDepend': Prim(CONTROL),
'Assign': Prim(ELEMWISE),
'Tanh': Prim(ELEMWISE),
'ExpandDims': Prim(ELEMWISE),
'ExpandDims': Prim(RESHAPE),
'InplaceAssign': Prim(ELEMWISE),
'@ReduceInit': Prim(ELEMWISE),
'Reshape': Prim(ELEMWISE),
'Reshape': Prim(RESHAPE),
'Squeeze': Prim(RESHAPE),
'Flatten': Prim(RESHAPE),
'FlattenGrad': Prim(RESHAPE),
'Transpose': Prim(TRANSFORM),
'Tile': Prim(BROADCAST),
}
default_primtive = Prim(UNKNOWN)

View File

@ -50,6 +50,7 @@ class OpInfer:
return shape
default_infer_shape_func = [
None,
None,
default_elementwise_infer.__func__,
lambda inputs, attrs: max([t.shape for t in inputs]),
@ -70,7 +71,8 @@ class OpInfer:
infer_shape_func = {
# add special infer func here
'InplaceAssign': lambda inputs, attrs: inputs[2].shape
'InplaceAssign': lambda inputs, attrs: inputs[2].shape,
'Reshape': lambda inputs, attrs: attrs["shape"],
}
infer_dtype_func = {
# add special infer func here

View File

@ -0,0 +1,99 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
#include <algorithm>
#include <vector>
#include <string>
#include <unordered_set>
#include <utility>
#include <queue>
#include <map>
#include <unordered_map>
#include "frontend/optimizer/irpass.h"
#include "pipeline/jit/parse/python_adapter.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "debug/anf_ir_dump.h"
namespace mindspore {
namespace opt {
namespace {
bool IsMultiUserShapeOps(AnfNodePtr node, const FuncGraphManagerPtr &mng) {
std::vector<PrimitivePtr> shape_ops = {prim::kPrimReshape};
auto &users = mng->node_users();
return std::any_of(shape_ops.begin(), shape_ops.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }) &&
users[node].size() > 1;
}
AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(anf_node->func_graph());
MS_EXCEPTION_IF_NULL(kernel_graph);
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
CNodePtr node = kernel_graph->NewCNode(cnode->inputs());
node->set_abstract(cnode->abstract());
node->set_forward(cnode->forward().first, cnode->forward().second);
node->set_inputs_value(cnode->inputs_value());
ScopePtr scope = (anf_node->scope() != kDefaultScope) ? anf_node->scope() : kDefaultScope;
node->set_scope(scope);
node->set_kernel_info(cnode->kernel_info_ptr());
return node;
}
void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
auto &users = mng->node_users();
AnfNodePtrList splitted_nodes;
for (size_t i = 0; i < users[node].size(); ++i) {
splitted_nodes.push_back(CloneCNode(node));
}
const auto &index_set = users[node];
int i = 0;
for (auto [user, index] : index_set) {
auto user_node = user->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_node);
user_node->set_input(index, splitted_nodes[i]);
i++;
}
}
} // namespace
bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
bool changed = false;
auto todos = TopoSort(func_graph->get_return());
for (const auto &anf_node : todos) {
auto node = anf_node->cast<CNodePtr>();
if (node != nullptr && IsMultiUserShapeOps(node, mng)) {
SplitNode(node, mng);
changed = true;
}
}
mng->RemoveRoots();
mng->KeepRoots({func_graph});
return changed;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_
#include <memory>
#include "ir/func_graph.h"
#include "backend/optimizer/common/pass.h"
namespace mindspore {
namespace opt {
class ShapeOpsSplitter : public Pass {
public:
ShapeOpsSplitter() : Pass("shape_ops_splitter") {}
~ShapeOpsSplitter() override = default;
bool Run(const FuncGraphPtr &func_graph);
};
using ShapeOpsSplitterPtr = std::shared_ptr<ShapeOpsSplitter>;
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_

View File

@ -42,6 +42,7 @@
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
#include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
#include "backend/optimizer/graph_kernel/graph_kernel_cse.h"
#include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
#include "backend/optimizer/graph_kernel/value_graph_binder.h"
#include "backend/optimizer/pass/communication_op_fusion.h"
#include "backend/optimizer/pass/getitem_tuple.h"
@ -164,6 +165,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm");
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>());
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
pm->AddPass(std::make_shared<opt::CompositeOpsFusion>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());