Bugfix in graph_kernel_cluster when the op connects to UpdateState

In PR !16073, the tail inputs of UpdateState is ignored in TransformSegmentToAnfGraph,
but we spread the inputs of UpdateState in the pass SpreadUpdateState, so the error occurs.

In this submission, we call ShrinkUpdateState before clustering operators and
call SpreadUpdateState again after that, to avoid the problem.
The ParallelFusion also calls the TransformSegmentToAnfGraph interface, so we change it together.

Note: this submission is a temporary solution, we will rewrite the TransformSegmentToAnfGraph in
graphkernel module, without these special processes.

To speed up the pass SpreadUpdateState, we create new UpdateState node and use "mng.Replace",
instead of setting inputs and use "mng.KeepRoots".
This commit is contained in:
dayschan 2021-06-03 11:08:13 +08:00
parent 635c2b0adb
commit 5d980d9702
5 changed files with 23 additions and 18 deletions

View File

@ -70,8 +70,7 @@ bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr
bool changed = false;
for (const auto &user : users) {
if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
// Sometime, the user of MakeTuple is not a TupleGetItem, but a UpdateState.
continue;
MS_LOG(EXCEPTION) << "User of MakeTuple should be GetItem, but got: " << user.first->DebugString();
}
auto &getitem = user.first;
auto idx = GetIndex(getitem);

View File

@ -32,6 +32,7 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/pass/getitem_tuple.h"
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
namespace mindspore {
namespace opt {
@ -481,6 +482,7 @@ void GraphKernelCluster::Init(const FuncGraphPtr &func_graph) {
}
bool GraphKernelCluster::Run(const FuncGraphPtr &func_graph) {
(void)std::make_shared<ShrinkUpdateState>()->Run(func_graph);
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
Init(func_graph);
@ -493,6 +495,7 @@ bool GraphKernelCluster::Run(const FuncGraphPtr &func_graph) {
mng->KeepRoots({func_graph});
}
Clean();
(void)std::make_shared<SpreadUpdateState>()->Run(func_graph);
return changed;
}
} // namespace opt

View File

@ -36,6 +36,7 @@
#include "frontend/operator/ops.h"
#include "ir/func_graph_cloner.h"
#include "vm/segment_runner.h"
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
namespace mindspore {
namespace opt {
@ -710,6 +711,7 @@ bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo>
bool ParallelOpFusion::Run(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
(void)std::make_shared<ShrinkUpdateState>()->Run(graph);
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
MS_EXCEPTION_IF_NULL(kernel_graph);
@ -724,7 +726,9 @@ bool ParallelOpFusion::Run(const FuncGraphPtr &graph) {
auto parallel_infos = SearchFusableParallelCNodes(groups);
// Create core-fuse subgraph and change origin graph.
return CreateParallelOpSubGraphs(parallel_infos, kernel_graph);
bool changed = CreateParallelOpSubGraphs(parallel_infos, kernel_graph);
(void)std::make_shared<SpreadUpdateState>()->Run(graph);
return changed;
}
} // namespace opt
} // namespace mindspore

View File

@ -50,7 +50,8 @@ AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) {
return result;
}
AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdate(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph) {
AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdateState(const AnfNodePtrList &nodes,
const FuncGraphPtr &func_graph) {
AnfNodePtrList result;
for (auto node : nodes) {
if (node->abstract()->isa<abstract::AbstractTuple>()) {
@ -65,7 +66,6 @@ AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdate(const AnfNodePtrList &nod
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
MS_EXCEPTION_IF_NULL(tuple_getitem);
tuple_getitem->set_fullname_with_scope(node->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i));
tuple_getitem->set_abstract(node_abstract[i]);
tuple_getitem->set_kernel_info(std::make_shared<device::KernelInfo>());
result.push_back(tuple_getitem);
@ -76,30 +76,29 @@ AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdate(const AnfNodePtrList &nod
}
return result;
}
bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
auto todos = GetUpdateStateList(func_graph);
bool changed = false;
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
for (auto node : todos) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() <= kUpdateStateRealInput) continue;
auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
// extend inputs of update if which have multiple outputs
inputs = ExtendInputsOfUpdate(inputs, func_graph);
if (inputs.size() + 2 != cnode->size() || inputs[0] != cnode->input(2)) {
AnfNodePtrList node_inputs = {cnode->input(0), cnode->input(1)};
// extend inputs of UpdateState if which have multiple outputs
inputs = ExtendInputsOfUpdateState(inputs, func_graph);
if (inputs.size() + kUpdateStateRealInput != cnode->size() || inputs[0] != cnode->input(kUpdateStateRealInput)) {
AnfNodePtrList node_inputs = {cnode->input(kAnfPrimitiveIndex), cnode->input(kUpdateStateStateInput)};
node_inputs.insert(node_inputs.end(), inputs.begin(), inputs.end());
cnode->set_inputs(node_inputs);
// Create a new UpdateState
auto new_node = func_graph->NewCNode(node_inputs);
new_node->set_abstract(node->abstract());
mng->Replace(node, new_node);
changed = true;
}
}
if (changed) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed;
}

View File

@ -39,7 +39,7 @@ class SpreadUpdateState : public Pass {
public:
SpreadUpdateState() : Pass("spread_update_state") {}
~SpreadUpdateState() override = default;
AnfNodePtrList ExtendInputsOfUpdate(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph);
AnfNodePtrList ExtendInputsOfUpdateState(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph);
bool Run(const FuncGraphPtr &func_graph) override;
};