forked from mindspore-Ecosystem/mindspore
!26515 Refactor FuseNodesToSubGraph and decouple it from graph_kernel_helper
Merge pull request !26515 from DeshiChen/1111_fusenodes
This commit is contained in:
commit
511441a27e
|
@ -17,7 +17,9 @@
|
|||
#include "backend/optimizer/graph_kernel/adapter/callback_impl.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
|
||||
|
@ -76,4 +78,62 @@ std::string CallbackImpl::GetOutputFormat(const AnfNodePtr &node, size_t i) {
|
|||
std::string CallbackImpl::GetProcessor(const AnfNodePtr &node) { return kernel::GetProcessorStr(node); }
|
||||
|
||||
std::string CallbackImpl::GetProcessorFromContext() { return kernel::GetStrProcessorFromContext(); }
|
||||
|
||||
void CallbackImpl::SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) {
|
||||
std::vector<std::string> graph_input_format;
|
||||
std::vector<TypeId> graph_input_type;
|
||||
std::vector<std::string> graph_output_format;
|
||||
std::vector<TypeId> graph_output_type;
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto fg = GetCNodeFuncGraph(node);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto &inputs = cnode->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
auto kernel_with_index = AnfUtils::VisitKernel(inputs[i], 0);
|
||||
if (kernel_with_index.first->isa<ValueNode>()) {
|
||||
auto tensor = GetValueNode<tensor::TensorPtr>(kernel_with_index.first);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
(void)graph_input_format.emplace_back(kOpFormat_DEFAULT);
|
||||
(void)graph_input_type.emplace_back(tensor->data_type());
|
||||
} else {
|
||||
auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
||||
(void)graph_input_format.emplace_back(std::move(input_format));
|
||||
auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
|
||||
(void)graph_input_type.emplace_back(input_type);
|
||||
}
|
||||
fg->parameters()[i - 1]->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder para_info_builder;
|
||||
para_info_builder.SetOutputsFormat({graph_input_format.back()});
|
||||
para_info_builder.SetOutputsDeviceType({graph_input_type.back()});
|
||||
para_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
para_info_builder.SetProcessor(kernel::GetProcessorFromContext());
|
||||
AnfAlgo::SetSelectKernelBuildInfo(para_info_builder.Build(), fg->parameters()[i - 1].get());
|
||||
}
|
||||
AnfNodePtrList outputs;
|
||||
if (IsPrimitiveCNode(fg->output(), prim::kPrimMakeTuple)) {
|
||||
auto fg_output = fg->output()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(fg_output);
|
||||
outputs.assign(fg_output->inputs().begin() + 1, fg_output->inputs().end());
|
||||
} else {
|
||||
outputs.push_back(fg->output());
|
||||
}
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(outputs[i], 0);
|
||||
auto output_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
|
||||
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
|
||||
graph_output_format.push_back(output_format);
|
||||
graph_output_type.push_back(output_type);
|
||||
}
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
|
||||
graph_info_builder.SetInputsFormat(graph_input_format);
|
||||
graph_info_builder.SetInputsDeviceType(graph_input_type);
|
||||
graph_info_builder.SetOutputsFormat(graph_output_format);
|
||||
graph_info_builder.SetOutputsDeviceType(graph_output_type);
|
||||
graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
|
||||
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
auto graph_selected_info = graph_info_builder.Build();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, node.get());
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
||||
|
|
|
@ -34,6 +34,7 @@ class CallbackImpl : public Callback {
|
|||
std::string GetOutputFormat(const AnfNodePtr &node, size_t i) override;
|
||||
std::string GetProcessor(const AnfNodePtr &node) override;
|
||||
std::string GetProcessorFromContext() override;
|
||||
void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) override;
|
||||
};
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADAPTER_CALLBACK_IMPL_H_
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "utils/log_adapter.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
|
@ -360,7 +361,7 @@ void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node,
|
|||
CorrectKernelBuildInfo(composite_node, new_input);
|
||||
|
||||
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
|
||||
auto new_graph_name = GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
|
||||
sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
|
||||
MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
|
||||
}
|
||||
|
@ -435,7 +436,7 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP
|
|||
auto broadcast_to_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph)});
|
||||
broadcast_to_composite_node->set_abstract(broadcast_to_node_inner->abstract());
|
||||
SetNewKernelInfo(broadcast_to_composite_node, new_sub_graph, {}, {broadcast_to_node_inner});
|
||||
auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean");
|
||||
auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean");
|
||||
new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
|
||||
new_sub_graph->set_attr("composite_type", MakeValue("atomic_clean"));
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
|
@ -82,7 +83,7 @@ void StitchAtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_
|
|||
}
|
||||
|
||||
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
|
||||
auto new_graph_name = GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "atomic_add");
|
||||
sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
|
||||
MS_LOG(INFO) << "Convert " << old_graph_name << " to atomic add graph " << new_graph_name;
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <unordered_map>
|
||||
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/anf.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
|
@ -645,14 +646,12 @@ bool ArithmeticSimplify::Run(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
if (!change_anf_graph) continue;
|
||||
ReorganizeEmptyGraph(lg);
|
||||
AnfNodePtrList outputs;
|
||||
auto new_funcgraph = LiteGraph2AnfGraph(lg, &outputs);
|
||||
auto new_funcgraph = LiteGraph2AnfGraph(lg);
|
||||
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
EliminateRedundantParameters(new_funcgraph, &inputs);
|
||||
auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs, outputs);
|
||||
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
|
||||
auto new_node = CreateNewFuseCNode(func_graph, new_funcgraph, inputs);
|
||||
mng->Replace(node, new_node);
|
||||
mng->AddFuncGraph(new_funcgraph);
|
||||
do_simplify = true;
|
||||
|
|
|
@ -25,29 +25,27 @@
|
|||
#include "base/core_ops.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_kernel_callback.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users,
|
||||
const std::unordered_set<AnfNodePtr> &seen) {
|
||||
namespace {
|
||||
// find outputs of nodes
|
||||
AnfNodePtrList FindOutputs(const AnfNodePtrList &nodes, const AnfNodePtrToAnfNodePtrMap &eqv) {
|
||||
AnfNodePtrList output;
|
||||
if (users.size() == 0) {
|
||||
return output;
|
||||
}
|
||||
auto mng = nodes[0]->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto &users = mng->node_users();
|
||||
for (auto &node : nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
// only CNode can be an output.
|
||||
if (!node->isa<CNode>()) continue;
|
||||
auto iter = users.find(node);
|
||||
if (iter == users.end()) {
|
||||
continue;
|
||||
}
|
||||
if (iter == users.end()) continue;
|
||||
auto &node_users = iter->second;
|
||||
const bool has_outer_user = std::any_of(std::begin(node_users), std::end(node_users),
|
||||
[&seen](const std::pair<AnfNodePtr, int64_t> &u) -> bool {
|
||||
const bool is_outer_user = (seen.find(u.first) == seen.end());
|
||||
return is_outer_user;
|
||||
});
|
||||
if (has_outer_user) {
|
||||
// if any user of the `node` is not in the nodes list, the `node` is an output.
|
||||
if (std::any_of(std::begin(node_users), std::end(node_users),
|
||||
[&eqv](const std::pair<AnfNodePtr, int> &u) { return eqv.find(u.first) == eqv.end(); })) {
|
||||
output.emplace_back(node);
|
||||
}
|
||||
}
|
||||
|
@ -56,12 +54,11 @@ AnfNodePtrList GetOutput(const AnfNodePtrList &nodes, const NodeUsersMap &users,
|
|||
|
||||
AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNodePtrList *const inputs_ptr,
|
||||
AnfNodePtrToAnfNodePtrMap *eqv_ptr) {
|
||||
auto &input_list = *inputs_ptr;
|
||||
auto &eqv = *eqv_ptr;
|
||||
if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
|
||||
eqv[node] = node;
|
||||
} else if (eqv.find(node) == eqv.end()) {
|
||||
input_list.push_back(node);
|
||||
inputs_ptr->push_back(node);
|
||||
eqv[node] = fg->add_parameter();
|
||||
eqv[node]->set_abstract(node->abstract());
|
||||
eqv[node]->set_kernel_info(node->kernel_info_ptr());
|
||||
|
@ -69,43 +66,150 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo
|
|||
return eqv[node];
|
||||
}
|
||||
|
||||
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(const AnfNodePtrList &node_list) {
|
||||
bool InlineInnerFuncGraph(const FuncGraphPtr &fg) {
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
bool changed = false;
|
||||
auto cnodes = fg->GetOrderedCnodes();
|
||||
for (const auto &n : cnodes) {
|
||||
auto graph_kernel_g = GetCNodeFuncGraph(n);
|
||||
if (graph_kernel_g == nullptr) continue;
|
||||
AnfNodePtrList inp(n->inputs().begin() + 1, n->inputs().end());
|
||||
auto out = InlineClone(graph_kernel_g, fg, inp, n->input(0)->scope());
|
||||
mng->Replace(n, out);
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
void EliminateMakeTuple(const FuncGraphPtr &fg) {
|
||||
if (!IsPrimitiveCNode(fg->output(), prim::kPrimMakeTuple)) {
|
||||
return;
|
||||
}
|
||||
auto out_cnode = fg->output()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_cnode);
|
||||
AnfNodePtrList new_args = GkUtils::SpreadTuples(out_cnode->inputs());
|
||||
if (new_args.size() != out_cnode->size()) {
|
||||
auto new_out = fg->NewCNode(new_args);
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
mng->Replace(out_cnode, new_out);
|
||||
}
|
||||
AbstractBasePtrList abs_list;
|
||||
std::transform(new_args.begin() + 1, new_args.end(), std::back_inserter(abs_list),
|
||||
[](const AnfNodePtr &node) { return node->abstract(); });
|
||||
fg->output()->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
|
||||
}
|
||||
|
||||
bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) {
|
||||
auto cnodes = fg->GetOrderedCnodes();
|
||||
AnfNodePtrList value_nodes;
|
||||
for (const auto &cnode : cnodes) {
|
||||
auto &inputs = cnode->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
const auto &tnode = inputs[i];
|
||||
auto tensor = GetValueNode<tensor::TensorPtr>(tnode);
|
||||
if (tensor == nullptr || tensor->DataSize() == 1) {
|
||||
continue;
|
||||
}
|
||||
value_nodes.push_back(tnode);
|
||||
}
|
||||
}
|
||||
if (value_nodes.empty()) return false;
|
||||
auto mng = fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
for (auto &vnode : value_nodes) {
|
||||
auto parameter = fg->add_parameter();
|
||||
parameter->set_abstract(vnode->abstract());
|
||||
parameter->set_kernel_info(vnode->kernel_info_ptr());
|
||||
mng->Replace(vnode, parameter);
|
||||
inputs_ptr->push_back(vnode);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsTupleOutput(const AnfNodePtr &out, AnfNodePtrList *real_outs) {
|
||||
if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
|
||||
auto &inputs = out->cast<CNodePtr>()->inputs();
|
||||
real_outs->assign(inputs.begin() + 1, inputs.end());
|
||||
return true;
|
||||
}
|
||||
if (auto fg = GetCNodeFuncGraph(out); fg != nullptr) {
|
||||
return IsTupleOutput(fg->output(), real_outs);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_fuse_cnode,
|
||||
const AnfNodePtrList &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
// single out
|
||||
if (outputs.size() == 1) {
|
||||
mng->Replace(outputs[0], new_fuse_cnode);
|
||||
return;
|
||||
}
|
||||
|
||||
size_t offset = 0;
|
||||
for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) {
|
||||
AnfNodePtrList real_outs;
|
||||
// the output is a single tensor
|
||||
if (!IsTupleOutput(outputs[out_idx], &real_outs)) {
|
||||
auto gt_idx = MakeValue(SizeToLong(out_idx + offset));
|
||||
AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), new_fuse_cnode, NewValueNode(gt_idx)};
|
||||
gt_inputs.back()->set_abstract(gt_idx->ToAbstract());
|
||||
auto new_out = func_graph->NewCNode(gt_inputs);
|
||||
new_out->set_abstract(outputs[out_idx]->abstract());
|
||||
mng->Replace(outputs[out_idx], new_out);
|
||||
continue;
|
||||
}
|
||||
|
||||
// the out is make tuple , modify the get_item node's value
|
||||
auto users = mng->node_users()[outputs[out_idx]]; // use a copy, the original user map is changed in for-loop.
|
||||
for (auto &user : users) {
|
||||
auto getitem_node = user.first;
|
||||
if (!getitem_node->isa<CNode>() || !IsPrimitiveCNode(getitem_node, prim::kPrimTupleGetItem)) {
|
||||
continue;
|
||||
}
|
||||
auto value_ptr = GetValueNode(getitem_node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem));
|
||||
MS_EXCEPTION_IF_NULL(value_ptr);
|
||||
auto old_gt_idx = GetValue<int64_t>(value_ptr);
|
||||
auto gt_idx = MakeValue(SizeToLong(out_idx + offset) + old_gt_idx);
|
||||
AnfNodePtrList gt_inputs{NewValueNode(prim::kPrimTupleGetItem), new_fuse_cnode, NewValueNode(gt_idx)};
|
||||
gt_inputs.back()->set_abstract(gt_idx->ToAbstract());
|
||||
auto new_getitem_node = func_graph->NewCNode(gt_inputs);
|
||||
new_getitem_node->set_abstract(getitem_node->abstract());
|
||||
mng->Replace(getitem_node, new_getitem_node);
|
||||
}
|
||||
|
||||
offset += real_outs.size() - 1;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(const AnfNodePtrList &nodes) {
|
||||
FuncGraphPtr fg = nullptr;
|
||||
{
|
||||
// limit the lifetime of guard.
|
||||
TraceGuard guard(
|
||||
std::make_shared<TraceSegmentTransform>(node_list[0]->cast<CNodePtr>()->func_graph()->debug_info()));
|
||||
TraceGuard guard(std::make_shared<TraceSegmentTransform>(nodes[0]->cast<CNodePtr>()->func_graph()->debug_info()));
|
||||
fg = std::make_shared<FuncGraph>();
|
||||
}
|
||||
AnfNodePtrList input_list;
|
||||
AnfNodePtrToAnfNodePtrMap eqv;
|
||||
// Merge CNodes into a AnfGraph that represents a linear instruction segment
|
||||
for (auto node : node_list) {
|
||||
auto &input_nodes = node->cast<CNodePtr>()->inputs();
|
||||
auto fn = input_nodes[0];
|
||||
std::vector<AnfNodePtr> new_args{fn};
|
||||
if (IsPrimitive(fn, prim::kPrimDepend) && input_nodes.size() >= kDependInputSize &&
|
||||
eqv.find(input_nodes[kDependAttachNodeIndex]) == eqv.end()) {
|
||||
new_args.emplace_back(RefSubGraphNode(fg, input_nodes[kRealInputIndexInDepend], &input_list, &eqv));
|
||||
const size_t value_start_index = 2;
|
||||
for (size_t i = value_start_index; i < input_nodes.size(); ++i) {
|
||||
new_args.emplace_back(NewValueNode(MakeValue(0)));
|
||||
}
|
||||
} else {
|
||||
(void)std::transform(
|
||||
std::begin(input_nodes) + 1, std::end(input_nodes), std::back_inserter(new_args),
|
||||
[&fg, &input_list, &eqv](const AnfNodePtr &node) { return RefSubGraphNode(fg, node, &input_list, &eqv); });
|
||||
}
|
||||
for (auto &node : nodes) {
|
||||
auto &node_inputs = node->cast<CNodePtr>()->inputs();
|
||||
std::vector<AnfNodePtr> new_args{node_inputs[0]};
|
||||
(void)std::transform(
|
||||
std::begin(node_inputs) + 1, std::end(node_inputs), std::back_inserter(new_args),
|
||||
[&fg, &input_list, &eqv](const AnfNodePtr &node) { return RefSubGraphNode(fg, node, &input_list, &eqv); });
|
||||
TraceGuard tg(std::make_shared<TraceSegmentTransform>(node->debug_info()));
|
||||
eqv[node] = fg->NewCNode(new_args);
|
||||
eqv[node]->set_abstract(node->abstract());
|
||||
eqv[node]->set_kernel_info(node->kernel_info_ptr());
|
||||
}
|
||||
std::unordered_set<AnfNodePtr> eqv_keys;
|
||||
(void)std::transform(std::begin(eqv), std::end(eqv), std::inserter(eqv_keys, eqv_keys.end()),
|
||||
[](const std::pair<AnfNodePtr, AnfNodePtr> &elem) -> AnfNodePtr { return elem.first; });
|
||||
auto mgr = node_list[0]->func_graph()->manager();
|
||||
auto outputs = GetOutput(node_list, mgr->node_users(), eqv_keys);
|
||||
auto outputs = FindOutputs(nodes, eqv);
|
||||
AnfNodePtr fg_output;
|
||||
if (outputs.size() > 1) {
|
||||
std::vector<AnfNodePtr> output_args;
|
||||
|
@ -120,4 +224,52 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(con
|
|||
fg->set_output(fg_output);
|
||||
return std::make_tuple(fg, input_list, outputs);
|
||||
}
|
||||
|
||||
// Transform nodes(including basic and composite node) to a new graph, and collect their inputs and outputs.
|
||||
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildSingleGraphFromNodes(const AnfNodePtrList &nodes) {
|
||||
FuncGraphPtr fg;
|
||||
AnfNodePtrList inputs;
|
||||
AnfNodePtrList outputs;
|
||||
std::tie(fg, inputs, outputs) = BuildGraphFromNodes(nodes);
|
||||
|
||||
FuncGraphManagerPtr mng = fg->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(fg, false);
|
||||
fg->set_manager(mng);
|
||||
}
|
||||
|
||||
InlineInnerFuncGraph(fg);
|
||||
// eliminate tuple of tuple, and set Abstract for output MakeTuple
|
||||
EliminateMakeTuple(fg);
|
||||
ConvertNonscalarTensorToParameter(fg, &inputs);
|
||||
|
||||
return std::make_tuple(fg, inputs, outputs);
|
||||
}
|
||||
|
||||
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &main_fg, const FuncGraphPtr &sub_fg, const AnfNodePtrList &inputs) {
|
||||
std::vector<AnfNodePtr> fn_inputs{NewValueNode(sub_fg)};
|
||||
fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end());
|
||||
auto fuse_cnode = main_fg->NewCNode(fn_inputs);
|
||||
fuse_cnode->set_abstract(sub_fg->output()->abstract());
|
||||
Callback::Instance()->SetGraphKernelNodeKernelInfo(fuse_cnode);
|
||||
return fuse_cnode;
|
||||
}
|
||||
|
||||
AnfNodePtr ReplaceNodesWithGraphKernelNode(const AnfNodePtrList &nodes, const FuncGraphPtr &main_graph,
|
||||
const std::string &postfix) {
|
||||
auto mng = main_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(main_graph, true);
|
||||
main_graph->set_manager(mng);
|
||||
}
|
||||
FuncGraphPtr fg;
|
||||
AnfNodePtrList inputs;
|
||||
AnfNodePtrList outputs;
|
||||
std::tie(fg, inputs, outputs) = BuildSingleGraphFromNodes(nodes);
|
||||
auto fuse_new_node = CreateNewFuseCNode(main_graph, fg, inputs);
|
||||
ReplaceNewFuseCNode(main_graph, fuse_new_node, outputs);
|
||||
auto fuse_op_name = GkUtils::ExtractGraphKernelName(nodes, "", postfix);
|
||||
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
|
||||
return fuse_new_node;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
||||
|
|
|
@ -18,11 +18,16 @@
|
|||
|
||||
#include <unordered_map>
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
using AnfNodePtrToAnfNodePtrMap = std::unordered_map<AnfNodePtr, AnfNodePtr>;
|
||||
|
||||
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(const AnfNodePtrList &lst);
|
||||
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildGraphFromNodes(const AnfNodePtrList &nodes);
|
||||
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> BuildSingleGraphFromNodes(const AnfNodePtrList &nodes);
|
||||
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &main_fg, const FuncGraphPtr &sub_fg, const AnfNodePtrList &inputs);
|
||||
AnfNodePtr ReplaceNodesWithGraphKernelNode(const AnfNodePtrList &nodes, const FuncGraphPtr &main_graph,
|
||||
const std::string &postfix = "");
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_BUILDER_H_
|
||||
|
|
|
@ -120,6 +120,13 @@ class Callback {
|
|||
*/
|
||||
virtual std::string GetProcessorFromContext() = 0;
|
||||
|
||||
/**
|
||||
* @brief Set KernelInfo for a GraphKernel node, the info is extract from its inputs/outputs.
|
||||
*
|
||||
* @param[in] node the GraphKernel CNode.
|
||||
*/
|
||||
virtual void SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) = 0;
|
||||
|
||||
private:
|
||||
friend class CallbackImplRegister;
|
||||
static void RegImpl(Callback *cb) { instance_.reset(cb); }
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include <sstream>
|
||||
#include "base/core_ops.h"
|
||||
#include "utils/anf_utils.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
std::string GkUtils::ExtractGraphKernelName(const AnfNodePtrList &nodes, const std::string &prefix,
|
||||
const std::string &postfix) {
|
||||
std::stringstream name;
|
||||
if (!prefix.empty()) {
|
||||
name << prefix << "_";
|
||||
}
|
||||
for (const auto &node : nodes) {
|
||||
if (AnfUtils::IsGraphKernel(node)) {
|
||||
auto fg_flag_val = GetCNodeFuncGraph(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
name << GetValue<std::string>(fg_flag_val) << "_";
|
||||
} else if (node->isa<CNode>() && AnfUtils::IsRealKernel(node)) {
|
||||
name << GetCNodePrimitive(node)->name() << "_";
|
||||
}
|
||||
}
|
||||
if (!postfix.empty()) {
|
||||
name << postfix;
|
||||
}
|
||||
return name.str();
|
||||
}
|
||||
|
||||
AnfNodePtrList GkUtils::SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) {
|
||||
AnfNodePtrList result;
|
||||
for (size_t i = begin_index; i < nodes.size(); i++) {
|
||||
if (IsPrimitiveCNode(nodes[i], prim::kPrimMakeTuple)) {
|
||||
auto mt = nodes[i]->cast<CNodePtr>();
|
||||
// recursively spread all inner tuples.
|
||||
auto mt_inputs = SpreadTuples(mt->inputs(), 1);
|
||||
result.insert(result.end(), mt_inputs.begin(), mt_inputs.end());
|
||||
} else {
|
||||
result.push_back(nodes[i]);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_
|
||||
|
||||
#include <string>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
class GkUtils {
|
||||
public:
|
||||
/**
|
||||
* @brief Extract kernel name from nodes, only the real kernel CNode is processed.
|
||||
* @param[in] nodes The node list
|
||||
* @param[in] prefix The prefix of result name
|
||||
* @param[in] postfix The postfix of result name
|
||||
* @return The string concatenated by the names of all cnodes
|
||||
*/
|
||||
static std::string ExtractGraphKernelName(const AnfNodePtrList &nodes, const std::string &prefix = "",
|
||||
const std::string &postfix = "");
|
||||
|
||||
/**
|
||||
* @brief Spread the MakeTuple in node list
|
||||
* @param[in] nodes
|
||||
* @param[in] begin_index
|
||||
* @example
|
||||
* input
|
||||
* nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ]
|
||||
* begin_index: 1
|
||||
* output
|
||||
* [b, i, j, c, d, x, y, z]
|
||||
* @return std::vector<AnfNodePtr>
|
||||
*/
|
||||
static AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0);
|
||||
};
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_
|
|
@ -29,6 +29,7 @@
|
|||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
|
@ -257,10 +258,7 @@ AnfNodePtr EliminateHangingOutput::ReplaceMakeTuple(const AnfNodePtr &node, cons
|
|||
auto old_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(old_cnode);
|
||||
AnfNodePtrList inputs(old_cnode->inputs().begin() + 1, old_cnode->inputs().end());
|
||||
AnfNodePtrList outputs;
|
||||
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
|
||||
auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs, outputs);
|
||||
SetNewKernelInfo(graph_kernel_node, func_graph, inputs, outputs);
|
||||
auto graph_kernel_node = CreateNewFuseCNode(node->func_graph(), func_graph, inputs);
|
||||
return graph_kernel_node;
|
||||
}
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
std::vector<PrimitivePtr> GraphKernelCluster::GetClusterableOpList() {
|
||||
|
@ -404,10 +405,9 @@ bool GraphKernelCluster::Process(const FuncGraphPtr &func_graph) {
|
|||
|
||||
void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id) {
|
||||
AnfNodePtrList old_nodes;
|
||||
AnfNodePtr new_node;
|
||||
(void)std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes),
|
||||
[this](size_t id) { return this->nodes_[id]; });
|
||||
std::tie(new_node, std::ignore) = FuseNodesToSubGraph(old_nodes, func_graph, "fusion");
|
||||
auto new_node = ReplaceNodesWithGraphKernelNode(old_nodes, func_graph, "fusion");
|
||||
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
|
||||
(void)eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(new_node));
|
||||
if (GraphKernelFlags::GetInstance().dump_as_text) {
|
||||
|
|
|
@ -37,6 +37,7 @@
|
|||
#include "pybind_api/ir/primitive_py.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/optimizer/graph_kernel/expanders/expander_factory.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
|
@ -164,12 +165,9 @@ AnfNodePtr PyExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_grap
|
|||
auto func_graph = old_node->func_graph();
|
||||
std::vector<AnfNodePtr> inputs(old_node->inputs().begin() + 1, old_node->inputs().end());
|
||||
AnfNodePtrList kernel_nodes;
|
||||
AnfNodePtrList outputs;
|
||||
EliminateRedundantParameters(new_func_graph, &inputs);
|
||||
kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes);
|
||||
kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs);
|
||||
auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs);
|
||||
SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs);
|
||||
auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs);
|
||||
MS_LOG(DEBUG) << "Expand node: " << old_node->fullname_with_scope()
|
||||
<< " with: " << graph_kernel_node->fullname_with_scope();
|
||||
return graph_kernel_node;
|
||||
|
|
|
@ -67,37 +67,6 @@ bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) {
|
|||
return false;
|
||||
}
|
||||
|
||||
AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) {
|
||||
AnfNodePtrList outs;
|
||||
auto out_node = fg->output();
|
||||
if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
|
||||
std::vector<AnfNodePtr> output_args;
|
||||
auto out_cnode = out_node->cast<CNodePtr>();
|
||||
for (auto out : out_cnode->inputs()) {
|
||||
if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
|
||||
auto inputs = out->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
output_args.push_back(inputs[i]);
|
||||
}
|
||||
} else {
|
||||
output_args.push_back(out);
|
||||
}
|
||||
}
|
||||
if (output_args.size() != out_cnode->inputs().size()) {
|
||||
auto new_out = fg->NewCNode(output_args);
|
||||
mng->Replace(out_node, new_out);
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < output_args.size(); ++i) {
|
||||
outs.push_back(output_args[i]);
|
||||
}
|
||||
return outs;
|
||||
}
|
||||
|
||||
outs.push_back(out_node);
|
||||
return outs;
|
||||
}
|
||||
|
||||
bool GenJson(const AnfNodePtrList &op_nodes, const std::pair<AnfNodePtrList, AnfNodePtrList> &in_and_out,
|
||||
const DumpOption &dump_option, nlohmann::json *op_desc,
|
||||
std::map<std::string, AnfNodePtr> *address_node_map = nullptr) {
|
||||
|
@ -128,100 +97,6 @@ AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) {
|
|||
return out_spec;
|
||||
}
|
||||
|
||||
bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) {
|
||||
MS_EXCEPTION_IF_NULL(inputs_ptr);
|
||||
auto nodes = TopoSort(fg->get_return());
|
||||
|
||||
std::vector<std::pair<tensor::TensorPtr, AnfNodePtrList>> v_replace;
|
||||
for (const auto &node : nodes) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
const auto &tnode = inputs[i];
|
||||
auto tensor = GetValueNode<tensor::TensorPtr>(tnode);
|
||||
if (tensor == nullptr || tensor->DataSize() == 1) {
|
||||
continue;
|
||||
}
|
||||
auto tensor_iter = std::find_if(
|
||||
v_replace.begin(), v_replace.end(),
|
||||
[&tensor](const std::pair<tensor::TensorPtr, AnfNodePtrList> &vl) { return vl.first->ValueEqual(*tensor); });
|
||||
if (tensor_iter == v_replace.end()) {
|
||||
(void)v_replace.emplace_back(tensor, AnfNodePtrList{tnode});
|
||||
} else {
|
||||
tensor_iter->second.push_back(tnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (v_replace.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto mng = fg->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(fg, false);
|
||||
fg->set_manager(mng);
|
||||
}
|
||||
|
||||
auto &inputs = *inputs_ptr;
|
||||
for (auto iter : v_replace) {
|
||||
auto value_nodes = iter.second;
|
||||
if (value_nodes.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid value in map!";
|
||||
}
|
||||
|
||||
auto vnode = value_nodes[0];
|
||||
auto parameter = fg->add_parameter();
|
||||
parameter->set_abstract(vnode->abstract());
|
||||
parameter->set_kernel_info(vnode->kernel_info_ptr());
|
||||
for (const auto &value_node : value_nodes) {
|
||||
mng->Replace(value_node, parameter);
|
||||
}
|
||||
|
||||
inputs.push_back(vnode);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Transform nodes(including basic and composite node) to a new graph, and collect their inputs and outputs.
|
||||
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(const AnfNodePtrList &fuse_nodes,
|
||||
AnfNodePtrList *src_outputs) {
|
||||
FuncGraphPtr fg;
|
||||
AnfNodePtrList inputs;
|
||||
AnfNodePtrList outputs;
|
||||
AnfNodePtrList *soutputs = (src_outputs != nullptr) ? src_outputs : &outputs;
|
||||
std::tie(fg, inputs, *soutputs) = BuildGraphFromNodes(fuse_nodes);
|
||||
|
||||
FuncGraphManagerPtr mng = fg->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(fg, false);
|
||||
fg->set_manager(mng);
|
||||
}
|
||||
|
||||
// Inline origin graphkernel
|
||||
auto cnodes = fg->GetOrderedCnodes();
|
||||
for (const auto &n : cnodes) {
|
||||
if (!AnfAlgo::IsGraphKernel(n)) {
|
||||
continue;
|
||||
}
|
||||
auto graph_kernel_g = GetValueNode<FuncGraphPtr>(n->input(0));
|
||||
AnfNodePtrList ins;
|
||||
ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end());
|
||||
auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope());
|
||||
mng->Replace(n, out);
|
||||
}
|
||||
|
||||
EliminateMakeTuple(fg, mng);
|
||||
ConvertNonscalarTensorToParameter(fg, &inputs);
|
||||
|
||||
outputs.clear();
|
||||
kernel::GetFuncGraphOutputNodes(fg, &outputs);
|
||||
return std::make_tuple(fg, inputs, outputs);
|
||||
}
|
||||
|
||||
// Rebuild as node inputs or outputs have changed, processor comes from node itself
|
||||
kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
|
||||
const std::vector<TypeId> &inputs_type,
|
||||
|
@ -254,6 +129,7 @@ kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::str
|
|||
return graph_info_builder.Build();
|
||||
}
|
||||
|
||||
// Deprecated. use Callback->SetGraphKernelNodeKernelInfo.
|
||||
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
|
||||
const AnfNodePtrList &outputs) {
|
||||
std::vector<std::string> graph_input_format;
|
||||
|
@ -309,127 +185,6 @@ void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const
|
|||
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, new_node.get());
|
||||
}
|
||||
|
||||
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
|
||||
const AnfNodePtrList &outputs) {
|
||||
auto func_node = NewValueNode(fg);
|
||||
std::vector<AnfNodePtr> fn_inputs;
|
||||
fn_inputs.push_back(func_node);
|
||||
fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end());
|
||||
auto fuse_cnode = func_graph->NewCNode(fn_inputs);
|
||||
// Set output abstract
|
||||
if (outputs.size() > 1) {
|
||||
std::vector<AbstractBasePtr> out_specs;
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
out_specs.push_back(outputs[i]->abstract());
|
||||
}
|
||||
auto out_spec = std::make_shared<abstract::AbstractTuple>(out_specs);
|
||||
fuse_cnode->set_abstract(out_spec);
|
||||
} else {
|
||||
fuse_cnode->set_abstract(outputs[0]->abstract());
|
||||
}
|
||||
// Set parameter abstract.
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
|
||||
auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
|
||||
fg->parameters()[i]->set_abstract(input_abs);
|
||||
}
|
||||
return fuse_cnode;
|
||||
}
|
||||
|
||||
void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_fuse_cnode,
|
||||
const AnfNodePtrList &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
// single out
|
||||
if (outputs.size() == 1) {
|
||||
mng->Replace(outputs[0], new_fuse_cnode);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> fn_inputs;
|
||||
size_t offset = 0;
|
||||
for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) {
|
||||
AnfNodePtrList real_outs;
|
||||
// not make tuple out, replace
|
||||
if (!IsMakeTupleOut(outputs[out_idx], &real_outs)) {
|
||||
fn_inputs.clear();
|
||||
fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
|
||||
fn_inputs.push_back(new_fuse_cnode);
|
||||
fn_inputs.push_back(NewValueNode(MakeValue(SizeToLong(out_idx + offset))));
|
||||
auto new_out = func_graph->NewCNode(fn_inputs);
|
||||
new_out->set_abstract(outputs[out_idx]->abstract());
|
||||
mng->Replace(outputs[out_idx], new_out);
|
||||
continue;
|
||||
}
|
||||
|
||||
// the out is make tuple , modify the get_item node's value
|
||||
auto users = mng->node_users()[outputs[out_idx]];
|
||||
for (auto &user : users) {
|
||||
auto use_node = user.first;
|
||||
if (!use_node->isa<CNode>() || !IsPrimitiveCNode(use_node, prim::kPrimTupleGetItem)) {
|
||||
continue;
|
||||
}
|
||||
auto get_item_cnode = use_node->cast<CNodePtr>();
|
||||
auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(value_input);
|
||||
auto value_node = value_input->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto item_idx = GetValue<int64_t>(value_node->value());
|
||||
int64_t new_item_idx = SizeToLong(out_idx + offset) + item_idx;
|
||||
fn_inputs.clear();
|
||||
fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
|
||||
fn_inputs.push_back(new_fuse_cnode);
|
||||
fn_inputs.push_back(NewValueNode(new_item_idx));
|
||||
auto new_out = func_graph->NewCNode(fn_inputs);
|
||||
new_out->set_abstract(get_item_cnode->abstract());
|
||||
mng->Replace(get_item_cnode, new_out);
|
||||
}
|
||||
|
||||
offset += real_outs.size() - 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
||||
const FuncGraphPtr &kernel_graph,
|
||||
const std::string &postfix) {
|
||||
auto mng = kernel_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(kernel_graph, true);
|
||||
kernel_graph->set_manager(mng);
|
||||
}
|
||||
|
||||
FuncGraphPtr fg;
|
||||
AnfNodePtrList inputs;
|
||||
AnfNodePtrList src_outputs;
|
||||
AnfNodePtrList outputs;
|
||||
|
||||
std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(fuse_nodes, &src_outputs);
|
||||
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs);
|
||||
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs);
|
||||
// Handle get-item probleam.
|
||||
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, src_outputs);
|
||||
|
||||
// set graphKernel attr
|
||||
std::string fuse_op_name = "";
|
||||
for (auto &fuse_node : fuse_nodes) {
|
||||
if (IsPrimitiveCNode(fuse_node)) {
|
||||
fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_";
|
||||
} else if (AnfAlgo::IsGraphKernel(fuse_node)) {
|
||||
auto fuse_cnode = fuse_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(fuse_cnode);
|
||||
auto graph_kernel_fg = GetValueNode<FuncGraphPtr>(fuse_cnode->input(kAnfPrimitiveIndex));
|
||||
auto fg_flag_val = graph_kernel_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
auto fuse_fg_name = GetValue<std::string>(fg_flag_val);
|
||||
fuse_op_name += fuse_fg_name + "_";
|
||||
}
|
||||
}
|
||||
fuse_op_name += postfix;
|
||||
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
|
||||
|
||||
return std::make_tuple(fuse_new_node, src_outputs);
|
||||
}
|
||||
|
||||
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
|
||||
std::map<std::string, AnfNodePtr> *address_node_map) {
|
||||
MS_EXCEPTION_IF_NULL(op_desc);
|
||||
|
@ -466,15 +221,14 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n
|
|||
}
|
||||
|
||||
FuncGraphPtr fg;
|
||||
AnfNodePtrList op_nodes, inputs, outputs;
|
||||
|
||||
if (nodes.size() == 1 && AnfAlgo::IsGraphKernel(nodes[0])) {
|
||||
fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]);
|
||||
} else {
|
||||
std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(nodes);
|
||||
inputs.clear();
|
||||
outputs.clear();
|
||||
std::tie(fg, std::ignore, std::ignore) = BuildSingleGraphFromNodes(nodes);
|
||||
}
|
||||
|
||||
AnfNodePtrList op_nodes, inputs, outputs;
|
||||
kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs);
|
||||
|
||||
auto mng = fg->manager();
|
||||
|
@ -524,22 +278,6 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc) {
|
|||
return fg;
|
||||
}
|
||||
|
||||
std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix, const string &postfix) {
|
||||
std::stringstream name;
|
||||
if (prefix != "") {
|
||||
name << prefix << "_";
|
||||
}
|
||||
for (const auto &node : cnodes) {
|
||||
if (node->isa<CNode>() && AnfUtils::IsRealKernel(node)) {
|
||||
name << AnfAlgo::GetCNodeName(node) << "_";
|
||||
}
|
||||
}
|
||||
if (postfix != "") {
|
||||
name << postfix;
|
||||
}
|
||||
return name.str();
|
||||
}
|
||||
|
||||
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
|
|
@ -53,9 +53,6 @@ struct DataInfo {
|
|||
TypePtr type{nullptr};
|
||||
};
|
||||
|
||||
bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr);
|
||||
std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(const AnfNodePtrList &fuse_nodes,
|
||||
AnfNodePtrList *src_outputs = nullptr);
|
||||
void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
|
||||
const AnfNodePtrList &outputs);
|
||||
kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
|
||||
|
@ -66,19 +63,11 @@ kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::str
|
|||
const std::vector<TypeId> &inputs_type,
|
||||
const std::vector<std::string> &output_formats,
|
||||
const std::vector<TypeId> &output_types);
|
||||
AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
|
||||
const AnfNodePtrList &outputs);
|
||||
void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode,
|
||||
const AnfNodePtrList &outputs);
|
||||
std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
||||
const FuncGraphPtr &kernel_graph,
|
||||
const std::string &postfix = "");
|
||||
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc);
|
||||
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
|
||||
std::map<std::string, AnfNodePtr> *address_node_map);
|
||||
bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc);
|
||||
FuncGraphPtr JsonDescToAnf(const std::string &json_desc);
|
||||
std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = "");
|
||||
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
|
||||
|
||||
std::string GetFormat(const AnfNodePtr &node);
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
|
||||
|
@ -632,7 +633,7 @@ class Splitter {
|
|||
graph_manager->AddFuncGraph(sub_func_graph);
|
||||
|
||||
// set GraphKernel attr
|
||||
auto attr = ExtractGraphKernelName(TopoSort(sub_func_graph->get_return()), "", "split");
|
||||
auto attr = GkUtils::ExtractGraphKernelName(TopoSort(sub_func_graph->get_return()), "", "split");
|
||||
sub_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(attr));
|
||||
|
||||
// set kernel info
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "frontend/operator/ops.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
|
@ -746,8 +747,7 @@ bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo>
|
|||
}
|
||||
changed = true;
|
||||
SetFusedParallelOpAttrToReturnNode(parallel_infos[i]);
|
||||
AnfNodePtr sg_node;
|
||||
std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel");
|
||||
auto sg_node = ReplaceNodesWithGraphKernelNode(fuse_nodes, kernel_graph, "parallel");
|
||||
AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node);
|
||||
DumpParallelFusionDetail(fuse_nodes, sg_node);
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/graph_kernel/model/lite_graph.h"
|
||||
#include "backend/optimizer/graph_kernel/model/op_register.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_builder.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
namespace {
|
||||
|
@ -438,13 +439,11 @@ bool TransformOpOptimizer::Run(const FuncGraphPtr &kernel_graph) {
|
|||
auto litegraph = AnfGraph2LiteGraph(sub_func_graph);
|
||||
if (Process(litegraph)) {
|
||||
changed = true;
|
||||
AnfNodePtrList outputs;
|
||||
auto new_funcgraph = LiteGraph2AnfGraph(litegraph, &outputs);
|
||||
auto new_funcgraph = LiteGraph2AnfGraph(litegraph);
|
||||
new_funcgraph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, sub_func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs, outputs);
|
||||
SetNewKernelInfo(new_node, new_funcgraph, inputs, outputs);
|
||||
auto new_node = CreateNewFuseCNode(kernel_graph, new_funcgraph, inputs);
|
||||
(void)mng->Replace(node, new_node);
|
||||
mng->AddFuncGraph(new_funcgraph);
|
||||
}
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
class TsaChecker : public AtomicAddChecker {
|
||||
|
@ -133,7 +134,7 @@ AnfNodePtr TsaAtomicAddToFirstTensor::ProcessTsaFirstNode(const KernelGraphPtr &
|
|||
auto new_composite_node = main_graph->NewCNode({NewValueNode(new_sub_graph), tsa_first_input});
|
||||
new_composite_node->set_abstract(identity_node->abstract());
|
||||
SetNewKernelInfo(new_composite_node, new_sub_graph, {tsa_first_input}, {identity_node});
|
||||
auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "tsa_identity");
|
||||
auto graph_attr = GkUtils::ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "tsa_identity");
|
||||
new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
|
||||
new_sub_graph->set_attr("composite_type", MakeValue("tsa_identity"));
|
||||
|
||||
|
@ -198,7 +199,8 @@ void TsaAtomicAddToFirstTensor::ProcessOriginCNode(const AnfNodePtr &composite_n
|
|||
CorrectKernelBuildInfo(composite_node, outter_node);
|
||||
|
||||
auto old_graph_name = GetValue<std::string>(sub_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
auto new_graph_name = ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "tensor_scatter_add_modified");
|
||||
auto new_graph_name =
|
||||
GkUtils::ExtractGraphKernelName(TopoSort(sub_graph->get_return()), "", "tensor_scatter_add_modified");
|
||||
sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(new_graph_name));
|
||||
MS_LOG(INFO) << "Convert " << old_graph_name << " to tensor scatter add graph " << new_graph_name;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/graph_kernel/core/graph_kernel_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
|
@ -34,21 +35,6 @@ AnfNodePtrList GetUpdateStateList(const FuncGraphPtr &func_graph) {
|
|||
return result;
|
||||
}
|
||||
|
||||
AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index) {
|
||||
AnfNodePtrList result;
|
||||
for (size_t i = begin_index; i < nodes.size(); i++) {
|
||||
if (IsPrimitiveCNode(nodes[i], prim::kPrimMakeTuple)) {
|
||||
auto mt = nodes[i]->cast<CNodePtr>();
|
||||
// recursively spread all inner tuples.
|
||||
auto mt_inputs = SpreadTuples(mt->inputs(), 1);
|
||||
result.insert(result.end(), mt_inputs.begin(), mt_inputs.end());
|
||||
} else {
|
||||
result.push_back(nodes[i]);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
AnfNodePtrList SpreadUpdateState::ExtendInputsOfUpdateState(const AnfNodePtrList &nodes,
|
||||
const FuncGraphPtr &func_graph) {
|
||||
AnfNodePtrList result;
|
||||
|
@ -85,7 +71,7 @@ bool SpreadUpdateState::Run(const FuncGraphPtr &func_graph) {
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() <= kUpdateStateRealInput) continue;
|
||||
auto inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
|
||||
auto inputs = GkUtils::SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
|
||||
// extend inputs of UpdateState if which have multiple outputs
|
||||
inputs = ExtendInputsOfUpdateState(inputs, func_graph);
|
||||
if (inputs.size() + kUpdateStateRealInput != cnode->size() || inputs[0] != cnode->input(kUpdateStateRealInput)) {
|
||||
|
@ -110,7 +96,7 @@ bool ShrinkUpdateState::Run(const FuncGraphPtr &func_graph) {
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->size() <= kUpdateStateRealInput + 1) continue;
|
||||
AnfNodePtrList mt_inputs = SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
|
||||
AnfNodePtrList mt_inputs = GkUtils::SpreadTuples(cnode->inputs(), kUpdateStateRealInput);
|
||||
AbstractBasePtrList abs_list;
|
||||
std::transform(mt_inputs.begin(), mt_inputs.end(), std::back_inserter(abs_list),
|
||||
[](const AnfNodePtr &inp) { return inp->abstract(); });
|
||||
|
|
|
@ -61,20 +61,6 @@ class ShrinkUpdateState : public opt::Pass {
|
|||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Spread the MakeTuple in node list
|
||||
* @param nodes
|
||||
* @param begin_index
|
||||
* @example
|
||||
* input
|
||||
* nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ]
|
||||
* begin_index: 1
|
||||
* output
|
||||
* [b, i, j, c, d, x, y, z]
|
||||
* @return std::vector<AnfNodePtr>
|
||||
*/
|
||||
AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0);
|
||||
|
||||
/**
|
||||
* @brief Extend the getitem for UpdateState
|
||||
* @example
|
||||
|
|
Loading…
Reference in New Issue