diff --git a/mindspore/ccsrc/runtime/framework/graph_compiler.cc b/mindspore/ccsrc/runtime/framework/graph_compiler.cc index eca94ea32f..add58f3f18 100644 --- a/mindspore/ccsrc/runtime/framework/graph_compiler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_compiler.cc @@ -56,55 +56,50 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) { return graph->graph_id(); } -void GraphCompiler::RunGraph(const GraphId &graph_id, const std::vector &inputs, - VectorRef *outputs) { - MS_EXCEPTION_IF_NULL(session_); - auto graph = session_->GetGraph(graph_id); - MS_EXCEPTION_IF_NULL(graph); - auto actor_set = GraphScheduler::GetInstance().Fetch(graph); - MS_EXCEPTION_IF_NULL(actor_set); - GraphScheduler::GetInstance().Run(actor_set); -} - -void GraphCompiler::CompileAndRunGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info, - std::vector *input_tensors, - const std::vector &tensors_mask, VectorRef *outputs) { +GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info, + std::vector *input_tensors, + const std::vector &tensors_mask) { // Check if the graph cache exists. - if (run_op_graphs_.find(graph_info) == run_op_graphs_.end()) { - // Prepare the graph - MS_EXCEPTION_IF_NULL(session_); - auto graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask); + auto iter = run_op_graphs_.find(graph_info); + if (iter != run_op_graphs_.end()) { + const auto &graph = iter->second; MS_EXCEPTION_IF_NULL(graph); - - MS_EXCEPTION_IF_NULL(device_context_); - device_context_->SetOperatorInfo(graph->execution_order()); - - device_context_->OptimizeSingleOpGraph(graph); - MS_EXCEPTION_IF_NULL(session_); - session_->RunOpHideNopNode(graph); - - device_context_->CreateKernel(graph->execution_order()); - run_op_graphs_[graph_info] = graph; + return graph->graph_id(); } - - session_->EraseValueNodeTensor(tensors_mask, input_tensors); - - // wait for allreduce - for (auto &tensor : *input_tensors) { - if (tensor->NeedWaitDevice()) { - tensor->WaitDevice(); - } - } - - // run op - auto graph = run_op_graphs_[graph_info]; + // Generate kernel graph. + MS_EXCEPTION_IF_NULL(session_); + auto graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask); MS_EXCEPTION_IF_NULL(graph); + + MS_EXCEPTION_IF_NULL(device_context_); + device_context_->SetOperatorInfo(graph->execution_order()); + + device_context_->OptimizeSingleOpGraph(graph); + MS_EXCEPTION_IF_NULL(session_); + session_->RunOpHideNopNode(graph); session_->RunOpRemoveNopNode(graph); + // Generate 'KernelMod' for kernel in graph. + device_context_->CreateKernel(graph->execution_order()); + + // Transform graph to actor DAG, contains build and link. GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep); - auto actor_set = GraphScheduler::GetInstance().Fetch(graph); - MS_EXCEPTION_IF_NULL(actor_set); - GraphScheduler::GetInstance().Run(actor_set, GraphExecutionStrategy::kStep); + run_op_graphs_[graph_info] = graph; + return graph->graph_id(); +} + +KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const { + MS_EXCEPTION_IF_NULL(session_); + return session_->GetGraph(graph_id); +} + +KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const { + auto iter = run_op_graphs_.find(graph_info); + if (iter == run_op_graphs_.end()) { + MS_LOG(ERROR) << "Can't find graph for: " << graph_info; + return nullptr; + } + return iter->second; } } // namespace runtime } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/framework/graph_compiler.h b/mindspore/ccsrc/runtime/framework/graph_compiler.h index 7fe17bd863..1db27cc0e8 100644 --- a/mindspore/ccsrc/runtime/framework/graph_compiler.h +++ b/mindspore/ccsrc/runtime/framework/graph_compiler.h @@ -41,13 +41,15 @@ class GraphCompiler { // the detailed implementation of compiling graph is in 'CompileGraphImpl'. GraphId CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs); - // Run a graph and get the output in Graph mode. - void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs); + // Construct single op kernel graph and compile the kernel graph in PyNative mode. + GraphId CompileGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info, + std::vector *input_tensors, const std::vector &tensors_mask); - // Construct single op kernel graph, compile and run the kernel graph in PyNative mode. - void CompileAndRunGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info, - std::vector *input_tensors, const std::vector &tensors_mask, - VectorRef *outputs); + // Get graph by graph id, if not exist return nullptr, used in Graph mode. + KernelGraphPtr Fetch(GraphId graph_id) const; + + // Get graph by graph info, if not exist return nullptr, used in PyNative mode. + KernelGraphPtr Fetch(const GraphInfo &graph_info) const; private: GraphCompiler() = default; diff --git a/mindspore/ccsrc/runtime/hardware/device_context_manager.cc b/mindspore/ccsrc/runtime/hardware/device_context_manager.cc index c7d850d1f8..216142b985 100644 --- a/mindspore/ccsrc/runtime/hardware/device_context_manager.cc +++ b/mindspore/ccsrc/runtime/hardware/device_context_manager.cc @@ -34,7 +34,7 @@ void DeviceContextManager::ClearDeviceContexts() { device_contexts_.clear(); } -DeviceContext *DeviceContextManager::CreateOrGetDeviceContext(const DeviceContextKey &device_context_key) { +DeviceContext *DeviceContextManager::GetOrCreateDeviceContext(const DeviceContextKey &device_context_key) { std::string device_context_key_str = device_context_key.ToString(); std::lock_guard guard(lock_); diff --git a/mindspore/ccsrc/runtime/hardware/device_context_manager.h b/mindspore/ccsrc/runtime/hardware/device_context_manager.h index 916db4b311..5d78bd3042 100644 --- a/mindspore/ccsrc/runtime/hardware/device_context_manager.h +++ b/mindspore/ccsrc/runtime/hardware/device_context_manager.h @@ -36,7 +36,7 @@ class DeviceContextManager { return instance; } void Register(const std::string &device_name, DeviceContextCreator &&device_context_creator); - DeviceContext *CreateOrGetDeviceContext(const DeviceContextKey &device_context_key); + DeviceContext *GetOrCreateDeviceContext(const DeviceContextKey &device_context_key); void ClearDeviceContexts(); private: