From 61464245965ef005a12a22c8e17016c993209681 Mon Sep 17 00:00:00 2001 From: rick_sanchez Date: Thu, 30 Apr 2020 15:55:52 +0800 Subject: [PATCH] refactor vm module for multigraph sink --- mindspore/ccsrc/session/ascend_session.cc | 73 ++++++++++++++-------- mindspore/ccsrc/session/ascend_session.h | 4 ++ mindspore/ccsrc/vm/transform.cc | 2 +- mindspore/ccsrc/vm/vm.cc | 74 ++++++++++++++++++----- mindspore/ccsrc/vm/vm.h | 5 ++ tests/st/control/test_multigraph_sink.py | 6 ++ 6 files changed, 123 insertions(+), 41 deletions(-) mode change 100755 => 100644 mindspore/ccsrc/session/ascend_session.cc diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc old mode 100755 new mode 100644 index 45f02d75483..1311cb465e7 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -564,42 +564,67 @@ AnfNodePtr AscendSession::CreateFakeOutput(GraphId fake_graph_id, const AnfNodeP return create_parameter_from_cnode(output_item_with_index.first, output_item_with_index.second); } -void AscendSession::SetFinalGraphOutput(const BaseRef &output) { - auto final_graph = GetGraph(final_graph_id_); - MS_EXCEPTION_IF_NULL(final_graph); - if (!utils::isa(output)) { - if (!utils::isa(output)) { - MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); - } - auto value_ptr = utils::cast(output); - auto value_node = NewValueNode(value_ptr); - MS_EXCEPTION_IF_NULL(value_node); - auto kernel_info = std::make_shared(); - value_node->set_kernel_info(kernel_info); - value_node->set_abstract(abstract::FromValue(value_ptr)); - final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node})); - final_graph->set_executable(false); - MS_LOG(INFO) << "Not anf output[" << output.ToString() << "]"; - return; - } +void AscendSession::SetFinalGraphOutput(const AnfNodePtr &node) { // get the backend anf node related to the output node of front - auto output_anf_node = utils::cast(output); - auto output_from_graph_id = GetGraphIdByNode(output_anf_node); + auto output_from_graph_id = GetGraphIdByNode(node); auto output_from_graph = GetGraph(output_from_graph_id); - MS_EXCEPTION_IF_NULL(output_anf_node); - MS_LOG(INFO) << "Set the output[" << output_anf_node->DebugString() << "] of graph[" << output_from_graph_id + MS_EXCEPTION_IF_NULL(node); + MS_LOG(INFO) << "Set the output[" << node->DebugString() << "] of graph[" << output_from_graph_id << "] to final graph"; MS_EXCEPTION_IF_NULL(output_from_graph); + auto final_graph = GetGraph(final_graph_id_); + MS_EXCEPTION_IF_NULL(final_graph); // if output is from final graph,it remarks no child graph exist if (final_graph_id_ == output_from_graph_id) { - MS_LOG(INFO) << "No child graph,output is " << output_anf_node->DebugString(); - final_graph->set_output(ConstructOutput({output_anf_node}, final_graph)); + MS_LOG(INFO) << "No child graph,output is " << node->DebugString(); + final_graph->set_output(ConstructOutput({node}, final_graph)); final_graph->set_executable(false); return; } final_graph->set_output(output_from_graph->output()); } +void AscendSession::SetFinalGraphOutput(const ValuePtr &value) { + auto value_node = NewValueNode(value); + auto kernel_info = std::make_shared(); + value_node->set_kernel_info(kernel_info); + value_node->set_abstract(abstract::FromValue(value)); + auto final_graph = GetGraph(final_graph_id_); + MS_EXCEPTION_IF_NULL(final_graph); + final_graph->set_output(final_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), value_node})); + final_graph->set_executable(false); + MS_LOG(INFO) << "Not anf output[" << value->ToString() << "]"; +} + +void AscendSession::SetFinalGraphOutput(const VectorRef &vec_output) { + for (auto &output : vec_output) { + if (utils::isa(output)) { + auto output_anf_node = utils::cast(output); + SetFinalGraphOutput(output_anf_node); + } else if (utils::isa(output)) { + auto value = utils::cast(output); + SetFinalGraphOutput(value); + } else { + MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); + } + } +} + +void AscendSession::SetFinalGraphOutput(const BaseRef &output) { + if (utils::isa(output)) { + auto output_anf_node = utils::cast(output); + SetFinalGraphOutput(output_anf_node); + } else if (utils::isa(output)) { + auto value = utils::cast(output); + SetFinalGraphOutput(value); + } else if (utils::isa(output)) { + auto vec_output = utils::cast(output); + SetFinalGraphOutput(vec_output); + } else { + MS_LOG(EXCEPTION) << "Unknown output type:" << output.ToString(); + } +} + KernelGraphPtr AscendSession::GetGraph(mindspore::GraphId graph_id) { auto it = graphs_.find(graph_id); if (it == graphs_.end()) { diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index 4f695b2436f..4ab7797257a 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -88,6 +88,10 @@ class AscendSession : public SessionBasic { size_t SetChildGraphInput(const KernelGraphPtr &graph, const ValuePtr &value, size_t input_index); size_t SetChildGraphInput(const KernelGraphPtr &graph, const VectorRef &vec_args, size_t input_index); + void SetFinalGraphOutput(const AnfNodePtr &node); + void SetFinalGraphOutput(const ValuePtr &value); + void SetFinalGraphOutput(const VectorRef &vec_output); + // merge execution order list of child graphs void MergeGraphExecOrder(); // insert assion op to sync data bettween different graphs diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 9147f75fb25..e8b47a4bcdd 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -243,7 +243,7 @@ void CompileGraph::AddSinkSwitch(const CNodePtr &node) { AddInst(Instruction::kCall, args); args.clear(); - args.emplace_back(true); + args.emplace_back(node->input(1)); AddInst(Instruction::kSwitchReturn, args); args.clear(); diff --git a/mindspore/ccsrc/vm/vm.cc b/mindspore/ccsrc/vm/vm.cc index cf52aafdfef..3a34eba186f 100644 --- a/mindspore/ccsrc/vm/vm.cc +++ b/mindspore/ccsrc/vm/vm.cc @@ -141,17 +141,31 @@ void FinalVM::Popsp() { } } +void FinalVM::PushStatus(bool is_switch_call) { ret_status_.push(is_switch_call); } + +bool FinalVM::PopStatus() { + if (ret_status_.empty()) { + return false; + } + bool status = ret_status_.top(); + ret_status_.pop(); + return status; +} + void FinalVM::DoJmp(const BaseRef &jmp_orig) { MS_LOG(DEBUG) << "Start"; BaseRef jmp = jmp_orig; if (backend_->simu_flag()) { + bool is_switch_call = false; if (utils::isa(jmp)) { // need to inherit from Base MS_LOG(DEBUG) << "Start jump StructSwitch"; auto simu_value = utils::cast>(jmp); jmp = simu_value->fn_; backend_->set_curr_switch(simu_value->value_); + is_switch_call = true; } + PushStatus(is_switch_call); } if (utils::isa(jmp)) { // need to inherit from Base @@ -255,6 +269,13 @@ void FinalVM::InstSwitchReturn(const VectorRef &args) { MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << "."; return; } + + auto rv = Ref(-1); + if (utils::isa(rv) || utils::isa(rv)) { + auto &c = args[0]; + cond_out_[c] = rv; + } + Pop(1); Popsp(); } @@ -272,8 +293,20 @@ void FinalVM::InstReturn(const VectorRef &args) { int height = utils::cast(args[1]); auto rv = Ref(rpos); - if (backend_->simu_flag() && backend_->is_switch_call()) { - backend_->SetSwitchGraph(); + if (backend_->simu_flag()) { + auto c = backend_->curr_switch(); + auto status = PopStatus(); + if (status) { + auto iter = cond_out_.find(c); + if (iter != cond_out_.end()) { + rv = MergeArgs(rv, iter->second); + cond_out_.erase(iter); + } + } + + if (backend_->is_switch_call()) { + backend_->SetSwitchGraph(); + } } Pop(height); @@ -383,23 +416,32 @@ void FinalVM::MergeJmpArgs(const BaseRef &jmp, const BaseRef &c) { for (size_t i = 0; i < new_args.size(); ++i) { auto &old_arg = old_args[i]; auto &new_arg = new_args[i]; - if (utils::isa(old_arg)) { - auto old_vec_ref = utils::cast(old_arg); - if (utils::isa(new_arg)) { - auto new_vec_ref = utils::cast(new_arg); - std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref)); - } - new_arg = old_vec_ref; - } else if (utils::isa(new_arg)) { - auto new_vec_ref = utils::cast(new_arg); - new_vec_ref.push_back(old_arg); - new_arg = new_vec_ref; - } else { - new_arg = VectorRef({new_arg, old_arg}); - } + new_arg = MergeArgs(old_arg, new_arg); } } +BaseRef FinalVM::MergeArgs(const BaseRef &first, const BaseRef &second) { + MS_LOG(DEBUG) << __FUNCTION__ << ": " << first.ToString() << ", " << second.ToString(); + if (utils::isa(first)) { + auto old_vec_ref = utils::cast(first); + if (utils::isa(second)) { + auto new_vec_ref = utils::cast(second); + std::copy(new_vec_ref.begin(), new_vec_ref.end(), std::back_inserter(old_vec_ref)); + } else { + old_vec_ref.push_back(second); + } + return old_vec_ref; + } + + if (utils::isa(second)) { + auto new_vec_ref = utils::cast(second); + new_vec_ref.push_back(first); + return new_vec_ref; + } + + return VectorRef({first, second}); +} + void FinalVM::InstRealSwitch(const VectorRef &args) { const size_t args_size = 3; if (args.size() != args_size) { diff --git a/mindspore/ccsrc/vm/vm.h b/mindspore/ccsrc/vm/vm.h index a02eced44c5..a9832ab5ea3 100644 --- a/mindspore/ccsrc/vm/vm.h +++ b/mindspore/ccsrc/vm/vm.h @@ -125,17 +125,22 @@ class FinalVM { void Popp(); void Pushsp(); void Popsp(); + void PushStatus(bool is_switch_call); + bool PopStatus(); void DoJmp(const BaseRef &jmp); void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c); + BaseRef MergeArgs(const BaseRef &first, const BaseRef &second); private: InstSet insts_; std::deque insts_stack_; std::stack retp_; std::stack retsp_; + std::stack ret_status_; int pc_; int sp_; std::unordered_map cond_jmp_; + std::unordered_map cond_out_; BackendPtr backend_; const InstFunctionMap inst_function_map = { {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }}, diff --git a/tests/st/control/test_multigraph_sink.py b/tests/st/control/test_multigraph_sink.py index 2b9a1a020aa..b145fb18f61 100644 --- a/tests/st/control/test_multigraph_sink.py +++ b/tests/st/control/test_multigraph_sink.py @@ -26,6 +26,7 @@ from mindspore.ops import operations as P def setup_module(module): context.set_context(mode = context.PYNATIVE_MODE, device_target = "Ascend") + c1 = Tensor([2], mstype.int32) c2 = Tensor([14], mstype.int32) c3 = Tensor([1], mstype.int32) @@ -149,6 +150,10 @@ def test_if_by_if(): assert output == expect +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard def test_if_in_if(): output = if_in_if(c1, c2, c3) expect = Tensor([7], mstype.int32) @@ -194,6 +199,7 @@ def test_while_by_while_in_while(): expect = Tensor([350], mstype.int32) assert output == expect + @pytest.mark.level0 @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_arm_ascend_training