forked from OSSInnovation/mindspore
remove run graph method and add get graph method in graph compiler
This commit is contained in:
parent
1ad05f52cb
commit
66c3a14303
|
@ -56,22 +56,17 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) {
|
|||
return graph->graph_id();
|
||||
}
|
||||
|
||||
void GraphCompiler::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
|
||||
VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
auto graph = session_->GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto actor_set = GraphScheduler::GetInstance().Fetch(graph);
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
GraphScheduler::GetInstance().Run(actor_set);
|
||||
}
|
||||
|
||||
void GraphCompiler::CompileAndRunGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors,
|
||||
const std::vector<int64_t> &tensors_mask, VectorRef *outputs) {
|
||||
const std::vector<int64_t> &tensors_mask) {
|
||||
// Check if the graph cache exists.
|
||||
if (run_op_graphs_.find(graph_info) == run_op_graphs_.end()) {
|
||||
// Prepare the graph
|
||||
auto iter = run_op_graphs_.find(graph_info);
|
||||
if (iter != run_op_graphs_.end()) {
|
||||
const auto &graph = iter->second;
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
return graph->graph_id();
|
||||
}
|
||||
// Generate kernel graph.
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
auto graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
@ -82,29 +77,29 @@ void GraphCompiler::CompileAndRunGraph(session::OpRunInfo *op_run_info, const Gr
|
|||
device_context_->OptimizeSingleOpGraph(graph);
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
session_->RunOpHideNopNode(graph);
|
||||
|
||||
device_context_->CreateKernel(graph->execution_order());
|
||||
run_op_graphs_[graph_info] = graph;
|
||||
}
|
||||
|
||||
session_->EraseValueNodeTensor(tensors_mask, input_tensors);
|
||||
|
||||
// wait for allreduce
|
||||
for (auto &tensor : *input_tensors) {
|
||||
if (tensor->NeedWaitDevice()) {
|
||||
tensor->WaitDevice();
|
||||
}
|
||||
}
|
||||
|
||||
// run op
|
||||
auto graph = run_op_graphs_[graph_info];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
session_->RunOpRemoveNopNode(graph);
|
||||
|
||||
// Generate 'KernelMod' for kernel in graph.
|
||||
device_context_->CreateKernel(graph->execution_order());
|
||||
|
||||
// Transform graph to actor DAG, contains build and link.
|
||||
GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep);
|
||||
auto actor_set = GraphScheduler::GetInstance().Fetch(graph);
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
GraphScheduler::GetInstance().Run(actor_set, GraphExecutionStrategy::kStep);
|
||||
run_op_graphs_[graph_info] = graph;
|
||||
return graph->graph_id();
|
||||
}
|
||||
|
||||
KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const {
|
||||
MS_EXCEPTION_IF_NULL(session_);
|
||||
return session_->GetGraph(graph_id);
|
||||
}
|
||||
|
||||
KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const {
|
||||
auto iter = run_op_graphs_.find(graph_info);
|
||||
if (iter == run_op_graphs_.end()) {
|
||||
MS_LOG(ERROR) << "Can't find graph for: " << graph_info;
|
||||
return nullptr;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -41,13 +41,15 @@ class GraphCompiler {
|
|||
// the detailed implementation of compiling graph is in 'CompileGraphImpl'.
|
||||
GraphId CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs);
|
||||
|
||||
// Run a graph and get the output in Graph mode.
|
||||
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs);
|
||||
// Construct single op kernel graph and compile the kernel graph in PyNative mode.
|
||||
GraphId CompileGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, const std::vector<int64_t> &tensors_mask);
|
||||
|
||||
// Construct single op kernel graph, compile and run the kernel graph in PyNative mode.
|
||||
void CompileAndRunGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
||||
std::vector<tensor::TensorPtr> *input_tensors, const std::vector<int64_t> &tensors_mask,
|
||||
VectorRef *outputs);
|
||||
// Get graph by graph id, if not exist return nullptr, used in Graph mode.
|
||||
KernelGraphPtr Fetch(GraphId graph_id) const;
|
||||
|
||||
// Get graph by graph info, if not exist return nullptr, used in PyNative mode.
|
||||
KernelGraphPtr Fetch(const GraphInfo &graph_info) const;
|
||||
|
||||
private:
|
||||
GraphCompiler() = default;
|
||||
|
|
|
@ -34,7 +34,7 @@ void DeviceContextManager::ClearDeviceContexts() {
|
|||
device_contexts_.clear();
|
||||
}
|
||||
|
||||
DeviceContext *DeviceContextManager::CreateOrGetDeviceContext(const DeviceContextKey &device_context_key) {
|
||||
DeviceContext *DeviceContextManager::GetOrCreateDeviceContext(const DeviceContextKey &device_context_key) {
|
||||
std::string device_context_key_str = device_context_key.ToString();
|
||||
std::lock_guard<std::mutex> guard(lock_);
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class DeviceContextManager {
|
|||
return instance;
|
||||
}
|
||||
void Register(const std::string &device_name, DeviceContextCreator &&device_context_creator);
|
||||
DeviceContext *CreateOrGetDeviceContext(const DeviceContextKey &device_context_key);
|
||||
DeviceContext *GetOrCreateDeviceContext(const DeviceContextKey &device_context_key);
|
||||
void ClearDeviceContexts();
|
||||
|
||||
private:
|
||||
|
|
Loading…
Reference in New Issue