forked from mindspore-Ecosystem/mindspore
!11622 【GraphKernel】Moved ShapeOpsSplitter before GraphKernelSplitter
From: @dayschan Reviewed-by: @gaoxiong1,@dylangeng Signed-off-by: @gaoxiong1
This commit is contained in:
commit
9efbef72fc
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -34,12 +34,12 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace {
|
||||
AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
|
||||
auto kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(anf_node->func_graph());
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto func_graph = anf_node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info()));
|
||||
CNodePtr node = kernel_graph->NewCNode(cnode->inputs());
|
||||
CNodePtr node = func_graph->NewCNode(cnode->inputs());
|
||||
node->set_abstract(cnode->abstract());
|
||||
node->set_forward(cnode->forward().first, cnode->forward().second);
|
||||
node->set_inputs_value(cnode->inputs_value());
|
||||
|
@ -90,19 +90,38 @@ bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) {
|
|||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
if (changed) {
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
bool result = false;
|
||||
bool changed;
|
||||
do {
|
||||
changed = Process(func_graph);
|
||||
result |= changed;
|
||||
} while (changed);
|
||||
for (const auto &anf_node : todos) {
|
||||
if (AnfAlgo::IsGraphKernel(anf_node)) {
|
||||
auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(anf_node);
|
||||
bool changed = false;
|
||||
do {
|
||||
changed = Process(sub_graph);
|
||||
result = result || changed;
|
||||
} while (changed);
|
||||
}
|
||||
}
|
||||
|
||||
if (result) {
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -185,14 +185,14 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
|
|||
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast};
|
||||
pm->AddPass(std::make_shared<opt::DependFormater>()); // Make more fusion opportunity.
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelExpander>());
|
||||
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>(duplicated_ops));
|
||||
pm->AddPass(std::make_shared<opt::BasicOpsFusion>());
|
||||
pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>());
|
||||
pm->AddPass(std::make_shared<opt::RaiseReductionPrecision>());
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(duplicated_ops));
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
|
||||
pm->AddPass(std::make_shared<opt::ArithmeticSimplify>());
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>(duplicated_ops));
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
|
||||
pm->AddPass(std::make_shared<opt::TensorPromotion>());
|
||||
pm->AddPass(std::make_shared<opt::ShapeOpsSplitter>(duplicated_ops));
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
|
||||
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
|
||||
// The CSE may output a graph with repeated outputs.
|
||||
|
|
Loading…
Reference in New Issue