[GraphKernel] replace graph kernel node with custom in lite.

This commit is contained in:
chenlei_autodiff 2022-02-25 15:07:34 +08:00
parent 0341d96dd6
commit 9fabf8ae0d
7 changed files with 163 additions and 32 deletions

View File

@ -28,6 +28,7 @@
#include "kernel/akg/akg_kernel_json_generator.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "utils/anf_utils.h"
#include "utils/file_utils.h"
#include "utils/log_adapter.h"
@ -157,6 +158,7 @@ bool AkgKernelBuilder::CompileJsonsInAnfnodes(const AnfNodePtrList &node_list) {
return false;
}
std::vector<std::string> json_list;
std::string kernels_name = "";
for (const auto &node : node_list) {
graphkernel::DumpOption option;
option.get_compute_capability = true;
@ -172,14 +174,23 @@ bool AkgKernelBuilder::CompileJsonsInAnfnodes(const AnfNodePtrList &node_list) {
GetValidKernelNodes(fg, &node_list, &input_list, &output_list);
akg_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list);
auto json_kernel_name = akg_kernel_json_generator.kernel_name();
AnfUtils::SetNodeAttr("kernel_name", MakeValue(json_kernel_name + "_kernel"), node->cast<CNodePtr>());
if (find(json_list.begin(), json_list.end(), json_kernel_name) != json_list.end()) {
continue;
}
json_list.push_back(json_kernel_name);
kernels_name += dir_path.value() + "/" + json_kernel_name + ".o ";
if (!SaveJsonInfo(dir_path.value() + "/" + json_kernel_name, akg_kernel_json_generator.kernel_json_str())) {
return false;
}
}
return CompileJsonsInList(dir_path.value(), json_list);
auto res = CompileJsonsInList(dir_path.value(), json_list);
if (res) {
auto cmd = "g++ -fPIC -shared " + kernels_name + " -o " + dir_path.value() + "/akgkernels.so";
if (system(cmd.c_str()) == 0) {
return true;
}
}
return false;
}
} // namespace mindspore::graphkernel

View File

