forked from mindspore-Ecosystem/mindspore
add collective communication topology node
This commit is contained in:
parent
19464ee3e7
commit
68d8d9ee78
|
@ -2,8 +2,9 @@ if(ENABLE_CPU)
|
|||
file(GLOB_RECURSE HARDWARE_CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "mpi_collective_comm_lib.cc" "mpi_communication_group.cc")
|
||||
|
||||
if(WIN32)
|
||||
if(WIN32 OR APPLE)
|
||||
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "ms_collective_comm_lib.cc" "allreduce_impl.cc")
|
||||
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "ms_collective_topo.cc")
|
||||
endif()
|
||||
if(ENABLE_MPI)
|
||||
set(MPI_COLLECTIVE_SRCS "mpi_collective_comm_lib.cc"
|
||||
|
@ -23,4 +24,4 @@ if(ENABLE_CPU)
|
|||
set_property(SOURCE ${HARDWARE_CPU_SRC_LIST} PROPERTY
|
||||
COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
|
||||
add_library(_mindspore_plugin_device_cpu_hal_hardware_obj OBJECT ${HARDWARE_CPU_SRC_LIST})
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/hal/hardware/ms_collective_topo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
bool TopologyNode::Initialize() {
|
||||
// Initialize the rank id.
|
||||
MS_EXCEPTION_IF_NULL(cgn_);
|
||||
rank_id_ = cgn_->rank_id();
|
||||
|
||||
// Initialize the tcp server.
|
||||
tcp_server_ = std::make_unique<distributed::rpc::TCPServer>();
|
||||
RETURN_IF_FALSE_WITH_LOG(tcp_server_->Initialize(), "Failed to initialize the tcp server.");
|
||||
tcp_server_->SetMessageHandler(std::bind(&TopologyNode::HandleMessage, this, std::placeholders::_1));
|
||||
|
||||
// Put the address of this topo node into meta server node.
|
||||
auto ip = tcp_server_->GetIP();
|
||||
auto port = tcp_server_->GetPort();
|
||||
auto rank_name = "RNAK_ID_" + std::to_string(rank_id_);
|
||||
auto address = ip + ":" + std::to_string(port);
|
||||
cgn_->PutMetadata(rank_name, address);
|
||||
|
||||
// Get the address of the topo node of the next rank from meta server node and create an tcp connection to it.
|
||||
// A thread is used because all the addresses of other rank are registered asynchronously into the meta server.
|
||||
distributed::rpc::TCPClient *tcp_client = new distributed::rpc::TCPClient();
|
||||
RETURN_IF_FALSE_WITH_LOG(tcp_client->Initialize(), "Failed to initialize the tcp client to the next rank.");
|
||||
|
||||
size_t next_rank_id = (rank_id_ + 1) % this->total_node_num_;
|
||||
tcp_clients_[next_rank_id] = tcp_client;
|
||||
|
||||
// Because all the topo node address metadata are registered into the metadata server asynchronously, a separate
|
||||
// thread is needed to fetch these metadata.
|
||||
init_thread_ = std::thread([this, next_rank_id]() {
|
||||
size_t retry = 60;
|
||||
while (retry-- > 0) {
|
||||
// Lookup the address from meta server node.
|
||||
auto next_rank_name = "RNAK_ID_" + std::to_string(next_rank_id);
|
||||
std::string next_rank_addr = this->cgn_->GetMetadata(next_rank_name);
|
||||
if (next_rank_addr.length() > 0) {
|
||||
if (this->tcp_clients_[next_rank_id]->Connect(next_rank_addr)) {
|
||||
this->node_addresses_[next_rank_id] = next_rank_addr;
|
||||
this->initialized_ = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Retry to get the address of next rank : " << next_rank_name;
|
||||
static uint32_t interval = 3;
|
||||
sleep(interval);
|
||||
}
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TopologyNode::Initialized() {
|
||||
init_thread_.join();
|
||||
return initialized_;
|
||||
}
|
||||
|
||||
bool TopologyNode::Finalize() {
|
||||
// Destroy the tcp server.
|
||||
MS_EXCEPTION_IF_NULL(tcp_server_);
|
||||
tcp_server_->Finalize();
|
||||
tcp_server_.reset();
|
||||
|
||||
// Destroy the tcp clients.
|
||||
for (auto iter = tcp_clients_.begin(); iter != tcp_clients_.end(); iter++) {
|
||||
auto &client = iter->second;
|
||||
if (client != nullptr) {
|
||||
client->Finalize();
|
||||
delete client;
|
||||
client = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy the received message queues.
|
||||
for (auto iter = received_messages_.begin(); iter != received_messages_.end(); iter++) {
|
||||
auto &queue = iter->second;
|
||||
if (queue != nullptr) {
|
||||
delete queue;
|
||||
queue = nullptr;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TopologyNode::SendAsync(size_t rank_id, void *data, size_t size) {
|
||||
if (tcp_clients_.find(rank_id) == tcp_clients_.end()) {
|
||||
MS_LOG(ERROR) << "Cann not find tcp client for rank id: " << rank_id;
|
||||
return false;
|
||||
}
|
||||
auto &tcp_client = tcp_clients_[rank_id];
|
||||
MS_EXCEPTION_IF_NULL(tcp_client);
|
||||
|
||||
std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
|
||||
message->name = std::to_string(rank_id_);
|
||||
message->to = AID("", node_addresses_[rank_id]);
|
||||
message->body.reserve(size);
|
||||
message->body.append(static_cast<char *>(data), size);
|
||||
|
||||
tcp_client->SendAsync(std::move(message));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TopologyNode::WaitForSend(size_t rank_id) {
|
||||
// Wait for all the pending data to be sent to the destination of specified rank id.
|
||||
if (tcp_clients_.find(rank_id) == tcp_clients_.end()) {
|
||||
MS_LOG(ERROR) << "Can not find tcp client for rank id: " << rank_id;
|
||||
return false;
|
||||
}
|
||||
if (node_addresses_.find(rank_id) == node_addresses_.end()) {
|
||||
MS_LOG(ERROR) << "Can not find the address for rank id: " << rank_id;
|
||||
}
|
||||
auto &tcp_client = tcp_clients_[rank_id];
|
||||
MS_EXCEPTION_IF_NULL(tcp_client);
|
||||
|
||||
return tcp_client->Flush(node_addresses_[rank_id]);
|
||||
}
|
||||
|
||||
bool TopologyNode::Receive(size_t rank_id, MessageBase **message, size_t timeout) {
|
||||
std::unique_lock<std::mutex> lock(cond_mutex_);
|
||||
bool rt = cond_var_.wait_for(lock, std::chrono::seconds(timeout), [this, rank_id] {
|
||||
return this->received_messages_.find(rank_id) != this->received_messages_.end() &&
|
||||
this->received_messages_[rank_id] != nullptr && this->received_messages_[rank_id]->size() > 0;
|
||||
});
|
||||
if (rt) {
|
||||
auto queue = this->received_messages_[rank_id];
|
||||
MS_EXCEPTION_IF_NULL(queue);
|
||||
auto recv_msg = queue->front();
|
||||
queue->pop();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
MS_EXCEPTION_IF_NULL(recv_msg);
|
||||
*message = recv_msg;
|
||||
}
|
||||
return rt;
|
||||
}
|
||||
|
||||
size_t TopologyNode::rank_id() { return rank_id_; }
|
||||
|
||||
MessageBase *const TopologyNode::HandleMessage(MessageBase *const message) {
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
auto rank_id = std::stoi(message->name);
|
||||
|
||||
std::lock_guard<std::mutex> lock(cond_mutex_);
|
||||
std::queue<MessageBase *> *queue = nullptr;
|
||||
if (received_messages_.find(rank_id) == received_messages_.end()) {
|
||||
queue = new std::queue<MessageBase *>();
|
||||
received_messages_[rank_id] = queue;
|
||||
}
|
||||
queue->push(message);
|
||||
cond_var_.notify_all();
|
||||
return distributed::rpc::NULL_MSG;
|
||||
}
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,97 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_TOPO_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_TOPO_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <map>
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include "actor/msg.h"
|
||||
#include "distributed/rpc/tcp/tcp_client.h"
|
||||
#include "distributed/rpc/tcp/tcp_server.h"
|
||||
#include "distributed/cluster/topology/compute_graph_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
class TopologyNode {
|
||||
public:
|
||||
TopologyNode(size_t total_node_num, std::shared_ptr<distributed::cluster::topology::ComputeGraphNode> cgn)
|
||||
: rank_id_(-1), total_node_num_(total_node_num), cgn_(cgn), initialized_(false) {}
|
||||
~TopologyNode() = default;
|
||||
|
||||
// Init this topology node includes build tcp clients and server.
|
||||
bool Initialize();
|
||||
|
||||
// Indicates whether this topo node has been initialized successfully.
|
||||
bool Initialized();
|
||||
|
||||
// Destroy tcp clients and the tcp server.
|
||||
bool Finalize();
|
||||
|
||||
// Send data asynchronously to the specified rank node.
|
||||
bool SendAsync(size_t rank_id, void *data, size_t size);
|
||||
|
||||
// Wait for all the pending sending tasks to the rank_id to be finished.
|
||||
bool WaitForSend(size_t rank_id);
|
||||
|
||||
// Receive data asynchronously from the specified rank node.
|
||||
bool Receive(size_t rank_id, MessageBase **message, size_t timeout = 15);
|
||||
|
||||
size_t rank_id();
|
||||
|
||||
private:
|
||||
// Handle the message received by the tcp server.
|
||||
MessageBase *const HandleMessage(MessageBase *const message);
|
||||
|
||||
// The rank id of this node in the collective communication topology.
|
||||
size_t rank_id_;
|
||||
|
||||
// The total topology node number.
|
||||
size_t total_node_num_;
|
||||
|
||||
// The received messages sent from other rank nodes.
|
||||
std::map<size_t, std::queue<MessageBase *> *> received_messages_;
|
||||
|
||||
// Synchronizer for receive message queue reads and writes.
|
||||
std::mutex cond_mutex_;
|
||||
std::condition_variable cond_var_;
|
||||
|
||||
// The tcp clients for other ranks, each client is responsible for sending message to the specified rank node.
|
||||
std::map<size_t, distributed::rpc::TCPClient *> tcp_clients_;
|
||||
|
||||
// Maintain the tcp addresses for other nodes if needed.
|
||||
std::map<size_t, std::string> node_addresses_;
|
||||
|
||||
// The tcp server which is responsible for receiving messages from other rank nodes.
|
||||
std::unique_ptr<distributed::rpc::TCPServer> tcp_server_;
|
||||
|
||||
// The compute grpah node used to exchange the topology meta info(eg. ip:port) between topology nodes.
|
||||
std::shared_ptr<distributed::cluster::topology::ComputeGraphNode> cgn_;
|
||||
|
||||
std::atomic<bool> initialized_;
|
||||
|
||||
std::thread init_thread_;
|
||||
};
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_TOPO_H_
|
|
@ -78,6 +78,7 @@ if(ENABLE_MINDDATA)
|
|||
./tbe/*.cc
|
||||
./mindapi/*.cc
|
||||
./runtime/graph_scheduler/*.cc
|
||||
./plugin/device/cpu/hal/*.cc
|
||||
)
|
||||
if(NOT ENABLE_SECURITY)
|
||||
file(GLOB_RECURSE UT_SRCS_DEBUG RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
|
@ -148,6 +149,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/lic_manager.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_graph_optimization.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/cpu/hal/hardware/ms_collective_topo.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/cpu_kernel.cc"
|
||||
"../../../mindspore/ccsrc/plugin/factory/ms_factory.h"
|
||||
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_adam_cpu_kernel.cc"
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "distributed/cluster/topology/compute_graph_node.h"
|
||||
#include "distributed/cluster/topology/meta_server_node.h"
|
||||
#include "plugin/device/cpu/hal/hardware/ms_collective_topo.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "common/common_test.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
class TestMSCollectiveTopo : public UT::Common {
|
||||
protected:
|
||||
void SetUp() {}
|
||||
void TearDown() {}
|
||||
};
|
||||
|
||||
/// Feature: test create cpu collective topology node.
|
||||
/// Description: create the topology node.
|
||||
/// Expectation: the topology node is created successfully.
|
||||
TEST_F(TestMSCollectiveTopo, InitCollectiveTopoNode) {
|
||||
std::string server_host = "127.0.0.1";
|
||||
std::string server_port = "8090";
|
||||
common::SetEnv(distributed::cluster::topology::kEnvMetaServerHost, server_host.c_str());
|
||||
common::SetEnv(distributed::cluster::topology::kEnvMetaServerPort, server_port.c_str());
|
||||
|
||||
size_t total_node_num = 8;
|
||||
std::vector<std::shared_ptr<distributed::cluster::topology::ComputeGraphNode>> cgns;
|
||||
distributed::cluster::topology::MetaServerNode msn("meta_server_node", total_node_num);
|
||||
ASSERT_TRUE(msn.Initialize());
|
||||
|
||||
for (size_t i = 0; i < total_node_num; ++i) {
|
||||
auto cgn =
|
||||
std::make_shared<distributed::cluster::topology::ComputeGraphNode>("compute_graph_node_" + std::to_string(i + 1));
|
||||
ASSERT_TRUE(cgn->Initialize());
|
||||
cgns.push_back(cgn);
|
||||
}
|
||||
|
||||
size_t interval = 1;
|
||||
size_t retry = 30;
|
||||
while (((msn.GetAliveNodeNum() != total_node_num) ||
|
||||
(msn.TopologyState() != distributed::cluster::topology::TopoState::kInitialized)) &&
|
||||
(retry-- > 0)) {
|
||||
sleep(interval);
|
||||
}
|
||||
|
||||
ASSERT_EQ(total_node_num, msn.GetAliveNodeNum());
|
||||
ASSERT_EQ(distributed::cluster::topology::TopoState::kInitialized, msn.TopologyState());
|
||||
|
||||
// Create the topo nodes.
|
||||
std::vector<std::shared_ptr<TopologyNode>> topo_nodes;
|
||||
for (size_t i = 0; i < total_node_num; ++i) {
|
||||
auto node = std::make_shared<TopologyNode>(total_node_num, cgns[i]);
|
||||
topo_nodes.push_back(node);
|
||||
node->Initialize();
|
||||
}
|
||||
for (size_t i = 0; i < total_node_num; ++i) {
|
||||
ASSERT_TRUE(topo_nodes[i]->Initialized());
|
||||
}
|
||||
|
||||
// Check the rank id of topo node.
|
||||
for (size_t i = 0; i < total_node_num; ++i) {
|
||||
ASSERT_EQ(i, topo_nodes[i]->rank_id());
|
||||
}
|
||||
|
||||
// Test data communication.
|
||||
for (size_t i = 0; i < total_node_num; ++i) {
|
||||
auto node = topo_nodes[i];
|
||||
auto rank_id = node->rank_id();
|
||||
auto next_rank_id = (rank_id + 1) % total_node_num;
|
||||
|
||||
std::string data = "model gradients " + std::to_string(rank_id);
|
||||
node->SendAsync(next_rank_id, data.data(), data.length());
|
||||
}
|
||||
|
||||
// Flush all the sending data.
|
||||
for (size_t i = 0; i < total_node_num; ++i) {
|
||||
auto node = topo_nodes[i];
|
||||
auto rank_id = node->rank_id();
|
||||
auto next_rank_id = (rank_id + 1) % total_node_num;
|
||||
node->WaitForSend(next_rank_id);
|
||||
}
|
||||
|
||||
// Receive data from other rank nodes.
|
||||
for (size_t i = 0; i < total_node_num; ++i) {
|
||||
auto node = topo_nodes[i];
|
||||
auto rank_id = node->rank_id();
|
||||
auto upstream_rank_id = (rank_id > 0) ? (rank_id - 1) : (total_node_num - 1);
|
||||
MessageBase *message = nullptr;
|
||||
node->Receive(upstream_rank_id, &message);
|
||||
|
||||
ASSERT_NE(nullptr, message);
|
||||
ASSERT_EQ(std::to_string(upstream_rank_id), message->name);
|
||||
ASSERT_EQ("model gradients " + std::to_string(upstream_rank_id), message->body);
|
||||
|
||||
delete message;
|
||||
message = nullptr;
|
||||
}
|
||||
|
||||
// Destroy the topo nodes.
|
||||
for (size_t i = 0; i < total_node_num; ++i) {
|
||||
topo_nodes[i]->Finalize();
|
||||
}
|
||||
|
||||
for (auto &cgn : cgns) {
|
||||
cgn->Finalize();
|
||||
}
|
||||
|
||||
retry = 30;
|
||||
while ((msn.GetAliveNodeNum() > 0 || msn.TopologyState() != distributed::cluster::topology::TopoState::kFinished) &&
|
||||
retry-- > 0) {
|
||||
sleep(interval);
|
||||
}
|
||||
ASSERT_EQ(0, msn.GetAliveNodeNum());
|
||||
ASSERT_EQ(distributed::cluster::topology::TopoState::kFinished, msn.TopologyState());
|
||||
|
||||
msn.Finalize();
|
||||
}
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue