forked from mindspore-Ecosystem/mindspore
Call the pass "opt::GetitemTuple" in BuildSingleGraphFromNodes.
this pass may eliminate some nodes and make the parameter unused, so we should call it before the function EliminateRedundantParameters.
This commit is contained in:
parent
555c9dc929
commit
820e620bcf
|
@ -27,6 +27,7 @@
|
|||
#include "include/common/utils/utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/ordered_set.h"
|
||||
#include "backend/common/pass/getitem_tuple.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_callback.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
|
||||
|
@ -83,7 +84,7 @@ bool InlineInnerFuncGraph(const FuncGraphPtr &fg) {
|
|||
return changed;
|
||||
}
|
||||
|
||||
void EliminateMakeTuple(const FuncGraphPtr &fg) {
|
||||
void EliminateTupleOfTuple(const FuncGraphPtr &fg) {
|
||||
if (!IsPrimitiveCNode(fg->output(), prim::kPrimMakeTuple)) {
|
||||
return;
|
||||
}
|
||||
|
@ -259,15 +260,14 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildSingleGraphFromNod
|
|||
AnfNodePtrList outputs;
|
||||
std::tie(fg, inputs, outputs) = BuildGraphFromNodes(nodes);
|
||||
|
||||
FuncGraphManagerPtr mng = fg->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(fg, false);
|
||||
fg->set_manager(mng);
|
||||
}
|
||||
FuncGraphManagerPtr mng = GkUtils::GetFuncGraphManager(fg);
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
|
||||
(void)InlineInnerFuncGraph(fg);
|
||||
// eliminate tuple of tuple, and set Abstract for output MakeTuple
|
||||
EliminateMakeTuple(fg);
|
||||
EliminateTupleOfTuple(fg);
|
||||
// eliminate the inner MakeTuple-GetItem edges
|
||||
(void)std::static_pointer_cast<opt::Pass>(std::make_shared<opt::GetitemTuple>())->Run(fg);
|
||||
(void)ConvertNonscalarTensorToParameter(fg, &inputs);
|
||||
|
||||
return std::make_tuple(fg, inputs, outputs);
|
||||
|
|
|
@ -28,14 +28,11 @@
|
|||
#include "utils/ms_context.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "include/common/utils/context/graph_kernel_flags.h"
|
||||
#include "backend/common/pass/getitem_tuple.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_callback.h"
|
||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "common/graph_kernel/core/graph_builder.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
using opt::GetitemTuple;
|
||||
|
||||
std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() {
|
||||
std::vector<OpWithLevel> clusterable_ops_with_level = {
|
||||
// all target
|
||||
|
@ -405,8 +402,6 @@ void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const s
|
|||
(void)std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes),
|
||||
[this](size_t id) { return this->nodes_[id]; });
|
||||
auto new_node = ReplaceNodesWithGraphKernelNode(old_nodes, func_graph, "fusion");
|
||||
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<GetitemTuple>();
|
||||
(void)eliminate_getitem_pass->Run(GetCNodeFuncGraph(new_node));
|
||||
if (GraphKernelFlags::GetInstance().dump_as_text) {
|
||||
DumpClusterInfo(old_nodes, new_node);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue