!2176 Fix code review problems of session

Merge pull request !2176 from chenfei_mindspore/code-review-of-session
This commit is contained in:
mindspore-ci-bot 2020-06-20 10:15:59 +08:00 committed by Gitee
commit d4d0faaad9
5 changed files with 61 additions and 28 deletions

View File

@ -60,7 +60,7 @@ static void UpdateLabelSwitch(NotNull<CNodePtr> node) {
MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id; MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id;
} }
AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), node.get()); AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), node.get());
node->set_inputs({node->input(0), node->input(1)}); node->set_inputs({node->input(kAnfPrimitiveIndex), node->input(kFirstDataInputIndex)});
} }
static void AssignLabelForLabelSet(NotNull<std::shared_ptr<session::KernelGraph>> graph, NotNull<uint32_t *> label_id, static void AssignLabelForLabelSet(NotNull<std::shared_ptr<session::KernelGraph>> graph, NotNull<uint32_t *> label_id,

View File

@ -43,10 +43,12 @@ static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFind
const std::map<AnfNodePtr, std::vector<AnfNodePtr>> &real_inputs = kg->real_inputs(); const std::map<AnfNodePtr, std::vector<AnfNodePtr>> &real_inputs = kg->real_inputs();
for (auto &iter : real_inputs) { for (auto &iter : real_inputs) {
auto &para = iter.first; auto &para = iter.first;
MS_EXCEPTION_IF_NULL(para);
if (para->isa<Parameter>()) { if (para->isa<Parameter>()) {
union_find_set->Add(para); union_find_set->Add(para);
} }
for (auto &arg : iter.second) { for (auto &arg : iter.second) {
MS_EXCEPTION_IF_NULL(arg);
if (!arg->isa<Parameter>()) { if (!arg->isa<Parameter>()) {
continue; continue;
} }
@ -69,6 +71,7 @@ static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<Union
for (auto &iter : real_inputs) { for (auto &iter : real_inputs) {
auto &para = iter.first; auto &para = iter.first;
for (auto &arg : iter.second) { for (auto &arg : iter.second) {
MS_EXCEPTION_IF_NULL(arg);
if (!arg->isa<Parameter>()) { if (!arg->isa<Parameter>()) {
continue; continue;
} }
@ -104,6 +107,7 @@ static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr>
if (para == main_parameter.get()) { if (para == main_parameter.get()) {
continue; continue;
} }
MS_EXCEPTION_IF_NULL(para);
MS_LOG(INFO) << "Replace " << para->DebugString() << " of graph " << AnfAlgo::GetGraphId(para.get()) << " to " MS_LOG(INFO) << "Replace " << para->DebugString() << " of graph " << AnfAlgo::GetGraphId(para.get()) << " to "
<< main_parameter->DebugString() << " of graph " << AnfAlgo::GetGraphId(main_parameter.get().get()); << main_parameter->DebugString() << " of graph " << AnfAlgo::GetGraphId(main_parameter.get().get());
kg->ReplaceNode(NOT_NULL(para), main_parameter); kg->ReplaceNode(NOT_NULL(para), main_parameter);
@ -185,6 +189,7 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
for (auto &arg : args) { for (auto &arg : args) {
MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(arg);
if (arg->isa<Parameter>()) { if (arg->isa<Parameter>()) {
MS_EXCEPTION_IF_NULL(parameter);
MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString()
<< ", arg:" << arg->DebugString(); << ", arg:" << arg->DebugString();
continue; continue;
@ -237,12 +242,12 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
if (cnode->size() < kCNodePrim + 1) { if (cnode->size() < kCNodePrim + 1) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
} }
AnfNodePtr fn = cnode->input(kCNodePrim); AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex);
if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) { if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) {
MS_LOG(DEBUG) << "continue node " << cnode->DebugString(); MS_LOG(DEBUG) << "continue node " << cnode->DebugString();
continue; continue;
} }
AnfNodePtr arg = cnode->input(kCNodeCallArg); AnfNodePtr arg = cnode->input(kFirstDataInputIndex);
if (IsValueNode<KernelGraph>(arg)) { if (IsValueNode<KernelGraph>(arg)) {
RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
} else if (!arg->isa<CNode>()) { } else if (!arg->isa<CNode>()) {
@ -268,7 +273,7 @@ void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNul
auto return_node = kg->get_return(); auto return_node = kg->get_return();
MS_EXCEPTION_IF_NULL(return_node); MS_EXCEPTION_IF_NULL(return_node);
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())), std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
return_node->input(1), attch_node.get()}; return_node->input(kFirstDataInputIndex), attch_node.get()};
auto depend_node = kg->NewCNode(inputs); auto depend_node = kg->NewCNode(inputs);
return_node->set_input(1, depend_node); return_node->set_input(1, depend_node);
} }
@ -305,10 +310,13 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod
void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
const NotNull<std::set<KernelGraphPtr> *> memo) { const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "process call func " << cur_node->DebugString(); MS_LOG(INFO) << "Process call func " << cur_node->DebugString();
// 1 get kernel graph // 1 get kernel graph
const std::vector<AnfNodePtr> &origin_inputs = cur_node->inputs(); const std::vector<AnfNodePtr> &origin_inputs = cur_node->inputs();
if (kCNodeCallArg >= origin_inputs.size()) {
MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size();
}
std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))}; std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))};
if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) { if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) {
MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode"; MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode";
@ -332,12 +340,12 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end()); new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end());
cur_node->set_inputs(new_inputs); cur_node->set_inputs(new_inputs);
cur_node->set_abstract(nullptr); cur_node->set_abstract(nullptr);
MS_LOG(INFO) << "success process call func " << cur_node->DebugString(); MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString();
} }
void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
const CNodePtr &next_node, const NotNull<std::set<KernelGraphPtr> *> memo) { const CNodePtr &next_node, const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); MS_LOG(INFO) << "Process switch node " << cur_node->DebugString();
if (cur_node->size() < kCNodeSwitchLength) { if (cur_node->size() < kCNodeSwitchLength) {
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength; MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength;
@ -369,13 +377,13 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end()); new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end());
cur_node->set_inputs(new_switch_inputs); cur_node->set_inputs(new_switch_inputs);
cur_node->set_abstract(nullptr); cur_node->set_abstract(nullptr);
MS_LOG(INFO) << "success process switch func " << cur_node->DebugString(); MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString();
} }
void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
const CNodePtr &next_node, const CNodePtr &next_node,
const NotNull<std::set<KernelGraphPtr> *> memo) { const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); MS_LOG(INFO) << "Process switch node " << cur_node->DebugString();
if (cur_node->size() < kCNodeSwitchLayerLength) { if (cur_node->size() < kCNodeSwitchLayerLength) {
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength; MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength;
@ -396,6 +404,9 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
} }
// 3 recurse sub graph // 3 recurse sub graph
const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs(); const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs();
if (kCNodeSwitchCond >= origin_switch_inputs.size()) {
MS_LOG(EXCEPTION) << "Index out of range:" << origin_switch_inputs.size() << ".";
}
std::vector<AnfNodePtr> new_switch_inputs = { std::vector<AnfNodePtr> new_switch_inputs = {
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
origin_switch_inputs[kCNodeSwitchCond]}; origin_switch_inputs[kCNodeSwitchCond]};
@ -410,7 +421,7 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
cur_node->set_inputs(new_switch_inputs); cur_node->set_inputs(new_switch_inputs);
cur_node->set_abstract(nullptr); cur_node->set_abstract(nullptr);
MS_LOG(INFO) << "success process switch layer " << cur_node->DebugString(); MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString();
} }
std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) { std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
@ -419,11 +430,15 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
} }
// 2.1 branch kernel graph and args // 2.1 branch kernel graph and args
auto partial_cnode = utils::cast<CNodePtr>(node.get()); auto partial_cnode = utils::cast<CNodePtr>(node.get());
MS_EXCEPTION_IF_NULL(partial_cnode);
if (partial_cnode->size() < kCNodePartialLength) { if (partial_cnode->size() < kCNodePartialLength) {
MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength; MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength;
} }
auto partial_inputs = partial_cnode->inputs(); const auto &partial_inputs = partial_cnode->inputs();
if (kCNodePartialFunc >= partial_inputs.size()) {
MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << ".";
}
auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]); auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]);
return {partial_cnode, branch_kg}; return {partial_cnode, branch_kg};
} }
@ -451,7 +466,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph, std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo) { const NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "graph:" << graph->graph_id() << " start"; MS_LOG(INFO) << "Graph:" << graph->graph_id() << " start";
if (memo->find(graph) != memo->end()) { if (memo->find(graph) != memo->end()) {
return {}; return {};
} }
@ -473,6 +488,9 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
if (!CheckLabelIndex(child_order_index, *iter, node, graph)) { if (!CheckLabelIndex(child_order_index, *iter, node, graph)) {
MS_LOG(EXCEPTION) << "Check label index fail"; MS_LOG(EXCEPTION) << "Check label index fail";
} }
if (child_order_index >= graph->child_graph_order().size()) {
MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size();
}
auto child_graph = graph->child_graph_order()[child_order_index++]; auto child_graph = graph->child_graph_order()[child_order_index++];
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
@ -516,7 +534,7 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
} }
void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) { void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) {
MS_LOG(INFO) << "graph id:" << kg->graph_id(); MS_LOG(INFO) << "Graph id:" << kg->graph_id();
kg->SetExecOrderByDefault(); kg->SetExecOrderByDefault();
auto call_nodes = kg->FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name())); auto call_nodes = kg->FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
std::vector<KernelGraphPtr> child_graph_order; std::vector<KernelGraphPtr> child_graph_order;

