Call the pass "opt::GetitemTuple" in BuildSingleGraphFromNodes.

this pass may eliminate some nodes and make the parameter unused,
so we should call it before the function EliminateRedundantParameters.
This commit is contained in:
dayschan 2022-03-09 22:40:10 +08:00
parent 555c9dc929
commit 820e620bcf
2 changed files with 7 additions and 12 deletions

View File

@ -27,6 +27,7 @@
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
#include "utils/anf_utils.h" #include "utils/anf_utils.h"
#include "utils/ordered_set.h" #include "utils/ordered_set.h"
#include "backend/common/pass/getitem_tuple.h"
#include "common/graph_kernel/core/graph_kernel_callback.h" #include "common/graph_kernel/core/graph_kernel_callback.h"
#include "common/graph_kernel/core/graph_kernel_utils.h" #include "common/graph_kernel/core/graph_kernel_utils.h"
@ -83,7 +84,7 @@ bool InlineInnerFuncGraph(const FuncGraphPtr &fg) {
return changed; return changed;
} }
void EliminateMakeTuple(const FuncGraphPtr &fg) { void EliminateTupleOfTuple(const FuncGraphPtr &fg) {
if (!IsPrimitiveCNode(fg->output(), prim::kPrimMakeTuple)) { if (!IsPrimitiveCNode(fg->output(), prim::kPrimMakeTuple)) {
return; return;
} }
@ -259,15 +260,14 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildSingleGraphFromNod
AnfNodePtrList outputs; AnfNodePtrList outputs;
std::tie(fg, inputs, outputs) = BuildGraphFromNodes(nodes); std::tie(fg, inputs, outputs) = BuildGraphFromNodes(nodes);
FuncGraphManagerPtr mng = fg->manager(); FuncGraphManagerPtr mng = GkUtils::GetFuncGraphManager(fg);
if (mng == nullptr) { MS_EXCEPTION_IF_NULL(mng);
mng = Manage(fg, false);
fg->set_manager(mng);
}
(void)InlineInnerFuncGraph(fg); (void)InlineInnerFuncGraph(fg);
// eliminate tuple of tuple, and set Abstract for output MakeTuple // eliminate tuple of tuple, and set Abstract for output MakeTuple
EliminateMakeTuple(fg); EliminateTupleOfTuple(fg);
// eliminate the inner MakeTuple-GetItem edges
(void)std::static_pointer_cast<opt::Pass>(std::make_shared<opt::GetitemTuple>())->Run(fg);
(void)ConvertNonscalarTensorToParameter(fg, &inputs); (void)ConvertNonscalarTensorToParameter(fg, &inputs);
return std::make_tuple(fg, inputs, outputs); return std::make_tuple(fg, inputs, outputs);

View File

@ -28,14 +28,11 @@
#include "utils/ms_context.h" #include "utils/ms_context.h"
#include "utils/file_utils.h" #include "utils/file_utils.h"
#include "include/common/utils/context/graph_kernel_flags.h" #include "include/common/utils/context/graph_kernel_flags.h"
#include "backend/common/pass/getitem_tuple.h"
#include "common/graph_kernel/core/graph_kernel_callback.h" #include "common/graph_kernel/core/graph_kernel_callback.h"
#include "common/graph_kernel/core/graph_kernel_utils.h" #include "common/graph_kernel/core/graph_kernel_utils.h"
#include "common/graph_kernel/core/graph_builder.h" #include "common/graph_kernel/core/graph_builder.h"
namespace mindspore::graphkernel { namespace mindspore::graphkernel {
using opt::GetitemTuple;
std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() { std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() {
std::vector<OpWithLevel> clusterable_ops_with_level = { std::vector<OpWithLevel> clusterable_ops_with_level = {
// all target // all target
@ -405,8 +402,6 @@ void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const s
(void)std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes), (void)std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes),
[this](size_t id) { return this->nodes_[id]; }); [this](size_t id) { return this->nodes_[id]; });
auto new_node = ReplaceNodesWithGraphKernelNode(old_nodes, func_graph, "fusion"); auto new_node = ReplaceNodesWithGraphKernelNode(old_nodes, func_graph, "fusion");
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<GetitemTuple>();
(void)eliminate_getitem_pass->Run(GetCNodeFuncGraph(new_node));
if (GraphKernelFlags::GetInstance().dump_as_text) { if (GraphKernelFlags::GetInstance().dump_as_text) {
DumpClusterInfo(old_nodes, new_node); DumpClusterInfo(old_nodes, new_node);
} }