From d1a0ded6c2f5c9077464ff67e6d3231680cba00f Mon Sep 17 00:00:00 2001 From: chenfei Date: Fri, 27 Mar 2020 14:49:16 +0800 Subject: [PATCH] use first depend create parameter --- mindspore/ccsrc/operator/ops.cc | 3 + mindspore/ccsrc/operator/ops.h | 3 + mindspore/ccsrc/session/ascend_session.cc | 59 +++++++++++---- mindspore/ccsrc/session/ascend_session.h | 6 +- mindspore/ccsrc/session/kernel_graph.cc | 3 +- mindspore/ccsrc/session/kernel_graph.h | 5 ++ mindspore/ccsrc/session/session_basic.cc | 90 ++++++++++++++++------- mindspore/ccsrc/session/session_basic.h | 5 +- mindspore/ccsrc/vm/backend.cc | 2 +- 9 files changed, 128 insertions(+), 48 deletions(-) mode change 100644 => 100755 mindspore/ccsrc/operator/ops.cc mode change 100644 => 100755 mindspore/ccsrc/operator/ops.h mode change 100644 => 100755 mindspore/ccsrc/session/ascend_session.cc mode change 100644 => 100755 mindspore/ccsrc/session/ascend_session.h mode change 100644 => 100755 mindspore/ccsrc/session/kernel_graph.cc mode change 100644 => 100755 mindspore/ccsrc/session/kernel_graph.h mode change 100644 => 100755 mindspore/ccsrc/session/session_basic.cc mode change 100644 => 100755 mindspore/ccsrc/session/session_basic.h mode change 100644 => 100755 mindspore/ccsrc/vm/backend.cc diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc old mode 100644 new mode 100755 index 12e6b70a6f7..f3053cac7d5 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -154,6 +154,9 @@ const PrimitivePtr kPrimMul = std::make_shared("Mul"); const PrimitivePtr kPrimMinimum = std::make_shared("Minimum"); const PrimitivePtr kPrimMaximum = std::make_shared("Maximum"); const PrimitivePtr kPrimSquare = std::make_shared("Square"); +const PrimitivePtr kPrimEqual = std::make_shared("Equal"); +const PrimitivePtr kPrimLess = std::make_shared("Less"); +const PrimitivePtr kPrimLessEqual = std::make_shared("LessEqual"); // NN const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h old mode 100644 new mode 100755 index 5fbf2b70679..2dc7072972f --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -160,6 +160,9 @@ extern const PrimitivePtr kPrimMul; extern const PrimitivePtr kPrimMinimum; extern const PrimitivePtr kPrimMaximum; extern const PrimitivePtr kPrimSquare; +extern const PrimitivePtr kPrimEqual; +extern const PrimitivePtr kPrimLess; +extern const PrimitivePtr kPrimLessEqual; // NN extern const PrimitivePtr kPrimFlatten; diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc old mode 100644 new mode 100755 index 34c05aed088..f255b2f15fd --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -506,11 +506,13 @@ void AscendSession::InsertSwitchToGraph(GraphId condition_graph_id, GraphId true kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE); kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE); kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); - // condition graph's output must be single output - if (condition_graph->outputs().size() != 1) { - MS_LOG(EXCEPTION) << "Condition_graph output num " << condition_graph_id << " should be 1"; + auto cond_output_it = condition_output_.find(condition_graph_id); + if (cond_output_it == condition_output_.end()) { + MS_LOG(EXCEPTION) << "Can't find condition graph" << condition_graph_id; } - AnfNodePtr cond_output_kernel = condition_graph->outputs()[0]; + auto cond_output_kernel = + AnfAlgo::VisitKernel(condition_graph->GetBackendAnfByFrontAnf(cond_output_it->second), 0).first; + MS_EXCEPTION_IF_NULL(cond_output_kernel); std::vector inputs = {NewValueNode(switch_primitive), cond_output_kernel, counter_const}; CNodePtr switch_node = condition_graph->NewCNode(inputs); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), switch_node.get()); @@ -569,12 +571,14 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) { } } -void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id) { +void AscendSession::SwitchCompile(GraphId cond_graph_id, GraphId true_graph_id, GraphId false_graph_id, + const AnfNodePtr &output) { if (switches_.find(cond_graph_id) != switches_.end()) { MS_LOG(WARNING) << "Condition graph" << cond_graph_id << " has been set before "; return; } switches_[cond_graph_id] = std::pair(true_graph_id, false_graph_id); + condition_output_[cond_graph_id] = output; MS_LOG(INFO) << "New switch compile " << cond_graph_id << " " << true_graph_id << " " << false_graph_id; // set the type of condition graph auto cond_graph_index = ExecOrderOfChildGraph(final_graph_id_, cond_graph_id); @@ -682,12 +686,14 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An auto from_graph_id = GetGraphIdByNode(front_anf); auto from_graph = GetGraph(from_graph_id); MS_EXCEPTION_IF_NULL(from_graph); - + auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get()); + auto to_graph = GetGraph(to_graph_id); + auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); + MS_EXCEPTION_IF_NULL(to_graph); MS_LOG(INFO) << "Set node[" << front_anf->DebugString() << "] of graph[" << from_graph_id << "]to node[" << backend_parameter->DebugString() << "] of graph[" << AnfAlgo::GetGraphId(backend_parameter.get()) << "]"; // a node should not assign to itself - auto backend_arg = from_graph->GetBackendAnfByFrontAnf(front_anf); if (backend_arg.get() == backend_parameter.get()) { return; } @@ -703,15 +709,16 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, const An return; } } - InsertMultipleAssignToGraph(from_graph_id, backend_arg, backend_parameter); - // if front anf is a parameter, we can assign the value back, because backend_parameter - // won't be changed in it's graph unless it's a weight. If backend_parameter is a weight, - // we do should assign the value back. - auto to_graph_id = AnfAlgo::GetGraphId(backend_parameter.get()); - auto to_graph = GetGraph(to_graph_id); - MS_EXCEPTION_IF_NULL(to_graph); + // if a parameter is a weight and not linked to any executable node,device type will be kTypeUnknown,set it's device + // type same to arg + if (AnfAlgo::GetOutputDeviceDataType(backend_parameter, 0) == kTypeUnknown) { + AnfAlgo::SetSelectKernelBuildInfo(AnfAlgo::GetSelectKernelBuildInfo(backend_arg), backend_parameter.get()); + } + InsertAssignToGraph(from_graph_id, backend_arg, backend_parameter); + // if front anf is a parameter,we can assign the value back,because backend_parameter won't be change in it's graph + // unless it's a weigth.If backend_parameter is a weight,we do should assign the value back if (backend_arg->isa() && !to_graph->execution_order().empty()) { - InsertMultipleAssignToGraph(to_graph_id, backend_parameter, backend_arg); + InsertAssignToGraph(to_graph_id, backend_parameter, backend_arg); } MS_LOG(INFO) << "Finish!"; } @@ -755,7 +762,25 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) { DumpGraphInputArgs(args); UpdateGraphOrder(g); std::vector graph_inputs = to_graph->inputs(); + auto valid_inputs = to_graph->ValidInputs(); + size_t real_args_size = 0; + for (size_t i = 0; i < args.size(); i++) { + real_args_size += AnfAlgo::GetAllOutput(utils::cast(args[i]), {prim::kPrimTupleGetItem}).size(); + } + if (real_args_size != graph_inputs.size()) { + for (size_t j = 0; j < valid_inputs.size(); j++) { + if (valid_inputs[j]) { + MS_LOG(INFO) << "index: " << j << ", nodes: " << graph_inputs[j]->DebugString(); + } + } + MS_LOG(WARNING) << "real_args_size: " << real_args_size << ", graph_inputs.size(): " << graph_inputs.size() + << " not equal"; + } size_t input_index = 0; + if (graph_inputs.size() != valid_inputs.size()) { + MS_LOG(EXCEPTION) << "graph_inputs.size(): " << graph_inputs.size() + << ", valid_inputs.size(): " << valid_inputs.size() << " not equal"; + } for (size_t i = 0; i < args.size(); i++) { if (input_index >= graph_inputs.size()) { MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size(); @@ -763,6 +788,10 @@ void AscendSession::SetChildGraphInput(GraphId g, const VectorRef &args) { if (utils::isa(args[i])) { // arg is a anf node for (const auto &real_arg : AnfAlgo::GetAllOutput(utils::cast(args[i]), {prim::kPrimTupleGetItem})) { + if (!valid_inputs[input_index]) { + MS_LOG(DEBUG) << "Invalid input arg" << real_arg->DebugString(); + continue; + } SetChildGraphParameter(real_arg, graph_inputs[input_index]); input_index++; } diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h old mode 100644 new mode 100755 index caec4b35f76..c45ab6630a4 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -49,9 +49,8 @@ class AscendSession : public SessionBasic { // set output of final graph void SetFinalGraphOutput(const BaseRef &output) override; // insert switch and set the relative active ops - void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g) override; - // set args of child graph. the arg maybe come from a output of other child graphs, - // or from final graph's parameter + void SwitchCompile(GraphId cond_g, GraphId true_g, GraphId false_g, const AnfNodePtr &condition_output) override; + // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter void SetChildGraphInput(GraphId g, const VectorRef &args) override; // get graph id in child graphs by ME front anf node pointer GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override; @@ -116,6 +115,7 @@ class AscendSession : public SessionBasic { std::unordered_map while_condition_graphs_; // record all conditions std::unordered_map> switches_; + std::unordered_map condition_output_; // final_graph_id is used in every root graph has it's own session situation GraphId final_graph_id_; }; diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc old mode 100644 new mode 100755 index 84ff6b81a27..dbf6e07e7e1 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -372,8 +372,7 @@ void KernelGraph::UpdateControlDependRelations(const std::vector &de MS_EXCEPTION_IF_NULL(depend_node); std::vector prior_nodes = {prior_node}; std::vector depend_nodes = {depend_node}; - MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "],depend node[" << depend_node->DebugString() - << "],depend_mode=[" << AnfAlgo::GetNodeAttr(cnode, "depend_mode") << "]"; + MS_LOG(INFO) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString(); if (prior_node->isa()) { prior_nodes = GetOutputNodes(prior_node); } diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h old mode 100644 new mode 100755 index e11f6807f54..ff964482bba --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -86,6 +86,9 @@ class KernelGraph : public FuncGraph { bool executable() const { return executable_; } // set executable of graph void set_executable(bool executable) { executable_ = executable; } + // set invalid inputs for control sink + std::vector *MutableValidInputs() { return &valid_inputs_; } + std::vector ValidInputs() { return valid_inputs_; } private: // remove value node form graph @@ -118,6 +121,8 @@ class KernelGraph : public FuncGraph { std::unordered_map>> node_output_edges_; // graph needn't execute bool executable_; + // valid inputs + std::vector valid_inputs_; }; } // namespace session using KernelGraphPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc old mode 100644 new mode 100755 index ede3ae7419a..d2a255229d7 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -243,29 +243,38 @@ ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) { return new_value_node; } -ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) { +ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(anf); if (!anf->isa()) { MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter"; } auto graph_inputs = graph->MutableInputs(); MS_EXCEPTION_IF_NULL(graph_inputs); + auto valid_inputs = graph->MutableValidInputs(); + MS_EXCEPTION_IF_NULL(valid_inputs); ParameterPtr new_parameter = graph->NewParameter(anf->cast()); - graph->FrontBackendlMapAdd(anf, new_parameter); graph_inputs->push_back(new_parameter); + valid_inputs->push_back(valid_input); return new_parameter; } -std::vector CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) { +std::vector CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(graph); std::vector parameters; std::vector pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); + auto valid_inputs = graph->MutableValidInputs(); + MS_EXCEPTION_IF_NULL(valid_inputs); + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); auto create_parameter = [&](const AbstractBasePtr &abstract) -> void { auto parameter = graph->NewParameter(); MS_EXCEPTION_IF_NULL(parameter); parameter->set_abstract(abstract); - parameters.push_back(graph->NewParameter(parameter)); + auto new_parameter = graph->NewParameter(parameter); + parameters.push_back(new_parameter); + valid_inputs->push_back(valid_input); + graph_inputs->push_back(new_parameter); }; for (const auto &out_node : pre_graph_out) { MS_EXCEPTION_IF_NULL(out_node); @@ -287,18 +296,15 @@ std::vector CreateParameterFromTuple(const AnfNodePtr &node, KernelG return parameters; } -AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) { +AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) { MS_EXCEPTION_IF_NULL(anf); if (!anf->isa()) { - MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a cnode"; + MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a cnode"; } - MS_LOG(INFO) << "create a new parameter from cnode[" << anf->DebugString() << "]"; - auto parameters = CreateParameterFromTuple(anf, graph); - auto graph_inputs = graph->MutableInputs(); - MS_EXCEPTION_IF_NULL(graph_inputs); - (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(*graph_inputs)); + MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]"; + auto parameters = CreateParameterFromTuple(anf, valid_input, graph); if (parameters.empty()) { - MS_LOG(EXCEPTION) << "no parameter exist!!"; + MS_LOG(EXCEPTION) << "No parameter exist!!"; } if (parameters.size() == 1) { return parameters[0]; @@ -307,7 +313,7 @@ AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input)); auto make_tuple = graph->NewCNode(make_tuple_input); MS_EXCEPTION_IF_NULL(make_tuple); - MS_LOG(INFO) << "new make tuple [" << make_tuple->DebugString() << "] of parameters"; + MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters"; return make_tuple; } @@ -397,14 +403,20 @@ void DumpGraphOutput(const Any &any, size_t recurse_level = 0) { GraphId SessionBasic::graph_sum_ = 0; -CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) { +CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, + bool *from_other_graph, + std::unordered_map *other_graph_cnode) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(from_other_graph); + MS_EXCEPTION_IF_NULL(other_graph_cnode); + *from_other_graph = false; // get primitive of old node auto prim = AnfAlgo::GetCNodePrimitive(cnode); MS_EXCEPTION_IF_NULL(prim); // push attr to inputs[0] of new cnode std::vector cnode_inputs = {std::make_shared(std::make_shared(*prim))}; + // if has multiple depends,only select first depend as parameter for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { auto anf = cnode->inputs()[input_idx]; MS_EXCEPTION_IF_NULL(anf); @@ -412,6 +424,9 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); continue; + } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) { + cnode_inputs.push_back((*other_graph_cnode)[anf]); + continue; } else if (anf->isa() && !IsValueNode(anf)) { // if input is a value node, auto new_value_node = CreateNewValueNode(anf, graph); @@ -421,38 +436,60 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) continue; } else if (anf->isa()) { // if anf is a parameter - cnode_inputs.emplace_back(CreateNewParameterFromParameter(anf, graph)); + auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); + cnode_inputs.push_back(new_parameter); + if (GetGraphIdByNode(anf) == kInvalidGraphId) { + graph->FrontBackendlMapAdd(anf, new_parameter); + } else { + (*other_graph_cnode)[anf] = new_parameter; + } continue; } else if (anf->isa()) { + *from_other_graph = true; // the input node is a cnode from other graph - cnode_inputs.emplace_back(CreateNewParameterFromCNode(anf, graph)); + auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph); + cnode_inputs.push_back(parameter_from_cnode); + (*other_graph_cnode)[anf] = parameter_from_cnode; continue; } - MS_LOG(EXCEPTION) << "unexpected input[" << anf->DebugString() << "]"; + MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; } - return graph->NewCNode(cnode_inputs); + TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); + auto new_cnode = graph->NewCNode(cnode_inputs); + TraceManager::EndTrace(); + return new_cnode; } KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { + std::unordered_map other_graph_cnode; auto graph = std::make_shared(); graph->set_graph_id(graph_sum_); + MS_LOG(INFO) << "Create graph: " << graph_sum_; + size_t from_other_graph_depend_num = 0; for (const auto &node : lst) { MS_EXCEPTION_IF_NULL(node); - MS_LOG(DEBUG) << "start create new cnode,node = " << node->DebugString(); + MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); if (!node->isa()) { - MS_LOG(EXCEPTION) << "Inst node " << node->DebugString() << " is not CNode"; + MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode"; } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - TraceManager::DebugTrace(std::make_shared(cnode->debug_info())); // create a new cnode object - auto new_cnode = CreateNewCNode(cnode, graph.get()); + bool from_other_graph = false; + // only first depend from other graph can create + bool valid_input = true; + if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { + valid_input = false; + } + auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode); + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) { + from_other_graph_depend_num++; + } MS_EXCEPTION_IF_NULL(new_cnode); new_cnode->set_abstract(cnode->abstract()); new_cnode->set_scope(cnode->scope()); // record map relations between anf from ME and new anf node used in backend graph->FrontBackendlMapAdd(node, new_cnode); - TraceManager::EndTrace(); } // add a make_tuple at the end of graph as output graph->set_output(ConstructOutput(outputs, graph)); @@ -631,12 +668,15 @@ void SessionBasic::ToTensorPtr(const OpRunInfo &op_run_info, std::vector &graph) { MS_EXCEPTION_IF_NULL(graph); std::vector output_args; - auto FindEqu = [graph](const AnfNodePtr &out) -> AnfNodePtr { + auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { auto backend_anf = graph->GetBackendAnfByFrontAnf(out); if (backend_anf != nullptr) { return backend_anf; } - MS_LOG(EXCEPTION) << "Can not find the node in the equiv map!"; + for (const auto &output : outputs) { + MS_LOG(INFO) << "output:" << output->DebugString(); + } + MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!"; }; output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args), diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h old mode 100644 new mode 100755 index 9aadb78cb27..f1872e375cb --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -69,14 +69,15 @@ class SessionBasic { std::shared_ptr ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); - CNodePtr CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph); + CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph, + std::unordered_map *other_graph_cnode); // set parameters of final graph virtual GraphId SetFinalGraphInput(const std::vector &) { return kInvalidGraphId; } // set output of final graph virtual void SetFinalGraphOutput(const BaseRef &) {} // insert switch and set the relative active ops - virtual void SwitchCompile(GraphId, GraphId, GraphId) {} + virtual void SwitchCompile(GraphId, GraphId, GraphId, const AnfNodePtr &) {} // set args of child graph.the arg maybe come from a output of other child graphs,or from final graph's parameter virtual void SetChildGraphInput(GraphId, const VectorRef &) {} // get graph id in child graphs by ME front anf node pointer diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc old mode 100644 new mode 100755 index 9355cca99cb..e69d25d2dcb --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -136,7 +136,7 @@ void MsBackend::SetSwitchGraph() { MS_LOG(EXCEPTION) << "cond not a anf node:" << curr_switch_.ToString(); } MS_LOG(DEBUG) << "switch compile:" << cond_g << ", " << true_g << ", " << false_g; - sess_->SwitchCompile(cond_g, true_g, false_g); + sess_->SwitchCompile(cond_g, true_g, false_g, utils::cast(curr_switch_)); } is_switch_call_ = false; MS_LOG(DEBUG) << "end SetSwitchGraph:" << curr_cond << ", " << is_switch_call_;