From fe2567b59311446ea79d9b377ca7cc2435662e76 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Sat, 11 Feb 2023 11:14:51 +0800 Subject: [PATCH] Enhance rdma impl --- mindspore/ccsrc/distributed/constants.h | 4 + .../ccsrc/distributed/rpc/rdma/constants.h | 6 +- .../ccsrc/distributed/rpc/rdma/rdma_client.cc | 78 ++++++++++++----- .../ccsrc/distributed/rpc/rdma/rdma_client.h | 6 ++ .../ccsrc/distributed/rpc/rdma/rdma_server.cc | 49 +++++++---- .../ccsrc/distributed/rpc/rdma/rdma_server.h | 4 + .../ccsrc/distributed/rpc/rpc_client_base.h | 17 +++- .../ccsrc/distributed/rpc/rpc_server_base.h | 4 +- .../ccsrc/distributed/rpc/tcp/tcp_client.h | 27 +++--- .../ccsrc/distributed/rpc/tcp/tcp_server.h | 24 +++--- .../graph_scheduler/actor/rpc/recv_actor.cc | 14 ++- .../graph_scheduler/actor/rpc/recv_actor.h | 6 +- .../graph_scheduler/actor/rpc/rpc_actor.h | 10 +++ .../graph_scheduler/actor/rpc/send_actor.cc | 13 ++- .../graph_scheduler/actor/rpc/send_actor.h | 8 +- .../ut/cpp/distributed/rpc/rdma/rdma_test.cc | 85 ++++++++++++++++++- 16 files changed, 267 insertions(+), 88 deletions(-) diff --git a/mindspore/ccsrc/distributed/constants.h b/mindspore/ccsrc/distributed/constants.h index 4828c12ac17..e27c4fd05b1 100644 --- a/mindspore/ccsrc/distributed/constants.h +++ b/mindspore/ccsrc/distributed/constants.h @@ -95,6 +95,10 @@ constexpr char kControlDstOpName[] = "ControlDst"; static const char URL_PROTOCOL_IP_SEPARATOR[] = "://"; static const char URL_IP_PORT_SEPARATOR[] = ":"; +constexpr char kEnableRDMA[] = "enable_rdma"; +constexpr char kRDMADevName[] = "rdma_dev"; +constexpr char kRDMAIP[] = "rdma_ip"; + // This macro the current timestamp in milliseconds. #define CURRENT_TIMESTAMP_MILLI \ (std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch())) diff --git a/mindspore/ccsrc/distributed/rpc/rdma/constants.h b/mindspore/ccsrc/distributed/rpc/rdma/constants.h index 016b746ef64..4e44775f280 100644 --- a/mindspore/ccsrc/distributed/rpc/rdma/constants.h +++ b/mindspore/ccsrc/distributed/rpc/rdma/constants.h @@ -18,7 +18,9 @@ #define MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_CONSTANTS_H_ #include +#include #include +#include #include "utils/dlopen_macro.h" #include "distributed/constants.h" @@ -114,8 +116,10 @@ constexpr uint32_t kServerWorkingThreadNum = 4; constexpr uint32_t kClientPollingThreadNum = 4; struct req_cb_arg { - int *rsp_received; + bool rsp_received; struct urpc_buffer_allocator *allocator; + std::mutex *mtx; + std::condition_variable *cv; }; } // namespace rpc } // namespace distributed diff --git a/mindspore/ccsrc/distributed/rpc/rdma/rdma_client.cc b/mindspore/ccsrc/distributed/rpc/rdma/rdma_client.cc index 16cc0bcecf4..24869d9bdb8 100644 --- a/mindspore/ccsrc/distributed/rpc/rdma/rdma_client.cc +++ b/mindspore/ccsrc/distributed/rpc/rdma/rdma_client.cc @@ -21,7 +21,10 @@ namespace mindspore { namespace distributed { namespace rpc { -bool RDMAClient::Initialize() { return true; } +bool RDMAClient::Initialize() { + // The initialization of URPC RDMA is implemented in 'Connect' function. + return true; +} void RDMAClient::Finalize() { if (urpc_session_ != nullptr) { @@ -63,7 +66,13 @@ bool RDMAClient::Connect(const std::string &dst_url, size_t retry_count, const M bool RDMAClient::IsConnected(const std::string &dst_url) { return false; } -bool RDMAClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) { return true; } +bool RDMAClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) { + if (urpc_session_ != nullptr) { + urpc_close_func(urpc_session_); + urpc_session_ = nullptr; + } + return true; +} bool RDMAClient::SendSync(std::unique_ptr &&msg, size_t *const send_bytes) { MS_EXCEPTION_IF_NULL(msg); @@ -77,34 +86,59 @@ bool RDMAClient::SendSync(std::unique_ptr &&msg, size_t *const send sgl.sge[0].flag = URPC_SGE_FLAG_ZERO_COPY; sgl.sge_num = 1; - int rsp_received = 0; - struct req_cb_arg cb_arg = {}; - cb_arg.rsp_received = &rsp_received; - cb_arg.allocator = urpc_allocator_; + struct urpc_send_wr send_wr = {}; + send_wr.func_id = kInterProcessDataHandleID; + send_wr.send_mode = URPC_SEND_MODE_SYNC; + send_wr.req = &sgl; + struct urpc_sgl rsp_sgl = {0}; + send_wr.sync.rsp = &rsp_sgl; + + if (urpc_send_request_func(urpc_session_, &send_wr, nullptr) < 0) { + MS_LOG(ERROR) << "Failed to send request for function call: " << send_wr.func_id; + return false; + } + MS_LOG(INFO) << "Server response message is " << reinterpret_cast(rsp_sgl.sge[0].addr); + + return true; +} + +void RDMAClient::SendAsync(std::unique_ptr &&msg) { + MS_EXCEPTION_IF_NULL(msg); + size_t msg_size = msg->size; + void *msg_buf = msg->data; + MS_EXCEPTION_IF_NULL(msg_buf); + + struct urpc_sgl sgl; + sgl.sge[0].addr = reinterpret_cast(msg_buf); + sgl.sge[0].length = msg_size; + sgl.sge[0].flag = URPC_SGE_FLAG_ZERO_COPY; + sgl.sge_num = 1; + + std::unique_lock lock(mtx_); + cb_arg_.rsp_received = false; + cb_arg_.allocator = urpc_allocator_; + cb_arg_.mtx = &mtx_; + cb_arg_.cv = &cv_; + lock.unlock(); struct urpc_send_wr send_wr = {}; send_wr.func_id = kInterProcessDataHandleID; send_wr.send_mode = URPC_SEND_MODE_ASYNC; send_wr.req = &sgl; send_wr.async.cb.wo_ctx = urpc_rsp_cb; - send_wr.async.cb_arg = &cb_arg; + send_wr.async.cb_arg = &cb_arg_; - if (urpc_send_request_func(urpc_session_, &send_wr, nullptr) != kURPCSuccess) { - MS_LOG(ERROR) << "Failed to send request."; - return false; + if (urpc_send_request_func(urpc_session_, &send_wr, nullptr) < 0) { + MS_LOG(EXCEPTION) << "Failed to send request to server."; + return; } - - size_t sleep_time_us = 200000; - // Wait till server responds. - while (rsp_received == 0) { - usleep(sleep_time_us); - } - return true; } -void RDMAClient::SendAsync(std::unique_ptr &&msg) {} - -bool RDMAClient::Flush(const std::string &dst_url) { return true; } +bool RDMAClient::Flush(const std::string &dst_url) { + std::unique_lock lock(mtx_); + cv_.wait(lock, [this]() { return cb_arg_.rsp_received; }); + return true; +} void RDMAClient::urpc_rsp_cb(struct urpc_sgl *rsp, int err, void *arg) { MS_ERROR_IF_NULL_WO_RET_VAL(rsp); @@ -116,7 +150,9 @@ void RDMAClient::urpc_rsp_cb(struct urpc_sgl *rsp, int err, void *arg) { MS_LOG(INFO) << "Server response message is " << reinterpret_cast(rsp->sge[0].addr); struct req_cb_arg *cb_arg = static_cast(arg); - *(cb_arg->rsp_received) = 1; + std::unique_lock lock(*(cb_arg->mtx)); + cb_arg->rsp_received = true; + cb_arg->cv->notify_all(); } } // namespace rpc } // namespace distributed diff --git a/mindspore/ccsrc/distributed/rpc/rdma/rdma_client.h b/mindspore/ccsrc/distributed/rpc/rdma/rdma_client.h index 42c2fbdcc9a..4a3badbdd24 100644 --- a/mindspore/ccsrc/distributed/rpc/rdma/rdma_client.h +++ b/mindspore/ccsrc/distributed/rpc/rdma/rdma_client.h @@ -65,6 +65,12 @@ class BACKEND_EXPORT RDMAClient : public RPCClientBase { struct urpc_buffer_allocator *urpc_allocator_; urpc_session_t *urpc_session_; + + // The variables for synchronization of async messages. + std::mutex mtx_; + std::condition_variable cv_; + + struct req_cb_arg cb_arg_; }; } // namespace rpc } // namespace distributed diff --git a/mindspore/ccsrc/distributed/rpc/rdma/rdma_server.cc b/mindspore/ccsrc/distributed/rpc/rdma/rdma_server.cc index 295f927c8d6..b63ca6cb44f 100644 --- a/mindspore/ccsrc/distributed/rpc/rdma/rdma_server.cc +++ b/mindspore/ccsrc/distributed/rpc/rdma/rdma_server.cc @@ -31,21 +31,18 @@ bool RDMAServer::Initialize(const std::string &url, const MemAllocateCallback &a port_ = port; // Init URPC for RMDA server. - struct urpc_config urpc_cfg = {}; - urpc_cfg.mode = URPC_MODE_SERVER; - urpc_cfg.sfeature = 0; - urpc_cfg.model = URPC_THREAD_MODEL_R2C; - urpc_cfg.worker_num = kServerWorkingThreadNum; - urpc_cfg.transport.dev_name = dev_name_; - urpc_cfg.transport.ip_addr = ip_addr_; - urpc_cfg.transport.port = port_; - urpc_cfg.transport.max_sge = 0; - urpc_cfg.allocator = nullptr; - if (urpc_init_func(&urpc_cfg) != kURPCSuccess) { - MS_LOG(EXCEPTION) << "Failed to call urpc_init. Device name: " << dev_name_ << ", ip address: " << ip_addr_ - << ", port: " << port_ << ". Please refer to URPC log directory: /var/log/umdk/urpc."; - } - return true; + return InitializeURPC(); +} + +bool RDMAServer::Initialize(const MemAllocateCallback &allocate_cb) { + dev_name_ = const_cast(common::GetEnv(kRDMADevName).c_str()); + ip_addr_ = const_cast(common::GetEnv(kRDMAIP).c_str()); + port_ = 0; + MS_LOG(INFO) << "Initialize RDMA server. Device name: " << dev_name_ << ", ip address: " << ip_addr_ + << ", port: " << port_; + + // Init URPC for RMDA server. + return InitializeURPC(); } void RDMAServer::Finalize() { @@ -67,9 +64,27 @@ void RDMAServer::SetMessageHandler(const MessageHandler &handler) { } } -std::string RDMAServer::GetIP() const { return ""; } +std::string RDMAServer::GetIP() const { return ip_addr_; } -uint32_t RDMAServer::GetPort() const { return 0; } +uint32_t RDMAServer::GetPort() const { return static_cast(port_); } + +bool RDMAServer::InitializeURPC() { + struct urpc_config urpc_cfg = {}; + urpc_cfg.mode = URPC_MODE_SERVER; + urpc_cfg.sfeature = 0; + urpc_cfg.model = URPC_THREAD_MODEL_R2C; + urpc_cfg.worker_num = kServerWorkingThreadNum; + urpc_cfg.transport.dev_name = dev_name_; + urpc_cfg.transport.ip_addr = ip_addr_; + urpc_cfg.transport.port = port_; + urpc_cfg.transport.max_sge = 0; + urpc_cfg.allocator = nullptr; + if (urpc_init_func(&urpc_cfg) != kURPCSuccess) { + MS_LOG(EXCEPTION) << "Failed to call urpc_init. Device name: " << dev_name_ << ", ip address: " << ip_addr_ + << ", port: " << port_ << ". Please refer to URPC log directory: /var/log/umdk/urpc."; + } + return true; +} void RDMAServer::urpc_req_handler(struct urpc_sgl *req, void *arg, struct urpc_sgl *rsp) { MS_ERROR_IF_NULL_WO_RET_VAL(req); diff --git a/mindspore/ccsrc/distributed/rpc/rdma/rdma_server.h b/mindspore/ccsrc/distributed/rpc/rdma/rdma_server.h index 7de0db74655..63e7fff148e 100644 --- a/mindspore/ccsrc/distributed/rpc/rdma/rdma_server.h +++ b/mindspore/ccsrc/distributed/rpc/rdma/rdma_server.h @@ -37,6 +37,7 @@ class BACKEND_EXPORT RDMAServer : public RPCServerBase { ~RDMAServer() override = default; bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {}) override; + bool Initialize(const MemAllocateCallback &allocate_cb = {}) override; void Finalize() override; void SetMessageHandler(const MessageHandler &handler) override; @@ -47,6 +48,9 @@ class BACKEND_EXPORT RDMAServer : public RPCServerBase { struct urpc_buffer_allocator *urpc_allocator_; private: + // Initialize urpc configuration according to dev_name_, ip_addr_ and port_. + bool InitializeURPC(); + // The message callback for urpc. This method will call user-set message handler. static void urpc_req_handler(struct urpc_sgl *req, void *arg, struct urpc_sgl *rsp); // The callback after this server responding. diff --git a/mindspore/ccsrc/distributed/rpc/rpc_client_base.h b/mindspore/ccsrc/distributed/rpc/rpc_client_base.h index 173b36c4ffe..1fb7f0fac8e 100644 --- a/mindspore/ccsrc/distributed/rpc/rpc_client_base.h +++ b/mindspore/ccsrc/distributed/rpc/rpc_client_base.h @@ -27,7 +27,7 @@ namespace distributed { namespace rpc { class BACKEND_EXPORT RPCClientBase { public: - explicit RPCClientBase(bool enable_ssl = false) : enable_ssl_(enable_ssl) {} + explicit RPCClientBase(bool enable_ssl) : enable_ssl_(enable_ssl) {} virtual ~RPCClientBase() = default; // Build or destroy the rpc client. @@ -36,20 +36,29 @@ class BACKEND_EXPORT RPCClientBase { // Connect to the specified server. // Function free_cb binds with client's each connection. It frees the real memory after message is sent to the peer. - virtual bool Connect(const std::string &dst_url, size_t retry_count, const MemFreeCallback &free_cb) { return true; } + virtual bool Connect( + const std::string &dst_url, size_t retry_count = 60, const MemFreeCallback &free_cb = [](void *data) { + MS_ERROR_IF_NULL(data); + delete static_cast(data); + return true; + }) { + return true; + } // Check if the connection to dst_url has been established. virtual bool IsConnected(const std::string &dst_url) { return false; } // Disconnect from the specified server. - virtual bool Disconnect(const std::string &dst_url, size_t timeout_in_sec) { return true; } + virtual bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5) { return true; } // Send the message from the source to the destination synchronously and return the byte size by this method call. - virtual bool SendSync(std::unique_ptr &&msg, size_t *const send_bytes) { return true; } + virtual bool SendSync(std::unique_ptr &&msg, size_t *const send_bytes = nullptr) { return true; } // Send the message from the source to the destination asynchronously. virtual void SendAsync(std::unique_ptr &&msg) {} + virtual MessageBase *ReceiveSync(std::unique_ptr &&msg, uint32_t timeout = 30) { return nullptr; } + // Force the data in the send buffer to be sent out. virtual bool Flush(const std::string &dst_url) { return true; } diff --git a/mindspore/ccsrc/distributed/rpc/rpc_server_base.h b/mindspore/ccsrc/distributed/rpc/rpc_server_base.h index 0c26b117963..255f3f08ef4 100644 --- a/mindspore/ccsrc/distributed/rpc/rpc_server_base.h +++ b/mindspore/ccsrc/distributed/rpc/rpc_server_base.h @@ -31,10 +31,10 @@ class BACKEND_EXPORT RPCServerBase { virtual ~RPCServerBase() = default; // Init server using the specified url, with memory allocating function. - virtual bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb) { return true; } + virtual bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {}) { return true; } // Init server using local IP and random port. - virtual bool Initialize(const MemAllocateCallback &allocate_cb) { return true; } + virtual bool Initialize(const MemAllocateCallback &allocate_cb = {}) { return true; } // Destroy the tcp server. virtual void Finalize() {} diff --git a/mindspore/ccsrc/distributed/rpc/tcp/tcp_client.h b/mindspore/ccsrc/distributed/rpc/tcp/tcp_client.h index d876fdea265..67f82bf7b42 100644 --- a/mindspore/ccsrc/distributed/rpc/tcp/tcp_client.h +++ b/mindspore/ccsrc/distributed/rpc/tcp/tcp_client.h @@ -22,6 +22,7 @@ #include #include +#include "distributed/rpc/rpc_client_base.h" #include "distributed/rpc/tcp/tcp_comm.h" #include "utils/ms_utils.h" #include "include/backend/visible.h" @@ -29,14 +30,14 @@ namespace mindspore { namespace distributed { namespace rpc { -class BACKEND_EXPORT TCPClient { +class BACKEND_EXPORT TCPClient : public RPCClientBase { public: - explicit TCPClient(bool enable_ssl = false) : enable_ssl_(enable_ssl) {} - ~TCPClient() = default; + explicit TCPClient(bool enable_ssl = false) : RPCClientBase(enable_ssl) {} + ~TCPClient() override = default; // Build or destroy the TCP client. - bool Initialize(); - void Finalize(); + bool Initialize() override; + void Finalize() override; // Connect to the specified server. // Function free_cb binds with client's each connection. It frees the real memory after message is sent to the peer. @@ -45,26 +46,26 @@ class BACKEND_EXPORT TCPClient { MS_ERROR_IF_NULL(data); delete static_cast(data); return true; - }); + }) override; // Check if the connection to dst_url has been established. - bool IsConnected(const std::string &dst_url); + bool IsConnected(const std::string &dst_url) override; // Disconnect from the specified server. - bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5); + bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5) override; // Send the message from the source to the destination synchronously and return the byte size by this method call. - bool SendSync(std::unique_ptr &&msg, size_t *const send_bytes = nullptr); + bool SendSync(std::unique_ptr &&msg, size_t *const send_bytes = nullptr) override; // Send the message from the source to the destination asynchronously. - void SendAsync(std::unique_ptr &&msg); + void SendAsync(std::unique_ptr &&msg) override; // Retrieve a message from tcp server specified by the input message. // Returns nullptr after timeout. - MessageBase *ReceiveSync(std::unique_ptr &&msg, uint32_t timeout = 30); + MessageBase *ReceiveSync(std::unique_ptr &&msg, uint32_t timeout = 30) override; // Force the data in the send buffer to be sent out. - bool Flush(const std::string &dst_url); + bool Flush(const std::string &dst_url) override; private: // The basic TCP communication component used by the client. @@ -78,8 +79,6 @@ class BACKEND_EXPORT TCPClient { // The received message from the meta server by calling the method `ReceiveSync`. MessageBase *received_message_{nullptr}; - bool enable_ssl_; - DISABLE_COPY_AND_ASSIGN(TCPClient); }; } // namespace rpc diff --git a/mindspore/ccsrc/distributed/rpc/tcp/tcp_server.h b/mindspore/ccsrc/distributed/rpc/tcp/tcp_server.h index 92d3d3fd6bd..90024fe0cbb 100644 --- a/mindspore/ccsrc/distributed/rpc/tcp/tcp_server.h +++ b/mindspore/ccsrc/distributed/rpc/tcp/tcp_server.h @@ -20,6 +20,7 @@ #include #include +#include "distributed/rpc/rpc_server_base.h" #include "distributed/rpc/tcp/tcp_comm.h" #include "utils/ms_utils.h" #include "include/backend/visible.h" @@ -27,26 +28,26 @@ namespace mindspore { namespace distributed { namespace rpc { -class BACKEND_EXPORT TCPServer { +class BACKEND_EXPORT TCPServer : public RPCServerBase { public: - explicit TCPServer(bool enable_ssl = false) : enable_ssl_(enable_ssl) {} - ~TCPServer() = default; + explicit TCPServer(bool enable_ssl = false) : RPCServerBase(enable_ssl) {} + ~TCPServer() override = default; // Init the tcp server using the specified url. - bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {}); + bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {}) override; // Init the tcp server using local IP and random port. - bool Initialize(const MemAllocateCallback &allocate_cb = {}); + bool Initialize(const MemAllocateCallback &allocate_cb = {}) override; // Destroy the tcp server. - void Finalize(); + void Finalize() override; // Set the message processing handler. - void SetMessageHandler(const MessageHandler &handler); + void SetMessageHandler(const MessageHandler &handler) override; // Return the IP and port binded by this server. - std::string GetIP() const; - uint32_t GetPort() const; + std::string GetIP() const override; + uint32_t GetPort() const override; private: bool InitializeImpl(const std::string &url, const MemAllocateCallback &allocate_cb); @@ -54,11 +55,6 @@ class BACKEND_EXPORT TCPServer { // The basic TCP communication component used by the server. std::unique_ptr tcp_comm_{nullptr}; - std::string ip_{""}; - uint32_t port_{0}; - - bool enable_ssl_; - DISABLE_COPY_AND_ASSIGN(TCPServer); }; } // namespace rpc diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.cc index 1cc1628ea78..1b8d151e2ce 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.cc @@ -21,7 +21,6 @@ #include #include #include "proto/topology.pb.h" -#include "distributed/rpc/tcp/constants.h" #include "plugin/device/cpu/kernel/rpc/rpc_recv_kernel.h" #include "backend/common/optimizer/helper.h" @@ -62,15 +61,24 @@ void RecvActor::SetRouteInfo(uint32_t, const std::string &, const std::string &r } bool RecvActor::StartServer() { - // Step 1: Create a tcp server and start listening. + // Step 1: Create a rpc server and start listening. + +#ifdef ENABLE_RDMA + if (common::GetEnv(kEnableRDMA) == "1") { + server_ = std::make_unique(); + } else { + server_ = std::make_unique(); + } +#else server_ = std::make_unique(); +#endif MS_EXCEPTION_IF_NULL(server_); // Set the memory allocating callback using void* message. std::function allocate_callback = std::bind(&RecvActor::AllocateMessage, this, std::placeholders::_1); if (!server_->Initialize(allocate_callback)) { - MS_LOG(EXCEPTION) << "Failed to initialize tcp server for recv actor"; + MS_LOG(EXCEPTION) << "Failed to initialize rpc server for recv actor"; } ip_ = server_->GetIP(); port_ = server_->GetPort(); diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.h index f8b05000781..17471837f75 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/recv_actor.h @@ -62,7 +62,7 @@ class RecvActor : public RpcActor { // Start recv actor server and register this server address to actor route table in scheduler by proxy. bool StartServer(); - // Finalize tcp server. + // Finalize rpc server. void Clear() override; void StopRpcAtException() override; @@ -102,7 +102,7 @@ class RecvActor : public RpcActor { */ void *AllocateMemByDeviceRes(size_t size); - std::unique_ptr server_; + std::unique_ptr server_; // The variables used to ensure thread-safe of op context visited by recv actor. bool is_context_valid_; @@ -129,7 +129,7 @@ class RecvActor : public RpcActor { // it, e.g., infer shape for RpcRecv kernel and call Resize(). void PreprocessRemoteInput(const MessageBase *const msg, bool *need_finalize); - // The message callback of the tcp server. + // The message callback of the rpc server. MessageBase *HandleMessage(MessageBase *const msg); // The network address of this recv actor. It's generated automatically by rpc module. diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/rpc_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/rpc_actor.h index 5a1d4b3b3fc..358c812373c 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/rpc_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/rpc_actor.h @@ -26,6 +26,10 @@ #include "distributed/cluster/cluster_context.h" #include "distributed/rpc/tcp/tcp_client.h" #include "distributed/rpc/tcp/tcp_server.h" +#ifdef ENABLE_RDMA +#include "distributed/rpc/rdma/rdma_client.h" +#include "distributed/rpc/rdma/rdma_server.h" +#endif #include "proto/rpc.pb.h" #include "proto/topology.pb.h" @@ -35,9 +39,15 @@ using distributed::cluster::ActorRouteTableProxy; using distributed::cluster::ActorRouteTableProxyPtr; using distributed::cluster::ClusterContext; using distributed::cluster::topology::ActorAddress; +using distributed::rpc::RPCClientBase; +using distributed::rpc::RPCServerBase; using distributed::rpc::TCPClient; using distributed::rpc::TCPServer; using mindspore::device::KernelInfo; +#ifdef ENABLE_RDMA +using distributed::rpc::RDMAClient; +using distributed::rpc::RDMAServer; +#endif // The inter-process edge mark between two nodes. constexpr char kInterProcessEdgeMark[] = "->"; diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.cc index 441c8d302a2..6852648df19 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.cc @@ -27,7 +27,7 @@ SendActor::~SendActor() { (void)client_->Disconnect(server_url_); client_->Finalize(); } catch (const std::exception &) { - MS_LOG(ERROR) << "Failed to disconnect and finalize for tcp client in send actor."; + MS_LOG(ERROR) << "Failed to disconnect and finalize for rpc client in send actor."; } client_ = nullptr; } @@ -40,10 +40,19 @@ void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &s } bool SendActor::ConnectServer() { +#ifdef ENABLE_RDMA + if (common::GetEnv(kEnableRDMA) == "1") { + client_ = std::make_unique(); + } else { + client_ = std::make_unique(); + } +#else client_ = std::make_unique(); +#endif MS_EXCEPTION_IF_NULL(client_); + if (!client_->Initialize()) { - MS_LOG(EXCEPTION) << "Failed to initialize tcp server for send actor."; + MS_LOG(EXCEPTION) << "Failed to initialize rpc server for send actor."; } // Lookup actor addresses for each peer actor. for (const auto &peer_actor_id : peer_actor_ids_) { diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.h index e3be2690bc4..d50362b6731 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/rpc/send_actor.h @@ -49,7 +49,7 @@ class SendActor : public RpcActor { // Flush and wait for sent data to be passed to kernel. void FlushData() override; - // Finalize tcp client. + // Finalize rpc client. void Clear() override; protected: @@ -76,8 +76,8 @@ class SendActor : public RpcActor { */ virtual void Flush(); - // The tcp client connection to multiple servers. - std::unique_ptr client_; + // The rpc client connection to multiple servers. + std::unique_ptr client_; private: /** @@ -136,7 +136,7 @@ class SendActor : public RpcActor { std::vector peer_actor_ids_; mindspore::HashMap peer_actor_urls_; - // The url of the peer recv actor's tcp server. + // The url of the peer recv actor's server. std::string server_url_; }; diff --git a/tests/ut/cpp/distributed/rpc/rdma/rdma_test.cc b/tests/ut/cpp/distributed/rpc/rdma/rdma_test.cc index 7642fad79a5..588211a0d91 100644 --- a/tests/ut/cpp/distributed/rpc/rdma/rdma_test.cc +++ b/tests/ut/cpp/distributed/rpc/rdma/rdma_test.cc @@ -45,7 +45,9 @@ class RDMATest : public UT::Common { std::unique_ptr RDMATest::CreateMessage(const std::string &msg) { std::unique_ptr message = std::make_unique(); size_t msg_size = msg.size(); - ASSERT_TRUE(msg_size != 0); + if (msg_size == 0) { + MS_LOG(EXCEPTION) << "msg_size is 0!"; + } void *data = malloc(msg_size + 1); (void)memcpy_s(data, msg_size, msg.c_str(), msg_size); message->data = data; @@ -57,6 +59,36 @@ std::unique_ptr RDMATest::CreateMessage(const std::string &msg) { /// Description: test basic connection function between RDMA client and server. /// Expectation: RDMA client successfully connects to RDMA server and sends a simple message. TEST_F(RDMATest, TestRDMAConnection) { + std::string url = "127.0.0.1:10969"; + size_t server_pid = fork(); + if (server_pid == 0) { + std::shared_ptr rdma_server = std::make_shared(); + MS_EXCEPTION_IF_NULL(rdma_server); + ASSERT_TRUE(rdma_server->Initialize(url)); + sleep(3); + rdma_server->Finalize(); + return; + } + sleep(1); + size_t client_pid = fork(); + if (client_pid == 0) { + std::shared_ptr rdma_client = std::make_shared(); + MS_EXCEPTION_IF_NULL(rdma_client); + ASSERT_TRUE(rdma_client->Initialize()); + ASSERT_TRUE(rdma_client->Connect(url)); + rdma_client->Finalize(); + return; + } + + int wstatus; + (void)waitpid(client_pid, &wstatus, WUNTRACED | WCONTINUED); + (void)waitpid(server_pid, &wstatus, WUNTRACED | WCONTINUED); +} + +/// Feature: RDMA communication. +/// Description: test SendSync interface for RDMA client and server. +/// Expectation: RDMA client successfully sends two messages to RDMA server synchronously. +TEST_F(RDMATest, TestRDMASendSync) { std::string url = "127.0.0.1:10969"; size_t server_pid = fork(); if (server_pid == 0) { @@ -70,6 +102,7 @@ TEST_F(RDMATest, TestRDMAConnection) { }; rdma_server->SetMessageHandler(msg_handler); sleep(3); + rdma_server->Finalize(); return; } sleep(1); @@ -80,8 +113,54 @@ TEST_F(RDMATest, TestRDMAConnection) { ASSERT_TRUE(rdma_client->Initialize()); ASSERT_TRUE(rdma_client->Connect(url)); - auto message = CreateMessage("Hello server!"); - ASSERT_TRUE(rdma_client->SendSync(std::move(message))); + auto message1 = CreateMessage("Hello server sync!"); + ASSERT_TRUE(rdma_client->SendSync(std::move(message1))); + auto message2 = CreateMessage("Hello server sync!"); + ASSERT_TRUE(rdma_client->SendSync(std::move(message2))); + rdma_client->Finalize(); + return; + } + + int wstatus; + (void)waitpid(client_pid, &wstatus, WUNTRACED | WCONTINUED); + (void)waitpid(server_pid, &wstatus, WUNTRACED | WCONTINUED); +} + +/// Feature: RDMA communication. +/// Description: test SendAsync interface for RDMA client and server. +/// Expectation: RDMA client successfully sends two messages to RDMA server asynchronously. +TEST_F(RDMATest, TestRDMASendAsync) { + std::string url = "127.0.0.1:10969"; + size_t server_pid = fork(); + if (server_pid == 0) { + std::shared_ptr rdma_server = std::make_shared(); + MS_EXCEPTION_IF_NULL(rdma_server); + ASSERT_TRUE(rdma_server->Initialize(url)); + + auto msg_handler = [](MessageBase *const msg) { + MS_LOG(INFO) << "Receive message from client: " << static_cast(msg->data); + return nullptr; + }; + rdma_server->SetMessageHandler(msg_handler); + sleep(3); + rdma_server->Finalize(); + return; + } + sleep(1); + size_t client_pid = fork(); + if (client_pid == 0) { + std::shared_ptr rdma_client = std::make_shared(); + MS_EXCEPTION_IF_NULL(rdma_client); + ASSERT_TRUE(rdma_client->Initialize()); + ASSERT_TRUE(rdma_client->Connect(url)); + + auto message1 = CreateMessage("Hello server async!"); + rdma_client->SendAsync(std::move(message1)); + ASSERT_TRUE(rdma_client->Flush(url)); + auto message2 = CreateMessage("Hello server async!"); + rdma_client->SendAsync(std::move(message2)); + ASSERT_TRUE(rdma_client->Flush(url)); + rdma_client->Finalize(); return; }