!30389 Add impl for rpc actor

Merge pull request !30389 from ZPaC/add-impl-for-rpc
This commit is contained in:
i-robot 2022-02-23 01:31:17 +00:00 committed by Gitee
commit 9f416126eb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
13 changed files with 248 additions and 20 deletions

View File

@ -37,7 +37,7 @@ bool Initialize() {
std::dynamic_pointer_cast<ps::core::AbstractNode>(cluster::ClusterContext::instance()->node());
MS_EXCEPTION_IF_NULL(abstract_node);
collective::CollectiveManager::instance()->set_global_rank_id(abstract_node->rank_id());
collective::CollectiveManager::instance()->set_global_rank_size(IntToUint(abstract_node->worker_num()));
collective::CollectiveManager::instance()->set_global_rank_size(abstract_node->worker_num());
if (!InitializeCollective()) {
MS_LOG(ERROR) << "Failed to initialize collective communication.";

View File

@ -128,6 +128,15 @@ bool IsSkippedKernelActor(const AnfNodePtr &node) {
return false;
}
bool IsRpcActor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (IsKernelActor(node) &&
(AnfAlgo::GetCNodeName(node) == kRpcSendOpName || AnfAlgo::GetCNodeName(node) == kRpcRecvOpName)) {
return true;
}
return false;
}
bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<ValueNode>()) {

View File

@ -191,6 +191,8 @@ bool IsSwitchActor(const AnfNodePtr &node);
// The skip kernel doesn't run, it exists in the inplace optimizer.
bool IsSkippedKernelActor(const AnfNodePtr &node);
bool IsRpcActor(const AnfNodePtr &node);
// Internal parameter is not the origin parameter of func graph, it is the output of previous kernel graph which is
// related to the input of this kernel graph.
bool IsInternalParameter(const AnfNodePtr &node, const KernelGraphPtr &graph);

View File

@ -98,6 +98,7 @@ struct ActorSet {
LoopCountActorPtr loop_count_actor_{nullptr};
OutputActorPtr output_actor_{nullptr};
ControlActorSetPtr control_actors_{nullptr};
RpcActorSetPtr rpc_actors_{nullptr};
ActorInfo name_;
// The related statistics information of multi thread and single thread to decide whether use the multi thread.
bool is_multi_thread_execution_{true};

View File

@ -0,0 +1,25 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/graph_scheduler/actor/rpc/recv_actor.h"
namespace mindspore {
namespace runtime {
void RecvActor::SetRouteInfo(uint32_t, const std::string &, const std::string &src_node_name, const std::string &) {
input_peer_node_name_.emplace_back(src_node_name);
}
} // namespace runtime
} // namespace mindspore

View File

@ -28,14 +28,17 @@ namespace runtime {
// RecvActor inherits from RpcActor and it's used to receive data from other processes.
class RecvActor : public RpcActor {
public:
RecvActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes)
explicit RecvActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes)
: RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kRecvActor) {}
~RecvActor() override = default;
void SetRouteInfo(uint32_t src_rank, const std::string &src_role, const std::string &src_node_name,
const std::string &dst_node_name) override;
private:
friend class GraphScheduler;
};

View File

@ -33,15 +33,31 @@ using mindspore::device::KernelInfo;
// communication with other processes. It supports both sync and async communication.
class RpcActor : public KernelActor {
public:
RpcActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes, const KernelTransformType &type)
explicit RpcActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes, const KernelTransformType &type)
: KernelActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
modifiable_ref_input_indexes, modifiable_ref_output_indexes, type) {}
virtual ~RpcActor() = default;
// Set some info which will be used for rpc routing.
virtual void SetRouteInfo(uint32_t peer_rank, const std::string &peer_role, const std::string &src_node_name,
const std::string &dst_node_name) {}
// When an inter-process data received, this method is called.
void RunOpInterProcessData(std::unique_ptr<MessageBase> &&msg, OpContext<DeviceTensor> *const context);
protected:
// Besides the checking method in base class AbstractActor, condition of inter-process arrows should be checked.
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override { return true; }
// After rpc kernel is launched, inter-process data could be sent.
void SendOutput(OpContext<DeviceTensor> *const context) override {}
// The node name of rpc actor's peers.
std::vector<std::string> input_peer_node_name_;
std::vector<std::string> output_peer_node_name_;
// The arrows represent inter-process communication.
std::vector<AID> inter_process_input_arrows_;
std::vector<AID> inter_process_output_arrows_;

View File

@ -0,0 +1,25 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/graph_scheduler/actor/rpc/send_actor.h"
namespace mindspore {
namespace runtime {
void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &, const std::string &dst_node_name) {
output_peer_node_name_.emplace_back(dst_node_name);
}
} // namespace runtime
} // namespace mindspore

View File

@ -28,14 +28,17 @@ namespace runtime {
// SendActor inherits from RpcActor and it's used to send data to other processes.
class SendActor : public RpcActor {
public:
SendActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes)
explicit SendActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes)
: RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kSendActor) {}
~SendActor() override = default;
void SetRouteInfo(uint32_t dst_rank, const std::string &dst_role, const std::string &src_node_name,
const std::string &dst_node_name) override;
private:
friend class GraphScheduler;
};

View File

@ -296,6 +296,11 @@ void GraphScheduler::Initialize() {
MS_LOG(INFO) << "The actor thread number: " << actor_thread_num
<< ", the kernel thread number: " << (actor_and_kernel_thread_num - actor_thread_num);
// Create and initialize RpcNodeScheduler.
rpc_node_scheduler_ = std::make_unique<RpcNodeScheduler>();
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
rpc_node_scheduler_->Initialize();
BuildAndScheduleGlobalActor();
}
@ -536,6 +541,9 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info)
actor_set->data_prepare_actor_ =
BuildDataPrepareActor(graph_compiler_info, actor_set->data_source_actors_, host_queue);
actor_set->control_actors_ = control_node_scheduler_.Build(graph_compiler_info, memory_manager_aid_);
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
actor_set->rpc_actors_ = rpc_node_scheduler_->Build(graph_compiler_info);
return actor_set;
}
@ -783,9 +791,14 @@ std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompiler
if (IsKernelActor(kernel, graph_compiler_info.strategy_) && (!IsSkippedKernelActor(kernel))) {
auto ref_input_indexes = FetchModifiableRefInputIndex(kernel);
auto ref_output_indexes = FetchModifiableRefOutputIndex(kernel, graph);
auto kernel_actor =
std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_,
debug_aid_, recorder_aid_, strategy, ref_input_indexes, ref_output_indexes);
KernelActorPtr kernel_actor = nullptr;
if (IsRpcActor(kernel)) {
kernel_actor = GenerateRpcActor(kernel, device_context, strategy, ref_input_indexes, ref_output_indexes);
} else {
kernel_actor =
std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_,
debug_aid_, recorder_aid_, strategy, ref_input_indexes, ref_output_indexes);
}
MS_EXCEPTION_IF_NULL(kernel_actor);
InsertActor(kernel_actor.get());
(void)kernel_actors.emplace_back(kernel_actor);
@ -939,6 +952,33 @@ std::vector<AbstractActorPtr> GraphScheduler::BuildNoInputKernelActor(const Acto
return no_input_kernel_actors;
}
KernelActorPtr GraphScheduler::GenerateRpcActor(const CNodePtr &kernel, const DeviceContext *device_context,
GraphExecutionStrategy strategy,
const std::set<size_t> &ref_input_indexes,
const std::set<size_t> &ref_output_indexes) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
if (AnfAlgo::GetCNodeName(kernel) == kRpcSendOpName) {
auto send_actor =
std::make_shared<SendActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_,
debug_aid_, recorder_aid_, strategy, ref_input_indexes, ref_output_indexes);
MS_EXCEPTION_IF_NULL(send_actor);
rpc_node_scheduler_->InsertSendActor(send_actor);
return send_actor;
} else if (AnfAlgo::GetCNodeName(kernel) == kRpcRecvOpName) {
auto recv_actor =
std::make_shared<RecvActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_,
debug_aid_, recorder_aid_, strategy, ref_input_indexes, ref_output_indexes);
MS_EXCEPTION_IF_NULL(recv_actor);
rpc_node_scheduler_->InsertRecvActor(recv_actor);
return recv_actor;
} else {
MS_LOG(EXCEPTION) << "Kernel " << kernel->fullname_with_scope() << " is not an rpc kernel.";
}
return nullptr;
}
void GraphScheduler::LinkDataArrowInSinkMode(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info,
std::vector<AbstractActor *> *const auto_monad_actors) {
MS_EXCEPTION_IF_NULL(graph);

View File

@ -28,6 +28,7 @@
#include "utils/hash_map.h"
#include "utils/hash_set.h"
#include "runtime/graph_scheduler/control_node_scheduler.h"
#include "runtime/graph_scheduler/rpc_node_scheduler.h"
#include "runtime/graph_scheduler/actor/actor_set.h"
#include "runtime/graph_scheduler/graph_compiler.h"
#include "runtime/graph_scheduler/actor/actor_dump.h"
@ -107,6 +108,11 @@ class GraphScheduler {
const HostTensorQueuePtr &host_queue);
std::vector<AbstractActorPtr> BuildNoInputKernelActor(const ActorSet *actor_set, GraphExecutionStrategy strategy);
// Generate rpc actor object inherited from kernel actor.
KernelActorPtr GenerateRpcActor(const CNodePtr &kernel, const DeviceContext *device_context,
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes);
// Cache the information of graph output node to actor between “build” and “link”, for linking between the tail of
// previous graph and the head of next graph.
void CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info);
@ -201,6 +207,9 @@ class GraphScheduler {
// In the control flow, used to build and link control actor.
ControlNodeScheduler control_node_scheduler_;
// Used to build and link for rpc actors.
std::unique_ptr<RpcNodeScheduler> rpc_node_scheduler_{nullptr};
// The id of global actor.
AID memory_manager_aid_;
const AID *recorder_aid_{nullptr};

View File

@ -0,0 +1,87 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/graph_scheduler/rpc_node_scheduler.h"
namespace mindspore {
namespace runtime {
void RpcNodeScheduler::Initialize() {
rpc_actor_set_ = std::make_shared<RpcActorSet>();
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
}
RpcActorSetPtr RpcNodeScheduler::Build(const GraphCompilerInfo &) {
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
return rpc_actor_set_;
}
void RpcNodeScheduler::Link(const ActorSetPtr &) {
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
std::vector<SendActorPtr> send_actors = rpc_actor_set_->send_actors_;
std::vector<RecvActorPtr> recv_actors = rpc_actor_set_->recv_actors_;
// The inter-process edge is connected to a remote peer. So the peer info attributes in the kernel should be
// sufficient for route table.
for (auto &send_actor : send_actors) {
CNodePtr rpc_send_kernel = send_actor->kernel();
MS_EXCEPTION_IF_NULL(rpc_send_kernel);
auto send_dst_ranks = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(rpc_send_kernel, kAttrSendDstRanks);
auto send_dst_roles = AnfAlgo::GetNodeAttr<std::vector<std::string>>(rpc_send_kernel, kAttrSendDstRoles);
std::string send_src_node_name = AnfAlgo::GetNodeAttr<std::string>(rpc_send_kernel, kAttrSendSrcNodeName);
std::string send_dst_node_name = AnfAlgo::GetNodeAttr<std::string>(rpc_send_kernel, kAttrSendDstNodeName);
if (send_dst_ranks.empty() || send_dst_roles.empty()) {
MS_LOG(EXCEPTION) << "The attributes of send node " << rpc_send_kernel->fullname_with_scope()
<< " is invalid. send_dst_ranks: " << send_dst_ranks << ", send_dst_roles: " << send_dst_roles
<< ", send_src_node_name: " << send_src_node_name
<< ", send_dst_node_name: " << send_dst_node_name;
return;
}
send_actor->SetRouteInfo(send_dst_ranks[0], send_dst_roles[0], send_src_node_name, send_dst_node_name);
}
for (auto &recv_actor : recv_actors) {
CNodePtr rpc_recv_kernel = recv_actor->kernel();
MS_EXCEPTION_IF_NULL(rpc_recv_kernel);
auto recv_src_ranks = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(rpc_recv_kernel, kAttrRecvSrcRanks);
auto recv_src_roles = AnfAlgo::GetNodeAttr<std::vector<std::string>>(rpc_recv_kernel, kAttrRecvSrcRoles);
std::string recv_src_node_name = AnfAlgo::GetNodeAttr<std::string>(rpc_recv_kernel, kAttrRecvSrcNodeName);
std::string recv_dst_node_name = AnfAlgo::GetNodeAttr<std::string>(rpc_recv_kernel, kAttrRecvDstNodeName);
if (recv_src_ranks.empty() || recv_src_roles.empty()) {
MS_LOG(EXCEPTION) << "The attributes of recv node " << rpc_recv_kernel->fullname_with_scope()
<< " is invalid. recv_src_ranks: " << recv_src_ranks << ", recv_src_roles: " << recv_src_roles
<< ", recv_src_node_name: " << recv_src_node_name
<< ", recv_dst_node_name: " << recv_dst_node_name;
return;
}
recv_actor->SetRouteInfo(recv_src_ranks[0], recv_src_roles[0], recv_src_node_name, recv_dst_node_name);
}
}
void RpcNodeScheduler::InsertSendActor(const SendActorPtr &send_actor) {
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
MS_EXCEPTION_IF_NULL(send_actor);
rpc_actor_set_->send_actors_.emplace_back(send_actor);
}
void RpcNodeScheduler::InsertRecvActor(const RecvActorPtr &recv_actor) {
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
MS_EXCEPTION_IF_NULL(recv_actor);
rpc_actor_set_->recv_actors_.emplace_back(recv_actor);
}
} // namespace runtime
} // namespace mindspore

View File

@ -22,8 +22,6 @@
#include <memory>
#include "runtime/graph_scheduler/actor/actor_set.h"
#include "runtime/graph_scheduler/graph_compiler.h"
#include "distributed/rpc/tcp/tcp_client.h"
#include "distributed/rpc/tcp/tcp_server.h"
namespace mindspore {
namespace runtime {
@ -37,11 +35,21 @@ class RpcNodeScheduler {
RpcNodeScheduler() = default;
~RpcNodeScheduler() = default;
// Cast some actors to rpc actors according to its kernel name.
RpcActorSetPtr Build(const ActorSetPtr &actor_set);
// Create rpc actor set.
void Initialize();
// Build rpc actors and return rpc actor set.
RpcActorSetPtr Build(const GraphCompilerInfo &graph_compiler_info);
// Link rpc actors with inter-process arrows.
void Link(const ActorSetPtr &actor_set);
// Insert Send/Recv actors generated by GraphScheduler to the rpc actor set.
void InsertSendActor(const SendActorPtr &send_actor);
void InsertRecvActor(const RecvActorPtr &recv_actor);
private:
RpcActorSetPtr rpc_actor_set_;
};
} // namespace runtime
} // namespace mindspore