forked from mindspore-Ecosystem/mindspore
!14322 add graph compiler header file
From: @zyli2020 Reviewed-by: @limingqi107,@limingqi107,@cristoval Signed-off-by: @cristoval
This commit is contained in:
commit
3689f37fb1
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
class GraphCompiler {
|
||||
public:
|
||||
static GraphCompiler &GetInstance() {
|
||||
static GraphCompiler instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
// Set device context which is initialized, the function must be called
|
||||
// before using GraphCompiler and after changing device type or device id.
|
||||
void set_device_context(device::DeviceContext *device_context);
|
||||
|
||||
// Construct kernel graph from anf nodes list and compile kernel graph in Graph mode,
|
||||
// the detailed implementation of compiling graph is in 'CompileGraphImpl'.
|
||||
GraphId CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs);
|
||||
|
||||
// Run a graph and get the output in Graph mode.
|
||||
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
||||
|
||||
// Construct single op kernel graph, compile and run the kernel graph in PyNative mode.
|
||||
void CompileAndRunGraph(OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, const std::vector<int64_t> &tensors_mask,
|
||||
VectorRef *outputs);
|
||||
|
||||
private:
|
||||
GraphCompiler() = default;
|
||||
~GraphCompiler() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(GraphCompiler);
|
||||
|
||||
// The implementation of compiling graph in Graph Mode, including optimizing graph,
|
||||
// setting operator info, creating kernel and transforming kernel graph to ActorSet.
|
||||
GraphId CompileGraphImpl(const KernelGraphPtr &graph);
|
||||
|
||||
device::DeviceContext *device_context_{nullptr};
|
||||
|
||||
// Single op kernel graph cache for PyNative mode.
|
||||
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
|
||||
|
||||
// The member variable 'session_' will be removed after removing session module.
|
||||
session::SessionPtr session_{nullptr};
|
||||
};
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_GRAPH_COMPILER_H_
|
Loading…
Reference in New Issue