!12926 【GraphKernel】Process for UpdateState node

From: @dayschan
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2021-03-10 16:12:17 +08:00 committed by Gitee
commit bc38590e53
5 changed files with 498 additions and 133 deletions

View File

@ -28,6 +28,7 @@
#include "debug/anf_ir_dump.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
namespace mindspore {
namespace opt {
@ -50,9 +51,10 @@ void SetIndex(const AnfNodePtr &getitem_node, size_t index) {
idx_node->set_kernel_info(std::make_shared<device::KernelInfo>());
getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
}
} // namespace
bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
bool merge_repeated_getitem = false) {
bool merge_repeated_getitem) {
MS_EXCEPTION_IF_NULL(mng);
MS_EXCEPTION_IF_NULL(getitem_list);
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
@ -194,121 +196,6 @@ class UnifyRepeatedGetitem : public Pass {
}
};
/* Merge the get_item nodes that have same index.
* subgraph graph_kernel(%para1, %para2)
* %1 = TensorAdd(%para1, %para2)
* %2 = Neg(%1)
* %3 = make_tuple(%1, %2)
* return (%3)
* %1 = call @graph_kernel(%p1, %p2)
* %2 = tuple_getitem(%1, 0)
* %3 = tuple_getitem(%1, 1)
* %4 = ControlDepend(%0, %2)
* %5 = other_user(%3)
* --->
* subgraph graph_kernel(%para1, %para2)
* %1 = TensorAdd(%para1, %para2)
* %2 = Neg(%1)
* %3 = make_tuple(%1, %2)
* return (%3)
* %1 = call @graph_kernel(%p1, %p2)
* %3 = tuple_getitem(%1, 1)
* %4 = ControlDepend(%0, %3)
* %5 = other_user(%3)
*
* Then the output 0 can be eliminate in the later pass.
*/
class EliminateGetitemForControlDepend : public Pass {
public:
bool Run(const FuncGraphPtr &func_graph) {
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
bool changed = false;
for (const auto &node : todos) {
getitems_.clear();
GetGraphKernelGetitemList(mng, node, &getitems_, false);
if (getitems_.empty()) continue;
indexes_.clear();
GetIndexesToControlDepend(mng);
FilterRedundantOutputs(node);
if (indexes_.empty()) continue;
size_t index = GetFinalIndex(node);
changed = ReplaceGetitems(mng, index) || changed;
}
return changed;
}
private:
AnfNodePtrList getitems_; // Users of GraphKernel node with multiple outputs.
std::vector<size_t> indexes_; // Indexes of MakeTuple to be eliminated.
bool ReplaceGetitems(const FuncGraphManagerPtr &mng, size_t index) {
MS_EXCEPTION_IF_NULL(getitems_[index]);
bool changed = false;
for (auto i : indexes_) {
if (i != index) {
MS_EXCEPTION_IF_NULL(getitems_[i]);
mng->Replace(getitems_[i], getitems_[index]);
changed = true;
}
}
return changed;
}
// Find the redundant output index.
// the real output should have multiple users.
void FilterRedundantOutputs(const AnfNodePtr &node) {
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
auto &users = mng->node_users();
auto maketuple = func_graph->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(maketuple);
std::vector<size_t> result;
for (auto i : indexes_) {
auto real_output = maketuple->input(i + 1);
if (users[real_output].size() > 1) {
result.push_back(i);
}
}
indexes_ = std::move(result);
}
// Get the nodes that only have ControlDepend users.
void GetIndexesToControlDepend(const FuncGraphManagerPtr &mng) {
for (size_t i = 0; i < getitems_.size(); ++i) {
const AnfNodePtr &getitem = getitems_[i];
if (getitem == nullptr) {
continue;
}
const auto &getitem_user = mng->node_users()[getitem];
if (std::all_of(getitem_user.begin(), getitem_user.end(), [](const std::pair<AnfNodePtr, int> &user) {
return IsPrimitiveCNode(user.first, prim::kPrimControlDepend);
})) {
indexes_.push_back(i);
}
}
}
size_t GetFinalIndex(const AnfNodePtr &node) {
auto is_redundant_index = [this](size_t i) {
return std::find(indexes_.begin(), indexes_.end(), i) != indexes_.end();
};
for (size_t i = 0; i < getitems_.size(); ++i) {
if (getitems_[i] != nullptr && !is_redundant_index(i)) {
return i;
}
}
return indexes_[0];
}
};
} // namespace
// Remove the output without user or with virtual user (like ControlDepend)
bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
if (mng == nullptr) {
@ -319,13 +206,11 @@ bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
changed = std::make_shared<UnifyRepeatedOutput>()->Run(func_graph) || changed;
changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
changed = std::make_shared<EliminateGetitemForControlDepend>()->Run(func_graph) || changed;
changed = Process(func_graph) || changed;
changed = std::make_shared<EliminateHangingOutput>()->Run(func_graph) || changed;
return changed;
}
// update the GetItem(node, i) to GetItem(node, i - offset)
void EliminateRedundantOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset) {
void EliminateHangingOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset) {
if (offset == 0) return;
MS_EXCEPTION_IF_NULL(getitem);
auto index = GetIndex(getitem);
@ -336,7 +221,7 @@ void EliminateRedundantOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, siz
SetIndex(getitem, index);
}
AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) {
AnfNodePtr EliminateHangingOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) {
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
MS_EXCEPTION_IF_NULL(func_graph);
auto old_maketuple = func_graph->output()->cast<CNodePtr>();
@ -379,7 +264,7 @@ AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, co
return graph_kernel_node;
}
bool EliminateRedundantOutput::Process(const FuncGraphPtr &func_graph) {
bool EliminateHangingOutput::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto todos = FindGraphKernelsWithMultiOutput(func_graph);

View File

@ -20,17 +20,54 @@
namespace mindspore {
namespace opt {
/* Eliminate the output without external user
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0) // the getitem(1) does not exist.
* %3 = op(%2)
* graph_kernel:
* %1 = TensorAdd(p1, p2)
* %2 = Sub(p1, p2)
* return make_tuple(%1, %2)
* --->
* %1 = call @graph_kernel(p1, p2)
* %3 = op(%1) // if only one output remains, the getitem is not used
* graph_kernel:
* %1 = TensorAdd(p1, p2)
* return %1 // the Sub was eliminated
*/
class EliminateHangingOutput : public Pass {
public:
bool Run(const FuncGraphPtr &func_graph) override;
private:
// update the GetItem(node, i) to GetItem(node, i - offset)
void UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset);
AnfNodePtr ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems);
};
// Remove the output without user or with virtual user (like UpdateState)
class EliminateRedundantOutput : public Pass {
public:
EliminateRedundantOutput() : Pass("eliminate_redundant_output") {}
~EliminateRedundantOutput() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
bool Process(const FuncGraphPtr &func_graph);
void UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset);
AnfNodePtr ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems);
};
bool IsSideEffectNode(const AnfNodePtr &node);
AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph);
/**
* @brief Get the GraphKernel's user getitems
*
* @param mng FuncGraphManagerPtr for the main func_graph
* @param node The cnode that indicates the GraphKernel
* @param getitem_list The user getitem list.
* @param merge_repeated_getitem If true, getitems with same index will be merged,
* otherwise, only one getitem will be outputted.
* @return If the graph was changed, returns true, otherwise returns false.
*/
bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
bool merge_repeated_getitem = false);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_

View File

@ -40,6 +40,7 @@
#include "backend/optimizer/graph_kernel/optimize_assign.h"
#include "backend/optimizer/graph_kernel/split_assign.h"
#include "backend/optimizer/graph_kernel/reorder_ops.h"
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
#include "backend/optimizer/pass/getitem_tuple.h"
namespace mindspore {
@ -56,6 +57,9 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() {
if (is_ascend) {
pm->AddPass(std::make_shared<ReorderOps>());
}
// Spread the MakeTuple input of UpdateState
pm->AddPass(std::make_shared<SpreadUpdateState>());
return pm;
}
@ -99,6 +103,8 @@ PassManagerPtr GraphKernelOptimizer::Split() {
// Make certain nodes redundant so that they are used by only one user,
// which can avoid unnecessary input-output and get better performance.
if (is_gpu) {
// preprocess for ShapeOpsSplitter
pm->AddPass(std::make_shared<ExtendOutputForUpdateState>());
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast};
pm->AddPass(std::make_shared<ShapeOpsSplitter>(duplicated_ops));
}
@ -106,15 +112,16 @@ PassManagerPtr GraphKernelOptimizer::Split() {
// Split kernel according to costmodel
pm->AddPass(std::make_shared<GraphKernelSplitter>());
// Eliminate the redundant node that is copied above but not handled by GraphKernelSplitter
if (is_gpu) {
pm->AddPass(std::make_shared<GraphKernelCSE>());
pm->AddPass(std::make_shared<EliminateRedundantOutput>());
}
// After Simplify and Splitter, a lot of redundant getitem/maketuple
// will be exposed, use GetitemTuple Pass to delete them.
pm->AddPass(std::make_shared<GetitemTuple>());
// Eliminate the redundant node that is copied above but not handled by GraphKernelSplitter
if (is_gpu) {
pm->AddPass(std::make_shared<MergeOutputForUpdateState>());
pm->AddPass(std::make_shared<GraphKernelCSE>());
pm->AddPass(std::make_shared<EliminateRedundantOutput>());
}
return pm;
}
@ -146,6 +153,9 @@ PassManagerPtr GraphKernelOptimizer::PostProcess() {
auto pm = std::make_shared<PassManager>("graphkernel_stage7_postprocess");
// Add the new tensors to the kernel_graph
pm->AddPass(std::make_shared<BindValueToGraph>());
// Make Tuple for the inputs of UpdateState. (the reverse of SpreadUpdateState)
pm->AddPass(std::make_shared<ShrinkUpdateState>());
return pm;
}
@ -163,6 +173,12 @@ void GraphKernelOptimizer::Run(const KernelGraphPtr &kernel_graph) {
optimizer->AddPassManager(HighLevelOpt2());
optimizer->AddPassManager(Combine());
optimizer->AddPassManager(PostProcess());
auto mng = kernel_graph->manager();
if (mng == nullptr) {
mng = Manage(kernel_graph, true);
kernel_graph->set_manager(mng);
}
(void)optimizer->Optimize(kernel_graph);
}

View File

@ -0,0 +1,264 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
#include <vector>
#include <set>
#include <memory>
#include <utility>
#include <algorithm>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
namespace mindspore {
namespace opt {
AnfNodePtrList GetUpdateStateList(const FuncGraphPtr &func_graph) {
auto todos = TopoSort(func_graph->get_return());
AnfNodePtrList result;
std::copy_if(todos.begin(), todos.end(), std::back_inserter(result),
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimUpdateState); });
return result;
}
AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) {
AnfNodePtrList result;
for (size_t i = begin_index; i < nodes.size(); i++) {
if (IsPrimitiveCNode(nodes[i], prim::kPrimMakeTuple)) {
auto mt = nodes[i]->cast<CNodePtr>();
// recursively spread all inner tuples.
auto mt_inputs = SpreadTuples(mt->inputs(), 1);
result.insert(result.end(), mt_inputs.begin(), mt_inputs.end());
} else {
result.push_back(nodes[i]);
}
}
return result;
}
bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
auto todos = GetUpdateStateList(func_graph);
bool changed = false;
for (auto node : todos) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() <= kUpdateStateRealInput) continue;
auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
if (inputs.size() + 2 != cnode->size() || inputs[0] != cnode->input(2)) {
AnfNodePtrList node_inputs = {cnode->input(0), cnode->input(1)};
node_inputs.insert(node_inputs.end(), inputs.begin(), inputs.end());
cnode->set_inputs(node_inputs);
changed = true;
}
}
if (changed) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed;
}
bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) {
auto todos = GetUpdateStateList(func_graph);
bool changed = false;
for (auto node : todos) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() <= kUpdateStateRealInput) continue;
AnfNodePtrList mt_inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
AbstractBasePtrList abs_list;
std::transform(mt_inputs.begin(), mt_inputs.end(), std::back_inserter(abs_list),
[](const AnfNodePtr &inp) { return inp->abstract(); });
mt_inputs.insert(mt_inputs.begin(), NewValueNode(prim::kPrimMakeTuple));
auto mt_node = func_graph->NewCNode(mt_inputs);
mt_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
mt_node->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfNodePtrList inputs = {cnode->input(0), cnode->input(1), mt_node};
cnode->set_inputs(inputs);
changed = true;
}
if (changed) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed;
}
bool ExtendOutputForUpdateState::Run(const FuncGraphPtr &func_graph) {
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
bool changed = false;
for (const auto &node : todos) {
GetGraphKernelGetitemList(mng, node, &getitems_, false);
if (getitems_.empty()) continue;
FindIndexesToUpdateState(mng);
if (indexes_.empty()) continue;
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
FilterIndexes(sub_func_graph);
if (indexes_.empty()) continue;
for (auto idx : indexes_) {
changed = ProcessIndex(func_graph, sub_func_graph, idx) || changed;
}
}
if (changed) {
std::make_shared<SpreadUpdateState>()->Run(func_graph);
std::make_shared<EliminateHangingOutput>()->Run(func_graph);
}
return changed;
}
void ExtendOutputForUpdateState::FindIndexesToUpdateState(const FuncGraphManagerPtr &mng) {
indexes_.clear();
external_user_type_.clear();
external_user_type_.resize(getitems_.size(), ExternalUserType::kNormalOp);
for (size_t i = 0; i < getitems_.size(); ++i) {
const AnfNodePtr &getitem = getitems_[i];
if (getitem == nullptr) continue;
const auto &getitem_user = mng->node_users()[getitem];
auto IsUpdateState = [](const std::pair<AnfNodePtr, int> &user) {
return IsPrimitiveCNode(user.first, prim::kPrimUpdateState);
};
if (std::all_of(getitem_user.begin(), getitem_user.end(), IsUpdateState)) {
external_user_type_[i] = ExternalUserType::kUpdateState;
indexes_.push_back(i);
} else if (std::any_of(getitem_user.begin(), getitem_user.end(), IsUpdateState)) {
external_user_type_[i] = ExternalUserType::kMix;
indexes_.push_back(i);
}
}
}
void ExtendOutputForUpdateState::FilterIndexes(const FuncGraphPtr &func_graph) {
auto output_node = func_graph->output()->cast<CNodePtr>();
// do not process the side-effect nodes.
indexes_.erase(std::remove_if(indexes_.begin(), indexes_.end(),
[&output_node](size_t i) { return IsSideEffectNode(output_node->input(i + 1)); }),
indexes_.end());
}
std::vector<size_t> ExtendOutputForUpdateState::FindAllOutputs(const FuncGraphPtr &func_graph, size_t index) {
auto output_node = func_graph->output()->cast<CNodePtr>();
auto index_node = output_node->input(index);
std::vector<size_t> group;
// if the `out_node` is a user (direct or indirect) of the `index_node`, returns true
auto DependsOnIndexNode = [&index_node](const AnfNodePtr &out_node) -> bool {
bool result = false;
auto IncludeFunc = [&result, &index_node](const AnfNodePtr &node) {
if (node == index_node) {
result = true;
return EXCLUDE;
}
return result ? EXCLUDE : FOLLOW;
};
static_cast<void>(DeepLinkedGraphSearch(out_node, IncludeFunc));
return result;
};
for (size_t i = 1; i < output_node->size(); i++) {
auto out = output_node->input(i);
// only process the nodes that depend on index_node.
if (!DependsOnIndexNode(out)) continue;
// 1. always extend to the side-effect nodes
// 2. if the external users are only UpdateState, the related output will be eliminated,
// so only the getitem with realkernel user can be extended to.
if (IsSideEffectNode(out) ||
(getitems_[i - 1] != nullptr && external_user_type_[i - 1] != ExternalUserType::kUpdateState)) {
group.push_back(i - 1);
}
}
return group;
}
bool ExtendOutputForUpdateState::ProcessIndex(const FuncGraphPtr &func_graph, const FuncGraphPtr &sub_func_graph,
size_t index) {
auto group = FindAllOutputs(sub_func_graph, index + 1);
AnfNodePtr new_node = nullptr;
if (group.size() == 1 && group[0] == index) {
return false;
}
if (group.empty()) {
// the output is not side-effect node, but it hasn't realkernel user.
// replace the getitem with a value node that is unrelated to the original node.
// and this value node will be removed at the later pass.
MS_LOG(INFO) << "The " << getitems_[index]->fullname_with_scope() << " only has UpdateState user.";
new_node = NewValueNode(kUMonad)->cast<AnfNodePtr>();
new_node->set_abstract(kUMonad->ToAbstract());
} else {
// Create MakeTuple, even though the group size is 1, the following pass will spread the MakeTuple,
// so it's unnecessary to set abstract for it.
AnfNodePtrList mt_input = {NewValueNode(prim::kPrimMakeTuple)};
std::transform(group.begin(), group.end(), std::back_inserter(mt_input),
[this](size_t idx) { return getitems_[idx]; });
new_node = func_graph->NewCNode(mt_input)->cast<AnfNodePtr>();
}
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
for (auto user : mng->node_users()[getitems_[index]]) {
user.first->cast<CNodePtr>()->set_input(user.second, new_node);
}
return true;
}
bool MergeOutputForUpdateState::Run(const FuncGraphPtr &func_graph) {
auto todos = GetUpdateStateList(func_graph);
bool changed = false;
for (auto node : todos) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
AnfNodePtrList inputs = {cnode->input(0), cnode->input(1)};
std::set<AnfNodePtr> node_set;
for (size_t i = 2; i < cnode->size(); ++i) {
auto input = cnode->input(i);
if (IsPrimitiveCNode(input, prim::kPrimTupleGetItem)) {
// only keep one GetItem for that link to the same node.
auto gt_input = input->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
if (node_set.insert(gt_input).second) {
inputs.push_back(input);
}
} else if (!HasAbstractUMonad(input)) /*filter the UMonad that was added in "ExtendOutputForUpdateState" */ {
if (node_set.insert(input).second) {
inputs.push_back(input);
}
}
}
if (inputs.size() < cnode->size()) {
cnode->set_inputs(inputs);
changed = true;
}
}
if (changed) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,163 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_
#include <vector>
#include "backend/optimizer/common/pass.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace opt {
/**
* @brief Spread the input tuple of UpdateState
* @example
* %1 = op1
* %2 = op2
* %3 = make_tuple(%1, %2)
* UpdateState(U, %3)
* -->
* %1 = op1
* %2 = op2
* UpdateState(U, %1, %2)
*/
class SpreadUpdateState : public Pass {
public:
SpreadUpdateState() : Pass("spread_update_state") {}
~SpreadUpdateState() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
/**
* @brief Shrink the inputs of UpdateState to a tuple
* @example
* %1 = op1
* %2 = op2
* UpdateState(U, %1, %2)
* -->
* %1 = op1
* %2 = op2
* %3 = make_tuple(%1, %2)
* UpdateState(U, %3)
*/
class ShrinkUpdateState : public Pass {
public:
ShrinkUpdateState() : Pass("shrink_update_state") {}
~ShrinkUpdateState() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
/**
* @brief Spread the MakeTuple in node list
* @param nodes
* @param begin_index
* @example
* input
* nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ]
* begin_index: 1
* output
* [b, i, j, c, d, x, y, z]
* @return std::vector<AnfNodePtr>
*/
AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0);
/**
* @brief Extend the getitem for UpdateState
* @example
* In this example, the Cast is an output of GraphKernel and only links to an UpdateState,
* it has two users in GraphKernel, Add and Sub, which are all outputs.
* after processing, the Cast was eliminate from output list and the Add and Sub was linked to UpdateState.
*
* graph_kernel:
* %1 = Cast(p1)
* %2 = Add(%1, p2) // depends on Cast
* %3 = Sub(%2, p3) // depends on Cast
* %4 = Mul(p1, p2) // not depends on Cast
* return make_tuple(%1, %2, %3, %4)
* main graph:
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0) // The Cast
* %3 = UpdateState(U, %2)
* -->
* graph_kernel:
* %1 = Cast(p1)
* %2 = Add(%1, p2) // depends on Cast
* %3 = Sub(%2, p3) // depends on Cast
* %4 = Mul(p1, p2) // not depends on Cast
* return make_tuple(%2, %3, %4) // the Cast was eliminated from output list
* main graph:
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0) // the Add
* %3 = tuple_getitem(%1, 1) // the Sub
* %4 = UpdateState(U, %2, %3)
*/
class ExtendOutputForUpdateState : public Pass {
public:
ExtendOutputForUpdateState() : Pass("extend_output_for_update_state") {}
~ExtendOutputForUpdateState() = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
// Get the nodes that have external UpdateState user.
void FindIndexesToUpdateState(const FuncGraphManagerPtr &mng);
void FilterIndexes(const FuncGraphPtr &func_graph);
// Find all the func_graph's outputs that depends (directly or indirectly) on the indicated(index) node.
std::vector<size_t> FindAllOutputs(const FuncGraphPtr &func_graph, size_t index);
bool ProcessIndex(const FuncGraphPtr &func_graph, const FuncGraphPtr &sub_func_graph, size_t index);
enum ExternalUserType {
kNormalOp, // only has normal operators
kUpdateState, // only has UpdateState(s)
kMix, // UpdateState mix with normal operator
};
AnfNodePtrList getitems_; // Users of the GraphKernel nodes.
std::vector<size_t> indexes_; // Indexes of GetItem to be processed.
std::vector<ExternalUserType> external_user_type_; // The type of getitem's users.
};
/**
* @brief Merge UpdateState's inputs which link to the same node
* @example
* graph_kernel:
* %1 = Cast(p1)
* %2 = Add(%1, p2)
* %3 = Sub(%2, p3)
* %4 = Mul(p1, p2)
* return make_tuple(%1, %2, %3, %4)
* main graph:
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0)
* %3 = tuple_getitem(%1, 1)
* %4 = tuple_getitem(%1, 2)
* %5 = UpdateState(U, %2, %3, %4) // the %2 %3 %4 are all link to %1
* -->
* main graph:
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0)
* %3 = tuple_getitem(%1, 1)
* %4 = tuple_getitem(%1, 2)
* %5 = UpdateState(U, %2) // only keep %2
*/
class MergeOutputForUpdateState : public Pass {
public:
MergeOutputForUpdateState() : Pass("merge_output_for_update_state") {}
~MergeOutputForUpdateState() = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_