forked from mindspore-Ecosystem/mindspore
Enhance rdma impl
This commit is contained in:
parent
364e292587
commit
fe2567b593
|
@ -95,6 +95,10 @@ constexpr char kControlDstOpName[] = "ControlDst";
|
|||
static const char URL_PROTOCOL_IP_SEPARATOR[] = "://";
|
||||
static const char URL_IP_PORT_SEPARATOR[] = ":";
|
||||
|
||||
constexpr char kEnableRDMA[] = "enable_rdma";
|
||||
constexpr char kRDMADevName[] = "rdma_dev";
|
||||
constexpr char kRDMAIP[] = "rdma_ip";
|
||||
|
||||
// This macro the current timestamp in milliseconds.
|
||||
#define CURRENT_TIMESTAMP_MILLI \
|
||||
(std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()))
|
||||
|
|
|
@ -18,7 +18,9 @@
|
|||
#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_CONSTANTS_H_
|
||||
|
||||
#include <urpc.h>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <condition_variable>
|
||||
|
||||
#include "utils/dlopen_macro.h"
|
||||
#include "distributed/constants.h"
|
||||
|
@ -114,8 +116,10 @@ constexpr uint32_t kServerWorkingThreadNum = 4;
|
|||
constexpr uint32_t kClientPollingThreadNum = 4;
|
||||
|
||||
struct req_cb_arg {
|
||||
int *rsp_received;
|
||||
bool rsp_received;
|
||||
struct urpc_buffer_allocator *allocator;
|
||||
std::mutex *mtx;
|
||||
std::condition_variable *cv;
|
||||
};
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
|
|
|
@ -21,7 +21,10 @@
|
|||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
bool RDMAClient::Initialize() { return true; }
|
||||
bool RDMAClient::Initialize() {
|
||||
// The initialization of URPC RDMA is implemented in 'Connect' function.
|
||||
return true;
|
||||
}
|
||||
|
||||
void RDMAClient::Finalize() {
|
||||
if (urpc_session_ != nullptr) {
|
||||
|
@ -63,7 +66,13 @@ bool RDMAClient::Connect(const std::string &dst_url, size_t retry_count, const M
|
|||
|
||||
bool RDMAClient::IsConnected(const std::string &dst_url) { return false; }
|
||||
|
||||
bool RDMAClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) { return true; }
|
||||
bool RDMAClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) {
|
||||
if (urpc_session_ != nullptr) {
|
||||
urpc_close_func(urpc_session_);
|
||||
urpc_session_ = nullptr;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RDMAClient::SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes) {
|
||||
MS_EXCEPTION_IF_NULL(msg);
|
||||
|
@ -77,34 +86,59 @@ bool RDMAClient::SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send
|
|||
sgl.sge[0].flag = URPC_SGE_FLAG_ZERO_COPY;
|
||||
sgl.sge_num = 1;
|
||||
|
||||
int rsp_received = 0;
|
||||
struct req_cb_arg cb_arg = {};
|
||||
cb_arg.rsp_received = &rsp_received;
|
||||
cb_arg.allocator = urpc_allocator_;
|
||||
struct urpc_send_wr send_wr = {};
|
||||
send_wr.func_id = kInterProcessDataHandleID;
|
||||
send_wr.send_mode = URPC_SEND_MODE_SYNC;
|
||||
send_wr.req = &sgl;
|
||||
struct urpc_sgl rsp_sgl = {0};
|
||||
send_wr.sync.rsp = &rsp_sgl;
|
||||
|
||||
if (urpc_send_request_func(urpc_session_, &send_wr, nullptr) < 0) {
|
||||
MS_LOG(ERROR) << "Failed to send request for function call: " << send_wr.func_id;
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Server response message is " << reinterpret_cast<char *>(rsp_sgl.sge[0].addr);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void RDMAClient::SendAsync(std::unique_ptr<MessageBase> &&msg) {
|
||||
MS_EXCEPTION_IF_NULL(msg);
|
||||
size_t msg_size = msg->size;
|
||||
void *msg_buf = msg->data;
|
||||
MS_EXCEPTION_IF_NULL(msg_buf);
|
||||
|
||||
struct urpc_sgl sgl;
|
||||
sgl.sge[0].addr = reinterpret_cast<uintptr_t>(msg_buf);
|
||||
sgl.sge[0].length = msg_size;
|
||||
sgl.sge[0].flag = URPC_SGE_FLAG_ZERO_COPY;
|
||||
sgl.sge_num = 1;
|
||||
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
cb_arg_.rsp_received = false;
|
||||
cb_arg_.allocator = urpc_allocator_;
|
||||
cb_arg_.mtx = &mtx_;
|
||||
cb_arg_.cv = &cv_;
|
||||
lock.unlock();
|
||||
|
||||
struct urpc_send_wr send_wr = {};
|
||||
send_wr.func_id = kInterProcessDataHandleID;
|
||||
send_wr.send_mode = URPC_SEND_MODE_ASYNC;
|
||||
send_wr.req = &sgl;
|
||||
send_wr.async.cb.wo_ctx = urpc_rsp_cb;
|
||||
send_wr.async.cb_arg = &cb_arg;
|
||||
send_wr.async.cb_arg = &cb_arg_;
|
||||
|
||||
if (urpc_send_request_func(urpc_session_, &send_wr, nullptr) != kURPCSuccess) {
|
||||
MS_LOG(ERROR) << "Failed to send request.";
|
||||
return false;
|
||||
if (urpc_send_request_func(urpc_session_, &send_wr, nullptr) < 0) {
|
||||
MS_LOG(EXCEPTION) << "Failed to send request to server.";
|
||||
return;
|
||||
}
|
||||
|
||||
size_t sleep_time_us = 200000;
|
||||
// Wait till server responds.
|
||||
while (rsp_received == 0) {
|
||||
usleep(sleep_time_us);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void RDMAClient::SendAsync(std::unique_ptr<MessageBase> &&msg) {}
|
||||
|
||||
bool RDMAClient::Flush(const std::string &dst_url) { return true; }
|
||||
bool RDMAClient::Flush(const std::string &dst_url) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
cv_.wait(lock, [this]() { return cb_arg_.rsp_received; });
|
||||
return true;
|
||||
}
|
||||
|
||||
void RDMAClient::urpc_rsp_cb(struct urpc_sgl *rsp, int err, void *arg) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(rsp);
|
||||
|
@ -116,7 +150,9 @@ void RDMAClient::urpc_rsp_cb(struct urpc_sgl *rsp, int err, void *arg) {
|
|||
|
||||
MS_LOG(INFO) << "Server response message is " << reinterpret_cast<char *>(rsp->sge[0].addr);
|
||||
struct req_cb_arg *cb_arg = static_cast<struct req_cb_arg *>(arg);
|
||||
*(cb_arg->rsp_received) = 1;
|
||||
std::unique_lock<std::mutex> lock(*(cb_arg->mtx));
|
||||
cb_arg->rsp_received = true;
|
||||
cb_arg->cv->notify_all();
|
||||
}
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
|
|
|
@ -65,6 +65,12 @@ class BACKEND_EXPORT RDMAClient : public RPCClientBase {
|
|||
|
||||
struct urpc_buffer_allocator *urpc_allocator_;
|
||||
urpc_session_t *urpc_session_;
|
||||
|
||||
// The variables for synchronization of async messages.
|
||||
std::mutex mtx_;
|
||||
std::condition_variable cv_;
|
||||
|
||||
struct req_cb_arg cb_arg_;
|
||||
};
|
||||
} // namespace rpc
|
||||
} // namespace distributed
|
||||
|
|
|
@ -31,21 +31,18 @@ bool RDMAServer::Initialize(const std::string &url, const MemAllocateCallback &a
|
|||
port_ = port;
|
||||
|
||||
// Init URPC for RMDA server.
|
||||
struct urpc_config urpc_cfg = {};
|
||||
urpc_cfg.mode = URPC_MODE_SERVER;
|
||||
urpc_cfg.sfeature = 0;
|
||||
urpc_cfg.model = URPC_THREAD_MODEL_R2C;
|
||||
urpc_cfg.worker_num = kServerWorkingThreadNum;
|
||||
urpc_cfg.transport.dev_name = dev_name_;
|
||||
urpc_cfg.transport.ip_addr = ip_addr_;
|
||||
urpc_cfg.transport.port = port_;
|
||||
urpc_cfg.transport.max_sge = 0;
|
||||
urpc_cfg.allocator = nullptr;
|
||||
if (urpc_init_func(&urpc_cfg) != kURPCSuccess) {
|
||||
MS_LOG(EXCEPTION) << "Failed to call urpc_init. Device name: " << dev_name_ << ", ip address: " << ip_addr_
|
||||
<< ", port: " << port_ << ". Please refer to URPC log directory: /var/log/umdk/urpc.";
|
||||
}
|
||||
return true;
|
||||
return InitializeURPC();
|
||||
}
|
||||
|
||||
bool RDMAServer::Initialize(const MemAllocateCallback &allocate_cb) {
|
||||
dev_name_ = const_cast<char *>(common::GetEnv(kRDMADevName).c_str());
|
||||
ip_addr_ = const_cast<char *>(common::GetEnv(kRDMAIP).c_str());
|
||||
port_ = 0;
|
||||
MS_LOG(INFO) << "Initialize RDMA server. Device name: " << dev_name_ << ", ip address: " << ip_addr_
|
||||
<< ", port: " << port_;
|
||||
|
||||
// Init URPC for RMDA server.
|
||||
return InitializeURPC();
|
||||
}
|
||||
|
||||
void RDMAServer::Finalize() {
|
||||
|
@ -67,9 +64,27 @@ void RDMAServer::SetMessageHandler(const MessageHandler &handler) {
|
|||
}
|
||||
}
|
||||
|
||||
std::string RDMAServer::GetIP() const { return ""; }
|
||||
std::string RDMAServer::GetIP() const { return ip_addr_; }
|
||||
|
||||
uint32_t RDMAServer::GetPort() const { return 0; }
|
||||
uint32_t RDMAServer::GetPort() const { return static_cast<uint32_t>(port_); }
|
||||
|
||||
bool RDMAServer::InitializeURPC() {
|
||||
struct urpc_config urpc_cfg = {};
|
||||
urpc_cfg.mode = URPC_MODE_SERVER;
|
||||
urpc_cfg.sfeature = 0;
|
||||
urpc_cfg.model = URPC_THREAD_MODEL_R2C;
|
||||
urpc_cfg.worker_num = kServerWorkingThreadNum;
|
||||
urpc_cfg.transport.dev_name = dev_name_;
|
||||
urpc_cfg.transport.ip_addr = ip_addr_;
|
||||
urpc_cfg.transport.port = port_;
|
||||
urpc_cfg.transport.max_sge = 0;
|
||||
urpc_cfg.allocator = nullptr;
|
||||
if (urpc_init_func(&urpc_cfg) != kURPCSuccess) {
|
||||
MS_LOG(EXCEPTION) << "Failed to call urpc_init. Device name: " << dev_name_ << ", ip address: " << ip_addr_
|
||||
<< ", port: " << port_ << ". Please refer to URPC log directory: /var/log/umdk/urpc.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void RDMAServer::urpc_req_handler(struct urpc_sgl *req, void *arg, struct urpc_sgl *rsp) {
|
||||
MS_ERROR_IF_NULL_WO_RET_VAL(req);
|
||||
|
|
|
@ -37,6 +37,7 @@ class BACKEND_EXPORT RDMAServer : public RPCServerBase {
|
|||
~RDMAServer() override = default;
|
||||
|
||||
bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {}) override;
|
||||
bool Initialize(const MemAllocateCallback &allocate_cb = {}) override;
|
||||
void Finalize() override;
|
||||
void SetMessageHandler(const MessageHandler &handler) override;
|
||||
|
||||
|
@ -47,6 +48,9 @@ class BACKEND_EXPORT RDMAServer : public RPCServerBase {
|
|||
struct urpc_buffer_allocator *urpc_allocator_;
|
||||
|
||||
private:
|
||||
// Initialize urpc configuration according to dev_name_, ip_addr_ and port_.
|
||||
bool InitializeURPC();
|
||||
|
||||
// The message callback for urpc. This method will call user-set message handler.
|
||||
static void urpc_req_handler(struct urpc_sgl *req, void *arg, struct urpc_sgl *rsp);
|
||||
// The callback after this server responding.
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace distributed {
|
|||
namespace rpc {
|
||||
class BACKEND_EXPORT RPCClientBase {
|
||||
public:
|
||||
explicit RPCClientBase(bool enable_ssl = false) : enable_ssl_(enable_ssl) {}
|
||||
explicit RPCClientBase(bool enable_ssl) : enable_ssl_(enable_ssl) {}
|
||||
virtual ~RPCClientBase() = default;
|
||||
|
||||
// Build or destroy the rpc client.
|
||||
|
@ -36,20 +36,29 @@ class BACKEND_EXPORT RPCClientBase {
|
|||
|
||||
// Connect to the specified server.
|
||||
// Function free_cb binds with client's each connection. It frees the real memory after message is sent to the peer.
|
||||
virtual bool Connect(const std::string &dst_url, size_t retry_count, const MemFreeCallback &free_cb) { return true; }
|
||||
virtual bool Connect(
|
||||
const std::string &dst_url, size_t retry_count = 60, const MemFreeCallback &free_cb = [](void *data) {
|
||||
MS_ERROR_IF_NULL(data);
|
||||
delete static_cast<char *>(data);
|
||||
return true;
|
||||
}) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check if the connection to dst_url has been established.
|
||||
virtual bool IsConnected(const std::string &dst_url) { return false; }
|
||||
|
||||
// Disconnect from the specified server.
|
||||
virtual bool Disconnect(const std::string &dst_url, size_t timeout_in_sec) { return true; }
|
||||
virtual bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5) { return true; }
|
||||
|
||||
// Send the message from the source to the destination synchronously and return the byte size by this method call.
|
||||
virtual bool SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes) { return true; }
|
||||
virtual bool SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes = nullptr) { return true; }
|
||||
|
||||
// Send the message from the source to the destination asynchronously.
|
||||
virtual void SendAsync(std::unique_ptr<MessageBase> &&msg) {}
|
||||
|
||||
virtual MessageBase *ReceiveSync(std::unique_ptr<MessageBase> &&msg, uint32_t timeout = 30) { return nullptr; }
|
||||
|
||||
// Force the data in the send buffer to be sent out.
|
||||
virtual bool Flush(const std::string &dst_url) { return true; }
|
||||
|
||||
|
|
|
@ -31,10 +31,10 @@ class BACKEND_EXPORT RPCServerBase {
|
|||
virtual ~RPCServerBase() = default;
|
||||
|
||||
// Init server using the specified url, with memory allocating function.
|
||||
virtual bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb) { return true; }
|
||||
virtual bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {}) { return true; }
|
||||
|
||||
// Init server using local IP and random port.
|
||||
virtual bool Initialize(const MemAllocateCallback &allocate_cb) { return true; }
|
||||
virtual bool Initialize(const MemAllocateCallback &allocate_cb = {}) { return true; }
|
||||
|
||||
// Destroy the tcp server.
|
||||
virtual void Finalize() {}
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
|
||||
#include "distributed/rpc/rpc_client_base.h"
|
||||
#include "distributed/rpc/tcp/tcp_comm.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "include/backend/visible.h"
|
||||
|
@ -29,14 +30,14 @@
|
|||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
class BACKEND_EXPORT TCPClient {
|
||||
class BACKEND_EXPORT TCPClient : public RPCClientBase {
|
||||
public:
|
||||
explicit TCPClient(bool enable_ssl = false) : enable_ssl_(enable_ssl) {}
|
||||
~TCPClient() = default;
|
||||
explicit TCPClient(bool enable_ssl = false) : RPCClientBase(enable_ssl) {}
|
||||
~TCPClient() override = default;
|
||||
|
||||
// Build or destroy the TCP client.
|
||||
bool Initialize();
|
||||
void Finalize();
|
||||
bool Initialize() override;
|
||||
void Finalize() override;
|
||||
|
||||
// Connect to the specified server.
|
||||
// Function free_cb binds with client's each connection. It frees the real memory after message is sent to the peer.
|
||||
|
@ -45,26 +46,26 @@ class BACKEND_EXPORT TCPClient {
|
|||
MS_ERROR_IF_NULL(data);
|
||||
delete static_cast<char *>(data);
|
||||
return true;
|
||||
});
|
||||
}) override;
|
||||
|
||||
// Check if the connection to dst_url has been established.
|
||||
bool IsConnected(const std::string &dst_url);
|
||||
bool IsConnected(const std::string &dst_url) override;
|
||||
|
||||
// Disconnect from the specified server.
|
||||
bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5);
|
||||
bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5) override;
|
||||
|
||||
// Send the message from the source to the destination synchronously and return the byte size by this method call.
|
||||
bool SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes = nullptr);
|
||||
bool SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes = nullptr) override;
|
||||
|
||||
// Send the message from the source to the destination asynchronously.
|
||||
void SendAsync(std::unique_ptr<MessageBase> &&msg);
|
||||
void SendAsync(std::unique_ptr<MessageBase> &&msg) override;
|
||||
|
||||
// Retrieve a message from tcp server specified by the input message.
|
||||
// Returns nullptr after timeout.
|
||||
MessageBase *ReceiveSync(std::unique_ptr<MessageBase> &&msg, uint32_t timeout = 30);
|
||||
MessageBase *ReceiveSync(std::unique_ptr<MessageBase> &&msg, uint32_t timeout = 30) override;
|
||||
|
||||
// Force the data in the send buffer to be sent out.
|
||||
bool Flush(const std::string &dst_url);
|
||||
bool Flush(const std::string &dst_url) override;
|
||||
|
||||
private:
|
||||
// The basic TCP communication component used by the client.
|
||||
|
@ -78,8 +79,6 @@ class BACKEND_EXPORT TCPClient {
|
|||
// The received message from the meta server by calling the method `ReceiveSync`.
|
||||
MessageBase *received_message_{nullptr};
|
||||
|
||||
bool enable_ssl_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(TCPClient);
|
||||
};
|
||||
} // namespace rpc
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "distributed/rpc/rpc_server_base.h"
|
||||
#include "distributed/rpc/tcp/tcp_comm.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "include/backend/visible.h"
|
||||
|
@ -27,26 +28,26 @@
|
|||
namespace mindspore {
|
||||
namespace distributed {
|
||||
namespace rpc {
|
||||
class BACKEND_EXPORT TCPServer {
|
||||
class BACKEND_EXPORT TCPServer : public RPCServerBase {
|
||||
public:
|
||||
explicit TCPServer(bool enable_ssl = false) : enable_ssl_(enable_ssl) {}
|
||||
~TCPServer() = default;
|
||||
explicit TCPServer(bool enable_ssl = false) : RPCServerBase(enable_ssl) {}
|
||||
~TCPServer() override = default;
|
||||
|
||||
// Init the tcp server using the specified url.
|
||||
bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {});
|
||||
bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {}) override;
|
||||
|
||||
// Init the tcp server using local IP and random port.
|
||||
bool Initialize(const MemAllocateCallback &allocate_cb = {});
|
||||
bool Initialize(const MemAllocateCallback &allocate_cb = {}) override;
|
||||
|
||||
// Destroy the tcp server.
|
||||
void Finalize();
|
||||
void Finalize() override;
|
||||
|
||||
// Set the message processing handler.
|
||||
void SetMessageHandler(const MessageHandler &handler);
|
||||
void SetMessageHandler(const MessageHandler &handler) override;
|
||||
|
||||
// Return the IP and port binded by this server.
|
||||
std::string GetIP() const;
|
||||
uint32_t GetPort() const;
|
||||
std::string GetIP() const override;
|
||||
uint32_t GetPort() const override;
|
||||
|
||||
private:
|
||||
bool InitializeImpl(const std::string &url, const MemAllocateCallback &allocate_cb);
|
||||
|
@ -54,11 +55,6 @@ class BACKEND_EXPORT TCPServer {
|
|||
// The basic TCP communication component used by the server.
|
||||
std::unique_ptr<TCPComm> tcp_comm_{nullptr};
|
||||
|
||||
std::string ip_{""};
|
||||
uint32_t port_{0};
|
||||
|
||||
bool enable_ssl_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(TCPServer);
|
||||
};
|
||||
} // namespace rpc
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include <functional>
|
||||
#include <condition_variable>
|
||||
#include "proto/topology.pb.h"
|
||||
#include "distributed/rpc/tcp/constants.h"
|
||||
#include "plugin/device/cpu/kernel/rpc/rpc_recv_kernel.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
|
||||
|
@ -62,15 +61,24 @@ void RecvActor::SetRouteInfo(uint32_t, const std::string &, const std::string &r
|
|||
}
|
||||
|
||||
bool RecvActor::StartServer() {
|
||||
// Step 1: Create a tcp server and start listening.
|
||||
// Step 1: Create a rpc server and start listening.
|
||||
|
||||
#ifdef ENABLE_RDMA
|
||||
if (common::GetEnv(kEnableRDMA) == "1") {
|
||||
server_ = std::make_unique<RDMAServer>();
|
||||
} else {
|
||||
server_ = std::make_unique<TCPServer>();
|
||||
}
|
||||
#else
|
||||
server_ = std::make_unique<TCPServer>();
|
||||
#endif
|
||||
MS_EXCEPTION_IF_NULL(server_);
|
||||
|
||||
// Set the memory allocating callback using void* message.
|
||||
std::function<void *(size_t size)> allocate_callback =
|
||||
std::bind(&RecvActor::AllocateMessage, this, std::placeholders::_1);
|
||||
if (!server_->Initialize(allocate_callback)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to initialize tcp server for recv actor";
|
||||
MS_LOG(EXCEPTION) << "Failed to initialize rpc server for recv actor";
|
||||
}
|
||||
ip_ = server_->GetIP();
|
||||
port_ = server_->GetPort();
|
||||
|
|
|
@ -62,7 +62,7 @@ class RecvActor : public RpcActor {
|
|||
// Start recv actor server and register this server address to actor route table in scheduler by proxy.
|
||||
bool StartServer();
|
||||
|
||||
// Finalize tcp server.
|
||||
// Finalize rpc server.
|
||||
void Clear() override;
|
||||
|
||||
void StopRpcAtException() override;
|
||||
|
@ -102,7 +102,7 @@ class RecvActor : public RpcActor {
|
|||
*/
|
||||
void *AllocateMemByDeviceRes(size_t size);
|
||||
|
||||
std::unique_ptr<TCPServer> server_;
|
||||
std::unique_ptr<RPCServerBase> server_;
|
||||
|
||||
// The variables used to ensure thread-safe of op context visited by recv actor.
|
||||
bool is_context_valid_;
|
||||
|
@ -129,7 +129,7 @@ class RecvActor : public RpcActor {
|
|||
// it, e.g., infer shape for RpcRecv kernel and call Resize().
|
||||
void PreprocessRemoteInput(const MessageBase *const msg, bool *need_finalize);
|
||||
|
||||
// The message callback of the tcp server.
|
||||
// The message callback of the rpc server.
|
||||
MessageBase *HandleMessage(MessageBase *const msg);
|
||||
|
||||
// The network address of this recv actor. It's generated automatically by rpc module.
|
||||
|
|
|
@ -26,6 +26,10 @@
|
|||
#include "distributed/cluster/cluster_context.h"
|
||||
#include "distributed/rpc/tcp/tcp_client.h"
|
||||
#include "distributed/rpc/tcp/tcp_server.h"
|
||||
#ifdef ENABLE_RDMA
|
||||
#include "distributed/rpc/rdma/rdma_client.h"
|
||||
#include "distributed/rpc/rdma/rdma_server.h"
|
||||
#endif
|
||||
#include "proto/rpc.pb.h"
|
||||
#include "proto/topology.pb.h"
|
||||
|
||||
|
@ -35,9 +39,15 @@ using distributed::cluster::ActorRouteTableProxy;
|
|||
using distributed::cluster::ActorRouteTableProxyPtr;
|
||||
using distributed::cluster::ClusterContext;
|
||||
using distributed::cluster::topology::ActorAddress;
|
||||
using distributed::rpc::RPCClientBase;
|
||||
using distributed::rpc::RPCServerBase;
|
||||
using distributed::rpc::TCPClient;
|
||||
using distributed::rpc::TCPServer;
|
||||
using mindspore::device::KernelInfo;
|
||||
#ifdef ENABLE_RDMA
|
||||
using distributed::rpc::RDMAClient;
|
||||
using distributed::rpc::RDMAServer;
|
||||
#endif
|
||||
|
||||
// The inter-process edge mark between two nodes.
|
||||
constexpr char kInterProcessEdgeMark[] = "->";
|
||||
|
|
|
@ -27,7 +27,7 @@ SendActor::~SendActor() {
|
|||
(void)client_->Disconnect(server_url_);
|
||||
client_->Finalize();
|
||||
} catch (const std::exception &) {
|
||||
MS_LOG(ERROR) << "Failed to disconnect and finalize for tcp client in send actor.";
|
||||
MS_LOG(ERROR) << "Failed to disconnect and finalize for rpc client in send actor.";
|
||||
}
|
||||
client_ = nullptr;
|
||||
}
|
||||
|
@ -40,10 +40,19 @@ void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &s
|
|||
}
|
||||
|
||||
bool SendActor::ConnectServer() {
|
||||
#ifdef ENABLE_RDMA
|
||||
if (common::GetEnv(kEnableRDMA) == "1") {
|
||||
client_ = std::make_unique<RDMAClient>();
|
||||
} else {
|
||||
client_ = std::make_unique<TCPClient>();
|
||||
}
|
||||
#else
|
||||
client_ = std::make_unique<TCPClient>();
|
||||
#endif
|
||||
MS_EXCEPTION_IF_NULL(client_);
|
||||
|
||||
if (!client_->Initialize()) {
|
||||
MS_LOG(EXCEPTION) << "Failed to initialize tcp server for send actor.";
|
||||
MS_LOG(EXCEPTION) << "Failed to initialize rpc server for send actor.";
|
||||
}
|
||||
// Lookup actor addresses for each peer actor.
|
||||
for (const auto &peer_actor_id : peer_actor_ids_) {
|
||||
|
|
|
@ -49,7 +49,7 @@ class SendActor : public RpcActor {
|
|||
// Flush and wait for sent data to be passed to kernel.
|
||||
void FlushData() override;
|
||||
|
||||
// Finalize tcp client.
|
||||
// Finalize rpc client.
|
||||
void Clear() override;
|
||||
|
||||
protected:
|
||||
|
@ -76,8 +76,8 @@ class SendActor : public RpcActor {
|
|||
*/
|
||||
virtual void Flush();
|
||||
|
||||
// The tcp client connection to multiple servers.
|
||||
std::unique_ptr<TCPClient> client_;
|
||||
// The rpc client connection to multiple servers.
|
||||
std::unique_ptr<RPCClientBase> client_;
|
||||
|
||||
private:
|
||||
/**
|
||||
|
@ -136,7 +136,7 @@ class SendActor : public RpcActor {
|
|||
std::vector<std::string> peer_actor_ids_;
|
||||
mindspore::HashMap<std::string, std::string> peer_actor_urls_;
|
||||
|
||||
// The url of the peer recv actor's tcp server.
|
||||
// The url of the peer recv actor's server.
|
||||
std::string server_url_;
|
||||
};
|
||||
|
||||
|
|
|
@ -45,7 +45,9 @@ class RDMATest : public UT::Common {
|
|||
std::unique_ptr<MessageBase> RDMATest::CreateMessage(const std::string &msg) {
|
||||
std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
|
||||
size_t msg_size = msg.size();
|
||||
ASSERT_TRUE(msg_size != 0);
|
||||
if (msg_size == 0) {
|
||||
MS_LOG(EXCEPTION) << "msg_size is 0!";
|
||||
}
|
||||
void *data = malloc(msg_size + 1);
|
||||
(void)memcpy_s(data, msg_size, msg.c_str(), msg_size);
|
||||
message->data = data;
|
||||
|
@ -57,6 +59,36 @@ std::unique_ptr<MessageBase> RDMATest::CreateMessage(const std::string &msg) {
|
|||
/// Description: test basic connection function between RDMA client and server.
|
||||
/// Expectation: RDMA client successfully connects to RDMA server and sends a simple message.
|
||||
TEST_F(RDMATest, TestRDMAConnection) {
|
||||
std::string url = "127.0.0.1:10969";
|
||||
size_t server_pid = fork();
|
||||
if (server_pid == 0) {
|
||||
std::shared_ptr<RDMAServer> rdma_server = std::make_shared<RDMAServer>();
|
||||
MS_EXCEPTION_IF_NULL(rdma_server);
|
||||
ASSERT_TRUE(rdma_server->Initialize(url));
|
||||
sleep(3);
|
||||
rdma_server->Finalize();
|
||||
return;
|
||||
}
|
||||
sleep(1);
|
||||
size_t client_pid = fork();
|
||||
if (client_pid == 0) {
|
||||
std::shared_ptr<RDMAClient> rdma_client = std::make_shared<RDMAClient>();
|
||||
MS_EXCEPTION_IF_NULL(rdma_client);
|
||||
ASSERT_TRUE(rdma_client->Initialize());
|
||||
ASSERT_TRUE(rdma_client->Connect(url));
|
||||
rdma_client->Finalize();
|
||||
return;
|
||||
}
|
||||
|
||||
int wstatus;
|
||||
(void)waitpid(client_pid, &wstatus, WUNTRACED | WCONTINUED);
|
||||
(void)waitpid(server_pid, &wstatus, WUNTRACED | WCONTINUED);
|
||||
}
|
||||
|
||||
/// Feature: RDMA communication.
|
||||
/// Description: test SendSync interface for RDMA client and server.
|
||||
/// Expectation: RDMA client successfully sends two messages to RDMA server synchronously.
|
||||
TEST_F(RDMATest, TestRDMASendSync) {
|
||||
std::string url = "127.0.0.1:10969";
|
||||
size_t server_pid = fork();
|
||||
if (server_pid == 0) {
|
||||
|
@ -70,6 +102,7 @@ TEST_F(RDMATest, TestRDMAConnection) {
|
|||
};
|
||||
rdma_server->SetMessageHandler(msg_handler);
|
||||
sleep(3);
|
||||
rdma_server->Finalize();
|
||||
return;
|
||||
}
|
||||
sleep(1);
|
||||
|
@ -80,8 +113,54 @@ TEST_F(RDMATest, TestRDMAConnection) {
|
|||
ASSERT_TRUE(rdma_client->Initialize());
|
||||
ASSERT_TRUE(rdma_client->Connect(url));
|
||||
|
||||
auto message = CreateMessage("Hello server!");
|
||||
ASSERT_TRUE(rdma_client->SendSync(std::move(message)));
|
||||
auto message1 = CreateMessage("Hello server sync!");
|
||||
ASSERT_TRUE(rdma_client->SendSync(std::move(message1)));
|
||||
auto message2 = CreateMessage("Hello server sync!");
|
||||
ASSERT_TRUE(rdma_client->SendSync(std::move(message2)));
|
||||
rdma_client->Finalize();
|
||||
return;
|
||||
}
|
||||
|
||||
int wstatus;
|
||||
(void)waitpid(client_pid, &wstatus, WUNTRACED | WCONTINUED);
|
||||
(void)waitpid(server_pid, &wstatus, WUNTRACED | WCONTINUED);
|
||||
}
|
||||
|
||||
/// Feature: RDMA communication.
|
||||
/// Description: test SendAsync interface for RDMA client and server.
|
||||
/// Expectation: RDMA client successfully sends two messages to RDMA server asynchronously.
|
||||
TEST_F(RDMATest, TestRDMASendAsync) {
|
||||
std::string url = "127.0.0.1:10969";
|
||||
size_t server_pid = fork();
|
||||
if (server_pid == 0) {
|
||||
std::shared_ptr<RDMAServer> rdma_server = std::make_shared<RDMAServer>();
|
||||
MS_EXCEPTION_IF_NULL(rdma_server);
|
||||
ASSERT_TRUE(rdma_server->Initialize(url));
|
||||
|
||||
auto msg_handler = [](MessageBase *const msg) {
|
||||
MS_LOG(INFO) << "Receive message from client: " << static_cast<char *>(msg->data);
|
||||
return nullptr;
|
||||
};
|
||||
rdma_server->SetMessageHandler(msg_handler);
|
||||
sleep(3);
|
||||
rdma_server->Finalize();
|
||||
return;
|
||||
}
|
||||
sleep(1);
|
||||
size_t client_pid = fork();
|
||||
if (client_pid == 0) {
|
||||
std::shared_ptr<RDMAClient> rdma_client = std::make_shared<RDMAClient>();
|
||||
MS_EXCEPTION_IF_NULL(rdma_client);
|
||||
ASSERT_TRUE(rdma_client->Initialize());
|
||||
ASSERT_TRUE(rdma_client->Connect(url));
|
||||
|
||||
auto message1 = CreateMessage("Hello server async!");
|
||||
rdma_client->SendAsync(std::move(message1));
|
||||
ASSERT_TRUE(rdma_client->Flush(url));
|
||||
auto message2 = CreateMessage("Hello server async!");
|
||||
rdma_client->SendAsync(std::move(message2));
|
||||
ASSERT_TRUE(rdma_client->Flush(url));
|
||||
|
||||
rdma_client->Finalize();
|
||||
return;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue