!41844 Decouple GraphKernel from MS backend

Merge pull request !41844 from DeshiChen/0913_passmanager
This commit is contained in:
i-robot 2022-09-20 02:07:33 +00:00 committed by Gitee
commit c46b894690
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 93 additions and 13 deletions

View File

@ -21,7 +21,7 @@
#include <string>
#include "utils/hash_map.h"
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/optimizer/pass.h"
#include "ir/func_graph.h"
#include "common/graph_kernel/model/lite_graph.h"

View File

@ -16,7 +16,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_ELIMINATE_REDUNDANT_OUTPUT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_ELIMINATE_REDUNDANT_OUTPUT_H_
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/optimizer/pass.h"
namespace mindspore::graphkernel {
/* Eliminate the output without external user

View File

@ -27,7 +27,6 @@
#include "include/common/utils/utils.h"
#include "utils/anf_utils.h"
#include "utils/ordered_set.h"
#include "backend/common/pass/getitem_tuple.h"
#include "common/graph_kernel/core/graph_kernel_callback.h"
#include "common/graph_kernel/core/graph_kernel_utils.h"
#include "ir/func_graph_cloner.h"
@ -313,8 +312,7 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildSingleGraphFromNod
(void)InlineInnerFuncGraph(fg);
// eliminate tuple of tuple, and set Abstract for output MakeTuple
EliminateTupleOfTuple(fg);
// eliminate the inner MakeTuple-GetItem edges
(void)std::static_pointer_cast<opt::Pass>(std::make_shared<opt::GetitemTuple>())->Run(fg);
(void)EliminateMaketupleGetitem(fg);
(void)ConvertNonscalarTensorToParameter(fg, &inputs);
return std::make_tuple(fg, inputs, outputs);
@ -347,4 +345,26 @@ AnfNodePtr ReplaceNodesWithGraphKernelNode(const AnfNodePtrList &nodes, const Fu
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
return fuse_new_node;
}
// Eliminate redundant MakeTuple-Getitem edges
bool EliminateMaketupleGetitem(const FuncGraphPtr &fg) {
auto nodes = fg->GetOrderedCnodes();
auto mng = GkUtils::GetFuncGraphManager(fg);
MS_EXCEPTION_IF_NULL(mng);
bool changed = false;
for (const auto &node : nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
continue;
}
auto gt = node->cast<CNodePtr>();
auto mt = gt->input(kRealInputNodeIndexInTupleGetItem)->cast<CNodePtr>();
if (mt == nullptr || !IsPrimitiveCNode(mt, prim::kPrimMakeTuple)) {
continue;
}
auto idx = AnfUtils::GetIntValue(gt->input(kInputNodeOutputIndexInTupleGetItem));
mng->Replace(node, mt->input(idx + 1));
changed = true;
}
return changed;
}
} // namespace mindspore::graphkernel

View File

@ -31,5 +31,6 @@ AnfNodePtr ReplaceNodesWithGraphKernelNode(const AnfNodePtrList &nodes, const Fu
const std::string &postfix = "");
bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr);
bool RemoveNonScalarConstTensorFromParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr);
bool EliminateMaketupleGetitem(const FuncGraphPtr &fg);
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_BUILDER_H_

View File

@ -22,7 +22,7 @@
#include "utils/hash_map.h"
#include "ir/anf.h"
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/optimizer/pass.h"
namespace mindspore {
namespace graphkernel {

View File

@ -25,8 +25,6 @@ if(MSLITE_ENABLE_GRAPH_KERNEL)
${CCSRC_DIR}/common/graph_kernel/split_model/*.cc
${CCSRC_DIR}/common/graph_kernel/graph_kernel_flags.cc
${CCSRC_DIR}/kernel/akg/akg_kernel_json_generator.cc
${CCSRC_DIR}/backend/common/pass/getitem_tuple.cc
${CCSRC_DIR}/backend/common/optimizer/optimizer.cc
)
set(CCSRC_SRC
${CCSRC_SRC}

View File

@ -0,0 +1,47 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_ELIMINATE_MAKETUPLE_GETITEM_H_
#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_ELIMINATE_MAKETUPLE_GETITEM_H_
#include <vector>
#include "backend/common/optimizer/pass.h"
#include "ir/func_graph.h"
#include "common/graph_kernel/core/graph_builder.h"
namespace mindspore::graphkernel {
/**
* @brief Eliminate redundant MakeTuple-Getitem edges
* @example
* %1 = op1
* %2 = op2
* %3 = make_tuple(%1, %2)
* %4 = tuple_getitem(%3, 0)
* %5 = tuple_getitem(%3, 1)
* %6 = op6(%4, %5)
* -->
* %1 = op1
* %2 = op2
* %6 = op6(%1, %2)
*/
class ElimMaketupleGetitem : public opt::Pass {
public:
ElimMaketupleGetitem() : Pass("elim_maketuple_getitem") {}
~ElimMaketupleGetitem() override = default;
bool Run(const FuncGraphPtr &func_graph) override { return EliminateMaketupleGetitem(func_graph); }
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_ELIMINATE_MAKETUPLE_GETITEM_H_

View File

@ -20,9 +20,8 @@
#include <memory>
#include "ir/func_graph.h"
#include "common/graph_kernel/graph_kernel_flags.h"
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/optimizer/graph_optimizer.h"
#include "backend/common/pass/getitem_tuple.h"
#include "common/graph_kernel/core/arithmetic_simplify.h"
#include "common/graph_kernel/core/eliminate_redundant_output.h"
#include "common/graph_kernel/core/shape_ops_splitter.h"
@ -37,9 +36,9 @@
#include "tools/graph_kernel/converter/insert_abstract.h"
#include "tools/graph_kernel/converter/graph_kernel_splitter_lite.h"
#include "tools/graph_kernel/converter/parameter_to_tensor.h"
#include "tools/graph_kernel/converter/eliminate_maketuple_getitem.h"
namespace mindspore::graphkernel {
using opt::GetitemTuple;
using opt::GraphOptimizer;
constexpr size_t kStagePreProcess = 0;
constexpr size_t kStageCluster = 1;
@ -91,8 +90,8 @@ GkPassManagerPtr GraphKernelOptimizer::Split() const {
pm->Add(std::make_shared<GraphKernelSplitterWithTuning>(), OptLevel_1);
// After Simplify and Splitter, a lot of redundant getitem/maketuple
// will be exposed, use GetitemTuple Pass to delete them.
pm->Add(std::make_shared<GetitemTuple>(), OptLevel_1);
// will be exposed, use ElimMaketupleGetitem Pass to delete them.
pm->Add(std::make_shared<ElimMaketupleGetitem>(), OptLevel_1);
// Eliminate the redundant node that is copied above but not handled by GraphKernelSplitter
pm->Add(std::make_shared<MergeOutputForUpdateState>(), OptLevel_1);

View File

@ -24,6 +24,7 @@
#include "ir/graph_utils.h"
#include "src/common/file_utils.h"
#include "utils/file_utils.h"
#include "src/common/utils.h"
namespace mindspore::graphkernel {
namespace dumpir {
@ -467,4 +468,16 @@ void GraphKernelPassManagerLite::DumpPassIR(const FuncGraphPtr &func_graph, cons
dumpir::DumpIR(filename, func_graph, true);
}
}
// transplant this function from pass_manager_extends.cc because the implement was moved to PassManagerLite.
bool GraphKernelPassManagerLite::RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const {
bool changed = false;
auto begin_time = lite::GetTimeUs();
if (pass->Run(func_graph)) {
changed = true;
}
auto end_time = lite::GetTimeUs();
MS_LOG(INFO) << "Run pass " << GetPassFullname(pass_id, pass) << " in " << (end_time - begin_time) << " us.";
return changed;
}
} // namespace mindspore::graphkernel

View File

@ -23,6 +23,7 @@
#include "common/graph_kernel/core/graph_kernel_pass_manager.h"
namespace mindspore::graphkernel {
using opt::PassPtr;
class GraphKernelPassManagerLite : public GraphKernelPassManager {
public:
using GraphKernelPassManager::GraphKernelPassManager;
@ -30,6 +31,7 @@ class GraphKernelPassManagerLite : public GraphKernelPassManager {
protected:
void DumpPassIR(const FuncGraphPtr &func_graph, const std::string &pass_fullname) const override;
bool RunPass(const FuncGraphPtr &func_graph, size_t pass_id, const PassPtr &pass) const override;
bool dump_ir_{false};
};