Enhance rdma impl

This commit is contained in:
ZPaC 2023-02-11 11:14:51 +08:00
parent 364e292587
commit fe2567b593
16 changed files with 267 additions and 88 deletions

View File

@ -95,6 +95,10 @@ constexpr char kControlDstOpName[] = "ControlDst";
static const char URL_PROTOCOL_IP_SEPARATOR[] = "://";
static const char URL_IP_PORT_SEPARATOR[] = ":";
constexpr char kEnableRDMA[] = "enable_rdma";
constexpr char kRDMADevName[] = "rdma_dev";
constexpr char kRDMAIP[] = "rdma_ip";
// This macro the current timestamp in milliseconds.
#define CURRENT_TIMESTAMP_MILLI \
(std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()))

View File

@ -18,7 +18,9 @@
#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_CONSTANTS_H_
#include <urpc.h>
#include <mutex>
#include <string>
#include <condition_variable>
#include "utils/dlopen_macro.h"
#include "distributed/constants.h"
@ -114,8 +116,10 @@ constexpr uint32_t kServerWorkingThreadNum = 4;
constexpr uint32_t kClientPollingThreadNum = 4;
struct req_cb_arg {
int *rsp_received;
bool rsp_received;
struct urpc_buffer_allocator *allocator;
std::mutex *mtx;
std::condition_variable *cv;
};
} // namespace rpc
} // namespace distributed

View File