View File

@ -104,6 +104,7 @@ std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &ar
if (abstract->isa<abstract::AbstractTuple>() && if (abstract->isa<abstract::AbstractTuple>() &&
!AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) { !AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>(); auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
real_args_size += tuple_abstract->size(); real_args_size += tuple_abstract->size();
continue; continue;
} }
@ -181,19 +182,19 @@ static std::vector<std::vector<CNodePtr>> GetChildList(const std::vector<CNodePt
static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> &parameters, const std::vector<AnfNodePtr> &args, static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> &parameters, const std::vector<AnfNodePtr> &args,
KernelGraph *child_graph) { KernelGraph *child_graph) {
MS_EXCEPTION_IF_NULL(child_graph); MS_EXCEPTION_IF_NULL(child_graph);
MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id(); MS_LOG(INFO) << "Start bind parameter of child graph:" << child_graph->graph_id();
if (args.empty()) { if (args.empty()) {
return; return;
} }
if (parameters.size() != args.size()) { if (parameters.size() != args.size()) {
MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size() MS_LOG(EXCEPTION) << "Graph:" << child_graph->graph_id() << " parameters size:" << parameters.size()
<< " and args size:" << args.size() << " not equal!"; << " and args size:" << args.size() << " not equal!";
} }
child_graph->SetExecOrderByDefault(); child_graph->SetExecOrderByDefault();
for (size_t i = 0; i < parameters.size(); i++) { for (size_t i = 0; i < parameters.size(); i++) {
if (args[i] == parameters[i]) { if (args[i] == parameters[i]) {
child_graph->SetRealInput(parameters[i], args[i]); child_graph->SetRealInput(parameters[i], args[i]);
MS_LOG(INFO) << "Parameter and arg are same"; MS_LOG(INFO) << "Parameter and arg are same.";
continue; continue;
} }
child_graph->SetRealInput(parameters[i], args[i]); child_graph->SetRealInput(parameters[i], args[i]);
@ -238,7 +239,7 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph) {
static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph, static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo) { const NotNull<std::set<KernelGraphPtr> *> memo) {
memo->insert(graph.get()); memo->insert(graph.get());
MS_LOG(INFO) << "start graph id:" << graph->graph_id(); MS_LOG(INFO) << "Start graph id:" << graph->graph_id();
for (auto &child_graph : graph->child_graph_order()) { for (auto &child_graph : graph->child_graph_order()) {
if (memo->find(child_graph) != memo->end()) { if (memo->find(child_graph) != memo->end()) {
MS_LOG(INFO) << "Child graph:" << child_graph->graph_id() MS_LOG(INFO) << "Child graph:" << child_graph->graph_id()
@ -295,9 +296,13 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
} }
void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph) { void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto graph_order = GetGraphOrder(kernel_graph->graph_id()); auto graph_order = GetGraphOrder(kernel_graph->graph_id());
for (auto graph_id : graph_order) { for (auto graph_id : graph_order) {
auto child_graph = GetGraph(graph_id); auto child_graph = GetGraph(graph_id);
if (child_graph == nullptr) {
continue;
}
if (child_graph->summary_node_exist()) { if (child_graph->summary_node_exist()) {
kernel_graph->set_summary_node_exist(true); kernel_graph->set_summary_node_exist(true);
return; return;
@ -688,14 +693,15 @@ void AscendSession::ExportChildGraphs(const GraphId graph_id) {
save_graphs_path = "."; save_graphs_path = ".";
} }
if (graph_id == final_graph_id_) { if (graph_id == final_graph_id_) {
auto &graph_order = GetGraphOrder(final_graph_id_); const auto &graph_order = GetGraphOrder(final_graph_id_);
auto &graph_type = GetGraphOrderType(final_graph_id_); const auto &graph_type = GetGraphOrderType(final_graph_id_);
for (size_t i = 0; i < graph_order.size(); i++) { for (size_t i = 0; i < graph_order.size(); i++) {
if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) { if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) {
continue; continue;
} }
auto child_graph = GetGraph(graph_order[i]); const auto child_graph = GetGraph(graph_order[i]);
MS_LOG(DEBUG) << "Start export child graph " << graph_order[i]; MS_LOG(DEBUG) << "Start export child graph " << graph_order[i];
MS_EXCEPTION_IF_NULL(child_graph);
std::string file_path = save_graphs_path + "/graph_build_" + std::to_string(child_graph->graph_id()) + ".ir"; std::string file_path = save_graphs_path + "/graph_build_" + std::to_string(child_graph->graph_id()) + ".ir";
DumpIR(file_path, child_graph, true); DumpIR(file_path, child_graph, true);
DumpIRProto(child_graph, "vm_build_" + std::to_string(child_graph->graph_id())); DumpIRProto(child_graph, "vm_build_" + std::to_string(child_graph->graph_id()));
@ -772,6 +778,7 @@ void AscendSession::GetSummaryNodes(KernelGraph *graph) {
AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) { AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) {
auto fake_graph = GetGraph(fake_graph_id); auto fake_graph = GetGraph(fake_graph_id);
MS_EXCEPTION_IF_NULL(fake_graph);
auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0); auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0);
auto create_parameter = [&](const AbstractBasePtr &abstract) -> AnfNodePtr { auto create_parameter = [&](const AbstractBasePtr &abstract) -> AnfNodePtr {
auto parameter = fake_graph->NewParameter(); auto parameter = fake_graph->NewParameter();
@ -792,7 +799,7 @@ AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodeP
if (abstract->isa<abstract::AbstractTuple>()) { if (abstract->isa<abstract::AbstractTuple>()) {
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>(); auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract); MS_EXCEPTION_IF_NULL(tuple_abstract);
MS_LOG(INFO) << "tuple_size [" << tuple_abstract->size() << "]"; MS_LOG(INFO) << "Tuple size [" << tuple_abstract->size() << "]";
return create_parameter((*tuple_abstract)[output_idx]); return create_parameter((*tuple_abstract)[output_idx]);
} }
return create_parameter(cnode->abstract()); return create_parameter(cnode->abstract());
@ -984,6 +991,7 @@ void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id,
if (false_graph_id != kInvalidGraphId) { if (false_graph_id != kInvalidGraphId) {
// false graph and condition in graph same stream // false graph and condition in graph same stream
auto condition_graph = GetGraph(cond_graph_id); auto condition_graph = GetGraph(cond_graph_id);
MS_EXCEPTION_IF_NULL(condition_graph);
SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true);
// if false graph is a condition graph and has been switch compiled before,it's false should be updated again // if false graph is a condition graph and has been switch compiled before,it's false should be updated again
auto cond_it = switches_.find(false_graph_id); auto cond_it = switches_.find(false_graph_id);
@ -991,6 +999,9 @@ void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id,
cond_graph_id = cond_it->first; cond_graph_id = cond_it->first;
false_graph_id = cond_it->second.second; false_graph_id = cond_it->second.second;
condition_graph = GetGraph(cond_graph_id); condition_graph = GetGraph(cond_graph_id);
if (condition_graph == nullptr) {
continue;
}
SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true); SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true);
cond_it = switches_.find(false_graph_id); cond_it = switches_.find(false_graph_id);
} }
@ -1427,6 +1438,7 @@ static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph,
// count the output of every anf node // count the output of every anf node
std::set<AnfNodePtr> has_output_nodes; std::set<AnfNodePtr> has_output_nodes;
for (auto &anf_node : list) { for (auto &anf_node : list) {
MS_EXCEPTION_IF_NULL(anf_node);
for (auto &input : anf_node->inputs()) { for (auto &input : anf_node->inputs()) {
(void)has_output_nodes.insert(input); (void)has_output_nodes.insert(input);
} }
@ -1435,12 +1447,13 @@ static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph,
auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())); auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()));
std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve}; std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve};
int output_idx = 0; int output_idx = 0;
MS_EXCEPTION_IF_NULL(new_kernel_graph);
for (auto &anf_node : list) { for (auto &anf_node : list) {
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) { if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
new_kernel_graph->set_return(anf_node); new_kernel_graph->set_return(anf_node);
} }
if (has_output_nodes.find(anf_node) == has_output_nodes.end()) { if (has_output_nodes.find(anf_node) == has_output_nodes.end()) {
MS_LOG(INFO) << "output[" << output_idx++ << "]:" << anf_node->DebugString(); MS_LOG(INFO) << "Output[" << output_idx++ << "]:" << anf_node->DebugString();
make_tuple_inputs.push_back(anf_node); make_tuple_inputs.push_back(anf_node);
} }
} }
@ -1458,6 +1471,7 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
std::vector<AnfNodePtr> new_graph_inputs; std::vector<AnfNodePtr> new_graph_inputs;
// create new parameter from cnode // create new parameter from cnode
for (auto &anf_node : list) { for (auto &anf_node : list) {
MS_EXCEPTION_IF_NULL(anf_node);
auto cnode = anf_node->cast<CNodePtr>(); auto cnode = anf_node->cast<CNodePtr>();
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
auto input = cnode->inputs()[input_idx]; auto input = cnode->inputs()[input_idx];
@ -1536,12 +1550,12 @@ AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph,
auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list); auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list);
std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input)); std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input));
auto new_call = graph->NewCNode(new_call_input); auto new_call = graph->NewCNode(new_call_input);
AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call); AnfAlgo::SetNodeAttr("graph_id", MakeValue(graph->graph_id()), new_call);
return new_call; return new_call;
} }
void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) { void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) {
MS_LOG(INFO) << "start,graph_id:" << graph->graph_id(); MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id();
auto apply_list = GetCNodes(TopoSort(graph->get_return())); auto apply_list = GetCNodes(TopoSort(graph->get_return()));
// update the root graph child graph order // update the root graph child graph order
AscendControlParser::UpdateChildGraphOrder(graph); AscendControlParser::UpdateChildGraphOrder(graph);

View File

@ -780,12 +780,12 @@ void SessionBasic::Summary(KernelGraph *graph) {
return; return;
} }
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
GetSummaryNodes(graph); bool exist_summary = graph->summary_node_exist();
auto summary_outputs = graph->summary_nodes(); if (!exist_summary) {
// do not exist summary node
if (summary_outputs.empty()) {
return; return;
} }
GetSummaryNodes(graph);
auto summary_outputs = graph->summary_nodes();
std::map<std::string, tensor::TensorPtr> params_list; std::map<std::string, tensor::TensorPtr> params_list;
// fetch outputs apply kernel in session & run callback functions // fetch outputs apply kernel in session & run callback functions
for (auto &output_item : summary_outputs) { for (auto &output_item : summary_outputs) {

View File

@ -229,6 +229,7 @@ const int kValueNodeTensorMask = 2;
// define special index in special node // define special index in special node
constexpr auto kAnfPrimitiveIndex = 0; constexpr auto kAnfPrimitiveIndex = 0;
constexpr auto kFirstDataInputIndex = 1;
constexpr auto kAnfPartialFuncGraphIndex = 1; constexpr auto kAnfPartialFuncGraphIndex = 1;
constexpr auto kRealInputNodeIndexInTupleGetItem = 1; constexpr auto kRealInputNodeIndexInTupleGetItem = 1;
constexpr auto kInputNodeOutputIndexInTupleGetItem = 2; constexpr auto kInputNodeOutputIndexInTupleGetItem = 2;