forked from mindspore-Ecosystem/mindspore
!12926 【GraphKernel】Process for UpdateState node
From: @dayschan Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou
This commit is contained in:
commit
bc38590e53
|
@ -28,6 +28,7 @@
|
||||||
#include "debug/anf_ir_dump.h"
|
#include "debug/anf_ir_dump.h"
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -50,9 +51,10 @@ void SetIndex(const AnfNodePtr &getitem_node, size_t index) {
|
||||||
idx_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
idx_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||||
getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
|
getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
|
||||||
}
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
|
bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
|
||||||
bool merge_repeated_getitem = false) {
|
bool merge_repeated_getitem) {
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
MS_EXCEPTION_IF_NULL(getitem_list);
|
MS_EXCEPTION_IF_NULL(getitem_list);
|
||||||
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||||
|
@ -194,121 +196,6 @@ class UnifyRepeatedGetitem : public Pass {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Merge the get_item nodes that have same index.
|
|
||||||
* subgraph graph_kernel(%para1, %para2)
|
|
||||||
* %1 = TensorAdd(%para1, %para2)
|
|
||||||
* %2 = Neg(%1)
|
|
||||||
* %3 = make_tuple(%1, %2)
|
|
||||||
* return (%3)
|
|
||||||
* %1 = call @graph_kernel(%p1, %p2)
|
|
||||||
* %2 = tuple_getitem(%1, 0)
|
|
||||||
* %3 = tuple_getitem(%1, 1)
|
|
||||||
* %4 = ControlDepend(%0, %2)
|
|
||||||
* %5 = other_user(%3)
|
|
||||||
* --->
|
|
||||||
* subgraph graph_kernel(%para1, %para2)
|
|
||||||
* %1 = TensorAdd(%para1, %para2)
|
|
||||||
* %2 = Neg(%1)
|
|
||||||
* %3 = make_tuple(%1, %2)
|
|
||||||
* return (%3)
|
|
||||||
* %1 = call @graph_kernel(%p1, %p2)
|
|
||||||
* %3 = tuple_getitem(%1, 1)
|
|
||||||
* %4 = ControlDepend(%0, %3)
|
|
||||||
* %5 = other_user(%3)
|
|
||||||
*
|
|
||||||
* Then the output 0 can be eliminate in the later pass.
|
|
||||||
*/
|
|
||||||
class EliminateGetitemForControlDepend : public Pass {
|
|
||||||
public:
|
|
||||||
bool Run(const FuncGraphPtr &func_graph) {
|
|
||||||
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
|
|
||||||
auto mng = func_graph->manager();
|
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
|
||||||
bool changed = false;
|
|
||||||
for (const auto &node : todos) {
|
|
||||||
getitems_.clear();
|
|
||||||
GetGraphKernelGetitemList(mng, node, &getitems_, false);
|
|
||||||
if (getitems_.empty()) continue;
|
|
||||||
indexes_.clear();
|
|
||||||
GetIndexesToControlDepend(mng);
|
|
||||||
FilterRedundantOutputs(node);
|
|
||||||
if (indexes_.empty()) continue;
|
|
||||||
size_t index = GetFinalIndex(node);
|
|
||||||
changed = ReplaceGetitems(mng, index) || changed;
|
|
||||||
}
|
|
||||||
return changed;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
AnfNodePtrList getitems_; // Users of GraphKernel node with multiple outputs.
|
|
||||||
std::vector<size_t> indexes_; // Indexes of MakeTuple to be eliminated.
|
|
||||||
|
|
||||||
bool ReplaceGetitems(const FuncGraphManagerPtr &mng, size_t index) {
|
|
||||||
MS_EXCEPTION_IF_NULL(getitems_[index]);
|
|
||||||
bool changed = false;
|
|
||||||
for (auto i : indexes_) {
|
|
||||||
if (i != index) {
|
|
||||||
MS_EXCEPTION_IF_NULL(getitems_[i]);
|
|
||||||
mng->Replace(getitems_[i], getitems_[index]);
|
|
||||||
changed = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return changed;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the redundant output index.
|
|
||||||
// the real output should have multiple users.
|
|
||||||
void FilterRedundantOutputs(const AnfNodePtr &node) {
|
|
||||||
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
|
||||||
auto mng = func_graph->manager();
|
|
||||||
if (mng == nullptr) {
|
|
||||||
mng = Manage(func_graph, true);
|
|
||||||
func_graph->set_manager(mng);
|
|
||||||
}
|
|
||||||
auto &users = mng->node_users();
|
|
||||||
auto maketuple = func_graph->output()->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(maketuple);
|
|
||||||
std::vector<size_t> result;
|
|
||||||
for (auto i : indexes_) {
|
|
||||||
auto real_output = maketuple->input(i + 1);
|
|
||||||
if (users[real_output].size() > 1) {
|
|
||||||
result.push_back(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
indexes_ = std::move(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the nodes that only have ControlDepend users.
|
|
||||||
void GetIndexesToControlDepend(const FuncGraphManagerPtr &mng) {
|
|
||||||
for (size_t i = 0; i < getitems_.size(); ++i) {
|
|
||||||
const AnfNodePtr &getitem = getitems_[i];
|
|
||||||
if (getitem == nullptr) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const auto &getitem_user = mng->node_users()[getitem];
|
|
||||||
if (std::all_of(getitem_user.begin(), getitem_user.end(), [](const std::pair<AnfNodePtr, int> &user) {
|
|
||||||
return IsPrimitiveCNode(user.first, prim::kPrimControlDepend);
|
|
||||||
})) {
|
|
||||||
indexes_.push_back(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t GetFinalIndex(const AnfNodePtr &node) {
|
|
||||||
auto is_redundant_index = [this](size_t i) {
|
|
||||||
return std::find(indexes_.begin(), indexes_.end(), i) != indexes_.end();
|
|
||||||
};
|
|
||||||
for (size_t i = 0; i < getitems_.size(); ++i) {
|
|
||||||
if (getitems_[i] != nullptr && !is_redundant_index(i)) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return indexes_[0];
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// Remove the output without user or with virtual user (like ControlDepend)
|
|
||||||
bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
|
bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
|
||||||
auto mng = func_graph->manager();
|
auto mng = func_graph->manager();
|
||||||
if (mng == nullptr) {
|
if (mng == nullptr) {
|
||||||
|
@ -319,13 +206,11 @@ bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
|
||||||
changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
|
changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
|
||||||
changed = std::make_shared<UnifyRepeatedOutput>()->Run(func_graph) || changed;
|
changed = std::make_shared<UnifyRepeatedOutput>()->Run(func_graph) || changed;
|
||||||
changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
|
changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
|
||||||
changed = std::make_shared<EliminateGetitemForControlDepend>()->Run(func_graph) || changed;
|
changed = std::make_shared<EliminateHangingOutput>()->Run(func_graph) || changed;
|
||||||
changed = Process(func_graph) || changed;
|
|
||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
// update the GetItem(node, i) to GetItem(node, i - offset)
|
void EliminateHangingOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset) {
|
||||||
void EliminateRedundantOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset) {
|
|
||||||
if (offset == 0) return;
|
if (offset == 0) return;
|
||||||
MS_EXCEPTION_IF_NULL(getitem);
|
MS_EXCEPTION_IF_NULL(getitem);
|
||||||
auto index = GetIndex(getitem);
|
auto index = GetIndex(getitem);
|
||||||
|
@ -336,7 +221,7 @@ void EliminateRedundantOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, siz
|
||||||
SetIndex(getitem, index);
|
SetIndex(getitem, index);
|
||||||
}
|
}
|
||||||
|
|
||||||
AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) {
|
AnfNodePtr EliminateHangingOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) {
|
||||||
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
auto old_maketuple = func_graph->output()->cast<CNodePtr>();
|
auto old_maketuple = func_graph->output()->cast<CNodePtr>();
|
||||||
|
@ -379,7 +264,7 @@ AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, co
|
||||||
return graph_kernel_node;
|
return graph_kernel_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool EliminateRedundantOutput::Process(const FuncGraphPtr &func_graph) {
|
bool EliminateHangingOutput::Run(const FuncGraphPtr &func_graph) {
|
||||||
auto mng = func_graph->manager();
|
auto mng = func_graph->manager();
|
||||||
MS_EXCEPTION_IF_NULL(mng);
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
|
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
|
||||||
|
|
|
@ -20,17 +20,54 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
/* Eliminate the output without external user
|
||||||
|
* %1 = call @graph_kernel(p1, p2)
|
||||||
|
* %2 = tuple_getitem(%1, 0) // the getitem(1) does not exist.
|
||||||
|
* %3 = op(%2)
|
||||||
|
* graph_kernel:
|
||||||
|
* %1 = TensorAdd(p1, p2)
|
||||||
|
* %2 = Sub(p1, p2)
|
||||||
|
* return make_tuple(%1, %2)
|
||||||
|
* --->
|
||||||
|
* %1 = call @graph_kernel(p1, p2)
|
||||||
|
* %3 = op(%1) // if only one output remains, the getitem is not used
|
||||||
|
* graph_kernel:
|
||||||
|
* %1 = TensorAdd(p1, p2)
|
||||||
|
* return %1 // the Sub was eliminated
|
||||||
|
*/
|
||||||
|
class EliminateHangingOutput : public Pass {
|
||||||
|
public:
|
||||||
|
bool Run(const FuncGraphPtr &func_graph) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// update the GetItem(node, i) to GetItem(node, i - offset)
|
||||||
|
void UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset);
|
||||||
|
AnfNodePtr ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Remove the output without user or with virtual user (like UpdateState)
|
||||||
class EliminateRedundantOutput : public Pass {
|
class EliminateRedundantOutput : public Pass {
|
||||||
public:
|
public:
|
||||||
EliminateRedundantOutput() : Pass("eliminate_redundant_output") {}
|
EliminateRedundantOutput() : Pass("eliminate_redundant_output") {}
|
||||||
~EliminateRedundantOutput() override = default;
|
~EliminateRedundantOutput() override = default;
|
||||||
bool Run(const FuncGraphPtr &func_graph) override;
|
bool Run(const FuncGraphPtr &func_graph) override;
|
||||||
|
|
||||||
private:
|
|
||||||
bool Process(const FuncGraphPtr &func_graph);
|
|
||||||
void UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset);
|
|
||||||
AnfNodePtr ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
bool IsSideEffectNode(const AnfNodePtr &node);
|
||||||
|
AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the GraphKernel's user getitems
|
||||||
|
*
|
||||||
|
* @param mng FuncGraphManagerPtr for the main func_graph
|
||||||
|
* @param node The cnode that indicates the GraphKernel
|
||||||
|
* @param getitem_list The user getitem list.
|
||||||
|
* @param merge_repeated_getitem If true, getitems with same index will be merged,
|
||||||
|
* otherwise, only one getitem will be outputted.
|
||||||
|
* @return If the graph was changed, returns true, otherwise returns false.
|
||||||
|
*/
|
||||||
|
bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
|
||||||
|
bool merge_repeated_getitem = false);
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ELIMINATE_REDUNDANT_OUTPUT_H_
|
||||||
|
|
|
@ -40,6 +40,7 @@
|
||||||
#include "backend/optimizer/graph_kernel/optimize_assign.h"
|
#include "backend/optimizer/graph_kernel/optimize_assign.h"
|
||||||
#include "backend/optimizer/graph_kernel/split_assign.h"
|
#include "backend/optimizer/graph_kernel/split_assign.h"
|
||||||
#include "backend/optimizer/graph_kernel/reorder_ops.h"
|
#include "backend/optimizer/graph_kernel/reorder_ops.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
|
||||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -56,6 +57,9 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() {
|
||||||
if (is_ascend) {
|
if (is_ascend) {
|
||||||
pm->AddPass(std::make_shared<ReorderOps>());
|
pm->AddPass(std::make_shared<ReorderOps>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Spread the MakeTuple input of UpdateState
|
||||||
|
pm->AddPass(std::make_shared<SpreadUpdateState>());
|
||||||
return pm;
|
return pm;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -99,6 +103,8 @@ PassManagerPtr GraphKernelOptimizer::Split() {
|
||||||
// Make certain nodes redundant so that they are used by only one user,
|
// Make certain nodes redundant so that they are used by only one user,
|
||||||
// which can avoid unnecessary input-output and get better performance.
|
// which can avoid unnecessary input-output and get better performance.
|
||||||
if (is_gpu) {
|
if (is_gpu) {
|
||||||
|
// preprocess for ShapeOpsSplitter
|
||||||
|
pm->AddPass(std::make_shared<ExtendOutputForUpdateState>());
|
||||||
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast};
|
std::vector<PrimitivePtr> duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast};
|
||||||
pm->AddPass(std::make_shared<ShapeOpsSplitter>(duplicated_ops));
|
pm->AddPass(std::make_shared<ShapeOpsSplitter>(duplicated_ops));
|
||||||
}
|
}
|
||||||
|
@ -106,15 +112,16 @@ PassManagerPtr GraphKernelOptimizer::Split() {
|
||||||
// Split kernel according to costmodel
|
// Split kernel according to costmodel
|
||||||
pm->AddPass(std::make_shared<GraphKernelSplitter>());
|
pm->AddPass(std::make_shared<GraphKernelSplitter>());
|
||||||
|
|
||||||
// Eliminate the redundant node that is copied above but not handled by GraphKernelSplitter
|
|
||||||
if (is_gpu) {
|
|
||||||
pm->AddPass(std::make_shared<GraphKernelCSE>());
|
|
||||||
pm->AddPass(std::make_shared<EliminateRedundantOutput>());
|
|
||||||
}
|
|
||||||
|
|
||||||
// After Simplify and Splitter, a lot of redundant getitem/maketuple
|
// After Simplify and Splitter, a lot of redundant getitem/maketuple
|
||||||
// will be exposed, use GetitemTuple Pass to delete them.
|
// will be exposed, use GetitemTuple Pass to delete them.
|
||||||
pm->AddPass(std::make_shared<GetitemTuple>());
|
pm->AddPass(std::make_shared<GetitemTuple>());
|
||||||
|
|
||||||
|
// Eliminate the redundant node that is copied above but not handled by GraphKernelSplitter
|
||||||
|
if (is_gpu) {
|
||||||
|
pm->AddPass(std::make_shared<MergeOutputForUpdateState>());
|
||||||
|
pm->AddPass(std::make_shared<GraphKernelCSE>());
|
||||||
|
pm->AddPass(std::make_shared<EliminateRedundantOutput>());
|
||||||
|
}
|
||||||
return pm;
|
return pm;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,6 +153,9 @@ PassManagerPtr GraphKernelOptimizer::PostProcess() {
|
||||||
auto pm = std::make_shared<PassManager>("graphkernel_stage7_postprocess");
|
auto pm = std::make_shared<PassManager>("graphkernel_stage7_postprocess");
|
||||||
// Add the new tensors to the kernel_graph
|
// Add the new tensors to the kernel_graph
|
||||||
pm->AddPass(std::make_shared<BindValueToGraph>());
|
pm->AddPass(std::make_shared<BindValueToGraph>());
|
||||||
|
|
||||||
|
// Make Tuple for the inputs of UpdateState. (the reverse of SpreadUpdateState)
|
||||||
|
pm->AddPass(std::make_shared<ShrinkUpdateState>());
|
||||||
return pm;
|
return pm;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,6 +173,12 @@ void GraphKernelOptimizer::Run(const KernelGraphPtr &kernel_graph) {
|
||||||
optimizer->AddPassManager(HighLevelOpt2());
|
optimizer->AddPassManager(HighLevelOpt2());
|
||||||
optimizer->AddPassManager(Combine());
|
optimizer->AddPassManager(Combine());
|
||||||
optimizer->AddPassManager(PostProcess());
|
optimizer->AddPassManager(PostProcess());
|
||||||
|
|
||||||
|
auto mng = kernel_graph->manager();
|
||||||
|
if (mng == nullptr) {
|
||||||
|
mng = Manage(kernel_graph, true);
|
||||||
|
kernel_graph->set_manager(mng);
|
||||||
|
}
|
||||||
(void)optimizer->Optimize(kernel_graph);
|
(void)optimizer->Optimize(kernel_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,264 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <set>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||||
|
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
AnfNodePtrList GetUpdateStateList(const FuncGraphPtr &func_graph) {
|
||||||
|
auto todos = TopoSort(func_graph->get_return());
|
||||||
|
AnfNodePtrList result;
|
||||||
|
std::copy_if(todos.begin(), todos.end(), std::back_inserter(result),
|
||||||
|
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimUpdateState); });
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) {
|
||||||
|
AnfNodePtrList result;
|
||||||
|
for (size_t i = begin_index; i < nodes.size(); i++) {
|
||||||
|
if (IsPrimitiveCNode(nodes[i], prim::kPrimMakeTuple)) {
|
||||||
|
auto mt = nodes[i]->cast<CNodePtr>();
|
||||||
|
// recursively spread all inner tuples.
|
||||||
|
auto mt_inputs = SpreadTuples(mt->inputs(), 1);
|
||||||
|
result.insert(result.end(), mt_inputs.begin(), mt_inputs.end());
|
||||||
|
} else {
|
||||||
|
result.push_back(nodes[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
|
||||||
|
auto todos = GetUpdateStateList(func_graph);
|
||||||
|
bool changed = false;
|
||||||
|
for (auto node : todos) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (cnode->size() <= kUpdateStateRealInput) continue;
|
||||||
|
auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
|
||||||
|
if (inputs.size() + 2 != cnode->size() || inputs[0] != cnode->input(2)) {
|
||||||
|
AnfNodePtrList node_inputs = {cnode->input(0), cnode->input(1)};
|
||||||
|
node_inputs.insert(node_inputs.end(), inputs.begin(), inputs.end());
|
||||||
|
cnode->set_inputs(node_inputs);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (changed) {
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
mng->RemoveRoots();
|
||||||
|
mng->KeepRoots({func_graph});
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) {
|
||||||
|
auto todos = GetUpdateStateList(func_graph);
|
||||||
|
bool changed = false;
|
||||||
|
for (auto node : todos) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
if (cnode->size() <= kUpdateStateRealInput) continue;
|
||||||
|
AnfNodePtrList mt_inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
|
||||||
|
AbstractBasePtrList abs_list;
|
||||||
|
std::transform(mt_inputs.begin(), mt_inputs.end(), std::back_inserter(abs_list),
|
||||||
|
[](const AnfNodePtr &inp) { return inp->abstract(); });
|
||||||
|
mt_inputs.insert(mt_inputs.begin(), NewValueNode(prim::kPrimMakeTuple));
|
||||||
|
auto mt_node = func_graph->NewCNode(mt_inputs);
|
||||||
|
mt_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
|
||||||
|
mt_node->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||||
|
|
||||||
|
AnfNodePtrList inputs = {cnode->input(0), cnode->input(1), mt_node};
|
||||||
|
cnode->set_inputs(inputs);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (changed) {
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
mng->RemoveRoots();
|
||||||
|
mng->KeepRoots({func_graph});
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ExtendOutputForUpdateState::Run(const FuncGraphPtr &func_graph) {
|
||||||
|
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
bool changed = false;
|
||||||
|
for (const auto &node : todos) {
|
||||||
|
GetGraphKernelGetitemList(mng, node, &getitems_, false);
|
||||||
|
if (getitems_.empty()) continue;
|
||||||
|
FindIndexesToUpdateState(mng);
|
||||||
|
if (indexes_.empty()) continue;
|
||||||
|
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
|
||||||
|
FilterIndexes(sub_func_graph);
|
||||||
|
if (indexes_.empty()) continue;
|
||||||
|
for (auto idx : indexes_) {
|
||||||
|
changed = ProcessIndex(func_graph, sub_func_graph, idx) || changed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (changed) {
|
||||||
|
std::make_shared<SpreadUpdateState>()->Run(func_graph);
|
||||||
|
std::make_shared<EliminateHangingOutput>()->Run(func_graph);
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExtendOutputForUpdateState::FindIndexesToUpdateState(const FuncGraphManagerPtr &mng) {
|
||||||
|
indexes_.clear();
|
||||||
|
external_user_type_.clear();
|
||||||
|
external_user_type_.resize(getitems_.size(), ExternalUserType::kNormalOp);
|
||||||
|
for (size_t i = 0; i < getitems_.size(); ++i) {
|
||||||
|
const AnfNodePtr &getitem = getitems_[i];
|
||||||
|
if (getitem == nullptr) continue;
|
||||||
|
|
||||||
|
const auto &getitem_user = mng->node_users()[getitem];
|
||||||
|
auto IsUpdateState = [](const std::pair<AnfNodePtr, int> &user) {
|
||||||
|
return IsPrimitiveCNode(user.first, prim::kPrimUpdateState);
|
||||||
|
};
|
||||||
|
if (std::all_of(getitem_user.begin(), getitem_user.end(), IsUpdateState)) {
|
||||||
|
external_user_type_[i] = ExternalUserType::kUpdateState;
|
||||||
|
indexes_.push_back(i);
|
||||||
|
} else if (std::any_of(getitem_user.begin(), getitem_user.end(), IsUpdateState)) {
|
||||||
|
external_user_type_[i] = ExternalUserType::kMix;
|
||||||
|
indexes_.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExtendOutputForUpdateState::FilterIndexes(const FuncGraphPtr &func_graph) {
|
||||||
|
auto output_node = func_graph->output()->cast<CNodePtr>();
|
||||||
|
// do not process the side-effect nodes.
|
||||||
|
indexes_.erase(std::remove_if(indexes_.begin(), indexes_.end(),
|
||||||
|
[&output_node](size_t i) { return IsSideEffectNode(output_node->input(i + 1)); }),
|
||||||
|
indexes_.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> ExtendOutputForUpdateState::FindAllOutputs(const FuncGraphPtr &func_graph, size_t index) {
|
||||||
|
auto output_node = func_graph->output()->cast<CNodePtr>();
|
||||||
|
auto index_node = output_node->input(index);
|
||||||
|
std::vector<size_t> group;
|
||||||
|
|
||||||
|
// if the `out_node` is a user (direct or indirect) of the `index_node`, returns true
|
||||||
|
auto DependsOnIndexNode = [&index_node](const AnfNodePtr &out_node) -> bool {
|
||||||
|
bool result = false;
|
||||||
|
auto IncludeFunc = [&result, &index_node](const AnfNodePtr &node) {
|
||||||
|
if (node == index_node) {
|
||||||
|
result = true;
|
||||||
|
return EXCLUDE;
|
||||||
|
}
|
||||||
|
return result ? EXCLUDE : FOLLOW;
|
||||||
|
};
|
||||||
|
static_cast<void>(DeepLinkedGraphSearch(out_node, IncludeFunc));
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (size_t i = 1; i < output_node->size(); i++) {
|
||||||
|
auto out = output_node->input(i);
|
||||||
|
// only process the nodes that depend on index_node.
|
||||||
|
if (!DependsOnIndexNode(out)) continue;
|
||||||
|
|
||||||
|
// 1. always extend to the side-effect nodes
|
||||||
|
// 2. if the external users are only UpdateState, the related output will be eliminated,
|
||||||
|
// so only the getitem with realkernel user can be extended to.
|
||||||
|
if (IsSideEffectNode(out) ||
|
||||||
|
(getitems_[i - 1] != nullptr && external_user_type_[i - 1] != ExternalUserType::kUpdateState)) {
|
||||||
|
group.push_back(i - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return group;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ExtendOutputForUpdateState::ProcessIndex(const FuncGraphPtr &func_graph, const FuncGraphPtr &sub_func_graph,
|
||||||
|
size_t index) {
|
||||||
|
auto group = FindAllOutputs(sub_func_graph, index + 1);
|
||||||
|
AnfNodePtr new_node = nullptr;
|
||||||
|
if (group.size() == 1 && group[0] == index) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (group.empty()) {
|
||||||
|
// the output is not side-effect node, but it hasn't realkernel user.
|
||||||
|
// replace the getitem with a value node that is unrelated to the original node.
|
||||||
|
// and this value node will be removed at the later pass.
|
||||||
|
MS_LOG(INFO) << "The " << getitems_[index]->fullname_with_scope() << " only has UpdateState user.";
|
||||||
|
new_node = NewValueNode(kUMonad)->cast<AnfNodePtr>();
|
||||||
|
new_node->set_abstract(kUMonad->ToAbstract());
|
||||||
|
} else {
|
||||||
|
// Create MakeTuple, even though the group size is 1, the following pass will spread the MakeTuple,
|
||||||
|
// so it's unnecessary to set abstract for it.
|
||||||
|
AnfNodePtrList mt_input = {NewValueNode(prim::kPrimMakeTuple)};
|
||||||
|
std::transform(group.begin(), group.end(), std::back_inserter(mt_input),
|
||||||
|
[this](size_t idx) { return getitems_[idx]; });
|
||||||
|
new_node = func_graph->NewCNode(mt_input)->cast<AnfNodePtr>();
|
||||||
|
}
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
for (auto user : mng->node_users()[getitems_[index]]) {
|
||||||
|
user.first->cast<CNodePtr>()->set_input(user.second, new_node);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MergeOutputForUpdateState::Run(const FuncGraphPtr &func_graph) {
|
||||||
|
auto todos = GetUpdateStateList(func_graph);
|
||||||
|
bool changed = false;
|
||||||
|
for (auto node : todos) {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
AnfNodePtrList inputs = {cnode->input(0), cnode->input(1)};
|
||||||
|
std::set<AnfNodePtr> node_set;
|
||||||
|
for (size_t i = 2; i < cnode->size(); ++i) {
|
||||||
|
auto input = cnode->input(i);
|
||||||
|
if (IsPrimitiveCNode(input, prim::kPrimTupleGetItem)) {
|
||||||
|
// only keep one GetItem for that link to the same node.
|
||||||
|
auto gt_input = input->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
|
||||||
|
if (node_set.insert(gt_input).second) {
|
||||||
|
inputs.push_back(input);
|
||||||
|
}
|
||||||
|
} else if (!HasAbstractUMonad(input)) /*filter the UMonad that was added in "ExtendOutputForUpdateState" */ {
|
||||||
|
if (node_set.insert(input).second) {
|
||||||
|
inputs.push_back(input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (inputs.size() < cnode->size()) {
|
||||||
|
cnode->set_inputs(inputs);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (changed) {
|
||||||
|
auto mng = func_graph->manager();
|
||||||
|
MS_EXCEPTION_IF_NULL(mng);
|
||||||
|
mng->RemoveRoots();
|
||||||
|
mng->KeepRoots({func_graph});
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,163 @@
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/optimizer/common/pass.h"
|
||||||
|
#include "ir/func_graph.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
/**
|
||||||
|
* @brief Spread the input tuple of UpdateState
|
||||||
|
* @example
|
||||||
|
* %1 = op1
|
||||||
|
* %2 = op2
|
||||||
|
* %3 = make_tuple(%1, %2)
|
||||||
|
* UpdateState(U, %3)
|
||||||
|
* -->
|
||||||
|
* %1 = op1
|
||||||
|
* %2 = op2
|
||||||
|
* UpdateState(U, %1, %2)
|
||||||
|
*/
|
||||||
|
class SpreadUpdateState : public Pass {
|
||||||
|
public:
|
||||||
|
SpreadUpdateState() : Pass("spread_update_state") {}
|
||||||
|
~SpreadUpdateState() override = default;
|
||||||
|
bool Run(const FuncGraphPtr &func_graph) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Shrink the inputs of UpdateState to a tuple
|
||||||
|
* @example
|
||||||
|
* %1 = op1
|
||||||
|
* %2 = op2
|
||||||
|
* UpdateState(U, %1, %2)
|
||||||
|
* -->
|
||||||
|
* %1 = op1
|
||||||
|
* %2 = op2
|
||||||
|
* %3 = make_tuple(%1, %2)
|
||||||
|
* UpdateState(U, %3)
|
||||||
|
*/
|
||||||
|
class ShrinkUpdateState : public Pass {
|
||||||
|
public:
|
||||||
|
ShrinkUpdateState() : Pass("shrink_update_state") {}
|
||||||
|
~ShrinkUpdateState() override = default;
|
||||||
|
bool Run(const FuncGraphPtr &func_graph) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Spread the MakeTuple in node list
|
||||||
|
* @param nodes
|
||||||
|
* @param begin_index
|
||||||
|
* @example
|
||||||
|
* input
|
||||||
|
* nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ]
|
||||||
|
* begin_index: 1
|
||||||
|
* output
|
||||||
|
* [b, i, j, c, d, x, y, z]
|
||||||
|
* @return std::vector<AnfNodePtr>
|
||||||
|
*/
|
||||||
|
AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Extend the getitem for UpdateState
|
||||||
|
* @example
|
||||||
|
* In this example, the Cast is an output of GraphKernel and only links to an UpdateState,
|
||||||
|
* it has two users in GraphKernel, Add and Sub, which are all outputs.
|
||||||
|
* after processing, the Cast was eliminate from output list and the Add and Sub was linked to UpdateState.
|
||||||
|
*
|
||||||
|
* graph_kernel:
|
||||||
|
* %1 = Cast(p1)
|
||||||
|
* %2 = Add(%1, p2) // depends on Cast
|
||||||
|
* %3 = Sub(%2, p3) // depends on Cast
|
||||||
|
* %4 = Mul(p1, p2) // not depends on Cast
|
||||||
|
* return make_tuple(%1, %2, %3, %4)
|
||||||
|
* main graph:
|
||||||
|
* %1 = call @graph_kernel(p1, p2)
|
||||||
|
* %2 = tuple_getitem(%1, 0) // The Cast
|
||||||
|
* %3 = UpdateState(U, %2)
|
||||||
|
* -->
|
||||||
|
* graph_kernel:
|
||||||
|
* %1 = Cast(p1)
|
||||||
|
* %2 = Add(%1, p2) // depends on Cast
|
||||||
|
* %3 = Sub(%2, p3) // depends on Cast
|
||||||
|
* %4 = Mul(p1, p2) // not depends on Cast
|
||||||
|
* return make_tuple(%2, %3, %4) // the Cast was eliminated from output list
|
||||||
|
* main graph:
|
||||||
|
* %1 = call @graph_kernel(p1, p2)
|
||||||
|
* %2 = tuple_getitem(%1, 0) // the Add
|
||||||
|
* %3 = tuple_getitem(%1, 1) // the Sub
|
||||||
|
* %4 = UpdateState(U, %2, %3)
|
||||||
|
*/
|
||||||
|
class ExtendOutputForUpdateState : public Pass {
|
||||||
|
public:
|
||||||
|
ExtendOutputForUpdateState() : Pass("extend_output_for_update_state") {}
|
||||||
|
~ExtendOutputForUpdateState() = default;
|
||||||
|
bool Run(const FuncGraphPtr &func_graph) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Get the nodes that have external UpdateState user.
|
||||||
|
void FindIndexesToUpdateState(const FuncGraphManagerPtr &mng);
|
||||||
|
void FilterIndexes(const FuncGraphPtr &func_graph);
|
||||||
|
// Find all the func_graph's outputs that depends (directly or indirectly) on the indicated(index) node.
|
||||||
|
std::vector<size_t> FindAllOutputs(const FuncGraphPtr &func_graph, size_t index);
|
||||||
|
bool ProcessIndex(const FuncGraphPtr &func_graph, const FuncGraphPtr &sub_func_graph, size_t index);
|
||||||
|
|
||||||
|
enum ExternalUserType {
|
||||||
|
kNormalOp, // only has normal operators
|
||||||
|
kUpdateState, // only has UpdateState(s)
|
||||||
|
kMix, // UpdateState mix with normal operator
|
||||||
|
};
|
||||||
|
AnfNodePtrList getitems_; // Users of the GraphKernel nodes.
|
||||||
|
std::vector<size_t> indexes_; // Indexes of GetItem to be processed.
|
||||||
|
std::vector<ExternalUserType> external_user_type_; // The type of getitem's users.
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Merge UpdateState's inputs which link to the same node
|
||||||
|
* @example
|
||||||
|
* graph_kernel:
|
||||||
|
* %1 = Cast(p1)
|
||||||
|
* %2 = Add(%1, p2)
|
||||||
|
* %3 = Sub(%2, p3)
|
||||||
|
* %4 = Mul(p1, p2)
|
||||||
|
* return make_tuple(%1, %2, %3, %4)
|
||||||
|
* main graph:
|
||||||
|
* %1 = call @graph_kernel(p1, p2)
|
||||||
|
* %2 = tuple_getitem(%1, 0)
|
||||||
|
* %3 = tuple_getitem(%1, 1)
|
||||||
|
* %4 = tuple_getitem(%1, 2)
|
||||||
|
* %5 = UpdateState(U, %2, %3, %4) // the %2 %3 %4 are all link to %1
|
||||||
|
* -->
|
||||||
|
* main graph:
|
||||||
|
* %1 = call @graph_kernel(p1, p2)
|
||||||
|
* %2 = tuple_getitem(%1, 0)
|
||||||
|
* %3 = tuple_getitem(%1, 1)
|
||||||
|
* %4 = tuple_getitem(%1, 2)
|
||||||
|
* %5 = UpdateState(U, %2) // only keep %2
|
||||||
|
*/
|
||||||
|
class MergeOutputForUpdateState : public Pass {
|
||||||
|
public:
|
||||||
|
MergeOutputForUpdateState() : Pass("merge_output_for_update_state") {}
|
||||||
|
~MergeOutputForUpdateState() = default;
|
||||||
|
bool Run(const FuncGraphPtr &func_graph) override;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_UPDATE_STATE_FORMATTER_H_
|
Loading…
Reference in New Issue