!16443 actor runtime support host and devcie

From: @limingqi107
Reviewed-by: @cristoval,@wilfchen
Signed-off-by: @wilfchen
This commit is contained in:
mindspore-ci-bot 2021-05-17 09:01:22 +08:00 committed by Gitee
commit cd439940f4
10 changed files with 378 additions and 155 deletions

View File

@ -1109,12 +1109,20 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr
}
}
AnfNodePtr KernelGraph::GetFrontNodeByInternalParameter(const AnfNodePtr &parameter) const {
const auto &iter = internal_parameters_to_front_map_.find(parameter);
if (iter != internal_parameters_to_front_map_.end()) {
void KernelGraph::CacheInternalParameterToFrontNode(const AnfNodePtr &parameter,
const AnfWithOutIndex &front_node_with_index) {
if (parameter == nullptr) {
return;
}
internal_parameter_to_front_node_map_[parameter] = front_node_with_index;
}
AnfWithOutIndex KernelGraph::GetFrontNodeByInternalParameter(const AnfNodePtr &parameter) const {
const auto &iter = internal_parameter_to_front_node_map_.find(parameter);
if (iter != internal_parameter_to_front_node_map_.end()) {
return iter->second;
}
return nullptr;
return AnfWithOutIndex();
}
AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const {

View File

@ -71,7 +71,7 @@ class KernelGraph : public FuncGraph {
parent_graph_ = graph.parent_graph_;
start_label_ = graph.start_label_;
end_goto_ = graph.end_goto_;
internal_parameters_to_front_map_ = graph.internal_parameters_to_front_map_;
internal_parameter_to_front_node_map_ = graph.internal_parameter_to_front_node_map_;
front_to_internal_outputs_map_ = graph.front_to_internal_outputs_map_;
internal_outputs_to_front_map_ = graph.internal_outputs_to_front_map_;
internal_outputs_tensor_map_ = graph.internal_outputs_tensor_map_;
@ -200,7 +200,9 @@ class KernelGraph : public FuncGraph {
bool unique_target = false);
void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx = -1,
int dst_output_idx = -1);
AnfNodePtr GetFrontNodeByInternalParameter(const AnfNodePtr &parameter) const;
// Cache the internal parameter and corresponding to front node into internal_parameter_to_front_node_map_.
void CacheInternalParameterToFrontNode(const AnfNodePtr &parameter, const AnfWithOutIndex &front_node_with_index);
AnfWithOutIndex GetFrontNodeByInternalParameter(const AnfNodePtr &parameter) const;
AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const;
bool IsInternalOutput(const AnfNodePtr &node, int output_idx = -1) const;
bool IsUniqueTargetInternalOutput(const AnfNodePtr &node, int output_idx) const;
@ -355,7 +357,10 @@ class KernelGraph : public FuncGraph {
CNodePtr start_label_;
CNodePtr end_goto_;
std::unordered_map<AnfNodePtr, AnfNodePtr> internal_parameters_to_front_map_;
// Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is
// related to the input of this kernel graph. The first of unordered map is the input of this kernel graph, the second
// of unordered map is front node corresponding to the output of previous kernel graph.
std::unordered_map<AnfNodePtr, AnfWithOutIndex> internal_parameter_to_front_node_map_;
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_;
std::unordered_map<AnfNodePtr, std::unordered_map<int, std::pair<AnfNodePtr, bool>>> internal_outputs_to_front_map_;
std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_;

View File

@ -771,6 +771,8 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
}
cnode_inputs->push_back(parameter_from_cnode);
(*other_graph_cnode)[anf] = parameter_from_cnode;
KernelWithIndex front_node_with_index(anf, 0);
graph->CacheInternalParameterToFrontNode(parameter_from_cnode, front_node_with_index);
}
}
}

View File

@ -48,7 +48,21 @@ bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
// Judge whether node is internal parameter.
if (graph->GetFrontNodeByInternalParameter(node) == nullptr) {
const auto &front_node = graph->GetFrontNodeByInternalParameter(node);
if (front_node.first == nullptr) {
return true;
}
}
return false;
}
bool IsInternalParameter(const AnfNodePtr &node, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
// Judge whether node is internal parameter.
const auto &front_node = graph->GetFrontNodeByInternalParameter(node);
if (front_node.first != nullptr) {
return true;
}
}

View File

@ -48,6 +48,10 @@ bool IsDeviceQueueDSActor(const AnfNodePtr &node);
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph);
bool IsKernelActor(const AnfNodePtr &node);
// Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is
// related to the input of this kernel graph.
bool IsInternalParameter(const AnfNodePtr &node, const KernelGraphPtr &graph);
// Judge whether the device tensor of the node is persistent or not.
bool IsPersistentDeviceTensor(const AnfNodePtr &node);
} // namespace runtime

View File

