forked from mindspore-Ecosystem/mindspore
!16116 graph scheduler support multi graphs
From: @limingqi107 Reviewed-by: @cristoval,@wilfchen Signed-off-by: @wilfchen
This commit is contained in:
commit
f356c51169
|
@ -2,6 +2,7 @@
|
|||
include_directories(${CMAKE_SOURCE_DIR}/mindspore/core)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
include_directories(${CMAKE_BINARY_DIR})
|
||||
include_directories(${CMAKE_SOURCE_DIR}/mindspore/core/mindrt/include)
|
||||
|
||||
if(ENABLE_CPU)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/backend/kernel_compiler/cpu)
|
||||
|
|
|
@ -1109,6 +1109,14 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr
|
|||
}
|
||||
}
|
||||
|
||||
AnfNodePtr KernelGraph::GetFrontNodeByInternalParameter(const AnfNodePtr ¶meter) const {
|
||||
const auto &iter = internal_parameters_to_front_map_.find(parameter);
|
||||
if (iter != internal_parameters_to_front_map_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const {
|
||||
auto iter = front_to_internal_outputs_map_.find(front_node);
|
||||
if (iter != front_to_internal_outputs_map_.end()) {
|
||||
|
|
|
@ -71,6 +71,7 @@ class KernelGraph : public FuncGraph {
|
|||
parent_graph_ = graph.parent_graph_;
|
||||
start_label_ = graph.start_label_;
|
||||
end_goto_ = graph.end_goto_;
|
||||
internal_parameters_to_front_map_ = graph.internal_parameters_to_front_map_;
|
||||
front_to_internal_outputs_map_ = graph.front_to_internal_outputs_map_;
|
||||
internal_outputs_to_front_map_ = graph.internal_outputs_to_front_map_;
|
||||
internal_outputs_tensor_map_ = graph.internal_outputs_tensor_map_;
|
||||
|
@ -199,6 +200,7 @@ class KernelGraph : public FuncGraph {
|
|||
bool unique_target = false);
|
||||
void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx = -1,
|
||||
int dst_output_idx = -1);
|
||||
AnfNodePtr GetFrontNodeByInternalParameter(const AnfNodePtr ¶meter) const;
|
||||
AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const;
|
||||
bool IsInternalOutput(const AnfNodePtr &node, int output_idx = -1) const;
|
||||
bool IsUniqueTargetInternalOutput(const AnfNodePtr &node, int output_idx) const;
|
||||
|
@ -353,6 +355,7 @@ class KernelGraph : public FuncGraph {
|
|||
|
||||
CNodePtr start_label_;
|
||||
CNodePtr end_goto_;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> internal_parameters_to_front_map_;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_;
|
||||
std::unordered_map<AnfNodePtr, std::unordered_map<int, std::pair<AnfNodePtr, bool>>> internal_outputs_to_front_map_;
|
||||
std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_;
|
||||
|
|
|
@ -756,11 +756,7 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
|
|||
} else if (anf->isa<Parameter>()) {
|
||||
auto new_parameter = CreateNewParameterFromParameter(anf, graph);
|
||||
cnode_inputs->push_back(new_parameter);
|
||||
if (GetGraphIdByNode(anf) == kInvalidGraphId) {
|
||||
graph->FrontBackendlMapAdd(anf, new_parameter);
|
||||
} else {
|
||||
(*other_graph_cnode)[anf] = new_parameter;
|
||||
}
|
||||
graph->FrontBackendlMapAdd(anf, new_parameter);
|
||||
continue;
|
||||
} else {
|
||||
// the input node is a cnode from other graph
|
||||
|
|
|
@ -60,16 +60,16 @@ void TaskEmitActionForMindRT(const ResourcePtr &res) {
|
|||
auto mindrt_bc_ptr = std::dynamic_pointer_cast<compile::MindRTBackend>(bc_ptr);
|
||||
MS_EXCEPTION_IF_NULL(mindrt_bc_ptr);
|
||||
|
||||
// The output of graph compiler is graph id.
|
||||
// The output of graph compiler is actor.
|
||||
res->results()[kOutput] = mindrt_bc_ptr->CompileGraphs(res->func_graph());
|
||||
}
|
||||
|
||||
void ExecuteActionForMindRT(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res);
|
||||
if (!res->results()[kOutput].is<GraphId>()) {
|
||||
if (!res->results()[kOutput].is<compile::ActorInfo>()) {
|
||||
MS_LOG(EXCEPTION) << "Execute args error";
|
||||
}
|
||||
auto graph_id = res->results()[kOutput].cast<GraphId>();
|
||||
const auto &actor_info = res->results()[kOutput].cast<compile::ActorInfo>();
|
||||
|
||||
// Get the mindRT backend.
|
||||
std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
|
||||
|
@ -78,9 +78,9 @@ void ExecuteActionForMindRT(const ResourcePtr &res) {
|
|||
|
||||
// Construct the graph run function ptr.
|
||||
compile::VmEvalFuncPtr run =
|
||||
std::make_shared<compile::VmEvalFunc>([mindrt_bc_ptr, graph_id](const VectorRef &args) -> BaseRef {
|
||||
std::make_shared<compile::VmEvalFunc>([mindrt_bc_ptr, actor_info](const VectorRef &args) -> BaseRef {
|
||||
MS_LOG(INFO) << "Execute args size " << args.size();
|
||||
auto outs = mindrt_bc_ptr->RunGraph(graph_id, args);
|
||||
auto outs = mindrt_bc_ptr->RunGraph(actor_info, args);
|
||||
MS_LOG(DEBUG) << "out size " << outs.size();
|
||||
return outs[0];
|
||||
});
|
||||
|
|
|
@ -60,6 +60,11 @@ void DataSourceActor::FreeMemory(OpContext<DeviceTensor> *context) {
|
|||
void DataSourceActor::SendOutput(OpContext<DeviceTensor> *context) {
|
||||
MS_LOG(INFO) << "Data source actor(" << GetAID().Name() << ") sends output data.";
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
// No output.
|
||||
if (output_op_arrows_.size() == 0) {
|
||||
SET_OPCONTEXT_SUCCESS_RET((*context));
|
||||
}
|
||||
|
||||
if (buffers_.size() == 0) {
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty.");
|
||||
}
|
||||
|
|
|
@ -118,6 +118,8 @@ class HostQueueDataSourceActor : public DataSourceActor {
|
|||
HostTensorQueuePtr host_queue_;
|
||||
// Input data nodes fetch data from host queue.
|
||||
std::vector<AnfNodePtr> data_nodes_;
|
||||
// The location of the data node in the data source actor.
|
||||
std::unordered_map<AnfNodePtr, size_t> data_node_position_map_;
|
||||
};
|
||||
|
||||
using DataSourceActorPtr = std::shared_ptr<DataSourceActor>;
|
||||
|
|
|
@ -173,6 +173,11 @@ void KernelActor::FetchLaunchArgs(std::vector<AddressPtr> *kernel_inputs, std::v
|
|||
|
||||
void KernelActor::SendOutput(OpContext<DeviceTensor> *context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
// No output.
|
||||
if ((output_op_arrows_.size() == 0) && (output_op_controls_.size() == 0)) {
|
||||
SET_OPCONTEXT_SUCCESS_RET((*context));
|
||||
}
|
||||
|
||||
// Send output data.
|
||||
for (auto &op_arrow : output_op_arrows_) {
|
||||
MS_EXCEPTION_IF_NULL(op_arrow);
|
||||
|
|
|
@ -256,10 +256,6 @@ GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const Graph
|
|||
|
||||
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
|
||||
|
||||
// Transform graph to actor DAG, contains build and link.
|
||||
GraphScheduler::GetInstance().Transform({graph}, {device_context_}, input_tensors, nullptr,
|
||||
GraphExecutionStrategy::kStep);
|
||||
run_op_graphs_[graph_info] = graph;
|
||||
return graph->graph_id();
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "utils/config_manager.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "common/trans.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -36,10 +37,14 @@ bool IsDeviceQueueDSActor(const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool IsHostQueueDSActor(const AnfNodePtr &node) {
|
||||
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
|
||||
return true;
|
||||
// Judge whether node is internal parameter.
|
||||
if (graph->GetFrontNodeByInternalParameter(node) == nullptr) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -384,32 +389,28 @@ void GraphScheduler::Initialize() {
|
|||
(void)actorMgr->Spawn(base_actor, false);
|
||||
}
|
||||
|
||||
ActorSet *GraphScheduler::Transform(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<TensorPtr> *input_tensors,
|
||||
const std::vector<AnfNodePtr> *control_nodes, GraphExecutionStrategy strategy) {
|
||||
if (graphs.size() != device_contexts.size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device_contexts.";
|
||||
ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info, GraphExecutionStrategy strategy) {
|
||||
MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor begin.";
|
||||
if (graph_compiler_info.graphs_.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "The number of graphs is zero.";
|
||||
}
|
||||
if (graph_compiler_info.graphs_.size() != graph_compiler_info.device_contexts_.size()) {
|
||||
MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device contexts.";
|
||||
}
|
||||
Initialize();
|
||||
std::vector<ActorSetPtr> actor_sets;
|
||||
for (size_t i = 0; i < graphs.size(); ++i) {
|
||||
auto graph = graphs[i];
|
||||
auto device_context = device_contexts[i];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_LOG(INFO) << "Graph(" << graph->ToString() << ") transforms actor begin.";
|
||||
PersistDeviceTensor(graph);
|
||||
auto actor_set = Build(graph, device_context);
|
||||
actor_sets.emplace_back(actor_set);
|
||||
graph_to_actors_.emplace(graph, actor_set);
|
||||
Link(actor_set.get(), graph, strategy);
|
||||
|
||||
if (!CheckActorValid(actor_set.get())) {
|
||||
MS_LOG(EXCEPTION) << "The actor set of " << graph->ToString() << " is invalid.";
|
||||
}
|
||||
MS_LOG(INFO) << "Graph(" << graph->ToString() << ") transforms actor end.";
|
||||
Initialize();
|
||||
|
||||
PersistDeviceTensor(graph_compiler_info);
|
||||
const auto &actor_set = Build(graph_compiler_info);
|
||||
Link(actor_set.get(), graph_compiler_info, strategy);
|
||||
actors_.emplace(actor_set->name_, actor_set);
|
||||
|
||||
DumpActor(actor_set.get());
|
||||
if (!CheckActorValid(actor_set.get())) {
|
||||
MS_LOG(EXCEPTION) << "The actor set of " << graph_compiler_info.name_ << " is invalid.";
|
||||
}
|
||||
return actor_sets[0].get();
|
||||
MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor end.";
|
||||
return actor_set.get();
|
||||
}
|
||||
|
||||
void GraphScheduler::Schedule(const ActorSet *actor_set) {
|
||||
|
@ -438,61 +439,70 @@ void GraphScheduler::Schedule(const ActorSet *actor_set) {
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::PrepareRun(const KernelGraphPtr &graph, const std::vector<TensorPtr> *input_tensors,
|
||||
VectorRef *const &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(input_tensors);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
// Get the device context for the first kernel actor.
|
||||
const auto &actor_set = Fetch(graph);
|
||||
void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info,
|
||||
const std::vector<std::vector<TensorPtr>> &input_tensors, VectorRef *const &outputs) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
const auto &first_kernel_actor = actor_set->kernel_actors_[0];
|
||||
MS_EXCEPTION_IF_NULL(first_kernel_actor);
|
||||
const auto &device_context = first_kernel_actor->device_context_;
|
||||
|
||||
// 1.Prepare the data of device tensor store(value nodes of graph).
|
||||
for (const auto &value_node : graph->graph_value_nodes()) {
|
||||
if (AnfAlgo::OutputAddrExist(value_node, 0)) {
|
||||
PrepareDataForValueNode(value_node, device_context);
|
||||
}
|
||||
}
|
||||
|
||||
// 1.Prepare the data of device tensor store(weights of graph), and fill the host tensors for non weighted parameters.
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
std::vector<TensorPtr> host_tensors;
|
||||
const auto &input_nodes = graph->input_nodes();
|
||||
for (size_t i = 0; i < input_nodes.size(); ++i) {
|
||||
const auto &input_node = input_nodes[i];
|
||||
const auto &input_tensor = (*input_tensors)[i];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (IsPersistentDeviceTensor(input_node)) {
|
||||
// Prepare the device data for weights.
|
||||
PrepareDataForWeightNode(input_node, input_tensor, device_context);
|
||||
} else {
|
||||
// Fill the host tensors for non weighted parameters.
|
||||
host_tensors.emplace_back(input_tensor);
|
||||
const auto &host_data_source_actor = FindHostQueueDSActor(actor_set->data_source_actors_);
|
||||
if (host_data_source_actor != nullptr) {
|
||||
host_tensors.resize(host_data_source_actor->data_nodes_.size());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
|
||||
const auto &graph = graph_compiler_info.graphs_[i];
|
||||
const auto &device_context = graph_compiler_info.device_contexts_[i];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
// 1.Prepare the data of device tensor store(value nodes of graph).
|
||||
for (const auto &value_node : graph->graph_value_nodes()) {
|
||||
if (AnfAlgo::OutputAddrExist(value_node, 0)) {
|
||||
PrepareDataForValueNode(value_node, device_context);
|
||||
}
|
||||
}
|
||||
|
||||
// 1.Prepare the data of device tensor store(weights of graph), and fill host tensors for non weighted parameters.
|
||||
const auto &input_nodes = graph->input_nodes();
|
||||
const auto &tensors = input_tensors[i];
|
||||
for (size_t j = 0; j < input_nodes.size(); ++j) {
|
||||
const auto &input_node = input_nodes[j];
|
||||
const auto &input_tensor = tensors[j];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (IsPersistentDeviceTensor(input_node)) {
|
||||
// Prepare the device data for weights.
|
||||
PrepareDataForWeightNode(input_node, input_tensor, device_context);
|
||||
} else if (IsHostQueueDSActor(input_node, graph)) {
|
||||
MS_EXCEPTION_IF_NULL(host_data_source_actor);
|
||||
// Fill the host tensors for non weighted parameters.
|
||||
const auto &iter = host_data_source_actor->data_node_position_map_.find(input_node);
|
||||
if (iter != host_data_source_actor->data_node_position_map_.end()) {
|
||||
host_tensors[iter->second] = input_tensor;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2.Prepare the output tensor of graph.
|
||||
for (const auto &output_node : graph->outputs()) {
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
MS_LOG(INFO) << "Create node output: " << output_node->fullname_with_scope();
|
||||
outputs->emplace_back(CreateOutputTensors(output_node, graph, tensors));
|
||||
}
|
||||
|
||||
// 3.Prepare the continuous memory for communication kernel.
|
||||
for (const auto &kernel : graph->execution_order()) {
|
||||
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
||||
AllocateContinuousMemoryForInput(kernel, device_context, graph->is_all_nop_node());
|
||||
AllocateContinuousMemoryForOutput(kernel, device_context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2.Prepare the data of host tensor queue(non weighted parameters of graph).
|
||||
const auto &host_tensor_queue = FetchHostQueue(graph);
|
||||
if (host_tensor_queue != nullptr) {
|
||||
// 4.Prepare the data of host tensor queue(non weighted parameters of graph).
|
||||
if (host_data_source_actor != nullptr) {
|
||||
const auto &host_tensor_queue = FetchHostQueue(actor_set->name_);
|
||||
MS_EXCEPTION_IF_NULL(host_tensor_queue);
|
||||
host_tensor_queue->PushData(host_tensors);
|
||||
}
|
||||
|
||||
// 3.Prepare the output tensor of graph.
|
||||
for (const auto &output_node : graph->outputs()) {
|
||||
MS_EXCEPTION_IF_NULL(output_node);
|
||||
MS_LOG(INFO) << "Create node output: " << output_node->fullname_with_scope();
|
||||
outputs->emplace_back(CreateOutputTensors(output_node, graph, *input_tensors));
|
||||
}
|
||||
|
||||
// 4.Prepare the continuous memory for communication kernel.
|
||||
for (const auto &kernel : graph->execution_order()) {
|
||||
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
||||
AllocateContinuousMemoryForInput(kernel, device_context, graph->is_all_nop_node());
|
||||
AllocateContinuousMemoryForOutput(kernel, device_context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool GraphScheduler::Run(const ActorSet *actor_set, GraphExecutionStrategy strategy) {
|
||||
|
@ -544,36 +554,32 @@ bool GraphScheduler::Run(const ActorSet *actor_set, GraphExecutionStrategy strat
|
|||
return true;
|
||||
}
|
||||
|
||||
ActorSet *GraphScheduler::Fetch(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto iter = graph_to_actors_.find(graph);
|
||||
if (iter != graph_to_actors_.end()) {
|
||||
ActorSet *GraphScheduler::Fetch(const ActorInfo &actor_info) const {
|
||||
auto iter = actors_.find(actor_info);
|
||||
if (iter != actors_.end()) {
|
||||
return iter->second.get();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Can't find the actors map of graph: " << graph->ToString();
|
||||
MS_LOG(ERROR) << "Can't find the actors map of " << actor_info;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
ActorSetPtr GraphScheduler::Build(const KernelGraphPtr &graph, const DeviceContext *device_context) {
|
||||
auto actor_set = std::make_shared<ActorSet>();
|
||||
ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info) {
|
||||
auto actor_set = std::make_shared<ActorSet>(graph_compiler_info.name_);
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
|
||||
auto data_source_actors = BuildDataSourceActor(graph, device_context);
|
||||
actor_set->data_source_actors_.swap(data_source_actors);
|
||||
|
||||
auto kernel_actors = BuildKernelActor(graph, device_context);
|
||||
actor_set->kernel_actors_.swap(kernel_actors);
|
||||
|
||||
auto loop_count_actor = BuildLoopCountActor(graph);
|
||||
actor_set->loop_count_actor_ = loop_count_actor;
|
||||
auto host_queue = std::make_shared<HostTensorQueue>();
|
||||
actor_to_host_queue_.emplace(actor_set->name_, host_queue);
|
||||
actor_set->data_source_actors_ = BuildDataSourceActor(graph_compiler_info, host_queue);
|
||||
actor_set->kernel_actors_ = BuildKernelActor(graph_compiler_info);
|
||||
actor_set->loop_count_actor_ = BuildLoopCountActor(graph_compiler_info);
|
||||
|
||||
return actor_set;
|
||||
}
|
||||
|
||||
void GraphScheduler::Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy) {
|
||||
void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info,
|
||||
GraphExecutionStrategy strategy) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
KernelMapActor kernel_actors_temp_map;
|
||||
for (auto &actor : actor_set->kernel_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
|
@ -581,113 +587,145 @@ void GraphScheduler::Link(ActorSet *actor_set, const KernelGraphPtr &graph, Grap
|
|||
}
|
||||
|
||||
// Foreach the execution order to link the actors.
|
||||
auto execution_order = graph->execution_order();
|
||||
for (auto &kernel : execution_order) {
|
||||
if (!IsKernelActor(kernel)) {
|
||||
continue;
|
||||
}
|
||||
auto kernel_actor = FindKernelActor(kernel_actors_temp_map, kernel->fullname_with_scope());
|
||||
// Link the control arrows of kernel actor.
|
||||
LinkControlArrowForKernelActor(kernel_actor, actor_set->loop_count_actor_.get(), graph, strategy);
|
||||
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
|
||||
auto input_node = AnfAlgo::GetInputNode(kernel, i);
|
||||
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
|
||||
LinkControlArrowByAutoMonad(kernel_actor, input_node, kernel_actors_temp_map);
|
||||
if (HasAbstractMonad(input_node)) {
|
||||
continue; // No data arrow for monad input.
|
||||
for (const auto &graph : graph_compiler_info.graphs_) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto execution_order = graph->execution_order();
|
||||
for (auto &kernel : execution_order) {
|
||||
if (!IsKernelActor(kernel)) {
|
||||
continue;
|
||||
}
|
||||
auto kernel_actor = FindKernelActor(kernel_actors_temp_map, kernel->fullname_with_scope());
|
||||
|
||||
KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
|
||||
KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
|
||||
auto from_kernel = from_kernel_with_output_idx.first;
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
|
||||
auto input_node = AnfAlgo::GetInputNode(kernel, i);
|
||||
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
|
||||
LinkControlArrowByAutoMonad(kernel_actor, input_node, kernel_actors_temp_map);
|
||||
if (HasAbstractMonad(input_node)) {
|
||||
continue; // No data arrow for monad input.
|
||||
}
|
||||
|
||||
if (IsDeviceQueueDSActor(from_kernel)) {
|
||||
// Link the data arrows of device queue data source actor.
|
||||
auto from_actor = FindDeviceQueueDSActor(actor_set->data_source_actors_);
|
||||
LinkDataArrowForDeviceDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else if (IsHostQueueDSActor(from_kernel)) {
|
||||
// Link the data arrows of host queue data source actor.
|
||||
auto from_actor = FindHostQueueDSActor(actor_set->data_source_actors_);
|
||||
LinkDataArrowForHostDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else {
|
||||
// Link the data arrows of kernel actor.
|
||||
auto from_actor = FindKernelActor(kernel_actors_temp_map, from_kernel->fullname_with_scope());
|
||||
LinkDataArrowForKernelActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
|
||||
KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
|
||||
auto from_kernel = from_kernel_with_output_idx.first;
|
||||
|
||||
if (IsDeviceQueueDSActor(from_kernel)) {
|
||||
// Link the data arrows of device queue data source actor.
|
||||
auto from_actor = FindDeviceQueueDSActor(actor_set->data_source_actors_);
|
||||
LinkDataArrowForDeviceDSActor(from_actor, kernel_actor, from_kernel_with_output_idx,
|
||||
to_kernel_with_input_idx);
|
||||
} else if (IsHostQueueDSActor(from_kernel, graph)) {
|
||||
// Link the data arrows of host queue data source actor.
|
||||
auto from_actor = FindHostQueueDSActor(actor_set->data_source_actors_);
|
||||
LinkDataArrowForHostDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else {
|
||||
// Link the data arrows of kernel actor.
|
||||
auto from_actor = FindKernelActor(kernel_actors_temp_map, from_kernel->fullname_with_scope());
|
||||
LinkDataArrowForKernelActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Link the control arrows of kernel actor.
|
||||
LinkControlArrowForKernelActor(&(actor_set->kernel_actors_), actor_set->loop_count_actor_.get(), strategy);
|
||||
|
||||
// BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors.
|
||||
auto no_input_kernel_actors = BuildNoInputKernelActor(graph);
|
||||
auto no_input_kernel_actors = BuildNoInputKernelActor(actor_set);
|
||||
actor_set->no_input_kernel_actors_.swap(no_input_kernel_actors);
|
||||
|
||||
// Link the control arrows of loop count actor, which depends on the no input kernel actors.
|
||||
LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), graph);
|
||||
LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set);
|
||||
}
|
||||
|
||||
std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const KernelGraphPtr &graph,
|
||||
const DeviceContext *device_context) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
|
||||
const HostTensorQueuePtr &host_queue) {
|
||||
std::vector<DataSourceActorPtr> data_source_actors;
|
||||
|
||||
// Build host queue data source actor.
|
||||
HostQueueDSActorPtr host_queue_ds_actor = nullptr;
|
||||
for (auto &input_node : graph->input_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (IsHostQueueDSActor(input_node)) {
|
||||
if (host_queue_ds_actor == nullptr) {
|
||||
auto actor_name = graph->ToString() + "_" + "HostQueueDataSourceActor";
|
||||
MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
|
||||
auto host_queue = std::make_shared<HostTensorQueue>();
|
||||
graph_to_host_queue_.emplace(graph, host_queue);
|
||||
host_queue_ds_actor =
|
||||
std::make_shared<HostQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_, host_queue);
|
||||
data_source_actors.emplace_back(host_queue_ds_actor);
|
||||
}
|
||||
host_queue_ds_actor->data_nodes_.emplace_back(input_node);
|
||||
}
|
||||
}
|
||||
size_t data_node_position = 0;
|
||||
std::unordered_map<AnfNodePtr, size_t> front_node_position_temp_map;
|
||||
|
||||
// Build device queue data source actor.
|
||||
auto execution_order = graph->execution_order();
|
||||
auto iter = std::find_if(execution_order.begin(), execution_order.end(),
|
||||
[](const CNodePtr &node) { return IsDeviceQueueDSActor(node); });
|
||||
if (iter != execution_order.end()) {
|
||||
auto actor_name = graph->ToString() + "_" + "DeviceQueueDataSourceActor";
|
||||
MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
|
||||
auto device_queue_ds_actor =
|
||||
std::make_shared<DeviceQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_);
|
||||
MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
|
||||
data_source_actors.emplace_back(device_queue_ds_actor);
|
||||
device_queue_ds_actor->data_kernel_ = *iter;
|
||||
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
|
||||
const auto &graph = graph_compiler_info.graphs_[i];
|
||||
const auto &device_context = graph_compiler_info.device_contexts_[i];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
// Build host queue data source actor.
|
||||
for (const auto &input_node : graph->input_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (IsHostQueueDSActor(input_node, graph)) {
|
||||
if (host_queue_ds_actor == nullptr) {
|
||||
auto actor_name = graph_compiler_info.name_ + "_HostQueueDataSourceActor";
|
||||
MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
|
||||
host_queue_ds_actor =
|
||||
std::make_shared<HostQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_, host_queue);
|
||||
data_source_actors.emplace_back(host_queue_ds_actor);
|
||||
}
|
||||
|
||||
const auto &front_node = graph->GetFrontAnfByBackendAnf(input_node);
|
||||
// In the scenario where multiple backend nodes correspond to the same front node, only the first backend node
|
||||
// is saved in the host queue data source actor.
|
||||
if ((front_node != nullptr) && (front_node_position_temp_map.count(front_node) > 0)) {
|
||||
host_queue_ds_actor->data_node_position_map_.emplace(input_node, front_node_position_temp_map[front_node]);
|
||||
continue;
|
||||
}
|
||||
host_queue_ds_actor->data_nodes_.emplace_back(input_node);
|
||||
host_queue_ds_actor->data_node_position_map_.emplace(input_node, data_node_position);
|
||||
front_node_position_temp_map.emplace(front_node, data_node_position);
|
||||
data_node_position++;
|
||||
}
|
||||
}
|
||||
|
||||
// Build device queue data source actor.
|
||||
const auto &execution_order = graph->execution_order();
|
||||
const auto &iter = std::find_if(execution_order.begin(), execution_order.end(),
|
||||
[](const CNodePtr &node) { return IsDeviceQueueDSActor(node); });
|
||||
if (iter != execution_order.end()) {
|
||||
auto actor_name =
|
||||
graph_compiler_info.name_ + "_DeviceQueueDataSourceActor" + "_" + std::to_string(graph->graph_id());
|
||||
MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
|
||||
auto device_queue_ds_actor =
|
||||
std::make_shared<DeviceQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_);
|
||||
MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
|
||||
data_source_actors.emplace_back(device_queue_ds_actor);
|
||||
device_queue_ds_actor->data_kernel_ = *iter;
|
||||
}
|
||||
}
|
||||
return data_source_actors;
|
||||
}
|
||||
|
||||
std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const KernelGraphPtr &graph,
|
||||
const DeviceContext *device_context) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
std::vector<KernelActorPtr> kernel_actors;
|
||||
|
||||
auto execution_order = graph->execution_order();
|
||||
for (auto &kernel : execution_order) {
|
||||
if (IsKernelActor(kernel)) {
|
||||
auto kernel_actor =
|
||||
std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_);
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
kernel_actors.emplace_back(kernel_actor);
|
||||
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
|
||||
const auto &graph = graph_compiler_info.graphs_[i];
|
||||
const auto &device_context = graph_compiler_info.device_contexts_[i];
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto execution_order = graph->execution_order();
|
||||
|
||||
for (auto &kernel : execution_order) {
|
||||
if (IsKernelActor(kernel)) {
|
||||
auto kernel_actor =
|
||||
std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_);
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
kernel_actors.emplace_back(kernel_actor);
|
||||
}
|
||||
}
|
||||
}
|
||||
return kernel_actors;
|
||||
}
|
||||
|
||||
std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
auto loop_count = ConfigManager::GetInstance().iter_num();
|
||||
auto actor_name = graph_compiler_info.name_ + "_" + "LoopCountActor";
|
||||
auto loop_count_actor = std::make_shared<LoopCountActor>(actor_name, loop_count);
|
||||
MS_LOG(INFO) << "Create loop count actor: " << actor_name;
|
||||
MS_EXCEPTION_IF_NULL(loop_count_actor);
|
||||
return loop_count_actor;
|
||||
}
|
||||
|
||||
std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorSet *actor_set) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
std::vector<KernelActorPtr> no_input_kernel_actors;
|
||||
|
||||
auto actor_set = Fetch(graph);
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
for (auto &kernel_actor : actor_set->kernel_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
|
||||
|
@ -699,16 +737,6 @@ std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const Kernel
|
|||
return no_input_kernel_actors;
|
||||
}
|
||||
|
||||
LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto loop_count = ConfigManager::GetInstance().iter_num();
|
||||
auto actor_name = graph->ToString() + "_" + "LoopCountActor";
|
||||
auto loop_count_actor = std::make_shared<LoopCountActor>(actor_name, loop_count);
|
||||
MS_LOG(INFO) << "Create loop count actor: " << actor_name;
|
||||
MS_EXCEPTION_IF_NULL(loop_count_actor);
|
||||
return loop_count_actor;
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_kernel_with_input_idx) {
|
||||
|
@ -740,19 +768,20 @@ void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_
|
|||
auto from_output_index = from_kernel_with_output_idx.second;
|
||||
auto to_input_index = to_kernel_with_input_idx.second;
|
||||
|
||||
auto data_nodes = from_actor->data_nodes_;
|
||||
auto iter = find(data_nodes.begin(), data_nodes.end(), from_kernel);
|
||||
if (iter == data_nodes.end()) {
|
||||
// Get the position of from kernel in the data source actor.
|
||||
auto iter = from_actor->data_node_position_map_.find(from_kernel);
|
||||
if (iter == from_actor->data_node_position_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Parameter node: " << from_kernel->fullname_with_scope() << " is not exist.";
|
||||
}
|
||||
auto position = IntToSize(std::distance(data_nodes.begin(), iter));
|
||||
auto position = iter->second;
|
||||
|
||||
auto to_aid = to_actor->GetAID();
|
||||
auto op_arrow = std::make_shared<OpArrow>(position, to_aid, to_input_index);
|
||||
from_actor->output_op_arrows_.emplace_back(op_arrow);
|
||||
to_actor->input_datas_num_++;
|
||||
|
||||
// Update the reference count of device tensor.
|
||||
UpdateRefCount(from_kernel, from_output_index);
|
||||
UpdateRefCount(from_actor->data_nodes_[position], from_output_index);
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
|
||||
|
@ -778,25 +807,26 @@ void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, Kernel
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowForKernelActor(KernelActor *from_actor, LoopCountActor *to_actor,
|
||||
const KernelGraphPtr &graph, GraphExecutionStrategy strategy) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
void GraphScheduler::LinkControlArrowForKernelActor(std::vector<KernelActorPtr> *from_actors, LoopCountActor *to_actor,
|
||||
GraphExecutionStrategy strategy) {
|
||||
MS_EXCEPTION_IF_NULL(from_actors);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
if (strategy == GraphExecutionStrategy::kStep) {
|
||||
from_actor->input_controls_num_++;
|
||||
}
|
||||
for (auto &from_actor : *from_actors) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
if (strategy == GraphExecutionStrategy::kStep) {
|
||||
from_actor->input_controls_num_++;
|
||||
}
|
||||
|
||||
// The manager of graph member is weak ptr, so need created and used in the function IsNotRealUsedByOthers.
|
||||
const auto &manager = Manage(graph, true);
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
if (opt::IsNotRealUsedByOthers(graph, from_actor->kernel_)) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor->kernel_);
|
||||
MS_LOG(INFO) << from_actor->kernel_->fullname_with_scope() << " is not real used by other nodes.";
|
||||
auto to_aid = to_actor->GetAID();
|
||||
from_actor->output_op_controls_.emplace_back(to_aid);
|
||||
to_actor->input_controls_num_++;
|
||||
// If the kernel actor has no output in the pipeline mode, then adds the output control to loop count actor.
|
||||
if ((strategy == GraphExecutionStrategy::kPipeline) && (from_actor->output_op_arrows_.size() == 0) &&
|
||||
(from_actor->output_op_controls_.size() == 0)) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor->kernel_);
|
||||
MS_LOG(INFO) << from_actor->kernel_->fullname_with_scope() << " is not real used by other nodes.";
|
||||
auto to_aid = to_actor->GetAID();
|
||||
from_actor->output_op_controls_.emplace_back(to_aid);
|
||||
to_actor->input_controls_num_++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -852,12 +882,9 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An
|
|||
to_actor->input_controls_num_++;
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(loop_count_actor);
|
||||
|
||||
auto actor_set = Fetch(graph);
|
||||
void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
MS_EXCEPTION_IF_NULL(loop_count_actor);
|
||||
|
||||
// Set the source data actor.
|
||||
for (auto &data_source_actor : actor_set->data_source_actors_) {
|
||||
|
@ -914,48 +941,58 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set) const {
|
|||
return true;
|
||||
}
|
||||
|
||||
void GraphScheduler::PersistDeviceTensor(const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
for (const auto &graph : graph_compiler_info.graphs_) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
for (auto &value_node : graph->graph_value_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
|
||||
MS_LOG(INFO) << "The device address is not exist: " << value_node->ToString();
|
||||
continue;
|
||||
}
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0);
|
||||
DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor);
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCount();
|
||||
}
|
||||
|
||||
for (auto &input_node : graph->input_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (IsPersistentDeviceTensor(input_node)) {
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor);
|
||||
for (auto &value_node : graph->graph_value_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
|
||||
MS_LOG(INFO) << "The device address is not exist: " << value_node->ToString();
|
||||
continue;
|
||||
}
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0);
|
||||
DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor);
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCount();
|
||||
}
|
||||
|
||||
for (auto &input_node : graph->input_nodes()) {
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (IsPersistentDeviceTensor(input_node)) {
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor);
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCount();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HostTensorQueue *GraphScheduler::FetchHostQueue(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
const auto &iter = graph_to_host_queue_.find(graph);
|
||||
if (iter != graph_to_host_queue_.end()) {
|
||||
HostTensorQueue *GraphScheduler::FetchHostQueue(const ActorInfo &actor_info) const {
|
||||
const auto &iter = actor_to_host_queue_.find(actor_info);
|
||||
if (iter != actor_to_host_queue_.end()) {
|
||||
return iter->second.get();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::DumpActor(const KernelGraphPtr &graph) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
const auto &actor_set = Fetch(graph);
|
||||
void GraphScheduler::DumpActor(const ActorSet *actor_set) const {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
std::string filename = "./actor_set_" + graph->ToString() + ".ir";
|
||||
const auto &context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
|
||||
if (!save_graphs) {
|
||||
return;
|
||||
}
|
||||
auto save_graphs_path = context_ptr->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
|
||||
if (save_graphs_path.empty()) {
|
||||
save_graphs_path = ".";
|
||||
}
|
||||
|
||||
std::string filename = save_graphs_path + "/actor_set_" + actor_set->name_ + ".ir";
|
||||
std::ofstream ofs(filename);
|
||||
if (!ofs.is_open()) {
|
||||
MS_LOG(ERROR) << "Open file [" << filename << "] failed!";
|
||||
|
|
|
@ -35,12 +35,37 @@ namespace runtime {
|
|||
using mindspore::device::DeviceContext;
|
||||
using mindspore::session::KernelWithIndex;
|
||||
using KernelMapActor = std::unordered_map<std::string, KernelActorPtr>;
|
||||
using ActorInfo = std::string;
|
||||
|
||||
enum class GraphExecutionStrategy {
|
||||
kPipeline, // The actor running is triggered only by data.
|
||||
kStep // The actor running need be triggered by control in addition.
|
||||
};
|
||||
|
||||
// The graph compiler info generated by graph compiler is the express of executable graph.
|
||||
// The device context is unified interface of interaction with device of corresponding graph.
|
||||
// The input tensor is used to link graphs in the dynamic build scenario.
|
||||
// The control node is used to link graphs in the control flow scenario.
|
||||
// The origin parameters order is used to correspond to the input args.
|
||||
struct GraphCompilerInfo {
|
||||
GraphCompilerInfo(const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<std::vector<TensorPtr> *> &input_tensors,
|
||||
const std::vector<AnfNodePtr> &control_nodes,
|
||||
const std::vector<AnfNodePtr> &origin_parameters_order, const std::string &name)
|
||||
: graphs_(graphs),
|
||||
device_contexts_(device_contexts),
|
||||
input_tensors_(input_tensors),
|
||||
control_nodes_(control_nodes),
|
||||
origin_parameters_order_(origin_parameters_order),
|
||||
name_(name) {}
|
||||
std::vector<KernelGraphPtr> graphs_;
|
||||
std::vector<DeviceContext *> device_contexts_;
|
||||
std::vector<std::vector<TensorPtr> *> input_tensors_;
|
||||
std::vector<AnfNodePtr> control_nodes_;
|
||||
std::vector<AnfNodePtr> origin_parameters_order_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
// The actor set generated by graph transformer is the execution unit of actor runtime.
|
||||
// It includes data source actor, kernel actor, loop count actor.
|
||||
// The data source actor is used to obtain data and process them into device tensors,
|
||||
|
@ -49,11 +74,13 @@ enum class GraphExecutionStrategy {
|
|||
// externally. The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
|
||||
// and decide whether to loop execution by loop count.
|
||||
struct ActorSet {
|
||||
explicit ActorSet(const ActorInfo &name) : name_(name) {}
|
||||
std::vector<DataSourceActorPtr> data_source_actors_;
|
||||
std::vector<KernelActorPtr> kernel_actors_;
|
||||
// No input kernel actors need be triggered specifically.
|
||||
std::vector<KernelActorPtr> no_input_kernel_actors_;
|
||||
LoopCountActorPtr loop_count_actor_{nullptr};
|
||||
ActorInfo name_;
|
||||
};
|
||||
using ActorSetPtr = std::shared_ptr<ActorSet>;
|
||||
|
||||
|
@ -69,9 +96,7 @@ class GraphScheduler {
|
|||
void Initialize();
|
||||
|
||||
// Transform graph to actor DAG, contains build and link.
|
||||
ActorSet *Transform(const std::vector<KernelGraphPtr> &graphs, const std::vector<DeviceContext *> &device_contexts,
|
||||
const std::vector<TensorPtr> *input_tensors = nullptr,
|
||||
const std::vector<AnfNodePtr> *control_nodes = nullptr,
|
||||
ActorSet *Transform(const GraphCompilerInfo &graph_compiler_info,
|
||||
GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
|
||||
|
||||
// Schedule actors in the actor runtime. Single machine scheduling is supported currently, and distributed scheduling
|
||||
|
@ -82,14 +107,15 @@ class GraphScheduler {
|
|||
// 1. Prepare the data of device tensor store(such as weights and value nodes of graph).
|
||||
// 2. Prepare the data of host tensor queue(such as non weighted parameters of graph).
|
||||
// 3. Prepare the output tensor of graph.
|
||||
// 4.Prepare the continuous memory for communication kernel.
|
||||
void PrepareRun(const KernelGraphPtr &graph, const std::vector<TensorPtr> *input_tensors, VectorRef *const &outputs);
|
||||
// 4. Prepare the continuous memory for communication kernel.
|
||||
void PrepareRun(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info,
|
||||
const std::vector<std::vector<TensorPtr>> &input_tensors, VectorRef *const &outputs);
|
||||
|
||||
// The processing entry of actors running.
|
||||
bool Run(const ActorSet *actor_set, GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline);
|
||||
|
||||
// Fetch the actor set by kernel graph.
|
||||
ActorSet *Fetch(const KernelGraphPtr &graph) const;
|
||||
// Fetch the actor set by actor info.
|
||||
ActorSet *Fetch(const ActorInfo &actor_info) const;
|
||||
|
||||
private:
|
||||
GraphScheduler() = default;
|
||||
|
@ -97,16 +123,16 @@ class GraphScheduler {
|
|||
DISABLE_COPY_AND_ASSIGN(GraphScheduler);
|
||||
|
||||
// Transform the nodes of graph to actors.
|
||||
ActorSetPtr Build(const KernelGraphPtr &graph, const DeviceContext *device_context);
|
||||
ActorSetPtr Build(const GraphCompilerInfo &graph_compiler_info);
|
||||
// Link actors to DAG through the edge connection of graph and graph execution strategy.
|
||||
void Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy);
|
||||
void Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info, GraphExecutionStrategy strategy);
|
||||
|
||||
// The processing of actors build.
|
||||
std::vector<DataSourceActorPtr> BuildDataSourceActor(const KernelGraphPtr &graph,
|
||||
const DeviceContext *device_context);
|
||||
std::vector<KernelActorPtr> BuildKernelActor(const KernelGraphPtr &graph, const DeviceContext *device_context);
|
||||
std::vector<KernelActorPtr> BuildNoInputKernelActor(const KernelGraphPtr &graph);
|
||||
LoopCountActorPtr BuildLoopCountActor(const KernelGraphPtr &graph);
|
||||
std::vector<DataSourceActorPtr> BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
|
||||
const HostTensorQueuePtr &host_queue);
|
||||
std::vector<KernelActorPtr> BuildKernelActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
LoopCountActorPtr BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
std::vector<KernelActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set);
|
||||
|
||||
// The processing of actors link.
|
||||
void LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
|
||||
|
@ -118,9 +144,9 @@ class GraphScheduler {
|
|||
void LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_kernel_with_input_idx);
|
||||
void LinkControlArrowForKernelActor(KernelActor *from_actor, LoopCountActor *to_actor, const KernelGraphPtr &graph,
|
||||
void LinkControlArrowForKernelActor(std::vector<KernelActorPtr> *from_actors, LoopCountActor *to_actor,
|
||||
GraphExecutionStrategy strategy);
|
||||
void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph);
|
||||
void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set);
|
||||
void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node,
|
||||
const KernelMapActor &kernel_actors_map);
|
||||
|
||||
|
@ -128,19 +154,19 @@ class GraphScheduler {
|
|||
bool CheckActorValid(const ActorSet *actor_set) const;
|
||||
|
||||
// Persist device tensors of graph's some nodes(such as weights and value nodes).
|
||||
void PersistDeviceTensor(const KernelGraphPtr &graph);
|
||||
void PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info);
|
||||
|
||||
// Fetch the hsot tensor queue by kernel graph.
|
||||
HostTensorQueue *FetchHostQueue(const KernelGraphPtr &graph) const;
|
||||
// Fetch the hsot tensor queue by actor info.
|
||||
HostTensorQueue *FetchHostQueue(const ActorInfo &actor_info) const;
|
||||
|
||||
// Display the actor information of corresponding kernel graph.
|
||||
void DumpActor(const KernelGraphPtr &graph) const;
|
||||
void DumpActor(const ActorSet *actor_set) const;
|
||||
void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const;
|
||||
void DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const;
|
||||
void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const;
|
||||
|
||||
std::unordered_map<KernelGraphPtr, ActorSetPtr> graph_to_actors_;
|
||||
std::unordered_map<KernelGraphPtr, HostTensorQueuePtr> graph_to_host_queue_;
|
||||
std::unordered_map<ActorInfo, ActorSetPtr> actors_;
|
||||
std::unordered_map<ActorInfo, HostTensorQueuePtr> actor_to_host_queue_;
|
||||
|
||||
// The second element of pair represents the output index of kernel actor corresponding to the device tensor.
|
||||
std::unordered_map<DeviceTensorPtr, std::pair<KernelActorPtr, int>> device_address_to_actor_;
|
||||
|
|
|
@ -29,7 +29,6 @@
|
|||
#include "utils/ms_utils.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "runtime/framework/graph_compiler.h"
|
||||
#include "runtime/framework/graph_scheduler.h"
|
||||
#include "utils/scoped_long_running.h"
|
||||
#ifdef ENABLE_GE
|
||||
#include "utils/callbacks_ge.h"
|
||||
|
@ -237,37 +236,46 @@ MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string
|
|||
graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend_name);
|
||||
}
|
||||
|
||||
GraphId MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
|
||||
ActorInfo MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
FuncGraphPtr root_graph = WrapPrimitives(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(root_graph);
|
||||
|
||||
// Compile root graph.
|
||||
auto root_graph_id = CompileGraph(root_graph);
|
||||
graph_to_device_context_.clear();
|
||||
control_nodes_.clear();
|
||||
CompileGraph(root_graph);
|
||||
|
||||
// Compile sub graphs.
|
||||
FuncGraphSet sub_graphs = root_graph->manager()->func_graphs();
|
||||
for (auto sub_graph : sub_graphs) {
|
||||
if (sub_graph != func_graph && sub_graph != nullptr) {
|
||||
(void)CompileGraph(sub_graph);
|
||||
CompileGraph(sub_graph);
|
||||
}
|
||||
}
|
||||
|
||||
// Transform graph to actor DAG, and schedule the actor DAG.
|
||||
// Construct the graph compiler info.
|
||||
std::vector<KernelGraphPtr> graphs;
|
||||
std::vector<DeviceContext *> device_contexts;
|
||||
std::string name = "kernel_graph";
|
||||
for (const auto &graph_id_to_context : graph_to_device_context_) {
|
||||
graphs.emplace_back(runtime::GraphCompiler::GetInstance().Fetch(graph_id_to_context.first));
|
||||
device_contexts.emplace_back(graph_id_to_context.second);
|
||||
name.append("_").append(std::to_string(graph_id_to_context.first));
|
||||
}
|
||||
const auto &actor_set =
|
||||
runtime::GraphScheduler::GetInstance().Transform(graphs, device_contexts, nullptr, &control_nodes_);
|
||||
std::vector<std::vector<tensor::TensorPtr> *> input_tensors;
|
||||
auto graph_compiler_info = std::make_unique<GraphCompilerInfo>(graphs, device_contexts, input_tensors, control_nodes_,
|
||||
root_graph->parameters(), name);
|
||||
|
||||
// Transform graph to actor DAG, and schedule the actor DAG.
|
||||
const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(*(graph_compiler_info.get()));
|
||||
runtime::GraphScheduler::GetInstance().Schedule(actor_set);
|
||||
|
||||
return root_graph_id;
|
||||
actor_to_graph_compiler_info_.emplace(actor_set->name_, std::move(graph_compiler_info));
|
||||
return actor_set->name_;
|
||||
}
|
||||
|
||||
GraphId MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) {
|
||||
void MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(graph_partition_);
|
||||
|
||||
|
@ -309,12 +317,10 @@ GraphId MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) {
|
|||
control_nodes_.push_back(cut_node);
|
||||
}
|
||||
}
|
||||
|
||||
return graph_to_device_context_.begin()->first;
|
||||
}
|
||||
|
||||
VectorRef MindRTBackend::RunGraph(GraphId graph_id, const VectorRef &args) {
|
||||
MS_LOG(INFO) << "Run graph begin, graph id: " << graph_id;
|
||||
VectorRef MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args) {
|
||||
MS_LOG(INFO) << "Run actor begin, actor name: " << actor_info;
|
||||
const auto &context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
|
||||
|
@ -322,39 +328,41 @@ VectorRef MindRTBackend::RunGraph(GraphId graph_id, const VectorRef &args) {
|
|||
return VectorRef();
|
||||
}
|
||||
|
||||
// Fetch the kernel graph.
|
||||
const auto &kernel_graph = runtime::GraphCompiler::GetInstance().Fetch(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
// Fetch the graph compiler info.
|
||||
const auto &graph_iter = actor_to_graph_compiler_info_.find(actor_info);
|
||||
if (graph_iter == actor_to_graph_compiler_info_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can't find the graph compiler info.";
|
||||
}
|
||||
const auto &graph_compiler_info = *(graph_iter->second.get());
|
||||
const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
|
||||
|
||||
// Transform args to input tensors.
|
||||
std::vector<tensor::TensorPtr> inputs;
|
||||
for (const auto &input_node : kernel_graph->input_nodes()) {
|
||||
const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
|
||||
MS_EXCEPTION_IF_NULL(front_node);
|
||||
MS_EXCEPTION_IF_NULL(front_node->func_graph());
|
||||
const auto &origin_parameters = front_node->func_graph()->parameters();
|
||||
const auto &iter = std::find(origin_parameters.begin(), origin_parameters.end(), front_node);
|
||||
if (iter == origin_parameters.end()) {
|
||||
MS_LOG(EXCEPTION) << "Parameter node: " << front_node->fullname_with_scope() << " is not exist.";
|
||||
std::vector<std::vector<tensor::TensorPtr>> input_tensors;
|
||||
for (const auto &kernel_graph : graph_compiler_info.graphs_) {
|
||||
std::vector<tensor::TensorPtr> input_tensor;
|
||||
for (const auto &input_node : kernel_graph->input_nodes()) {
|
||||
const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
|
||||
const auto &iter = std::find(origin_parameters.begin(), origin_parameters.end(), front_node);
|
||||
if (iter == origin_parameters.end()) {
|
||||
input_tensor.emplace_back(nullptr);
|
||||
continue;
|
||||
}
|
||||
auto position = IntToSize(std::distance(origin_parameters.begin(), iter));
|
||||
PushInputTensor(args[position], &input_tensor);
|
||||
}
|
||||
auto position = IntToSize(std::distance(origin_parameters.begin(), iter));
|
||||
PushInputTensor(args[position], &inputs);
|
||||
input_tensors.emplace_back(input_tensor);
|
||||
}
|
||||
|
||||
// Fetch the actor DAG.
|
||||
const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
|
||||
// Run actor DAG.
|
||||
mindspore::ScopedLongRunning long_running;
|
||||
VectorRef outputs;
|
||||
runtime::GraphScheduler::GetInstance().PrepareRun(kernel_graph, &inputs, &outputs);
|
||||
const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
|
||||
runtime::GraphScheduler::GetInstance().PrepareRun(actor_set, graph_compiler_info, input_tensors, &outputs);
|
||||
if (!runtime::GraphScheduler::GetInstance().Run(actor_set)) {
|
||||
MS_LOG(EXCEPTION) << "The graph runs failed, graph id: " << graph_id
|
||||
<< ", graph name: " << kernel_graph->ToString();
|
||||
MS_LOG(EXCEPTION) << "The actor runs failed, actor name: " << actor_set->name_;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Run graph end, graph id: " << graph_id;
|
||||
MS_LOG(INFO) << "Run actor end, actor name: " << actor_info;
|
||||
return outputs;
|
||||
}
|
||||
} // namespace compile
|
||||
|
|
|
@ -30,11 +30,15 @@
|
|||
#include "vm/vm.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "runtime/framework/graph_scheduler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace compile {
|
||||
using OpRunInfo = session::OpRunInfo;
|
||||
using DeviceContext = device::DeviceContext;
|
||||
using ActorInfo = runtime::ActorInfo;
|
||||
using GraphCompilerInfo = runtime::GraphCompilerInfo;
|
||||
|
||||
enum SwitchCondStatus {
|
||||
kCondOk = 0,
|
||||
kCondAlreadyRun,
|
||||
|
@ -95,22 +99,24 @@ class MindRTBackend : public Backend {
|
|||
MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id);
|
||||
~MindRTBackend() override = default;
|
||||
|
||||
// The parameter root_graph is a root graph, and the root graph maybe contain multiple sub graphs,
|
||||
// the return is the kernelGraph id of the root graph. It will traverse all subgraphs to call CompileGraph.
|
||||
GraphId CompileGraphs(const FuncGraphPtr &root_graph);
|
||||
// The parameter root_graph is a root graph, and the root graph maybe contain multiple sub graphs, It will traverse
|
||||
// all sub graphs to call CompileGraph.
|
||||
ActorInfo CompileGraphs(const FuncGraphPtr &root_graph);
|
||||
|
||||
// Compile single op kernel graph in the pyNative mode.
|
||||
GraphId CompileGraph(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int64_t> &tensors_mask);
|
||||
ActorInfo CompileGraph(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
|
||||
const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int64_t> &tensors_mask);
|
||||
|
||||
// Run Graph in the graph mode.
|
||||
VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
|
||||
VectorRef RunGraph(const ActorInfo &actor_info, const VectorRef &args);
|
||||
|
||||
// Run Graph in the pyNative mode.
|
||||
VectorRef RunGraph(const GraphInfo &graph_info, const VectorRef &args);
|
||||
VectorRef RunGraph(const GraphInfo &graph_info, const std::vector<tensor::TensorPtr> &input_tensors);
|
||||
|
||||
private:
|
||||
// The parameter func_graph is a graph, it can be either a root graph or a sub graph,
|
||||
// the return is the corresponding kernelGraph id of the graph.
|
||||
GraphId CompileGraph(const FuncGraphPtr &func_graph);
|
||||
// The result of graph compiler is stored in graph_to_device_context_ and control_nodes_.
|
||||
void CompileGraph(const FuncGraphPtr &func_graph);
|
||||
|
||||
// When compiling FuncGraph, it is divided according to the control nodes, and obtain the control nodes and several
|
||||
// node segments. Node segments will be compiled into kernelGraphs which are expressed as GraphId and bound to
|
||||
|
@ -118,6 +124,8 @@ class MindRTBackend : public Backend {
|
|||
std::unordered_map<GraphId, DeviceContext *> graph_to_device_context_;
|
||||
std::vector<AnfNodePtr> control_nodes_;
|
||||
|
||||
std::unordered_map<ActorInfo, std::unique_ptr<GraphCompilerInfo>> actor_to_graph_compiler_info_;
|
||||
|
||||
GraphPartitionPtr graph_partition_;
|
||||
std::string device_name_;
|
||||
uint32_t device_id_;
|
||||
|
|
|
@ -20,6 +20,7 @@ message("PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}")
|
|||
include_directories(${PYTHON_INCLUDE_DIRS})
|
||||
include_directories(${MS_CCSRC_PATH})
|
||||
include_directories(${CMAKE_SOURCE_DIR}/mindspore/core)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/mindspore/core/mindrt/include)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/stub/runtime/)
|
||||
include_directories(${CMAKE_BINARY_DIR})
|
||||
|
|
Loading…
Reference in New Issue