forked from mindspore-Ecosystem/mindspore
[GraphKernel] replace graph kernel node with custom in lite.
This commit is contained in:
parent
0341d96dd6
commit
9fabf8ae0d
|
@ -28,6 +28,7 @@
|
|||
#include "kernel/akg/akg_kernel_json_generator.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/file_utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
|
@ -157,6 +158,7 @@ bool AkgKernelBuilder::CompileJsonsInAnfnodes(const AnfNodePtrList &node_list) {
|
|||
return false;
|
||||
}
|
||||
std::vector<std::string> json_list;
|
||||
std::string kernels_name = "";
|
||||
for (const auto &node : node_list) {
|
||||
graphkernel::DumpOption option;
|
||||
option.get_compute_capability = true;
|
||||
|
@ -172,14 +174,23 @@ bool AkgKernelBuilder::CompileJsonsInAnfnodes(const AnfNodePtrList &node_list) {
|
|||
GetValidKernelNodes(fg, &node_list, &input_list, &output_list);
|
||||
akg_kernel_json_generator.CollectFusedJson(node_list, input_list, output_list);
|
||||
auto json_kernel_name = akg_kernel_json_generator.kernel_name();
|
||||
AnfUtils::SetNodeAttr("kernel_name", MakeValue(json_kernel_name + "_kernel"), node->cast<CNodePtr>());
|
||||
if (find(json_list.begin(), json_list.end(), json_kernel_name) != json_list.end()) {
|
||||
continue;
|
||||
}
|
||||
json_list.push_back(json_kernel_name);
|
||||
kernels_name += dir_path.value() + "/" + json_kernel_name + ".o ";
|
||||
if (!SaveJsonInfo(dir_path.value() + "/" + json_kernel_name, akg_kernel_json_generator.kernel_json_str())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return CompileJsonsInList(dir_path.value(), json_list);
|
||||
auto res = CompileJsonsInList(dir_path.value(), json_list);
|
||||
if (res) {
|
||||
auto cmd = "g++ -fPIC -shared " + kernels_name + " -o " + dir_path.value() + "/akgkernels.so";
|
||||
if (system(cmd.c_str()) == 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/graph_kernel/lite_adapter/build_kernel.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "common/graph_kernel/lite_adapter/akg_build.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ops/custom.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
void BuildAKGKernel(const std::vector<AnfNodePtr> &node_list) {
|
||||
AnfNodePtrList anf_list;
|
||||
for (auto &node : node_list) {
|
||||
if (AnfUtils::IsGraphKernel(node)) {
|
||||
anf_list.push_back(node);
|
||||
}
|
||||
}
|
||||
graphkernel::AkgKernelBuilder gk;
|
||||
if (!gk.CompileJsonsInAnfnodes(anf_list)) {
|
||||
MS_LOG(EXCEPTION) << "Graph kernel compile fail";
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr KernelBuilder::CreateCustomOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
if (func_graph == nullptr || cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto primc = std::make_shared<ops::Custom>();
|
||||
if (primc == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
primc->set_type("GraphKernel");
|
||||
std::map<std::string, std::vector<uint8_t>> custom_attrs;
|
||||
auto fg = GetCNodeFuncGraph(cnode);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto kernel_name = GetValue<std::string>(fg->get_attr("kernel_name"));
|
||||
std::vector<uint8_t> kernel_name_str(kernel_name.begin(), kernel_name.end());
|
||||
custom_attrs["kernel_name"] = kernel_name_str;
|
||||
primc->set_attr(custom_attrs);
|
||||
auto inputs = cnode->inputs();
|
||||
inputs.erase(inputs.begin());
|
||||
auto custom_cnode = func_graph->NewCNode(primc, inputs);
|
||||
custom_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
|
||||
custom_cnode->set_abstract(cnode->abstract()->Clone());
|
||||
return custom_cnode;
|
||||
}
|
||||
|
||||
bool KernelBuilder::Run(const FuncGraphPtr &func_graph) {
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
BuildAKGKernel(node_list);
|
||||
bool changed = false;
|
||||
auto manager = Manage(func_graph, true);
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
for (auto &node : node_list) {
|
||||
if (!AnfUtils::IsGraphKernel(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto custom_cnode = CreateCustomOp(func_graph, cnode);
|
||||
if (custom_cnode == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Create custom op fail for " << cnode->fullname_with_scope();
|
||||
}
|
||||
manager->Replace(node, custom_cnode);
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_LITE_ADAPTER_BUILD_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_LITE_ADAPTER_BUILD_KERNEL_H_
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/common/optimizer/pass.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
class KernelBuilder : public opt::Pass {
|
||||
public:
|
||||
KernelBuilder() : Pass("build_kernel_lite") {}
|
||||
~KernelBuilder() override = default;
|
||||
|
||||
AnfNodePtr CreateCustomOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_LITE_ADAPTER_BUILD_KERNEL_H_
|
|
@ -23,17 +23,23 @@ bool ConvertConstInputToAttr::Run(const FuncGraphPtr &func_graph) {
|
|||
bool changed = false;
|
||||
auto nodes = TopoSort(func_graph->get_return());
|
||||
for (auto node : nodes) {
|
||||
if (node == nullptr || !AnfUtils::IsRealCNodeKernel(node)) {
|
||||
continue;
|
||||
auto graph_kernel_fg = GetCNodeFuncGraph(node);
|
||||
if (graph_kernel_fg != nullptr && graph_kernel_fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
auto toposet = TopoSort(graph_kernel_fg->get_return());
|
||||
for (auto sub_node : toposet) {
|
||||
if (sub_node == nullptr || !AnfUtils::IsRealCNodeKernel(sub_node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = sub_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
opt::ConstInputToAttrInfoRegister reg;
|
||||
if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfUtils::GetCNodeName(cnode), ®)) {
|
||||
continue;
|
||||
}
|
||||
changed = true;
|
||||
opt::ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
|
||||
}
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
opt::ConstInputToAttrInfoRegister reg;
|
||||
if (!opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfUtils::GetCNodeName(cnode), ®)) {
|
||||
continue;
|
||||
}
|
||||
changed = true;
|
||||
opt::ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
#include "common/graph_kernel/core/eliminate_redundant_output.h"
|
||||
#include "common/graph_kernel/core/shape_ops_splitter.h"
|
||||
#include "common/graph_kernel/core/update_state_formatter.h"
|
||||
#include "common/graph_kernel/lite_adapter/akg_build.h"
|
||||
#include "common/graph_kernel/lite_adapter/build_kernel.h"
|
||||
#include "common/graph_kernel/lite_adapter/convert_const_input_to_attr.h"
|
||||
#include "common/graph_kernel/lite_adapter/graph_kernel_pass_manager.h"
|
||||
|
||||
|
@ -36,12 +36,6 @@ namespace mindspore::graphkernel {
|
|||
using opt::GetitemTuple;
|
||||
using opt::GraphOptimizer;
|
||||
|
||||
PassManagerPtr GraphKernelOptimizer::PreProcess() const {
|
||||
auto pm = std::make_shared<GraphKernelPassManager>(0, "preprocess");
|
||||
pm->AddPass(std::make_shared<ConvertConstInputToAttr>(), OptLevel_1);
|
||||
return pm;
|
||||
}
|
||||
|
||||
PassManagerPtr GraphKernelOptimizer::Cluster() const {
|
||||
auto pm = std::make_shared<GraphKernelPassManager>(0, "cluster");
|
||||
// Expand complex basic kernels to composite kernels
|
||||
|
@ -49,6 +43,7 @@ PassManagerPtr GraphKernelOptimizer::Cluster() const {
|
|||
|
||||
// Cluster basic kernels and composite kernels
|
||||
pm->AddPass(std::make_shared<GraphKernelCluster>(), OptLevel_1);
|
||||
pm->AddPass(std::make_shared<ConvertConstInputToAttr>(), OptLevel_1);
|
||||
|
||||
// Eliminate the outputs without external user
|
||||
pm->AddPass(std::make_shared<EliminateRedundantOutput>(), OptLevel_1);
|
||||
|
@ -77,11 +72,18 @@ PassManagerPtr GraphKernelOptimizer::Split() const {
|
|||
return pm;
|
||||
}
|
||||
|
||||
PassManagerPtr GraphKernelOptimizer::PostProcess() const {
|
||||
auto pm = std::make_shared<GraphKernelPassManager>(1, "postprocess");
|
||||
// build akg and replace graph kernel nodes
|
||||
pm->AddPass(std::make_shared<KernelBuilder>(), OptLevel_1);
|
||||
return pm;
|
||||
}
|
||||
|
||||
void GraphKernelOptimizer::Run(const FuncGraphPtr &kernel_graph) {
|
||||
auto optimizer = std::make_shared<GraphOptimizer>("graph_kernel_optimizer");
|
||||
optimizer->AddPassManager(PreProcess());
|
||||
optimizer->AddPassManager(Cluster());
|
||||
optimizer->AddPassManager(Split());
|
||||
optimizer->AddPassManager(PostProcess());
|
||||
|
||||
auto mng = kernel_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
|
@ -89,17 +91,6 @@ void GraphKernelOptimizer::Run(const FuncGraphPtr &kernel_graph) {
|
|||
kernel_graph->set_manager(mng);
|
||||
}
|
||||
(void)optimizer->Optimize(kernel_graph);
|
||||
auto node_list = kernel_graph->GetOrderedCnodes();
|
||||
AnfNodePtrList anf_list;
|
||||
for (auto &node : node_list) {
|
||||
if (AnfUtils::IsGraphKernel(node)) {
|
||||
anf_list.push_back(node);
|
||||
}
|
||||
}
|
||||
graphkernel::AkgKernelBuilder gk;
|
||||
if (!gk.CompileJsonsInAnfnodes(anf_list)) {
|
||||
MS_LOG(WARNING) << "Graph kernel compile fail";
|
||||
}
|
||||
}
|
||||
|
||||
void GraphKernelOptimize(const FuncGraphPtr &kernel_graph) { GraphKernelOptimizer().Run(kernel_graph); }
|
||||
|
|
|
@ -28,12 +28,12 @@ class GraphKernelOptimizer {
|
|||
void Run(const FuncGraphPtr &kernel_graph);
|
||||
|
||||
private:
|
||||
// before graph_kernel
|
||||
PassManagerPtr PreProcess() const;
|
||||
// Cluster kernels
|
||||
PassManagerPtr Cluster() const;
|
||||
// Split kernels
|
||||
PassManagerPtr Split() const;
|
||||
// Post-process
|
||||
PassManagerPtr PostProcess() const;
|
||||
};
|
||||
|
||||
void GraphKernelOptimize(const FuncGraphPtr &kernel_graph);
|
||||
|
|
|
@ -216,6 +216,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/gpu/
|
|||
list(REMOVE_ITEM MINDSPORE_SRC_LIST
|
||||
"../../../mindspore/ccsrc/plugin/device/gpu/optimizer/batch_norm_relu_grad_fusion.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/common/graph_kernel/lite_adapter/akg_build.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/common/graph_kernel/lite_adapter/build_kernel.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST
|
||||
"../../../mindspore/ccsrc/common/graph_kernel/lite_adapter/callback_impl.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST
|
||||
|
|
Loading…
Reference in New Issue