@ -44,14 +44,14 @@ void CopyActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *contex
}
void CopyActor::AllocateMemory(OpContext<DeviceTensor> *context) {
std::vector<DeviceTensor *> alloc_list({output_device_tensor_});
std::vector<DeviceTensor *> alloc_list({output_device_tensor_.get()});
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, alloc_list, output_device_context_, context,
GetAID());
}
void CopyActor::FreeMemory(OpContext<DeviceTensor> *context) {
std::vector<DeviceTensor *> input_free_list({input_device_tensor_});
std::vector<DeviceTensor *> output_free_list({output_device_tensor_});
std::vector<DeviceTensor *> output_free_list({output_device_tensor_.get()});
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, input_free_list, input_device_context_, context);
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, output_free_list, output_device_context_, context);
}
@ -59,7 +59,7 @@ void CopyActor::FreeMemory(OpContext<DeviceTensor> *context) {
void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
MS_EXCEPTION_IF_NULL(context);
if (!Copy(output_device_tensor_, input_device_tensor_)) {
if (!Copy(output_device_tensor_.get(), input_device_tensor_)) {
std::string error_info = "Copy device tensor failed: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
@ -146,8 +146,8 @@ void CopyActor::SendOutput(OpContext<DeviceTensor> *context) const {
std::string error_info = "The output index is out of range: " + GetAID().Name();
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
}
auto data =
std::make_shared<OpData<DeviceTensor>>(op_arrow->to_op_id_, output_device_tensor_, op_arrow->to_input_index_);
auto data = std::make_shared<OpData<DeviceTensor>>(op_arrow->to_op_id_, output_device_tensor_.get(),
op_arrow->to_input_index_);
Async(op_arrow->to_op_id_, &CopyActor::RunOpData, data, context);
}

View File

@ -36,14 +36,11 @@ using mindspore::device::DeviceContext;
// -> OnMemoryAllocFinish -> Copy -> FreeMemory -> SendOutput.
class CopyActor : public MemoryInterfaceActor {
public:
CopyActor(const std::string &name, const DeviceContext *input_device_context,
const DeviceContext *output_device_context, const AID &memory_manager_aid)
CopyActor(const std::string &name, const AID &memory_manager_aid)
: MemoryInterfaceActor(name),
memory_manager_aid_(memory_manager_aid),
input_datas_num_(0),
input_controls_num_(0),
input_device_context_(input_device_context),
output_device_context_(output_device_context),
input_device_tensor_(nullptr),
output_device_tensor_(nullptr) {}
~CopyActor() override = default;
@ -89,9 +86,10 @@ class CopyActor : public MemoryInterfaceActor {
const DeviceContext *input_device_context_;
const DeviceContext *output_device_context_;
// The device tensor for copy.
// The input device tensor is saved from the input data.
DeviceTensor *input_device_tensor_;
DeviceTensor *output_device_tensor_;
// The output device tensor is created in the copy actor build, so can't be the raw pointer.
DeviceTensorPtr output_device_tensor_;
};
using CopyActorPtr = std::shared_ptr<CopyActor>;

View File

@ -35,14 +35,10 @@ using mindspore::session::KernelWithIndex;
using mindspore::tensor::TensorPtr;
// The output actor is used to receive the output result of actor which represents the graph output.
class OutputActor : public ActorBase {
class OutputActor : public OpActor<DeviceTensor> {
public:
OutputActor(std::string name, size_t loop_count, size_t outputs_num)
: ActorBase(name),
loop_count_(loop_count),
current_count_(0),
outputs_num_(outputs_num),
current_outputs_num_(0) {
: OpActor(name), loop_count_(loop_count), current_count_(0), outputs_num_(outputs_num), current_outputs_num_(0) {
outputs_.resize(outputs_num);
output_nodes_.resize(outputs_num);
device_contexts_.resize(outputs_num);

View File

@ -28,42 +28,19 @@
namespace mindspore {
namespace runtime {
namespace {
KernelActor *FindKernelActor(const KernelMapActor &kernel_actors_map, const std::string &name) {
auto iter = kernel_actors_map.find(name);
if (iter != kernel_actors_map.end()) {
return iter->second.get();
bool IsNeedInsertCopyActor(const DeviceContext *from_devcie_context, const DeviceContext *to_devcie_context) {
MS_EXCEPTION_IF_NULL(from_devcie_context);
MS_EXCEPTION_IF_NULL(to_devcie_context);
if (from_devcie_context->GetDeviceAddressType() == to_devcie_context->GetDeviceAddressType()) {
return false;
} else {
return true;
}
return nullptr;
}
DeviceQueueDataSourceActor *FindDeviceQueueDSActor(const std::vector<DataSourceActorPtr> &data_source_actors) {
for (auto &actor : data_source_actors) {
MS_EXCEPTION_IF_NULL(actor);
if (actor->GetAID().Name().find("_DeviceQueueDataSourceActor") != string::npos) {
auto device_queue_ds_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor.get());
return device_queue_ds_actor;
}
}
return nullptr;
}
HostQueueDataSourceActor *FindHostQueueDSActor(const std::vector<DataSourceActorPtr> &data_source_actors) {
for (auto &actor : data_source_actors) {
MS_EXCEPTION_IF_NULL(actor);
if (actor->GetAID().Name().find("_HostQueueDataSourceActor") != string::npos) {
auto device_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(actor.get());
return device_queue_ds_actor;
}
}
return nullptr;
}
// Update the reference count of device tensor by the output index of node.
void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_count = false) {
MS_EXCEPTION_IF_NULL(node);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_idx);
void UpdateRefCount(DeviceTensor *device_tensor, bool is_max_ref_count = false) {
MS_EXCEPTION_IF_NULL(device_tensor);
if (is_max_ref_count) {
device_tensor->set_original_ref_count(SIZE_MAX);
@ -73,6 +50,13 @@ void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_c
device_tensor->ResetRefCount();
}
// Update the reference count of device tensor by the output index of node.
void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_count = false) {
MS_EXCEPTION_IF_NULL(node);
auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_idx, false);
UpdateRefCount(device_tensor.get(), is_max_ref_count);
}
// The branch processing of PrepareDataForValueNode that value type is tensor.
void PrepareDataForValueNodeTensor(const ValueNodePtr &node, const ValuePtr &node_value,
const DeviceContext *device_context) {
@ -207,8 +191,7 @@ void AllocateContinuousMemoryForInput(const AnfNodePtr &kernel, const DeviceCont
MS_EXCEPTION_IF_NULL(device_tensor);
// In the scene of communication op and computing op parallel multi stream, the input address of communication op
// can't be reused, so set the max reference count.
device_tensor->set_original_ref_count(SIZE_MAX);
device_tensor->ResetRefCount();
UpdateRefCount(device_tensor.get(), true);
if (device_tensor->GetPtr() == nullptr) {
is_need_alloc_memory = true;
@ -241,8 +224,7 @@ void AllocateContinuousMemoryForOutput(const AnfNodePtr &kernel, const DeviceCon
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
MS_EXCEPTION_IF_NULL(device_tensor);
// One time application for continuous memory, so set the max reference count.
device_tensor->set_original_ref_count(SIZE_MAX);
device_tensor->ResetRefCount();
UpdateRefCount(device_tensor.get(), true);
if (device_tensor->GetPtr() == nullptr) {
is_need_alloc_memory = true;
@ -261,7 +243,23 @@ void AllocateContinuousMemoryForOutput(const AnfNodePtr &kernel, const DeviceCon
}
} // namespace
GraphScheduler::~GraphScheduler() {
// Global maps clear.
device_tensor_to_actor_.clear();
actor_to_host_queue_.clear();
actors_.clear();
// Local maps clear.
actor_name_to_actor_.clear();
output_to_actor_.clear();
}
void GraphScheduler::Initialize() {
// Local maps and vcetors clear.
actor_name_to_actor_.clear();
output_to_actor_.clear();
copy_actors_.clear();
if (init_) {
return;
}
@ -297,7 +295,11 @@ ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info
PersistDeviceTensor(graph_compiler_info);
const auto &actor_set = Build(graph_compiler_info);
CacheGraphOutputToActor(graph_compiler_info);
Link(actor_set.get(), graph_compiler_info, strategy);
// The copy actors are built in the link, so need push into the actor set after link.
actor_set->copy_actors_ = copy_actors_;
actors_.emplace(actor_set->name_, actor_set);
DumpActor(actor_set.get());
@ -310,40 +312,37 @@ ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info
void GraphScheduler::Schedule(const ActorSet *actor_set) {
MS_EXCEPTION_IF_NULL(actor_set);
auto actorMgr = ActorMgr::GetActorMgrRef();
MS_EXCEPTION_IF_NULL(actorMgr);
std::vector<ActorReference> actors;
// Schedule dats source actors.
// Collect actors.
for (auto &data_source_actor : actor_set->data_source_actors_) {
MS_EXCEPTION_IF_NULL(data_source_actor);
auto base_actor = static_cast<ActorReference>(data_source_actor);
(void)actorMgr->Spawn(base_actor);
actors.emplace_back(static_cast<ActorReference>(data_source_actor));
}
// Schedule kernel actors.
for (auto &kernel_actor : actor_set->kernel_actors_) {
MS_EXCEPTION_IF_NULL(kernel_actor);
auto base_actor = static_cast<ActorReference>(kernel_actor);
(void)actorMgr->Spawn(base_actor);
actors.emplace_back(static_cast<ActorReference>(kernel_actor));
}
// Schedule switch actors.
for (auto &switch_actor : actor_set->switch_actors_) {
MS_EXCEPTION_IF_NULL(switch_actor);
auto base_actor = static_cast<ActorReference>(switch_actor);
(void)actorMgr->Spawn(base_actor);
actors.emplace_back(static_cast<ActorReference>(switch_actor));
}
for (auto &copy_actor : actor_set->copy_actors_) {
MS_EXCEPTION_IF_NULL(copy_actor);
actors.emplace_back(static_cast<ActorReference>(copy_actor));
}
// Schedule loop count actor.
if (actor_set->loop_count_actor_ != nullptr) {
auto base_actor = static_cast<ActorReference>(actor_set->loop_count_actor_);
(void)actorMgr->Spawn(base_actor);
actors.emplace_back(static_cast<ActorReference>(actor_set->loop_count_actor_));
}
if (actor_set->output_actor_ != nullptr) {
actors.emplace_back(static_cast<ActorReference>(actor_set->output_actor_));
}
// Schedule output actor.
if (actor_set->output_actor_ != nullptr) {
auto base_actor = static_cast<ActorReference>(actor_set->output_actor_);
(void)actorMgr->Spawn(base_actor);
// Schedule actors.
auto actorMgr = ActorMgr::GetActorMgrRef();
MS_EXCEPTION_IF_NULL(actorMgr);
for (auto actor : actors) {
(void)actorMgr->Spawn(actor);
}
}
@ -351,7 +350,8 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
const std::vector<std::vector<TensorPtr>> &input_tensors) {
MS_EXCEPTION_IF_NULL(actor_set);
std::vector<TensorPtr> host_tensors;
const auto &host_data_source_actor = FindHostQueueDSActor(actor_set->data_source_actors_);
std::string actor_name = actor_set->name_ + "_HostDSActor";
const auto &host_data_source_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(actor_name));
if (host_data_source_actor != nullptr) {
host_tensors.resize(host_data_source_actor->data_nodes_.size());
}
@ -478,14 +478,39 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info)
return actor_set;
}
void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info) {
for (const auto &graph : graph_compiler_info.graphs_) {
MS_EXCEPTION_IF_NULL(graph);
const auto &outputs = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
for (const auto &output : outputs) {
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(output, 0, true);
MS_EXCEPTION_IF_NULL(output_with_index.first);
const auto &front_node = graph->GetFrontAnfByBackendAnf(output_with_index.first);
if (front_node == nullptr) {
continue;
}
auto origin_output_with_index = KernelWithIndex(front_node, output_with_index.second);
std::string actor_name;
// Only cache the kernel actor and device queue data source actor.
if (IsKernelActor(output_with_index.first)) {
actor_name = output_with_index.first->fullname_with_scope();
} else if (IsDeviceQueueDSActor(output_with_index.first)) {
actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
} else {
continue;
}
const auto &actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
std::pair<OpActor<DeviceTensor> *, size_t> actor_pair(actor, output_with_index.second);
output_to_actor_.emplace(origin_output_with_index, actor_pair);
}
}
}
void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info,
GraphExecutionStrategy strategy) {
MS_EXCEPTION_IF_NULL(actor_set);
KernelMapActor kernel_actors_temp_map;
for (auto &actor : actor_set->kernel_actors_) {
MS_EXCEPTION_IF_NULL(actor);
kernel_actors_temp_map.emplace(actor->GetAID().Name(), actor);
}
// Foreach the execution order to link the actors.
for (const auto &graph : graph_compiler_info.graphs_) {
@ -495,34 +520,21 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
if (!IsKernelActor(kernel)) {
continue;
}
auto kernel_actor = FindKernelActor(kernel_actors_temp_map, kernel->fullname_with_scope());
const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(kernel_actor);
for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
auto input_node = AnfAlgo::GetInputNode(kernel, i);
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
LinkControlArrowByAutoMonad(kernel_actor, input_node, kernel_actors_temp_map);
LinkControlArrowByAutoMonad(kernel_actor, input_node);
if (HasAbstractMonad(input_node)) {
continue; // No data arrow for monad input.
}
KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
auto from_kernel = from_kernel_with_output_idx.first;
if (IsDeviceQueueDSActor(from_kernel)) {
// Link the data arrows of device queue data source actor.
auto from_actor = FindDeviceQueueDSActor(actor_set->data_source_actors_);
LinkDataArrowForDeviceDSActor(from_actor, kernel_actor, from_kernel_with_output_idx,
to_kernel_with_input_idx);
} else if (IsHostQueueDSActor(from_kernel, graph)) {
// Link the data arrows of host queue data source actor.
auto from_actor = FindHostQueueDSActor(actor_set->data_source_actors_);
LinkDataArrowForHostDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else {
// Link the data arrows of kernel actor.
auto from_actor = FindKernelActor(kernel_actors_temp_map, from_kernel->fullname_with_scope());
LinkDataArrowForKernelActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
}
// The gather of linking data allows of kernel by the different from kernel type.
LinkDataArrow(kernel_actor, actor_set, graph, from_kernel_with_output_idx, to_kernel_with_input_idx);
}
}
}
@ -531,15 +543,13 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
LinkControlArrowForKernelActor(&(actor_set->kernel_actors_), actor_set->loop_count_actor_.get(), strategy);
// BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors.
auto no_input_kernel_actors = BuildNoInputKernelActor(actor_set);
actor_set->no_input_kernel_actors_.swap(no_input_kernel_actors);
actor_set->no_input_kernel_actors_ = BuildNoInputKernelActor(actor_set);
// Link the control arrows of loop count actor, which depends on the no input kernel actors.
LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set);
// Link the output result arrows for output actors.
LinkOutputResultArrowForOutputActor(actor_set->output_actor_.get(), actor_set->data_source_actors_,
kernel_actors_temp_map, graph_compiler_info);
LinkOutputResultArrowForOutputActor(actor_set->output_actor_.get(), graph_compiler_info);
}
std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
@ -558,10 +568,11 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
MS_EXCEPTION_IF_NULL(input_node);
if (IsHostQueueDSActor(input_node, graph)) {
if (host_queue_ds_actor == nullptr) {
auto actor_name = graph_compiler_info.name_ + "_HostQueueDataSourceActor";
auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
host_queue_ds_actor =
std::make_shared<HostQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_, host_queue);
InsertActor(host_queue_ds_actor.get());
data_source_actors.emplace_back(host_queue_ds_actor);
}
@ -584,12 +595,12 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
const auto &iter = std::find_if(execution_order.begin(), execution_order.end(),
[](const CNodePtr &node) { return IsDeviceQueueDSActor(node); });
if (iter != execution_order.end()) {
auto actor_name =
graph_compiler_info.name_ + "_DeviceQueueDataSourceActor" + "_" + std::to_string(graph->graph_id());
auto actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
auto device_queue_ds_actor =
std::make_shared<DeviceQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_);
MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
InsertActor(device_queue_ds_actor.get());
data_source_actors.emplace_back(device_queue_ds_actor);
device_queue_ds_actor->data_kernel_ = *iter;
}
@ -610,6 +621,7 @@ std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompiler
auto kernel_actor =
std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_);
MS_EXCEPTION_IF_NULL(kernel_actor);
InsertActor(kernel_actor.get());
kernel_actors.emplace_back(kernel_actor);
}
}
@ -623,6 +635,7 @@ LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &g
auto loop_count_actor = std::make_shared<LoopCountActor>(actor_name, loop_count);
MS_LOG(INFO) << "Create loop count actor: " << actor_name;
MS_EXCEPTION_IF_NULL(loop_count_actor);
InsertActor(loop_count_actor.get());
return loop_count_actor;
}
@ -633,6 +646,7 @@ OutputActorPtr GraphScheduler::BuildOutputActor(const GraphCompilerInfo &graph_c
std::make_shared<OutputActor>(actor_name, loop_count, graph_compiler_info.origin_outputs_order_.size());
MS_LOG(INFO) << "Create output actor: " << actor_name;
MS_EXCEPTION_IF_NULL(output_actor);
InsertActor(output_actor.get());
return output_actor;
}
@ -651,6 +665,67 @@ std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorS
return no_input_kernel_actors;
}
void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const ActorSet *actor_set, const KernelGraphPtr &graph,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx) {
MS_EXCEPTION_IF_NULL(to_actor);
MS_EXCEPTION_IF_NULL(actor_set);
MS_EXCEPTION_IF_NULL(graph);
auto from_kernel = from_kernel_with_output_idx.first;
if (IsDeviceQueueDSActor(from_kernel)) {
// Link the data arrows of device queue data source actor.
std::string actor_name = actor_set->name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
const auto &from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(FetchActor(actor_name));
LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else if (IsHostQueueDSActor(from_kernel, graph)) {
// Link the data arrows of host queue data source actor.
std::string actor_name = actor_set->name_ + "_HostDSActor";
const auto &from_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(actor_name));
LinkDataArrowForHostDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else if (IsKernelActor(from_kernel)) {
// Link the data arrows of kernel actor.
const auto &from_actor = dynamic_cast<KernelActor *>(FetchActor(from_kernel->fullname_with_scope()));
LinkDataArrowForKernelActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else if (IsInternalParameter(from_kernel, graph)) {
// Link data arrow for internal parameter, convert internal parameter to actor by internal parameter cache to link.
LinkDataArrowForInternalParameter(from_kernel, graph, to_actor, to_kernel_with_input_idx);
} else if (IsPersistentDeviceTensor(from_kernel)) {
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second,
static_cast<void *>(from_kernel.get()));
} else {
MS_LOG(EXCEPTION) << "Invalid from kernel: " << from_kernel->fullname_with_scope();
}
}
void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &internal_parameter,
const KernelGraphPtr &graph, KernelActor *to_actor,
KernelWithIndex to_kernel_with_input_idx) {
MS_EXCEPTION_IF_NULL(internal_parameter);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(to_actor);
// Parameter ---> front node ---> actor.
auto front_node_with_index = graph->GetFrontNodeByInternalParameter(internal_parameter);
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
if (output_to_actor_.count(front_node_with_index) == 0) {
MS_LOG(EXCEPTION) << "Can't find actor by node:" << front_node_with_index.first->fullname_with_scope();
}
auto actor_pair = output_to_actor_[front_node_with_index];
if (IsDeviceQueueDSActor(front_node_with_index.first)) {
auto from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor_pair.first);
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->data_kernel_, actor_pair.second);
LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else if (IsKernelActor(front_node_with_index.first)) {
auto from_actor = dynamic_cast<KernelActor *>(actor_pair.first);
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->kernel_, actor_pair.second);
LinkDataArrowForKernelActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else {
MS_LOG(EXCEPTION) << "Invalid internal parameter: " << internal_parameter->fullname_with_scope();
}
}
void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx) {
@ -662,13 +737,17 @@ void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *f
auto from_output_index = from_kernel_with_output_idx.second;
auto to_input_index = to_kernel_with_input_idx.second;
auto to_aid = to_actor->GetAID();
auto op_arrow = std::make_shared<OpArrow>(from_output_index, to_aid, to_input_index);
from_actor->output_op_arrows_.emplace_back(op_arrow);
to_actor->input_datas_num_++;
if (IsNeedInsertCopyActor(from_actor->device_context_, to_actor->device_context_)) {
LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else {
auto to_aid = to_actor->GetAID();
auto op_arrow = std::make_shared<OpArrow>(from_output_index, to_aid, to_input_index);
from_actor->output_op_arrows_.emplace_back(op_arrow);
to_actor->input_datas_num_++;
// Update the reference count of device tensor.
UpdateRefCount(from_kernel, from_output_index);
// Update the reference count of device tensor.
UpdateRefCount(from_kernel, from_output_index);
}
}
void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_actor, KernelActor *to_actor,
@ -701,16 +780,16 @@ void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_
void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);
auto from_kernel = from_kernel_with_output_idx.first;
MS_EXCEPTION_IF_NULL(from_kernel);
auto from_output_index = from_kernel_with_output_idx.second;
auto to_input_index = to_kernel_with_input_idx.second;
if (IsPersistentDeviceTensor(from_kernel)) {
to_actor->device_tensor_store_keys_.emplace_back(to_input_index, static_cast<void *>(from_kernel.get()));
} else if (IsKernelActor(from_kernel)) {
MS_EXCEPTION_IF_NULL(from_actor);
if (IsNeedInsertCopyActor(from_actor->device_context_, to_actor->device_context_)) {
LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
} else {
auto to_aid = to_actor->GetAID();
auto op_arrow = std::make_shared<OpArrow>(from_output_index, to_aid, to_input_index);
from_actor->output_op_arrows_.emplace_back(op_arrow);
@ -721,6 +800,64 @@ void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, Kernel
}
}
void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor, KernelActor *to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_kernel_with_input_idx) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);
auto from_kernel = from_kernel_with_output_idx.first;
MS_EXCEPTION_IF_NULL(from_kernel);
auto to_devcie_context = to_actor->device_context_;
MS_EXCEPTION_IF_NULL(to_devcie_context);
auto from_output_index = from_kernel_with_output_idx.second;
auto to_input_index = to_kernel_with_input_idx.second;
std::string name =
"copy_actor_" + from_kernel->fullname_with_scope() + "_output_index_" + std::to_string(from_output_index);
CopyActor *copy_actor = dynamic_cast<CopyActor *>(FetchActor(name));
// Link between from actor and copy actor.
if (copy_actor == nullptr) {
// Create the copy actor.
auto copy_actor_shared_ptr = std::make_shared<CopyActor>(name, memory_manager_aid_);
copy_actors_.emplace_back(copy_actor_shared_ptr);
copy_actor = copy_actor_shared_ptr.get();
MS_EXCEPTION_IF_NULL(copy_actor);
InsertActor(copy_actor);
// LInk.
const DeviceContext *from_devcie_context = nullptr;
auto op_arrow_to_copy = std::make_shared<OpArrow>(from_output_index, copy_actor->GetAID(), 0);
if (IsDeviceQueueDSActor(from_kernel)) {
auto real_from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(from_actor);
from_devcie_context = real_from_actor->device_context_;
real_from_actor->output_op_arrows_.emplace_back(op_arrow_to_copy);
} else if (IsKernelActor(from_kernel)) {
auto real_from_actor = dynamic_cast<KernelActor *>(from_actor);
from_devcie_context = real_from_actor->device_context_;
real_from_actor->output_op_arrows_.emplace_back(op_arrow_to_copy);
}
copy_actor->input_datas_num_++;
// Set the member of the copy actor.
const auto &from_device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
MS_EXCEPTION_IF_NULL(from_device_tensor);
copy_actor->output_device_tensor_ = to_devcie_context->CreateDeviceAddress(
nullptr, from_device_tensor->GetSize(), from_device_tensor->format(), from_device_tensor->type_id());
MS_EXCEPTION_IF_NULL(from_devcie_context);
copy_actor->input_device_context_ = from_devcie_context;
copy_actor->output_device_context_ = to_devcie_context;
// Update the reference count of device tensor.
UpdateRefCount(from_device_tensor.get());
}
// If the copy actor already exists, only need link between copy actor and to actor.
auto op_arrow_from_copy = std::make_shared<OpArrow>(0, to_actor->GetAID(), to_input_index);
copy_actor->output_op_arrows_.emplace_back(op_arrow_from_copy);
to_actor->input_datas_num_++;
UpdateRefCount(copy_actor->output_device_tensor_.get());
}
void GraphScheduler::LinkControlArrowForKernelActor(std::vector<KernelActorPtr> *from_actors, LoopCountActor *to_actor,
GraphExecutionStrategy strategy) {
MS_EXCEPTION_IF_NULL(from_actors);
@ -744,8 +881,7 @@ void GraphScheduler::LinkControlArrowForKernelActor(std::vector<KernelActorPtr>
}
}
void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node,
const KernelMapActor &kernel_actors_map) {
void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node) {
MS_EXCEPTION_IF_NULL(to_actor);
MS_EXCEPTION_IF_NULL(from_node);
if (!from_node->isa<CNode>()) {
@ -770,7 +906,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An
} else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimMakeTuple)) {
// Make tuple node needs to be expanded.
for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), kernel_actors_map);
LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i));
}
return;
} else {
@ -785,12 +921,12 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An
if (AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimUpdateState) ||
AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimLoad) ||
AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimMakeTuple)) {
LinkControlArrowByAutoMonad(to_actor, real_depend_input, kernel_actors_map);
LinkControlArrowByAutoMonad(to_actor, real_depend_input);
return;
}
// Link the control arrow between the kernel actors.
auto from_actor = FindKernelActor(kernel_actors_map, real_depend_input->fullname_with_scope());
const auto &from_actor = dynamic_cast<KernelActor *>(FetchActor(real_depend_input->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(from_actor);
from_actor->output_op_controls_.emplace_back(to_actor->GetAID());
to_actor->input_controls_num_++;
@ -818,8 +954,6 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun
}
void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
const std::vector<DataSourceActorPtr> &data_source_actors,
const KernelMapActor &kernel_actors_map,
const GraphCompilerInfo &graph_compiler_info) {
MS_EXCEPTION_IF_NULL(to_actor);
@ -853,7 +987,8 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
// The graph output is from kernel actor.
if (IsKernelActor(output_with_index.first)) {
const auto &from_actor = FindKernelActor(kernel_actors_map, output_with_index.first->fullname_with_scope());
const auto &from_actor =
dynamic_cast<KernelActor *>(FetchActor(output_with_index.first->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(from_actor);
auto op_arrow = std::make_shared<OpArrow>(output_with_index.second, to_actor->GetAID(), iter->second);
from_actor->output_result_arrows_.emplace_back(op_arrow);
@ -861,10 +996,12 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
}
// The graph output is from data source actor.
std::string actor_name;
DataSourceActor *from_actor = nullptr;
size_t from_actor_output_index = 0;
if (IsHostQueueDSActor(output_with_index.first, graph)) {
const auto &host_queue_ds_actor = FindHostQueueDSActor(data_source_actors);
actor_name = graph_compiler_info.name_ + "_HostDSActor";
const auto &host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(actor_name));
auto position_iter = host_queue_ds_actor->data_node_position_map_.find(output_with_index.first);
if (position_iter == host_queue_ds_actor->data_node_position_map_.end()) {
MS_LOG(EXCEPTION) << "Parameter node: " << output_with_index.first->fullname_with_scope() << " is not exist.";
@ -873,7 +1010,8 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
UpdateRefCount(host_queue_ds_actor->data_nodes_[from_actor_output_index], output_with_index.second, true);
from_actor = static_cast<DataSourceActor *>(host_queue_ds_actor);
} else if (IsDeviceQueueDSActor(output_with_index.first)) {
from_actor = FindDeviceQueueDSActor(data_source_actors);
actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
from_actor = dynamic_cast<DataSourceActor *>(FetchActor(actor_name));
from_actor_output_index = output_with_index.second;
}
MS_EXCEPTION_IF_NULL(from_actor);
@ -913,6 +1051,26 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set) const {
}
}
// Check the copy actors.
for (const auto &copy_actor : actor_set->copy_actors_) {
MS_EXCEPTION_IF_NULL(copy_actor);
if (copy_actor->output_op_arrows_.size() + copy_actor->output_op_controls_.size() == 0) {
MS_LOG(ERROR) << copy_actor->GetAID().Name() << " has no user.";
return false;
}
const size_t kCopyActorInputDataNum = 1;
auto input_data_num = copy_actor->input_datas_num_;
auto device_tensor_store_num = copy_actor->device_tensor_store_keys_.size();
if (input_data_num + device_tensor_store_num != kCopyActorInputDataNum) {
MS_LOG(ERROR) << "The input building of " << copy_actor->GetAID().Name()
<< " is wrong, input data num: " << input_data_num
<< ", device tensor store num: " << device_tensor_store_num
<< ", total input num: " << kCopyActorInputDataNum;
return false;
}
}
// Check the loop count actor.
const auto &loop_count_actor = actor_set->loop_count_actor_;
if (loop_count_actor != nullptr) {
@ -937,8 +1095,7 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
}
auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0);
DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor);
device_tensor->set_original_ref_count(SIZE_MAX);
device_tensor->ResetRefCount();
UpdateRefCount(device_tensor.get(), true);
}
for (auto &input_node : graph->input_nodes()) {
@ -947,8 +1104,7 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0);
MS_EXCEPTION_IF_NULL(device_tensor);
DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor);
device_tensor->set_original_ref_count(SIZE_MAX);
device_tensor->ResetRefCount();
UpdateRefCount(device_tensor.get(), true);
}
}
}
@ -963,6 +1119,22 @@ HostTensorQueue *GraphScheduler::FetchHostQueue(const ActorInfo &actor_info) con
}
}
void GraphScheduler::InsertActor(OpActor<DeviceTensor> *actor) {
MS_EXCEPTION_IF_NULL(actor);
if (actor_name_to_actor_.count(actor->GetAID().Name()) > 0) {
MS_LOG(EXCEPTION) << "The actor already exists: " << actor->GetAID().Name();
}
actor_name_to_actor_[actor->GetAID().Name()] = actor;
}
OpActor<DeviceTensor> *GraphScheduler::FetchActor(const std::string actor_name) const {
const auto &iter = actor_name_to_actor_.find(actor_name);
if (iter == actor_name_to_actor_.end()) {
return nullptr;
}
return iter->second;
}
void GraphScheduler::DumpActor(const ActorSet *actor_set) const {
MS_EXCEPTION_IF_NULL(actor_set);
const auto &context_ptr = MsContext::GetInstance();
@ -1024,7 +1196,7 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of
ofs << "\tactor_name:" << actor_name << "\tdevice_context:" << actor->device_context_->device_context_key().ToString()
<< "\n";
if (actor_name.find("_DeviceQueueDataSourceActor") != string::npos) {
if (actor_name.find("_DeviceDSActor") != string::npos) {
// Dump the member info of device queue data source actor.
const auto &device_queue_ds_actor = dynamic_cast<const DeviceQueueDataSourceActor *>(actor);
const auto &data_kernel = device_queue_ds_actor->data_kernel_;
@ -1038,7 +1210,7 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of
ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
<< "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
}
} else if (actor_name.find("_HostQueueDataSourceActor") != string::npos) {
} else if (actor_name.find("_HostDSActor") != string::npos) {
// Dump the member info of host queue data source actor.
const auto &host_queue_ds_actor = dynamic_cast<const HostQueueDataSourceActor *>(actor);
ofs << "\t\tdata_nodes:" << host_queue_ds_actor->data_nodes_.size() << "\n";

View File

@ -30,6 +30,7 @@
#include "runtime/framework/actor/kernel_actor.h"
#include "runtime/framework/actor/output_actor.h"
#include "runtime/framework/actor/switch_actor.h"
#include "runtime/framework/actor/copy_actor.h"
#include "runtime/hardware/device_context.h"
#include "backend/session/kernel_graph.h"
@ -75,13 +76,14 @@ struct GraphCompilerInfo {
};
// The actor set generated by graph transformer is the execution unit of actor runtime.
// It includes data source actor, kernel actor, loop count actor.
// The data source actor is used to obtain data and process them into device tensors,
// and then send them to kernel actor. The kernel actor is used to receive the device tensors to luanch kernel.
// Specifically notice the no input kernel actor, it means that this actor has no input device tensor, need be triggered
// externally. The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
// and decide whether to loop execution by loop count. The output actor is used to receive the output result of actor
// which represents the graph output.
// It includes data source actor, kernel actor, switch actor, copy actor, loop count actor and output actor.
// The data source actor is used to obtain data and process them into device tensors, and send them to kernel actor.
// The kernel actor is used to receive the device tensors to luanch kernel. Specifically notice the no input
// kernel actor, it means that this actor has no input device tensor, need be triggered externally.
// The copy actor is used to convert the device tensor between the different device kernel.
// The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
// and decide whether to loop execution by loop count.
// The output actor is used to receive the output result of actor which represents the graph output.
struct ActorSet {
explicit ActorSet(const ActorInfo &name) : name_(name) {}
std::vector<DataSourceActorPtr> data_source_actors_;
@ -89,6 +91,7 @@ struct ActorSet {
// No input kernel actors need be triggered specifically.
std::vector<KernelActorPtr> no_input_kernel_actors_;
std::vector<SwitchActorPtr> switch_actors_;
std::vector<CopyActorPtr> copy_actors_;
LoopCountActorPtr loop_count_actor_{nullptr};
OutputActorPtr output_actor_{nullptr};
ActorInfo name_;
@ -129,7 +132,7 @@ class GraphScheduler {
private:
GraphScheduler() = default;
~GraphScheduler() = default;
~GraphScheduler();
DISABLE_COPY_AND_ASSIGN(GraphScheduler);
// Transform the nodes of graph to actors.
@ -145,7 +148,20 @@ class GraphScheduler {
OutputActorPtr BuildOutputActor(const GraphCompilerInfo &graph_compiler_info);
std::vector<KernelActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set);
// Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
// previous graph and the head of next graph.
void CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info);
// The processing of actors link.
// The gather of linking data allows of kernel, it will call following functions by the different from actor type.
void LinkDataArrow(KernelActor *to_actor, const ActorSet *actor_set, const KernelGraphPtr &graph,
KernelWithIndex from_kernel_with_output_idx, KernelWithIndex to_kernel_with_input_idx);
// Link data arrows for internal parameter, convert internal parameter to actor by internal parameter cache to link.
void LinkDataArrowForInternalParameter(const AnfNodePtr &internal_parameter, const KernelGraphPtr &graph,
KernelActor *to_actor, KernelWithIndex to_kernel_with_input_idx);
// Link data arrows in the copy actor scene, insert the copy actor between from_actor and to_actor.
void LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor, KernelActor *to_actor,
KernelWithIndex from_kernel_with_output_idx, KernelWithIndex to_kernel_with_input_idx);
void LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
KernelWithIndex from_kernel_with_output_idx,
KernelWithIndex to_to_kernel_with_input_idx);
@ -158,12 +174,8 @@ class GraphScheduler {
void LinkControlArrowForKernelActor(std::vector<KernelActorPtr> *from_actors, LoopCountActor *to_actor,
GraphExecutionStrategy strategy);
void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set);
void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node,
const KernelMapActor &kernel_actors_map);
void LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
const std::vector<DataSourceActorPtr> &data_source_actors,
const KernelMapActor &kernel_actors_map,
const GraphCompilerInfo &graph_compiler_info);
void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node);
void LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info);
// Check whether the actor set is valid.
bool CheckActorValid(const ActorSet *actor_set) const;
@ -174,6 +186,10 @@ class GraphScheduler {
// Fetch the hsot tensor queue by actor info.
HostTensorQueue *FetchHostQueue(const ActorInfo &actor_info) const;
// The operation of the map of actor_name_to_actor_.
void InsertActor(OpActor<DeviceTensor> *actor);
OpActor<DeviceTensor> *FetchActor(const std::string actor_name) const;
// Display the actor information of corresponding kernel graph.
void DumpActor(const ActorSet *actor_set) const;
void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const;
@ -181,11 +197,19 @@ class GraphScheduler {
void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const;
void DumpOutputActor(const OutputActor *actor, std::ofstream &ofs) const;
// The global maps, only be cleared in the deconstruction.
std::unordered_map<ActorInfo, ActorSetPtr> actors_;
std::unordered_map<ActorInfo, HostTensorQueuePtr> actor_to_host_queue_;
// The second element of pair represents the output index of op actor corresponding to the device tensor.
std::unordered_map<DeviceTensorPtr, std::pair<OpActor<DeviceTensor> *, size_t>> device_tensor_to_actor_;
// The second element of pair represents the output index of kernel actor corresponding to the device tensor.
std::unordered_map<DeviceTensorPtr, std::pair<KernelActorPtr, int>> device_address_to_actor_;
// The local maps and vectors, will be cleared at the beginning of each graph transform.
std::unordered_map<std::string, OpActor<DeviceTensor> *> actor_name_to_actor_;
// The second element of pair represents the output index of op actor corresponding to the graph output front node.
std::map<KernelWithIndex, std::pair<OpActor<DeviceTensor> *, size_t>, session::KernelWithIndexCmp> output_to_actor_;
// Beaceuse the copy actors are built in the link, so need record the all copy actors in the link process to push into
// the actor set after link.
std::vector<CopyActorPtr> copy_actors_;
// The id of memory manager actor.
AID memory_manager_aid_;