diff --git a/mindspore/ccsrc/distributed/CMakeLists.txt b/mindspore/ccsrc/distributed/CMakeLists.txt index 9b40a5a4391..26f0a14cd6f 100644 --- a/mindspore/ccsrc/distributed/CMakeLists.txt +++ b/mindspore/ccsrc/distributed/CMakeLists.txt @@ -19,6 +19,13 @@ if(NOT ENABLE_CPU OR WIN32 OR APPLE) endforeach() endif() +if(${ENABLE_RDMA} STREQUAL "ON") + include_directories(/usr/include/umdk) +else() + list(REMOVE_ITEM _DISTRIBUTED_SRC_FILES "rpc/rdma/rdma_client.cc") + list(REMOVE_ITEM _DISTRIBUTED_SRC_FILES "rpc/rdma/rdma_server.cc") +endif() + set_property(SOURCE ${_DISTRIBUTED_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DISTRIBUTED) add_library(_mindspore_distributed_obj OBJECT ${_DISTRIBUTED_SRC_FILES}) diff --git a/mindspore/ccsrc/distributed/cluster/topology/meta_server_node.h b/mindspore/ccsrc/distributed/cluster/topology/meta_server_node.h index fdb3685d547..53b8255c4b3 100644 --- a/mindspore/ccsrc/distributed/cluster/topology/meta_server_node.h +++ b/mindspore/ccsrc/distributed/cluster/topology/meta_server_node.h @@ -142,7 +142,7 @@ class MetaServerNode : public NodeBase { // All the handlers for compute graph node's system messages processing. // The `system` means the built-in messages used for cluster topology construction. - std::map system_msg_handlers_; + std::map system_msg_handlers_; // All the handlers for compute graph node's user-defined messages processing. // The `user-defined` means that this kind of message is user defined and has customized message handler. diff --git a/mindspore/ccsrc/distributed/constants.h b/mindspore/ccsrc/distributed/constants.h index 158a73d515e..64ac6056b2e 100644 --- a/mindspore/ccsrc/distributed/constants.h +++ b/mindspore/ccsrc/distributed/constants.h @@ -22,6 +22,13 @@ #include #include #include +#include + +#include "actor/log.h" +#include "actor/msg.h" +#include "utils/ms_utils.h" +#include "utils/log_adapter.h" +#include "include/backend/visible.h" namespace mindspore { namespace distributed { @@ -88,6 +95,22 @@ constexpr char kControlDstOpName[] = "ControlDst"; // This macro the current timestamp in milliseconds. #define CURRENT_TIMESTAMP_MILLI \ (std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch())) + +using MessageHandler = std::function; + +/** + * @description: The callback function type for allocating memory after receiving data for the peer. + * @param {size_t} size: Size of the memory to be allocated. + * @return {void *}: A pointer to the newly allocated memory. + */ +using MemAllocateCallback = std::function; + +/** + * @description: The callback function for releasing memory after sending it to the peer. + * @param {void} *data: The memory to be released, which should be allocated on heap. + * @return {bool}: Whether the memory is successfully released. + */ +using MemFreeCallback = std::function; } // namespace distributed } // namespace mindspore #endif // MINDSPORE_CCSRC_DISTRIBUTED_CONSTANTS_H_ diff --git a/mindspore/ccsrc/distributed/rpc/rdma/constants.h b/mindspore/ccsrc/distributed/rpc/rdma/constants.h new file mode 100644 index 00000000000..29a34e6620d --- /dev/null +++ b/mindspore/ccsrc/distributed/rpc/rdma/constants.h @@ -0,0 +1,66 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_CONSTANTS_H_ +#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_CONSTANTS_H_ + +#include +#include + +#include "utils/dlopen_macro.h" +#include "distributed/constants.h" + +namespace mindspore { +namespace distributed { +namespace rpc { +inline void *LoadURPC() { + static void *urpc_handle = nullptr; + if (urpc_handle == nullptr) { + urpc_handle = dlopen("liburpc.so", RTLD_LAZY | RTLD_LOCAL); + if (urpc_handle == nullptr) { + auto err = GetDlErrorMsg(); + MS_LOG(EXCEPTION) << "dlopen liburpc.so failed. Error message: " << err; + } + } + return urpc_handle; +} +static void *kURPCHandle = LoadURPC(); + +#define REG_URPC_METHOD(name, return_type, ...) \ + constexpr const char *k##name##Name = #name; \ + using name##FunObj = std::function; \ + using name##FunPtr = return_type (*)(__VA_ARGS__); \ + const name##FunPtr name##_func = DlsymFuncObj(name, kURPCHandle); + +// The symbols of liburpc.so to be dynamically loaded. +REG_URPC_METHOD(urpc_init, int, struct urpc_config *) +REG_URPC_METHOD(urpc_uninit, void) +REG_URPC_METHOD(urpc_connect, urpc_session_t *, const char *, uint16_t, urma_jfs_t *) +REG_URPC_METHOD(urpc_close, void, urpc_session_t *) +REG_URPC_METHOD(urpc_register_memory, int, void *, int) +REG_URPC_METHOD(urpc_register_serdes, int, const char *, const urpc_serdes_t *, urpc_tx_cb_t, void *) +REG_URPC_METHOD(urpc_register_handler, int, urpc_handler_info_t *, uint32_t *) +REG_URPC_METHOD(urpc_register_raw_handler_explicit, int, urpc_raw_handler_t, void *, urpc_tx_cb_t, void *, uint32_t) +REG_URPC_METHOD(urpc_unregister_handler, void, const char *, uint32_t) +REG_URPC_METHOD(urpc_query_capability, int, struct urpc_cap *) +REG_URPC_METHOD(urpc_send_request, int, urpc_session_t *, struct urpc_send_wr *, struct urpc_send_option *) +REG_URPC_METHOD(urpc_call, int, urpc_session_t *, const char *, void *, void **, struct urpc_send_option *) +REG_URPC_METHOD(urpc_call_sgl, int, urpc_session_t *, const char *, void *, void **, struct urpc_send_option *) +REG_URPC_METHOD(urpc_get_default_allocator, struct urpc_buffer_allocator *) +} // namespace rpc +} // namespace distributed +} // namespace mindspore +#endif // MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_CONSTANTS_H_ diff --git a/mindspore/ccsrc/distributed/rpc/rdma/rdma_client.cc b/mindspore/ccsrc/distributed/rpc/rdma/rdma_client.cc new file mode 100644 index 00000000000..66a6a416db8 --- /dev/null +++ b/mindspore/ccsrc/distributed/rpc/rdma/rdma_client.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "distributed/rpc/rdma/rdma_client.h" + +namespace mindspore { +namespace distributed { +namespace rpc { +bool RDMAClient::Initialize() { return true; } + +void RDMAClient::Finalize() {} + +bool RDMAClient::Connect(const std::string &dst_url, size_t retry_count, const MemFreeCallback &free_cb) { + return true; +} + +bool RDMAClient::IsConnected(const std::string &dst_url) { return false; } + +bool RDMAClient::Disconnect(const std::string &dst_url, size_t timeout_in_sec) { return true; } + +bool RDMAClient::SendSync(std::unique_ptr &&msg, size_t *const send_bytes) { return true; } + +void RDMAClient::SendAsync(std::unique_ptr &&msg) {} + +bool RDMAClient::Flush(const std::string &dst_url) { return true; } +} // namespace rpc +} // namespace distributed +} // namespace mindspore diff --git a/mindspore/ccsrc/distributed/rpc/rdma/rdma_client.h b/mindspore/ccsrc/distributed/rpc/rdma/rdma_client.h new file mode 100644 index 00000000000..b9de43041f4 --- /dev/null +++ b/mindspore/ccsrc/distributed/rpc/rdma/rdma_client.h @@ -0,0 +1,58 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_RDMA_CLIENT_H_ +#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_RDMA_CLIENT_H_ + +#include +#include +#include +#include + +#include "distributed/rpc/rdma/constants.h" +#include "distributed/rpc/rpc_client_base.h" + +namespace mindspore { +namespace distributed { +namespace rpc { +class BACKEND_EXPORT RDMAClient : public RPCClientBase { + public: + explicit RDMAClient(bool enable_ssl = false) : RPCClientBase(enable_ssl) {} + ~RDMAClient() override = default; + + bool Initialize() override; + void Finalize() override; + bool Connect( + const std::string &dst_url, size_t retry_count = 60, const MemFreeCallback &free_cb = [](void *data) { + MS_ERROR_IF_NULL(data); + delete static_cast(data); + return true; + }) override; + bool IsConnected(const std::string &dst_url) override; + bool Disconnect(const std::string &dst_url, size_t timeout_in_sec = 5) override; + + bool SendSync(std::unique_ptr &&msg, size_t *const send_bytes = nullptr) override; + void SendAsync(std::unique_ptr &&msg) override; + + bool Flush(const std::string &dst_url) override; + + private: +}; +} // namespace rpc +} // namespace distributed +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_RDMA_CLIENT_H_ diff --git a/mindspore/ccsrc/distributed/rpc/rdma/rdma_server.cc b/mindspore/ccsrc/distributed/rpc/rdma/rdma_server.cc new file mode 100644 index 00000000000..49309c51650 --- /dev/null +++ b/mindspore/ccsrc/distributed/rpc/rdma/rdma_server.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "distributed/rpc/rdma/rdma_server.h" + +namespace mindspore { +namespace distributed { +namespace rpc { +bool RDMAServer::Initialize(const std::string &url, const MemAllocateCallback &allocate_cb) { return true; } + +void RDMAServer::Finalize() {} + +void RDMAServer::SetMessageHandler(const MessageHandler &handler) {} + +std::string RDMAServer::GetIP() const { return ""; } + +uint32_t RDMAServer::GetPort() const { return 0; } +} // namespace rpc +} // namespace distributed +} // namespace mindspore diff --git a/mindspore/ccsrc/distributed/rpc/rdma/rdma_server.h b/mindspore/ccsrc/distributed/rpc/rdma/rdma_server.h new file mode 100644 index 00000000000..524fbe8770f --- /dev/null +++ b/mindspore/ccsrc/distributed/rpc/rdma/rdma_server.h @@ -0,0 +1,47 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_RDMA_SERVER_H_ +#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_RDMA_SERVER_H_ + +#include +#include + +#include "distributed/rpc/rdma/constants.h" +#include "distributed/rpc/rpc_server_base.h" + +namespace mindspore { +namespace distributed { +namespace rpc { +class BACKEND_EXPORT RDMAServer : public RPCServerBase { + public: + explicit RDMAServer(bool enable_ssl = false) : RPCServerBase(enable_ssl) {} + ~RDMAServer() override = default; + + bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb = {}) override; + void Finalize() override; + void SetMessageHandler(const MessageHandler &handler) override; + + std::string GetIP() const override; + uint32_t GetPort() const override; + + private: +}; +} // namespace rpc +} // namespace distributed +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DISTRIBUTED_RPC_RDMA_RDMA_SERVER_H_ diff --git a/mindspore/ccsrc/distributed/rpc/rpc_client_base.h b/mindspore/ccsrc/distributed/rpc/rpc_client_base.h new file mode 100644 index 00000000000..173b36c4ffe --- /dev/null +++ b/mindspore/ccsrc/distributed/rpc/rpc_client_base.h @@ -0,0 +1,63 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_RPC_CLIENT_BASE_H_ +#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_RPC_CLIENT_BASE_H_ + +#include +#include + +#include "distributed/constants.h" + +namespace mindspore { +namespace distributed { +namespace rpc { +class BACKEND_EXPORT RPCClientBase { + public: + explicit RPCClientBase(bool enable_ssl = false) : enable_ssl_(enable_ssl) {} + virtual ~RPCClientBase() = default; + + // Build or destroy the rpc client. + virtual bool Initialize() { return true; } + virtual void Finalize() {} + + // Connect to the specified server. + // Function free_cb binds with client's each connection. It frees the real memory after message is sent to the peer. + virtual bool Connect(const std::string &dst_url, size_t retry_count, const MemFreeCallback &free_cb) { return true; } + + // Check if the connection to dst_url has been established. + virtual bool IsConnected(const std::string &dst_url) { return false; } + + // Disconnect from the specified server. + virtual bool Disconnect(const std::string &dst_url, size_t timeout_in_sec) { return true; } + + // Send the message from the source to the destination synchronously and return the byte size by this method call. + virtual bool SendSync(std::unique_ptr &&msg, size_t *const send_bytes) { return true; } + + // Send the message from the source to the destination asynchronously. + virtual void SendAsync(std::unique_ptr &&msg) {} + + // Force the data in the send buffer to be sent out. + virtual bool Flush(const std::string &dst_url) { return true; } + + protected: + bool enable_ssl_; +}; +} // namespace rpc +} // namespace distributed +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DISTRIBUTED_RPC_RPC_CLIENT_BASE_H_ diff --git a/mindspore/ccsrc/distributed/rpc/rpc_server_base.h b/mindspore/ccsrc/distributed/rpc/rpc_server_base.h new file mode 100644 index 00000000000..0c26b117963 --- /dev/null +++ b/mindspore/ccsrc/distributed/rpc/rpc_server_base.h @@ -0,0 +1,59 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_DISTRIBUTED_RPC_RPC_SERVER_BASE_H_ +#define MINDSPORE_CCSRC_DISTRIBUTED_RPC_RPC_SERVER_BASE_H_ + +#include +#include + +#include "distributed/constants.h" + +namespace mindspore { +namespace distributed { +namespace rpc { +class BACKEND_EXPORT RPCServerBase { + public: + explicit RPCServerBase(bool enable_ssl) : ip_(""), port_(0), enable_ssl_(enable_ssl) {} + virtual ~RPCServerBase() = default; + + // Init server using the specified url, with memory allocating function. + virtual bool Initialize(const std::string &url, const MemAllocateCallback &allocate_cb) { return true; } + + // Init server using local IP and random port. + virtual bool Initialize(const MemAllocateCallback &allocate_cb) { return true; } + + // Destroy the tcp server. + virtual void Finalize() {} + + // Set the message processing handler. + virtual void SetMessageHandler(const MessageHandler &handler) {} + + // Return the IP and port bound to this server. + virtual std::string GetIP() const { return ip_; } + virtual uint32_t GetPort() const { return port_; } + + protected: + std::string ip_; + uint32_t port_; + + bool enable_ssl_; +}; +} // namespace rpc +} // namespace distributed +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_DISTRIBUTED_RPC_RPC_SERVER_BASE_H_ diff --git a/mindspore/ccsrc/distributed/rpc/tcp/constants.h b/mindspore/ccsrc/distributed/rpc/tcp/constants.h index e1c64356829..406a9048b52 100644 --- a/mindspore/ccsrc/distributed/rpc/tcp/constants.h +++ b/mindspore/ccsrc/distributed/rpc/tcp/constants.h @@ -24,30 +24,14 @@ #include #include -#include "actor/log.h" -#include "actor/msg.h" +#include "distributed/constants.h" namespace mindspore { namespace distributed { namespace rpc { -using MessageHandler = std::function; using DeleteCallBack = void (*)(const std::string &from, const std::string &to); using ConnectionCallBack = std::function; -/** - * @description: The callback function type for allocating memory after receiving data for the peer. - * @param {size_t} size: Size of the memory to be allocated. - * @return {void *}: A pointer to the newly allocated memory. - */ -using MemAllocateCallback = std::function; - -/** - * @description: The callback function for releasing memory after sending it to the peer. - * @param {void} *data: The memory to be released, which should be allocated on heap. - * @return {bool}: Whether the memory is successfully released. - */ -using MemFreeCallback = std::function; - constexpr int SEND_MSG_IO_VEC_LEN = 5; constexpr int RECV_MSG_IO_VEC_LEN = 4; diff --git a/scripts/build/process_options.sh b/scripts/build/process_options.sh index 931d2345174..83f7814d8b1 100755 --- a/scripts/build/process_options.sh +++ b/scripts/build/process_options.sh @@ -91,7 +91,7 @@ process_options() E) check_on_off $OPTARG E export ENABLE_RDMA="$OPTARG" - echo "enable RDMA for RPC $ENABLE_RDMA" ;; + echo "RDMA for RPC $ENABLE_RDMA" ;; A) build_option_proc_upper_a ;; W) diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 830a0e7ba2b..e15ce81ca69 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -94,6 +94,10 @@ if(ENABLE_MINDDATA) ./distributed/cluster/topology/*.cc ./distributed/recovery/*.cc ./distributed/embedding_cache/*.cc) + if(${ENABLE_RDMA} STREQUAL "ON") + include_directories(/usr/include/umdk) + file(GLOB_RECURSE UT_DISTRIBUTED_SRCS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ./distributed/rpc/rdma/*.cc) + endif() list(APPEND UT_SRCS ${UT_DISTRIBUTED_SRCS}) endif() if(NOT ENABLE_PYTHON) @@ -227,6 +231,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/plugin/device/cpu/kernel/rpc/rpc_recv_kernel.cc" "../../../mindspore/ccsrc/distributed/persistent/*.cc" "../../../mindspore/ccsrc/distributed/rpc/tcp/*.cc" + "../../../mindspore/ccsrc/distributed/rpc/rdma/*.cc" "../../../mindspore/ccsrc/distributed/cluster/topology/*.cc" "../../../mindspore/ccsrc/distributed/embedding_cache/*.cc" "../../../mindspore/ccsrc/plugin/device/ascend/hal/profiler/*.cc" @@ -239,6 +244,11 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ext_info_handle.cc" "../../../mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.cc" ) +if(${ENABLE_RDMA} STREQUAL "OFF") + list(REMOVE_ITEM MINDSPORE_SRC_LIST + "../../../mindspore/ccsrc/distributed/rpc/rdma/rdma_client.cc" + "../../../mindspore/ccsrc/distributed/rpc/rdma/rdma_server.cc") +endif() list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/util.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/ascend/optimizer/create_node_helper.cc") diff --git a/tests/ut/cpp/distributed/rpc/rdma/rdma_test.cc b/tests/ut/cpp/distributed/rpc/rdma/rdma_test.cc new file mode 100644 index 00000000000..a39a330871a --- /dev/null +++ b/tests/ut/cpp/distributed/rpc/rdma/rdma_test.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#define private public +#include "distributed/rpc/rdma/rdma_server.h" +#include "distributed/rpc/rdma/rdma_client.h" +#include "distributed/rpc/rdma/constants.h" +#include "common/common_test.h" +#undef private + +namespace mindspore { +namespace distributed { +namespace rpc { +class RDMATest : public UT::Common { + public: + RDMATest() = default; + ~RDMATest() = default; +}; + +/// Feature: RDMA communication. +/// Description: test basic connection function between RDMA client and server. +/// Expectation: RDMA client successfully connects to RDMA server and sends a simple message. +TEST_F(RDMATest, TestRDMAConnection) { + size_t server_pid = fork(); + if (server_pid == 0) { + std::shared_ptr rdma_server = std::make_shared(); + (void)rdma_server->Initialize(kLocalHost); + return; + } + sleep(2); + size_t client_pid = fork(); + if (client_pid == 0) { + std::shared_ptr rdma_client = std::make_shared(); + (void)rdma_client->Initialize(); + return; + } + + int wstatus; + (void)waitpid(client_pid, &wstatus, WUNTRACED | WCONTINUED); + (void)waitpid(server_pid, &wstatus, WUNTRACED | WCONTINUED); +} +} // namespace rpc +} // namespace distributed +} // namespace mindspore