!2176 Fix code review problems of session

Merge pull request !2176 from chenfei_mindspore/code-review-of-session
This commit is contained in:
mindspore-ci-bot 2020-06-20 10:15:59 +08:00 committed by Gitee
commit d4d0faaad9
5 changed files with 61 additions and 28 deletions

View File

@ -60,7 +60,7 @@ static void UpdateLabelSwitch(NotNull<CNodePtr> node) {
MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id;
}
AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), node.get());
node->set_inputs({node->input(0), node->input(1)});
node->set_inputs({node->input(kAnfPrimitiveIndex), node->input(kFirstDataInputIndex)});
}
static void AssignLabelForLabelSet(NotNull<std::shared_ptr<session::KernelGraph>> graph, NotNull<uint32_t *> label_id,

View File

@ -43,10 +43,12 @@ static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFind
const std::map<AnfNodePtr, std::vector<AnfNodePtr>> &real_inputs = kg->real_inputs();
for (auto &iter : real_inputs) {
auto &para = iter.first;
MS_EXCEPTION_IF_NULL(para);
if (para->isa<Parameter>()) {
union_find_set->Add(para);
}
for (auto &arg : iter.second) {
MS_EXCEPTION_IF_NULL(arg);
if (!arg->isa<Parameter>()) {
continue;
}
@ -69,6 +71,7 @@ static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<Union
for (auto &iter : real_inputs) {
auto &para = iter.first;
for (auto &arg : iter.second) {
MS_EXCEPTION_IF_NULL(arg);
if (!arg->isa<Parameter>()) {
continue;
}
@ -104,6 +107,7 @@ static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr>
if (para == main_parameter.get()) {
continue;
}
MS_EXCEPTION_IF_NULL(para);
MS_LOG(INFO) << "Replace " << para->DebugString() << " of graph " << AnfAlgo::GetGraphId(para.get()) << " to "
<< main_parameter->DebugString() << " of graph " << AnfAlgo::GetGraphId(main_parameter.get().get());
kg->ReplaceNode(NOT_NULL(para), main_parameter);
@ -185,6 +189,7 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
for (auto &arg : args) {
MS_EXCEPTION_IF_NULL(arg);
if (arg->isa<Parameter>()) {
MS_EXCEPTION_IF_NULL(parameter);
MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString()
<< ", arg:" << arg->DebugString();
continue;
@ -237,12 +242,12 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
if (cnode->size() < kCNodePrim + 1) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
AnfNodePtr fn = cnode->input(kCNodePrim);
AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex);
if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) {
MS_LOG(DEBUG) << "continue node " << cnode->DebugString();
continue;
}
AnfNodePtr arg = cnode->input(kCNodeCallArg);
AnfNodePtr arg = cnode->input(kFirstDataInputIndex);
if (IsValueNode<KernelGraph>(arg)) {
RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
} else if (!arg->isa<CNode>()) {
@ -268,7 +273,7 @@ void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNul
auto return_node = kg->get_return();
MS_EXCEPTION_IF_NULL(return_node);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
return_node->input(1), attch_node.get()};
return_node->input(kFirstDataInputIndex), attch_node.get()};
auto depend_node = kg->NewCNode(inputs);
return_node->set_input(1, depend_node);
}
@ -305,10 +310,13 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod
void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "process call func " << cur_node->DebugString();
MS_LOG(INFO) << "Process call func " << cur_node->DebugString();
// 1 get kernel graph
const std::vector<AnfNodePtr> &origin_inputs = cur_node->inputs();
if (kCNodeCallArg >= origin_inputs.size()) {
MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size();
}
std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))};
if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) {
MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode";
@ -332,12 +340,12 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end());
cur_node->set_inputs(new_inputs);
cur_node->set_abstract(nullptr);
MS_LOG(INFO) << "success process call func " << cur_node->DebugString();
MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString();
}
void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
const CNodePtr &next_node, const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "process switch node " << cur_node->DebugString();
MS_LOG(INFO) << "Process switch node " << cur_node->DebugString();
if (cur_node->size() < kCNodeSwitchLength) {
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength;
@ -369,13 +377,13 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end());
cur_node->set_inputs(new_switch_inputs);
cur_node->set_abstract(nullptr);
MS_LOG(INFO) << "success process switch func " << cur_node->DebugString();
MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString();
}
void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
const CNodePtr &next_node,
const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "process switch node " << cur_node->DebugString();
MS_LOG(INFO) << "Process switch node " << cur_node->DebugString();
if (cur_node->size() < kCNodeSwitchLayerLength) {
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength;
@ -396,6 +404,9 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
}
// 3 recurse sub graph
const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs();
if (kCNodeSwitchCond >= origin_switch_inputs.size()) {
MS_LOG(EXCEPTION) << "Index out of range:" << origin_switch_inputs.size() << ".";
}
std::vector<AnfNodePtr> new_switch_inputs = {
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
origin_switch_inputs[kCNodeSwitchCond]};
@ -410,7 +421,7 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
cur_node->set_inputs(new_switch_inputs);
cur_node->set_abstract(nullptr);
MS_LOG(INFO) << "success process switch layer " << cur_node->DebugString();
MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString();
}
std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
@ -419,11 +430,15 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
}
// 2.1 branch kernel graph and args
auto partial_cnode = utils::cast<CNodePtr>(node.get());
MS_EXCEPTION_IF_NULL(partial_cnode);
if (partial_cnode->size() < kCNodePartialLength) {
MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength;
}
auto partial_inputs = partial_cnode->inputs();
const auto &partial_inputs = partial_cnode->inputs();
if (kCNodePartialFunc >= partial_inputs.size()) {
MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << ".";
}
auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]);
return {partial_cnode, branch_kg};
}
@ -451,7 +466,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "graph:" << graph->graph_id() << " start";
MS_LOG(INFO) << "Graph:" << graph->graph_id() << " start";
if (memo->find(graph) != memo->end()) {
return {};
}
@ -473,6 +488,9 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
if (!CheckLabelIndex(child_order_index, *iter, node, graph)) {
MS_LOG(EXCEPTION) << "Check label index fail";
}
if (child_order_index >= graph->child_graph_order().size()) {
MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size();
}
auto child_graph = graph->child_graph_order()[child_order_index++];
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
@ -516,7 +534,7 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
}
void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) {
MS_LOG(INFO) << "graph id:" << kg->graph_id();
MS_LOG(INFO) << "Graph id:" << kg->graph_id();
kg->SetExecOrderByDefault();
auto call_nodes = kg->FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
std::vector<KernelGraphPtr> child_graph_order;

View File

@ -104,6 +104,7 @@ std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &ar
if (abstract->isa<abstract::AbstractTuple>() &&
!AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
real_args_size += tuple_abstract->size();
continue;
}
@ -181,19 +182,19 @@ static std::vector<std::vector<CNodePtr>> GetChildList(const std::vector<CNodePt
static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> &parameters, const std::vector<AnfNodePtr> &args,
KernelGraph *child_graph) {
MS_EXCEPTION_IF_NULL(child_graph);
MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id();
MS_LOG(INFO) << "Start bind parameter of child graph:" << child_graph->graph_id();
if (args.empty()) {
return;
}
if (parameters.size() != args.size()) {
MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size()
MS_LOG(EXCEPTION) << "Graph:" << child_graph->graph_id() << " parameters size:" << parameters.size()
<< " and args size:" << args.size() << " not equal!";
}
child_graph->SetExecOrderByDefault();
for (size_t i = 0; i < parameters.size(); i++) {
if (args[i] == parameters[i]) {
child_graph->SetRealInput(parameters[i], args[i]);
MS_LOG(INFO) << "Parameter and arg are same";
MS_LOG(INFO) << "Parameter and arg are same.";
continue;
}
child_graph->SetRealInput(parameters[i], args[i]);
@ -238,7 +239,7 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph) {
static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo) {
memo->insert(graph.get());
MS_LOG(INFO) << "start graph id:" << graph->graph_id();
MS_LOG(INFO) << "Start graph id:" << graph->graph_id();
for (auto &child_graph : graph->child_graph_order()) {
if (memo->find(child_graph) != memo->end()) {
MS_LOG(INFO) << "Child graph:" << child_graph->graph_id()
@ -295,9 +296,13 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
}
void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto graph_order = GetGraphOrder(kernel_graph->graph_id());
for (auto graph_id : graph_order) {
auto child_graph = GetGraph(graph_id);
if (child_graph == nullptr) {
continue;
}
if (child_graph->summary_node_exist()) {
kernel_graph->set_summary_node_exist(true);
return;
@ -688,14 +693,15 @@ void AscendSession::ExportChildGraphs(const GraphId graph_id) {
save_graphs_path = ".";
}
if (graph_id == final_graph_id_) {
auto &graph_order = GetGraphOrder(final_graph_id_);
auto &graph_type = GetGraphOrderType(final_graph_id_);
const auto &graph_order = GetGraphOrder(final_graph_id_);
const auto &graph_type = GetGraphOrderType(final_graph_id_);
for (size_t i = 0; i < graph_order.size(); i++) {
if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) {
continue;
}
auto child_graph = GetGraph(graph_order[i]);
const auto child_graph = GetGraph(graph_order[i]);
MS_LOG(DEBUG) << "Start export child graph " << graph_order[i];
MS_EXCEPTION_IF_NULL(child_graph);
std::string file_path = save_graphs_path + "/graph_build_" + std::to_string(child_graph->graph_id()) + ".ir";
DumpIR(file_path, child_graph, true);
DumpIRProto(child_graph, "vm_build_" + std::to_string(child_graph->graph_id()));
@ -772,6 +778,7 @@ void AscendSession::GetSummaryNodes(KernelGraph *graph) {
AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) {
auto fake_graph = GetGraph(fake_graph_id);
MS_EXCEPTION_IF_NULL(fake_graph);
auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0);
auto create_parameter = [&](const AbstractBasePtr &abstract) -> AnfNodePtr {
auto parameter = fake_graph->NewParameter();
@ -792,7 +799,7 @@ AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodeP
if (abstract->isa<abstract::AbstractTuple>()) {
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
MS_LOG(INFO) << "tuple_size [" << tuple_abstract->size() << "]";
MS_LOG(INFO) << "Tuple size [" << tuple_abstract->size() << "]";
return create_parameter((*tuple_abstract)[output_idx]);
}
return create_parameter(cnode->abstract());
@ -984,6 +991,7 @@ void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id,
if (false_graph_id != kInvalidGraphId) {
// false graph and condition in graph same stream
auto condition_graph = GetGraph(cond_graph_id);
MS_EXCEPTION_IF_NULL(condition_graph);
SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true);
// if false graph is a condition graph and has been switch compiled before,it's false should be updated again
auto cond_it = switches_.find(false_graph_id);
@ -991,6 +999,9 @@ void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id,
cond_graph_id = cond_it->first;
false_graph_id = cond_it->second.second;
condition_graph = GetGraph(cond_graph_id);
if (condition_graph == nullptr) {
continue;
}
SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true);
cond_it = switches_.find(false_graph_id);
}
@ -1427,6 +1438,7 @@ static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph,
// count the output of every anf node
std::set<AnfNodePtr> has_output_nodes;
for (auto &anf_node : list) {
MS_EXCEPTION_IF_NULL(anf_node);
for (auto &input : anf_node->inputs()) {
(void)has_output_nodes.insert(input);
}
@ -1435,12 +1447,13 @@ static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph,
auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()));
std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve};
int output_idx = 0;
MS_EXCEPTION_IF_NULL(new_kernel_graph);
for (auto &anf_node : list) {
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
new_kernel_graph->set_return(anf_node);
}
if (has_output_nodes.find(anf_node) == has_output_nodes.end()) {
MS_LOG(INFO) << "output[" << output_idx++ << "]:" << anf_node->DebugString();
MS_LOG(INFO) << "Output[" << output_idx++ << "]:" << anf_node->DebugString();
make_tuple_inputs.push_back(anf_node);
}
}
@ -1458,6 +1471,7 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
std::vector<AnfNodePtr> new_graph_inputs;
// create new parameter from cnode
for (auto &anf_node : list) {
MS_EXCEPTION_IF_NULL(anf_node);
auto cnode = anf_node->cast<CNodePtr>();
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
auto input = cnode->inputs()[input_idx];
@ -1536,12 +1550,12 @@ AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph,
auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list);
std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input));
auto new_call = graph->NewCNode(new_call_input);
AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call);
AnfAlgo::SetNodeAttr("graph_id", MakeValue(graph->graph_id()), new_call);
return new_call;
}
void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) {
MS_LOG(INFO) << "start,graph_id:" << graph->graph_id();
MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id();
auto apply_list = GetCNodes(TopoSort(graph->get_return()));
// update the root graph child graph order
AscendControlParser::UpdateChildGraphOrder(graph);

View File

@ -780,12 +780,12 @@ void SessionBasic::Summary(KernelGraph *graph) {
return;
}
MS_EXCEPTION_IF_NULL(graph);
GetSummaryNodes(graph);
auto summary_outputs = graph->summary_nodes();
// do not exist summary node
if (summary_outputs.empty()) {
bool exist_summary = graph->summary_node_exist();
if (!exist_summary) {
return;
}
GetSummaryNodes(graph);
auto summary_outputs = graph->summary_nodes();
std::map<std::string, tensor::TensorPtr> params_list;
// fetch outputs apply kernel in session & run callback functions
for (auto &output_item : summary_outputs) {

View File

@ -229,6 +229,7 @@ const int kValueNodeTensorMask = 2;
// define special index in special node
constexpr auto kAnfPrimitiveIndex = 0;
constexpr auto kFirstDataInputIndex = 1;
constexpr auto kAnfPartialFuncGraphIndex = 1;
constexpr auto kRealInputNodeIndexInTupleGetItem = 1;
constexpr auto kInputNodeOutputIndexInTupleGetItem = 2;