diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 6a12c810fd7..e4128f16c64 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -691,22 +691,27 @@ void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &g } void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, VectorRef *outputs) { + std::vector *input_tensors, VectorRef *outputs, + const std::vector &tensors_mask) { + MS_EXCEPTION_IF_NULL(input_tensors); + BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask); + EraseValueNodeTensor(tensors_mask, input_tensors); + auto graph = run_op_graphs_[graph_info]; MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; // malloc mem - RunOpMemoryAlloc(input_tensors, graph.get()); + RunOpMemoryAlloc(*input_tensors, graph.get()); // Build dynamic kernel if (op_run_info.is_dynamic_shape) { BuildDynamicKernel(graph); } // load input data to device - LoadInputData(graph, input_tensors); + LoadInputData(graph, *input_tensors); // run op Execute(graph, false); // get output - UpdateOutputs(graph, outputs, input_tensors); + UpdateOutputs(graph, outputs, *input_tensors); RunOpMemoryClear(graph.get()); MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; } @@ -736,7 +741,8 @@ void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector // Build and run current single op BuildOpImpl(run_info, graph_info, input_tensor_info.input_tensors, input_tensor_info.input_tensors_mask); VectorRef op_outputs; - RunOpImpl(run_info, graph_info, input_tensor_info.input_tensors, &op_outputs); + RunOpImpl(run_info, graph_info, &input_tensor_info.input_tensors, &op_outputs, + input_tensor_info.input_tensors_mask); // Handle inputs and outputs of current op HandleOpInputs(input_tensor_info.input_kernel, &cnode_ref, &op_output_map); diff --git a/mindspore/ccsrc/backend/session/ascend_session.h b/mindspore/ccsrc/backend/session/ascend_session.h index e46a98a6c69..4c94beba98f 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.h +++ b/mindspore/ccsrc/backend/session/ascend_session.h @@ -60,7 +60,8 @@ class AscendSession : public SessionBasic { const std::vector &input_tensors, const std::vector &tensors_mask) override; void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, VectorRef *outputs) override; + std::vector *input_tensors, VectorRef *outputs, + const std::vector &tensors_mask) override; void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index 7477cee6f0c..18e5d418933 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -125,14 +125,9 @@ void RunGraphTask::Run() { ExecutorManager::Instance().OnRunGraphFinished(); } -void BuildOpTask::Run() { - MS_EXCEPTION_IF_NULL(session_); - session_->BuildOpImpl(*op_run_info_, graph_info_, input_tensors_, tensors_mask_); -} - void RunOpTask::Run() { MS_EXCEPTION_IF_NULL(session_); - session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_); + session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_, tensors_mask_); } void RunOpsInGraphTask::Run() { @@ -340,25 +335,16 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, task_cond_var_.notify_all(); } -void Executor::BuildOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, const std::vector &tensors_mask) { - auto task = std::make_shared(); - task->session_ = session; - task->op_run_info_ = op_run_info; - task->graph_info_ = graph_info; - task->input_tensors_ = input_tensors; - task->tensors_mask_ = tensors_mask; - SyncRunTask(task); -} - void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, VectorRef *outputs) { + std::vector *input_tensors, VectorRef *outputs, + const std::vector &tensors_mask) { auto task = std::make_shared(); task->session_ = session; task->op_run_info_ = op_run_info; task->graph_info_ = graph_info; task->input_tensors_ = input_tensors; - for (auto &tensor : input_tensors) { + task->tensors_mask_ = tensors_mask; + for (auto &tensor : *input_tensors) { if (tensor->NeedWait()) { tensor->Wait(); } diff --git a/mindspore/ccsrc/backend/session/executor.h b/mindspore/ccsrc/backend/session/executor.h index cdd66f50c38..0f006625d62 100644 --- a/mindspore/ccsrc/backend/session/executor.h +++ b/mindspore/ccsrc/backend/session/executor.h @@ -110,17 +110,6 @@ class RunOpsInGraphTask : public Task { GraphId graph_id_{0}; }; -class BuildOpTask : public Task { - public: - BuildOpTask() { type_ = kBuildOp; } - ~BuildOpTask() override = default; - void Run() override; - OpRunInfo *op_run_info_{nullptr}; - GraphInfo graph_info_; - std::vector input_tensors_; - std::vector tensors_mask_; -}; - class RunOpTask : public Task { public: RunOpTask() { type_ = kRunOp; } @@ -128,8 +117,9 @@ class RunOpTask : public Task { void Run() override; OpRunInfo *op_run_info_{nullptr}; GraphInfo graph_info_; - std::vector input_tensors_; + std::vector *input_tensors_; VectorRef outputs_; + std::vector tensors_mask_; }; class CreateCommGroupTask : public Task { @@ -170,10 +160,9 @@ class Executor { VectorRef *outputs); void RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs); - void BuildOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, const std::vector &tensors_mask); void RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, VectorRef *outputs); + std::vector *input_tensors, VectorRef *outputs, + const std::vector &tensors_mask); void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs); void OnRunGraphFinished(); diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 5313ba338d1..cfd2238347a 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -398,17 +398,22 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap } void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, VectorRef *outputs) { + std::vector *input_tensors, VectorRef *outputs, + const std::vector &tensors_mask) { + MS_EXCEPTION_IF_NULL(input_tensors); + BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask); + EraseValueNodeTensor(tensors_mask, input_tensors); + auto kernel_graph = run_op_graphs_[graph_info]; MS_EXCEPTION_IF_NULL(kernel_graph); // Remove NopOp from execution graph opt::RemoveNopNode(kernel_graph.get()); - RunOpAllocateMemory(input_tensors, kernel_graph.get()); + RunOpAllocateMemory(*input_tensors, kernel_graph.get()); // Execute the computation - LoadInputData(kernel_graph, input_tensors); + LoadInputData(kernel_graph, *input_tensors); Execute(kernel_graph); // Fetch outputs - UpdateOutputs(kernel_graph, outputs, input_tensors); + UpdateOutputs(kernel_graph, outputs, *input_tensors); RunOpClearMemory(kernel_graph.get()); } diff --git a/mindspore/ccsrc/backend/session/gpu_session.h b/mindspore/ccsrc/backend/session/gpu_session.h index 3ac31ccbd01..67033f00b9a 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.h +++ b/mindspore/ccsrc/backend/session/gpu_session.h @@ -40,7 +40,8 @@ class GPUSession : public SessionBasic { const std::vector &input_tensors, const std::vector &tensors_mask) override; void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, VectorRef *outputs) override; + std::vector *input_tensors, VectorRef *outputs, + const std::vector &tensors_mask) override; private: void SelectKernel(const std::shared_ptr &kernel_graph) const; diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 6a5f4afdf37..7143a1ef1ba 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1593,17 +1593,11 @@ void SessionBasic::BuildGraph(GraphId graph_id) { executor_->BuildGraph(shared_from_this(), graph_id); } -void SessionBasic::BuildOp(OpRunInfo *op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, - const std::vector &tensors_mask) { - MS_EXCEPTION_IF_NULL(executor_); - executor_->BuildOp(shared_from_this(), op_run_info, graph_info, input_tensors, tensors_mask); -} - void SessionBasic::RunOp(OpRunInfo *op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, VectorRef *outputs) { + std::vector *input_tensors, VectorRef *outputs, + const std::vector &tensors_mask) { MS_EXCEPTION_IF_NULL(executor_); - executor_->RunOp(shared_from_this(), op_run_info, graph_info, input_tensors, outputs); + executor_->RunOp(shared_from_this(), op_run_info, graph_info, input_tensors, outputs, tensors_mask); } void SessionBasic::RunOpsInGraph(const GraphId &graph_id, const std::vector &inputs, @@ -1623,6 +1617,22 @@ void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vectorRunGraphAsync(shared_from_this(), graph_id, inputs, outputs); } +void SessionBasic::EraseValueNodeTensor(const std::vector &tensors_mask, + std::vector *input_tensors) { + MS_EXCEPTION_IF_NULL(input_tensors); + if (input_tensors->size() != tensors_mask.size()) { + MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size " + << tensors_mask.size(); + } + std::vector new_input_tensors; + for (size_t index = 0; index < tensors_mask.size(); ++index) { + if (tensors_mask[index] != kValueNodeTensorMask) { + new_input_tensors.emplace_back(input_tensors->at(index)); + } + } + *input_tensors = new_input_tensors; +} + void SessionBasic::UpdateAllGraphDynamicShapeAttr(const std::vector &all_graphs) { bool is_dynamic = false; for (const auto &graph : all_graphs) { diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 2037e971bbc..620016d6116 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -76,9 +76,8 @@ class SessionBasic : public std::enable_shared_from_this { void BuildGraph(GraphId graphId); void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs); void RunGraphAsync(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs); - void BuildOp(OpRunInfo *, const GraphInfo &, const std::vector &input_tensors, - const std::vector &tensors_mask); - void RunOp(OpRunInfo *, const GraphInfo &, const std::vector &input_tensors, VectorRef *outputs); + void RunOp(OpRunInfo *, const GraphInfo &, std::vector *input_tensors, VectorRef *outputs, + const std::vector &tensors_mask); void RunOpsInGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs); virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); @@ -137,7 +136,6 @@ class SessionBasic : public std::enable_shared_from_this { friend class CompileGraphTask; friend class BuildGraphTask; friend class RunGraphTask; - friend class BuildOpTask; friend class RunOpTask; friend class RunOpsInGraphTask; virtual bool IsSupportSummary() { return true; } @@ -156,7 +154,8 @@ class SessionBasic : public std::enable_shared_from_this { const std::vector &input_tensors, const std::vector &tensors_mask) {} virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, VectorRef *outputs) {} + std::vector *input_tensors, VectorRef *outputs, + const std::vector &tensors_mask) {} virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) {} void RunInfer(NotNull func_graph, const std::vector &inputs); @@ -165,6 +164,7 @@ class SessionBasic : public std::enable_shared_from_this { virtual void LoadInputData(const std::shared_ptr &kernel_graph, const std::vector &inputs_const) const; + void EraseValueNodeTensor(const std::vector &tensors_mask, std::vector *input_tensors); void UpdateOutputs(const std::shared_ptr &kernel_graph, VectorRef *const outputs, const std::vector &input_tensors) const; void Reorder(std::vector *node_list); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index a7b3b00f458..9013192b1fd 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -471,21 +471,6 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector op_prim->EndRecordAddAttr(); } -void EraseValueNodeTensor(const std::vector &tensors_mask, std::vector *input_tensors) { - MS_EXCEPTION_IF_NULL(input_tensors); - if (input_tensors->size() != tensors_mask.size()) { - MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size " - << tensors_mask.size(); - } - std::vector new_input_tensors; - for (size_t index = 0; index < tensors_mask.size(); ++index) { - if (tensors_mask[index] != kValueNodeTensorMask) { - new_input_tensors.emplace_back(input_tensors->at(index)); - } - } - *input_tensors = new_input_tensors; -} - BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) { if (utils::isa(base_ref)) { auto ref_list = utils::cast(base_ref); @@ -1301,10 +1286,8 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati op_exec_info->is_mixed_precision_cast, op_exec_info->next_op_name, op_exec_info->next_input_index}; - session->BuildOp(&op_run_info, graph_info, input_tensors, tensors_mask); - EraseValueNodeTensor(tensors_mask, &input_tensors); VectorRef outputs; - session->RunOp(&op_run_info, graph_info, input_tensors, &outputs); + session->RunOp(&op_run_info, graph_info, &input_tensors, &outputs, tensors_mask); auto result = BaseRefToPyData(outputs); ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_INFER, false); *status = PYNATIVE_SUCCESS;