forked from mindspore-Ecosystem/mindspore
Bugfix in graph_kernel_cluster when the op connects to UpdateState
In PR !16073, the tail inputs of UpdateState is ignored in TransformSegmentToAnfGraph, but we spread the inputs of UpdateState in the pass SpreadUpdateState, so the error occurs. In this submission, we call ShrinkUpdateState before clustering operators and call SpreadUpdateState again after that, to avoid the problem. The ParallelFusion also calls the TransformSegmentToAnfGraph interface, so we change it together. Note: this submission is a temporary solution, we will rewrite the TransformSegmentToAnfGraph in graphkernel module, without these special processes. To speed up the pass SpreadUpdateState, we create new UpdateState node and use "mng.Replace", instead of setting inputs and use "mng.KeepRoots".
This commit is contained in:
parent
635c2b0adb
commit
5d980d9702
|
@ -70,8 +70,7 @@ bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr
|
|||
bool changed = false;
|
||||
for (const auto &user : users) {
|
||||
if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) {
|
||||
// Sometime, the user of MakeTuple is not a TupleGetItem, but a UpdateState.
|
||||
continue;
|
||||
MS_LOG(EXCEPTION) << "User of MakeTuple should be GetItem, but got: " << user.first->DebugString();
|
||||
}
|
||||
auto &getitem = user.first;
|
||||
auto idx = GetIndex(getitem);
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -481,6 +482,7 @@ void GraphKernelCluster::Init(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
|
||||
bool GraphKernelCluster::Run(const FuncGraphPtr &func_graph) {
|
||||
(void)std::make_shared<ShrinkUpdateState>()->Run(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
Init(func_graph);
|
||||
|
@ -493,6 +495,7 @@ bool GraphKernelCluster::Run(const FuncGraphPtr &func_graph) {
|
|||
mng->KeepRoots({func_graph});
|
||||
}
|
||||
Clean();
|
||||
(void)std::make_shared<SpreadUpdateState>()->Run(func_graph);
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#include "frontend/operator/ops.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "vm/segment_runner.h"
|
||||
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -710,6 +711,7 @@ bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo>
|
|||
|
||||
bool ParallelOpFusion::Run(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
(void)std::make_shared<ShrinkUpdateState>()->Run(graph);
|
||||
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
||||
|
@ -724,7 +726,9 @@ bool ParallelOpFusion::Run(const FuncGraphPtr &graph) {
|
|||
auto parallel_infos = SearchFusableParallelCNodes(groups);
|
||||
|
||||
// Create core-fuse subgraph and change origin graph.
|
||||
return CreateParallelOpSubGraphs(parallel_infos, kernel_graph);
|
||||
bool changed = CreateParallelOpSubGraphs(parallel_infos, kernel_graph);
|
||||
(void)std::make_shared<SpreadUpdateState>()->Run(graph);
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -50,7 +50,8 @@ AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) {
|
|||
return result;
|
||||
}
|
||||
|
||||
AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdate(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph) {
|
||||
AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdateState(const AnfNodePtrList &nodes,
|
||||
const FuncGraphPtr &func_graph) {
|
||||
AnfNodePtrList result;
|
||||
for (auto node : nodes) {
|
||||
if (node->abstract()->isa<abstract::AbstractTuple>()) {
|
||||
|
@ -65,7 +66,6 @@ AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdate(const AnfNodePtrList &nod
|
|||
|
||||
auto tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
tuple_getitem->set_fullname_with_scope(node->fullname_with_scope() + "_TupleGetItem_" + std::to_string(i));
|
||||
tuple_getitem->set_abstract(node_abstract[i]);
|
||||
tuple_getitem->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
result.push_back(tuple_getitem);
|
||||
|
@ -76,30 +76,29 @@ AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdate(const AnfNodePtrList &nod
|
|||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
|
||||
auto todos = GetUpdateStateList(func_graph);
|
||||
bool changed = false;
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
for (auto node : todos) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() <= kUpdateStateRealInput) continue;
|
||||
auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
|
||||
// extend inputs of update if which have multiple outputs
|
||||
inputs = ExtendInputsOfUpdate(inputs, func_graph);
|
||||
if (inputs.size() + 2 != cnode->size() || inputs[0] != cnode->input(2)) {
|
||||
AnfNodePtrList node_inputs = {cnode->input(0), cnode->input(1)};
|
||||
// extend inputs of UpdateState if which have multiple outputs
|
||||
inputs = ExtendInputsOfUpdateState(inputs, func_graph);
|
||||
if (inputs.size() + kUpdateStateRealInput != cnode->size() || inputs[0] != cnode->input(kUpdateStateRealInput)) {
|
||||
AnfNodePtrList node_inputs = {cnode->input(kAnfPrimitiveIndex), cnode->input(kUpdateStateStateInput)};
|
||||
node_inputs.insert(node_inputs.end(), inputs.begin(), inputs.end());
|
||||
cnode->set_inputs(node_inputs);
|
||||
// Create a new UpdateState
|
||||
auto new_node = func_graph->NewCNode(node_inputs);
|
||||
new_node->set_abstract(node->abstract());
|
||||
mng->Replace(node, new_node);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (changed) {
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ class SpreadUpdateState : public Pass {
|
|||
public:
|
||||
SpreadUpdateState() : Pass("spread_update_state") {}
|
||||
~SpreadUpdateState() override = default;
|
||||
AnfNodePtrList ExtendInputsOfUpdate(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph);
|
||||
AnfNodePtrList ExtendInputsOfUpdateState(const AnfNodePtrList &nodes, const FuncGraphPtr &func_graph);
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue