!26515 Refactor FuseNodesToSubGraph and decouple it from graph_kernel_helper

Merge pull request !26515 from DeshiChen/1111_fusenodes
This commit is contained in:
i-robot 2021-11-23 07:58:56 +00:00 committed by Gitee
commit 511441a27e
21 changed files with 410 additions and 379 deletions

View File

@ -17,7 +17,9 @@
#include "backend/optimizer/graph_kernel/adapter/callback_impl.h"
#include <algorithm>
#include <string>
#include <vector>
#include <utility>
#include <memory>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
@ -76,4 +78,62 @@ std::string CallbackImpl::GetOutputFormat(const AnfNodePtr &node, size_t i) {
std::string CallbackImpl::GetProcessor(const AnfNodePtr &node) { return kernel::GetProcessorStr(node); }
std::string CallbackImpl::GetProcessorFromContext() { return kernel::GetStrProcessorFromContext(); }
void CallbackImpl::SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) {
std::vector<std::string> graph_input_format;
std::vector<TypeId> graph_input_type;
std::vector<std::string> graph_output_format;
std::vector<TypeId> graph_output_type;
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto fg = GetCNodeFuncGraph(node);
MS_EXCEPTION_IF_NULL(fg);
auto &inputs = cnode->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
auto kernel_with_index = AnfUtils::VisitKernel(inputs[i], 0);
if (kernel_with_index.first->isa<ValueNode>()) {
auto tensor = GetValueNode<tensor::TensorPtr>(kernel_with_index.first);
MS_EXCEPTION_IF_NULL(tensor);
(void)graph_input_format.emplace_back(kOpFormat_DEFAULT);
(void)graph_input_type.emplace_back(tensor->data_type());
} else {
auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
(void)graph_input_format.emplace_back(std::move(input_format));
auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
(void)graph_input_type.emplace_back(input_type);
}
fg->parameters()[i - 1]->set_kernel_info(std::make_shared<device::KernelInfo>());
kernel::KernelBuildInfo::KernelBuildInfoBuilder para_info_builder;
para_info_builder.SetOutputsFormat({graph_input_format.back()});
para_info_builder.SetOutputsDeviceType({graph_input_type.back()});
para_info_builder.SetKernelType(KernelType::AKG_KERNEL);
para_info_builder.SetProcessor(kernel::GetProcessorFromContext());
AnfAlgo::SetSelectKernelBuildInfo(para_info_builder.Build(), fg->parameters()[i - 1].get());
}
AnfNodePtrList outputs;
if (IsPrimitiveCNode(fg->output(), prim::kPrimMakeTuple)) {
auto fg_output = fg->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(fg_output);
outputs.assign(fg_output->inputs().begin() + 1, fg_output->inputs().end());
} else {
outputs.push_back(fg->output());
}
for (size_t i = 0; i < outputs.size(); ++i) {
auto kernel_with_index = AnfAlgo::VisitKernel(outputs[i], 0);
auto output_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
graph_output_format.push_back(output_format);
graph_output_type.push_back(output_type);
}
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
graph_info_builder.SetInputsFormat(graph_input_format);
graph_info_builder.SetInputsDeviceType(graph_input_type);
graph_info_builder.SetOutputsFormat(graph_output_format);
graph_info_builder.SetOutputsDeviceType(graph_output_type);
graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
auto graph_selected_info = graph_info_builder.Build();
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, node.get());
}
} // namespace mindspore::graphkernel

View File

@ -34,6 +34,7 @@ class CallbackImpl : public Callback {
std::string GetOutputFormat(const AnfNodePtr &node, size_t i) override;
std::string GetProcessor(const AnfNodePtr &node) override;
std::string GetProcessorFromContext() override;
void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) override;
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_CALLBACK_IMPL_H_

View File

@ -29,6 +29,7 @@
#include "utils/log_adapter.h"
#include "backend/kernel_compiler/kernel.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_graph.h"
#include "debug/anf_ir_dump.h"
@ -360,7 +361,7 @@ void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node,
CorrectKernelBuildInfo(composite_node, new_input);
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
auto new_graph_name = GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
}
@ -435,7 +436,7 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP
auto broadcast_to_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph)});
broadcast_to_composite_node->set_abstract(broadcast_to_node_inner->abstract());
SetNewKernelInfo(broadcast_to_composite_node, new_sub_graph, {}, {broadcast_to_node_inner});
auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean");
auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean");
new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
new_sub_graph->set_attr("composite_type", MakeValue("atomic_clean"));

View File

@ -25,6 +25,7 @@
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_graph.h"
@ -82,7 +83,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_
}
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
auto new_graph_name = GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
}

View File

@ -25,6 +25,7 @@
#include <unordered_map>
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/anf.h"
#include "utils/context/graph_kernel_flags.h"
@ -645,14 +646,12 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
}
if (!change_anf_graph) continue;
ReorganizeEmptyGraph(lg);
AnfNodePtrList outputs;
auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs);
auto new_funcgraph = LiteGraph2AnfGraph(lg);
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto cnode = node->cast<CNodePtr>();
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
EliminateRedundantParameters(new_funcgraph, &inputs);
auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs);
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs);
mng->Replace(node, new_node);
mng->AddFuncGraph(new_funcgraph);
do_simplify = true;

View File

@ -25,29 +25,27 @@
#include "base/core_ops.h"
#include "ir/func_graph.h"
#include "utils/utils.h"
#include "utils/anf_utils.h"
#include "backend/optimizer/graph_kernel/core/graph_kernel_callback.h"
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
namespace mindspore::graphkernel {
AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users,
const std::unordered_set<AnfNodePtr> &seen) {
namespace {
// find outputs of nodes
AnfNodePtrList FindOutputs(const AnfNodePtrList &nodes, const AnfNodePtrToAnfNodePtrMap &eqv) {
AnfNodePtrList output;
if (users.size() == 0) {
return output;
}
auto mng = nodes[0]->func_graph()->manager();
MS_EXCEPTION_IF_NULL(mng);
auto &users = mng->node_users();
for (auto &node : nodes) {
if (!node->isa<CNode>()) {
continue;
}
// only CNode can be an output.
if (!node->isa<CNode>()) continue;
auto iter = users.find(node);
if (iter == users.end()) {
continue;
}
if (iter == users.end()) continue;
auto &node_users = iter->second;
const bool has_outer_user = std::any_of(std::begin(node_users), std::end(node_users),
[&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
const bool is_outer_user = (seen.find(u.first) == seen.end());
return is_outer_user;
});
if (has_outer_user) {
// if any user of the `node` is not in the nodes list, the `node` is an output.
if (std::any_of(std::begin(node_users), std::end(node_users),
[&eqv](const std::pair<AnfNodePtr, int> &u) { return eqv.find(u.first) == eqv.end(); })) {
output.emplace_back(node);
}
}
@ -56,12 +54,11 @@ AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users,
AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *const inputs_ptr,
AnfNodePtrToAnfNodePtrMap *eqv_ptr) {
auto &input_list = *inputs_ptr;
auto &eqv = *eqv_ptr;
if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
eqv[node] = node;
} else if (eqv.find(node) == eqv.end()) {
input_list.push_back(node);
inputs_ptr->push_back(node);
eqv[node] = fg->add_parameter();
eqv[node]->set_abstract(node->abstract());
eqv[node]->set_kernel_info(node->kernel_info_ptr());
@ -69,43 +66,150 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo
return eqv[node];
}
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(const AnfNodePtrList &node_list) {
bool InlineInnerFuncGraph(const FuncGraphPtr &fg) {
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
bool changed = false;
auto cnodes = fg->GetOrderedCnodes();
for (const auto &n : cnodes) {
auto graph_kernel_g = GetCNodeFuncGraph(n);
if (graph_kernel_g == nullptr) continue;
AnfNodePtrList inp(n->inputs().begin() + 1, n->inputs().end());
auto out = InlineClone(graph_kernel_g, fg, inp, n->input(0)->scope());
mng->Replace(n, out);
changed = true;
}
return changed;
}
void EliminateMakeTuple(const FuncGraphPtr &fg) {
if (!IsPrimitiveCNode(fg->output(), prim::kPrimMakeTuple)) {
return;
}
auto out_cnode = fg->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_cnode);
AnfNodePtrList new_args = GkUtils::SpreadTuples(out_cnode->inputs());
if (new_args.size() != out_cnode->size()) {
auto new_out = fg->NewCNode(new_args);
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
mng->Replace(out_cnode, new_out);
}
AbstractBasePtrList abs_list;
std::transform(new_args.begin() + 1, new_args.end(), std::back_inserter(abs_list),
[](const AnfNodePtr &node) { return node->abstract(); });
fg->output()->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
}
bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) {
auto cnodes = fg->GetOrderedCnodes();
AnfNodePtrList value_nodes;
for (const auto &cnode : cnodes) {
auto &inputs = cnode->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
const auto &tnode = inputs[i];
auto tensor = GetValueNode<tensor::TensorPtr>(tnode);
if (tensor == nullptr || tensor->DataSize() == 1) {
continue;
}
value_nodes.push_back(tnode);
}
}
if (value_nodes.empty()) return false;
auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng);
for (auto &vnode : value_nodes) {
auto parameter = fg->add_parameter();
parameter->set_abstract(vnode->abstract());
parameter->set_kernel_info(vnode->kernel_info_ptr());
mng->Replace(vnode, parameter);
inputs_ptr->push_back(vnode);
}
return true;
}
bool IsTupleOutput(const AnfNodePtr &out, AnfNodePtrList *real_outs) {
if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
auto &inputs = out->cast<CNodePtr>()->inputs();
real_outs->assign(inputs.begin() + 1, inputs.end());
return true;
}
if (auto fg = GetCNodeFuncGraph(out); fg != nullptr) {
return IsTupleOutput(fg->output(), real_outs);
}
return false;
}
void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_fuse_cnode,
const AnfNodePtrList &outputs) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
// single out
if (outputs.size() == 1) {
mng->Replace(outputs[0], new_fuse_cnode);
return;
}
size_t offset = 0;
for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) {
AnfNodePtrList real_outs;
// the output is a single tensor
if (!IsTupleOutput(outputs[out_idx], &real_outs)) {
auto gt_idx = MakeValue(SizeToLong(out_idx + offset));
AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), new_fuse_cnode, NewValueNode(gt_idx)};
gt_inputs.back()->set_abstract(gt_idx->ToAbstract());
auto new_out = func_graph->NewCNode(gt_inputs);
new_out->set_abstract(outputs[out_idx]->abstract());
mng->Replace(outputs[out_idx], new_out);
continue;
}
// the out is make tuple , modify the get_item node's value
auto users = mng->node_users()[outputs[out_idx]]; // use a copy, the original user map is changed in for-loop.
for (auto &user : users) {
auto getitem_node = user.first;
if (!getitem_node->isa<CNode>() || !IsPrimitiveCNode(getitem_node, prim::kPrimTupleGetItem)) {
continue;
}
auto value_ptr = GetValueNode(getitem_node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
MS_EXCEPTION_IF_NULL(value_ptr);
auto old_gt_idx = GetValue<int64_t>(value_ptr);
auto gt_idx = MakeValue(SizeToLong(out_idx + offset) + old_gt_idx);
AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), new_fuse_cnode, NewValueNode(gt_idx)};
gt_inputs.back()->set_abstract(gt_idx->ToAbstract());
auto new_getitem_node = func_graph->NewCNode(gt_inputs);
new_getitem_node->set_abstract(getitem_node->abstract());
mng->Replace(getitem_node, new_getitem_node);
}
offset += real_outs.size() - 1;
}
}
} // namespace
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(const AnfNodePtrList &nodes) {
FuncGraphPtr fg = nullptr;
{
// limit the lifetime of guard.
TraceGuard guard(
std::make_shared<TraceSegmentTransform>(node_list[0]->cast<CNodePtr>()->func_graph()->debug_info()));
TraceGuard guard(std::make_shared<TraceSegmentTransform>(nodes[0]->cast<CNodePtr>()->func_graph()->debug_info()));
fg = std::make_shared<FuncGraph>();
}
AnfNodePtrList input_list;
AnfNodePtrToAnfNodePtrMap eqv;
// Merge CNodes into a AnfGraph that represents a linear instruction segment
for (auto node : node_list) {
auto &input_nodes = node->cast<CNodePtr>()->inputs();
auto fn = input_nodes[0];
std::vector<AnfNodePtr> new_args{fn};
if (IsPrimitive(fn, prim::kPrimDepend) && input_nodes.size() >= kDependInputSize &&
eqv.find(input_nodes[kDependAttachNodeIndex]) == eqv.end()) {
new_args.emplace_back(RefSubGraphNode(fg, input_nodes[kRealInputIndexInDepend], &input_list, &eqv));
const size_t value_start_index = 2;
for (size_t i = value_start_index; i < input_nodes.size(); ++i) {
new_args.emplace_back(NewValueNode(MakeValue(0)));
}
} else {
(void)std::transform(
std::begin(input_nodes) + 1, std::end(input_nodes), std::back_inserter(new_args),
[&fg, &input_list, &eqv](const AnfNodePtr &node) { return RefSubGraphNode(fg, node, &input_list, &eqv); });
}
for (auto &node : nodes) {
auto &node_inputs = node->cast<CNodePtr>()->inputs();
std::vector<AnfNodePtr> new_args{node_inputs[0]};
(void)std::transform(
std::begin(node_inputs) + 1, std::end(node_inputs), std::back_inserter(new_args),
[&fg, &input_list, &eqv](const AnfNodePtr &node) { return RefSubGraphNode(fg, node, &input_list, &eqv); });
TraceGuard tg(std::make_shared<TraceSegmentTransform>(node->debug_info()));
eqv[node] = fg->NewCNode(new_args);
eqv[node]->set_abstract(node->abstract());
eqv[node]->set_kernel_info(node->kernel_info_ptr());
}
std::unordered_set<AnfNodePtr> eqv_keys;
(void)std::transform(std::begin(eqv), std::end(eqv), std::inserter(eqv_keys, eqv_keys.end()),
[](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; });
auto mgr = node_list[0]->func_graph()->manager();
auto outputs = GetOutput(node_list, mgr->node_users(), eqv_keys);
auto outputs = FindOutputs(nodes, eqv);
AnfNodePtr fg_output;
if (outputs.size() > 1) {
std::vector<AnfNodePtr> output_args;
@ -120,4 +224,52 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(con
fg->set_output(fg_output);
return std::make_tuple(fg, input_list, outputs);
}
// Transform nodes(including basic and composite node) to a new graph, and collect their inputs and outputs.
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildSingleGraphFromNodes(const AnfNodePtrList &nodes) {
FuncGraphPtr fg;
AnfNodePtrList inputs;
AnfNodePtrList outputs;
std::tie(fg, inputs, outputs) = BuildGraphFromNodes(nodes);
FuncGraphManagerPtr mng = fg->manager();
if (mng == nullptr) {
mng = Manage(fg, false);
fg->set_manager(mng);
}
InlineInnerFuncGraph(fg);
// eliminate tuple of tuple, and set Abstract for output MakeTuple
EliminateMakeTuple(fg);
ConvertNonscalarTensorToParameter(fg, &inputs);
return std::make_tuple(fg, inputs, outputs);
}
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &main_fg, const FuncGraphPtr &sub_fg, const AnfNodePtrList &inputs) {
std::vector<AnfNodePtr> fn_inputs{NewValueNode(sub_fg)};
fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end());
auto fuse_cnode = main_fg->NewCNode(fn_inputs);
fuse_cnode->set_abstract(sub_fg->output()->abstract());
Callback::Instance()->SetGraphKernelNodeKernelInfo(fuse_cnode);
return fuse_cnode;
}
AnfNodePtr ReplaceNodesWithGraphKernelNode(const AnfNodePtrList &nodes, const FuncGraphPtr &main_graph,
const std::string &postfix) {
auto mng = main_graph->manager();
if (mng == nullptr) {
mng = Manage(main_graph, true);
main_graph->set_manager(mng);
}
FuncGraphPtr fg;
AnfNodePtrList inputs;
AnfNodePtrList outputs;
std::tie(fg, inputs, outputs) = BuildSingleGraphFromNodes(nodes);
auto fuse_new_node = CreateNewFuseCNode(main_graph, fg, inputs);
ReplaceNewFuseCNode(main_graph, fuse_new_node, outputs);
auto fuse_op_name = GkUtils::ExtractGraphKernelName(nodes, "", postfix);
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
return fuse_new_node;
}
} // namespace mindspore::graphkernel

View File

@ -18,11 +18,16 @@
#include <unordered_map>
#include <tuple>
#include <string>
#include "ir/anf.h"
namespace mindspore::graphkernel {
using AnfNodePtrToAnfNodePtrMap = std::unordered_map<AnfNodePtr, AnfNodePtr>;
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(const AnfNodePtrList &lst);
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(const AnfNodePtrList &nodes);
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildSingleGraphFromNodes(const AnfNodePtrList &nodes);
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &main_fg, const FuncGraphPtr &sub_fg, const AnfNodePtrList &inputs);
AnfNodePtr ReplaceNodesWithGraphKernelNode(const AnfNodePtrList &nodes, const FuncGraphPtr &main_graph,
const std::string &postfix = "");
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_BUILDER_H_

View File

@ -120,6 +120,13 @@ class Callback {
*/
virtual std::string GetProcessorFromContext() = 0;
/**
* @brief Set KernelInfo for a GraphKernel node, the info is extract from its inputs/outputs.
*
* @param[in] node the GraphKernel CNode.
*/
virtual void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) = 0;
private:
friend class CallbackImplRegister;
static void RegImpl(Callback *cb) { instance_.reset(cb); }

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
#include <sstream>
#include "base/core_ops.h"
#include "utils/anf_utils.h"
namespace mindspore::graphkernel {
std::string GkUtils::ExtractGraphKernelName(const AnfNodePtrList &nodes, const std::string &prefix,
const std::string &postfix) {
std::stringstream name;
if (!prefix.empty()) {
name << prefix << "_";
}
for (const auto &node : nodes) {
if (AnfUtils::IsGraphKernel(node)) {
auto fg_flag_val = GetCNodeFuncGraph(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
name << GetValue<std::string>(fg_flag_val) << "_";
} else if (node->isa<CNode>() && AnfUtils::IsRealKernel(node)) {
name << GetCNodePrimitive(node)->name() << "_";
}
}
if (!postfix.empty()) {
name << postfix;
}
return name.str();
}
AnfNodePtrList GkUtils::SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) {
AnfNodePtrList result;
for (size_t i = begin_index; i < nodes.size(); i++) {
if (IsPrimitiveCNode(nodes[i], prim::kPrimMakeTuple)) {
auto mt = nodes[i]->cast<CNodePtr>();
// recursively spread all inner tuples.
auto mt_inputs = SpreadTuples(mt->inputs(), 1);
result.insert(result.end(), mt_inputs.begin(), mt_inputs.end());
} else {
result.push_back(nodes[i]);
}
}
return result;
}
} // namespace mindspore::graphkernel

View File

@ -0,0 +1,52 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_
#include <string>
#include "ir/anf.h"
#include "ir/func_graph.h"
namespace mindspore::graphkernel {
class GkUtils {
public:
/**
* @brief Extract kernel name from nodes, only the real kernel CNode is processed.
* @param[in] nodes The node list
* @param[in] prefix The prefix of result name
* @param[in] postfix The postfix of result name
* @return The string concatenated by the names of all cnodes
*/
static std::string ExtractGraphKernelName(const AnfNodePtrList &nodes, const std::string &prefix = "",
const std::string &postfix = "");
/**
* @brief Spread the MakeTuple in node list
* @param[in] nodes
* @param[in] begin_index
* @example
* input
* nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ]
* begin_index: 1
* output
* [b, i, j, c, d, x, y, z]
* @return std::vector<AnfNodePtr>
*/
static AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0);
};
} // namespace mindspore::graphkernel
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_

View File

@ -29,6 +29,7 @@
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
namespace mindspore::graphkernel {
namespace {
@ -257,10 +258,7 @@ AnfNodePtr EliminateHangingOutput::ReplaceMakeTuple(const AnfNodePtr &node, cons
auto old_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(old_cnode);
AnfNodePtrList inputs(old_cnode->inputs().begin() + 1, old_cnode->inputs().end());
AnfNodePtrList outputs;
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs, outputs);
SetNewKernelInfo(graph_kernel_node, func_graph, inputs, outputs);
auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs);
return graph_kernel_node;
}

View File

@ -33,6 +33,7 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/pass/getitem_tuple.h"
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
namespace mindspore::graphkernel {
std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() {
@ -404,10 +405,9 @@ bool GraphKernelCluster::Process(const FuncGraphPtr &func_graph) {
void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id) {
AnfNodePtrList old_nodes;
AnfNodePtr new_node;
(void)std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes),
[this](size_t id) { return this->nodes_[id]; });
std::tie(new_node, std::ignore) = FuseNodesToSubGraph(old_nodes, func_graph, "fusion");
auto new_node = ReplaceNodesWithGraphKernelNode(old_nodes, func_graph, "fusion");
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
(void)eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(new_node));
if (GraphKernelFlags::GetInstance().dump_as_text) {

View File

@ -37,6 +37,7 @@
#include "pybind_api/ir/primitive_py.h"
#include "runtime/device/kernel_info.h"
#include "backend/optimizer/graph_kernel/expanders/expander_factory.h"
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
namespace mindspore::graphkernel {
namespace {
@ -164,12 +165,9 @@ AnfNodePtr PyExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_grap
auto func_graph = old_node->func_graph();
std::vector<AnfNodePtr> inputs(old_node->inputs().begin() + 1, old_node->inputs().end());
AnfNodePtrList kernel_nodes;
AnfNodePtrList outputs;
EliminateRedundantParameters(new_func_graph, &inputs);
kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes);
kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs);
auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs);
SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs);
auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs);
MS_LOG(DEBUG) << "Expand node: " << old_node->fullname_with_scope()
<< " with: " << graph_kernel_node->fullname_with_scope();
return graph_kernel_node;

View File

@ -67,37 +67,6 @@ bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) {
return false;
}
AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) {
AnfNodePtrList outs;
auto out_node = fg->output();
if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
std::vector<AnfNodePtr> output_args;
auto out_cnode = out_node->cast<CNodePtr>();
for (auto out : out_cnode->inputs()) {
if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
auto inputs = out->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
output_args.push_back(inputs[i]);
}
} else {
output_args.push_back(out);
}
}
if (output_args.size() != out_cnode->inputs().size()) {
auto new_out = fg->NewCNode(output_args);
mng->Replace(out_node, new_out);
}
for (size_t i = 1; i < output_args.size(); ++i) {
outs.push_back(output_args[i]);
}
return outs;
}
outs.push_back(out_node);
return outs;
}
bool GenJson(const AnfNodePtrList &op_nodes, const std::pair<AnfNodePtrList, AnfNodePtrList> &in_and_out,
const DumpOption &dump_option, nlohmann::json *op_desc,
std::map<std::string, AnfNodePtr> *address_node_map = nullptr) {
@ -128,100 +97,6 @@ AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) {
return out_spec;
}
bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) {
MS_EXCEPTION_IF_NULL(inputs_ptr);
auto nodes = TopoSort(fg->get_return());
std::vector<std::pair<tensor::TensorPtr, AnfNodePtrList>> v_replace;
for (const auto &node : nodes) {
if (!node->isa<CNode>()) {
continue;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
const auto &tnode = inputs[i];
auto tensor = GetValueNode<tensor::TensorPtr>(tnode);
if (tensor == nullptr || tensor->DataSize() == 1) {
continue;
}
auto tensor_iter = std::find_if(
v_replace.begin(), v_replace.end(),
[&tensor](const std::pair<tensor::TensorPtr, AnfNodePtrList> &vl) { return vl.first->ValueEqual(*tensor); });
if (tensor_iter == v_replace.end()) {
(void)v_replace.emplace_back(tensor, AnfNodePtrList{tnode});
} else {
tensor_iter->second.push_back(tnode);
}
}
}
if (v_replace.empty()) {
return false;
}
auto mng = fg->manager();
if (mng == nullptr) {
mng = Manage(fg, false);
fg->set_manager(mng);
}
auto &inputs = *inputs_ptr;
for (auto iter : v_replace) {
auto value_nodes = iter.second;
if (value_nodes.empty()) {
MS_LOG(EXCEPTION) << "Invalid value in map!";
}
auto vnode = value_nodes[0];
auto parameter = fg->add_parameter();
parameter->set_abstract(vnode->abstract());
parameter->set_kernel_info(vnode->kernel_info_ptr());
for (const auto &value_node : value_nodes) {
mng->Replace(value_node, parameter);
}
inputs.push_back(vnode);
}
return true;
}
// Transform nodes(including basic and composite node) to a new graph, and collect their inputs and outputs.
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(const AnfNodePtrList &fuse_nodes,
AnfNodePtrList *src_outputs) {
FuncGraphPtr fg;
AnfNodePtrList inputs;
AnfNodePtrList outputs;
AnfNodePtrList *soutputs = (src_outputs != nullptr) ? src_outputs : &outputs;
std::tie(fg, inputs, *soutputs) = BuildGraphFromNodes(fuse_nodes);
FuncGraphManagerPtr mng = fg->manager();
if (mng == nullptr) {
mng = Manage(fg, false);
fg->set_manager(mng);
}
// Inline origin graphkernel
auto cnodes = fg->GetOrderedCnodes();
for (const auto &n : cnodes) {
if (!AnfAlgo::IsGraphKernel(n)) {
continue;
}
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(n->input(0));
AnfNodePtrList ins;
ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end());
auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope());
mng->Replace(n, out);
}
EliminateMakeTuple(fg, mng);
ConvertNonscalarTensorToParameter(fg, &inputs);
outputs.clear();
kernel::GetFuncGraphOutputNodes(fg, &outputs);
return std::make_tuple(fg, inputs, outputs);
}
// Rebuild as node inputs or outputs have changed, processor comes from node itself
kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
const std::vector<TypeId> &inputs_type,
@ -254,6 +129,7 @@ kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::str
return graph_info_builder.Build();
}
// Deprecated. use Callback->SetGraphKernelNodeKernelInfo.
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs) {
std::vector<std::string> graph_input_format;
@ -309,127 +185,6 @@ void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, new_node.get());
}
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs) {
auto func_node = NewValueNode(fg);
std::vector<AnfNodePtr> fn_inputs;
fn_inputs.push_back(func_node);
fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end());
auto fuse_cnode = func_graph->NewCNode(fn_inputs);
// Set output abstract
if (outputs.size() > 1) {
std::vector<AbstractBasePtr> out_specs;
for (size_t i = 0; i < outputs.size(); ++i) {
out_specs.push_back(outputs[i]->abstract());
}
auto out_spec = std::make_shared<abstract::AbstractTuple>(out_specs);
fuse_cnode->set_abstract(out_spec);
} else {
fuse_cnode->set_abstract(outputs[0]->abstract());
}
// Set parameter abstract.
for (size_t i = 0; i < inputs.size(); ++i) {
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
fg->parameters()[i]->set_abstract(input_abs);
}
return fuse_cnode;
}
void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_fuse_cnode,
const AnfNodePtrList &outputs) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
// single out
if (outputs.size() == 1) {
mng->Replace(outputs[0], new_fuse_cnode);
return;
}
std::vector<AnfNodePtr> fn_inputs;
size_t offset = 0;
for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) {
AnfNodePtrList real_outs;
// not make tuple out, replace
if (!IsMakeTupleOut(outputs[out_idx], &real_outs)) {
fn_inputs.clear();
fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
fn_inputs.push_back(new_fuse_cnode);
fn_inputs.push_back(NewValueNode(MakeValue(SizeToLong(out_idx + offset))));
auto new_out = func_graph->NewCNode(fn_inputs);
new_out->set_abstract(outputs[out_idx]->abstract());
mng->Replace(outputs[out_idx], new_out);
continue;
}
// the out is make tuple , modify the get_item node's value
auto users = mng->node_users()[outputs[out_idx]];
for (auto &user : users) {
auto use_node = user.first;
if (!use_node->isa<CNode>() || !IsPrimitiveCNode(use_node, prim::kPrimTupleGetItem)) {
continue;
}
auto get_item_cnode = use_node->cast<CNodePtr>();
auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(value_input);
auto value_node = value_input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto item_idx = GetValue<int64_t>(value_node->value());
int64_t new_item_idx = SizeToLong(out_idx + offset) + item_idx;
fn_inputs.clear();
fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
fn_inputs.push_back(new_fuse_cnode);
fn_inputs.push_back(NewValueNode(new_item_idx));
auto new_out = func_graph->NewCNode(fn_inputs);
new_out->set_abstract(get_item_cnode->abstract());
mng->Replace(get_item_cnode, new_out);
}
offset += real_outs.size() - 1;
}
}
std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
const FuncGraphPtr &kernel_graph,
const std::string &postfix) {
auto mng = kernel_graph->manager();
if (mng == nullptr) {
mng = Manage(kernel_graph, true);
kernel_graph->set_manager(mng);
}
FuncGraphPtr fg;
AnfNodePtrList inputs;
AnfNodePtrList src_outputs;
AnfNodePtrList outputs;
std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(fuse_nodes, &src_outputs);
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs);
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs);
// Handle get-item probleam.
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, src_outputs);
// set graphKernel attr
std::string fuse_op_name = "";
for (auto &fuse_node : fuse_nodes) {
if (IsPrimitiveCNode(fuse_node)) {
fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_";
} else if (AnfAlgo::IsGraphKernel(fuse_node)) {
auto fuse_cnode = fuse_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(fuse_cnode);
auto graph_kernel_fg = GetValueNode<FuncGraphPtr>(fuse_cnode->input(kAnfPrimitiveIndex));
auto fg_flag_val = graph_kernel_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
auto fuse_fg_name = GetValue<std::string>(fg_flag_val);
fuse_op_name += fuse_fg_name + "_";
}
}
fuse_op_name += postfix;
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
return std::make_tuple(fuse_new_node, src_outputs);
}
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
std::map<std::string, AnfNodePtr> *address_node_map) {
MS_EXCEPTION_IF_NULL(op_desc);
@ -466,15 +221,14 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n
}
FuncGraphPtr fg;
AnfNodePtrList op_nodes, inputs, outputs;
if (nodes.size() == 1 && AnfAlgo::IsGraphKernel(nodes[0])) {
fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]);
} else {
std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(nodes);
inputs.clear();
outputs.clear();
std::tie(fg, std::ignore, std::ignore) = BuildSingleGraphFromNodes(nodes);
}
AnfNodePtrList op_nodes, inputs, outputs;
kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs);
auto mng = fg->manager();
@ -524,22 +278,6 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc) {
return fg;
}
std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix, const string &postfix) {
std::stringstream name;
if (prefix != "") {
name << prefix << "_";
}
for (const auto &node : cnodes) {
if (node->isa<CNode>() && AnfUtils::IsRealKernel(node)) {
name << AnfAlgo::GetCNodeName(node) << "_";
}
}
if (postfix != "") {
name << postfix;
}
return name.str();
}
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);

