forked from mindspore-Ecosystem/mindspore
Added ExpandDims into GPU fusion list
what's more: remove one restriction of getitem in ops fusion. add a while loop for the ShapeOpsSplitter pass. add ExpandDims into shape_ops list.
This commit is contained in:
parent
2645ed3c90
commit
8af78cd5ce
|
@ -66,36 +66,23 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNode
|
|||
return EXCLUDE;
|
||||
}
|
||||
|
||||
// The GetItem node should be fused with its real input and users.
|
||||
// The GetItem node should be fused with its real input.
|
||||
// If its real input is not in the fuse_list, the GetItem should be excluded.
|
||||
AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) {
|
||||
if (fused_op.empty()) return AnfNodePtrList();
|
||||
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
|
||||
auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; };
|
||||
|
||||
auto mng = fused_op[0]->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
AnfNodePtrList remove_list;
|
||||
for (auto getitem : fused_op_set) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue;
|
||||
|
||||
// GetItem should be fused with its real input.
|
||||
auto prev_node = getitem->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
|
||||
if (check_include(prev_node) == EXCLUDE) {
|
||||
remove_list.push_back(getitem);
|
||||
break;
|
||||
}
|
||||
|
||||
// GetItem should be fused with its all users.
|
||||
const auto &users = mng->node_users()[getitem];
|
||||
if (std::any_of(users.begin(), users.end(), [check_include](const std::pair<AnfNodePtr, int> &user) {
|
||||
return check_include(user.first) == EXCLUDE;
|
||||
})) {
|
||||
remove_list = DeepLinkedGraphSearch(getitem, check_include);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!remove_list.empty()) {
|
||||
|
|
|
@ -753,7 +753,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() {
|
|||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
|
||||
prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
||||
prim::kPrimCast};
|
||||
prim::kPrimCast, prim::kPrimExpandDims};
|
||||
#else
|
||||
std::vector<PrimitivePtr> fusible_basic_ops;
|
||||
#endif
|
||||
|
|
|
@ -33,14 +33,6 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool IsMultiUserShapeOps(AnfNodePtr node, const FuncGraphManagerPtr &mng) {
|
||||
std::vector<PrimitivePtr> shape_ops = {prim::kPrimReshape, prim::kPrimCast};
|
||||
auto &users = mng->node_users();
|
||||
return std::any_of(shape_ops.begin(), shape_ops.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }) &&
|
||||
users[node].size() > 1;
|
||||
}
|
||||
|
||||
AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
|
||||
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(anf_node->func_graph());
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
@ -75,7 +67,14 @@ void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
|
||||
bool ShapeOpsSplitter::IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
|
||||
auto &users = mng->node_users();
|
||||
return users[node].size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(), [&node](const PrimitivePtr &prim) {
|
||||
return IsPrimitiveCNode(node, prim);
|
||||
});
|
||||
}
|
||||
|
||||
bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
|
@ -96,5 +95,15 @@ bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
|
|||
mng->KeepRoots({func_graph});
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
|
||||
bool result = false;
|
||||
bool changed;
|
||||
do {
|
||||
changed = Process(func_graph);
|
||||
result |= changed;
|
||||
} while (changed);
|
||||
return result;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SHAPE_OPS_SPLITTER_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
|
||||
|
@ -23,9 +24,15 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class ShapeOpsSplitter : public Pass {
|
||||
public:
|
||||
ShapeOpsSplitter() : Pass("shape_ops_splitter") {}
|
||||
explicit ShapeOpsSplitter(const std::vector<PrimitivePtr> &shape_ops)
|
||||
: Pass("shape_ops_splitter"), shape_ops_(shape_ops) {}
|
||||
~ShapeOpsSplitter() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
bool Process(const FuncGraphPtr &func_graph);
|
||||
bool IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng);
|
||||
const std::vector<PrimitivePtr> &shape_ops_;
|
||||
};
|
||||
using ShapeOpsSplitterPtr = std::shared_ptr<ShapeOpsSplitter>;
|
||||
} // namespace opt
|
||||
|
|
|
@ -177,14 +177,14 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
|
|||
}
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>("graph_kernel_pm");
|
||||
std::vector<PrimitivePtr> black_list = {prim::kPrimReshape, prim::kPrimCast};
|
||||
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast};
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
|
||||
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>());
|
||||
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>(duplicated_ops));
|
||||
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
|
||||
pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>());
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list));
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(duplicated_ops));
|
||||
pm->AddPass(std::make_shared<opt::ArithmeticSimplify>());
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(black_list));
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(duplicated_ops));
|
||||
pm->AddPass(std::make_shared<opt::TensorPromotion>());
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
|
||||
|
|
Loading…
Reference in New Issue