!11622 【GraphKernel】Moved ShapeOpsSplitter before GraphKernelSplitter

From: @dayschan
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @gaoxiong1
This commit is contained in:
mindspore-ci-bot 2021-02-02 10:46:45 +08:00 committed by Gitee
commit 9efbef72fc
2 changed files with 35 additions and 16 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,12 +34,12 @@ namespace mindspore {
namespace opt {
namespace {
AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(anf_node->func_graph());
MS_EXCEPTION_IF_NULL(kernel_graph);
auto func_graph = anf_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info()));
CNodePtr node = kernel_graph->NewCNode(cnode->inputs());
CNodePtr node = func_graph->NewCNode(cnode->inputs());
node->set_abstract(cnode->abstract());
node->set_forward(cnode->forward().first, cnode->forward().second);
node->set_inputs_value(cnode->inputs_value());
@ -90,19 +90,38 @@ bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) {
changed = true;
}
}
mng->RemoveRoots();
mng->KeepRoots({func_graph});
if (changed) {
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed;
}
bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
auto todos = TopoSort(func_graph->get_return());
bool result = false;
bool changed;
do {
changed = Process(func_graph);
result |= changed;
} while (changed);
for (const auto &anf_node : todos) {
if (AnfAlgo::IsGraphKernel(anf_node)) {
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(anf_node);
bool changed = false;
do {
changed = Process(sub_graph);
result = result || changed;
} while (changed);
}
}
if (result) {
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return result;
}
} // namespace opt

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -185,14 +185,14 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast};
pm->AddPass(std::make_shared<opt::DependFormater>()); // Make more fusion opportunity.
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>(duplicated_ops));
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>());
pm->AddPass(std::make_shared<opt::RaiseReductionPrecision>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(duplicated_ops));
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
pm->AddPass(std::make_shared<opt::ArithmeticSimplify>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(duplicated_ops));
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
pm->AddPass(std::make_shared<opt::TensorPromotion>());
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>(duplicated_ops));
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
// The CSE may output a graph with repeated outputs.