forked from mindspore-Ecosystem/mindspore
!2761 Record unreused arg in kernel graph
Merge pull request !2761 from chenfei_mindspore/split-real-inputs-to-reuse-args-and-not-reuse-args
This commit is contained in:
commit
f6b6ef2796
|
@ -102,7 +102,7 @@ static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGrap
|
|||
memo->insert(graph.get());
|
||||
|
||||
MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString();
|
||||
graph->SetExecOrderByDefault();
|
||||
|
||||
auto nodes = graph->execution_order();
|
||||
auto end_goto = graph->get_end_goto();
|
||||
if (end_goto != nullptr) {
|
||||
|
@ -128,6 +128,7 @@ static void AssignLabelForGotoSwitch(NotNull<std::shared_ptr<session::KernelGrap
|
|||
for (auto &cg : graph->child_graph_order()) {
|
||||
AssignLabelForGotoSwitch(NOT_NULL(cg), memo);
|
||||
}
|
||||
graph->SetExecOrderByDefault();
|
||||
}
|
||||
|
||||
void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph>> graph) {
|
||||
|
|
|
@ -199,7 +199,6 @@ class AnfRuntimeAlgorithm {
|
|||
static bool IsScalarInput(const CNodePtr &cnode, size_t index);
|
||||
static bool IsScalarOutput(const CNodePtr &cnode, size_t index);
|
||||
static void ReorderExecList(NotNull<std::vector<CNodePtr> *> node_list);
|
||||
static bool IsWhileTrueGraph(const KernelGraphPtr &child_graph);
|
||||
// get fix output precision of cnode.
|
||||
static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node);
|
||||
// get fix output precision from prev node, input_idx is the input index of current node related to prev node.
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#include "utils/union_find_set.h"
|
||||
#include "device/ascend/ascend_label_assign.h"
|
||||
|
||||
static constexpr size_t kCNodePrim = 0;
|
||||
static constexpr size_t kCNodeCallArg = 1;
|
||||
|
@ -35,17 +36,25 @@ namespace mindspore {
|
|||
namespace session {
|
||||
static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) {
|
||||
auto &nodes = parent_graph->execution_order();
|
||||
CNodePtr last_jump_node = nullptr;
|
||||
for (auto &node : nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) {
|
||||
return node;
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) &&
|
||||
(child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
|
||||
child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) {
|
||||
return node;
|
||||
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) {
|
||||
if (child_graph->get_start_label() == node->input(kCNodeCallArg)) {
|
||||
return node;
|
||||
}
|
||||
last_jump_node = node;
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) {
|
||||
if (child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
|
||||
child_graph->get_start_label() == node->input(kCNodeSwitchTrue)) {
|
||||
return node;
|
||||
}
|
||||
last_jump_node = node;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
|
||||
return nullptr;
|
||||
if (last_jump_node == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
|
||||
}
|
||||
return last_jump_node;
|
||||
}
|
||||
|
||||
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
|
||||
|
@ -90,6 +99,9 @@ static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<Union
|
|||
if (!arg->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
if (kg->unreuse_args().find(arg) != kg->unreuse_args().end()) {
|
||||
continue;
|
||||
}
|
||||
union_find_set->Union(arg, para);
|
||||
}
|
||||
}
|
||||
|
@ -133,24 +145,28 @@ static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr>
|
|||
}
|
||||
}
|
||||
|
||||
static AnfNodePtr GetMainParameter(NotNull<KernelGraphPtr> root_kg, const AnfNodePtr key,
|
||||
const std::set<AnfNodePtr> ¶meter_reuse_set) {
|
||||
AnfNodePtr main_parameter = key;
|
||||
std::set<AnfNodePtr> root_inputs_set;
|
||||
const auto &root_inputs_vector = root_kg->inputs();
|
||||
root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end());
|
||||
for (auto &node : parameter_reuse_set) {
|
||||
if (root_inputs_set.find(node) != root_inputs_set.end()) {
|
||||
main_parameter = node;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return main_parameter;
|
||||
}
|
||||
|
||||
static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet<AnfNodePtr> *> parameter_set) {
|
||||
auto parameter_reuse_sets = parameter_set->GetSets();
|
||||
for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) {
|
||||
if (parameter_reuse_set.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
AnfNodePtr main_parameter = key;
|
||||
std::set<AnfNodePtr> root_inputs_set;
|
||||
const auto &root_inputs_vector = root_kg->inputs();
|
||||
root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end());
|
||||
for (auto &node : parameter_reuse_set) {
|
||||
if (root_inputs_set.find(node) != root_inputs_set.end()) {
|
||||
main_parameter = node;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto main_parameter = GetMainParameter(root_kg, key, parameter_reuse_set);
|
||||
std::set<KernelGraphPtr> memo;
|
||||
RecursiveReplaceNode(root_kg, NOT_NULL(main_parameter), parameter_reuse_set, NOT_NULL(&memo));
|
||||
}
|
||||
|
@ -168,6 +184,7 @@ CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start) {
|
|||
void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
|
||||
std::set<KernelGraphPtr> memo;
|
||||
(void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
|
||||
device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kg);
|
||||
std::map<uint32_t, KernelGraphPtr> graph_id_map;
|
||||
for (auto &g : memo) {
|
||||
MS_EXCEPTION_IF_NULL(g);
|
||||
|
@ -177,12 +194,13 @@ void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
|
|||
}
|
||||
graph_id_map[g->graph_id()] = g;
|
||||
}
|
||||
|
||||
// Insert Assign
|
||||
ChildGraphDataAssign(graph_id_map);
|
||||
// Make UnionFindSet
|
||||
UnionFindSet<AnfNodePtr> parameter_set = MakeUnionFindSet(kg);
|
||||
// Reuse Parameter
|
||||
ReuseParameter(kg, NOT_NULL(¶meter_set));
|
||||
// Insert Assign
|
||||
ChildGraphDataAssign(graph_id_map);
|
||||
}
|
||||
|
||||
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
|
||||
|
@ -193,6 +211,7 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
|
|||
void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGraphPtr> &graph_id_map) {
|
||||
for (auto &iter : graph_id_map) {
|
||||
auto &kg = iter.second;
|
||||
MS_LOG(INFO) << "Data assign graph:" << kg->graph_id();
|
||||
MS_EXCEPTION_IF_NULL(kg);
|
||||
std::set<std::pair<AnfNodePtr, AnfNodePtr>> memo;
|
||||
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs = kg->real_inputs();
|
||||
|
@ -206,8 +225,14 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
|
|||
} else {
|
||||
memo.emplace(parameter, arg);
|
||||
}
|
||||
if (arg->isa<Parameter>()) {
|
||||
auto unreuse_args_map = kg->unreuse_args();
|
||||
auto unreuse_arg_iter = unreuse_args_map.find(arg);
|
||||
if (unreuse_arg_iter == unreuse_args_map.end()) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
if (!arg->isa<Parameter>()) {
|
||||
MS_LOG(EXCEPTION) << "Reused arg must be parameter, arg:" << arg->DebugString() << ".";
|
||||
}
|
||||
MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString()
|
||||
<< ", arg:" << arg->DebugString();
|
||||
continue;
|
||||
|
@ -220,6 +245,7 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
|
|||
NOT_NULL(parameter));
|
||||
}
|
||||
}
|
||||
kg->SetExecOrderByDefault();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -353,7 +379,6 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
|
|||
// 5 recurse sub graph
|
||||
CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, memo);
|
||||
new_inputs.push_back(sub_label);
|
||||
new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end());
|
||||
cur_node->set_inputs(new_inputs);
|
||||
cur_node->set_abstract(nullptr);
|
||||
MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString();
|
||||
|
@ -394,7 +419,6 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
|
|||
}
|
||||
std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]);
|
||||
|
||||
new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end());
|
||||
cur_node->set_inputs(new_switch_inputs);
|
||||
cur_node->set_abstract(nullptr);
|
||||
MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString();
|
||||
|
@ -477,6 +501,16 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> fr
|
|||
auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
|
||||
if (assign_node != nullptr) {
|
||||
auto jump_node = GetJumpNode(from_graph, to_graph);
|
||||
const auto &from_graph_exe_order = from_graph->execution_order();
|
||||
auto jump_node_iter = std::find(from_graph_exe_order.begin(), from_graph_exe_order.end(), jump_node);
|
||||
if (jump_node_iter == from_graph_exe_order.end()) {
|
||||
MS_EXCEPTION_IF_NULL(jump_node);
|
||||
MS_LOG(EXCEPTION) << "Can't find node:" << jump_node->DebugString() << " in graph:" << from_graph->graph_id();
|
||||
}
|
||||
// insert assign between jump_node -1 and jump_node
|
||||
if (jump_node_iter != from_graph_exe_order.begin()) {
|
||||
InsertControlDependToGraph(from_graph, NOT_NULL(*(jump_node_iter - 1)), NOT_NULL(assign_node));
|
||||
}
|
||||
if (jump_node != nullptr) {
|
||||
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
|
||||
}
|
||||
|
@ -501,8 +535,6 @@ AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg,
|
|||
auto assign_node = kg->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(assign_node);
|
||||
assign_node->set_abstract(to->abstract());
|
||||
// append the assign at the end of from graph
|
||||
InsertDependToGraph(kg, NOT_NULL(assign_node));
|
||||
return assign_node;
|
||||
}
|
||||
|
||||
|
@ -527,7 +559,6 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
|
|||
|
||||
std::vector<CNodePtr> execution_order;
|
||||
uint32_t child_order_index = 0;
|
||||
|
||||
for (auto &node : cnodes) {
|
||||
execution_order.push_back(node);
|
||||
if (node == graph->get_end_goto()) {
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "session/kernel_graph.h"
|
||||
#include "utils/base_ref.h"
|
||||
#include "utils/contract.h"
|
||||
#include "utils/union_find_set.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
|
|
@ -202,7 +202,8 @@ static std::vector<std::vector<CNodePtr>> GetChildList(const std::vector<CNodePt
|
|||
}
|
||||
|
||||
static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> ¶meters, const std::vector<AnfNodePtr> &args,
|
||||
KernelGraph *child_graph) {
|
||||
const KernelGraphPtr &graph, KernelGraphPtr child_graph,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
MS_LOG(INFO) << "Start bind parameter of child graph:" << child_graph->graph_id();
|
||||
if (args.empty()) {
|
||||
|
@ -214,18 +215,25 @@ static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> ¶meters,
|
|||
}
|
||||
child_graph->SetExecOrderByDefault();
|
||||
for (size_t i = 0; i < parameters.size(); i++) {
|
||||
MS_LOG(INFO) << "parameters[" << i << "]" << parameters[i]->DebugString() << ",args[" << i << "]"
|
||||
<< args[i]->DebugString();
|
||||
if (args[i] == parameters[i]) {
|
||||
child_graph->SetRealInput(parameters[i], args[i]);
|
||||
MS_LOG(INFO) << "Parameter and arg are same.";
|
||||
continue;
|
||||
}
|
||||
child_graph->SetRealInput(parameters[i], args[i]);
|
||||
if (memo->find(child_graph) != memo->end() || !args[i]->isa<Parameter>()) {
|
||||
MS_LOG(INFO) << "Add unreused arg,graph:" << graph->graph_id();
|
||||
child_graph->AddUnreuseArgs(args[i], graph);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of
|
||||
// graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2]
|
||||
static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag) {
|
||||
static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
MS_EXCEPTION_IF_NULL(memo.get());
|
||||
auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall);
|
||||
for (auto &call_node : call_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(call_node);
|
||||
|
@ -235,7 +243,7 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag) {
|
|||
std::vector<AnfNodePtr> real_args =
|
||||
std::vector<AnfNodePtr>(call_node->inputs().begin() + 2, call_node->inputs().end());
|
||||
std::vector<AnfNodePtr> child_inputs = child_graphs[0]->inputs();
|
||||
BindCallArgsWithParameter(child_inputs, real_args, child_graphs[0].get());
|
||||
BindCallArgsWithParameter(child_inputs, real_args, graph, child_graphs[0], memo);
|
||||
if (split_flag) {
|
||||
call_node->set_inputs(std::vector<AnfNodePtr>(call_node->inputs().begin(), call_node->inputs().begin() + 2));
|
||||
}
|
||||
|
@ -256,8 +264,8 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag) {
|
|||
}
|
||||
return ret;
|
||||
};
|
||||
BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get());
|
||||
BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get());
|
||||
BindCallArgsWithParameter(child_graphs[0]->inputs(), get_partial_args(2), graph, child_graphs[0], memo);
|
||||
BindCallArgsWithParameter(child_graphs[1]->inputs(), get_partial_args(3), graph, child_graphs[1], memo);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -306,8 +314,6 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|||
LinkChildGraphs(NOT_NULL(root_graph));
|
||||
// resource initialize
|
||||
InitRuntimeResource();
|
||||
// assign label
|
||||
AssignLabel(NOT_NULL(root_graph));
|
||||
// recurse compile child root_graph
|
||||
std::set<KernelGraphPtr> memo;
|
||||
RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||
|
@ -665,12 +671,6 @@ void AscendSession::AssignStream(NotNull<KernelGraphPtr> kernel_graph) const {
|
|||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
void AscendSession::AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
device::ascend::AscendLabelAssign::GetInstance().AssignLabel(kernel_graph);
|
||||
MS_LOG(INFO) << "Finish!";
|
||||
}
|
||||
|
||||
void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
MS_LOG(INFO) << "Start!";
|
||||
struct timeval start_time, end_time;
|
||||
|
@ -1582,14 +1582,17 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
|
|||
auto input = cnode->inputs()[input_idx];
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
AnfNodePtr new_parameter = nullptr;
|
||||
// check whether input has been put into args of call, if mulptiple use of one parameter or cnode, only set one
|
||||
// parameter in graph inputs and one arg in call node
|
||||
auto call_input_it = std::find(call_node_inputs.begin(), call_node_inputs.end(), input);
|
||||
if (call_input_it != call_node_inputs.end()) {
|
||||
cnode->set_input(input_idx, new_graph_inputs[std::distance(call_node_inputs.begin(), call_input_it)]);
|
||||
continue;
|
||||
}
|
||||
// value node consider move to new graph
|
||||
if (input->isa<ValueNode>()) {
|
||||
cnode->set_input(input_idx, input);
|
||||
continue;
|
||||
} else if (input->isa<Parameter>()) {
|
||||
// parameter reuse and should attention mulptiple use of one parameter
|
||||
cnode->set_input(input_idx, input);
|
||||
new_parameter = input;
|
||||
} else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) {
|
||||
// if is cnode and not in current child graph
|
||||
new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get());
|
||||
|
@ -1598,12 +1601,8 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
|
|||
// if is a cnode and in current graph
|
||||
continue;
|
||||
}
|
||||
// if mulptiple use of one parameter or cnode, only set one parameter in graph inputs and one arg in call node
|
||||
// args
|
||||
if (std::find(call_node_inputs.begin(), call_node_inputs.end(), new_parameter) == call_node_inputs.end()) {
|
||||
new_graph_inputs.push_back(new_parameter);
|
||||
call_node_inputs.push_back(input);
|
||||
}
|
||||
new_graph_inputs.push_back(new_parameter);
|
||||
call_node_inputs.push_back(input);
|
||||
}
|
||||
}
|
||||
// set graph inputs of new graph
|
||||
|
@ -1631,7 +1630,7 @@ void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) {
|
|||
// if root graph output is a call node ,the root graph is condition graph of 'if' sentence
|
||||
auto root_graph_output = AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0).first;
|
||||
if (AnfAlgo::CheckPrimitiveType(root_graph_output, prim::kPrimCall)) {
|
||||
SplitGraph(root_graph, {prim::kPrimReturn});
|
||||
SplitGraph(root_graph, {prim::kPrimReturn}, NOT_NULL(&memo));
|
||||
for (auto &child_graph : root_graph->child_graph_order()) {
|
||||
RecurseSplitGraph(NOT_NULL(child_graph), NOT_NULL(&memo));
|
||||
}
|
||||
|
@ -1672,7 +1671,8 @@ AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph,
|
|||
return new_call;
|
||||
}
|
||||
|
||||
void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) {
|
||||
void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id();
|
||||
bool split_flag = false;
|
||||
auto apply_list = GetCNodes(TopoSort(graph->get_return()));
|
||||
|
@ -1710,14 +1710,13 @@ void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<Pri
|
|||
split_flag = true;
|
||||
}
|
||||
AscendControlParser::UpdateChildGraphOrder(graph);
|
||||
UpdateRealInput(graph, split_flag);
|
||||
UpdateRealInput(graph, split_flag, memo);
|
||||
MS_LOG(INFO) << "Split graph[" << graph->graph_id() << "] end";
|
||||
// recurse to split child graph
|
||||
}
|
||||
|
||||
void AscendSession::RecurseSplitGraph(NotNull<KernelGraphPtr> graph, const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
memo->insert(graph.get());
|
||||
SplitGraph(graph, {prim::kPrimCall});
|
||||
SplitGraph(graph, {prim::kPrimCall}, memo);
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
if (memo->find(child_graph) == memo->end()) {
|
||||
RecurseSplitGraph(NOT_NULL(child_graph), memo);
|
||||
|
|
|
@ -77,7 +77,6 @@ class AscendSession : public SessionBasic {
|
|||
void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const;
|
||||
void AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const;
|
||||
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
|
||||
void MemoryAlloc(KernelGraph *kernel_graph) const;
|
||||
void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const;
|
||||
|
@ -100,7 +99,8 @@ class AscendSession : public SessionBasic {
|
|||
void SetFinalGraphOutput(const ValuePtr &value);
|
||||
void SetFinalGraphOutput(const VectorRef &vec_output);
|
||||
|
||||
void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims);
|
||||
void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
// split graphs with recurse from root graph
|
||||
void SplitGraphs(NotNull<KernelGraphPtr> root_graph);
|
||||
void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
|
||||
|
|
|
@ -103,6 +103,23 @@ AnfNodePtr MakeValueNode(const AnfNodePtr &node) {
|
|||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
|
||||
return new_value_node;
|
||||
}
|
||||
|
||||
bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
|
||||
if (left == right) {
|
||||
return true;
|
||||
}
|
||||
if (left == nullptr || right == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (!IsPrimitiveCNode(left, GetCNodePrimitive(right))) {
|
||||
return false;
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, left) && AnfAlgo::HasNodeAttr(kAttrLabelIndex, right)) {
|
||||
return AnfAlgo::GetNodeAttr<uint32_t>(left, kAttrLabelIndex) ==
|
||||
AnfAlgo::GetNodeAttr<uint32_t>(right, kAttrLabelIndex);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
std::vector<AnfNodePtr> KernelGraph::outputs() const {
|
||||
auto graph_output = output();
|
||||
|
@ -219,6 +236,19 @@ void KernelGraph::SetExecOrderByDefault() {
|
|||
if (node == start_label_ || node == end_goto_) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (IsSameLabel(node, end_goto_)) {
|
||||
end_goto_ = node;
|
||||
MS_LOG(INFO) << "Replace end_goto_ in kernel graph:" << graph_id();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (IsSameLabel(node, start_label_)) {
|
||||
start_label_ = node;
|
||||
MS_LOG(INFO) << "Replace start_label_ in kernel graph:" << graph_id();
|
||||
continue;
|
||||
}
|
||||
|
||||
re_order.push_back(node);
|
||||
}
|
||||
if (end_goto_ != nullptr) {
|
||||
|
@ -751,10 +781,9 @@ void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodeP
|
|||
}
|
||||
// update front to backend map
|
||||
FrontBackendlMapUpdate(old_anf_node, new_anf_node);
|
||||
// update output depend relations
|
||||
node_output_edges_[new_anf_node.get()] = it->second;
|
||||
(void)node_output_edges_.erase(old_anf_node);
|
||||
}
|
||||
// if change the ir of graph, regenerate execution order of graph
|
||||
SetExecOrderByDefault();
|
||||
// update graph inputs in child graph
|
||||
auto it_real_inputs = std::find_if(real_inputs_.begin(), real_inputs_.end(),
|
||||
[&old_anf_node](const std::pair<AnfNodePtr, std::vector<AnfNodePtr>> &n) -> bool {
|
||||
|
@ -770,7 +799,7 @@ void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodeP
|
|||
return n.first == new_anf_node.get();
|
||||
});
|
||||
if (iter != real_inputs_.end()) {
|
||||
MS_LOG(WARNING) << new_anf_node->DebugString() << " already exist in real inputs, will be rewrited.";
|
||||
MS_LOG(WARNING) << new_anf_node->DebugString() << " Already exist in real inputs, will be rewrited.";
|
||||
iter->second = old_args;
|
||||
} else {
|
||||
real_inputs_.emplace_back(new_anf_node, old_args);
|
||||
|
@ -827,6 +856,10 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar
|
|||
}
|
||||
}
|
||||
|
||||
void KernelGraph::AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr<KernelGraph> &from_graph) {
|
||||
unreuse_args_[arg] = from_graph;
|
||||
}
|
||||
|
||||
void KernelGraph::UpdateCallRealInput() {
|
||||
MS_LOG(INFO) << "Update graph id: " << graph_id_;
|
||||
std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_map;
|
||||
|
@ -839,6 +872,17 @@ void KernelGraph::UpdateCallRealInput() {
|
|||
// if real input is a call node ,find the child graph output act as the new real input
|
||||
auto tmp_real_input = GetCallRealOutputs(real_input);
|
||||
std::copy(tmp_real_input.begin(), tmp_real_input.end(), std::back_inserter(new_real_inputs));
|
||||
// replace the call in unreuse_args_
|
||||
auto unreuse_arg_it = unreuse_args_.find(real_input);
|
||||
if (unreuse_arg_it != unreuse_args_.end()) {
|
||||
auto old_graph = unreuse_arg_it->second;
|
||||
for (auto new_real_input : new_real_inputs) {
|
||||
// if call reference graph output is parameter, it will be allowed to reuse
|
||||
if (!new_real_input->isa<Parameter>()) {
|
||||
unreuse_args_[new_real_input] = old_graph;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
real_inputs_map.emplace_back(parameter, new_real_inputs);
|
||||
}
|
||||
|
|
|
@ -130,6 +130,9 @@ class KernelGraph : public FuncGraph {
|
|||
// get real inputs
|
||||
const std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> &real_inputs() const { return real_inputs_; }
|
||||
void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg);
|
||||
// mark unreused args
|
||||
void AddUnreuseArgs(const AnfNodePtr &arg, const std::shared_ptr<KernelGraph> &from_graph);
|
||||
const std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> &unreuse_args() const { return unreuse_args_; }
|
||||
// used to dump ir
|
||||
std::string ToString() const override;
|
||||
// update the real input if the node is a call
|
||||
|
@ -205,6 +208,7 @@ class KernelGraph : public FuncGraph {
|
|||
std::shared_ptr<KernelGraph> parent_graph_;
|
||||
// record real parameters,inputs_ is the formal parameters
|
||||
std::vector<std::pair<AnfNodePtr, std::vector<AnfNodePtr>>> real_inputs_;
|
||||
std::map<AnfNodePtr, std::shared_ptr<KernelGraph>> unreuse_args_;
|
||||
|
||||
CNodePtr start_label_;
|
||||
CNodePtr end_goto_;
|
||||
|
|
|
@ -99,6 +99,19 @@ class ControlIfbyIfbyIf(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
class ControlSimpleWhile(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.addn = op.AddN()
|
||||
|
||||
def construct(self, x, y, input_data):
|
||||
out = input_data
|
||||
while x:
|
||||
out = self.addn([input_data, input_data, input_data])
|
||||
x = y
|
||||
return out
|
||||
|
||||
|
||||
class ControlMixedWhileIf(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -204,6 +217,22 @@ def test_if_by_if_by_if():
|
|||
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_simple_while():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
x = np.array(True).astype(np.bool)
|
||||
y = np.array(False).astype(np.bool)
|
||||
input_shape = (127, 7, 53, 31)
|
||||
input_data = np.random.randn(*input_shape).astype(np.float32)
|
||||
net = ControlSimpleWhile()
|
||||
output = net(Tensor(x), Tensor(y), Tensor(input_data))
|
||||
expect = input_data * 3
|
||||
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
|
Loading…
Reference in New Issue