forked from mindspore-Ecosystem/mindspore
!2176 Fix code review problems of session
Merge pull request !2176 from chenfei_mindspore/code-review-of-session
This commit is contained in:
commit
d4d0faaad9
|
@ -60,7 +60,7 @@ static void UpdateLabelSwitch(NotNull<CNodePtr> node) {
|
|||
MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, MakeValue<std::vector<uint32_t>>(label_list), node.get());
|
||||
node->set_inputs({node->input(0), node->input(1)});
|
||||
node->set_inputs({node->input(kAnfPrimitiveIndex), node->input(kFirstDataInputIndex)});
|
||||
}
|
||||
|
||||
static void AssignLabelForLabelSet(NotNull<std::shared_ptr<session::KernelGraph>> graph, NotNull<uint32_t *> label_id,
|
||||
|
|
|
@ -43,10 +43,12 @@ static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFind
|
|||
const std::map<AnfNodePtr, std::vector<AnfNodePtr>> &real_inputs = kg->real_inputs();
|
||||
for (auto &iter : real_inputs) {
|
||||
auto ¶ = iter.first;
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
if (para->isa<Parameter>()) {
|
||||
union_find_set->Add(para);
|
||||
}
|
||||
for (auto &arg : iter.second) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (!arg->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -69,6 +71,7 @@ static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<Union
|
|||
for (auto &iter : real_inputs) {
|
||||
auto ¶ = iter.first;
|
||||
for (auto &arg : iter.second) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (!arg->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -104,6 +107,7 @@ static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr>
|
|||
if (para == main_parameter.get()) {
|
||||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(para);
|
||||
MS_LOG(INFO) << "Replace " << para->DebugString() << " of graph " << AnfAlgo::GetGraphId(para.get()) << " to "
|
||||
<< main_parameter->DebugString() << " of graph " << AnfAlgo::GetGraphId(main_parameter.get().get());
|
||||
kg->ReplaceNode(NOT_NULL(para), main_parameter);
|
||||
|
@ -185,6 +189,7 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
|
|||
for (auto &arg : args) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (arg->isa<Parameter>()) {
|
||||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString()
|
||||
<< ", arg:" << arg->DebugString();
|
||||
continue;
|
||||
|
@ -237,12 +242,12 @@ NotNull<CNodePtr> AscendControlParser::ProcessKernelGraph(NotNull<KernelGraphPtr
|
|||
if (cnode->size() < kCNodePrim + 1) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
|
||||
}
|
||||
AnfNodePtr fn = cnode->input(kCNodePrim);
|
||||
AnfNodePtr fn = cnode->input(kAnfPrimitiveIndex);
|
||||
if (!IsPrimitive(fn, prim::kPrimCall) || cnode->size() < kCNodeCallArg + 1) {
|
||||
MS_LOG(DEBUG) << "continue node " << cnode->DebugString();
|
||||
continue;
|
||||
}
|
||||
AnfNodePtr arg = cnode->input(kCNodeCallArg);
|
||||
AnfNodePtr arg = cnode->input(kFirstDataInputIndex);
|
||||
if (IsValueNode<KernelGraph>(arg)) {
|
||||
RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo);
|
||||
} else if (!arg->isa<CNode>()) {
|
||||
|
@ -268,7 +273,7 @@ void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNul
|
|||
auto return_node = kg->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimDepend->name())),
|
||||
return_node->input(1), attch_node.get()};
|
||||
return_node->input(kFirstDataInputIndex), attch_node.get()};
|
||||
auto depend_node = kg->NewCNode(inputs);
|
||||
return_node->set_input(1, depend_node);
|
||||
}
|
||||
|
@ -305,10 +310,13 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod
|
|||
|
||||
void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
MS_LOG(INFO) << "process call func " << cur_node->DebugString();
|
||||
MS_LOG(INFO) << "Process call func " << cur_node->DebugString();
|
||||
|
||||
// 1 get kernel graph
|
||||
const std::vector<AnfNodePtr> &origin_inputs = cur_node->inputs();
|
||||
if (kCNodeCallArg >= origin_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size();
|
||||
}
|
||||
std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelGotoOpName))};
|
||||
if (!IsValueNode<KernelGraph>(origin_inputs[kCNodeCallArg])) {
|
||||
MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode";
|
||||
|
@ -332,12 +340,12 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP
|
|||
new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end());
|
||||
cur_node->set_inputs(new_inputs);
|
||||
cur_node->set_abstract(nullptr);
|
||||
MS_LOG(INFO) << "success process call func " << cur_node->DebugString();
|
||||
MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString();
|
||||
}
|
||||
|
||||
void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
|
||||
const CNodePtr &next_node, const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
MS_LOG(INFO) << "process switch node " << cur_node->DebugString();
|
||||
MS_LOG(INFO) << "Process switch node " << cur_node->DebugString();
|
||||
|
||||
if (cur_node->size() < kCNodeSwitchLength) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength;
|
||||
|
@ -369,13 +377,13 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
|
|||
new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end());
|
||||
cur_node->set_inputs(new_switch_inputs);
|
||||
cur_node->set_abstract(nullptr);
|
||||
MS_LOG(INFO) << "success process switch func " << cur_node->DebugString();
|
||||
MS_LOG(INFO) << "Succeed processing switch func " << cur_node->DebugString();
|
||||
}
|
||||
|
||||
void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node,
|
||||
const CNodePtr &next_node,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
MS_LOG(INFO) << "process switch node " << cur_node->DebugString();
|
||||
MS_LOG(INFO) << "Process switch node " << cur_node->DebugString();
|
||||
|
||||
if (cur_node->size() < kCNodeSwitchLayerLength) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLayerLength;
|
||||
|
@ -396,6 +404,9 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
|
|||
}
|
||||
// 3 recurse sub graph
|
||||
const std::vector<AnfNodePtr> &origin_switch_inputs = cur_node->inputs();
|
||||
if (kCNodeSwitchCond >= origin_switch_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index out of range:" << origin_switch_inputs.size() << ".";
|
||||
}
|
||||
std::vector<AnfNodePtr> new_switch_inputs = {
|
||||
std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
|
||||
origin_switch_inputs[kCNodeSwitchCond]};
|
||||
|
@ -410,7 +421,7 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
|
|||
new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end());
|
||||
cur_node->set_inputs(new_switch_inputs);
|
||||
cur_node->set_abstract(nullptr);
|
||||
MS_LOG(INFO) << "success process switch layer " << cur_node->DebugString();
|
||||
MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString();
|
||||
}
|
||||
|
||||
std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
|
||||
|
@ -419,11 +430,15 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
|
|||
}
|
||||
// 2.1 branch kernel graph and args
|
||||
auto partial_cnode = utils::cast<CNodePtr>(node.get());
|
||||
MS_EXCEPTION_IF_NULL(partial_cnode);
|
||||
if (partial_cnode->size() < kCNodePartialLength) {
|
||||
MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength;
|
||||
}
|
||||
|
||||
auto partial_inputs = partial_cnode->inputs();
|
||||
const auto &partial_inputs = partial_cnode->inputs();
|
||||
if (kCNodePartialFunc >= partial_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << ".";
|
||||
}
|
||||
auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]);
|
||||
return {partial_cnode, branch_kg};
|
||||
}
|
||||
|
@ -451,7 +466,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
|
|||
|
||||
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
MS_LOG(INFO) << "graph:" << graph->graph_id() << " start";
|
||||
MS_LOG(INFO) << "Graph:" << graph->graph_id() << " start";
|
||||
if (memo->find(graph) != memo->end()) {
|
||||
return {};
|
||||
}
|
||||
|
@ -473,6 +488,9 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
|
|||
if (!CheckLabelIndex(child_order_index, *iter, node, graph)) {
|
||||
MS_LOG(EXCEPTION) << "Check label index fail";
|
||||
}
|
||||
if (child_order_index >= graph->child_graph_order().size()) {
|
||||
MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size();
|
||||
}
|
||||
auto child_graph = graph->child_graph_order()[child_order_index++];
|
||||
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
|
||||
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
|
||||
|
@ -516,7 +534,7 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
|
|||
}
|
||||
|
||||
void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) {
|
||||
MS_LOG(INFO) << "graph id:" << kg->graph_id();
|
||||
MS_LOG(INFO) << "Graph id:" << kg->graph_id();
|
||||
kg->SetExecOrderByDefault();
|
||||
auto call_nodes = kg->FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
|
||||
std::vector<KernelGraphPtr> child_graph_order;
|
||||
|
|
|
@ -104,6 +104,7 @@ std::vector<BaseRef> GetRealArgs(const KernelGraphPtr graph, const VectorRef &ar
|
|||
if (abstract->isa<abstract::AbstractTuple>() &&
|
||||
!AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
real_args_size += tuple_abstract->size();
|
||||
continue;
|
||||
}
|
||||
|
@ -181,19 +182,19 @@ static std::vector<std::vector<CNodePtr>> GetChildList(const std::vector<CNodePt
|
|||
static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> ¶meters, const std::vector<AnfNodePtr> &args,
|
||||
KernelGraph *child_graph) {
|
||||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id();
|
||||
MS_LOG(INFO) << "Start bind parameter of child graph:" << child_graph->graph_id();
|
||||
if (args.empty()) {
|
||||
return;
|
||||
}
|
||||
if (parameters.size() != args.size()) {
|
||||
MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size()
|
||||
MS_LOG(EXCEPTION) << "Graph:" << child_graph->graph_id() << " parameters size:" << parameters.size()
|
||||
<< " and args size:" << args.size() << " not equal!";
|
||||
}
|
||||
child_graph->SetExecOrderByDefault();
|
||||
for (size_t i = 0; i < parameters.size(); i++) {
|
||||
if (args[i] == parameters[i]) {
|
||||
child_graph->SetRealInput(parameters[i], args[i]);
|
||||
MS_LOG(INFO) << "Parameter and arg are same";
|
||||
MS_LOG(INFO) << "Parameter and arg are same.";
|
||||
continue;
|
||||
}
|
||||
child_graph->SetRealInput(parameters[i], args[i]);
|
||||
|
@ -238,7 +239,7 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph) {
|
|||
static void RecurseToUpdateCallRealInput(NotNull<KernelGraphPtr> graph,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
memo->insert(graph.get());
|
||||
MS_LOG(INFO) << "start graph id:" << graph->graph_id();
|
||||
MS_LOG(INFO) << "Start graph id:" << graph->graph_id();
|
||||
for (auto &child_graph : graph->child_graph_order()) {
|
||||
if (memo->find(child_graph) != memo->end()) {
|
||||
MS_LOG(INFO) << "Child graph:" << child_graph->graph_id()
|
||||
|
@ -295,9 +296,13 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
|||
}
|
||||
|
||||
void AscendSession::SetFinalGraphSummaryFlag(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto graph_order = GetGraphOrder(kernel_graph->graph_id());
|
||||
for (auto graph_id : graph_order) {
|
||||
auto child_graph = GetGraph(graph_id);
|
||||
if (child_graph == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (child_graph->summary_node_exist()) {
|
||||
kernel_graph->set_summary_node_exist(true);
|
||||
return;
|
||||
|
@ -688,14 +693,15 @@ void AscendSession::ExportChildGraphs(const GraphId graph_id) {
|
|||
save_graphs_path = ".";
|
||||
}
|
||||
if (graph_id == final_graph_id_) {
|
||||
auto &graph_order = GetGraphOrder(final_graph_id_);
|
||||
auto &graph_type = GetGraphOrderType(final_graph_id_);
|
||||
const auto &graph_order = GetGraphOrder(final_graph_id_);
|
||||
const auto &graph_type = GetGraphOrderType(final_graph_id_);
|
||||
for (size_t i = 0; i < graph_order.size(); i++) {
|
||||
if (graph_type[i] == BRANCH_END || graph_type[i] == BRANCH_START) {
|
||||
continue;
|
||||
}
|
||||
auto child_graph = GetGraph(graph_order[i]);
|
||||
const auto child_graph = GetGraph(graph_order[i]);
|
||||
MS_LOG(DEBUG) << "Start export child graph " << graph_order[i];
|
||||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
std::string file_path = save_graphs_path + "/graph_build_" + std::to_string(child_graph->graph_id()) + ".ir";
|
||||
DumpIR(file_path, child_graph, true);
|
||||
DumpIRProto(child_graph, "vm_build_" + std::to_string(child_graph->graph_id()));
|
||||
|
@ -772,6 +778,7 @@ void AscendSession::GetSummaryNodes(KernelGraph *graph) {
|
|||
|
||||
AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodePtr &true_output) {
|
||||
auto fake_graph = GetGraph(fake_graph_id);
|
||||
MS_EXCEPTION_IF_NULL(fake_graph);
|
||||
auto output_item_with_index = AnfAlgo::VisitKernelWithReturnType(true_output, 0);
|
||||
auto create_parameter = [&](const AbstractBasePtr &abstract) -> AnfNodePtr {
|
||||
auto parameter = fake_graph->NewParameter();
|
||||
|
@ -792,7 +799,7 @@ AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodeP
|
|||
if (abstract->isa<abstract::AbstractTuple>()) {
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
MS_LOG(INFO) << "tuple_size [" << tuple_abstract->size() << "]";
|
||||
MS_LOG(INFO) << "Tuple size [" << tuple_abstract->size() << "]";
|
||||
return create_parameter((*tuple_abstract)[output_idx]);
|
||||
}
|
||||
return create_parameter(cnode->abstract());
|
||||
|
@ -984,6 +991,7 @@ void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id,
|
|||
if (false_graph_id != kInvalidGraphId) {
|
||||
// false graph and condition in graph same stream
|
||||
auto condition_graph = GetGraph(cond_graph_id);
|
||||
MS_EXCEPTION_IF_NULL(condition_graph);
|
||||
SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true);
|
||||
// if false graph is a condition graph and has been switch compiled before,it's false should be updated again
|
||||
auto cond_it = switches_.find(false_graph_id);
|
||||
|
@ -991,6 +999,9 @@ void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id,
|
|||
cond_graph_id = cond_it->first;
|
||||
false_graph_id = cond_it->second.second;
|
||||
condition_graph = GetGraph(cond_graph_id);
|
||||
if (condition_graph == nullptr) {
|
||||
continue;
|
||||
}
|
||||
SetStreamDistinctionLabel(GetGraph(false_graph_id), condition_graph->stream_distinction_label(), true);
|
||||
cond_it = switches_.find(false_graph_id);
|
||||
}
|
||||
|
@ -1427,6 +1438,7 @@ static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph,
|
|||
// count the output of every anf node
|
||||
std::set<AnfNodePtr> has_output_nodes;
|
||||
for (auto &anf_node : list) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
for (auto &input : anf_node->inputs()) {
|
||||
(void)has_output_nodes.insert(input);
|
||||
}
|
||||
|
@ -1435,12 +1447,13 @@ static void ConstructSplitedGraphOutput(const KernelGraphPtr &new_kernel_graph,
|
|||
auto make_tuple_primitve = NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()));
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {make_tuple_primitve};
|
||||
int output_idx = 0;
|
||||
MS_EXCEPTION_IF_NULL(new_kernel_graph);
|
||||
for (auto &anf_node : list) {
|
||||
if (AnfAlgo::CheckPrimitiveType(anf_node, prim::kPrimReturn)) {
|
||||
new_kernel_graph->set_return(anf_node);
|
||||
}
|
||||
if (has_output_nodes.find(anf_node) == has_output_nodes.end()) {
|
||||
MS_LOG(INFO) << "output[" << output_idx++ << "]:" << anf_node->DebugString();
|
||||
MS_LOG(INFO) << "Output[" << output_idx++ << "]:" << anf_node->DebugString();
|
||||
make_tuple_inputs.push_back(anf_node);
|
||||
}
|
||||
}
|
||||
|
@ -1458,6 +1471,7 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
|
|||
std::vector<AnfNodePtr> new_graph_inputs;
|
||||
// create new parameter from cnode
|
||||
for (auto &anf_node : list) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
|
||||
auto input = cnode->inputs()[input_idx];
|
||||
|
@ -1536,12 +1550,12 @@ AnfNodePtr AscendSession::BindNewCallToNewGraph(NotNull<KernelGraphPtr> graph,
|
|||
auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list);
|
||||
std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input));
|
||||
auto new_call = graph->NewCNode(new_call_input);
|
||||
AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call);
|
||||
AnfAlgo::SetNodeAttr("graph_id", MakeValue(graph->graph_id()), new_call);
|
||||
return new_call;
|
||||
}
|
||||
|
||||
void AscendSession::SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims) {
|
||||
MS_LOG(INFO) << "start,graph_id:" << graph->graph_id();
|
||||
MS_LOG(INFO) << "Start,graph_id:" << graph->graph_id();
|
||||
auto apply_list = GetCNodes(TopoSort(graph->get_return()));
|
||||
// update the root graph child graph order
|
||||
AscendControlParser::UpdateChildGraphOrder(graph);
|
||||
|
|
|
@ -780,12 +780,12 @@ void SessionBasic::Summary(KernelGraph *graph) {
|
|||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
GetSummaryNodes(graph);
|
||||
auto summary_outputs = graph->summary_nodes();
|
||||
// do not exist summary node
|
||||
if (summary_outputs.empty()) {
|
||||
bool exist_summary = graph->summary_node_exist();
|
||||
if (!exist_summary) {
|
||||
return;
|
||||
}
|
||||
GetSummaryNodes(graph);
|
||||
auto summary_outputs = graph->summary_nodes();
|
||||
std::map<std::string, tensor::TensorPtr> params_list;
|
||||
// fetch outputs apply kernel in session & run callback functions
|
||||
for (auto &output_item : summary_outputs) {
|
||||
|
|
|
@ -229,6 +229,7 @@ const int kValueNodeTensorMask = 2;
|
|||
|
||||
// define special index in special node
|
||||
constexpr auto kAnfPrimitiveIndex = 0;
|
||||
constexpr auto kFirstDataInputIndex = 1;
|
||||
constexpr auto kAnfPartialFuncGraphIndex = 1;
|
||||
constexpr auto kRealInputNodeIndexInTupleGetItem = 1;
|
||||
constexpr auto kInputNodeOutputIndexInTupleGetItem = 2;
|
||||
|
|
Loading…
Reference in New Issue