diff --git a/mindspore/ccsrc/device/kernel_adjust.cc b/mindspore/ccsrc/device/kernel_adjust.cc index 596cf6790dd..1bd384ff133 100644 --- a/mindspore/ccsrc/device/kernel_adjust.cc +++ b/mindspore/ccsrc/device/kernel_adjust.cc @@ -375,18 +375,16 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP( return assign_add_one; } -bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &context, - const std::shared_ptr &kernel_graph_ptr) { +bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr) { if (!NeedInsertSwitch()) { return true; } - MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(kernel_graph_ptr); auto input_nodes = kernel_graph_ptr->inputs(); std::vector inputs; LoadSwitchInputs(&inputs); std::shared_ptr> inputsPtr = std::make_shared>(inputs); - context->SetResult(session::kInputCtrlTensors, inputsPtr); + kernel_graph_ptr->set_input_ctrl_tensors(inputsPtr); size_t input_ctrl_size = inputs.size(); // inputs_node:include four ctrl nodes in the back. such as:conv,loop_cnt, ites_loop, zero, one. // deal four ctrl nodes. diff --git a/mindspore/ccsrc/device/kernel_adjust.h b/mindspore/ccsrc/device/kernel_adjust.h index 4c69641a340..33e3a2007ca 100644 --- a/mindspore/ccsrc/device/kernel_adjust.h +++ b/mindspore/ccsrc/device/kernel_adjust.h @@ -53,8 +53,7 @@ class KernelAdjust { } void Reorder(const std::shared_ptr &kernel_graph_ptr); void InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr); - bool StepLoadCtrlInputs(const std::shared_ptr &context, - const std::shared_ptr &kernel_graph_ptr); + bool StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr); void Profiling(NotNull kernel_graph_ptr); static bool NeedInsertSwitch(); CNodePtr CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr); diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 76bc0752f28..26dde71f787 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -517,7 +517,7 @@ void AscendSession::RunOpMemoryAlloc(const std::vector &input void AscendSession::GenerateTaskInfo(const std::shared_ptr &kernel_graph) const { MS_LOG(INFO) << "Start!"; - (void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(context_, kernel_graph); + (void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); bool ret_ok = runtime_instance->GenTask(kernel_graph.get()); diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h old mode 100755 new mode 100644 index b0f27635d02..2fe9a9517b9 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -107,6 +107,12 @@ class KernelGraph : public FuncGraph { std::vector> child_graph_order() const { return child_graph_order_; } // checkout whether current graph is leaf graph bool IsLeafGraph() const; + // set input_tensors pointer of control parameter + void set_input_ctrl_tensors(const std::shared_ptr> &input_tensors_ptr) { + input_ctrl_tensors_ = input_tensors_ptr; + } + // get input_tensors pointer of control parameter + std::shared_ptr> input_ctrl_tensors() const { return input_ctrl_tensors_; } private: // remove value node form graph @@ -150,6 +156,8 @@ class KernelGraph : public FuncGraph { std::map> node_to_child_graphs_; // child graph execute order in root graph std::vector> child_graph_order_; + // input_tensors of control parameter + std::shared_ptr> input_ctrl_tensors_; }; } // namespace session using KernelGraphPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index aef85b175b8..ad738621825 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -266,23 +266,12 @@ AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, return make_tuple; } -bool NeedInsertSwitch() { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - return (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && - ConfigManager::GetInstance().iter_num() > 1); -} - -size_t LoadCtrlInputTensor(const std::shared_ptr &context, std::vector *inputs) { - MS_EXCEPTION_IF_NULL(context); - if (!NeedInsertSwitch()) { - (void)context->results_.erase(kInputCtrlTensors); +size_t LoadCtrlInputTensor(const std::shared_ptr &graph, std::vector *inputs) { + MS_LOG(INFO) << "Load kInputCtrlTensors"; + auto inputs_params = graph->input_ctrl_tensors(); + if (inputs_params == nullptr) { return 0; } - MS_LOG(INFO) << "Load kInputCtrlTensors"; - auto inputs_params = - context->GetResult(kInputCtrlTensors).cast>>(); - MS_EXCEPTION_IF_NULL(inputs_params); if (inputs_params->empty()) { MS_LOG(EXCEPTION) << "Illegal empty inputs_params"; } @@ -686,11 +675,10 @@ void SessionBasic::LoadInputData(const std::shared_ptr &kernel_grap const std::vector &inputs_const) const { std::vector inputs(inputs_const); size_t input_ctrl_size = 1; - MS_EXCEPTION_IF_NULL(context_); - if (context_->HasResult(kInputCtrlTensors)) { - input_ctrl_size = LoadCtrlInputTensor(context_, &inputs); - } MS_EXCEPTION_IF_NULL(kernel_graph); + if (kernel_graph->input_ctrl_tensors()) { + input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); + } auto input_nodes = kernel_graph->inputs(); if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) { MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() diff --git a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc index 5f195d6b3a9..619f2385b4e 100755 --- a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc +++ b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc @@ -39,8 +39,7 @@ bool TaskGenerator::GenTasks(const std::vector &anf_node_list, std::ve } // namespace ascend void KernelAdjust::Reorder(const std::shared_ptr &kernel_graph_ptr) { return; } void KernelAdjust::InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr) { return; } -bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &context, - const std::shared_ptr &kernel_graph_ptr) { +bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &kernel_graph_ptr) { return true; } bool KernelAdjust::NeedInsertSwitch() { return true; }