forked from mindspore-Ecosystem/mindspore
Enhance rdma impl
This commit is contained in:
parent
364e292587
commit
fe2567b593
|
@ -95,6 +95,10 @@ constexpr char kControlDstOpName[] = "ControlDst";
|
||||||
static const char URL_PROTOCOL_IP_SEPARATOR[] = "://";
|
static const char URL_PROTOCOL_IP_SEPARATOR[] = "://";
|
||||||
static const char URL_IP_PORT_SEPARATOR[] = ":";
|
static const char URL_IP_PORT_SEPARATOR[] = ":";
|
||||||
|
|
||||||
|
constexpr char kEnableRDMA[] = "enable_rdma";
|
||||||
|
constexpr char kRDMADevName[] = "rdma_dev";
|
||||||
|
constexpr char kRDMAIP[] = "rdma_ip";
|
||||||
|
|
||||||
// This macro the current timestamp in milliseconds.
|
// This macro the current timestamp in milliseconds.
|
||||||
#define CURRENT_TIMESTAMP_MILLI \
|
#define CURRENT_TIMESTAMP_MILLI \
|
||||||
(std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()))
|
(std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()))
|
||||||
|
|
|
@ -18,7 +18,9 @@
|
||||||
#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_CONSTANTS_H_
|
#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_CONSTANTS_H_
|
||||||
|
|
||||||
#include <urpc.h>
|
#include <urpc.h>
|
||||||
|
#include <mutex>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <condition_variable>
|
||||||
|
|
||||||
#include "utils/dlopen_macro.h"
|
#include "utils/dlopen_macro.h"
|
||||||
#include "distributed/constants.h"
|
#include "distributed/constants.h"
|
||||||
|
@ -114,8 +116,10 @@ constexpr uint32_t kServerWorkingThreadNum = 4;
|
||||||
constexpr uint32_t kClientPollingThreadNum = 4;
|
constexpr uint32_t kClientPollingThreadNum = 4;
|
||||||
|
|
||||||
struct req_cb_arg {
|
struct req_cb_arg {
|
||||||
int *rsp_received;
|
bool rsp_received;
|
||||||
struct urpc_buffer_allocator *allocator;
|
struct urpc_buffer_allocator *allocator;
|
||||||
|
std::mutex *mtx;
|
||||||
|
std::condition_variable *cv;
|
||||||
};
|
};
|
||||||
} // namespace rpc
|
} // namespace rpc
|
||||||
} // namespace distributed
|
} // namespace distributed
|
||||||
|
|
|
@ -21,7 +21,10 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
namespace rpc {
|
namespace rpc {
|
||||||
bool RDMAClient::Initialize() { return true; }
|
bool RDMAClient::Initialize() {
|
||||||
|
// The initialization of URPC RDMA is implemented in 'Connect' function.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void RDMAClient::Finalize() {
|
void RDMAClient::Finalize() {
|
||||||
if (urpc_session_ != nullptr) {
|
if (urpc_session_ != nullptr) {
|
||||||
|
@ -63,7 +66,13 @@ bool RDMAClient::Connect(const std::string &dst_url, size_t retry_count, const M
|
||||||
|
|
||||||
bool RDMAClient::IsConnected(const std::string &dst_url) { return false; }
|
bool RDMAClient::IsConnected(const std::string &dst_url) { return false; }
|
||||||
|
|
||||||
bool RDMAClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) { return true; }
|
bool RDMAClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) {
|
||||||
|
if (urpc_session_ != nullptr) {
|
||||||
|
urpc_close_func(urpc_session_);
|
||||||
|
urpc_session_ = nullptr;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool RDMAClient::SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes) {
|
bool RDMAClient::SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes) {
|
||||||
MS_EXCEPTION_IF_NULL(msg);
|
MS_EXCEPTION_IF_NULL(msg);
|
||||||
|
@ -77,34 +86,59 @@ bool RDMAClient::SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send
|
||||||
sgl.sge[0].flag = URPC_SGE_FLAG_ZERO_COPY;
|
sgl.sge[0].flag = URPC_SGE_FLAG_ZERO_COPY;
|
||||||
sgl.sge_num = 1;
|
sgl.sge_num = 1;
|
||||||
|
|
||||||
int rsp_received = 0;
|
struct urpc_send_wr send_wr = {};
|
||||||
struct req_cb_arg cb_arg = {};
|
send_wr.func_id = kInterProcessDataHandleID;
|
||||||
cb_arg.rsp_received = &rsp_received;
|
send_wr.send_mode = URPC_SEND_MODE_SYNC;
|
||||||
cb_arg.allocator = urpc_allocator_;
|
send_wr.req = &sgl;
|
||||||
|
struct urpc_sgl rsp_sgl = {0};
|
||||||
|
send_wr.sync.rsp = &rsp_sgl;
|
||||||
|
|
||||||
|
if (urpc_send_request_func(urpc_session_, &send_wr, nullptr) < 0) {
|
||||||
|
MS_LOG(ERROR) << "Failed to send request for function call: " << send_wr.func_id;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Server response message is " << reinterpret_cast<char *>(rsp_sgl.sge[0].addr);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RDMAClient::SendAsync(std::unique_ptr<MessageBase> &&msg) {
|
||||||
|
MS_EXCEPTION_IF_NULL(msg);
|
||||||
|
size_t msg_size = msg->size;
|
||||||
|
void *msg_buf = msg->data;
|
||||||
|
MS_EXCEPTION_IF_NULL(msg_buf);
|
||||||
|
|
||||||
|
struct urpc_sgl sgl;
|
||||||
|
sgl.sge[0].addr = reinterpret_cast<uintptr_t>(msg_buf);
|
||||||
|
sgl.sge[0].length = msg_size;
|
||||||
|
sgl.sge[0].flag = URPC_SGE_FLAG_ZERO_COPY;
|
||||||
|
sgl.sge_num = 1;
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> lock(mtx_);
|
||||||
|
cb_arg_.rsp_received = false;
|
||||||
|
cb_arg_.allocator = urpc_allocator_;
|
||||||
|
cb_arg_.mtx = &mtx_;
|
||||||
|
cb_arg_.cv = &cv_;
|
||||||
|
lock.unlock();
|
||||||
|
|
||||||
struct urpc_send_wr send_wr = {};
|
struct urpc_send_wr send_wr = {};
|
||||||
send_wr.func_id = kInterProcessDataHandleID;
|
send_wr.func_id = kInterProcessDataHandleID;
|
||||||
send_wr.send_mode = URPC_SEND_MODE_ASYNC;
|
send_wr.send_mode = URPC_SEND_MODE_ASYNC;
|
||||||
send_wr.req = &sgl;
|
send_wr.req = &sgl;
|
||||||
send_wr.async.cb.wo_ctx = urpc_rsp_cb;
|
send_wr.async.cb.wo_ctx = urpc_rsp_cb;
|
||||||
send_wr.async.cb_arg = &cb_arg;
|
send_wr.async.cb_arg = &cb_arg_;
|
||||||
|
|
||||||
if (urpc_send_request_func(urpc_session_, &send_wr, nullptr) != kURPCSuccess) {
|
if (urpc_send_request_func(urpc_session_, &send_wr, nullptr) < 0) {
|
||||||
MS_LOG(ERROR) << "Failed to send request.";
|
MS_LOG(EXCEPTION) << "Failed to send request to server.";
|
||||||
return false;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t sleep_time_us = 200000;
|
|
||||||
// Wait till server responds.
|
|
||||||
while (rsp_received == 0) {
|
|
||||||
usleep(sleep_time_us);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void RDMAClient::SendAsync(std::unique_ptr<MessageBase> &&msg) {}
|
bool RDMAClient::Flush(const std::string &dst_url) {
|
||||||
|
std::unique_lock<std::mutex> lock(mtx_);
|
||||||
bool RDMAClient::Flush(const std::string &dst_url) { return true; }
|
cv_.wait(lock, [this]() { return cb_arg_.rsp_received; });
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void RDMAClient::urpc_rsp_cb(struct urpc_sgl *rsp, int err, void *arg) {
|
void RDMAClient::urpc_rsp_cb(struct urpc_sgl *rsp, int err, void *arg) {
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(rsp);
|
MS_ERROR_IF_NULL_WO_RET_VAL(rsp);
|
||||||
|
@ -116,7 +150,9 @@ void RDMAClient::urpc_rsp_cb(struct urpc_sgl *rsp, int err, void *arg) {
|
||||||
|
|
||||||
MS_LOG(INFO) << "Server response message is " << reinterpret_cast<char *>(rsp->sge[0].addr);
|
MS_LOG(INFO) << "Server response message is " << reinterpret_cast<char *>(rsp->sge[0].addr);
|
||||||
struct req_cb_arg *cb_arg = static_cast<struct req_cb_arg *>(arg);
|
struct req_cb_arg *cb_arg = static_cast<struct req_cb_arg *>(arg);
|
||||||
*(cb_arg->rsp_received) = 1;
|
std::unique_lock<std::mutex> lock(*(cb_arg->mtx));
|
||||||
|
cb_arg->rsp_received = true;
|
||||||
|
cb_arg->cv->notify_all();
|
||||||
}
|
}
|
||||||
} // namespace rpc
|
} // namespace rpc
|
||||||
} // namespace distributed
|
} // namespace distributed
|
||||||
|
|
|
@ -65,6 +65,12 @@ class BACKEND_EXPORT RDMAClient : public RPCClientBase {
|
||||||
|
|
||||||
struct urpc_buffer_allocator *urpc_allocator_;
|
struct urpc_buffer_allocator *urpc_allocator_;
|
||||||
urpc_session_t *urpc_session_;
|
urpc_session_t *urpc_session_;
|
||||||
|
|
||||||
|
// The variables for synchronization of async messages.
|
||||||
|
std::mutex mtx_;
|
||||||
|
std::condition_variable cv_;
|
||||||
|
|
||||||
|
struct req_cb_arg cb_arg_;
|
||||||
};
|
};
|
||||||
} // namespace rpc
|
} // namespace rpc
|
||||||
} // namespace distributed
|
} // namespace distributed
|
||||||
|
|
|
@ -31,21 +31,18 @@ bool RDMAServer::Initialize(const std::string &url, const MemAllocateCallback &a
|
||||||
port_ = port;
|
port_ = port;
|
||||||
|
|
||||||
// Init URPC for RMDA server.
|
// Init URPC for RMDA server.
|
||||||
struct urpc_config urpc_cfg = {};
|
return InitializeURPC();
|
||||||
urpc_cfg.mode = URPC_MODE_SERVER;
|
}
|
||||||
urpc_cfg.sfeature = 0;
|
|
||||||
urpc_cfg.model = URPC_THREAD_MODEL_R2C;
|
bool RDMAServer::Initialize(const MemAllocateCallback &allocate_cb) {
|
||||||
urpc_cfg.worker_num = kServerWorkingThreadNum;
|
dev_name_ = const_cast<char *>(common::GetEnv(kRDMADevName).c_str());
|
||||||
urpc_cfg.transport.dev_name = dev_name_;
|
ip_addr_ = const_cast<char *>(common::GetEnv(kRDMAIP).c_str());
|
||||||
urpc_cfg.transport.ip_addr = ip_addr_;
|
port_ = 0;
|
||||||
urpc_cfg.transport.port = port_;
|
MS_LOG(INFO) << "Initialize RDMA server. Device name: " << dev_name_ << ", ip address: " << ip_addr_
|
||||||
urpc_cfg.transport.max_sge = 0;
|
<< ", port: " << port_;
|
||||||
urpc_cfg.allocator = nullptr;
|
|
||||||
if (urpc_init_func(&urpc_cfg) != kURPCSuccess) {
|
// Init URPC for RMDA server.
|
||||||
MS_LOG(EXCEPTION) << "Failed to call urpc_init. Device name: " << dev_name_ << ", ip address: " << ip_addr_
|
return InitializeURPC();
|
||||||
<< ", port: " << port_ << ". Please refer to URPC log directory: /var/log/umdk/urpc.";
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void RDMAServer::Finalize() {
|
void RDMAServer::Finalize() {
|
||||||
|
@ -67,9 +64,27 @@ void RDMAServer::SetMessageHandler(const MessageHandler &handler) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string RDMAServer::GetIP() const { return ""; }
|
std::string RDMAServer::GetIP() const { return ip_addr_; }
|
||||||
|
|
||||||
uint32_t RDMAServer::GetPort() const { return 0; }
|
uint32_t RDMAServer::GetPort() const { return static_cast<uint32_t>(port_); }
|
||||||
|
|
||||||
|
bool RDMAServer::InitializeURPC() {
|
||||||
|
struct urpc_config urpc_cfg = {};
|
||||||
|
urpc_cfg.mode = URPC_MODE_SERVER;
|
||||||
|
urpc_cfg.sfeature = 0;
|
||||||
|
urpc_cfg.model = URPC_THREAD_MODEL_R2C;
|
||||||
|
urpc_cfg.worker_num = kServerWorkingThreadNum;
|
||||||
|
urpc_cfg.transport.dev_name = dev_name_;
|
||||||
|
urpc_cfg.transport.ip_addr = ip_addr_;
|
||||||
|
urpc_cfg.transport.port = port_;
|
||||||
|
urpc_cfg.transport.max_sge = 0;
|
||||||
|
urpc_cfg.allocator = nullptr;
|
||||||
|
if (urpc_init_func(&urpc_cfg) != kURPCSuccess) {
|
||||||
|
MS_LOG(EXCEPTION) << "Failed to call urpc_init. Device name: " << dev_name_ << ", ip address: " << ip_addr_
|
||||||
|
<< ", port: " << port_ << ". Please refer to URPC log directory: /var/log/umdk/urpc.";
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void RDMAServer::urpc_req_handler(struct urpc_sgl *req, void *arg, struct urpc_sgl *rsp) {
|
void RDMAServer::urpc_req_handler(struct urpc_sgl *req, void *arg, struct urpc_sgl *rsp) {
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(req);
|
MS_ERROR_IF_NULL_WO_RET_VAL(req);
|
||||||
|
|
|
@ -37,6 +37,7 @@ class BACKEND_EXPORT RDMAServer : public RPCServerBase {
|
||||||
~RDMAServer() override = default;
|
~RDMAServer() override = default;
|
||||||
|
|
||||||
bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {}) override;
|
bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {}) override;
|
||||||
|
bool Initialize(const MemAllocateCallback &allocate_cb = {}) override;
|
||||||
void Finalize() override;
|
void Finalize() override;
|
||||||
void SetMessageHandler(const MessageHandler &handler) override;
|
void SetMessageHandler(const MessageHandler &handler) override;
|
||||||
|
|
||||||
|
@ -47,6 +48,9 @@ class BACKEND_EXPORT RDMAServer : public RPCServerBase {
|
||||||
struct urpc_buffer_allocator *urpc_allocator_;
|
struct urpc_buffer_allocator *urpc_allocator_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// Initialize urpc configuration according to dev_name_, ip_addr_ and port_.
|
||||||
|
bool InitializeURPC();
|
||||||
|
|
||||||
// The message callback for urpc. This method will call user-set message handler.
|
// The message callback for urpc. This method will call user-set message handler.
|
||||||
static void urpc_req_handler(struct urpc_sgl *req, void *arg, struct urpc_sgl *rsp);
|
static void urpc_req_handler(struct urpc_sgl *req, void *arg, struct urpc_sgl *rsp);
|
||||||
// The callback after this server responding.
|
// The callback after this server responding.
|
||||||
|
|
|
@ -27,7 +27,7 @@ namespace distributed {
|
||||||
namespace rpc {
|
namespace rpc {
|
||||||
class BACKEND_EXPORT RPCClientBase {
|
class BACKEND_EXPORT RPCClientBase {
|
||||||
public:
|
public:
|
||||||
explicit RPCClientBase(bool enable_ssl = false) : enable_ssl_(enable_ssl) {}
|
explicit RPCClientBase(bool enable_ssl) : enable_ssl_(enable_ssl) {}
|
||||||
virtual ~RPCClientBase() = default;
|
virtual ~RPCClientBase() = default;
|
||||||
|
|
||||||
// Build or destroy the rpc client.
|
// Build or destroy the rpc client.
|
||||||
|
@ -36,20 +36,29 @@ class BACKEND_EXPORT RPCClientBase {
|
||||||
|
|
||||||
// Connect to the specified server.
|
// Connect to the specified server.
|
||||||
// Function free_cb binds with client's each connection. It frees the real memory after message is sent to the peer.
|
// Function free_cb binds with client's each connection. It frees the real memory after message is sent to the peer.
|
||||||
virtual bool Connect(const std::string &dst_url, size_t retry_count, const MemFreeCallback &free_cb) { return true; }
|
virtual bool Connect(
|
||||||
|
const std::string &dst_url, size_t retry_count = 60, const MemFreeCallback &free_cb = [](void *data) {
|
||||||
|
MS_ERROR_IF_NULL(data);
|
||||||
|
delete static_cast<char *>(data);
|
||||||
|
return true;
|
||||||
|
}) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// Check if the connection to dst_url has been established.
|
// Check if the connection to dst_url has been established.
|
||||||
virtual bool IsConnected(const std::string &dst_url) { return false; }
|
virtual bool IsConnected(const std::string &dst_url) { return false; }
|
||||||
|
|
||||||
// Disconnect from the specified server.
|
// Disconnect from the specified server.
|
||||||
virtual bool Disconnect(const std::string &dst_url, size_t timeout_in_sec) { return true; }
|
virtual bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5) { return true; }
|
||||||
|
|
||||||
// Send the message from the source to the destination synchronously and return the byte size by this method call.
|
// Send the message from the source to the destination synchronously and return the byte size by this method call.
|
||||||
virtual bool SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes) { return true; }
|
virtual bool SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes = nullptr) { return true; }
|
||||||
|
|
||||||
// Send the message from the source to the destination asynchronously.
|
// Send the message from the source to the destination asynchronously.
|
||||||
virtual void SendAsync(std::unique_ptr<MessageBase> &&msg) {}
|
virtual void SendAsync(std::unique_ptr<MessageBase> &&msg) {}
|
||||||
|
|
||||||
|
virtual MessageBase *ReceiveSync(std::unique_ptr<MessageBase> &&msg, uint32_t timeout = 30) { return nullptr; }
|
||||||
|
|
||||||
// Force the data in the send buffer to be sent out.
|
// Force the data in the send buffer to be sent out.
|
||||||
virtual bool Flush(const std::string &dst_url) { return true; }
|
virtual bool Flush(const std::string &dst_url) { return true; }
|
||||||
|
|
||||||
|
|
|
@ -31,10 +31,10 @@ class BACKEND_EXPORT RPCServerBase {
|
||||||
virtual ~RPCServerBase() = default;
|
virtual ~RPCServerBase() = default;
|
||||||
|
|
||||||
// Init server using the specified url, with memory allocating function.
|
// Init server using the specified url, with memory allocating function.
|
||||||
virtual bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb) { return true; }
|
virtual bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {}) { return true; }
|
||||||
|
|
||||||
// Init server using local IP and random port.
|
// Init server using local IP and random port.
|
||||||
virtual bool Initialize(const MemAllocateCallback &allocate_cb) { return true; }
|
virtual bool Initialize(const MemAllocateCallback &allocate_cb = {}) { return true; }
|
||||||
|
|
||||||
// Destroy the tcp server.
|
// Destroy the tcp server.
|
||||||
virtual void Finalize() {}
|
virtual void Finalize() {}
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
|
|
||||||
|
#include "distributed/rpc/rpc_client_base.h"
|
||||||
#include "distributed/rpc/tcp/tcp_comm.h"
|
#include "distributed/rpc/tcp/tcp_comm.h"
|
||||||
#include "utils/ms_utils.h"
|
#include "utils/ms_utils.h"
|
||||||
#include "include/backend/visible.h"
|
#include "include/backend/visible.h"
|
||||||
|
@ -29,14 +30,14 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
namespace rpc {
|
namespace rpc {
|
||||||
class BACKEND_EXPORT TCPClient {
|
class BACKEND_EXPORT TCPClient : public RPCClientBase {
|
||||||
public:
|
public:
|
||||||
explicit TCPClient(bool enable_ssl = false) : enable_ssl_(enable_ssl) {}
|
explicit TCPClient(bool enable_ssl = false) : RPCClientBase(enable_ssl) {}
|
||||||
~TCPClient() = default;
|
~TCPClient() override = default;
|
||||||
|
|
||||||
// Build or destroy the TCP client.
|
// Build or destroy the TCP client.
|
||||||
bool Initialize();
|
bool Initialize() override;
|
||||||
void Finalize();
|
void Finalize() override;
|
||||||
|
|
||||||
// Connect to the specified server.
|
// Connect to the specified server.
|
||||||
// Function free_cb binds with client's each connection. It frees the real memory after message is sent to the peer.
|
// Function free_cb binds with client's each connection. It frees the real memory after message is sent to the peer.
|
||||||
|
@ -45,26 +46,26 @@ class BACKEND_EXPORT TCPClient {
|
||||||
MS_ERROR_IF_NULL(data);
|
MS_ERROR_IF_NULL(data);
|
||||||
delete static_cast<char *>(data);
|
delete static_cast<char *>(data);
|
||||||
return true;
|
return true;
|
||||||
});
|
}) override;
|
||||||
|
|
||||||
// Check if the connection to dst_url has been established.
|
// Check if the connection to dst_url has been established.
|
||||||
bool IsConnected(const std::string &dst_url);
|
bool IsConnected(const std::string &dst_url) override;
|
||||||
|
|
||||||
// Disconnect from the specified server.
|
// Disconnect from the specified server.
|
||||||
bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5);
|
bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5) override;
|
||||||
|
|
||||||
// Send the message from the source to the destination synchronously and return the byte size by this method call.
|
// Send the message from the source to the destination synchronously and return the byte size by this method call.
|
||||||
bool SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes = nullptr);
|
bool SendSync(std::unique_ptr<MessageBase> &&msg, size_t *const send_bytes = nullptr) override;
|
||||||
|
|
||||||
// Send the message from the source to the destination asynchronously.
|
// Send the message from the source to the destination asynchronously.
|
||||||
void SendAsync(std::unique_ptr<MessageBase> &&msg);
|
void SendAsync(std::unique_ptr<MessageBase> &&msg) override;
|
||||||
|
|
||||||
// Retrieve a message from tcp server specified by the input message.
|
// Retrieve a message from tcp server specified by the input message.
|
||||||
// Returns nullptr after timeout.
|
// Returns nullptr after timeout.
|
||||||
MessageBase *ReceiveSync(std::unique_ptr<MessageBase> &&msg, uint32_t timeout = 30);
|
MessageBase *ReceiveSync(std::unique_ptr<MessageBase> &&msg, uint32_t timeout = 30) override;
|
||||||
|
|
||||||
// Force the data in the send buffer to be sent out.
|
// Force the data in the send buffer to be sent out.
|
||||||
bool Flush(const std::string &dst_url);
|
bool Flush(const std::string &dst_url) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// The basic TCP communication component used by the client.
|
// The basic TCP communication component used by the client.
|
||||||
|
@ -78,8 +79,6 @@ class BACKEND_EXPORT TCPClient {
|
||||||
// The received message from the meta server by calling the method `ReceiveSync`.
|
// The received message from the meta server by calling the method `ReceiveSync`.
|
||||||
MessageBase *received_message_{nullptr};
|
MessageBase *received_message_{nullptr};
|
||||||
|
|
||||||
bool enable_ssl_;
|
|
||||||
|
|
||||||
DISABLE_COPY_AND_ASSIGN(TCPClient);
|
DISABLE_COPY_AND_ASSIGN(TCPClient);
|
||||||
};
|
};
|
||||||
} // namespace rpc
|
} // namespace rpc
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "distributed/rpc/rpc_server_base.h"
|
||||||
#include "distributed/rpc/tcp/tcp_comm.h"
|
#include "distributed/rpc/tcp/tcp_comm.h"
|
||||||
#include "utils/ms_utils.h"
|
#include "utils/ms_utils.h"
|
||||||
#include "include/backend/visible.h"
|
#include "include/backend/visible.h"
|
||||||
|
@ -27,26 +28,26 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace distributed {
|
namespace distributed {
|
||||||
namespace rpc {
|
namespace rpc {
|
||||||
class BACKEND_EXPORT TCPServer {
|
class BACKEND_EXPORT TCPServer : public RPCServerBase {
|
||||||
public:
|
public:
|
||||||
explicit TCPServer(bool enable_ssl = false) : enable_ssl_(enable_ssl) {}
|
explicit TCPServer(bool enable_ssl = false) : RPCServerBase(enable_ssl) {}
|
||||||
~TCPServer() = default;
|
~TCPServer() override = default;
|
||||||
|
|
||||||
// Init the tcp server using the specified url.
|
// Init the tcp server using the specified url.
|
||||||
bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {});
|
bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {}) override;
|
||||||
|
|
||||||
// Init the tcp server using local IP and random port.
|
// Init the tcp server using local IP and random port.
|
||||||
bool Initialize(const MemAllocateCallback &allocate_cb = {});
|
bool Initialize(const MemAllocateCallback &allocate_cb = {}) override;
|
||||||
|
|
||||||
// Destroy the tcp server.
|
// Destroy the tcp server.
|
||||||
void Finalize();
|
void Finalize() override;
|
||||||
|
|
||||||
// Set the message processing handler.
|
// Set the message processing handler.
|
||||||
void SetMessageHandler(const MessageHandler &handler);
|
void SetMessageHandler(const MessageHandler &handler) override;
|
||||||
|
|
||||||
// Return the IP and port binded by this server.
|
// Return the IP and port binded by this server.
|
||||||
std::string GetIP() const;
|
std::string GetIP() const override;
|
||||||
uint32_t GetPort() const;
|
uint32_t GetPort() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool InitializeImpl(const std::string &url, const MemAllocateCallback &allocate_cb);
|
bool InitializeImpl(const std::string &url, const MemAllocateCallback &allocate_cb);
|
||||||
|
@ -54,11 +55,6 @@ class BACKEND_EXPORT TCPServer {
|
||||||
// The basic TCP communication component used by the server.
|
// The basic TCP communication component used by the server.
|
||||||
std::unique_ptr<TCPComm> tcp_comm_{nullptr};
|
std::unique_ptr<TCPComm> tcp_comm_{nullptr};
|
||||||
|
|
||||||
std::string ip_{""};
|
|
||||||
uint32_t port_{0};
|
|
||||||
|
|
||||||
bool enable_ssl_;
|
|
||||||
|
|
||||||
DISABLE_COPY_AND_ASSIGN(TCPServer);
|
DISABLE_COPY_AND_ASSIGN(TCPServer);
|
||||||
};
|
};
|
||||||
} // namespace rpc
|
} // namespace rpc
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
#include "proto/topology.pb.h"
|
#include "proto/topology.pb.h"
|
||||||
#include "distributed/rpc/tcp/constants.h"
|
|
||||||
#include "plugin/device/cpu/kernel/rpc/rpc_recv_kernel.h"
|
#include "plugin/device/cpu/kernel/rpc/rpc_recv_kernel.h"
|
||||||
#include "backend/common/optimizer/helper.h"
|
#include "backend/common/optimizer/helper.h"
|
||||||
|
|
||||||
|
@ -62,15 +61,24 @@ void RecvActor::SetRouteInfo(uint32_t, const std::string &, const std::string &r
|
||||||
}
|
}
|
||||||
|
|
||||||
bool RecvActor::StartServer() {
|
bool RecvActor::StartServer() {
|
||||||
// Step 1: Create a tcp server and start listening.
|
// Step 1: Create a rpc server and start listening.
|
||||||
|
|
||||||
|
#ifdef ENABLE_RDMA
|
||||||
|
if (common::GetEnv(kEnableRDMA) == "1") {
|
||||||
|
server_ = std::make_unique<RDMAServer>();
|
||||||
|
} else {
|
||||||
|
server_ = std::make_unique<TCPServer>();
|
||||||
|
}
|
||||||
|
#else
|
||||||
server_ = std::make_unique<TCPServer>();
|
server_ = std::make_unique<TCPServer>();
|
||||||
|
#endif
|
||||||
MS_EXCEPTION_IF_NULL(server_);
|
MS_EXCEPTION_IF_NULL(server_);
|
||||||
|
|
||||||
// Set the memory allocating callback using void* message.
|
// Set the memory allocating callback using void* message.
|
||||||
std::function<void *(size_t size)> allocate_callback =
|
std::function<void *(size_t size)> allocate_callback =
|
||||||
std::bind(&RecvActor::AllocateMessage, this, std::placeholders::_1);
|
std::bind(&RecvActor::AllocateMessage, this, std::placeholders::_1);
|
||||||
if (!server_->Initialize(allocate_callback)) {
|
if (!server_->Initialize(allocate_callback)) {
|
||||||
MS_LOG(EXCEPTION) << "Failed to initialize tcp server for recv actor";
|
MS_LOG(EXCEPTION) << "Failed to initialize rpc server for recv actor";
|
||||||
}
|
}
|
||||||
ip_ = server_->GetIP();
|
ip_ = server_->GetIP();
|
||||||
port_ = server_->GetPort();
|
port_ = server_->GetPort();
|
||||||
|
|
|
@ -62,7 +62,7 @@ class RecvActor : public RpcActor {
|
||||||
// Start recv actor server and register this server address to actor route table in scheduler by proxy.
|
// Start recv actor server and register this server address to actor route table in scheduler by proxy.
|
||||||
bool StartServer();
|
bool StartServer();
|
||||||
|
|
||||||
// Finalize tcp server.
|
// Finalize rpc server.
|
||||||
void Clear() override;
|
void Clear() override;
|
||||||
|
|
||||||
void StopRpcAtException() override;
|
void StopRpcAtException() override;
|
||||||
|
@ -102,7 +102,7 @@ class RecvActor : public RpcActor {
|
||||||
*/
|
*/
|
||||||
void *AllocateMemByDeviceRes(size_t size);
|
void *AllocateMemByDeviceRes(size_t size);
|
||||||
|
|
||||||
std::unique_ptr<TCPServer> server_;
|
std::unique_ptr<RPCServerBase> server_;
|
||||||
|
|
||||||
// The variables used to ensure thread-safe of op context visited by recv actor.
|
// The variables used to ensure thread-safe of op context visited by recv actor.
|
||||||
bool is_context_valid_;
|
bool is_context_valid_;
|
||||||
|
@ -129,7 +129,7 @@ class RecvActor : public RpcActor {
|
||||||
// it, e.g., infer shape for RpcRecv kernel and call Resize().
|
// it, e.g., infer shape for RpcRecv kernel and call Resize().
|
||||||
void PreprocessRemoteInput(const MessageBase *const msg, bool *need_finalize);
|
void PreprocessRemoteInput(const MessageBase *const msg, bool *need_finalize);
|
||||||
|
|
||||||
// The message callback of the tcp server.
|
// The message callback of the rpc server.
|
||||||
MessageBase *HandleMessage(MessageBase *const msg);
|
MessageBase *HandleMessage(MessageBase *const msg);
|
||||||
|
|
||||||
// The network address of this recv actor. It's generated automatically by rpc module.
|
// The network address of this recv actor. It's generated automatically by rpc module.
|
||||||
|
|
|
@ -26,6 +26,10 @@
|
||||||
#include "distributed/cluster/cluster_context.h"
|
#include "distributed/cluster/cluster_context.h"
|
||||||
#include "distributed/rpc/tcp/tcp_client.h"
|
#include "distributed/rpc/tcp/tcp_client.h"
|
||||||
#include "distributed/rpc/tcp/tcp_server.h"
|
#include "distributed/rpc/tcp/tcp_server.h"
|
||||||
|
#ifdef ENABLE_RDMA
|
||||||
|
#include "distributed/rpc/rdma/rdma_client.h"
|
||||||
|
#include "distributed/rpc/rdma/rdma_server.h"
|
||||||
|
#endif
|
||||||
#include "proto/rpc.pb.h"
|
#include "proto/rpc.pb.h"
|
||||||
#include "proto/topology.pb.h"
|
#include "proto/topology.pb.h"
|
||||||
|
|
||||||
|
@ -35,9 +39,15 @@ using distributed::cluster::ActorRouteTableProxy;
|
||||||
using distributed::cluster::ActorRouteTableProxyPtr;
|
using distributed::cluster::ActorRouteTableProxyPtr;
|
||||||
using distributed::cluster::ClusterContext;
|
using distributed::cluster::ClusterContext;
|
||||||
using distributed::cluster::topology::ActorAddress;
|
using distributed::cluster::topology::ActorAddress;
|
||||||
|
using distributed::rpc::RPCClientBase;
|
||||||
|
using distributed::rpc::RPCServerBase;
|
||||||
using distributed::rpc::TCPClient;
|
using distributed::rpc::TCPClient;
|
||||||
using distributed::rpc::TCPServer;
|
using distributed::rpc::TCPServer;
|
||||||
using mindspore::device::KernelInfo;
|
using mindspore::device::KernelInfo;
|
||||||
|
#ifdef ENABLE_RDMA
|
||||||
|
using distributed::rpc::RDMAClient;
|
||||||
|
using distributed::rpc::RDMAServer;
|
||||||
|
#endif
|
||||||
|
|
||||||
// The inter-process edge mark between two nodes.
|
// The inter-process edge mark between two nodes.
|
||||||
constexpr char kInterProcessEdgeMark[] = "->";
|
constexpr char kInterProcessEdgeMark[] = "->";
|
||||||
|
|
|
@ -27,7 +27,7 @@ SendActor::~SendActor() {
|
||||||
(void)client_->Disconnect(server_url_);
|
(void)client_->Disconnect(server_url_);
|
||||||
client_->Finalize();
|
client_->Finalize();
|
||||||
} catch (const std::exception &) {
|
} catch (const std::exception &) {
|
||||||
MS_LOG(ERROR) << "Failed to disconnect and finalize for tcp client in send actor.";
|
MS_LOG(ERROR) << "Failed to disconnect and finalize for rpc client in send actor.";
|
||||||
}
|
}
|
||||||
client_ = nullptr;
|
client_ = nullptr;
|
||||||
}
|
}
|
||||||
|
@ -40,10 +40,19 @@ void SendActor::SetRouteInfo(uint32_t, const std::string &, const std::string &s
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SendActor::ConnectServer() {
|
bool SendActor::ConnectServer() {
|
||||||
|
#ifdef ENABLE_RDMA
|
||||||
|
if (common::GetEnv(kEnableRDMA) == "1") {
|
||||||
|
client_ = std::make_unique<RDMAClient>();
|
||||||
|
} else {
|
||||||
|
client_ = std::make_unique<TCPClient>();
|
||||||
|
}
|
||||||
|
#else
|
||||||
client_ = std::make_unique<TCPClient>();
|
client_ = std::make_unique<TCPClient>();
|
||||||
|
#endif
|
||||||
MS_EXCEPTION_IF_NULL(client_);
|
MS_EXCEPTION_IF_NULL(client_);
|
||||||
|
|
||||||
if (!client_->Initialize()) {
|
if (!client_->Initialize()) {
|
||||||
MS_LOG(EXCEPTION) << "Failed to initialize tcp server for send actor.";
|
MS_LOG(EXCEPTION) << "Failed to initialize rpc server for send actor.";
|
||||||
}
|
}
|
||||||
// Lookup actor addresses for each peer actor.
|
// Lookup actor addresses for each peer actor.
|
||||||
for (const auto &peer_actor_id : peer_actor_ids_) {
|
for (const auto &peer_actor_id : peer_actor_ids_) {
|
||||||
|
|
|
@ -49,7 +49,7 @@ class SendActor : public RpcActor {
|
||||||
// Flush and wait for sent data to be passed to kernel.
|
// Flush and wait for sent data to be passed to kernel.
|
||||||
void FlushData() override;
|
void FlushData() override;
|
||||||
|
|
||||||
// Finalize tcp client.
|
// Finalize rpc client.
|
||||||
void Clear() override;
|
void Clear() override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
@ -76,8 +76,8 @@ class SendActor : public RpcActor {
|
||||||
*/
|
*/
|
||||||
virtual void Flush();
|
virtual void Flush();
|
||||||
|
|
||||||
// The tcp client connection to multiple servers.
|
// The rpc client connection to multiple servers.
|
||||||
std::unique_ptr<TCPClient> client_;
|
std::unique_ptr<RPCClientBase> client_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
/**
|
||||||
|
@ -136,7 +136,7 @@ class SendActor : public RpcActor {
|
||||||
std::vector<std::string> peer_actor_ids_;
|
std::vector<std::string> peer_actor_ids_;
|
||||||
mindspore::HashMap<std::string, std::string> peer_actor_urls_;
|
mindspore::HashMap<std::string, std::string> peer_actor_urls_;
|
||||||
|
|
||||||
// The url of the peer recv actor's tcp server.
|
// The url of the peer recv actor's server.
|
||||||
std::string server_url_;
|
std::string server_url_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,9 @@ class RDMATest : public UT::Common {
|
||||||
std::unique_ptr<MessageBase> RDMATest::CreateMessage(const std::string &msg) {
|
std::unique_ptr<MessageBase> RDMATest::CreateMessage(const std::string &msg) {
|
||||||
std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
|
std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
|
||||||
size_t msg_size = msg.size();
|
size_t msg_size = msg.size();
|
||||||
ASSERT_TRUE(msg_size != 0);
|
if (msg_size == 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "msg_size is 0!";
|
||||||
|
}
|
||||||
void *data = malloc(msg_size + 1);
|
void *data = malloc(msg_size + 1);
|
||||||
(void)memcpy_s(data, msg_size, msg.c_str(), msg_size);
|
(void)memcpy_s(data, msg_size, msg.c_str(), msg_size);
|
||||||
message->data = data;
|
message->data = data;
|
||||||
|
@ -57,6 +59,36 @@ std::unique_ptr<MessageBase> RDMATest::CreateMessage(const std::string &msg) {
|
||||||
/// Description: test basic connection function between RDMA client and server.
|
/// Description: test basic connection function between RDMA client and server.
|
||||||
/// Expectation: RDMA client successfully connects to RDMA server and sends a simple message.
|
/// Expectation: RDMA client successfully connects to RDMA server and sends a simple message.
|
||||||
TEST_F(RDMATest, TestRDMAConnection) {
|
TEST_F(RDMATest, TestRDMAConnection) {
|
||||||
|
std::string url = "127.0.0.1:10969";
|
||||||
|
size_t server_pid = fork();
|
||||||
|
if (server_pid == 0) {
|
||||||
|
std::shared_ptr<RDMAServer> rdma_server = std::make_shared<RDMAServer>();
|
||||||
|
MS_EXCEPTION_IF_NULL(rdma_server);
|
||||||
|
ASSERT_TRUE(rdma_server->Initialize(url));
|
||||||
|
sleep(3);
|
||||||
|
rdma_server->Finalize();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
sleep(1);
|
||||||
|
size_t client_pid = fork();
|
||||||
|
if (client_pid == 0) {
|
||||||
|
std::shared_ptr<RDMAClient> rdma_client = std::make_shared<RDMAClient>();
|
||||||
|
MS_EXCEPTION_IF_NULL(rdma_client);
|
||||||
|
ASSERT_TRUE(rdma_client->Initialize());
|
||||||
|
ASSERT_TRUE(rdma_client->Connect(url));
|
||||||
|
rdma_client->Finalize();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int wstatus;
|
||||||
|
(void)waitpid(client_pid, &wstatus, WUNTRACED | WCONTINUED);
|
||||||
|
(void)waitpid(server_pid, &wstatus, WUNTRACED | WCONTINUED);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: RDMA communication.
|
||||||
|
/// Description: test SendSync interface for RDMA client and server.
|
||||||
|
/// Expectation: RDMA client successfully sends two messages to RDMA server synchronously.
|
||||||
|
TEST_F(RDMATest, TestRDMASendSync) {
|
||||||
std::string url = "127.0.0.1:10969";
|
std::string url = "127.0.0.1:10969";
|
||||||
size_t server_pid = fork();
|
size_t server_pid = fork();
|
||||||
if (server_pid == 0) {
|
if (server_pid == 0) {
|
||||||
|
@ -70,6 +102,7 @@ TEST_F(RDMATest, TestRDMAConnection) {
|
||||||
};
|
};
|
||||||
rdma_server->SetMessageHandler(msg_handler);
|
rdma_server->SetMessageHandler(msg_handler);
|
||||||
sleep(3);
|
sleep(3);
|
||||||
|
rdma_server->Finalize();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
sleep(1);
|
sleep(1);
|
||||||
|
@ -80,8 +113,54 @@ TEST_F(RDMATest, TestRDMAConnection) {
|
||||||
ASSERT_TRUE(rdma_client->Initialize());
|
ASSERT_TRUE(rdma_client->Initialize());
|
||||||
ASSERT_TRUE(rdma_client->Connect(url));
|
ASSERT_TRUE(rdma_client->Connect(url));
|
||||||
|
|
||||||
auto message = CreateMessage("Hello server!");
|
auto message1 = CreateMessage("Hello server sync!");
|
||||||
ASSERT_TRUE(rdma_client->SendSync(std::move(message)));
|
ASSERT_TRUE(rdma_client->SendSync(std::move(message1)));
|
||||||
|
auto message2 = CreateMessage("Hello server sync!");
|
||||||
|
ASSERT_TRUE(rdma_client->SendSync(std::move(message2)));
|
||||||
|
rdma_client->Finalize();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int wstatus;
|
||||||
|
(void)waitpid(client_pid, &wstatus, WUNTRACED | WCONTINUED);
|
||||||
|
(void)waitpid(server_pid, &wstatus, WUNTRACED | WCONTINUED);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: RDMA communication.
|
||||||
|
/// Description: test SendAsync interface for RDMA client and server.
|
||||||
|
/// Expectation: RDMA client successfully sends two messages to RDMA server asynchronously.
|
||||||
|
TEST_F(RDMATest, TestRDMASendAsync) {
|
||||||
|
std::string url = "127.0.0.1:10969";
|
||||||
|
size_t server_pid = fork();
|
||||||
|
if (server_pid == 0) {
|
||||||
|
std::shared_ptr<RDMAServer> rdma_server = std::make_shared<RDMAServer>();
|
||||||
|
MS_EXCEPTION_IF_NULL(rdma_server);
|
||||||
|
ASSERT_TRUE(rdma_server->Initialize(url));
|
||||||
|
|
||||||
|
auto msg_handler = [](MessageBase *const msg) {
|
||||||
|
MS_LOG(INFO) << "Receive message from client: " << static_cast<char *>(msg->data);
|
||||||
|
return nullptr;
|
||||||
|
};
|
||||||
|
rdma_server->SetMessageHandler(msg_handler);
|
||||||
|
sleep(3);
|
||||||
|
rdma_server->Finalize();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
sleep(1);
|
||||||
|
size_t client_pid = fork();
|
||||||
|
if (client_pid == 0) {
|
||||||
|
std::shared_ptr<RDMAClient> rdma_client = std::make_shared<RDMAClient>();
|
||||||
|
MS_EXCEPTION_IF_NULL(rdma_client);
|
||||||
|
ASSERT_TRUE(rdma_client->Initialize());
|
||||||
|
ASSERT_TRUE(rdma_client->Connect(url));
|
||||||
|
|
||||||
|
auto message1 = CreateMessage("Hello server async!");
|
||||||
|
rdma_client->SendAsync(std::move(message1));
|
||||||
|
ASSERT_TRUE(rdma_client->Flush(url));
|
||||||
|
auto message2 = CreateMessage("Hello server async!");
|
||||||
|
rdma_client->SendAsync(std::move(message2));
|
||||||
|
ASSERT_TRUE(rdma_client->Flush(url));
|
||||||
|
|
||||||
rdma_client->Finalize();
|
rdma_client->Finalize();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue