Added ExpandDims into GPU fusion list

what's more:
  remove one restriction of getitem in ops fusion.
  add a while loop for the ShapeOpsSplitter pass.
  add ExpandDims into shape_ops list.
This commit is contained in:
dayschan 2020-12-30 16:15:54 +08:00
parent 2645ed3c90
commit 8af78cd5ce
5 changed files with 32 additions and 29 deletions

View File

@ -66,36 +66,23 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNode
return EXCLUDE;
}
// The GetItem node should be fused with its real input and users.
// The GetItem node should be fused with its real input.
// If its real input is not in the fuse_list, the GetItem should be excluded.
AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) {
if (fused_op.empty()) return AnfNodePtrList();
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; };
auto mng = fused_op[0]->func_graph()->manager();
MS_EXCEPTION_IF_NULL(mng);
bool changed = true;
while (changed) {
changed = false;
AnfNodePtrList remove_list;
for (auto getitem : fused_op_set) {
if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue;
// GetItem should be fused with its real input.
auto prev_node = getitem->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
if (check_include(prev_node) == EXCLUDE) {
remove_list.push_back(getitem);
break;
}
// GetItem should be fused with its all users.
const auto &users = mng->node_users()[getitem];
if (std::any_of(users.begin(), users.end(), [check_include](const std::pair<AnfNodePtr, int> &user) {
return check_include(user.first) == EXCLUDE;
})) {
remove_list = DeepLinkedGraphSearch(getitem, check_include);
break;
}
}
if (!remove_list.empty()) {

View File

@ -753,7 +753,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() {
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimCast};
prim::kPrimCast, prim::kPrimExpandDims};
#else
std::vector<PrimitivePtr> fusible_basic_ops;
#endif

View File

@ -33,14 +33,6 @@
namespace mindspore {
namespace opt {
namespace {
bool IsMultiUserShapeOps(AnfNodePtr node, const FuncGraphManagerPtr &mng) {
std::vector<PrimitivePtr> shape_ops = {prim::kPrimReshape, prim::kPrimCast};
auto &users = mng->node_users();
return std::any_of(shape_ops.begin(), shape_ops.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }) &&
users[node].size() > 1;
}
AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(anf_node->func_graph());
MS_EXCEPTION_IF_NULL(kernel_graph);
@ -75,7 +67,14 @@ void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
}
} // namespace
bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
bool ShapeOpsSplitter::IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
auto &users = mng->node_users();
return users[node].size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(), [&node](const PrimitivePtr &prim) {
return IsPrimitiveCNode(node, prim);
});
}
bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
if (mng == nullptr) {
@ -96,5 +95,15 @@ bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
mng->KeepRoots({func_graph});
return changed;
}
bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
bool result = false;
bool changed;
do {
changed = Process(func_graph);
result |= changed;
} while (changed);
return result;
}
} // namespace opt
} // namespace mindspore

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_
#include <memory>
#include <vector>
#include "ir/func_graph.h"
#include "backend/optimizer/common/pass.h"
@ -23,9 +24,15 @@ namespace mindspore {
namespace opt {
class ShapeOpsSplitter : public Pass {
public:
ShapeOpsSplitter() : Pass("shape_ops_splitter") {}
explicit ShapeOpsSplitter(const std::vector<PrimitivePtr> &shape_ops)
: Pass("shape_ops_splitter"), shape_ops_(shape_ops) {}
~ShapeOpsSplitter() override = default;
bool Run(const FuncGraphPtr &func_graph);
private:
bool Process(const FuncGraphPtr &func_graph);
bool IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng);
const std::vector<PrimitivePtr> &shape_ops_;
};
using ShapeOpsSplitterPtr = std::shared_ptr<ShapeOpsSplitter>;
} // namespace opt

View File

@ -177,14 +177,14 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm");
std::vector<PrimitivePtr> black_list = {prim::kPrimReshape, prim::kPrimCast};
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast};
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>());
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>(duplicated_ops));
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list));
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(duplicated_ops));
pm->AddPass(std::make_shared<opt::ArithmeticSimplify>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list));
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(duplicated_ops));
pm->AddPass(std::make_shared<opt::TensorPromotion>());
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());