forked from mindspore-Ecosystem/mindspore
!16443 actor runtime support host and devcie
From: @limingqi107 Reviewed-by: @cristoval,@wilfchen Signed-off-by: @wilfchen
This commit is contained in:
commit
cd439940f4
|
@ -1109,12 +1109,20 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr
|
|||
}
|
||||
}
|
||||
|
||||
AnfNodePtr KernelGraph::GetFrontNodeByInternalParameter(const AnfNodePtr ¶meter) const {
|
||||
const auto &iter = internal_parameters_to_front_map_.find(parameter);
|
||||
if (iter != internal_parameters_to_front_map_.end()) {
|
||||
void KernelGraph::CacheInternalParameterToFrontNode(const AnfNodePtr ¶meter,
|
||||
const AnfWithOutIndex &front_node_with_index) {
|
||||
if (parameter == nullptr) {
|
||||
return;
|
||||
}
|
||||
internal_parameter_to_front_node_map_[parameter] = front_node_with_index;
|
||||
}
|
||||
|
||||
AnfWithOutIndex KernelGraph::GetFrontNodeByInternalParameter(const AnfNodePtr ¶meter) const {
|
||||
const auto &iter = internal_parameter_to_front_node_map_.find(parameter);
|
||||
if (iter != internal_parameter_to_front_node_map_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
return nullptr;
|
||||
return AnfWithOutIndex();
|
||||
}
|
||||
|
||||
AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const {
|
||||
|
|
|
@ -71,7 +71,7 @@ class KernelGraph : public FuncGraph {
|
|||
parent_graph_ = graph.parent_graph_;
|
||||
start_label_ = graph.start_label_;
|
||||
end_goto_ = graph.end_goto_;
|
||||
internal_parameters_to_front_map_ = graph.internal_parameters_to_front_map_;
|
||||
internal_parameter_to_front_node_map_ = graph.internal_parameter_to_front_node_map_;
|
||||
front_to_internal_outputs_map_ = graph.front_to_internal_outputs_map_;
|
||||
internal_outputs_to_front_map_ = graph.internal_outputs_to_front_map_;
|
||||
internal_outputs_tensor_map_ = graph.internal_outputs_tensor_map_;
|
||||
|
@ -200,7 +200,9 @@ class KernelGraph : public FuncGraph {
|
|||
bool unique_target = false);
|
||||
void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx = -1,
|
||||
int dst_output_idx = -1);
|
||||
AnfNodePtr GetFrontNodeByInternalParameter(const AnfNodePtr ¶meter) const;
|
||||
// Cache the internal parameter and corresponding to front node into internal_parameter_to_front_node_map_.
|
||||
void CacheInternalParameterToFrontNode(const AnfNodePtr ¶meter, const AnfWithOutIndex &front_node_with_index);
|
||||
AnfWithOutIndex GetFrontNodeByInternalParameter(const AnfNodePtr ¶meter) const;
|
||||
AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const;
|
||||
bool IsInternalOutput(const AnfNodePtr &node, int output_idx = -1) const;
|
||||
bool IsUniqueTargetInternalOutput(const AnfNodePtr &node, int output_idx) const;
|
||||
|
@ -355,7 +357,10 @@ class KernelGraph : public FuncGraph {
|
|||
|
||||
CNodePtr start_label_;
|
||||
CNodePtr end_goto_;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> internal_parameters_to_front_map_;
|
||||
// Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is
|
||||
// related to the input of this kernel graph. The first of unordered map is the input of this kernel graph, the second
|
||||
// of unordered map is front node corresponding to the output of previous kernel graph.
|
||||
std::unordered_map<AnfNodePtr, AnfWithOutIndex> internal_parameter_to_front_node_map_;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_;
|
||||
std::unordered_map<AnfNodePtr, std::unordered_map<int, std::pair<AnfNodePtr, bool>>> internal_outputs_to_front_map_;
|
||||
std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_;
|
||||
|
|
|
@ -771,6 +771,8 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph,
|
|||
}
|
||||
cnode_inputs->push_back(parameter_from_cnode);
|
||||
(*other_graph_cnode)[anf] = parameter_from_cnode;
|
||||
KernelWithIndex front_node_with_index(anf, 0);
|
||||
graph->CacheInternalParameterToFrontNode(parameter_from_cnode, front_node_with_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,7 +48,21 @@ bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph) {
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
|
||||
// Judge whether node is internal parameter.
|
||||
if (graph->GetFrontNodeByInternalParameter(node) == nullptr) {
|
||||
const auto &front_node = graph->GetFrontNodeByInternalParameter(node);
|
||||
if (front_node.first == nullptr) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsInternalParameter(const AnfNodePtr &node, const KernelGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
|
||||
// Judge whether node is internal parameter.
|
||||
const auto &front_node = graph->GetFrontNodeByInternalParameter(node);
|
||||
if (front_node.first != nullptr) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,6 +48,10 @@ bool IsDeviceQueueDSActor(const AnfNodePtr &node);
|
|||
bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph);
|
||||
bool IsKernelActor(const AnfNodePtr &node);
|
||||
|
||||
// Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is
|
||||
// related to the input of this kernel graph.
|
||||
bool IsInternalParameter(const AnfNodePtr &node, const KernelGraphPtr &graph);
|
||||
|
||||
// Judge whether the device tensor of the node is persistent or not.
|
||||
bool IsPersistentDeviceTensor(const AnfNodePtr &node);
|
||||
} // namespace runtime
|
||||
|
|
|
@ -44,14 +44,14 @@ void CopyActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *contex
|
|||
}
|
||||
|
||||
void CopyActor::AllocateMemory(OpContext<DeviceTensor> *context) {
|
||||
std::vector<DeviceTensor *> alloc_list({output_device_tensor_});
|
||||
std::vector<DeviceTensor *> alloc_list({output_device_tensor_.get()});
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::AllocateMemory, alloc_list, output_device_context_, context,
|
||||
GetAID());
|
||||
}
|
||||
|
||||
void CopyActor::FreeMemory(OpContext<DeviceTensor> *context) {
|
||||
std::vector<DeviceTensor *> input_free_list({input_device_tensor_});
|
||||
std::vector<DeviceTensor *> output_free_list({output_device_tensor_});
|
||||
std::vector<DeviceTensor *> output_free_list({output_device_tensor_.get()});
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, input_free_list, input_device_context_, context);
|
||||
Async(memory_manager_aid_, &MemoryManagerActor::FreeMemory, output_free_list, output_device_context_, context);
|
||||
}
|
||||
|
@ -59,7 +59,7 @@ void CopyActor::FreeMemory(OpContext<DeviceTensor> *context) {
|
|||
void CopyActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
if (!Copy(output_device_tensor_, input_device_tensor_)) {
|
||||
if (!Copy(output_device_tensor_.get(), input_device_tensor_)) {
|
||||
std::string error_info = "Copy device tensor failed: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
|
@ -146,8 +146,8 @@ void CopyActor::SendOutput(OpContext<DeviceTensor> *context) const {
|
|||
std::string error_info = "The output index is out of range: " + GetAID().Name();
|
||||
SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
|
||||
}
|
||||
auto data =
|
||||
std::make_shared<OpData<DeviceTensor>>(op_arrow->to_op_id_, output_device_tensor_, op_arrow->to_input_index_);
|
||||
auto data = std::make_shared<OpData<DeviceTensor>>(op_arrow->to_op_id_, output_device_tensor_.get(),
|
||||
op_arrow->to_input_index_);
|
||||
Async(op_arrow->to_op_id_, &CopyActor::RunOpData, data, context);
|
||||
}
|
||||
|
||||
|
|
|
@ -36,14 +36,11 @@ using mindspore::device::DeviceContext;
|
|||
// -> OnMemoryAllocFinish -> Copy -> FreeMemory -> SendOutput.
|
||||
class CopyActor : public MemoryInterfaceActor {
|
||||
public:
|
||||
CopyActor(const std::string &name, const DeviceContext *input_device_context,
|
||||
const DeviceContext *output_device_context, const AID &memory_manager_aid)
|
||||
CopyActor(const std::string &name, const AID &memory_manager_aid)
|
||||
: MemoryInterfaceActor(name),
|
||||
memory_manager_aid_(memory_manager_aid),
|
||||
input_datas_num_(0),
|
||||
input_controls_num_(0),
|
||||
input_device_context_(input_device_context),
|
||||
output_device_context_(output_device_context),
|
||||
input_device_tensor_(nullptr),
|
||||
output_device_tensor_(nullptr) {}
|
||||
~CopyActor() override = default;
|
||||
|
@ -89,9 +86,10 @@ class CopyActor : public MemoryInterfaceActor {
|
|||
const DeviceContext *input_device_context_;
|
||||
const DeviceContext *output_device_context_;
|
||||
|
||||
// The device tensor for copy.
|
||||
// The input device tensor is saved from the input data.
|
||||
DeviceTensor *input_device_tensor_;
|
||||
DeviceTensor *output_device_tensor_;
|
||||
// The output device tensor is created in the copy actor build, so can't be the raw pointer.
|
||||
DeviceTensorPtr output_device_tensor_;
|
||||
};
|
||||
|
||||
using CopyActorPtr = std::shared_ptr<CopyActor>;
|
||||
|
|
|
@ -35,14 +35,10 @@ using mindspore::session::KernelWithIndex;
|
|||
using mindspore::tensor::TensorPtr;
|
||||
|
||||
// The output actor is used to receive the output result of actor which represents the graph output.
|
||||
class OutputActor : public ActorBase {
|
||||
class OutputActor : public OpActor<DeviceTensor> {
|
||||
public:
|
||||
OutputActor(std::string name, size_t loop_count, size_t outputs_num)
|
||||
: ActorBase(name),
|
||||
loop_count_(loop_count),
|
||||
current_count_(0),
|
||||
outputs_num_(outputs_num),
|
||||
current_outputs_num_(0) {
|
||||
: OpActor(name), loop_count_(loop_count), current_count_(0), outputs_num_(outputs_num), current_outputs_num_(0) {
|
||||
outputs_.resize(outputs_num);
|
||||
output_nodes_.resize(outputs_num);
|
||||
device_contexts_.resize(outputs_num);
|
||||
|
|
|
@ -28,42 +28,19 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
|
||||
namespace {
|
||||
KernelActor *FindKernelActor(const KernelMapActor &kernel_actors_map, const std::string &name) {
|
||||
auto iter = kernel_actors_map.find(name);
|
||||
if (iter != kernel_actors_map.end()) {
|
||||
return iter->second.get();
|
||||
bool IsNeedInsertCopyActor(const DeviceContext *from_devcie_context, const DeviceContext *to_devcie_context) {
|
||||
MS_EXCEPTION_IF_NULL(from_devcie_context);
|
||||
MS_EXCEPTION_IF_NULL(to_devcie_context);
|
||||
|
||||
if (from_devcie_context->GetDeviceAddressType() == to_devcie_context->GetDeviceAddressType()) {
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
DeviceQueueDataSourceActor *FindDeviceQueueDSActor(const std::vector<DataSourceActorPtr> &data_source_actors) {
|
||||
for (auto &actor : data_source_actors) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
if (actor->GetAID().Name().find("_DeviceQueueDataSourceActor") != string::npos) {
|
||||
auto device_queue_ds_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor.get());
|
||||
return device_queue_ds_actor;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
HostQueueDataSourceActor *FindHostQueueDSActor(const std::vector<DataSourceActorPtr> &data_source_actors) {
|
||||
for (auto &actor : data_source_actors) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
if (actor->GetAID().Name().find("_HostQueueDataSourceActor") != string::npos) {
|
||||
auto device_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(actor.get());
|
||||
return device_queue_ds_actor;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Update the reference count of device tensor by the output index of node.
|
||||
void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_count = false) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_idx);
|
||||
void UpdateRefCount(DeviceTensor *device_tensor, bool is_max_ref_count = false) {
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
if (is_max_ref_count) {
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
|
@ -73,6 +50,13 @@ void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_c
|
|||
device_tensor->ResetRefCount();
|
||||
}
|
||||
|
||||
// Update the reference count of device tensor by the output index of node.
|
||||
void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_count = false) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_idx, false);
|
||||
UpdateRefCount(device_tensor.get(), is_max_ref_count);
|
||||
}
|
||||
|
||||
// The branch processing of PrepareDataForValueNode that value type is tensor.
|
||||
void PrepareDataForValueNodeTensor(const ValueNodePtr &node, const ValuePtr &node_value,
|
||||
const DeviceContext *device_context) {
|
||||
|
@ -207,8 +191,7 @@ void AllocateContinuousMemoryForInput(const AnfNodePtr &kernel, const DeviceCont
|
|||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
// In the scene of communication op and computing op parallel multi stream, the input address of communication op
|
||||
// can't be reused, so set the max reference count.
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCount();
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
|
||||
if (device_tensor->GetPtr() == nullptr) {
|
||||
is_need_alloc_memory = true;
|
||||
|
@ -241,8 +224,7 @@ void AllocateContinuousMemoryForOutput(const AnfNodePtr &kernel, const DeviceCon
|
|||
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
// One time application for continuous memory, so set the max reference count.
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCount();
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
|
||||
if (device_tensor->GetPtr() == nullptr) {
|
||||
is_need_alloc_memory = true;
|
||||
|
@ -261,7 +243,23 @@ void AllocateContinuousMemoryForOutput(const AnfNodePtr &kernel, const DeviceCon
|
|||
}
|
||||
} // namespace
|
||||
|
||||
GraphScheduler::~GraphScheduler() {
|
||||
// Global maps clear.
|
||||
device_tensor_to_actor_.clear();
|
||||
actor_to_host_queue_.clear();
|
||||
actors_.clear();
|
||||
|
||||
// Local maps clear.
|
||||
actor_name_to_actor_.clear();
|
||||
output_to_actor_.clear();
|
||||
}
|
||||
|
||||
void GraphScheduler::Initialize() {
|
||||
// Local maps and vcetors clear.
|
||||
actor_name_to_actor_.clear();
|
||||
output_to_actor_.clear();
|
||||
copy_actors_.clear();
|
||||
|
||||
if (init_) {
|
||||
return;
|
||||
}
|
||||
|
@ -297,7 +295,11 @@ ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info
|
|||
|
||||
PersistDeviceTensor(graph_compiler_info);
|
||||
const auto &actor_set = Build(graph_compiler_info);
|
||||
CacheGraphOutputToActor(graph_compiler_info);
|
||||
Link(actor_set.get(), graph_compiler_info, strategy);
|
||||
// The copy actors are built in the link, so need push into the actor set after link.
|
||||
actor_set->copy_actors_ = copy_actors_;
|
||||
|
||||
actors_.emplace(actor_set->name_, actor_set);
|
||||
|
||||
DumpActor(actor_set.get());
|
||||
|
@ -310,40 +312,37 @@ ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info
|
|||
|
||||
void GraphScheduler::Schedule(const ActorSet *actor_set) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
auto actorMgr = ActorMgr::GetActorMgrRef();
|
||||
MS_EXCEPTION_IF_NULL(actorMgr);
|
||||
std::vector<ActorReference> actors;
|
||||
|
||||
// Schedule dats source actors.
|
||||
// Collect actors.
|
||||
for (auto &data_source_actor : actor_set->data_source_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(data_source_actor);
|
||||
auto base_actor = static_cast<ActorReference>(data_source_actor);
|
||||
(void)actorMgr->Spawn(base_actor);
|
||||
actors.emplace_back(static_cast<ActorReference>(data_source_actor));
|
||||
}
|
||||
|
||||
// Schedule kernel actors.
|
||||
for (auto &kernel_actor : actor_set->kernel_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
auto base_actor = static_cast<ActorReference>(kernel_actor);
|
||||
(void)actorMgr->Spawn(base_actor);
|
||||
actors.emplace_back(static_cast<ActorReference>(kernel_actor));
|
||||
}
|
||||
|
||||
// Schedule switch actors.
|
||||
for (auto &switch_actor : actor_set->switch_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(switch_actor);
|
||||
auto base_actor = static_cast<ActorReference>(switch_actor);
|
||||
(void)actorMgr->Spawn(base_actor);
|
||||
actors.emplace_back(static_cast<ActorReference>(switch_actor));
|
||||
}
|
||||
for (auto ©_actor : actor_set->copy_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(copy_actor);
|
||||
actors.emplace_back(static_cast<ActorReference>(copy_actor));
|
||||
}
|
||||
|
||||
// Schedule loop count actor.
|
||||
if (actor_set->loop_count_actor_ != nullptr) {
|
||||
auto base_actor = static_cast<ActorReference>(actor_set->loop_count_actor_);
|
||||
(void)actorMgr->Spawn(base_actor);
|
||||
actors.emplace_back(static_cast<ActorReference>(actor_set->loop_count_actor_));
|
||||
}
|
||||
if (actor_set->output_actor_ != nullptr) {
|
||||
actors.emplace_back(static_cast<ActorReference>(actor_set->output_actor_));
|
||||
}
|
||||
|
||||
// Schedule output actor.
|
||||
if (actor_set->output_actor_ != nullptr) {
|
||||
auto base_actor = static_cast<ActorReference>(actor_set->output_actor_);
|
||||
(void)actorMgr->Spawn(base_actor);
|
||||
// Schedule actors.
|
||||
auto actorMgr = ActorMgr::GetActorMgrRef();
|
||||
MS_EXCEPTION_IF_NULL(actorMgr);
|
||||
for (auto actor : actors) {
|
||||
(void)actorMgr->Spawn(actor);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -351,7 +350,8 @@ void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerIn
|
|||
const std::vector<std::vector<TensorPtr>> &input_tensors) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
std::vector<TensorPtr> host_tensors;
|
||||
const auto &host_data_source_actor = FindHostQueueDSActor(actor_set->data_source_actors_);
|
||||
std::string actor_name = actor_set->name_ + "_HostDSActor";
|
||||
const auto &host_data_source_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(actor_name));
|
||||
if (host_data_source_actor != nullptr) {
|
||||
host_tensors.resize(host_data_source_actor->data_nodes_.size());
|
||||
}
|
||||
|
@ -478,14 +478,39 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info)
|
|||
return actor_set;
|
||||
}
|
||||
|
||||
void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
for (const auto &graph : graph_compiler_info.graphs_) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
const auto &outputs = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
|
||||
for (const auto &output : outputs) {
|
||||
const auto &output_with_index = AnfAlgo::VisitKernelWithReturnType(output, 0, true);
|
||||
MS_EXCEPTION_IF_NULL(output_with_index.first);
|
||||
const auto &front_node = graph->GetFrontAnfByBackendAnf(output_with_index.first);
|
||||
if (front_node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto origin_output_with_index = KernelWithIndex(front_node, output_with_index.second);
|
||||
std::string actor_name;
|
||||
// Only cache the kernel actor and device queue data source actor.
|
||||
if (IsKernelActor(output_with_index.first)) {
|
||||
actor_name = output_with_index.first->fullname_with_scope();
|
||||
} else if (IsDeviceQueueDSActor(output_with_index.first)) {
|
||||
actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
const auto &actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
std::pair<OpActor<DeviceTensor> *, size_t> actor_pair(actor, output_with_index.second);
|
||||
output_to_actor_.emplace(origin_output_with_index, actor_pair);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info,
|
||||
GraphExecutionStrategy strategy) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
KernelMapActor kernel_actors_temp_map;
|
||||
for (auto &actor : actor_set->kernel_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
kernel_actors_temp_map.emplace(actor->GetAID().Name(), actor);
|
||||
}
|
||||
|
||||
// Foreach the execution order to link the actors.
|
||||
for (const auto &graph : graph_compiler_info.graphs_) {
|
||||
|
@ -495,34 +520,21 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
|
|||
if (!IsKernelActor(kernel)) {
|
||||
continue;
|
||||
}
|
||||
auto kernel_actor = FindKernelActor(kernel_actors_temp_map, kernel->fullname_with_scope());
|
||||
const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
|
||||
for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
|
||||
auto input_node = AnfAlgo::GetInputNode(kernel, i);
|
||||
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
|
||||
LinkControlArrowByAutoMonad(kernel_actor, input_node, kernel_actors_temp_map);
|
||||
LinkControlArrowByAutoMonad(kernel_actor, input_node);
|
||||
if (HasAbstractMonad(input_node)) {
|
||||
continue; // No data arrow for monad input.
|
||||
}
|
||||
|
||||
KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
|
||||
KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
|
||||
auto from_kernel = from_kernel_with_output_idx.first;
|
||||
|
||||
if (IsDeviceQueueDSActor(from_kernel)) {
|
||||
// Link the data arrows of device queue data source actor.
|
||||
auto from_actor = FindDeviceQueueDSActor(actor_set->data_source_actors_);
|
||||
LinkDataArrowForDeviceDSActor(from_actor, kernel_actor, from_kernel_with_output_idx,
|
||||
to_kernel_with_input_idx);
|
||||
} else if (IsHostQueueDSActor(from_kernel, graph)) {
|
||||
// Link the data arrows of host queue data source actor.
|
||||
auto from_actor = FindHostQueueDSActor(actor_set->data_source_actors_);
|
||||
LinkDataArrowForHostDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else {
|
||||
// Link the data arrows of kernel actor.
|
||||
auto from_actor = FindKernelActor(kernel_actors_temp_map, from_kernel->fullname_with_scope());
|
||||
LinkDataArrowForKernelActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
}
|
||||
// The gather of linking data allows of kernel by the different from kernel type.
|
||||
LinkDataArrow(kernel_actor, actor_set, graph, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -531,15 +543,13 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
|
|||
LinkControlArrowForKernelActor(&(actor_set->kernel_actors_), actor_set->loop_count_actor_.get(), strategy);
|
||||
|
||||
// BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors.
|
||||
auto no_input_kernel_actors = BuildNoInputKernelActor(actor_set);
|
||||
actor_set->no_input_kernel_actors_.swap(no_input_kernel_actors);
|
||||
actor_set->no_input_kernel_actors_ = BuildNoInputKernelActor(actor_set);
|
||||
|
||||
// Link the control arrows of loop count actor, which depends on the no input kernel actors.
|
||||
LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set);
|
||||
|
||||
// Link the output result arrows for output actors.
|
||||
LinkOutputResultArrowForOutputActor(actor_set->output_actor_.get(), actor_set->data_source_actors_,
|
||||
kernel_actors_temp_map, graph_compiler_info);
|
||||
LinkOutputResultArrowForOutputActor(actor_set->output_actor_.get(), graph_compiler_info);
|
||||
}
|
||||
|
||||
std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
|
||||
|
@ -558,10 +568,11 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
|
|||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (IsHostQueueDSActor(input_node, graph)) {
|
||||
if (host_queue_ds_actor == nullptr) {
|
||||
auto actor_name = graph_compiler_info.name_ + "_HostQueueDataSourceActor";
|
||||
auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
|
||||
MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
|
||||
host_queue_ds_actor =
|
||||
std::make_shared<HostQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_, host_queue);
|
||||
InsertActor(host_queue_ds_actor.get());
|
||||
data_source_actors.emplace_back(host_queue_ds_actor);
|
||||
}
|
||||
|
||||
|
@ -584,12 +595,12 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph
|
|||
const auto &iter = std::find_if(execution_order.begin(), execution_order.end(),
|
||||
[](const CNodePtr &node) { return IsDeviceQueueDSActor(node); });
|
||||
if (iter != execution_order.end()) {
|
||||
auto actor_name =
|
||||
graph_compiler_info.name_ + "_DeviceQueueDataSourceActor" + "_" + std::to_string(graph->graph_id());
|
||||
auto actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
|
||||
MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
|
||||
auto device_queue_ds_actor =
|
||||
std::make_shared<DeviceQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_);
|
||||
MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
|
||||
InsertActor(device_queue_ds_actor.get());
|
||||
data_source_actors.emplace_back(device_queue_ds_actor);
|
||||
device_queue_ds_actor->data_kernel_ = *iter;
|
||||
}
|
||||
|
@ -610,6 +621,7 @@ std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompiler
|
|||
auto kernel_actor =
|
||||
std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_);
|
||||
MS_EXCEPTION_IF_NULL(kernel_actor);
|
||||
InsertActor(kernel_actor.get());
|
||||
kernel_actors.emplace_back(kernel_actor);
|
||||
}
|
||||
}
|
||||
|
@ -623,6 +635,7 @@ LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &g
|
|||
auto loop_count_actor = std::make_shared<LoopCountActor>(actor_name, loop_count);
|
||||
MS_LOG(INFO) << "Create loop count actor: " << actor_name;
|
||||
MS_EXCEPTION_IF_NULL(loop_count_actor);
|
||||
InsertActor(loop_count_actor.get());
|
||||
return loop_count_actor;
|
||||
}
|
||||
|
||||
|
@ -633,6 +646,7 @@ OutputActorPtr GraphScheduler::BuildOutputActor(const GraphCompilerInfo &graph_c
|
|||
std::make_shared<OutputActor>(actor_name, loop_count, graph_compiler_info.origin_outputs_order_.size());
|
||||
MS_LOG(INFO) << "Create output actor: " << actor_name;
|
||||
MS_EXCEPTION_IF_NULL(output_actor);
|
||||
InsertActor(output_actor.get());
|
||||
return output_actor;
|
||||
}
|
||||
|
||||
|
@ -651,6 +665,67 @@ std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorS
|
|||
return no_input_kernel_actors;
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrow(KernelActor *to_actor, const ActorSet *actor_set, const KernelGraphPtr &graph,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_kernel_with_input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
||||
auto from_kernel = from_kernel_with_output_idx.first;
|
||||
if (IsDeviceQueueDSActor(from_kernel)) {
|
||||
// Link the data arrows of device queue data source actor.
|
||||
std::string actor_name = actor_set->name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
|
||||
const auto &from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(FetchActor(actor_name));
|
||||
LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else if (IsHostQueueDSActor(from_kernel, graph)) {
|
||||
// Link the data arrows of host queue data source actor.
|
||||
std::string actor_name = actor_set->name_ + "_HostDSActor";
|
||||
const auto &from_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(actor_name));
|
||||
LinkDataArrowForHostDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else if (IsKernelActor(from_kernel)) {
|
||||
// Link the data arrows of kernel actor.
|
||||
const auto &from_actor = dynamic_cast<KernelActor *>(FetchActor(from_kernel->fullname_with_scope()));
|
||||
LinkDataArrowForKernelActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else if (IsInternalParameter(from_kernel, graph)) {
|
||||
// Link data arrow for internal parameter, convert internal parameter to actor by internal parameter cache to link.
|
||||
LinkDataArrowForInternalParameter(from_kernel, graph, to_actor, to_kernel_with_input_idx);
|
||||
} else if (IsPersistentDeviceTensor(from_kernel)) {
|
||||
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second,
|
||||
static_cast<void *>(from_kernel.get()));
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid from kernel: " << from_kernel->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &internal_parameter,
|
||||
const KernelGraphPtr &graph, KernelActor *to_actor,
|
||||
KernelWithIndex to_kernel_with_input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(internal_parameter);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
|
||||
// Parameter ---> front node ---> actor.
|
||||
auto front_node_with_index = graph->GetFrontNodeByInternalParameter(internal_parameter);
|
||||
MS_EXCEPTION_IF_NULL(front_node_with_index.first);
|
||||
if (output_to_actor_.count(front_node_with_index) == 0) {
|
||||
MS_LOG(EXCEPTION) << "Can't find actor by node:" << front_node_with_index.first->fullname_with_scope();
|
||||
}
|
||||
auto actor_pair = output_to_actor_[front_node_with_index];
|
||||
|
||||
if (IsDeviceQueueDSActor(front_node_with_index.first)) {
|
||||
auto from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor_pair.first);
|
||||
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->data_kernel_, actor_pair.second);
|
||||
LinkDataArrowForDeviceDSActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else if (IsKernelActor(front_node_with_index.first)) {
|
||||
auto from_actor = dynamic_cast<KernelActor *>(actor_pair.first);
|
||||
auto from_kernel_with_output_idx = KernelWithIndex(from_actor->kernel_, actor_pair.second);
|
||||
LinkDataArrowForKernelActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid internal parameter: " << internal_parameter->fullname_with_scope();
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_kernel_with_input_idx) {
|
||||
|
@ -662,13 +737,17 @@ void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *f
|
|||
auto from_output_index = from_kernel_with_output_idx.second;
|
||||
auto to_input_index = to_kernel_with_input_idx.second;
|
||||
|
||||
auto to_aid = to_actor->GetAID();
|
||||
auto op_arrow = std::make_shared<OpArrow>(from_output_index, to_aid, to_input_index);
|
||||
from_actor->output_op_arrows_.emplace_back(op_arrow);
|
||||
to_actor->input_datas_num_++;
|
||||
if (IsNeedInsertCopyActor(from_actor->device_context_, to_actor->device_context_)) {
|
||||
LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else {
|
||||
auto to_aid = to_actor->GetAID();
|
||||
auto op_arrow = std::make_shared<OpArrow>(from_output_index, to_aid, to_input_index);
|
||||
from_actor->output_op_arrows_.emplace_back(op_arrow);
|
||||
to_actor->input_datas_num_++;
|
||||
|
||||
// Update the reference count of device tensor.
|
||||
UpdateRefCount(from_kernel, from_output_index);
|
||||
// Update the reference count of device tensor.
|
||||
UpdateRefCount(from_kernel, from_output_index);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_actor, KernelActor *to_actor,
|
||||
|
@ -701,16 +780,16 @@ void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_
|
|||
void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_kernel_with_input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
auto from_kernel = from_kernel_with_output_idx.first;
|
||||
MS_EXCEPTION_IF_NULL(from_kernel);
|
||||
auto from_output_index = from_kernel_with_output_idx.second;
|
||||
auto to_input_index = to_kernel_with_input_idx.second;
|
||||
|
||||
if (IsPersistentDeviceTensor(from_kernel)) {
|
||||
to_actor->device_tensor_store_keys_.emplace_back(to_input_index, static_cast<void *>(from_kernel.get()));
|
||||
} else if (IsKernelActor(from_kernel)) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
if (IsNeedInsertCopyActor(from_actor->device_context_, to_actor->device_context_)) {
|
||||
LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
|
||||
} else {
|
||||
auto to_aid = to_actor->GetAID();
|
||||
auto op_arrow = std::make_shared<OpArrow>(from_output_index, to_aid, to_input_index);
|
||||
from_actor->output_op_arrows_.emplace_back(op_arrow);
|
||||
|
@ -721,6 +800,64 @@ void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, Kernel
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor, KernelActor *to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_kernel_with_input_idx) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
auto from_kernel = from_kernel_with_output_idx.first;
|
||||
MS_EXCEPTION_IF_NULL(from_kernel);
|
||||
auto to_devcie_context = to_actor->device_context_;
|
||||
MS_EXCEPTION_IF_NULL(to_devcie_context);
|
||||
auto from_output_index = from_kernel_with_output_idx.second;
|
||||
auto to_input_index = to_kernel_with_input_idx.second;
|
||||
|
||||
std::string name =
|
||||
"copy_actor_" + from_kernel->fullname_with_scope() + "_output_index_" + std::to_string(from_output_index);
|
||||
CopyActor *copy_actor = dynamic_cast<CopyActor *>(FetchActor(name));
|
||||
// Link between from actor and copy actor.
|
||||
if (copy_actor == nullptr) {
|
||||
// Create the copy actor.
|
||||
auto copy_actor_shared_ptr = std::make_shared<CopyActor>(name, memory_manager_aid_);
|
||||
copy_actors_.emplace_back(copy_actor_shared_ptr);
|
||||
copy_actor = copy_actor_shared_ptr.get();
|
||||
MS_EXCEPTION_IF_NULL(copy_actor);
|
||||
InsertActor(copy_actor);
|
||||
|
||||
// LInk.
|
||||
const DeviceContext *from_devcie_context = nullptr;
|
||||
auto op_arrow_to_copy = std::make_shared<OpArrow>(from_output_index, copy_actor->GetAID(), 0);
|
||||
if (IsDeviceQueueDSActor(from_kernel)) {
|
||||
auto real_from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(from_actor);
|
||||
from_devcie_context = real_from_actor->device_context_;
|
||||
real_from_actor->output_op_arrows_.emplace_back(op_arrow_to_copy);
|
||||
} else if (IsKernelActor(from_kernel)) {
|
||||
auto real_from_actor = dynamic_cast<KernelActor *>(from_actor);
|
||||
from_devcie_context = real_from_actor->device_context_;
|
||||
real_from_actor->output_op_arrows_.emplace_back(op_arrow_to_copy);
|
||||
}
|
||||
copy_actor->input_datas_num_++;
|
||||
|
||||
// Set the member of the copy actor.
|
||||
const auto &from_device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
|
||||
MS_EXCEPTION_IF_NULL(from_device_tensor);
|
||||
copy_actor->output_device_tensor_ = to_devcie_context->CreateDeviceAddress(
|
||||
nullptr, from_device_tensor->GetSize(), from_device_tensor->format(), from_device_tensor->type_id());
|
||||
MS_EXCEPTION_IF_NULL(from_devcie_context);
|
||||
copy_actor->input_device_context_ = from_devcie_context;
|
||||
copy_actor->output_device_context_ = to_devcie_context;
|
||||
|
||||
// Update the reference count of device tensor.
|
||||
UpdateRefCount(from_device_tensor.get());
|
||||
}
|
||||
|
||||
// If the copy actor already exists, only need link between copy actor and to actor.
|
||||
auto op_arrow_from_copy = std::make_shared<OpArrow>(0, to_actor->GetAID(), to_input_index);
|
||||
copy_actor->output_op_arrows_.emplace_back(op_arrow_from_copy);
|
||||
to_actor->input_datas_num_++;
|
||||
UpdateRefCount(copy_actor->output_device_tensor_.get());
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowForKernelActor(std::vector<KernelActorPtr> *from_actors, LoopCountActor *to_actor,
|
||||
GraphExecutionStrategy strategy) {
|
||||
MS_EXCEPTION_IF_NULL(from_actors);
|
||||
|
@ -744,8 +881,7 @@ void GraphScheduler::LinkControlArrowForKernelActor(std::vector<KernelActorPtr>
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node,
|
||||
const KernelMapActor &kernel_actors_map) {
|
||||
void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node) {
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
MS_EXCEPTION_IF_NULL(from_node);
|
||||
if (!from_node->isa<CNode>()) {
|
||||
|
@ -770,7 +906,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An
|
|||
} else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimMakeTuple)) {
|
||||
// Make tuple node needs to be expanded.
|
||||
for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
|
||||
LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), kernel_actors_map);
|
||||
LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i));
|
||||
}
|
||||
return;
|
||||
} else {
|
||||
|
@ -785,12 +921,12 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An
|
|||
if (AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimUpdateState) ||
|
||||
AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimLoad) ||
|
||||
AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimMakeTuple)) {
|
||||
LinkControlArrowByAutoMonad(to_actor, real_depend_input, kernel_actors_map);
|
||||
LinkControlArrowByAutoMonad(to_actor, real_depend_input);
|
||||
return;
|
||||
}
|
||||
|
||||
// Link the control arrow between the kernel actors.
|
||||
auto from_actor = FindKernelActor(kernel_actors_map, real_depend_input->fullname_with_scope());
|
||||
const auto &from_actor = dynamic_cast<KernelActor *>(FetchActor(real_depend_input->fullname_with_scope()));
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
from_actor->output_op_controls_.emplace_back(to_actor->GetAID());
|
||||
to_actor->input_controls_num_++;
|
||||
|
@ -818,8 +954,6 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun
|
|||
}
|
||||
|
||||
void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
|
||||
const std::vector<DataSourceActorPtr> &data_source_actors,
|
||||
const KernelMapActor &kernel_actors_map,
|
||||
const GraphCompilerInfo &graph_compiler_info) {
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
|
||||
|
@ -853,7 +987,8 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
|
|||
|
||||
// The graph output is from kernel actor.
|
||||
if (IsKernelActor(output_with_index.first)) {
|
||||
const auto &from_actor = FindKernelActor(kernel_actors_map, output_with_index.first->fullname_with_scope());
|
||||
const auto &from_actor =
|
||||
dynamic_cast<KernelActor *>(FetchActor(output_with_index.first->fullname_with_scope()));
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
auto op_arrow = std::make_shared<OpArrow>(output_with_index.second, to_actor->GetAID(), iter->second);
|
||||
from_actor->output_result_arrows_.emplace_back(op_arrow);
|
||||
|
@ -861,10 +996,12 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
|
|||
}
|
||||
|
||||
// The graph output is from data source actor.
|
||||
std::string actor_name;
|
||||
DataSourceActor *from_actor = nullptr;
|
||||
size_t from_actor_output_index = 0;
|
||||
if (IsHostQueueDSActor(output_with_index.first, graph)) {
|
||||
const auto &host_queue_ds_actor = FindHostQueueDSActor(data_source_actors);
|
||||
actor_name = graph_compiler_info.name_ + "_HostDSActor";
|
||||
const auto &host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(actor_name));
|
||||
auto position_iter = host_queue_ds_actor->data_node_position_map_.find(output_with_index.first);
|
||||
if (position_iter == host_queue_ds_actor->data_node_position_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Parameter node: " << output_with_index.first->fullname_with_scope() << " is not exist.";
|
||||
|
@ -873,7 +1010,8 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
|
|||
UpdateRefCount(host_queue_ds_actor->data_nodes_[from_actor_output_index], output_with_index.second, true);
|
||||
from_actor = static_cast<DataSourceActor *>(host_queue_ds_actor);
|
||||
} else if (IsDeviceQueueDSActor(output_with_index.first)) {
|
||||
from_actor = FindDeviceQueueDSActor(data_source_actors);
|
||||
actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
|
||||
from_actor = dynamic_cast<DataSourceActor *>(FetchActor(actor_name));
|
||||
from_actor_output_index = output_with_index.second;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
|
@ -913,6 +1051,26 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set) const {
|
|||
}
|
||||
}
|
||||
|
||||
// Check the copy actors.
|
||||
for (const auto ©_actor : actor_set->copy_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(copy_actor);
|
||||
if (copy_actor->output_op_arrows_.size() + copy_actor->output_op_controls_.size() == 0) {
|
||||
MS_LOG(ERROR) << copy_actor->GetAID().Name() << " has no user.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t kCopyActorInputDataNum = 1;
|
||||
auto input_data_num = copy_actor->input_datas_num_;
|
||||
auto device_tensor_store_num = copy_actor->device_tensor_store_keys_.size();
|
||||
if (input_data_num + device_tensor_store_num != kCopyActorInputDataNum) {
|
||||
MS_LOG(ERROR) << "The input building of " << copy_actor->GetAID().Name()
|
||||
<< " is wrong, input data num: " << input_data_num
|
||||
<< ", device tensor store num: " << device_tensor_store_num
|
||||
<< ", total input num: " << kCopyActorInputDataNum;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check the loop count actor.
|
||||
const auto &loop_count_actor = actor_set->loop_count_actor_;
|
||||
if (loop_count_actor != nullptr) {
|
||||
|
@ -937,8 +1095,7 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
|
|||
}
|
||||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0);
|
||||
DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor);
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCount();
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
}
|
||||
|
||||
for (auto &input_node : graph->input_nodes()) {
|
||||
|
@ -947,8 +1104,7 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler
|
|||
auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor);
|
||||
device_tensor->set_original_ref_count(SIZE_MAX);
|
||||
device_tensor->ResetRefCount();
|
||||
UpdateRefCount(device_tensor.get(), true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -963,6 +1119,22 @@ HostTensorQueue *GraphScheduler::FetchHostQueue(const ActorInfo &actor_info) con
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::InsertActor(OpActor<DeviceTensor> *actor) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
if (actor_name_to_actor_.count(actor->GetAID().Name()) > 0) {
|
||||
MS_LOG(EXCEPTION) << "The actor already exists: " << actor->GetAID().Name();
|
||||
}
|
||||
actor_name_to_actor_[actor->GetAID().Name()] = actor;
|
||||
}
|
||||
|
||||
OpActor<DeviceTensor> *GraphScheduler::FetchActor(const std::string actor_name) const {
|
||||
const auto &iter = actor_name_to_actor_.find(actor_name);
|
||||
if (iter == actor_name_to_actor_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
void GraphScheduler::DumpActor(const ActorSet *actor_set) const {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
const auto &context_ptr = MsContext::GetInstance();
|
||||
|
@ -1024,7 +1196,7 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of
|
|||
ofs << "\tactor_name:" << actor_name << "\tdevice_context:" << actor->device_context_->device_context_key().ToString()
|
||||
<< "\n";
|
||||
|
||||
if (actor_name.find("_DeviceQueueDataSourceActor") != string::npos) {
|
||||
if (actor_name.find("_DeviceDSActor") != string::npos) {
|
||||
// Dump the member info of device queue data source actor.
|
||||
const auto &device_queue_ds_actor = dynamic_cast<const DeviceQueueDataSourceActor *>(actor);
|
||||
const auto &data_kernel = device_queue_ds_actor->data_kernel_;
|
||||
|
@ -1038,7 +1210,7 @@ void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &of
|
|||
ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
|
||||
<< "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
|
||||
}
|
||||
} else if (actor_name.find("_HostQueueDataSourceActor") != string::npos) {
|
||||
} else if (actor_name.find("_HostDSActor") != string::npos) {
|
||||
// Dump the member info of host queue data source actor.
|
||||
const auto &host_queue_ds_actor = dynamic_cast<const HostQueueDataSourceActor *>(actor);
|
||||
ofs << "\t\tdata_nodes:" << host_queue_ds_actor->data_nodes_.size() << "\n";
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "runtime/framework/actor/kernel_actor.h"
|
||||
#include "runtime/framework/actor/output_actor.h"
|
||||
#include "runtime/framework/actor/switch_actor.h"
|
||||
#include "runtime/framework/actor/copy_actor.h"
|
||||
#include "runtime/hardware/device_context.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
|
@ -75,13 +76,14 @@ struct GraphCompilerInfo {
|
|||
};
|
||||
|
||||
// The actor set generated by graph transformer is the execution unit of actor runtime.
|
||||
// It includes data source actor, kernel actor, loop count actor.
|
||||
// The data source actor is used to obtain data and process them into device tensors,
|
||||
// and then send them to kernel actor. The kernel actor is used to receive the device tensors to luanch kernel.
|
||||
// Specifically notice the no input kernel actor, it means that this actor has no input device tensor, need be triggered
|
||||
// externally. The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
|
||||
// and decide whether to loop execution by loop count. The output actor is used to receive the output result of actor
|
||||
// which represents the graph output.
|
||||
// It includes data source actor, kernel actor, switch actor, copy actor, loop count actor and output actor.
|
||||
// The data source actor is used to obtain data and process them into device tensors, and send them to kernel actor.
|
||||
// The kernel actor is used to receive the device tensors to luanch kernel. Specifically notice the no input
|
||||
// kernel actor, it means that this actor has no input device tensor, need be triggered externally.
|
||||
// The copy actor is used to convert the device tensor between the different device kernel.
|
||||
// The loop count actor is used to receive the control of tail kernel actor to represent the end of one step
|
||||
// and decide whether to loop execution by loop count.
|
||||
// The output actor is used to receive the output result of actor which represents the graph output.
|
||||
struct ActorSet {
|
||||
explicit ActorSet(const ActorInfo &name) : name_(name) {}
|
||||
std::vector<DataSourceActorPtr> data_source_actors_;
|
||||
|
@ -89,6 +91,7 @@ struct ActorSet {
|
|||
// No input kernel actors need be triggered specifically.
|
||||
std::vector<KernelActorPtr> no_input_kernel_actors_;
|
||||
std::vector<SwitchActorPtr> switch_actors_;
|
||||
std::vector<CopyActorPtr> copy_actors_;
|
||||
LoopCountActorPtr loop_count_actor_{nullptr};
|
||||
OutputActorPtr output_actor_{nullptr};
|
||||
ActorInfo name_;
|
||||
|
@ -129,7 +132,7 @@ class GraphScheduler {
|
|||
|
||||
private:
|
||||
GraphScheduler() = default;
|
||||
~GraphScheduler() = default;
|
||||
~GraphScheduler();
|
||||
DISABLE_COPY_AND_ASSIGN(GraphScheduler);
|
||||
|
||||
// Transform the nodes of graph to actors.
|
||||
|
@ -145,7 +148,20 @@ class GraphScheduler {
|
|||
OutputActorPtr BuildOutputActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
std::vector<KernelActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set);
|
||||
|
||||
// Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
|
||||
// previous graph and the head of next graph.
|
||||
void CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
|
||||
// The processing of actors link.
|
||||
// The gather of linking data allows of kernel, it will call following functions by the different from actor type.
|
||||
void LinkDataArrow(KernelActor *to_actor, const ActorSet *actor_set, const KernelGraphPtr &graph,
|
||||
KernelWithIndex from_kernel_with_output_idx, KernelWithIndex to_kernel_with_input_idx);
|
||||
// Link data arrows for internal parameter, convert internal parameter to actor by internal parameter cache to link.
|
||||
void LinkDataArrowForInternalParameter(const AnfNodePtr &internal_parameter, const KernelGraphPtr &graph,
|
||||
KernelActor *to_actor, KernelWithIndex to_kernel_with_input_idx);
|
||||
// Link data arrows in the copy actor scene, insert the copy actor between from_actor and to_actor.
|
||||
void LinkDataArrowForCopyActor(OpActor<DeviceTensor> *from_actor, KernelActor *to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx, KernelWithIndex to_kernel_with_input_idx);
|
||||
void LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
|
||||
KernelWithIndex from_kernel_with_output_idx,
|
||||
KernelWithIndex to_to_kernel_with_input_idx);
|
||||
|
@ -158,12 +174,8 @@ class GraphScheduler {
|
|||
void LinkControlArrowForKernelActor(std::vector<KernelActorPtr> *from_actors, LoopCountActor *to_actor,
|
||||
GraphExecutionStrategy strategy);
|
||||
void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set);
|
||||
void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node,
|
||||
const KernelMapActor &kernel_actors_map);
|
||||
void LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
|
||||
const std::vector<DataSourceActorPtr> &data_source_actors,
|
||||
const KernelMapActor &kernel_actors_map,
|
||||
const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node);
|
||||
void LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info);
|
||||
|
||||
// Check whether the actor set is valid.
|
||||
bool CheckActorValid(const ActorSet *actor_set) const;
|
||||
|
@ -174,6 +186,10 @@ class GraphScheduler {
|
|||
// Fetch the hsot tensor queue by actor info.
|
||||
HostTensorQueue *FetchHostQueue(const ActorInfo &actor_info) const;
|
||||
|
||||
// The operation of the map of actor_name_to_actor_.
|
||||
void InsertActor(OpActor<DeviceTensor> *actor);
|
||||
OpActor<DeviceTensor> *FetchActor(const std::string actor_name) const;
|
||||
|
||||
// Display the actor information of corresponding kernel graph.
|
||||
void DumpActor(const ActorSet *actor_set) const;
|
||||
void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const;
|
||||
|
@ -181,11 +197,19 @@ class GraphScheduler {
|
|||
void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const;
|
||||
void DumpOutputActor(const OutputActor *actor, std::ofstream &ofs) const;
|
||||
|
||||
// The global maps, only be cleared in the deconstruction.
|
||||
std::unordered_map<ActorInfo, ActorSetPtr> actors_;
|
||||
std::unordered_map<ActorInfo, HostTensorQueuePtr> actor_to_host_queue_;
|
||||
// The second element of pair represents the output index of op actor corresponding to the device tensor.
|
||||
std::unordered_map<DeviceTensorPtr, std::pair<OpActor<DeviceTensor> *, size_t>> device_tensor_to_actor_;
|
||||
|
||||
// The second element of pair represents the output index of kernel actor corresponding to the device tensor.
|
||||
std::unordered_map<DeviceTensorPtr, std::pair<KernelActorPtr, int>> device_address_to_actor_;
|
||||
// The local maps and vectors, will be cleared at the beginning of each graph transform.
|
||||
std::unordered_map<std::string, OpActor<DeviceTensor> *> actor_name_to_actor_;
|
||||
// The second element of pair represents the output index of op actor corresponding to the graph output front node.
|
||||
std::map<KernelWithIndex, std::pair<OpActor<DeviceTensor> *, size_t>, session::KernelWithIndexCmp> output_to_actor_;
|
||||
// Beaceuse the copy actors are built in the link, so need record the all copy actors in the link process to push into
|
||||
// the actor set after link.
|
||||
std::vector<CopyActorPtr> copy_actors_;
|
||||
|
||||
// The id of memory manager actor.
|
||||
AID memory_manager_aid_;
|
||||
|
|
Loading…
Reference in New Issue