remove run graph method and add get graph method in graph compiler

This commit is contained in:
lizhenyu 2021-04-01 11:48:25 +08:00
parent 1ad05f52cb
commit 66c3a14303
4 changed files with 47 additions and 50 deletions

View File

@ -56,22 +56,17 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) {
return graph->graph_id();
}
void GraphCompiler::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(session_);
auto graph = session_->GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(graph);
auto actor_set = GraphScheduler::GetInstance().Fetch(graph);
MS_EXCEPTION_IF_NULL(actor_set);
GraphScheduler::GetInstance().Run(actor_set);
}
void GraphCompiler::CompileAndRunGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
std::vector<tensor::TensorPtr> *input_tensors,
const std::vector<int64_t> &tensors_mask, VectorRef *outputs) {
const std::vector<int64_t> &tensors_mask) {
// Check if the graph cache exists.
if (run_op_graphs_.find(graph_info) == run_op_graphs_.end()) {
// Prepare the graph
auto iter = run_op_graphs_.find(graph_info);
if (iter != run_op_graphs_.end()) {
const auto &graph = iter->second;
MS_EXCEPTION_IF_NULL(graph);
return graph->graph_id();
}
// Generate kernel graph.
MS_EXCEPTION_IF_NULL(session_);
auto graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask);
MS_EXCEPTION_IF_NULL(graph);
@ -82,29 +77,29 @@ void GraphCompiler::CompileAndRunGraph(session::OpRunInfo *op_run_info, const Gr
device_context_->OptimizeSingleOpGraph(graph);
MS_EXCEPTION_IF_NULL(session_);
session_->RunOpHideNopNode(graph);
device_context_->CreateKernel(graph->execution_order());
run_op_graphs_[graph_info] = graph;
}
session_->EraseValueNodeTensor(tensors_mask, input_tensors);
// wait for allreduce
for (auto &tensor : *input_tensors) {
if (tensor->NeedWaitDevice()) {
tensor->WaitDevice();
}
}
// run op
auto graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(graph);
session_->RunOpRemoveNopNode(graph);
// Generate 'KernelMod' for kernel in graph.
device_context_->CreateKernel(graph->execution_order());
// Transform graph to actor DAG, contains build and link.
GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep);
auto actor_set = GraphScheduler::GetInstance().Fetch(graph);
MS_EXCEPTION_IF_NULL(actor_set);
GraphScheduler::GetInstance().Run(actor_set, GraphExecutionStrategy::kStep);
run_op_graphs_[graph_info] = graph;
return graph->graph_id();
}
KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const {
MS_EXCEPTION_IF_NULL(session_);
return session_->GetGraph(graph_id);
}
KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const {
auto iter = run_op_graphs_.find(graph_info);
if (iter == run_op_graphs_.end()) {
MS_LOG(ERROR) << "Can't find graph for: " << graph_info;
return nullptr;
}
return iter->second;
}
} // namespace runtime
} // namespace mindspore

View File

@ -41,13 +41,15 @@ class GraphCompiler {
// the detailed implementation of compiling graph is in 'CompileGraphImpl'.
GraphId CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs);
// Run a graph and get the output in Graph mode.
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
// Construct single op kernel graph and compile the kernel graph in PyNative mode.
GraphId CompileGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
std::vector<tensor::TensorPtr> *input_tensors, const std::vector<int64_t> &tensors_mask);
// Construct single op kernel graph, compile and run the kernel graph in PyNative mode.
void CompileAndRunGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
std::vector<tensor::TensorPtr> *input_tensors, const std::vector<int64_t> &tensors_mask,
VectorRef *outputs);
// Get graph by graph id, if not exist return nullptr, used in Graph mode.
KernelGraphPtr Fetch(GraphId graph_id) const;
// Get graph by graph info, if not exist return nullptr, used in PyNative mode.
KernelGraphPtr Fetch(const GraphInfo &graph_info) const;
private:
GraphCompiler() = default;

View File

@ -34,7 +34,7 @@ void DeviceContextManager::ClearDeviceContexts() {
device_contexts_.clear();
}
DeviceContext *DeviceContextManager::CreateOrGetDeviceContext(const DeviceContextKey &device_context_key) {
DeviceContext *DeviceContextManager::GetOrCreateDeviceContext(const DeviceContextKey &device_context_key) {
std::string device_context_key_str = device_context_key.ToString();
std::lock_guard<std::mutex> guard(lock_);

View File

@ -36,7 +36,7 @@ class DeviceContextManager {
return instance;
}
void Register(const std::string &device_name, DeviceContextCreator &&device_context_creator);
DeviceContext *CreateOrGetDeviceContext(const DeviceContextKey &device_context_key);
DeviceContext *GetOrCreateDeviceContext(const DeviceContextKey &device_context_key);
void ClearDeviceContexts();
private: