From 3c2264730d38ca82c5ee484dd2b135271263ffab Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Mon, 29 Mar 2021 14:38:05 +0800 Subject: [PATCH] add graph compiler --- .../ccsrc/runtime/framework/graph_compiler.h | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 mindspore/ccsrc/runtime/framework/graph_compiler.h diff --git a/mindspore/ccsrc/runtime/framework/graph_compiler.h b/mindspore/ccsrc/runtime/framework/graph_compiler.h new file mode 100644 index 00000000000..924adfaaa6c --- /dev/null +++ b/mindspore/ccsrc/runtime/framework/graph_compiler.h @@ -0,0 +1,71 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_ +#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_ + +#include +#include +#include +#include +#include "runtime/hardware/device_context.h" +#include "backend/session/session_basic.h" + +namespace mindspore { +namespace runtime { +class GraphCompiler { + public: + static GraphCompiler &GetInstance() { + static GraphCompiler instance; + return instance; + } + + // Set device context which is initialized, the function must be called + // before using GraphCompiler and after changing device type or device id. + void set_device_context(device::DeviceContext *device_context); + + // Construct kernel graph from anf nodes list and compile kernel graph in Graph mode, + // the detailed implementation of compiling graph is in 'CompileGraphImpl'. + GraphId CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs); + + // Run a graph and get the output in Graph mode. + void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs); + + // Construct single op kernel graph, compile and run the kernel graph in PyNative mode. + void CompileAndRunGraph(OpRunInfo *op_run_info, const GraphInfo &graph_info, + std::vector *input_tensors, const std::vector &tensors_mask, + VectorRef *outputs); + + private: + GraphCompiler() = default; + ~GraphCompiler() = default; + DISABLE_COPY_AND_ASSIGN(GraphCompiler); + + // The implementation of compiling graph in Graph Mode, including optimizing graph, + // setting operator info, creating kernel and transforming kernel graph to ActorSet. + GraphId CompileGraphImpl(const KernelGraphPtr &graph); + + device::DeviceContext *device_context_{nullptr}; + + // Single op kernel graph cache for PyNative mode. + std::unordered_map> run_op_graphs_; + + // The member variable 'session_' will be removed after removing session module. + session::SessionPtr session_{nullptr}; +}; +} // namespace runtime +} // namespace mindspore +#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_