View File

@ -53,9 +53,6 @@ struct DataInfo {
TypePtr type{nullptr};
};
bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr);
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(const AnfNodePtrList &fuse_nodes,
AnfNodePtrList *src_outputs = nullptr);
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs);
kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
@ -66,19 +63,11 @@ kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::str
const std::vector<TypeId> &inputs_type,
const std::vector<std::string> &output_formats,
const std::vector<TypeId> &output_types);
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
const AnfNodePtrList &outputs);
void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode,
const AnfNodePtrList &outputs);
std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
const FuncGraphPtr &kernel_graph,
const std::string &postfix = "");
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc);
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
std::map<std::string, AnfNodePtr> *address_node_map);
bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc);
FuncGraphPtr JsonDescToAnf(const std::string &json_desc);
std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = "");
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
std::string GetFormat(const AnfNodePtr &node);

View File

@ -27,6 +27,7 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
#include "debug/anf_ir_dump.h"
#include "utils/context/graph_kernel_flags.h"
@ -632,7 +633,7 @@ class Splitter {
graph_manager->AddFuncGraph(sub_func_graph);
// set GraphKernel attr
auto attr = ExtractGraphKernelName(TopoSort(sub_func_graph->get_return()), "", "split");
auto attr = GkUtils::ExtractGraphKernelName(TopoSort(sub_func_graph->get_return()), "", "split");
sub_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(attr));
// set kernel info

View File

@ -28,6 +28,7 @@
#include "frontend/operator/ops.h"
#include "ir/func_graph_cloner.h"
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
namespace mindspore::graphkernel {
namespace {
@ -746,8 +747,7 @@ bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo>
}
changed = true;
SetFusedParallelOpAttrToReturnNode(parallel_infos[i]);
AnfNodePtr sg_node;
std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel");
auto sg_node = ReplaceNodesWithGraphKernelNode(fuse_nodes, kernel_graph, "parallel");
AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node);
DumpParallelFusionDetail(fuse_nodes, sg_node);
}

View File

@ -31,6 +31,7 @@
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/model/lite_graph.h"
#include "backend/optimizer/graph_kernel/model/op_register.h"
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
namespace mindspore::graphkernel {
namespace {
@ -438,13 +439,11 @@ bool TransformOpOptimizer::Run(const FuncGraphPtr &kernel_graph) {
auto litegraph = AnfGraph2LiteGraph(sub_func_graph);
if (Process(litegraph)) {
changed = true;
AnfNodePtrList outputs;
auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs);
auto new_funcgraph = LiteGraph2AnfGraph(litegraph);
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto cnode = node->cast<CNodePtr>();
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs);
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs);
(void)mng->Replace(node, new_node);
mng->AddFuncGraph(new_funcgraph);
}

View File

@ -35,6 +35,7 @@
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
namespace mindspore::graphkernel {
class TsaChecker : public AtomicAddChecker {
@ -133,7 +134,7 @@ AnfNodePtr TsaAtomicAddToFirstTensor::ProcessTsaFirstNode(const KernelGraphPtr &
auto new_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph), tsa_first_input});
new_composite_node->set_abstract(identity_node->abstract());
SetNewKernelInfo(new_composite_node, new_sub_graph, {tsa_first_input}, {identity_node});
auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "tsa_identity");
auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "tsa_identity");
new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
new_sub_graph->set_attr("composite_type", MakeValue("tsa_identity"));
@ -198,7 +199,8 @@ void TsaAtomicAddToFirstTensor::ProcessOriginCNode(const AnfNodePtr &composite_n
CorrectKernelBuildInfo(composite_node, outter_node);
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "tensor_scatter_add_modified");
auto new_graph_name =
GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "tensor_scatter_add_modified");
sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
MS_LOG(INFO) << "Convert " << old_graph_name << " to tensor scatter add graph " << new_graph_name;
}

View File

@ -23,6 +23,7 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
namespace mindspore::graphkernel {
@ -34,21 +35,6 @@ AnfNodePtrList GetUpdateStateList(const FuncGraphPtr &func_graph) {
return result;
}
AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) {
AnfNodePtrList result;
for (size_t i = begin_index; i < nodes.size(); i++) {
if (IsPrimitiveCNode(nodes[i], prim::kPrimMakeTuple)) {
auto mt = nodes[i]->cast<CNodePtr>();
// recursively spread all inner tuples.
auto mt_inputs = SpreadTuples(mt->inputs(), 1);
result.insert(result.end(), mt_inputs.begin(), mt_inputs.end());
} else {
result.push_back(nodes[i]);
}
}
return result;
}
AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdateState(const AnfNodePtrList &nodes,
const FuncGraphPtr &func_graph) {
AnfNodePtrList result;
@ -85,7 +71,7 @@ bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() <= kUpdateStateRealInput) continue;
auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
auto inputs = GkUtils::SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
// extend inputs of UpdateState if which have multiple outputs
inputs = ExtendInputsOfUpdateState(inputs, func_graph);
if (inputs.size() + kUpdateStateRealInput != cnode->size() || inputs[0] != cnode->input(kUpdateStateRealInput)) {
@ -110,7 +96,7 @@ bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() <= kUpdateStateRealInput + 1) continue;
AnfNodePtrList mt_inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
AnfNodePtrList mt_inputs = GkUtils::SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
AbstractBasePtrList abs_list;
std::transform(mt_inputs.begin(), mt_inputs.end(), std::back_inserter(abs_list),
[](const AnfNodePtr &inp) { return inp->abstract(); });

View File

@ -61,20 +61,6 @@ class ShrinkUpdateState : public opt::Pass {
bool Run(const FuncGraphPtr &func_graph) override;
};
/**
* @brief Spread the MakeTuple in node list
* @param nodes
* @param begin_index
* @example
* input
* nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ]
* begin_index: 1
* output
* [b, i, j, c, d, x, y, z]
* @return std::vector<AnfNodePtr>
*/
AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0);
/**
* @brief Extend the getitem for UpdateState
* @example