forked from mindspore-Ecosystem/mindspore
!30611 Fix rpc route bugs
Merge pull request !30611 from ZPaC/sync-route-table
This commit is contained in:
commit
4367377000
|
@ -80,6 +80,9 @@ ActorAddress ActorRouteTableProxy::LookupRoute(const std::string &actor_id) cons
|
|||
lookup_success = true;
|
||||
}
|
||||
} while (!lookup_success && CURRENT_TIMESTAMP_MILLI <= timeout_ts);
|
||||
if (!lookup_success) {
|
||||
MS_LOG(EXCEPTION) << "Failed to lookup actor address for " << actor_id;
|
||||
}
|
||||
|
||||
return lookup_route_rsp_msg;
|
||||
}
|
||||
|
|
|
@ -31,15 +31,19 @@ using ps::core::ActorAddress;
|
|||
using ps::core::GeneralResponseMsg;
|
||||
using ps::core::NodeCommand;
|
||||
|
||||
// The timeout in milliseconds for one lookup.
|
||||
constexpr uint32_t kDefaultLookupTimeout = 5000;
|
||||
|
||||
// The time in milliseconds between two lookup operations.
|
||||
constexpr auto kLookupInterval = 100;
|
||||
constexpr uint32_t kLookupInterval = 100;
|
||||
|
||||
// Actor route table proxy for nodes like workers and server. This class helps update actor route table in scheduler
|
||||
// across the network.
|
||||
class ActorRouteTableProxy {
|
||||
public:
|
||||
explicit ActorRouteTableProxy(const std::shared_ptr<ps::core::AbstractNode> &node, uint32_t lookup_timout)
|
||||
: node_(node), lookup_timeout_(std::chrono::milliseconds(lookup_timout)) {}
|
||||
explicit ActorRouteTableProxy(const std::shared_ptr<ps::core::AbstractNode> &node,
|
||||
uint32_t lookup_timeout = kDefaultLookupTimeout)
|
||||
: node_(node), lookup_timeout_(std::chrono::milliseconds(lookup_timeout)) {}
|
||||
~ActorRouteTableProxy() = default;
|
||||
|
||||
// Register actor address to the route table stored in scheduler.
|
||||
|
@ -55,9 +59,11 @@ class ActorRouteTableProxy {
|
|||
// The node variable helps proxy to communicate with scheduler, e.g., SendMessage.
|
||||
std::shared_ptr<ps::core::AbstractNode> node_;
|
||||
|
||||
// The timeout window for lookup route operation because time of route lookup_timout of each process is different.
|
||||
// The timeout window for lookup route operation because time of route lookup_timeout of each process is different.
|
||||
std::chrono::milliseconds lookup_timeout_;
|
||||
};
|
||||
|
||||
using ActorRouteTableProxyPtr = std::shared_ptr<ActorRouteTableProxy>;
|
||||
} // namespace cluster
|
||||
} // namespace distributed
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,6 +34,7 @@ ClusterContext::ClusterContext()
|
|||
scheduler_host_(kLocalHost),
|
||||
scheduler_port_(kDefaultSchedPort),
|
||||
node_(nullptr),
|
||||
abstract_node_(nullptr),
|
||||
node_role_(""),
|
||||
cluster_config_(nullptr) {}
|
||||
|
||||
|
@ -84,6 +85,14 @@ bool ClusterContext::Initialize() {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Step 3: Initialize some modules for the node, e.g., actor route table proxy.
|
||||
if (!IsScheduler()) {
|
||||
// Only node which is not the scheduler needs route table proxy.
|
||||
actor_route_table_proxy_ =
|
||||
std::make_shared<ActorRouteTableProxy>(std::dynamic_pointer_cast<ps::core::AbstractNode>(node_));
|
||||
MS_EXCEPTION_IF_NULL(actor_route_table_proxy_);
|
||||
}
|
||||
|
||||
inited_ = true;
|
||||
finalized_ = false;
|
||||
return true;
|
||||
|
@ -105,6 +114,8 @@ bool ClusterContext::Finalize(uint32_t timeout) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ClusterContext::IsScheduler() { return (abstract_node_ == nullptr) ? true : false; }
|
||||
|
||||
const std::shared_ptr<ps::core::Node> &ClusterContext::node() const { return node_; }
|
||||
|
||||
const std::string &ClusterContext::node_role() const { return node_role_; }
|
||||
|
@ -120,6 +131,8 @@ uint32_t ClusterContext::node_num(const std::string &node_role) {
|
|||
|
||||
bool ClusterContext::initialized() const { return inited_; }
|
||||
|
||||
const ActorRouteTableProxyPtr &ClusterContext::actor_route_table_proxy() const { return actor_route_table_proxy_; }
|
||||
|
||||
void ClusterContext::InitClusterConfig() {
|
||||
InitNodeRole();
|
||||
InitSchedulerIp();
|
||||
|
@ -154,6 +167,7 @@ bool ClusterContext::BuildCluster() {
|
|||
MS_LOG(ERROR) << "Building network failed.";
|
||||
return false;
|
||||
}
|
||||
abstract_node_ = std::dynamic_pointer_cast<ps::core::AbstractNode>(node_);
|
||||
MS_LOG(INFO) << "Cluster is successfully initialized.";
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "ps/core/worker_node.h"
|
||||
#include "ps/core/server_node.h"
|
||||
#include "ps/core/scheduler_node.h"
|
||||
#include "distributed/cluster/actor_route_table_proxy.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace distributed {
|
||||
|
@ -57,6 +58,10 @@ class ClusterContext {
|
|||
// Finalize the cluster and process exits. If timeout is set to UINT32_MAX, this method will block without timeout.
|
||||
bool Finalize(uint32_t timeout = kDefaultFinishTimeout);
|
||||
|
||||
// Return whether this node is the scheduler node.
|
||||
// In a cluster, the scheduler node is special because it's responsible for building network.
|
||||
bool IsScheduler();
|
||||
|
||||
// Return node object of this process.
|
||||
const std::shared_ptr<ps::core::Node> &node() const;
|
||||
|
||||
|
@ -69,6 +74,9 @@ class ClusterContext {
|
|||
// Return cluster is initialized.
|
||||
bool initialized() const;
|
||||
|
||||
// Return actor route proxy for AbstractNode.
|
||||
const ActorRouteTableProxyPtr &actor_route_table_proxy() const;
|
||||
|
||||
private:
|
||||
ClusterContext();
|
||||
|
||||
|
@ -106,11 +114,17 @@ class ClusterContext {
|
|||
// The node could be Worker, Server or Scheduler, etc.
|
||||
std::shared_ptr<ps::core::Node> node_;
|
||||
|
||||
// abstract_node_ is nullptr only when this is node is scheduler.
|
||||
std::shared_ptr<ps::core::AbstractNode> abstract_node_;
|
||||
|
||||
// The role of this process in the cluster.
|
||||
std::string node_role_;
|
||||
|
||||
// The configuration of this cluster.
|
||||
std::unique_ptr<ps::core::ClusterConfig> cluster_config_;
|
||||
|
||||
// The actor route table proxy. It only created in abstract nodes because scheduler does not use proxy.
|
||||
ActorRouteTableProxyPtr actor_route_table_proxy_;
|
||||
};
|
||||
} // namespace cluster
|
||||
} // namespace distributed
|
||||
|
|
|
@ -275,12 +275,13 @@ CNodePtr GraphSplitter::GenerateSendNode(const AnfNodePtr &input, const AnfNodeP
|
|||
MS_EXCEPTION_IF_NULL(peer);
|
||||
|
||||
std::vector<AnfNodePtr> send_inputs = {NewValueNode(std::make_shared<Primitive>(kRpcSendOpName))};
|
||||
auto mock_value = GenerateMockValueNode(true, input);
|
||||
MS_EXCEPTION_IF_NULL(mock_value);
|
||||
ValueNodePtr mock_value = nullptr;
|
||||
if (IsPrimitiveCNode(input, prim::kPrimUpdateState)) {
|
||||
mock_value = GenerateMockValueNode(false);
|
||||
send_inputs.push_back(mock_value);
|
||||
send_inputs.push_back(input);
|
||||
} else {
|
||||
mock_value = GenerateMockValueNode(true, input);
|
||||
send_inputs.push_back(input);
|
||||
}
|
||||
CNodePtr send_node = func_graph_->NewCNode(send_inputs);
|
||||
|
|
|
@ -30,7 +30,6 @@
|
|||
#include "include/common/utils/anfalgo.h"
|
||||
#include "kernel/common_utils.h"
|
||||
#include "ir/anf.h"
|
||||
#include "runtime/graph_scheduler/graph_scheduler.h"
|
||||
#include "actor/actormgr.h"
|
||||
#include "include/common/thread_pool.h"
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RPC_RPC_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
|
||||
|
@ -35,6 +36,9 @@ class RpcKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
void InitKernel(const CNodePtr &kernel_node) override { return; }
|
||||
|
||||
// Set remote data as input.
|
||||
void SetRemoteInput(std::unique_ptr<MessageBase> &&) {}
|
||||
|
||||
private:
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -312,10 +312,12 @@ bool AbstractNode::SendToScheduler(const void *message, size_t len, NodeCommand
|
|||
// Assign the response value from scheduler.
|
||||
if (output != nullptr) {
|
||||
if (received_scheduler_messages_.count(request_id) == 0) {
|
||||
MS_LOG(ERROR) << "The response message of " << node_cmd << " is not received yet.";
|
||||
MS_LOG(ERROR) << "The response message of command " << node_cmd << ", request_id " << request_id
|
||||
<< " is not received yet.";
|
||||
return false;
|
||||
}
|
||||
*output = received_scheduler_messages_[request_id];
|
||||
(void)received_scheduler_messages_.erase(request_id);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -107,9 +107,6 @@ bool Node::Wait(uint64_t request_id, const uint32_t &timeout) {
|
|||
if (receive_messages_.count(request_id) != 0) {
|
||||
(void)receive_messages_.erase(request_id);
|
||||
}
|
||||
if (received_scheduler_messages_.count(request_id) != 0) {
|
||||
(void)received_scheduler_messages_.erase(request_id);
|
||||
}
|
||||
msgs_lock.unlock();
|
||||
return res;
|
||||
}
|
||||
|
|
|
@ -3,6 +3,13 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/core/mindrt/src)
|
|||
|
||||
file(GLOB_RECURSE GRAPH_SCHEDULER_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
|
||||
if(NOT ENABLE_CPU OR WIN32)
|
||||
list(REMOVE_ITEM GRAPH_SCHEDULER_SRC_LIST "rpc_node_scheduler.cc")
|
||||
list(REMOVE_ITEM GRAPH_SCHEDULER_SRC_LIST "actor/rpc/recv_actor.cc")
|
||||
list(REMOVE_ITEM GRAPH_SCHEDULER_SRC_LIST "actor/rpc/rpc_actor.cc")
|
||||
list(REMOVE_ITEM GRAPH_SCHEDULER_SRC_LIST "actor/rpc/send_actor.cc")
|
||||
endif()
|
||||
|
||||
set_property(SOURCE ${GRAPH_SCHEDULER_SRC_LIST}
|
||||
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_RUNTIME_FRAMEWORK)
|
||||
add_library(_mindspore_runtime_graph_scheduler_obj OBJECT ${GRAPH_SCHEDULER_SRC_LIST})
|
|
@ -17,6 +17,10 @@
|
|||
#ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SET_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SET_H_
|
||||
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#define ENABLE_RPC_ACTOR
|
||||
#endif
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
@ -30,8 +34,6 @@
|
|||
#include "runtime/graph_scheduler/actor/loop_count_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/kernel_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/custom_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/rpc/send_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/rpc/recv_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/super_kernel_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/output_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/copy_actor.h"
|
||||
|
@ -41,6 +43,11 @@
|
|||
#include "runtime/graph_scheduler/actor/control_flow/exit_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/control_flow/stack_actor.h"
|
||||
|
||||
#ifdef ENABLE_RPC_ACTOR
|
||||
#include "runtime/graph_scheduler/actor/rpc/send_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/rpc/recv_actor.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using ActorInfo = std::string;
|
||||
|
@ -63,6 +70,7 @@ struct ControlActorSet {
|
|||
};
|
||||
using ControlActorSetPtr = std::shared_ptr<ControlActorSet>;
|
||||
|
||||
#ifdef ENABLE_RPC_ACTOR
|
||||
// Rpc actor set is a series of actors implemented to communicate with other processes. In distributed execution mode,
|
||||
// the graph could be considered as partitioned to different processes, which is connected by these rpc actors. Send
|
||||
// actors are in charge of sending data to other processes. Recv actors are in charge of receiving data from other
|
||||
|
@ -72,6 +80,7 @@ struct RpcActorSet {
|
|||
std::vector<RecvActorPtr> recv_actors_;
|
||||
};
|
||||
using RpcActorSetPtr = std::shared_ptr<RpcActorSet>;
|
||||
#endif
|
||||
|
||||
// The actor set generated by graph transformer is the execution unit of actor runtime.
|
||||
// It includes data source actor, kernel actor, switch actor, copy actor, loop count actor and output actor.
|
||||
|
@ -98,7 +107,9 @@ struct ActorSet {
|
|||
LoopCountActorPtr loop_count_actor_{nullptr};
|
||||
OutputActorPtr output_actor_{nullptr};
|
||||
ControlActorSetPtr control_actors_{nullptr};
|
||||
#ifdef ENABLE_RPC_ACTOR
|
||||
RpcActorSetPtr rpc_actors_{nullptr};
|
||||
#endif
|
||||
ActorInfo name_;
|
||||
// The related statistics information of multi thread and single thread to decide whether use the multi thread.
|
||||
bool is_multi_thread_execution_{true};
|
||||
|
|
|
@ -50,8 +50,8 @@ class KernelActor : public DebugAwareActor {
|
|||
const std::set<size_t> &modifiable_ref_output_indexes,
|
||||
const KernelTransformType &type = KernelTransformType::kKernelActor)
|
||||
: DebugAwareActor(name, type, recorder_aid, memory_manager_aid, debug_aid),
|
||||
kernel_(kernel),
|
||||
kernel_info_(nullptr),
|
||||
kernel_(kernel),
|
||||
is_dynamic_shape_(false),
|
||||
real_input_num_(0),
|
||||
strategy_(strategy),
|
||||
|
@ -85,6 +85,10 @@ class KernelActor : public DebugAwareActor {
|
|||
void Run(OpContext<DeviceTensor> *const context) override;
|
||||
void SendRecorderInfo(OpContext<DeviceTensor> *const context) const override;
|
||||
|
||||
KernelInfo *kernel_info_;
|
||||
// The kernel launch info is fetched by the device tensors.
|
||||
KernelLaunchInfo launch_info_;
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
friend class ControlNodeScheduler;
|
||||
|
@ -111,7 +115,6 @@ class KernelActor : public DebugAwareActor {
|
|||
|
||||
// The info of kernel.
|
||||
CNodePtr kernel_;
|
||||
KernelInfo *kernel_info_;
|
||||
bool is_dynamic_shape_;
|
||||
|
||||
// The real input number of kernel launch.
|
||||
|
@ -138,9 +141,6 @@ class KernelActor : public DebugAwareActor {
|
|||
// The device tensor of external reference is not the real data of this kernel, but need add to the memory_free_list_.
|
||||
std::vector<DeviceTensor *> external_reference_tensors_;
|
||||
|
||||
// The kernel launch info is fetched by the device tensors.
|
||||
KernelLaunchInfo launch_info_;
|
||||
|
||||
// Record the modifiable ref indexes. Used to refresh the ref data which are modified in the running.
|
||||
std::set<size_t> modifiable_ref_input_indexes_;
|
||||
std::set<size_t> modifiable_ref_output_indexes_;
|
||||
|
|
|
@ -16,10 +16,94 @@
|
|||
|
||||
#include "runtime/graph_scheduler/actor/rpc/recv_actor.h"
|
||||
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/rpc/rpc_recv_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void RecvActor::SetRouteInfo(uint32_t, const std::string &, const std::string &src_node_name, const std::string &) {
|
||||
input_peer_node_name_.emplace_back(src_node_name);
|
||||
void RecvActor::RunOpInterProcessData(std::unique_ptr<MessageBase> &&msg, OpContext<DeviceTensor> *const context) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(msg);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(op_context_);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
(void)input_op_inter_process_[sequential_num].emplace_back(msg->From().Name());
|
||||
|
||||
auto is_run = CheckRunningCondition(context);
|
||||
MS_LOG(INFO) << "Actor(" << GetAID().Name() << ") receive the input op inter-process. Edge is "
|
||||
<< inter_process_edge_name_ << ". Check running condition:" << is_run;
|
||||
|
||||
// Parse the message from remote peer and set to rpc recv kernel.
|
||||
auto recv_kernel_mod = dynamic_cast<kernel::RpcKernelMod *>(kernel_info_->MutableKernelMod());
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(recv_kernel_mod);
|
||||
// We set remote data by the interface of the rpc kernel, because currently there's no remote input for a kernel mod.
|
||||
recv_kernel_mod->SetRemoteInput(std::move(msg));
|
||||
|
||||
if (is_run) {
|
||||
Run(context);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void RecvActor::SetRouteInfo(uint32_t, const std::string &, const std::string &recv_src_node_name,
|
||||
const std::string &recv_dst_node_name) {
|
||||
rpc_input_node_name_.emplace_back(recv_src_node_name);
|
||||
input_inter_process_num_++;
|
||||
}
|
||||
|
||||
bool RecvActor::StartServer() {
|
||||
// Step 1: Create a tcp server and start listening.
|
||||
std::string server_url = ip_ + ":" + std::to_string(port_);
|
||||
server_ = std::make_unique<TCPServer>();
|
||||
MS_EXCEPTION_IF_NULL(server_);
|
||||
if (!server_->Initialize(server_url)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to initialize tcp server for recv actor. Server url: " << server_url;
|
||||
}
|
||||
|
||||
// Step 2: Set the message handler of the server.
|
||||
|
||||
// Step 2: Register the server address to route table. The server should not be connected before this step is done.
|
||||
ActorAddress recv_actor_addresss;
|
||||
recv_actor_addresss.set_actor_id(inter_process_edge_name_);
|
||||
recv_actor_addresss.set_ip(ip_);
|
||||
recv_actor_addresss.set_port(static_cast<uint32_t>(port_));
|
||||
MS_EXCEPTION_IF_NULL(actor_route_table_proxy_);
|
||||
if (!actor_route_table_proxy_->RegisterRoute(inter_process_edge_name_, recv_actor_addresss)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to register route for " << inter_process_edge_name_ << " " << server_url
|
||||
<< " when starting server.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RecvActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
// Step 1: Judge data and control inputs are satisfied.
|
||||
bool is_data_and_control_arrow_satisfied = AbstractActor::CheckRunningCondition(context);
|
||||
if (!is_data_and_control_arrow_satisfied) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input_inter_process_num_ != 0) {
|
||||
// Step 2: Judge inter-process inputs are satisfied.
|
||||
const auto &inter_process_iter = input_op_inter_process_.find(context->sequential_num_);
|
||||
if (inter_process_iter == input_op_inter_process_.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto ¤t_inter_process_inputs = inter_process_iter->second;
|
||||
if (current_inter_process_inputs.size() < input_inter_process_num_) {
|
||||
return false;
|
||||
} else if (current_inter_process_inputs.size() > input_inter_process_num_) {
|
||||
MS_LOG(ERROR) << "Invalid inter process input num:" << current_inter_process_inputs.size()
|
||||
<< " need:" << input_inter_process_num_ << " for actor:" << GetAID();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void RecvActor::HandleMessage(std::unique_ptr<MessageBase> &&msg) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(msg);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(op_context_);
|
||||
RunOpInterProcessData(std::move(msg), op_context_);
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,11 +36,31 @@ class RecvActor : public RpcActor {
|
|||
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kRecvActor) {}
|
||||
~RecvActor() override = default;
|
||||
|
||||
void SetRouteInfo(uint32_t src_rank, const std::string &src_role, const std::string &src_node_name,
|
||||
const std::string &dst_node_name) override;
|
||||
// When an inter-process data received, this method is called.
|
||||
void RunOpInterProcessData(std::unique_ptr<MessageBase> &&msg, OpContext<DeviceTensor> *const context);
|
||||
|
||||
// Set recv actor's source peer info, in another word, recv actor's input.
|
||||
void SetRouteInfo(uint32_t src_rank, const std::string &src_role, const std::string &recv_src_node_name,
|
||||
const std::string &recv_dst_node_name) override;
|
||||
|
||||
// Start recv actor server and register this server address to actor route table in scheduler by proxy.
|
||||
bool StartServer();
|
||||
|
||||
protected:
|
||||
// Besides the checking method in base class AbstractActor, condition of inter-process arrows should be checked for
|
||||
// recv actor.
|
||||
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
|
||||
|
||||
private:
|
||||
void HandleMessage(std::unique_ptr<MessageBase> &&msg);
|
||||
|
||||
friend class GraphScheduler;
|
||||
|
||||
// The network address of this recv actor. It's generated automatically by rpc module.
|
||||
std::string ip_;
|
||||
uint16_t port_;
|
||||
|
||||
std::unique_ptr<TCPServer> server_;
|
||||
};
|
||||
|
||||
using RecvActorPtr = std::shared_ptr<RecvActor>;
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/graph_scheduler/actor/rpc/rpc_actor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void RpcActor::SetInterProcessEdgeName(const std::string &src_node_name, const std::string &dst_node_name) {
|
||||
inter_process_edge_name_ = src_node_name + kInterProcessEdgeMark + dst_node_name;
|
||||
}
|
||||
|
||||
void RpcActor::SetOpcontext(OpContext<DeviceTensor> *const op_context) { op_context_ = op_context; }
|
||||
|
||||
void RpcActor::SetActorRouteRableProxy(const ActorRouteTableProxyPtr &proxy) { actor_route_table_proxy_ = proxy; }
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
|
@ -23,10 +23,22 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include "runtime/graph_scheduler/actor/kernel_actor.h"
|
||||
#include "distributed/cluster/cluster_context.h"
|
||||
#include "distributed/rpc/tcp/tcp_client.h"
|
||||
#include "distributed/rpc/tcp/tcp_server.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using distributed::cluster::ActorRouteTableProxy;
|
||||
using distributed::cluster::ActorRouteTableProxyPtr;
|
||||
using distributed::cluster::ClusterContext;
|
||||
using distributed::rpc::TCPClient;
|
||||
using distributed::rpc::TCPServer;
|
||||
using mindspore::device::KernelInfo;
|
||||
using ps::core::ActorAddress;
|
||||
|
||||
// The inter-process edge mark between two nodes.
|
||||
constexpr char kInterProcessEdgeMark[] = "->";
|
||||
|
||||
// RpcActor is used to do rpc with other processes in distributed execution.
|
||||
// Besides data arrows and controlling arrows, RpcActor also has inter-process arrows which is in charge of remote
|
||||
|
@ -38,30 +50,50 @@ class RpcActor : public KernelActor {
|
|||
GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
|
||||
const std::set<size_t> &modifiable_ref_output_indexes, const KernelTransformType &type)
|
||||
: KernelActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
|
||||
modifiable_ref_input_indexes, modifiable_ref_output_indexes, type) {}
|
||||
modifiable_ref_input_indexes, modifiable_ref_output_indexes, type),
|
||||
input_inter_process_num_(0) {}
|
||||
virtual ~RpcActor() = default;
|
||||
|
||||
// Normally, an actor's op_context is passed by its input actor, but rpc actors could be triggered by inter-process
|
||||
// arrows which do not contain op_context. So we need to set op_context manually.
|
||||
void SetOpcontext(OpContext<DeviceTensor> *const op_context);
|
||||
|
||||
// Set the actor route proxy for rpc actors.
|
||||
void SetActorRouteRableProxy(const ActorRouteTableProxyPtr &proxy);
|
||||
|
||||
// Set the inter-process edge name for rpc actor.
|
||||
void SetInterProcessEdgeName(const std::string &src_node_name, const std::string &dst_node_name);
|
||||
|
||||
// Set some info which will be used for rpc routing.
|
||||
virtual void SetRouteInfo(uint32_t peer_rank, const std::string &peer_role, const std::string &src_node_name,
|
||||
const std::string &dst_node_name) {}
|
||||
|
||||
// When an inter-process data received, this method is called.
|
||||
void RunOpInterProcessData(std::unique_ptr<MessageBase> &&msg, OpContext<DeviceTensor> *const context);
|
||||
|
||||
protected:
|
||||
// Besides the checking method in base class AbstractActor, condition of inter-process arrows should be checked.
|
||||
bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override { return true; }
|
||||
// The op context to run rpc actor inter-process op.
|
||||
OpContext<DeviceTensor> *op_context_;
|
||||
|
||||
// After rpc kernel is launched, inter-process data could be sent.
|
||||
void SendOutput(OpContext<DeviceTensor> *const context) override {}
|
||||
// The inter-process edge name. It is also used as the actor id for route. It's a string consists of source node name
|
||||
// and destination node name. The format is "source node name"->"destination node name". For each inter-process edge,
|
||||
// this is is unique. Rpc actor with the same inter_process_edge_name_ should not be in the same process.
|
||||
std::string inter_process_edge_name_;
|
||||
|
||||
// The node name of rpc actor's peers. They are not the name of send or recv nodes. Instead, they are the names of the
|
||||
// nodes which use send node as output and recv node as input.
|
||||
std::vector<std::string> rpc_input_node_name_;
|
||||
std::vector<std::string> rpc_output_node_name_;
|
||||
|
||||
// The iter-process inputs number. This should be the same as size of vector rpc_input_node_name_.
|
||||
size_t input_inter_process_num_;
|
||||
|
||||
// The inter-process inputs of each sequential number.
|
||||
mindspore::HashMap<int, std::vector<std::string>> input_op_inter_process_;
|
||||
|
||||
// The node name of rpc actor's peers.
|
||||
std::vector<std::string> input_peer_node_name_;
|
||||
std::vector<std::string> output_peer_node_name_;
|
||||
// The arrows represent inter-process communication.
|
||||
std::vector<AID> inter_process_input_arrows_;
|
||||
std::vector<AID> inter_process_output_arrows_;
|
||||
|
||||
ActorRouteTableProxyPtr actor_route_table_proxy_;
|
||||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
};
|
||||
|
|
|
@ -16,10 +16,69 @@
|
|||
|
||||
#include "runtime/graph_scheduler/actor/rpc/send_actor.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &, const std::string &dst_node_name) {
|
||||
output_peer_node_name_.emplace_back(dst_node_name);
|
||||
void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &send_src_node_name,
|
||||
const std::string &send_dst_node_name) {
|
||||
auto peer_actor_id = send_src_node_name + kInterProcessEdgeMark + send_dst_node_name;
|
||||
peer_actor_ids_.emplace_back(peer_actor_id);
|
||||
rpc_output_node_name_.emplace_back(send_dst_node_name);
|
||||
}
|
||||
|
||||
bool SendActor::ConnectServer() {
|
||||
client_ = std::make_unique<TCPClient>();
|
||||
MS_EXCEPTION_IF_NULL(client_);
|
||||
if (!client_->Initialize()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to initialize tcp server for send actor.";
|
||||
}
|
||||
// Lookup actor addresses for each peer actor.
|
||||
for (const auto &peer_actor_id : peer_actor_ids_) {
|
||||
MS_EXCEPTION_IF_NULL(actor_route_table_proxy_);
|
||||
auto peer_actor_address = actor_route_table_proxy_->LookupRoute(peer_actor_id);
|
||||
// If route is successfully looked up, peer_actor_address is not empty.
|
||||
std::string server_url = peer_actor_address.ip() + ":" + std::to_string(peer_actor_address.port());
|
||||
if (!client_->Connect(server_url)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to connect to server of actor " << peer_actor_id << ", server_url: " << server_url;
|
||||
}
|
||||
peer_actor_urls_[peer_actor_id] = server_url;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void SendActor::SendOutput(OpContext<DeviceTensor> *const context) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(context);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(client_);
|
||||
// Step 1: Send data and control outputs.
|
||||
AbstractActor::SendOutput(context);
|
||||
|
||||
// Step 2: Erase inter-process inputs for this sequential number.
|
||||
if (input_op_inter_process_.count(context->sequential_num_) != 0) {
|
||||
input_op_inter_process_.erase(context->sequential_num_);
|
||||
}
|
||||
|
||||
// Step 3: Send output data(inter-process data) to peers.
|
||||
if (launch_info_.outputs_.empty()) {
|
||||
MS_LOG(ERROR) << "Send kernel has no output tensor.";
|
||||
return;
|
||||
}
|
||||
auto send_output = launch_info_.outputs_[0];
|
||||
for (const auto &peer : peer_actor_urls_) {
|
||||
std::string peer_server_url = peer.second;
|
||||
auto message = BuildRpcMessage(send_output, peer_server_url);
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||
client_->Send(std::move(message));
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<MessageBase> SendActor::BuildRpcMessage(const kernel::AddressPtr &data, const std::string &server_url) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(data, nullptr);
|
||||
std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(message, nullptr);
|
||||
message->to = AID("", server_url);
|
||||
message->body.assign(static_cast<char *>(data->addr), data->size);
|
||||
return message;
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,11 +36,28 @@ class SendActor : public RpcActor {
|
|||
modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kSendActor) {}
|
||||
~SendActor() override = default;
|
||||
|
||||
void SetRouteInfo(uint32_t dst_rank, const std::string &dst_role, const std::string &src_node_name,
|
||||
const std::string &dst_node_name) override;
|
||||
// Set send actor's destination peer info, in another word, send actor's output.
|
||||
void SetRouteInfo(uint32_t dst_rank, const std::string &dst_role, const std::string &send_src_node_name,
|
||||
const std::string &send_dst_node_name) override;
|
||||
|
||||
// Lookup peer actors' route and create connection to them.
|
||||
bool ConnectServer();
|
||||
|
||||
protected:
|
||||
// After rpc send kernel is launched, inter-process data should be sent.
|
||||
void SendOutput(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
private:
|
||||
// Client only supports to send MessageBase, so build MessageBase with data and url.
|
||||
std::unique_ptr<MessageBase> BuildRpcMessage(const kernel::AddressPtr &data, const std::string &server_url);
|
||||
|
||||
friend class GraphScheduler;
|
||||
|
||||
// This send actor's destination peers' actor ids and route table.
|
||||
std::vector<std::string> peer_actor_ids_;
|
||||
mindspore::HashMap<std::string, std::string> peer_actor_urls_;
|
||||
|
||||
std::unique_ptr<TCPClient> client_;
|
||||
};
|
||||
|
||||
using SendActorPtr = std::shared_ptr<SendActor>;
|
||||
|
|
|
@ -282,6 +282,8 @@ void GraphScheduler::Initialize() {
|
|||
&GraphScheduler::LinkDataArrowForDeviceTensorStore);
|
||||
(void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kInternalParameter,
|
||||
&GraphScheduler::LinkDataArrowForInternalParameter);
|
||||
(void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kSendActor, &GraphScheduler::LinkDataArrowForBaseActor);
|
||||
(void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kRecvActor, &GraphScheduler::LinkDataArrowForBaseActor);
|
||||
|
||||
// Create the thread pool of actor runtime and Set the OMP_NUM_THREADS env.
|
||||
size_t actor_thread_num = 0;
|
||||
|
@ -297,10 +299,12 @@ void GraphScheduler::Initialize() {
|
|||
MS_LOG(INFO) << "The actor thread number: " << actor_thread_num
|
||||
<< ", the kernel thread number: " << (actor_and_kernel_thread_num - actor_thread_num);
|
||||
|
||||
#ifdef ENABLE_RPC_ACTOR
|
||||
// Create and initialize RpcNodeScheduler.
|
||||
rpc_node_scheduler_ = std::make_unique<RpcNodeScheduler>();
|
||||
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
|
||||
rpc_node_scheduler_->Initialize();
|
||||
#endif
|
||||
|
||||
BuildAndScheduleGlobalActor();
|
||||
}
|
||||
|
@ -393,6 +397,12 @@ void GraphScheduler::Schedule(const ActorSet *actor_set) {
|
|||
for (auto actor : actors) {
|
||||
(void)actor_manager->Spawn(actor);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_RPC_ACTOR
|
||||
// Build physical connections in 'RpcNodeScheduler::Schedule()' method. This costs some time.
|
||||
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
|
||||
rpc_node_scheduler_->Schedule();
|
||||
#endif
|
||||
}
|
||||
|
||||
void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceContext *> &device_contexts,
|
||||
|
@ -410,6 +420,12 @@ void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceCont
|
|||
op_context.sequential_num_ = RandInt::Instance().Get();
|
||||
op_context.results_ = &result;
|
||||
|
||||
#ifdef ENABLE_RPC_ACTOR
|
||||
// Set OpContext to rpc node scheduler.
|
||||
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
|
||||
rpc_node_scheduler_->SetOpcontext(&op_context);
|
||||
#endif
|
||||
|
||||
if ((strategy == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
|
||||
actor_set->data_prepare_actor_->PrepareData(input_tensors, &op_context, GraphExecutionStrategy::kStep);
|
||||
MS_EXCEPTION_IF_NULL(actor_set->kernel_actors_[0]);
|
||||
|
@ -543,8 +559,10 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info)
|
|||
BuildDataPrepareActor(graph_compiler_info, actor_set->data_source_actors_, host_queue);
|
||||
actor_set->control_actors_ = control_node_scheduler_.Build(graph_compiler_info, memory_manager_aid_);
|
||||
|
||||
#ifdef ENABLE_RPC_ACTOR
|
||||
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
|
||||
actor_set->rpc_actors_ = rpc_node_scheduler_->Build(graph_compiler_info);
|
||||
#endif
|
||||
return actor_set;
|
||||
}
|
||||
|
||||
|
@ -622,6 +640,12 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
|
|||
graph_compiler_info.control_node_parser_ != nullptr && graph_compiler_info.control_node_parser_->IsInited()) {
|
||||
control_node_scheduler_.Link(actor_set, graph_compiler_info);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_RPC_ACTOR
|
||||
// Link inter-process arrows for rpc actors.
|
||||
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
|
||||
rpc_node_scheduler_->Link(actor_set);
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
|
||||
|
@ -959,6 +983,7 @@ KernelActorPtr GraphScheduler::GenerateRpcActor(const CNodePtr &kernel, const De
|
|||
const std::set<size_t> &ref_output_indexes) {
|
||||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
MS_EXCEPTION_IF_NULL(device_context);
|
||||
#ifdef ENABLE_RPC_ACTOR
|
||||
MS_EXCEPTION_IF_NULL(rpc_node_scheduler_);
|
||||
if (common::AnfAlgo::GetCNodeName(kernel) == kRpcSendOpName) {
|
||||
auto send_actor =
|
||||
|
@ -977,6 +1002,7 @@ KernelActorPtr GraphScheduler::GenerateRpcActor(const CNodePtr &kernel, const De
|
|||
} else {
|
||||
MS_LOG(EXCEPTION) << "Kernel " << kernel->fullname_with_scope() << " is not an rpc kernel.";
|
||||
}
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -28,12 +28,15 @@
|
|||
#include "utils/hash_map.h"
|
||||
#include "utils/hash_set.h"
|
||||
#include "runtime/graph_scheduler/control_node_scheduler.h"
|
||||
#include "runtime/graph_scheduler/rpc_node_scheduler.h"
|
||||
#include "runtime/graph_scheduler/actor/actor_set.h"
|
||||
#include "runtime/graph_scheduler/graph_compiler.h"
|
||||
#include "runtime/graph_scheduler/actor/actor_dump.h"
|
||||
#include "thread/actor_threadpool.h"
|
||||
|
||||
#ifdef ENABLE_RPC_ACTOR
|
||||
#include "runtime/graph_scheduler/rpc_node_scheduler.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using mindspore::device::DeviceContext;
|
||||
|
@ -208,8 +211,10 @@ class GraphScheduler {
|
|||
// In the control flow, used to build and link control actor.
|
||||
ControlNodeScheduler control_node_scheduler_;
|
||||
|
||||
#ifdef ENABLE_RPC_ACTOR
|
||||
// Used to build and link for rpc actors.
|
||||
std::unique_ptr<RpcNodeScheduler> rpc_node_scheduler_{nullptr};
|
||||
#endif
|
||||
|
||||
// The id of global actor.
|
||||
AID memory_manager_aid_;
|
||||
|
|
|
@ -26,10 +26,21 @@ void RpcNodeScheduler::Initialize() {
|
|||
|
||||
RpcActorSetPtr RpcNodeScheduler::Build(const GraphCompilerInfo &) {
|
||||
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
|
||||
std::vector<RpcActorPtr> rpc_actors;
|
||||
(void)rpc_actors.insert(rpc_actors.end(), rpc_actor_set_->send_actors_.begin(), rpc_actor_set_->send_actors_.end());
|
||||
(void)rpc_actors.insert(rpc_actors.end(), rpc_actor_set_->recv_actors_.begin(), rpc_actor_set_->recv_actors_.end());
|
||||
|
||||
// Create route table proxy for each rpc actor and set.
|
||||
for (auto &rpc_actor : rpc_actors) {
|
||||
auto proxy = CreateRouteTableProxy();
|
||||
MS_EXCEPTION_IF_NULL(proxy);
|
||||
rpc_actor->SetActorRouteRableProxy(proxy);
|
||||
}
|
||||
|
||||
return rpc_actor_set_;
|
||||
}
|
||||
|
||||
void RpcNodeScheduler::Link(const ActorSetPtr &) {
|
||||
void RpcNodeScheduler::Link(const ActorSet *) {
|
||||
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
|
||||
std::vector<SendActorPtr> send_actors = rpc_actor_set_->send_actors_;
|
||||
std::vector<RecvActorPtr> recv_actors = rpc_actor_set_->recv_actors_;
|
||||
|
@ -50,6 +61,7 @@ void RpcNodeScheduler::Link(const ActorSetPtr &) {
|
|||
<< ", send_src_node_name: " << send_src_node_name
|
||||
<< ", send_dst_node_name: " << send_dst_node_name;
|
||||
}
|
||||
send_actor->SetInterProcessEdgeName(send_src_node_name, send_dst_node_name);
|
||||
send_actor->SetRouteInfo(send_dst_ranks[0], send_dst_roles[0], send_src_node_name, send_dst_node_name);
|
||||
}
|
||||
for (auto &recv_actor : recv_actors) {
|
||||
|
@ -67,20 +79,66 @@ void RpcNodeScheduler::Link(const ActorSetPtr &) {
|
|||
<< ", recv_src_node_name: " << recv_src_node_name
|
||||
<< ", recv_dst_node_name: " << recv_dst_node_name;
|
||||
}
|
||||
recv_actor->SetInterProcessEdgeName(recv_src_node_name, recv_dst_node_name);
|
||||
recv_actor->SetRouteInfo(recv_src_ranks[0], recv_src_roles[0], recv_src_node_name, recv_dst_node_name);
|
||||
}
|
||||
}
|
||||
|
||||
void RpcNodeScheduler::Schedule() {
|
||||
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
|
||||
// Must start server and register route table before looking up route and connecting.
|
||||
|
||||
// Start servers of recv actors and register route table.
|
||||
for (auto &recv_actor : rpc_actor_set_->recv_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(recv_actor);
|
||||
if (!recv_actor->StartServer()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to start server for the recv actor.";
|
||||
}
|
||||
}
|
||||
// Lookup route and connect to servers for send actors.
|
||||
for (auto &send_actor : rpc_actor_set_->send_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(send_actor);
|
||||
if (!send_actor->ConnectServer()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to connect servers for the send actor.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RpcNodeScheduler::InsertSendActor(const SendActorPtr &send_actor) {
|
||||
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
|
||||
MS_EXCEPTION_IF_NULL(send_actor);
|
||||
rpc_actor_set_->send_actors_.emplace_back(send_actor);
|
||||
(void)rpc_actor_set_->send_actors_.emplace_back(send_actor);
|
||||
}
|
||||
|
||||
void RpcNodeScheduler::InsertRecvActor(const RecvActorPtr &recv_actor) {
|
||||
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
|
||||
MS_EXCEPTION_IF_NULL(recv_actor);
|
||||
rpc_actor_set_->recv_actors_.emplace_back(recv_actor);
|
||||
(void)rpc_actor_set_->recv_actors_.emplace_back(recv_actor);
|
||||
}
|
||||
|
||||
void RpcNodeScheduler::SetOpcontext(OpContext<DeviceTensor> *const op_context) {
|
||||
MS_EXCEPTION_IF_NULL(op_context);
|
||||
MS_EXCEPTION_IF_NULL(rpc_actor_set_);
|
||||
|
||||
for (auto &recv_actor : rpc_actor_set_->recv_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(recv_actor);
|
||||
recv_actor->SetOpcontext(op_context);
|
||||
}
|
||||
for (auto &send_actor : rpc_actor_set_->send_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(send_actor);
|
||||
send_actor->SetOpcontext(op_context);
|
||||
}
|
||||
}
|
||||
|
||||
ActorRouteTableProxyPtr RpcNodeScheduler::CreateRouteTableProxy() {
|
||||
ActorRouteTableProxyPtr actor_route_table_proxy;
|
||||
if (!ClusterContext::instance()->IsScheduler()) {
|
||||
auto node = ClusterContext::instance()->node();
|
||||
actor_route_table_proxy =
|
||||
std::make_shared<ActorRouteTableProxy>(std::dynamic_pointer_cast<ps::core::AbstractNode>(node));
|
||||
MS_EXCEPTION_IF_NULL(actor_route_table_proxy);
|
||||
}
|
||||
return actor_route_table_proxy;
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,7 +32,7 @@ using mindspore::session::KernelWithIndex;
|
|||
// Scheduler for rpc actors, e.g., it adds inter-process arrows, generate router for actors, etc.
|
||||
class RpcNodeScheduler {
|
||||
public:
|
||||
RpcNodeScheduler() = default;
|
||||
RpcNodeScheduler() : rpc_actor_set_(nullptr) {}
|
||||
~RpcNodeScheduler() = default;
|
||||
|
||||
// Create rpc actor set.
|
||||
|
@ -42,13 +42,23 @@ class RpcNodeScheduler {
|
|||
RpcActorSetPtr Build(const GraphCompilerInfo &graph_compiler_info);
|
||||
|
||||
// Link rpc actors with inter-process arrows.
|
||||
void Link(const ActorSetPtr &actor_set);
|
||||
void Link(const ActorSet *actor_set);
|
||||
|
||||
// This should be called by 'GraphScheduler::Scheduler()' method.
|
||||
// Used to start servers for recv actors and create connections for send actors.
|
||||
void Schedule();
|
||||
|
||||
// Insert Send/Recv actors generated by GraphScheduler to the rpc actor set.
|
||||
void InsertSendActor(const SendActorPtr &send_actor);
|
||||
void InsertRecvActor(const RecvActorPtr &recv_actor);
|
||||
|
||||
// Set op_context to rpc actors.
|
||||
void SetOpcontext(OpContext<DeviceTensor> *const op_context);
|
||||
|
||||
private:
|
||||
// Create new route table proxy.
|
||||
ActorRouteTableProxyPtr CreateRouteTableProxy();
|
||||
|
||||
RpcActorSetPtr rpc_actor_set_;
|
||||
};
|
||||
} // namespace runtime
|
||||
|
|
|
@ -180,6 +180,8 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/ps/*.cc"
|
||||
"../../../mindspore/ccsrc/fl/*.cc"
|
||||
"../../../mindspore/ccsrc/distributed/cluster/actor_route_table_service.cc"
|
||||
"../../../mindspore/ccsrc/distributed/cluster/actor_route_table_proxy.cc"
|
||||
"../../../mindspore/ccsrc/distributed/cluster/cluster_context.cc"
|
||||
"../../../mindspore/ccsrc/distributed/persistent/*.cc"
|
||||
"../../../mindspore/ccsrc/distributed/rpc/tcp/*.cc"
|
||||
"../../../mindspore/ccsrc/profiler/device/ascend/*.cc"
|
||||
|
|
Loading…
Reference in New Issue