@ -0,0 +1,89 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/graph_kernel/lite_adapter/build_kernel.h"
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "common/graph_kernel/lite_adapter/akg_build.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ops/custom.h"
#include "utils/log_adapter.h"
namespace mindspore::graphkernel {
namespace {
void BuildAKGKernel(const std::vector<AnfNodePtr> &node_list) {
AnfNodePtrList anf_list;
for (auto &node : node_list) {
if (AnfUtils::IsGraphKernel(node)) {
anf_list.push_back(node);
}
}
graphkernel::AkgKernelBuilder gk;
if (!gk.CompileJsonsInAnfnodes(anf_list)) {
MS_LOG(EXCEPTION) << "Graph kernel compile fail";
}
}
} // namespace
AnfNodePtr KernelBuilder::CreateCustomOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
if (func_graph == nullptr || cnode == nullptr) {
return nullptr;
}
auto primc = std::make_shared<ops::Custom>();
if (primc == nullptr) {
return nullptr;
}
primc->set_type("GraphKernel");
std::map<std::string, std::vector<uint8_t>> custom_attrs;
auto fg = GetCNodeFuncGraph(cnode);
MS_EXCEPTION_IF_NULL(fg);
auto kernel_name = GetValue<std::string>(fg->get_attr("kernel_name"));
std::vector<uint8_t> kernel_name_str(kernel_name.begin(), kernel_name.end());
custom_attrs["kernel_name"] = kernel_name_str;
primc->set_attr(custom_attrs);
auto inputs = cnode->inputs();
inputs.erase(inputs.begin());
auto custom_cnode = func_graph->NewCNode(primc, inputs);
custom_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
custom_cnode->set_abstract(cnode->abstract()->Clone());
return custom_cnode;
}
bool KernelBuilder::Run(const FuncGraphPtr &func_graph) {
auto node_list = TopoSort(func_graph->get_return());
BuildAKGKernel(node_list);
bool changed = false;
auto manager = Manage(func_graph, true);
MS_EXCEPTION_IF_NULL(manager);
for (auto &node : node_list) {
if (!AnfUtils::IsGraphKernel(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
auto custom_cnode = CreateCustomOp(func_graph, cnode);
if (custom_cnode == nullptr) {
MS_LOG(EXCEPTION) << "Create custom op fail for " << cnode->fullname_with_scope();
}
manager->Replace(node, custom_cnode);
changed = true;
}
return changed;
}
} // namespace mindspore::graphkernel

View File

@ -0,0 +1,33 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_LITE_ADAPTER_BUILD_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_LITE_ADAPTER_BUILD_KERNEL_H_
#include "ir/func_graph.h"
#include "backend/common/optimizer/pass.h"
namespace mindspore::graphkernel {
class KernelBuilder : public opt::Pass {
public:
KernelBuilder() : Pass("build_kernel_lite") {}
~KernelBuilder() override = default;
AnfNodePtr CreateCustomOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
bool Run(const FuncGraphPtr &func_graph) override;
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_LITE_ADAPTER_BUILD_KERNEL_H_

View File

@ -23,17 +23,23 @@ bool ConvertConstInputToAttr::Run(const FuncGraphPtr &func_graph) {
bool changed = false;
auto nodes = TopoSort(func_graph->get_return());
for (auto node : nodes) {
if (node == nullptr || !AnfUtils::IsRealCNodeKernel(node)) {
continue;
auto graph_kernel_fg = GetCNodeFuncGraph(node);
if (graph_kernel_fg != nullptr && graph_kernel_fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
auto toposet = TopoSort(graph_kernel_fg->get_return());
for (auto sub_node : toposet) {
if (sub_node == nullptr || !AnfUtils::IsRealCNodeKernel(sub_node)) {
continue;
}
auto cnode = sub_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
opt::ConstInputToAttrInfoRegister reg;
if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfUtils::GetCNodeName(cnode), &reg)) {
continue;
}
changed = true;
opt::ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
}
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
opt::ConstInputToAttrInfoRegister reg;
if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfUtils::GetCNodeName(cnode), &reg)) {
continue;
}
changed = true;
opt::ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
}
return changed;
}

View File

@ -28,7 +28,7 @@
#include "common/graph_kernel/core/eliminate_redundant_output.h"
#include "common/graph_kernel/core/shape_ops_splitter.h"
#include "common/graph_kernel/core/update_state_formatter.h"
#include "common/graph_kernel/lite_adapter/akg_build.h"
#include "common/graph_kernel/lite_adapter/build_kernel.h"
#include "common/graph_kernel/lite_adapter/convert_const_input_to_attr.h"
#include "common/graph_kernel/lite_adapter/graph_kernel_pass_manager.h"
@ -36,12 +36,6 @@ namespace mindspore::graphkernel {
using opt::GetitemTuple;
using opt::GraphOptimizer;
PassManagerPtr GraphKernelOptimizer::PreProcess() const {
auto pm = std::make_shared<GraphKernelPassManager>(0, "preprocess");
pm->AddPass(std::make_shared<ConvertConstInputToAttr>(), OptLevel_1);
return pm;
}
PassManagerPtr GraphKernelOptimizer::Cluster() const {
auto pm = std::make_shared<GraphKernelPassManager>(0, "cluster");
// Expand complex basic kernels to composite kernels
@ -49,6 +43,7 @@ PassManagerPtr GraphKernelOptimizer::Cluster() const {
// Cluster basic kernels and composite kernels
pm->AddPass(std::make_shared<GraphKernelCluster>(), OptLevel_1);
pm->AddPass(std::make_shared<ConvertConstInputToAttr>(), OptLevel_1);
// Eliminate the outputs without external user
pm->AddPass(std::make_shared<EliminateRedundantOutput>(), OptLevel_1);
@ -77,11 +72,18 @@ PassManagerPtr GraphKernelOptimizer::Split() const {
return pm;
}
PassManagerPtr GraphKernelOptimizer::PostProcess() const {
auto pm = std::make_shared<GraphKernelPassManager>(1, "postprocess");
// build akg and replace graph kernel nodes
pm->AddPass(std::make_shared<KernelBuilder>(), OptLevel_1);
return pm;
}
void GraphKernelOptimizer::Run(const FuncGraphPtr &kernel_graph) {
auto optimizer = std::make_shared<GraphOptimizer>("graph_kernel_optimizer");
optimizer->AddPassManager(PreProcess());
optimizer->AddPassManager(Cluster());
optimizer->AddPassManager(Split());
optimizer->AddPassManager(PostProcess());
auto mng = kernel_graph->manager();
if (mng == nullptr) {
@ -89,17 +91,6 @@ void GraphKernelOptimizer::Run(const FuncGraphPtr &kernel_graph) {
kernel_graph->set_manager(mng);
}
(void)optimizer->Optimize(kernel_graph);
auto node_list = kernel_graph->GetOrderedCnodes();
AnfNodePtrList anf_list;
for (auto &node : node_list) {
if (AnfUtils::IsGraphKernel(node)) {
anf_list.push_back(node);
}
}
graphkernel::AkgKernelBuilder gk;
if (!gk.CompileJsonsInAnfnodes(anf_list)) {
MS_LOG(WARNING) << "Graph kernel compile fail";
}
}
void GraphKernelOptimize(const FuncGraphPtr &kernel_graph) { GraphKernelOptimizer().Run(kernel_graph); }

View File

@ -28,12 +28,12 @@ class GraphKernelOptimizer {
void Run(const FuncGraphPtr &kernel_graph);
private:
// before graph_kernel
PassManagerPtr PreProcess() const;
// Cluster kernels
PassManagerPtr Cluster() const;
// Split kernels
PassManagerPtr Split() const;
// Post-process
PassManagerPtr PostProcess() const;
};
void GraphKernelOptimize(const FuncGraphPtr &kernel_graph);

View File

@ -216,6 +216,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/gpu/
list(REMOVE_ITEM MINDSPORE_SRC_LIST
"../../../mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_relu_grad_fusion.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/common/graph_kernel/lite_adapter/akg_build.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/common/graph_kernel/lite_adapter/build_kernel.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST
"../../../mindspore/ccsrc/common/graph_kernel/lite_adapter/callback_impl.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST