!16116 graph scheduler support multi graphs

From: @limingqi107
Reviewed-by: @cristoval,@wilfchen
Signed-off-by: @wilfchen
This commit is contained in:
mindspore-ci-bot 2021-05-11 19:33:57 +08:00 committed by Gitee
commit f356c51169
14 changed files with 408 additions and 312 deletions

View File

@ -2,6 +2,7 @@
include_directories(${CMAKE_SOURCE_DIR}/mindspore/core)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_BINARY_DIR})
include_directories(${CMAKE_SOURCE_DIR}/mindspore/core/mindrt/include)
if(ENABLE_CPU)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/backend/kernel_compiler/cpu)

View File

@ -1109,6 +1109,14 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr
}
}
AnfNodePtr KernelGraph::GetFrontNodeByInternalParameter(const AnfNodePtr &parameter) const {
const auto &iter = internal_parameters_to_front_map_.find(parameter);
if (iter != internal_parameters_to_front_map_.end()) {
return iter->second;
}
return nullptr;
}
AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const {
auto iter = front_to_internal_outputs_map_.find(front_node);
if (iter != front_to_internal_outputs_map_.end()) {

View File

@ -71,6 +71,7 @@ class KernelGraph : public FuncGraph {
parent_graph_ = graph.parent_graph_;
start_label_ = graph.start_label_;
end_goto_ = graph.end_goto_;
internal_parameters_to_front_map_ = graph.internal_parameters_to_front_map_;
front_to_internal_outputs_map_ = graph.front_to_internal_outputs_map_;
internal_outputs_to_front_map_ = graph.internal_outputs_to_front_map_;
internal_outputs_tensor_map_ = graph.internal_outputs_tensor_map_;
@ -199,6 +200,7 @@ class KernelGraph : public FuncGraph {
bool unique_target = false);
void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx = -1,
int dst_output_idx = -1);
AnfNodePtr GetFrontNodeByInternalParameter(const AnfNodePtr &parameter) const;
AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const;
bool IsInternalOutput(const AnfNodePtr &node, int output_idx = -1) const;
bool IsUniqueTargetInternalOutput(const AnfNodePtr &node, int output_idx) const;
@ -353,6 +355,7 @@ class KernelGraph : public FuncGraph {
CNodePtr start_label_;
CNodePtr end_goto_;
std::unordered_map<AnfNodePtr, AnfNodePtr> internal_parameters_to_front_map_;
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_;
std::unordered_map<AnfNodePtr, std::unordered_map<int, std::pair<AnfNodePtr, bool>>> internal_outputs_to_front_map_;
std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_;

View File

@ -756,11 +756,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
} else if (anf->isa<Parameter>()) {
auto new_parameter = CreateNewParameterFromParameter(anf, graph);
cnode_inputs->push_back(new_parameter);
if (GetGraphIdByNode(anf) == kInvalidGraphId) {
graph->FrontBackendlMapAdd(anf, new_parameter);
} else {
(*other_graph_cnode)[anf] = new_parameter;
}
graph->FrontBackendlMapAdd(anf, new_parameter);
continue;
} else {
// the input node is a cnode from other graph

View File

@ -60,16 +60,16 @@ void TaskEmitActionForMindRT(const ResourcePtr &res) {
auto mindrt_bc_ptr = std::dynamic_pointer_cast<compile::MindRTBackend>(bc_ptr);
MS_EXCEPTION_IF_NULL(mindrt_bc_ptr);
// The output of graph compiler is graph id.
// The output of graph compiler is actor.
res->results()[kOutput] = mindrt_bc_ptr->CompileGraphs(res->func_graph());
}
void ExecuteActionForMindRT(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
if (!res->results()[kOutput].is<GraphId>()) {
if (!res->results()[kOutput].is<compile::ActorInfo>()) {
MS_LOG(EXCEPTION) << "Execute args error";
}
auto graph_id = res->results()[kOutput].cast<GraphId>();
const auto &actor_info = res->results()[kOutput].cast<compile::ActorInfo>();
// Get the mindRT backend.
std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
@ -78,9 +78,9 @@ void ExecuteActionForMindRT(const ResourcePtr &res) {
// Construct the graph run function ptr.
compile::VmEvalFuncPtr run =
std::make_shared<compile::VmEvalFunc>([mindrt_bc_ptr, graph_id](const VectorRef &args) -> BaseRef {
std::make_shared<compile::VmEvalFunc>([mindrt_bc_ptr, actor_info](const VectorRef &args) -> BaseRef {
MS_LOG(INFO) << "Execute args size " << args.size();
auto outs = mindrt_bc_ptr->RunGraph(graph_id, args);
auto outs = mindrt_bc_ptr->RunGraph(actor_info, args);
MS_LOG(DEBUG) << "out size " << outs.size();
return outs[0];
});

View File

@ -60,6 +60,11 @@ void DataSourceActor::FreeMemory(OpContext<DeviceTensor> *context) {
void DataSourceActor::SendOutput(OpContext<DeviceTensor> *context) {
MS_LOG(INFO) << "Data source actor(" << GetAID().Name() << ") sends output data.";
MS_EXCEPTION_IF_NULL(context);
// No output.
if (output_op_arrows_.size() == 0) {
SET_OPCONTEXT_SUCCESS_RET((*context));
}
if (buffers_.size() == 0) {
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
}

View File

@ -118,6 +118,8 @@ class HostQueueDataSourceActor : public DataSourceActor {
HostTensorQueuePtr host_queue_;
// Input data nodes fetch data from host queue.
std::vector<AnfNodePtr> data_nodes_;
// The location of the data node in the data source actor.
std::unordered_map<AnfNodePtr, size_t> data_node_position_map_;
};
using DataSourceActorPtr = std::shared_ptr<DataSourceActor>;

View File

@ -173,6 +173,11 @@ void KernelActor::FetchLaunchArgs(std::vector<AddressPtr> *kernel_inputs, std::v
void KernelActor::SendOutput(OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);
// No output.
if ((output_op_arrows_.size() == 0) && (output_op_controls_.size() == 0)) {
SET_OPCONTEXT_SUCCESS_RET((*context));
}
// Send output data.
for (auto &op_arrow : output_op_arrows_) {
MS_EXCEPTION_IF_NULL(op_arrow);

View File

@ -256,10 +256,6 @@ GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const Graph
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
// Transform graph to actor DAG, contains build and link.
GraphScheduler::GetInstance().Transform({graph}, {device_context_}, input_tensors, nullptr,
GraphExecutionStrategy::kStep);
run_op_graphs_[graph_info] = graph;
return graph->graph_id();
}

View File

@ -23,6 +23,7 @@
#include "utils/config_manager.h"
#include "utils/log_adapter.h"
#include "utils/convert_utils.h"
#include "utils/ms_context.h"
#include "common/trans.h"
namespace mindspore {
@ -36,10 +37,14 @@ bool IsDeviceQueueDSActor(const AnfNodePtr &node) {
return false;
}
bool IsHostQueueDSActor(const AnfNodePtr &node) {
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
return true;
// Judge whether node is internal parameter.
if (graph->GetFrontNodeByInternalParameter(node) == nullptr) {
return true;
}
}
return false;
}
@ -384,32 +389,28 @@ void GraphScheduler::Initialize() {
(void)actorMgr->Spawn(base_actor, false);
}
ActorSet *GraphScheduler::Transform(const std::vector<KernelGraphPtr> &graphs,
const std::vector<DeviceContext *> &device_contexts,
const std::vector<TensorPtr> *input_tensors,
const std::vector<AnfNodePtr> *control_nodes, GraphExecutionStrategy strategy) {
if (graphs.size() != device_contexts.size()) {
MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device_contexts.";
ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info, GraphExecutionStrategy strategy) {
MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor begin.";
if (graph_compiler_info.graphs_.size() == 0) {
MS_LOG(EXCEPTION) << "The number of graphs is zero.";
}
if (graph_compiler_info.graphs_.size() != graph_compiler_info.device_contexts_.size()) {
MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device contexts.";
}
Initialize();
std::vector<ActorSetPtr> actor_sets;
for (size_t i = 0; i < graphs.size(); ++i) {
auto graph = graphs[i];
auto device_context = device_contexts[i];
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Graph(" << graph->ToString() << ") transforms actor begin.";
PersistDeviceTensor(graph);
auto actor_set = Build(graph, device_context);
actor_sets.emplace_back(actor_set);
graph_to_actors_.emplace(graph, actor_set);
Link(actor_set.get(), graph, strategy);
if (!CheckActorValid(actor_set.get())) {
MS_LOG(EXCEPTION) << "The actor set of " << graph->ToString() << " is invalid.";
}
MS_LOG(INFO) << "Graph(" << graph->ToString() << ") transforms actor end.";
Initialize();
PersistDeviceTensor(graph_compiler_info);
const auto &actor_set = Build(graph_compiler_info);
Link(actor_set.get(), graph_compiler_info, strategy);
actors_.emplace(actor_set->name_, actor_set);
DumpActor(actor_set.get());
if (!CheckActorValid(actor_set.get())) {
MS_LOG(EXCEPTION) << "The actor set of " << graph_compiler_info.name_ << " is invalid.";
}
return actor_sets[0].get();
MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor end.";
return actor_set.get();
}
void GraphScheduler::Schedule(const ActorSet *actor_set) {
@ -438,61 +439,70 @@ void GraphScheduler::Schedule(const ActorSet *actor_set) {
}
}
void GraphScheduler::PrepareRun(const KernelGraphPtr &graph, const std::vector<TensorPtr> *input_tensors,
VectorRef *const &outputs) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(input_tensors);
MS_EXCEPTION_IF_NULL(outputs);
// Get the device context for the first kernel actor.
const auto &actor_set = Fetch(graph);
void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info,
const std::vector<std::vector<TensorPtr>> &input_tensors, VectorRef *const &outputs) {
MS_EXCEPTION_IF_NULL(actor_set);
const auto &first_kernel_actor = actor_set->kernel_actors_[0];
MS_EXCEPTION_IF_NULL(first_kernel_actor);
const auto &device_context = first_kernel_actor->device_context_;
// 1.Prepare the data of device tensor store(value nodes of graph).
for (const auto &value_node : graph->graph_value_nodes()) {
if (AnfAlgo::OutputAddrExist(value_node, 0)) {
PrepareDataForValueNode(value_node, device_context);
}
}
// 1.Prepare the data of device tensor store(weights of graph), and fill the host tensors for non weighted parameters.
MS_EXCEPTION_IF_NULL(outputs);
std::vector<TensorPtr> host_tensors;
const auto &input_nodes = graph->input_nodes();
for (size_t i = 0; i < input_nodes.size(); ++i) {
const auto &input_node = input_nodes[i];
const auto &input_tensor = (*input_tensors)[i];
MS_EXCEPTION_IF_NULL(input_node);
if (IsPersistentDeviceTensor(input_node)) {
// Prepare the device data for weights.
PrepareDataForWeightNode(input_node, input_tensor, device_context);
} else {
// Fill the host tensors for non weighted parameters.
host_tensors.emplace_back(input_tensor);
const auto &host_data_source_actor = FindHostQueueDSActor(actor_set->data_source_actors_);
if (host_data_source_actor != nullptr) {
host_tensors.resize(host_data_source_actor->data_nodes_.size());
}
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
const auto &graph = graph_compiler_info.graphs_[i];
const auto &device_context = graph_compiler_info.device_contexts_[i];
MS_EXCEPTION_IF_NULL(graph);
// 1.Prepare the data of device tensor store(value nodes of graph).
for (const auto &value_node : graph->graph_value_nodes()) {
if (AnfAlgo::OutputAddrExist(value_node, 0)) {
PrepareDataForValueNode(value_node, device_context);
}
}
// 1.Prepare the data of device tensor store(weights of graph), and fill host tensors for non weighted parameters.
const auto &input_nodes = graph->input_nodes();
const auto &tensors = input_tensors[i];
for (size_t j = 0; j < input_nodes.size(); ++j) {
const auto &input_node = input_nodes[j];
const auto &input_tensor = tensors[j];
MS_EXCEPTION_IF_NULL(input_node);
if (IsPersistentDeviceTensor(input_node)) {
// Prepare the device data for weights.
PrepareDataForWeightNode(input_node, input_tensor, device_context);
} else if (IsHostQueueDSActor(input_node, graph)) {
MS_EXCEPTION_IF_NULL(host_data_source_actor);
// Fill the host tensors for non weighted parameters.
const auto &iter = host_data_source_actor->data_node_position_map_.find(input_node);
if (iter != host_data_source_actor->data_node_position_map_.end()) {
host_tensors[iter->second] = input_tensor;
}
}
}
// 2.Prepare the output tensor of graph.
for (const auto &output_node : graph->outputs()) {
MS_EXCEPTION_IF_NULL(output_node);
MS_LOG(INFO) << "Create node output: " << output_node->fullname_with_scope();
outputs->emplace_back(CreateOutputTensors(output_node, graph, tensors));
}
// 3.Prepare the continuous memory for communication kernel.
for (const auto &kernel : graph->execution_order()) {
if (AnfAlgo::IsCommunicationOp(kernel)) {
AllocateContinuousMemoryForInput(kernel, device_context, graph->is_all_nop_node());
AllocateContinuousMemoryForOutput(kernel, device_context);
}
}
}
// 2.Prepare the data of host tensor queue(non weighted parameters of graph).
const auto &host_tensor_queue = FetchHostQueue(graph);
if (host_tensor_queue != nullptr) {
// 4.Prepare the data of host tensor queue(non weighted parameters of graph).
if (host_data_source_actor != nullptr) {
const auto &host_tensor_queue = FetchHostQueue(actor_set->name_);
MS_EXCEPTION_IF_NULL(host_tensor_queue);
host_tensor_queue->PushData(host_tensors);
}
// 3.Prepare the output tensor of graph.
for (const auto &output_node : graph->outputs()) {
MS_EXCEPTION_IF_NULL(output_node);
MS_LOG(INFO) << "Create node output: " << output_node->fullname_with_scope();
outputs->emplace_back(CreateOutputTensors(output_node, graph, *input_tensors));
}
// 4.Prepare the continuous memory for communication kernel.
for (const auto &kernel : graph->execution_order()) {
if (AnfAlgo::IsCommunicationOp(kernel)) {
AllocateContinuousMemoryForInput(kernel, device_context, graph->is_all_nop_node());
AllocateContinuousMemoryForOutput(kernel, device_context);
}
}
}
bool GraphScheduler::Run(const ActorSet *actor_set, GraphExecutionStrategy strategy) {
@ -544,36 +554,32 @@ bool GraphScheduler::Run(const ActorSet *actor_set, GraphExecutionStrategy strat
return true;
}
ActorSet *GraphScheduler::Fetch(const KernelGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(graph);
auto iter = graph_to_actors_.find(graph);
if (iter != graph_to_actors_.end()) {
ActorSet *GraphScheduler::Fetch(const ActorInfo &actor_info) const {
auto iter = actors_.find(actor_info);
if (iter != actors_.end()) {
return iter->second.get();
} else {
MS_LOG(ERROR) << "Can't find the actors map of graph: " << graph->ToString();
MS_LOG(ERROR) << "Can't find the actors map of " << actor_info;
return nullptr;
}
}
ActorSetPtr GraphScheduler::Build(const KernelGraphPtr &graph, const DeviceContext *device_context) {
auto actor_set = std::make_shared<ActorSet>();
ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info) {
auto actor_set = std::make_shared<ActorSet>(graph_compiler_info.name_);
MS_EXCEPTION_IF_NULL(actor_set);
auto data_source_actors = BuildDataSourceActor(graph, device_context);
actor_set->data_source_actors_.swap(data_source_actors);
auto kernel_actors = BuildKernelActor(graph, device_context);
actor_set->kernel_actors_.swap(kernel_actors);
auto loop_count_actor = BuildLoopCountActor(graph);
actor_set->loop_count_actor_ = loop_count_actor;
auto host_queue = std::make_shared<HostTensorQueue>();
actor_to_host_queue_.emplace(actor_set->name_, host_queue);
actor_set->data_source_actors_ = BuildDataSourceActor(graph_compiler_info, host_queue);
actor_set->kernel_actors_ = BuildKernelActor(graph_compiler_info);
actor_set->loop_count_actor_ = BuildLoopCountActor(graph_compiler_info);
return actor_set;
}
void GraphScheduler::Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy) {
void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info,
GraphExecutionStrategy strategy) {
MS_EXCEPTION_IF_NULL(actor_set);
MS_EXCEPTION_IF_NULL(graph);
KernelMapActor kernel_actors_temp_map;
for (auto &actor : actor_set->kernel_actors_) {
MS_EXCEPTION_IF_NULL(actor);
@ -581,113 +587,145 @@ void GraphScheduler::Link(ActorSet *actor_set, const KernelGraphPtr &graph, Grap
}
// Foreach the execution order to link the actors.
auto execution_order = graph->execution_order();
for (auto &kernel : execution_order) {
if (!IsKernelActor(kernel)) {
continue;
}
auto kernel_actor = FindKernelActor(kernel_actors_temp_map, kernel->fullname_with_scope());
// Link the control arrows of kernel actor.
LinkControlArrowForKernelActor(kernel_actor, actor_set->loop_count_actor_.get(), graph, strategy);
for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
auto input_node = AnfAlgo::GetInputNode(kernel, i);
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
LinkControlArrowByAutoMonad(kernel_actor, input_node, kernel_actors_temp_map);
if (HasAbstractMonad(input_node)) {
continue; // No data arrow for monad input.
for (const auto &graph : graph_compiler_info.graphs_) {
MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order();
for (auto &kernel : execution_order) {
if (!IsKernelActor(kernel)) {
continue;
}
auto kernel_actor = FindKernelActor(kernel_actors_temp_map, kernel->fullname_with_scope());
KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
auto from_kernel = from_kernel_with_output_idx.first;
for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
auto input_node = AnfAlgo::GetInputNode(kernel, i);
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
LinkControlArrowByAutoMonad(kernel_actor, input_node, kernel_actors_temp_map);
if (HasAbstractMonad(input_node)) {
continue; // No data arrow for monad input.
}
if (IsDeviceQueueDSActor(from_kernel)) {
// Link the data arrows of device queue data source actor.
auto from_actor = FindDeviceQueueDSActor(actor_set->data_source_actors_);
LinkDataArrowForDeviceDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else if (IsHostQueueDSActor(from_kernel)) {
// Link the data arrows of host queue data source actor.
auto from_actor = FindHostQueueDSActor(actor_set->data_source_actors_);
LinkDataArrowForHostDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else {
// Link the data arrows of kernel actor.
auto from_actor = FindKernelActor(kernel_actors_temp_map, from_kernel->fullname_with_scope());
LinkDataArrowForKernelActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
auto from_kernel = from_kernel_with_output_idx.first;
if (IsDeviceQueueDSActor(from_kernel)) {
// Link the data arrows of device queue data source actor.
auto from_actor = FindDeviceQueueDSActor(actor_set->data_source_actors_);
LinkDataArrowForDeviceDSActor(from_actor, kernel_actor, from_kernel_with_output_idx,
to_kernel_with_input_idx);
} else if (IsHostQueueDSActor(from_kernel, graph)) {
// Link the data arrows of host queue data source actor.
auto from_actor = FindHostQueueDSActor(actor_set->data_source_actors_);
LinkDataArrowForHostDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else {
// Link the data arrows of kernel actor.
auto from_actor = FindKernelActor(kernel_actors_temp_map, from_kernel->fullname_with_scope());
LinkDataArrowForKernelActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
}
}
}
}
// Link the control arrows of kernel actor.
LinkControlArrowForKernelActor(&(actor_set->kernel_actors_), actor_set->loop_count_actor_.get(), strategy);
// BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors.
auto no_input_kernel_actors = BuildNoInputKernelActor(graph);
auto no_input_kernel_actors = BuildNoInputKernelActor(actor_set);
actor_set->no_input_kernel_actors_.swap(no_input_kernel_actors);
// Link the control arrows of loop count actor, which depends on the no input kernel actors.
LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), graph);
LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set);
}
std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const KernelGraphPtr &graph,
const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
const HostTensorQueuePtr &host_queue) {
std::vector<DataSourceActorPtr> data_source_actors;
// Build host queue data source actor.
HostQueueDSActorPtr host_queue_ds_actor = nullptr;
for (auto &input_node : graph->input_nodes()) {
MS_EXCEPTION_IF_NULL(input_node);
if (IsHostQueueDSActor(input_node)) {
if (host_queue_ds_actor == nullptr) {
auto actor_name = graph->ToString() + "_" + "HostQueueDataSourceActor";
MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
auto host_queue = std::make_shared<HostTensorQueue>();
graph_to_host_queue_.emplace(graph, host_queue);
host_queue_ds_actor =
std::make_shared<HostQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_, host_queue);
data_source_actors.emplace_back(host_queue_ds_actor);
}
host_queue_ds_actor->data_nodes_.emplace_back(input_node);
}
}
size_t data_node_position = 0;
std::unordered_map<AnfNodePtr, size_t> front_node_position_temp_map;
// Build device queue data source actor.
auto execution_order = graph->execution_order();
auto iter = std::find_if(execution_order.begin(), execution_order.end(),
[](const CNodePtr &node) { return IsDeviceQueueDSActor(node); });
if (iter != execution_order.end()) {
auto actor_name = graph->ToString() + "_" + "DeviceQueueDataSourceActor";
MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
auto device_queue_ds_actor =
std::make_shared<DeviceQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_);
MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
data_source_actors.emplace_back(device_queue_ds_actor);
device_queue_ds_actor->data_kernel_ = *iter;
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
const auto &graph = graph_compiler_info.graphs_[i];
const auto &device_context = graph_compiler_info.device_contexts_[i];
MS_EXCEPTION_IF_NULL(graph);
// Build host queue data source actor.
for (const auto &input_node : graph->input_nodes()) {
MS_EXCEPTION_IF_NULL(input_node);
if (IsHostQueueDSActor(input_node, graph)) {
if (host_queue_ds_actor == nullptr) {
auto actor_name = graph_compiler_info.name_ + "_HostQueueDataSourceActor";
MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
host_queue_ds_actor =
std::make_shared<HostQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_, host_queue);
data_source_actors.emplace_back(host_queue_ds_actor);
}
const auto &front_node = graph->GetFrontAnfByBackendAnf(input_node);
// In the scenario where multiple backend nodes correspond to the same front node, only the first backend node
// is saved in the host queue data source actor.
if ((front_node != nullptr) && (front_node_position_temp_map.count(front_node) > 0)) {
host_queue_ds_actor->data_node_position_map_.emplace(input_node, front_node_position_temp_map[front_node]);
continue;
}
host_queue_ds_actor->data_nodes_.emplace_back(input_node);
host_queue_ds_actor->data_node_position_map_.emplace(input_node, data_node_position);
front_node_position_temp_map.emplace(front_node, data_node_position);
data_node_position++;
}
}
// Build device queue data source actor.
const auto &execution_order = graph->execution_order();
const auto &iter = std::find_if(execution_order.begin(), execution_order.end(),
[](const CNodePtr &node) { return IsDeviceQueueDSActor(node); });
if (iter != execution_order.end()) {
auto actor_name =
graph_compiler_info.name_ + "_DeviceQueueDataSourceActor" + "_" + std::to_string(graph->graph_id());
MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
auto device_queue_ds_actor =
std::make_shared<DeviceQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_);
MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
data_source_actors.emplace_back(device_queue_ds_actor);
device_queue_ds_actor->data_kernel_ = *iter;
}
}
return data_source_actors;
}
std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const KernelGraphPtr &graph,
const DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompilerInfo &graph_compiler_info) {
std::vector<KernelActorPtr> kernel_actors;
auto execution_order = graph->execution_order();
for (auto &kernel : execution_order) {
if (IsKernelActor(kernel)) {
auto kernel_actor =
std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_);
MS_EXCEPTION_IF_NULL(kernel_actor);
kernel_actors.emplace_back(kernel_actor);
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
const auto &graph = graph_compiler_info.graphs_[i];
const auto &device_context = graph_compiler_info.device_contexts_[i];
MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order();
for (auto &kernel : execution_order) {
if (IsKernelActor(kernel)) {
auto kernel_actor =
std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_);
MS_EXCEPTION_IF_NULL(kernel_actor);
kernel_actors.emplace_back(kernel_actor);
}
}
}
return kernel_actors;
}
std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info) {
auto loop_count = ConfigManager::GetInstance().iter_num();
auto actor_name = graph_compiler_info.name_ + "_" + "LoopCountActor";
auto loop_count_actor = std::make_shared<LoopCountActor>(actor_name, loop_count);
MS_LOG(INFO) << "Create loop count actor: " << actor_name;
MS_EXCEPTION_IF_NULL(loop_count_actor);
return loop_count_actor;
}
std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorSet *actor_set) {
MS_EXCEPTION_IF_NULL(actor_set);
std::vector<KernelActorPtr> no_input_kernel_actors;
auto actor_set = Fetch(graph);
MS_EXCEPTION_IF_NULL(actor_set);
for (auto &kernel_actor : actor_set->kernel_actors_) {
MS_EXCEPTION_IF_NULL(kernel_actor);
if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
@ -699,16 +737,6 @@ std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const Kernel
return no_input_kernel_actors;
}
LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto loop_count = ConfigManager::GetInstance().iter_num();
auto actor_name = graph->ToString() + "_" + "LoopCountActor";
auto loop_count_actor = std::make_shared<LoopCountActor>(actor_name, loop_count);
MS_LOG(INFO) << "Create loop count actor: " << actor_name;
MS_EXCEPTION_IF_NULL(loop_count_actor);
return loop_count_actor;
}
void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx) {
@ -740,19 +768,20 @@ void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_
auto from_output_index = from_kernel_with_output_idx.second;
auto to_input_index = to_kernel_with_input_idx.second;
auto data_nodes = from_actor->data_nodes_;
auto iter = find(data_nodes.begin(), data_nodes.end(), from_kernel);
if (iter == data_nodes.end()) {
// Get the position of from kernel in the data source actor.
auto iter = from_actor->data_node_position_map_.find(from_kernel);
if (iter == from_actor->data_node_position_map_.end()) {
MS_LOG(EXCEPTION) << "Parameter node: " << from_kernel->fullname_with_scope() << " is not exist.";
}
auto position = IntToSize(std::distance(data_nodes.begin(), iter));
auto position = iter->second;
auto to_aid = to_actor->GetAID();
auto op_arrow = std::make_shared<OpArrow>(position, to_aid, to_input_index);
from_actor->output_op_arrows_.emplace_back(op_arrow);
to_actor->input_datas_num_++;
// Update the reference count of device tensor.
UpdateRefCount(from_kernel, from_output_index);
UpdateRefCount(from_actor->data_nodes_[position], from_output_index);
}
void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
@ -778,25 +807,26 @@ void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, Kernel
}
}
void GraphScheduler::LinkControlArrowForKernelActor(KernelActor *from_actor, LoopCountActor *to_actor,
const KernelGraphPtr &graph, GraphExecutionStrategy strategy) {
MS_EXCEPTION_IF_NULL(from_actor);
void GraphScheduler::LinkControlArrowForKernelActor(std::vector<KernelActorPtr> *from_actors, LoopCountActor *to_actor,
GraphExecutionStrategy strategy) {
MS_EXCEPTION_IF_NULL(from_actors);
MS_EXCEPTION_IF_NULL(to_actor);
MS_EXCEPTION_IF_NULL(graph);
if (strategy == GraphExecutionStrategy::kStep) {
from_actor->input_controls_num_++;
}
for (auto &from_actor : *from_actors) {
MS_EXCEPTION_IF_NULL(from_actor);
if (strategy == GraphExecutionStrategy::kStep) {
from_actor->input_controls_num_++;
}
// The manager of graph member is weak ptr, so need created and used in the function IsNotRealUsedByOthers.
const auto &manager = Manage(graph, true);
MS_EXCEPTION_IF_NULL(manager);
if (opt::IsNotRealUsedByOthers(graph, from_actor->kernel_)) {
MS_EXCEPTION_IF_NULL(from_actor->kernel_);
MS_LOG(INFO) << from_actor->kernel_->fullname_with_scope() << " is not real used by other nodes.";
auto to_aid = to_actor->GetAID();
from_actor->output_op_controls_.emplace_back(to_aid);
to_actor->input_controls_num_++;
// If the kernel actor has no output in the pipeline mode, then adds the output control to loop count actor.
if ((strategy == GraphExecutionStrategy::kPipeline) && (from_actor->output_op_arrows_.size() == 0) &&
(from_actor->output_op_controls_.size() == 0)) {
MS_EXCEPTION_IF_NULL(from_actor->kernel_);
MS_LOG(INFO) << from_actor->kernel_->fullname_with_scope() << " is not real used by other nodes.";
auto to_aid = to_actor->GetAID();
from_actor->output_op_controls_.emplace_back(to_aid);
to_actor->input_controls_num_++;
}
}
}
@ -852,12 +882,9 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An
to_actor->input_controls_num_++;
}
void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(loop_count_actor);
auto actor_set = Fetch(graph);
void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set) {
MS_EXCEPTION_IF_NULL(actor_set);
MS_EXCEPTION_IF_NULL(loop_count_actor);
// Set the source data actor.
for (auto &data_source_actor : actor_set->data_source_actors_) {
@ -914,48 +941,58 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set) const {
return true;
}
void GraphScheduler::PersistDeviceTensor(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info) {
for (const auto &graph : graph_compiler_info.graphs_) {
MS_EXCEPTION_IF_NULL(graph);
for (auto &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
MS_LOG(INFO) << "The device address is not exist: " << value_node->ToString();
continue;
}
auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0);
DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor);
device_tensor->set_original_ref_count(SIZE_MAX);
device_tensor->ResetRefCount();
}
for (auto &input_node : graph->input_nodes()) {
MS_EXCEPTION_IF_NULL(input_node);
if (IsPersistentDeviceTensor(input_node)) {
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0);
MS_EXCEPTION_IF_NULL(device_tensor);
DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor);
for (auto &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
MS_LOG(INFO) << "The device address is not exist: " << value_node->ToString();
continue;
}
auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0);
DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor);
device_tensor->set_original_ref_count(SIZE_MAX);
device_tensor->ResetRefCount();
}
for (auto &input_node : graph->input_nodes()) {
MS_EXCEPTION_IF_NULL(input_node);
if (IsPersistentDeviceTensor(input_node)) {
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0);
MS_EXCEPTION_IF_NULL(device_tensor);
DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor);
device_tensor->set_original_ref_count(SIZE_MAX);
device_tensor->ResetRefCount();
}
}
}
}
HostTensorQueue *GraphScheduler::FetchHostQueue(const KernelGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(graph);
const auto &iter = graph_to_host_queue_.find(graph);
if (iter != graph_to_host_queue_.end()) {
HostTensorQueue *GraphScheduler::FetchHostQueue(const ActorInfo &actor_info) const {
const auto &iter = actor_to_host_queue_.find(actor_info);
if (iter != actor_to_host_queue_.end()) {
return iter->second.get();
} else {
return nullptr;
}
}
void GraphScheduler::DumpActor(const KernelGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(graph);
const auto &actor_set = Fetch(graph);
void GraphScheduler::DumpActor(const ActorSet *actor_set) const {
MS_EXCEPTION_IF_NULL(actor_set);
std::string filename = "./actor_set_" + graph->ToString() + ".ir";
const auto &context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
if (!save_graphs) {
return;
}
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
std::string filename = save_graphs_path + "/actor_set_" + actor_set->name_ + ".ir";
std::ofstream ofs(filename);
if (!ofs.is_open()) {
MS_LOG(ERROR) << "Open file [" << filename << "] failed!";

View File

@ -35,12 +35,37 @@ namespace runtime {
using mindspore::device::DeviceContext;
using mindspore::session::KernelWithIndex;
using KernelMapActor = std::unordered_map<std::string, KernelActorPtr>;
using ActorInfo = std::string;
enum class GraphExecutionStrategy {
kPipeline, // The actor running is triggered only by data.
kStep // The actor running need be triggered by control in addition.
};
// The graph compiler info generated by graph compiler is the express of executable graph.
// The device context is unified interface of interaction with device of corresponding graph.
// The input tensor is used to link graphs in the dynamic build scenario.
// The control node is used to link graphs in the control flow scenario.
// The origin parameters order is used to correspond to the input args.
struct GraphCompilerInfo {
GraphCompilerInfo(const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts,
const std::vector<std::vector<TensorPtr> *> &input_tensors,
const std::vector<AnfNodePtr> &control_nodes,
const std::vector<AnfNodePtr> &origin_parameters_order, const std::string &name)
: graphs_(graphs),
device_contexts_(device_contexts),
input_tensors_(input_tensors),
control_nodes_(control_nodes),
origin_parameters_order_(origin_parameters_order),
name_(name) {}
std::vector<KernelGraphPtr> graphs_;
std::vector<DeviceContext *> device_contexts_;
std::vector<std::vector<TensorPtr> *> input_tensors_;
std::vector<AnfNodePtr> control_nodes_;
std::vector<AnfNodePtr> origin_parameters_order_;
std::string name_;
};
// The actor set generated by graph transformer is the execution unit of actor runtime.
// It includes data source actor, kernel actor, loop count actor.
// The data source actor is used to obtain data and process them into device tensors,
@ -49,11 +74,13 @@ enum class GraphExecutionStrategy {
// externally. The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
// and decide whether to loop execution by loop count.
struct ActorSet {
explicit ActorSet(const ActorInfo &name) : name_(name) {}
std::vector<DataSourceActorPtr> data_source_actors_;
std::vector<KernelActorPtr> kernel_actors_;
// No input kernel actors need be triggered specifically.
std::vector<KernelActorPtr> no_input_kernel_actors_;
LoopCountActorPtr loop_count_actor_{nullptr};
ActorInfo name_;
};
using ActorSetPtr = std::shared_ptr<ActorSet>;
@ -69,9 +96,7 @@ class GraphScheduler {
void Initialize();
// Transform graph to actor DAG, contains build and link.
ActorSet *Transform(const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts,
const std::vector<TensorPtr> *input_tensors = nullptr,
const std::vector<AnfNodePtr> *control_nodes = nullptr,
ActorSet *Transform(const GraphCompilerInfo &graph_compiler_info,
GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
// Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling
@ -82,14 +107,15 @@ class GraphScheduler {
// 1. Prepare the data of device tensor store(such as weights and value nodes of graph).
// 2. Prepare the data of host tensor queue(such as non weighted parameters of graph).
// 3. Prepare the output tensor of graph.
// 4.Prepare the continuous memory for communication kernel.
void PrepareRun(const KernelGraphPtr &graph, const std::vector<TensorPtr> *input_tensors, VectorRef *const &outputs);
// 4. Prepare the continuous memory for communication kernel.
void PrepareRun(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info,
const std::vector<std::vector<TensorPtr>> &input_tensors, VectorRef *const &outputs);
// The processing entry of actors running.
bool Run(const ActorSet *actor_set, GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
// Fetch the actor set by kernel graph.
ActorSet *Fetch(const KernelGraphPtr &graph) const;
// Fetch the actor set by actor info.
ActorSet *Fetch(const ActorInfo &actor_info) const;
private:
GraphScheduler() = default;
@ -97,16 +123,16 @@ class GraphScheduler {
DISABLE_COPY_AND_ASSIGN(GraphScheduler);
// Transform the nodes of graph to actors.
ActorSetPtr Build(const KernelGraphPtr &graph, const DeviceContext *device_context);
ActorSetPtr Build(const GraphCompilerInfo &graph_compiler_info);
// Link actors to DAG through the edge connection of graph and graph execution strategy.
void Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy);
void Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info, GraphExecutionStrategy strategy);
// The processing of actors build.
std::vector<DataSourceActorPtr> BuildDataSourceActor(const KernelGraphPtr &graph,
const DeviceContext *device_context);
std::vector<KernelActorPtr> BuildKernelActor(const KernelGraphPtr &graph, const DeviceContext *device_context);
std::vector<KernelActorPtr> BuildNoInputKernelActor(const KernelGraphPtr &graph);
LoopCountActorPtr BuildLoopCountActor(const KernelGraphPtr &graph);
std::vector<DataSourceActorPtr> BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
const HostTensorQueuePtr &host_queue);
std::vector<KernelActorPtr> BuildKernelActor(const GraphCompilerInfo &graph_compiler_info);
LoopCountActorPtr BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info);
std::vector<KernelActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set);
// The processing of actors link.
void LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
@ -118,9 +144,9 @@ class GraphScheduler {
void LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx);
void LinkControlArrowForKernelActor(KernelActor *from_actor, LoopCountActor *to_actor, const KernelGraphPtr &graph,
void LinkControlArrowForKernelActor(std::vector<KernelActorPtr> *from_actors, LoopCountActor *to_actor,
GraphExecutionStrategy strategy);
void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph);
void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set);
void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node,
const KernelMapActor &kernel_actors_map);
@ -128,19 +154,19 @@ class GraphScheduler {
bool CheckActorValid(const ActorSet *actor_set) const;
// Persist device tensors of graph's some nodes(such as weights and value nodes).
void PersistDeviceTensor(const KernelGraphPtr &graph);
void PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info);
// Fetch the hsot tensor queue by kernel graph.
HostTensorQueue *FetchHostQueue(const KernelGraphPtr &graph) const;
// Fetch the hsot tensor queue by actor info.
HostTensorQueue *FetchHostQueue(const ActorInfo &actor_info) const;
// Display the actor information of corresponding kernel graph.
void DumpActor(const KernelGraphPtr &graph) const;
void DumpActor(const ActorSet *actor_set) const;
void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const;
void DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const;
void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const;
std::unordered_map<KernelGraphPtr, ActorSetPtr> graph_to_actors_;
std::unordered_map<KernelGraphPtr, HostTensorQueuePtr> graph_to_host_queue_;
std::unordered_map<ActorInfo, ActorSetPtr> actors_;
std::unordered_map<ActorInfo, HostTensorQueuePtr> actor_to_host_queue_;
// The second element of pair represents the output index of kernel actor corresponding to the device tensor.
std::unordered_map<DeviceTensorPtr, std::pair<KernelActorPtr, int>> device_address_to_actor_;

View File

@ -29,7 +29,6 @@
#include "utils/ms_utils.h"
#include "runtime/hardware/device_context_manager.h"
#include "runtime/framework/graph_compiler.h"
#include "runtime/framework/graph_scheduler.h"
#include "utils/scoped_long_running.h"
#ifdef ENABLE_GE
#include "utils/callbacks_ge.h"
@ -237,37 +236,46 @@ MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string
graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend_name);
}
GraphId MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
ActorInfo MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphPtr root_graph = WrapPrimitives(func_graph);
MS_EXCEPTION_IF_NULL(root_graph);
// Compile root graph.
auto root_graph_id = CompileGraph(root_graph);
graph_to_device_context_.clear();
control_nodes_.clear();
CompileGraph(root_graph);
// Compile sub graphs.
FuncGraphSet sub_graphs = root_graph->manager()->func_graphs();
for (auto sub_graph : sub_graphs) {
if (sub_graph != func_graph && sub_graph != nullptr) {
(void)CompileGraph(sub_graph);
CompileGraph(sub_graph);
}
}
// Transform graph to actor DAG, and schedule the actor DAG.
// Construct the graph compiler info.
std::vector<KernelGraphPtr> graphs;
std::vector<DeviceContext *> device_contexts;
std::string name = "kernel_graph";
for (const auto &graph_id_to_context : graph_to_device_context_) {
graphs.emplace_back(runtime::GraphCompiler::GetInstance().Fetch(graph_id_to_context.first));
device_contexts.emplace_back(graph_id_to_context.second);
name.append("_").append(std::to_string(graph_id_to_context.first));
}
const auto &actor_set =
runtime::GraphScheduler::GetInstance().Transform(graphs, device_contexts, nullptr, &control_nodes_);
std::vector<std::vector<tensor::TensorPtr> *> input_tensors;
auto graph_compiler_info = std::make_unique<GraphCompilerInfo>(graphs, device_contexts, input_tensors, control_nodes_,
root_graph->parameters(), name);
// Transform graph to actor DAG, and schedule the actor DAG.
const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(*(graph_compiler_info.get()));
runtime::GraphScheduler::GetInstance().Schedule(actor_set);
return root_graph_id;
actor_to_graph_compiler_info_.emplace(actor_set->name_, std::move(graph_compiler_info));
return actor_set->name_;
}
GraphId MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) {
void MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(graph_partition_);
@ -309,12 +317,10 @@ GraphId MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) {
control_nodes_.push_back(cut_node);
}
}
return graph_to_device_context_.begin()->first;
}
VectorRef MindRTBackend::RunGraph(GraphId graph_id, const VectorRef &args) {
MS_LOG(INFO) << "Run graph begin, graph id: " << graph_id;
VectorRef MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args) {
MS_LOG(INFO) << "Run actor begin, actor name: " << actor_info;
const auto &context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
@ -322,39 +328,41 @@ VectorRef MindRTBackend::RunGraph(GraphId graph_id, const VectorRef &args) {
return VectorRef();
}
// Fetch the kernel graph.
const auto &kernel_graph = runtime::GraphCompiler::GetInstance().Fetch(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
// Fetch the graph compiler info.
const auto &graph_iter = actor_to_graph_compiler_info_.find(actor_info);
if (graph_iter == actor_to_graph_compiler_info_.end()) {
MS_LOG(EXCEPTION) << "Can't find the graph compiler info.";
}
const auto &graph_compiler_info = *(graph_iter->second.get());
const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
// Transform args to input tensors.
std::vector<tensor::TensorPtr> inputs;
for (const auto &input_node : kernel_graph->input_nodes()) {
const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
MS_EXCEPTION_IF_NULL(front_node);
MS_EXCEPTION_IF_NULL(front_node->func_graph());
const auto &origin_parameters = front_node->func_graph()->parameters();
const auto &iter = std::find(origin_parameters.begin(), origin_parameters.end(), front_node);
if (iter == origin_parameters.end()) {
MS_LOG(EXCEPTION) << "Parameter node: " << front_node->fullname_with_scope() << " is not exist.";
std::vector<std::vector<tensor::TensorPtr>> input_tensors;
for (const auto &kernel_graph : graph_compiler_info.graphs_) {
std::vector<tensor::TensorPtr> input_tensor;
for (const auto &input_node : kernel_graph->input_nodes()) {
const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
const auto &iter = std::find(origin_parameters.begin(), origin_parameters.end(), front_node);
if (iter == origin_parameters.end()) {
input_tensor.emplace_back(nullptr);
continue;
}
auto position = IntToSize(std::distance(origin_parameters.begin(), iter));
PushInputTensor(args[position], &input_tensor);
}
auto position = IntToSize(std::distance(origin_parameters.begin(), iter));
PushInputTensor(args[position], &inputs);
input_tensors.emplace_back(input_tensor);
}
// Fetch the actor DAG.
const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(kernel_graph);
MS_EXCEPTION_IF_NULL(actor_set);
// Run actor DAG.
mindspore::ScopedLongRunning long_running;
VectorRef outputs;
runtime::GraphScheduler::GetInstance().PrepareRun(kernel_graph, &inputs, &outputs);
const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
runtime::GraphScheduler::GetInstance().PrepareRun(actor_set, graph_compiler_info, input_tensors, &outputs);
if (!runtime::GraphScheduler::GetInstance().Run(actor_set)) {
MS_LOG(EXCEPTION) << "The graph runs failed, graph id: " << graph_id
<< ", graph name: " << kernel_graph->ToString();
MS_LOG(EXCEPTION) << "The actor runs failed, actor name: " << actor_set->name_;
}
MS_LOG(INFO) << "Run graph end, graph id: " << graph_id;
MS_LOG(INFO) << "Run actor end, actor name: " << actor_info;
return outputs;
}
} // namespace compile

View File

@ -30,11 +30,15 @@
#include "vm/vm.h"
#include "backend/session/session_basic.h"
#include "runtime/hardware/device_context.h"
#include "runtime/framework/graph_scheduler.h"
namespace mindspore {
namespace compile {
using OpRunInfo = session::OpRunInfo;
using DeviceContext = device::DeviceContext;
using ActorInfo = runtime::ActorInfo;
using GraphCompilerInfo = runtime::GraphCompilerInfo;
enum SwitchCondStatus {
kCondOk = 0,
kCondAlreadyRun,
@ -95,22 +99,24 @@ class MindRTBackend : public Backend {
MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id);
~MindRTBackend() override = default;
// The parameter root_graph is a root graph, and the root graph maybe contain multiple sub graphs,
// the return is the kernelGraph id of the root graph. It will traverse all subgraphs to call CompileGraph.
GraphId CompileGraphs(const FuncGraphPtr &root_graph);
// The parameter root_graph is a root graph, and the root graph maybe contain multiple sub graphs, It will traverse
// all sub graphs to call CompileGraph.
ActorInfo CompileGraphs(const FuncGraphPtr &root_graph);
// Compile single op kernel graph in the pyNative mode.
GraphId CompileGraph(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int64_t> &tensors_mask);
ActorInfo CompileGraph(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int64_t> &tensors_mask);
// Run Graph in the graph mode.
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
VectorRef RunGraph(const ActorInfo &actor_info, const VectorRef &args);
// Run Graph in the pyNative mode.
VectorRef RunGraph(const GraphInfo &graph_info, const VectorRef &args);
VectorRef RunGraph(const GraphInfo &graph_info, const std::vector<tensor::TensorPtr> &input_tensors);
private:
// The parameter func_graph is a graph, it can be either a root graph or a sub graph,
// the return is the corresponding kernelGraph id of the graph.
GraphId CompileGraph(const FuncGraphPtr &func_graph);
// The result of graph compiler is stored in graph_to_device_context_ and control_nodes_.
void CompileGraph(const FuncGraphPtr &func_graph);
// When compiling FuncGraph, it is divided according to the control nodes, and obtain the control nodes and several
// node segments. Node segments will be compiled into kernelGraphs which are expressed as GraphId and bound to
@ -118,6 +124,8 @@ class MindRTBackend : public Backend {
std::unordered_map<GraphId, DeviceContext *> graph_to_device_context_;
std::vector<AnfNodePtr> control_nodes_;
std::unordered_map<ActorInfo, std::unique_ptr<GraphCompilerInfo>> actor_to_graph_compiler_info_;
GraphPartitionPtr graph_partition_;
std::string device_name_;
uint32_t device_id_;

View File

@ -20,6 +20,7 @@ message("PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}")
include_directories(${PYTHON_INCLUDE_DIRS})
include_directories(${MS_CCSRC_PATH})
include_directories(${CMAKE_SOURCE_DIR}/mindspore/core)
include_directories(${CMAKE_SOURCE_DIR}/mindspore/core/mindrt/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/stub/runtime/)
include_directories(${CMAKE_BINARY_DIR})