@ -21,7 +21,10 @@
namespace mindspore {
namespace distributed {
namespace rpc {
bool RDMAClient::Initialize() { return true; }
bool RDMAClient::Initialize() {
// The initialization of URPC RDMA is implemented in 'Connect' function.
return true;
}
void RDMAClient::Finalize() {
if (urpc_session_ != nullptr) {
@ -63,7 +66,13 @@ bool RDMAClient::Connect(const std::string &dst_url, size_t retry_count, const M
bool RDMAClient::IsConnected(const std::string &dst_url) { return false; }
bool RDMAClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) { return true; }
bool RDMAClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) {
if (urpc_session_ != nullptr) {
urpc_close_func(urpc_session_);
urpc_session_ = nullptr;
}
return true;
}
bool RDMAClient::SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes) {
MS_EXCEPTION_IF_NULL(msg);
@ -77,34 +86,59 @@ bool RDMAClient::SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send
sgl.sge[0].flag = URPC_SGE_FLAG_ZERO_COPY;
sgl.sge_num = 1;
int rsp_received = 0;
struct req_cb_arg cb_arg = {};
cb_arg.rsp_received = &rsp_received;
cb_arg.allocator = urpc_allocator_;
struct urpc_send_wr send_wr = {};
send_wr.func_id = kInterProcessDataHandleID;
send_wr.send_mode = URPC_SEND_MODE_SYNC;
send_wr.req = &sgl;
struct urpc_sgl rsp_sgl = {0};
send_wr.sync.rsp = &rsp_sgl;
if (urpc_send_request_func(urpc_session_, &send_wr, nullptr) < 0) {
MS_LOG(ERROR) << "Failed to send request for function call: " << send_wr.func_id;
return false;
}
MS_LOG(INFO) << "Server response message is " << reinterpret_cast<char *>(rsp_sgl.sge[0].addr);
return true;
}
void RDMAClient::SendAsync(std::unique_ptr<MessageBase> &&msg) {
MS_EXCEPTION_IF_NULL(msg);
size_t msg_size = msg->size;
void *msg_buf = msg->data;
MS_EXCEPTION_IF_NULL(msg_buf);
struct urpc_sgl sgl;
sgl.sge[0].addr = reinterpret_cast<uintptr_t>(msg_buf);
sgl.sge[0].length = msg_size;
sgl.sge[0].flag = URPC_SGE_FLAG_ZERO_COPY;
sgl.sge_num = 1;
std::unique_lock<std::mutex> lock(mtx_);
cb_arg_.rsp_received = false;
cb_arg_.allocator = urpc_allocator_;
cb_arg_.mtx = &mtx_;
cb_arg_.cv = &cv_;
lock.unlock();
struct urpc_send_wr send_wr = {};
send_wr.func_id = kInterProcessDataHandleID;
send_wr.send_mode = URPC_SEND_MODE_ASYNC;
send_wr.req = &sgl;
send_wr.async.cb.wo_ctx = urpc_rsp_cb;
send_wr.async.cb_arg = &cb_arg;
send_wr.async.cb_arg = &cb_arg_;
if (urpc_send_request_func(urpc_session_, &send_wr, nullptr) != kURPCSuccess) {
MS_LOG(ERROR) << "Failed to send request.";
return false;
if (urpc_send_request_func(urpc_session_, &send_wr, nullptr) < 0) {
MS_LOG(EXCEPTION) << "Failed to send request to server.";
return;
}
size_t sleep_time_us = 200000;
// Wait till server responds.
while (rsp_received == 0) {
usleep(sleep_time_us);
}
return true;
}
void RDMAClient::SendAsync(std::unique_ptr<MessageBase> &&msg) {}
bool RDMAClient::Flush(const std::string &dst_url) { return true; }
bool RDMAClient::Flush(const std::string &dst_url) {
std::unique_lock<std::mutex> lock(mtx_);
cv_.wait(lock, [this]() { return cb_arg_.rsp_received; });
return true;
}
void RDMAClient::urpc_rsp_cb(struct urpc_sgl *rsp, int err, void *arg) {
MS_ERROR_IF_NULL_WO_RET_VAL(rsp);
@ -116,7 +150,9 @@ void RDMAClient::urpc_rsp_cb(struct urpc_sgl *rsp, int err, void *arg) {
MS_LOG(INFO) << "Server response message is " << reinterpret_cast<char *>(rsp->sge[0].addr);
struct req_cb_arg *cb_arg = static_cast<struct req_cb_arg *>(arg);
*(cb_arg->rsp_received) = 1;
std::unique_lock<std::mutex> lock(*(cb_arg->mtx));
cb_arg->rsp_received = true;
cb_arg->cv->notify_all();
}
} // namespace rpc
} // namespace distributed

View File

@ -65,6 +65,12 @@ class BACKEND_EXPORT RDMAClient : public RPCClientBase {
struct urpc_buffer_allocator *urpc_allocator_;
urpc_session_t *urpc_session_;
// The variables for synchronization of async messages.
std::mutex mtx_;
std::condition_variable cv_;
struct req_cb_arg cb_arg_;
};
} // namespace rpc
} // namespace distributed

View File

@ -31,21 +31,18 @@ bool RDMAServer::Initialize(const std::string &url, const MemAllocateCallback &a
port_ = port;
// Init URPC for RMDA server.
struct urpc_config urpc_cfg = {};
urpc_cfg.mode = URPC_MODE_SERVER;
urpc_cfg.sfeature = 0;
urpc_cfg.model = URPC_THREAD_MODEL_R2C;
urpc_cfg.worker_num = kServerWorkingThreadNum;
urpc_cfg.transport.dev_name = dev_name_;
urpc_cfg.transport.ip_addr = ip_addr_;
urpc_cfg.transport.port = port_;
urpc_cfg.transport.max_sge = 0;
urpc_cfg.allocator = nullptr;
if (urpc_init_func(&urpc_cfg) != kURPCSuccess) {
MS_LOG(EXCEPTION) << "Failed to call urpc_init. Device name: " << dev_name_ << ", ip address: " << ip_addr_
<< ", port: " << port_ << ". Please refer to URPC log directory: /var/log/umdk/urpc.";
}
return true;
return InitializeURPC();
}
bool RDMAServer::Initialize(const MemAllocateCallback &allocate_cb) {
dev_name_ = const_cast<char *>(common::GetEnv(kRDMADevName).c_str());
ip_addr_ = const_cast<char *>(common::GetEnv(kRDMAIP).c_str());
port_ = 0;
MS_LOG(INFO) << "Initialize RDMA server. Device name: " << dev_name_ << ", ip address: " << ip_addr_
<< ", port: " << port_;
// Init URPC for RMDA server.
return InitializeURPC();
}
void RDMAServer::Finalize() {
@ -67,9 +64,27 @@ void RDMAServer::SetMessageHandler(const MessageHandler &handler) {
}
}
std::string RDMAServer::GetIP() const { return ""; }
std::string RDMAServer::GetIP() const { return ip_addr_; }
uint32_t RDMAServer::GetPort() const { return 0; }
uint32_t RDMAServer::GetPort() const { return static_cast<uint32_t>(port_); }
bool RDMAServer::InitializeURPC() {
struct urpc_config urpc_cfg = {};
urpc_cfg.mode = URPC_MODE_SERVER;
urpc_cfg.sfeature = 0;
urpc_cfg.model = URPC_THREAD_MODEL_R2C;
urpc_cfg.worker_num = kServerWorkingThreadNum;
urpc_cfg.transport.dev_name = dev_name_;
urpc_cfg.transport.ip_addr = ip_addr_;
urpc_cfg.transport.port = port_;
urpc_cfg.transport.max_sge = 0;
urpc_cfg.allocator = nullptr;
if (urpc_init_func(&urpc_cfg) != kURPCSuccess) {
MS_LOG(EXCEPTION) << "Failed to call urpc_init. Device name: " << dev_name_ << ", ip address: " << ip_addr_
<< ", port: " << port_ << ". Please refer to URPC log directory: /var/log/umdk/urpc.";
}
return true;
}
void RDMAServer::urpc_req_handler(struct urpc_sgl *req, void *arg, struct urpc_sgl *rsp) {
MS_ERROR_IF_NULL_WO_RET_VAL(req);

View File

@ -37,6 +37,7 @@ class BACKEND_EXPORT RDMAServer : public RPCServerBase {
~RDMAServer() override = default;
bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {}) override;
bool Initialize(const MemAllocateCallback &allocate_cb = {}) override;
void Finalize() override;
void SetMessageHandler(const MessageHandler &handler) override;
@ -47,6 +48,9 @@ class BACKEND_EXPORT RDMAServer : public RPCServerBase {
struct urpc_buffer_allocator *urpc_allocator_;
private:
// Initialize urpc configuration according to dev_name_, ip_addr_ and port_.
bool InitializeURPC();
// The message callback for urpc. This method will call user-set message handler.
static void urpc_req_handler(struct urpc_sgl *req, void *arg, struct urpc_sgl *rsp);
// The callback after this server responding.

View File

@ -27,7 +27,7 @@ namespace distributed {
namespace rpc {
class BACKEND_EXPORT RPCClientBase {
public:
explicit RPCClientBase(bool enable_ssl = false) : enable_ssl_(enable_ssl) {}
explicit RPCClientBase(bool enable_ssl) : enable_ssl_(enable_ssl) {}
virtual ~RPCClientBase() = default;
// Build or destroy the rpc client.
@ -36,20 +36,29 @@ class BACKEND_EXPORT RPCClientBase {
// Connect to the specified server.
// Function free_cb binds with client's each connection. It frees the real memory after message is sent to the peer.
virtual bool Connect(const std::string &dst_url, size_t retry_count, const MemFreeCallback &free_cb) { return true; }
virtual bool Connect(
const std::string &dst_url, size_t retry_count = 60, const MemFreeCallback &free_cb = [](void *data) {
MS_ERROR_IF_NULL(data);
delete static_cast<char *>(data);
return true;
}) {
return true;
}
// Check if the connection to dst_url has been established.
virtual bool IsConnected(const std::string &dst_url) { return false; }
// Disconnect from the specified server.
virtual bool Disconnect(const std::string &dst_url, size_t timeout_in_sec) { return true; }
virtual bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5) { return true; }
// Send the message from the source to the destination synchronously and return the byte size by this method call.
virtual bool SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes) { return true; }
virtual bool SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes = nullptr) { return true; }
// Send the message from the source to the destination asynchronously.
virtual void SendAsync(std::unique_ptr<MessageBase> &&msg) {}
virtual MessageBase *ReceiveSync(std::unique_ptr<MessageBase> &&msg, uint32_t timeout = 30) { return nullptr; }
// Force the data in the send buffer to be sent out.
virtual bool Flush(const std::string &dst_url) { return true; }

View File

@ -31,10 +31,10 @@ class BACKEND_EXPORT RPCServerBase {
virtual ~RPCServerBase() = default;
// Init server using the specified url, with memory allocating function.
virtual bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb) { return true; }
virtual bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {}) { return true; }
// Init server using local IP and random port.
virtual bool Initialize(const MemAllocateCallback &allocate_cb) { return true; }
virtual bool Initialize(const MemAllocateCallback &allocate_cb = {}) { return true; }
// Destroy the tcp server.
virtual void Finalize() {}

View File

@ -22,6 +22,7 @@
#include <mutex>
#include <condition_variable>
#include "distributed/rpc/rpc_client_base.h"
#include "distributed/rpc/tcp/tcp_comm.h"
#include "utils/ms_utils.h"
#include "include/backend/visible.h"
@ -29,14 +30,14 @@
namespace mindspore {
namespace distributed {
namespace rpc {
class BACKEND_EXPORT TCPClient {
class BACKEND_EXPORT TCPClient : public RPCClientBase {
public:
explicit TCPClient(bool enable_ssl = false) : enable_ssl_(enable_ssl) {}
~TCPClient() = default;
explicit TCPClient(bool enable_ssl = false) : RPCClientBase(enable_ssl) {}
~TCPClient() override = default;
// Build or destroy the TCP client.
bool Initialize();
void Finalize();
bool Initialize() override;
void Finalize() override;
// Connect to the specified server.
// Function free_cb binds with client's each connection. It frees the real memory after message is sent to the peer.
@ -45,26 +46,26 @@ class BACKEND_EXPORT TCPClient {
MS_ERROR_IF_NULL(data);
delete static_cast<char *>(data);
return true;
});
}) override;
// Check if the connection to dst_url has been established.
bool IsConnected(const std::string &dst_url);
bool IsConnected(const std::string &dst_url) override;
// Disconnect from the specified server.
bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5);
bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5) override;
// Send the message from the source to the destination synchronously and return the byte size by this method call.
bool SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes = nullptr);
bool SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes = nullptr) override;
// Send the message from the source to the destination asynchronously.
void SendAsync(std::unique_ptr<MessageBase> &&msg);
void SendAsync(std::unique_ptr<MessageBase> &&msg) override;
// Retrieve a message from tcp server specified by the input message.
// Returns nullptr after timeout.
MessageBase *ReceiveSync(std::unique_ptr<MessageBase> &&msg, uint32_t timeout = 30);
MessageBase *ReceiveSync(std::unique_ptr<MessageBase> &&msg, uint32_t timeout = 30) override;
// Force the data in the send buffer to be sent out.
bool Flush(const std::string &dst_url);
bool Flush(const std::string &dst_url) override;
private:
// The basic TCP communication component used by the client.
@ -78,8 +79,6 @@ class BACKEND_EXPORT TCPClient {
// The received message from the meta server by calling the method `ReceiveSync`.
MessageBase *received_message_{nullptr};
bool enable_ssl_;
DISABLE_COPY_AND_ASSIGN(TCPClient);
};
} // namespace rpc

View File

@ -20,6 +20,7 @@
#include <string>
#include <memory>
#include "distributed/rpc/rpc_server_base.h"
#include "distributed/rpc/tcp/tcp_comm.h"
#include "utils/ms_utils.h"
#include "include/backend/visible.h"
@ -27,26 +28,26 @@
namespace mindspore {
namespace distributed {
namespace rpc {
class BACKEND_EXPORT TCPServer {
class BACKEND_EXPORT TCPServer : public RPCServerBase {
public:
explicit TCPServer(bool enable_ssl = false) : enable_ssl_(enable_ssl) {}
~TCPServer() = default;
explicit TCPServer(bool enable_ssl = false) : RPCServerBase(enable_ssl) {}
~TCPServer() override = default;
// Init the tcp server using the specified url.
bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {});
bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {}) override;
// Init the tcp server using local IP and random port.
bool Initialize(const MemAllocateCallback &allocate_cb = {});
bool Initialize(const MemAllocateCallback &allocate_cb = {}) override;
// Destroy the tcp server.
void Finalize();
void Finalize() override;
// Set the message processing handler.
void SetMessageHandler(const MessageHandler &handler);
void SetMessageHandler(const MessageHandler &handler) override;
// Return the IP and port binded by this server.
std::string GetIP() const;
uint32_t GetPort() const;
std::string GetIP() const override;
uint32_t GetPort() const override;
private:
bool InitializeImpl(const std::string &url, const MemAllocateCallback &allocate_cb);
@ -54,11 +55,6 @@ class BACKEND_EXPORT TCPServer {
// The basic TCP communication component used by the server.
std::unique_ptr<TCPComm> tcp_comm_{nullptr};
std::string ip_{""};
uint32_t port_{0};
bool enable_ssl_;
DISABLE_COPY_AND_ASSIGN(TCPServer);
};
} // namespace rpc

View File

@ -21,7 +21,6 @@
#include <functional>
#include <condition_variable>
#include "proto/topology.pb.h"
#include "distributed/rpc/tcp/constants.h"
#include "plugin/device/cpu/kernel/rpc/rpc_recv_kernel.h"
#include "backend/common/optimizer/helper.h"
@ -62,15 +61,24 @@ void RecvActor::SetRouteInfo(uint32_t, const std::string &, const std::string &r
}
bool RecvActor::StartServer() {
// Step 1: Create a tcp server and start listening.
// Step 1: Create a rpc server and start listening.
#ifdef ENABLE_RDMA
if (common::GetEnv(kEnableRDMA) == "1") {
server_ = std::make_unique<RDMAServer>();
} else {
server_ = std::make_unique<TCPServer>();
}
#else
server_ = std::make_unique<TCPServer>();
#endif
MS_EXCEPTION_IF_NULL(server_);
// Set the memory allocating callback using void* message.
std::function<void *(size_t size)> allocate_callback =
std::bind(&RecvActor::AllocateMessage, this, std::placeholders::_1);
if (!server_->Initialize(allocate_callback)) {
MS_LOG(EXCEPTION) << "Failed to initialize tcp server for recv actor";
MS_LOG(EXCEPTION) << "Failed to initialize rpc server for recv actor";
}
ip_ = server_->GetIP();
port_ = server_->GetPort();

View File

@ -62,7 +62,7 @@ class RecvActor : public RpcActor {
// Start recv actor server and register this server address to actor route table in scheduler by proxy.
bool StartServer();
// Finalize tcp server.
// Finalize rpc server.
void Clear() override;
void StopRpcAtException() override;
@ -102,7 +102,7 @@ class RecvActor : public RpcActor {
*/
void *AllocateMemByDeviceRes(size_t size);
std::unique_ptr<TCPServer> server_;
std::unique_ptr<RPCServerBase> server_;
// The variables used to ensure thread-safe of op context visited by recv actor.
bool is_context_valid_;
@ -129,7 +129,7 @@ class RecvActor : public RpcActor {
// it, e.g., infer shape for RpcRecv kernel and call Resize().
void PreprocessRemoteInput(const MessageBase *const msg, bool *need_finalize);
// The message callback of the tcp server.
// The message callback of the rpc server.
MessageBase *HandleMessage(MessageBase *const msg);
// The network address of this recv actor. It's generated automatically by rpc module.

View File

@ -26,6 +26,10 @@
#include "distributed/cluster/cluster_context.h"
#include "distributed/rpc/tcp/tcp_client.h"
#include "distributed/rpc/tcp/tcp_server.h"
#ifdef ENABLE_RDMA
#include "distributed/rpc/rdma/rdma_client.h"
#include "distributed/rpc/rdma/rdma_server.h"
#endif
#include "proto/rpc.pb.h"
#include "proto/topology.pb.h"
@ -35,9 +39,15 @@ using distributed::cluster::ActorRouteTableProxy;
using distributed::cluster::ActorRouteTableProxyPtr;
using distributed::cluster::ClusterContext;
using distributed::cluster::topology::ActorAddress;
using distributed::rpc::RPCClientBase;
using distributed::rpc::RPCServerBase;
using distributed::rpc::TCPClient;
using distributed::rpc::TCPServer;
using mindspore::device::KernelInfo;
#ifdef ENABLE_RDMA
using distributed::rpc::RDMAClient;
using distributed::rpc::RDMAServer;
#endif
// The inter-process edge mark between two nodes.
constexpr char kInterProcessEdgeMark[] = "->";

View File

@ -27,7 +27,7 @@ SendActor::~SendActor() {
(void)client_->Disconnect(server_url_);
client_->Finalize();
} catch (const std::exception &) {
MS_LOG(ERROR) << "Failed to disconnect and finalize for tcp client in send actor.";
MS_LOG(ERROR) << "Failed to disconnect and finalize for rpc client in send actor.";
}
client_ = nullptr;
}
@ -40,10 +40,19 @@ void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &s
}
bool SendActor::ConnectServer() {
#ifdef ENABLE_RDMA
if (common::GetEnv(kEnableRDMA) == "1") {
client_ = std::make_unique<RDMAClient>();
} else {
client_ = std::make_unique<TCPClient>();
}
#else
client_ = std::make_unique<TCPClient>();
#endif
MS_EXCEPTION_IF_NULL(client_);
if (!client_->Initialize()) {
MS_LOG(EXCEPTION) << "Failed to initialize tcp server for send actor.";
MS_LOG(EXCEPTION) << "Failed to initialize rpc server for send actor.";
}
// Lookup actor addresses for each peer actor.
for (const auto &peer_actor_id : peer_actor_ids_) {

View File

@ -49,7 +49,7 @@ class SendActor : public RpcActor {
// Flush and wait for sent data to be passed to kernel.
void FlushData() override;
// Finalize tcp client.
// Finalize rpc client.
void Clear() override;
protected:
@ -76,8 +76,8 @@ class SendActor : public RpcActor {
*/
virtual void Flush();
// The tcp client connection to multiple servers.
std::unique_ptr<TCPClient> client_;
// The rpc client connection to multiple servers.
std::unique_ptr<RPCClientBase> client_;
private:
/**
@ -136,7 +136,7 @@ class SendActor : public RpcActor {
std::vector<std::string> peer_actor_ids_;
mindspore::HashMap<std::string, std::string> peer_actor_urls_;
// The url of the peer recv actor's tcp server.
// The url of the peer recv actor's server.
std::string server_url_;
};

View File

@ -45,7 +45,9 @@ class RDMATest : public UT::Common {
std::unique_ptr<MessageBase> RDMATest::CreateMessage(const std::string &msg) {
std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
size_t msg_size = msg.size();
ASSERT_TRUE(msg_size != 0);
if (msg_size == 0) {
MS_LOG(EXCEPTION) << "msg_size is 0!";
}
void *data = malloc(msg_size + 1);
(void)memcpy_s(data, msg_size, msg.c_str(), msg_size);
message->data = data;
@ -57,6 +59,36 @@ std::unique_ptr<MessageBase> RDMATest::CreateMessage(const std::string &msg) {
/// Description: test basic connection function between RDMA client and server.
/// Expectation: RDMA client successfully connects to RDMA server and sends a simple message.
TEST_F(RDMATest, TestRDMAConnection) {
std::string url = "127.0.0.1:10969";
size_t server_pid = fork();
if (server_pid == 0) {
std::shared_ptr<RDMAServer> rdma_server = std::make_shared<RDMAServer>();
MS_EXCEPTION_IF_NULL(rdma_server);
ASSERT_TRUE(rdma_server->Initialize(url));
sleep(3);
rdma_server->Finalize();
return;
}
sleep(1);
size_t client_pid = fork();
if (client_pid == 0) {
std::shared_ptr<RDMAClient> rdma_client = std::make_shared<RDMAClient>();
MS_EXCEPTION_IF_NULL(rdma_client);
ASSERT_TRUE(rdma_client->Initialize());
ASSERT_TRUE(rdma_client->Connect(url));
rdma_client->Finalize();
return;
}
int wstatus;
(void)waitpid(client_pid, &wstatus, WUNTRACED | WCONTINUED);
(void)waitpid(server_pid, &wstatus, WUNTRACED | WCONTINUED);
}
/// Feature: RDMA communication.
/// Description: test SendSync interface for RDMA client and server.
/// Expectation: RDMA client successfully sends two messages to RDMA server synchronously.
TEST_F(RDMATest, TestRDMASendSync) {
std::string url = "127.0.0.1:10969";
size_t server_pid = fork();
if (server_pid == 0) {
@ -70,6 +102,7 @@ TEST_F(RDMATest, TestRDMAConnection) {
};
rdma_server->SetMessageHandler(msg_handler);
sleep(3);
rdma_server->Finalize();
return;
}
sleep(1);
@ -80,8 +113,54 @@ TEST_F(RDMATest, TestRDMAConnection) {
ASSERT_TRUE(rdma_client->Initialize());
ASSERT_TRUE(rdma_client->Connect(url));
auto message = CreateMessage("Hello server!");
ASSERT_TRUE(rdma_client->SendSync(std::move(message)));
auto message1 = CreateMessage("Hello server sync!");
ASSERT_TRUE(rdma_client->SendSync(std::move(message1)));
auto message2 = CreateMessage("Hello server sync!");
ASSERT_TRUE(rdma_client->SendSync(std::move(message2)));
rdma_client->Finalize();
return;
}
int wstatus;
(void)waitpid(client_pid, &wstatus, WUNTRACED | WCONTINUED);
(void)waitpid(server_pid, &wstatus, WUNTRACED | WCONTINUED);
}
/// Feature: RDMA communication.
/// Description: test SendAsync interface for RDMA client and server.
/// Expectation: RDMA client successfully sends two messages to RDMA server asynchronously.
TEST_F(RDMATest, TestRDMASendAsync) {
std::string url = "127.0.0.1:10969";
size_t server_pid = fork();
if (server_pid == 0) {
std::shared_ptr<RDMAServer> rdma_server = std::make_shared<RDMAServer>();
MS_EXCEPTION_IF_NULL(rdma_server);
ASSERT_TRUE(rdma_server->Initialize(url));
auto msg_handler = [](MessageBase *const msg) {
MS_LOG(INFO) << "Receive message from client: " << static_cast<char *>(msg->data);
return nullptr;
};
rdma_server->SetMessageHandler(msg_handler);
sleep(3);
rdma_server->Finalize();
return;
}
sleep(1);
size_t client_pid = fork();
if (client_pid == 0) {
std::shared_ptr<RDMAClient> rdma_client = std::make_shared<RDMAClient>();
MS_EXCEPTION_IF_NULL(rdma_client);
ASSERT_TRUE(rdma_client->Initialize());
ASSERT_TRUE(rdma_client->Connect(url));
auto message1 = CreateMessage("Hello server async!");
rdma_client->SendAsync(std::move(message1));
ASSERT_TRUE(rdma_client->Flush(url));
auto message2 = CreateMessage("Hello server async!");
rdma_client->SendAsync(std::move(message2));
ASSERT_TRUE(rdma_client->Flush(url));
rdma_client->Finalize();
return;
}