!30611 Fix rpc route bugs

Merge pull request !30611 from ZPaC/sync-route-table
This commit is contained in:
i-robot 2022-02-28 15:50:35 +00:00 committed by Gitee
commit 4367377000
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
23 changed files with 443 additions and 43 deletions

View File

@ -80,6 +80,9 @@ ActorAddress ActorRouteTableProxy::LookupRoute(const std::string &actor_id) cons
lookup_success = true;
}
} while (!lookup_success && CURRENT_TIMESTAMP_MILLI <= timeout_ts);
if (!lookup_success) {
MS_LOG(EXCEPTION) << "Failed to lookup actor address for " << actor_id;
}
return lookup_route_rsp_msg;
}

View File

@ -31,15 +31,19 @@ using ps::core::ActorAddress;
using ps::core::GeneralResponseMsg;
using ps::core::NodeCommand;
// The timeout in milliseconds for one lookup.
constexpr uint32_t kDefaultLookupTimeout = 5000;
// The time in milliseconds between two lookup operations.
constexpr auto kLookupInterval = 100;
constexpr uint32_t kLookupInterval = 100;
// Actor route table proxy for nodes like workers and server. This class helps update actor route table in scheduler
// across the network.
class ActorRouteTableProxy {
public:
explicit ActorRouteTableProxy(const std::shared_ptr<ps::core::AbstractNode> &node, uint32_t lookup_timout)
: node_(node), lookup_timeout_(std::chrono::milliseconds(lookup_timout)) {}
explicit ActorRouteTableProxy(const std::shared_ptr<ps::core::AbstractNode> &node,
uint32_t lookup_timeout = kDefaultLookupTimeout)
: node_(node), lookup_timeout_(std::chrono::milliseconds(lookup_timeout)) {}
~ActorRouteTableProxy() = default;
// Register actor address to the route table stored in scheduler.
@ -55,9 +59,11 @@ class ActorRouteTableProxy {
// The node variable helps proxy to communicate with scheduler, e.g., SendMessage.
std::shared_ptr<ps::core::AbstractNode> node_;
// The timeout window for lookup route operation because time of route lookup_timout of each process is different.
// The timeout window for lookup route operation because time of route lookup_timeout of each process is different.
std::chrono::milliseconds lookup_timeout_;
};
using ActorRouteTableProxyPtr = std::shared_ptr<ActorRouteTableProxy>;
} // namespace cluster
} // namespace distributed
} // namespace mindspore

View File

@ -34,6 +34,7 @@ ClusterContext::ClusterContext()
scheduler_host_(kLocalHost),
scheduler_port_(kDefaultSchedPort),
node_(nullptr),
abstract_node_(nullptr),
node_role_(""),
cluster_config_(nullptr) {}
@ -84,6 +85,14 @@ bool ClusterContext::Initialize() {
return false;
}
// Step 3: Initialize some modules for the node, e.g., actor route table proxy.
if (!IsScheduler()) {
// Only node which is not the scheduler needs route table proxy.
actor_route_table_proxy_ =
std::make_shared<ActorRouteTableProxy>(std::dynamic_pointer_cast<ps::core::AbstractNode>(node_));
MS_EXCEPTION_IF_NULL(actor_route_table_proxy_);
}
inited_ = true;
finalized_ = false;
return true;
@ -105,6 +114,8 @@ bool ClusterContext::Finalize(uint32_t timeout) {
return true;
}
bool ClusterContext::IsScheduler() { return (abstract_node_ == nullptr) ? true : false; }
const std::shared_ptr<ps::core::Node> &ClusterContext::node() const { return node_; }
const std::string &ClusterContext::node_role() const { return node_role_; }
@ -120,6 +131,8 @@ uint32_t ClusterContext::node_num(const std::string &node_role) {
bool ClusterContext::initialized() const { return inited_; }
const ActorRouteTableProxyPtr &ClusterContext::actor_route_table_proxy() const { return actor_route_table_proxy_; }
void ClusterContext::InitClusterConfig() {
InitNodeRole();
InitSchedulerIp();
@ -154,6 +167,7 @@ bool ClusterContext::BuildCluster() {
MS_LOG(ERROR) << "Building network failed.";
return false;
}
abstract_node_ = std::dynamic_pointer_cast<ps::core::AbstractNode>(node_);
MS_LOG(INFO) << "Cluster is successfully initialized.";
return true;
}

View File

@ -33,6 +33,7 @@
#include "ps/core/worker_node.h"
#include "ps/core/server_node.h"
#include "ps/core/scheduler_node.h"
#include "distributed/cluster/actor_route_table_proxy.h"
namespace mindspore {
namespace distributed {
@ -57,6 +58,10 @@ class ClusterContext {
// Finalize the cluster and process exits. If timeout is set to UINT32_MAX, this method will block without timeout.
bool Finalize(uint32_t timeout = kDefaultFinishTimeout);
// Return whether this node is the scheduler node.
// In a cluster, the scheduler node is special because it's responsible for building network.
bool IsScheduler();
// Return node object of this process.
const std::shared_ptr<ps::core::Node> &node() const;
@ -69,6 +74,9 @@ class ClusterContext {
// Return cluster is initialized.
bool initialized() const;
// Return actor route proxy for AbstractNode.
const ActorRouteTableProxyPtr &actor_route_table_proxy() const;
private:
ClusterContext();
@ -106,11 +114,17 @@ class ClusterContext {
// The node could be Worker, Server or Scheduler, etc.
std::shared_ptr<ps::core::Node> node_;
// abstract_node_ is nullptr only when this is node is scheduler.
std::shared_ptr<ps::core::AbstractNode> abstract_node_;
// The role of this process in the cluster.
std::string node_role_;
// The configuration of this cluster.
std::unique_ptr<ps::core::ClusterConfig> cluster_config_;
// The actor route table proxy. It only created in abstract nodes because scheduler does not use proxy.
ActorRouteTableProxyPtr actor_route_table_proxy_;
};
} // namespace cluster
} // namespace distributed

View File

@ -275,12 +275,13 @@ CNodePtr GraphSplitter::GenerateSendNode(const AnfNodePtr &input, const AnfNodeP
MS_EXCEPTION_IF_NULL(peer);
std::vector<AnfNodePtr> send_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcSendOpName))};
auto mock_value = GenerateMockValueNode(true, input);
MS_EXCEPTION_IF_NULL(mock_value);
ValueNodePtr mock_value = nullptr;
if (IsPrimitiveCNode(input, prim::kPrimUpdateState)) {
mock_value = GenerateMockValueNode(false);
send_inputs.push_back(mock_value);
send_inputs.push_back(input);
} else {
mock_value = GenerateMockValueNode(true, input);
send_inputs.push_back(input);
}
CNodePtr send_node = func_graph_->NewCNode(send_inputs);

View File

@ -30,7 +30,6 @@
#include "include/common/utils/anfalgo.h"
#include "kernel/common_utils.h"
#include "ir/anf.h"
#include "runtime/graph_scheduler/graph_scheduler.h"
#include "actor/actormgr.h"
#include "include/common/thread_pool.h"

View File

@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RPC_RPC_KERNEL_H_
#include <vector>
#include <memory>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
@ -35,6 +36,9 @@ class RpcKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override { return; }
// Set remote data as input.
void SetRemoteInput(std::unique_ptr<MessageBase> &&) {}
private:
};
} // namespace kernel

View File

@ -312,10 +312,12 @@ bool AbstractNode::SendToScheduler(const void *message, size_t len, NodeCommand
// Assign the response value from scheduler.
if (output != nullptr) {
if (received_scheduler_messages_.count(request_id) == 0) {
MS_LOG(ERROR) << "The response message of " << node_cmd << " is not received yet.";
MS_LOG(ERROR) << "The response message of command " << node_cmd << ", request_id " << request_id
<< " is not received yet.";
return false;
}
*output = received_scheduler_messages_[request_id];
(void)received_scheduler_messages_.erase(request_id);
}
return ret;
}

View File

@ -107,9 +107,6 @@ bool Node::Wait(uint64_t request_id, const uint32_t &timeout) {
if (receive_messages_.count(request_id) != 0) {
(void)receive_messages_.erase(request_id);
}
if (received_scheduler_messages_.count(request_id) != 0) {
(void)received_scheduler_messages_.erase(request_id);
}
msgs_lock.unlock();
return res;
}

View File

@ -3,6 +3,13 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/core/mindrt/src)
file(GLOB_RECURSE GRAPH_SCHEDULER_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM GRAPH_SCHEDULER_SRC_LIST "rpc_node_scheduler.cc")
list(REMOVE_ITEM GRAPH_SCHEDULER_SRC_LIST "actor/rpc/recv_actor.cc")
list(REMOVE_ITEM GRAPH_SCHEDULER_SRC_LIST "actor/rpc/rpc_actor.cc")
list(REMOVE_ITEM GRAPH_SCHEDULER_SRC_LIST "actor/rpc/send_actor.cc")
endif()
set_property(SOURCE ${GRAPH_SCHEDULER_SRC_LIST}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_RUNTIME_FRAMEWORK)
add_library(_mindspore_runtime_graph_scheduler_obj OBJECT ${GRAPH_SCHEDULER_SRC_LIST})

View File

@ -17,6 +17,10 @@
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SET_H_
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SET_H_
#if ((defined ENABLE_CPU) && (!defined _WIN32))
#define ENABLE_RPC_ACTOR
#endif
#include <vector>
#include <string>
#include <memory>
@ -30,8 +34,6 @@
#include "runtime/graph_scheduler/actor/loop_count_actor.h"
#include "runtime/graph_scheduler/actor/kernel_actor.h"
#include "runtime/graph_scheduler/actor/custom_actor.h"
#include "runtime/graph_scheduler/actor/rpc/send_actor.h"
#include "runtime/graph_scheduler/actor/rpc/recv_actor.h"
#include "runtime/graph_scheduler/actor/super_kernel_actor.h"
#include "runtime/graph_scheduler/actor/output_actor.h"
#include "runtime/graph_scheduler/actor/copy_actor.h"
@ -41,6 +43,11 @@
#include "runtime/graph_scheduler/actor/control_flow/exit_actor.h"
#include "runtime/graph_scheduler/actor/control_flow/stack_actor.h"
#ifdef ENABLE_RPC_ACTOR
#include "runtime/graph_scheduler/actor/rpc/send_actor.h"
#include "runtime/graph_scheduler/actor/rpc/recv_actor.h"
#endif
namespace mindspore {
namespace runtime {
using ActorInfo = std::string;
@ -63,6 +70,7 @@ struct ControlActorSet {
};
using ControlActorSetPtr = std::shared_ptr<ControlActorSet>;
#ifdef ENABLE_RPC_ACTOR
// Rpc actor set is a series of actors implemented to communicate with other processes. In distributed execution mode,
// the graph could be considered as partitioned to different processes, which is connected by these rpc actors. Send
// actors are in charge of sending data to other processes. Recv actors are in charge of receiving data from other
@ -72,6 +80,7 @@ struct RpcActorSet {
std::vector<RecvActorPtr> recv_actors_;
};
using RpcActorSetPtr = std::shared_ptr<RpcActorSet>;
#endif
// The actor set generated by graph transformer is the execution unit of actor runtime.
// It includes data source actor, kernel actor, switch actor, copy actor, loop count actor and output actor.
@ -98,7 +107,9 @@ struct ActorSet {
LoopCountActorPtr loop_count_actor_{nullptr};
OutputActorPtr output_actor_{nullptr};
ControlActorSetPtr control_actors_{nullptr};
#ifdef ENABLE_RPC_ACTOR
RpcActorSetPtr rpc_actors_{nullptr};
#endif
ActorInfo name_;
// The related statistics information of multi thread and single thread to decide whether use the multi thread.
bool is_multi_thread_execution_{true};

View File

@ -50,8 +50,8 @@ class KernelActor : public DebugAwareActor {
const std::set<size_t> &modifiable_ref_output_indexes,
const KernelTransformType &type = KernelTransformType::kKernelActor)
: DebugAwareActor(name, type, recorder_aid, memory_manager_aid, debug_aid),
kernel_(kernel),
kernel_info_(nullptr),
kernel_(kernel),
is_dynamic_shape_(false),
real_input_num_(0),
strategy_(strategy),
@ -85,6 +85,10 @@ class KernelActor : public DebugAwareActor {
void Run(OpContext<DeviceTensor> *const context) override;
void SendRecorderInfo(OpContext<DeviceTensor> *const context) const override;
KernelInfo *kernel_info_;
// The kernel launch info is fetched by the device tensors.
KernelLaunchInfo launch_info_;
private:
friend class GraphScheduler;
friend class ControlNodeScheduler;
@ -111,7 +115,6 @@ class KernelActor : public DebugAwareActor {
// The info of kernel.
CNodePtr kernel_;
KernelInfo *kernel_info_;
bool is_dynamic_shape_;
// The real input number of kernel launch.
@ -138,9 +141,6 @@ class KernelActor : public DebugAwareActor {
// The device tensor of external reference is not the real data of this kernel, but need add to the memory_free_list_.
std::vector<DeviceTensor *> external_reference_tensors_;
// The kernel launch info is fetched by the device tensors.
KernelLaunchInfo launch_info_;
// Record the modifiable ref indexes. Used to refresh the ref data which are modified in the running.
std::set<size_t> modifiable_ref_input_indexes_;
std::set<size_t> modifiable_ref_output_indexes_;

View File

@ -16,10 +16,94 @@
#include "runtime/graph_scheduler/actor/rpc/recv_actor.h"
#include <utility>
#include "plugin/device/cpu/kernel/rpc/rpc_recv_kernel.h"
namespace mindspore {
namespace runtime {
void RecvActor::SetRouteInfo(uint32_t, const std::string &, const std::string &src_node_name, const std::string &) {
input_peer_node_name_.emplace_back(src_node_name);
void RecvActor::RunOpInterProcessData(std::unique_ptr<MessageBase> &&msg, OpContext<DeviceTensor> *const context) {
MS_ERROR_IF_NULL_WO_RET_VAL(msg);
MS_ERROR_IF_NULL_WO_RET_VAL(op_context_);
auto &sequential_num = context->sequential_num_;
(void)input_op_inter_process_[sequential_num].emplace_back(msg->From().Name());
auto is_run = CheckRunningCondition(context);
MS_LOG(INFO) << "Actor(" << GetAID().Name() << ") receive the input op inter-process. Edge is "
<< inter_process_edge_name_ << ". Check running condition:" << is_run;
// Parse the message from remote peer and set to rpc recv kernel.
auto recv_kernel_mod = dynamic_cast<kernel::RpcKernelMod *>(kernel_info_->MutableKernelMod());
MS_ERROR_IF_NULL_WO_RET_VAL(recv_kernel_mod);
// We set remote data by the interface of the rpc kernel, because currently there's no remote input for a kernel mod.
recv_kernel_mod->SetRemoteInput(std::move(msg));
if (is_run) {
Run(context);
}
return;
}
void RecvActor::SetRouteInfo(uint32_t, const std::string &, const std::string &recv_src_node_name,
const std::string &recv_dst_node_name) {
rpc_input_node_name_.emplace_back(recv_src_node_name);
input_inter_process_num_++;
}
bool RecvActor::StartServer() {
// Step 1: Create a tcp server and start listening.
std::string server_url = ip_ + ":" + std::to_string(port_);
server_ = std::make_unique<TCPServer>();
MS_EXCEPTION_IF_NULL(server_);
if (!server_->Initialize(server_url)) {
MS_LOG(EXCEPTION) << "Failed to initialize tcp server for recv actor. Server url: " << server_url;
}
// Step 2: Set the message handler of the server.
// Step 2: Register the server address to route table. The server should not be connected before this step is done.
ActorAddress recv_actor_addresss;
recv_actor_addresss.set_actor_id(inter_process_edge_name_);
recv_actor_addresss.set_ip(ip_);
recv_actor_addresss.set_port(static_cast<uint32_t>(port_));
MS_EXCEPTION_IF_NULL(actor_route_table_proxy_);
if (!actor_route_table_proxy_->RegisterRoute(inter_process_edge_name_, recv_actor_addresss)) {
MS_LOG(EXCEPTION) << "Failed to register route for " << inter_process_edge_name_ << " " << server_url
<< " when starting server.";
}
return true;
}
bool RecvActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);
// Step 1: Judge data and control inputs are satisfied.
bool is_data_and_control_arrow_satisfied = AbstractActor::CheckRunningCondition(context);
if (!is_data_and_control_arrow_satisfied) {
return false;
}
if (input_inter_process_num_ != 0) {
// Step 2: Judge inter-process inputs are satisfied.
const auto &inter_process_iter = input_op_inter_process_.find(context->sequential_num_);
if (inter_process_iter == input_op_inter_process_.end()) {
return false;
}
const auto &current_inter_process_inputs = inter_process_iter->second;
if (current_inter_process_inputs.size() < input_inter_process_num_) {
return false;
} else if (current_inter_process_inputs.size() > input_inter_process_num_) {
MS_LOG(ERROR) << "Invalid inter process input num:" << current_inter_process_inputs.size()
<< " need:" << input_inter_process_num_ << " for actor:" << GetAID();
return false;
}
}
return true;
}
void RecvActor::HandleMessage(std::unique_ptr<MessageBase> &&msg) {
MS_ERROR_IF_NULL_WO_RET_VAL(msg);
MS_ERROR_IF_NULL_WO_RET_VAL(op_context_);
RunOpInterProcessData(std::move(msg), op_context_);
}
} // namespace runtime
} // namespace mindspore

View File

@ -36,11 +36,31 @@ class RecvActor : public RpcActor {
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kRecvActor) {}
~RecvActor() override = default;
void SetRouteInfo(uint32_t src_rank, const std::string &src_role, const std::string &src_node_name,
const std::string &dst_node_name) override;
// When an inter-process data received, this method is called.
void RunOpInterProcessData(std::unique_ptr<MessageBase> &&msg, OpContext<DeviceTensor> *const context);
// Set recv actor's source peer info, in another word, recv actor's input.
void SetRouteInfo(uint32_t src_rank, const std::string &src_role, const std::string &recv_src_node_name,
const std::string &recv_dst_node_name) override;
// Start recv actor server and register this server address to actor route table in scheduler by proxy.
bool StartServer();
protected:
// Besides the checking method in base class AbstractActor, condition of inter-process arrows should be checked for
// recv actor.
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
private:
void HandleMessage(std::unique_ptr<MessageBase> &&msg);
friend class GraphScheduler;
// The network address of this recv actor. It's generated automatically by rpc module.
std::string ip_;
uint16_t port_;
std::unique_ptr<TCPServer> server_;
};
using RecvActorPtr = std::shared_ptr<RecvActor>;

View File

@ -0,0 +1,29 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/graph_scheduler/actor/rpc/rpc_actor.h"
namespace mindspore {
namespace runtime {
void RpcActor::SetInterProcessEdgeName(const std::string &src_node_name, const std::string &dst_node_name) {
inter_process_edge_name_ = src_node_name + kInterProcessEdgeMark + dst_node_name;
}
void RpcActor::SetOpcontext(OpContext<DeviceTensor> *const op_context) { op_context_ = op_context; }
void RpcActor::SetActorRouteRableProxy(const ActorRouteTableProxyPtr &proxy) { actor_route_table_proxy_ = proxy; }
} // namespace runtime
} // namespace mindspore

View File

@ -23,10 +23,22 @@
#include <memory>
#include <utility>
#include "runtime/graph_scheduler/actor/kernel_actor.h"
#include "distributed/cluster/cluster_context.h"
#include "distributed/rpc/tcp/tcp_client.h"
#include "distributed/rpc/tcp/tcp_server.h"
namespace mindspore {
namespace runtime {
using distributed::cluster::ActorRouteTableProxy;
using distributed::cluster::ActorRouteTableProxyPtr;
using distributed::cluster::ClusterContext;
using distributed::rpc::TCPClient;
using distributed::rpc::TCPServer;
using mindspore::device::KernelInfo;
using ps::core::ActorAddress;
// The inter-process edge mark between two nodes.
constexpr char kInterProcessEdgeMark[] = "->";
// RpcActor is used to do rpc with other processes in distributed execution.
// Besides data arrows and controlling arrows, RpcActor also has inter-process arrows which is in charge of remote
@ -38,30 +50,50 @@ class RpcActor : public KernelActor {
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
const std::set<size_t> &modifiable_ref_output_indexes, const KernelTransformType &type)
: KernelActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
modifiable_ref_input_indexes, modifiable_ref_output_indexes, type) {}
modifiable_ref_input_indexes, modifiable_ref_output_indexes, type),
input_inter_process_num_(0) {}
virtual ~RpcActor() = default;
// Normally, an actor's op_context is passed by its input actor, but rpc actors could be triggered by inter-process
// arrows which do not contain op_context. So we need to set op_context manually.
void SetOpcontext(OpContext<DeviceTensor> *const op_context);
// Set the actor route proxy for rpc actors.
void SetActorRouteRableProxy(const ActorRouteTableProxyPtr &proxy);
// Set the inter-process edge name for rpc actor.
void SetInterProcessEdgeName(const std::string &src_node_name, const std::string &dst_node_name);
// Set some info which will be used for rpc routing.
virtual void SetRouteInfo(uint32_t peer_rank, const std::string &peer_role, const std::string &src_node_name,
const std::string &dst_node_name) {}
// When an inter-process data received, this method is called.
void RunOpInterProcessData(std::unique_ptr<MessageBase> &&msg, OpContext<DeviceTensor> *const context);
protected:
// Besides the checking method in base class AbstractActor, condition of inter-process arrows should be checked.
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override { return true; }
// The op context to run rpc actor inter-process op.
OpContext<DeviceTensor> *op_context_;
// After rpc kernel is launched, inter-process data could be sent.
void SendOutput(OpContext<DeviceTensor> *const context) override {}
// The inter-process edge name. It is also used as the actor id for route. It's a string consists of source node name
// and destination node name. The format is "source node name"->"destination node name". For each inter-process edge,
// this is is unique. Rpc actor with the same inter_process_edge_name_ should not be in the same process.
std::string inter_process_edge_name_;
// The node name of rpc actor's peers. They are not the name of send or recv nodes. Instead, they are the names of the
// nodes which use send node as output and recv node as input.
std::vector<std::string> rpc_input_node_name_;
std::vector<std::string> rpc_output_node_name_;
// The iter-process inputs number. This should be the same as size of vector rpc_input_node_name_.
size_t input_inter_process_num_;
// The inter-process inputs of each sequential number.
mindspore::HashMap<int, std::vector<std::string>> input_op_inter_process_;
// The node name of rpc actor's peers.
std::vector<std::string> input_peer_node_name_;
std::vector<std::string> output_peer_node_name_;
// The arrows represent inter-process communication.
std::vector<AID> inter_process_input_arrows_;
std::vector<AID> inter_process_output_arrows_;
ActorRouteTableProxyPtr actor_route_table_proxy_;
private:
friend class GraphScheduler;
};

View File

@ -16,10 +16,69 @@
#include "runtime/graph_scheduler/actor/rpc/send_actor.h"
#include <utility>
namespace mindspore {
namespace runtime {
void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &, const std::string &dst_node_name) {
output_peer_node_name_.emplace_back(dst_node_name);
void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &send_src_node_name,
const std::string &send_dst_node_name) {
auto peer_actor_id = send_src_node_name + kInterProcessEdgeMark + send_dst_node_name;
peer_actor_ids_.emplace_back(peer_actor_id);
rpc_output_node_name_.emplace_back(send_dst_node_name);
}
bool SendActor::ConnectServer() {
client_ = std::make_unique<TCPClient>();
MS_EXCEPTION_IF_NULL(client_);
if (!client_->Initialize()) {
MS_LOG(EXCEPTION) << "Failed to initialize tcp server for send actor.";
}
// Lookup actor addresses for each peer actor.
for (const auto &peer_actor_id : peer_actor_ids_) {
MS_EXCEPTION_IF_NULL(actor_route_table_proxy_);
auto peer_actor_address = actor_route_table_proxy_->LookupRoute(peer_actor_id);
// If route is successfully looked up, peer_actor_address is not empty.
std::string server_url = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port());
if (!client_->Connect(server_url)) {
MS_LOG(EXCEPTION) << "Failed to connect to server of actor " << peer_actor_id << ", server_url: " << server_url;
}
peer_actor_urls_[peer_actor_id] = server_url;
}
return true;
}
void SendActor::SendOutput(OpContext<DeviceTensor> *const context) {
MS_ERROR_IF_NULL_WO_RET_VAL(context);
MS_ERROR_IF_NULL_WO_RET_VAL(client_);
// Step 1: Send data and control outputs.
AbstractActor::SendOutput(context);
// Step 2: Erase inter-process inputs for this sequential number.
if (input_op_inter_process_.count(context->sequential_num_) != 0) {
input_op_inter_process_.erase(context->sequential_num_);
}
// Step 3: Send output data(inter-process data) to peers.
if (launch_info_.outputs_.empty()) {
MS_LOG(ERROR) << "Send kernel has no output tensor.";
return;
}
auto send_output = launch_info_.outputs_[0];
for (const auto &peer : peer_actor_urls_) {
std::string peer_server_url = peer.second;
auto message = BuildRpcMessage(send_output, peer_server_url);
MS_ERROR_IF_NULL_WO_RET_VAL(message);
client_->Send(std::move(message));
}
}
std::unique_ptr<MessageBase> SendActor::BuildRpcMessage(const kernel::AddressPtr &data, const std::string &server_url) {
MS_ERROR_IF_NULL_W_RET_VAL(data, nullptr);
std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
MS_ERROR_IF_NULL_W_RET_VAL(message, nullptr);
message->to = AID("", server_url);
message->body.assign(static_cast<char *>(data->addr), data->size);
return message;
}
} // namespace runtime
} // namespace mindspore

View File

@ -36,11 +36,28 @@ class SendActor : public RpcActor {
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kSendActor) {}
~SendActor() override = default;
void SetRouteInfo(uint32_t dst_rank, const std::string &dst_role, const std::string &src_node_name,
const std::string &dst_node_name) override;
// Set send actor's destination peer info, in another word, send actor's output.
void SetRouteInfo(uint32_t dst_rank, const std::string &dst_role, const std::string &send_src_node_name,
const std::string &send_dst_node_name) override;
// Lookup peer actors' route and create connection to them.
bool ConnectServer();
protected:
// After rpc send kernel is launched, inter-process data should be sent.
void SendOutput(OpContext<DeviceTensor> *const context) override;
private:
// Client only supports to send MessageBase, so build MessageBase with data and url.
std::unique_ptr<MessageBase> BuildRpcMessage(const kernel::AddressPtr &data, const std::string &server_url);
friend class GraphScheduler;
// This send actor's destination peers' actor ids and route table.
std::vector<std::string> peer_actor_ids_;
mindspore::HashMap<std::string, std::string> peer_actor_urls_;
std::unique_ptr<TCPClient> client_;
};
using SendActorPtr = std::shared_ptr<SendActor>;

View File

@ -282,6 +282,8 @@ void GraphScheduler::Initialize() {
&GraphScheduler::LinkDataArrowForDeviceTensorStore);
(void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kInternalParameter,
&GraphScheduler::LinkDataArrowForInternalParameter);
(void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kSendActor, &GraphScheduler::LinkDataArrowForBaseActor);
(void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kRecvActor, &GraphScheduler::LinkDataArrowForBaseActor);
// Create the thread pool of actor runtime and Set the OMP_NUM_THREADS env.
size_t actor_thread_num = 0;
@ -297,10 +299,12 @@ void GraphScheduler::Initialize() {
MS_LOG(INFO) << "The actor thread number: " << actor_thread_num
<< ", the kernel thread number: " << (actor_and_kernel_thread_num - actor_thread_num);
#ifdef ENABLE_RPC_ACTOR
// Create and initialize RpcNodeScheduler.
rpc_node_scheduler_ = std::make_unique<RpcNodeScheduler>();
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
rpc_node_scheduler_->Initialize();
#endif
BuildAndScheduleGlobalActor();
}
@ -393,6 +397,12 @@ void GraphScheduler::Schedule(const ActorSet *actor_set) {
for (auto actor : actors) {
(void)actor_manager->Spawn(actor);
}
#ifdef ENABLE_RPC_ACTOR
// Build physical connections in 'RpcNodeScheduler::Schedule()' method. This costs some time.
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
rpc_node_scheduler_->Schedule();
#endif
}
void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceContext *> &device_contexts,
@ -410,6 +420,12 @@ void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceCont
op_context.sequential_num_ = RandInt::Instance().Get();
op_context.results_ = &result;
#ifdef ENABLE_RPC_ACTOR
// Set OpContext to rpc node scheduler.
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
rpc_node_scheduler_->SetOpcontext(&op_context);
#endif
if ((strategy == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
actor_set->data_prepare_actor_->PrepareData(input_tensors, &op_context, GraphExecutionStrategy::kStep);
MS_EXCEPTION_IF_NULL(actor_set->kernel_actors_[0]);
@ -543,8 +559,10 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info)
BuildDataPrepareActor(graph_compiler_info, actor_set->data_source_actors_, host_queue);
actor_set->control_actors_ = control_node_scheduler_.Build(graph_compiler_info, memory_manager_aid_);
#ifdef ENABLE_RPC_ACTOR
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
actor_set->rpc_actors_ = rpc_node_scheduler_->Build(graph_compiler_info);
#endif
return actor_set;
}
@ -622,6 +640,12 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
graph_compiler_info.control_node_parser_ != nullptr && graph_compiler_info.control_node_parser_->IsInited()) {
control_node_scheduler_.Link(actor_set, graph_compiler_info);
}
#ifdef ENABLE_RPC_ACTOR
// Link inter-process arrows for rpc actors.
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
rpc_node_scheduler_->Link(actor_set);
#endif
}
std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
@ -959,6 +983,7 @@ KernelActorPtr GraphScheduler::GenerateRpcActor(const CNodePtr &kernel, const De
const std::set<size_t> &ref_output_indexes) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(device_context);
#ifdef ENABLE_RPC_ACTOR
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
if (common::AnfAlgo::GetCNodeName(kernel) == kRpcSendOpName) {
auto send_actor =
@ -977,6 +1002,7 @@ KernelActorPtr GraphScheduler::GenerateRpcActor(const CNodePtr &kernel, const De
} else {
MS_LOG(EXCEPTION) << "Kernel " << kernel->fullname_with_scope() << " is not an rpc kernel.";
}
#endif
return nullptr;
}

View File

@ -28,12 +28,15 @@
#include "utils/hash_map.h"
#include "utils/hash_set.h"
#include "runtime/graph_scheduler/control_node_scheduler.h"
#include "runtime/graph_scheduler/rpc_node_scheduler.h"
#include "runtime/graph_scheduler/actor/actor_set.h"
#include "runtime/graph_scheduler/graph_compiler.h"
#include "runtime/graph_scheduler/actor/actor_dump.h"
#include "thread/actor_threadpool.h"
#ifdef ENABLE_RPC_ACTOR
#include "runtime/graph_scheduler/rpc_node_scheduler.h"
#endif
namespace mindspore {
namespace runtime {
using mindspore::device::DeviceContext;
@ -208,8 +211,10 @@ class GraphScheduler {
// In the control flow, used to build and link control actor.
ControlNodeScheduler control_node_scheduler_;
#ifdef ENABLE_RPC_ACTOR
// Used to build and link for rpc actors.
std::unique_ptr<RpcNodeScheduler> rpc_node_scheduler_{nullptr};
#endif
// The id of global actor.
AID memory_manager_aid_;

View File

@ -26,10 +26,21 @@ void RpcNodeScheduler::Initialize() {
RpcActorSetPtr RpcNodeScheduler::Build(const GraphCompilerInfo &) {
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
std::vector<RpcActorPtr> rpc_actors;
(void)rpc_actors.insert(rpc_actors.end(), rpc_actor_set_->send_actors_.begin(), rpc_actor_set_->send_actors_.end());
(void)rpc_actors.insert(rpc_actors.end(), rpc_actor_set_->recv_actors_.begin(), rpc_actor_set_->recv_actors_.end());
// Create route table proxy for each rpc actor and set.
for (auto &rpc_actor : rpc_actors) {
auto proxy = CreateRouteTableProxy();
MS_EXCEPTION_IF_NULL(proxy);
rpc_actor->SetActorRouteRableProxy(proxy);
}
return rpc_actor_set_;
}
void RpcNodeScheduler::Link(const ActorSetPtr &) {
void RpcNodeScheduler::Link(const ActorSet *) {
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
std::vector<SendActorPtr> send_actors = rpc_actor_set_->send_actors_;
std::vector<RecvActorPtr> recv_actors = rpc_actor_set_->recv_actors_;
@ -50,6 +61,7 @@ void RpcNodeScheduler::Link(const ActorSetPtr &) {
<< ", send_src_node_name: " << send_src_node_name
<< ", send_dst_node_name: " << send_dst_node_name;
}
send_actor->SetInterProcessEdgeName(send_src_node_name, send_dst_node_name);
send_actor->SetRouteInfo(send_dst_ranks[0], send_dst_roles[0], send_src_node_name, send_dst_node_name);
}
for (auto &recv_actor : recv_actors) {
@ -67,20 +79,66 @@ void RpcNodeScheduler::Link(const ActorSetPtr &) {
<< ", recv_src_node_name: " << recv_src_node_name
<< ", recv_dst_node_name: " << recv_dst_node_name;
}
recv_actor->SetInterProcessEdgeName(recv_src_node_name, recv_dst_node_name);
recv_actor->SetRouteInfo(recv_src_ranks[0], recv_src_roles[0], recv_src_node_name, recv_dst_node_name);
}
}
void RpcNodeScheduler::Schedule() {
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
// Must start server and register route table before looking up route and connecting.
// Start servers of recv actors and register route table.
for (auto &recv_actor : rpc_actor_set_->recv_actors_) {
MS_EXCEPTION_IF_NULL(recv_actor);
if (!recv_actor->StartServer()) {
MS_LOG(EXCEPTION) << "Failed to start server for the recv actor.";
}
}
// Lookup route and connect to servers for send actors.
for (auto &send_actor : rpc_actor_set_->send_actors_) {
MS_EXCEPTION_IF_NULL(send_actor);
if (!send_actor->ConnectServer()) {
MS_LOG(EXCEPTION) << "Failed to connect servers for the send actor.";
}
}
}
void RpcNodeScheduler::InsertSendActor(const SendActorPtr &send_actor) {
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
MS_EXCEPTION_IF_NULL(send_actor);
rpc_actor_set_->send_actors_.emplace_back(send_actor);
(void)rpc_actor_set_->send_actors_.emplace_back(send_actor);
}
void RpcNodeScheduler::InsertRecvActor(const RecvActorPtr &recv_actor) {
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
MS_EXCEPTION_IF_NULL(recv_actor);
rpc_actor_set_->recv_actors_.emplace_back(recv_actor);
(void)rpc_actor_set_->recv_actors_.emplace_back(recv_actor);
}
void RpcNodeScheduler::SetOpcontext(OpContext<DeviceTensor> *const op_context) {
MS_EXCEPTION_IF_NULL(op_context);
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
for (auto &recv_actor : rpc_actor_set_->recv_actors_) {
MS_EXCEPTION_IF_NULL(recv_actor);
recv_actor->SetOpcontext(op_context);
}
for (auto &send_actor : rpc_actor_set_->send_actors_) {
MS_EXCEPTION_IF_NULL(send_actor);
send_actor->SetOpcontext(op_context);
}
}
ActorRouteTableProxyPtr RpcNodeScheduler::CreateRouteTableProxy() {
ActorRouteTableProxyPtr actor_route_table_proxy;
if (!ClusterContext::instance()->IsScheduler()) {
auto node = ClusterContext::instance()->node();
actor_route_table_proxy =
std::make_shared<ActorRouteTableProxy>(std::dynamic_pointer_cast<ps::core::AbstractNode>(node));
MS_EXCEPTION_IF_NULL(actor_route_table_proxy);
}
return actor_route_table_proxy;
}
} // namespace runtime
} // namespace mindspore

View File

@ -32,7 +32,7 @@ using mindspore::session::KernelWithIndex;
// Scheduler for rpc actors, e.g., it adds inter-process arrows, generate router for actors, etc.
class RpcNodeScheduler {
public:
RpcNodeScheduler() = default;
RpcNodeScheduler() : rpc_actor_set_(nullptr) {}
~RpcNodeScheduler() = default;
// Create rpc actor set.
@ -42,13 +42,23 @@ class RpcNodeScheduler {
RpcActorSetPtr Build(const GraphCompilerInfo &graph_compiler_info);
// Link rpc actors with inter-process arrows.
void Link(const ActorSetPtr &actor_set);
void Link(const ActorSet *actor_set);
// This should be called by 'GraphScheduler::Scheduler()' method.
// Used to start servers for recv actors and create connections for send actors.
void Schedule();
// Insert Send/Recv actors generated by GraphScheduler to the rpc actor set.
void InsertSendActor(const SendActorPtr &send_actor);
void InsertRecvActor(const RecvActorPtr &recv_actor);
// Set op_context to rpc actors.
void SetOpcontext(OpContext<DeviceTensor> *const op_context);
private:
// Create new route table proxy.
ActorRouteTableProxyPtr CreateRouteTableProxy();
RpcActorSetPtr rpc_actor_set_;
};
} // namespace runtime

View File

@ -180,6 +180,8 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/ps/*.cc"
"../../../mindspore/ccsrc/fl/*.cc"
"../../../mindspore/ccsrc/distributed/cluster/actor_route_table_service.cc"
"../../../mindspore/ccsrc/distributed/cluster/actor_route_table_proxy.cc"
"../../../mindspore/ccsrc/distributed/cluster/cluster_context.cc"
"../../../mindspore/ccsrc/distributed/persistent/*.cc"
"../../../mindspore/ccsrc/distributed/rpc/tcp/*.cc"
"../../../mindspore/ccsrc/profiler/device/ascend/*.cc"