add collective communication topology node

This commit is contained in:
Parallels 2022-05-05 15:10:12 +08:00
parent 19464ee3e7
commit 68d8d9ee78
5 changed files with 415 additions and 2 deletions

View File

@ -2,8 +2,9 @@ if(ENABLE_CPU)
file(GLOB_RECURSE HARDWARE_CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "mpi_collective_comm_lib.cc" "mpi_communication_group.cc")
if(WIN32)
if(WIN32 OR APPLE)
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "ms_collective_comm_lib.cc" "allreduce_impl.cc")
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "ms_collective_topo.cc")
endif()
if(ENABLE_MPI)
set(MPI_COLLECTIVE_SRCS "mpi_collective_comm_lib.cc"
@ -23,4 +24,4 @@ if(ENABLE_CPU)
set_property(SOURCE ${HARDWARE_CPU_SRC_LIST} PROPERTY
COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
add_library(_mindspore_plugin_device_cpu_hal_hardware_obj OBJECT ${HARDWARE_CPU_SRC_LIST})
endif()
endif()

View File

@ -0,0 +1,177 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <memory>
#include <utility>
#include "plugin/device/cpu/hal/hardware/ms_collective_topo.h"
namespace mindspore {
namespace device {
namespace cpu {
bool TopologyNode::Initialize() {
// Initialize the rank id.
MS_EXCEPTION_IF_NULL(cgn_);
rank_id_ = cgn_->rank_id();
// Initialize the tcp server.
tcp_server_ = std::make_unique<distributed::rpc::TCPServer>();
RETURN_IF_FALSE_WITH_LOG(tcp_server_->Initialize(), "Failed to initialize the tcp server.");
tcp_server_->SetMessageHandler(std::bind(&TopologyNode::HandleMessage, this, std::placeholders::_1));
// Put the address of this topo node into meta server node.
auto ip = tcp_server_->GetIP();
auto port = tcp_server_->GetPort();
auto rank_name = "RNAK_ID_" + std::to_string(rank_id_);
auto address = ip + ":" + std::to_string(port);
cgn_->PutMetadata(rank_name, address);
// Get the address of the topo node of the next rank from meta server node and create an tcp connection to it.
// A thread is used because all the addresses of other rank are registered asynchronously into the meta server.
distributed::rpc::TCPClient *tcp_client = new distributed::rpc::TCPClient();
RETURN_IF_FALSE_WITH_LOG(tcp_client->Initialize(), "Failed to initialize the tcp client to the next rank.");
size_t next_rank_id = (rank_id_ + 1) % this->total_node_num_;
tcp_clients_[next_rank_id] = tcp_client;
// Because all the topo node address metadata are registered into the metadata server asynchronously, a separate
// thread is needed to fetch these metadata.
init_thread_ = std::thread([this, next_rank_id]() {
size_t retry = 60;
while (retry-- > 0) {
// Lookup the address from meta server node.
auto next_rank_name = "RNAK_ID_" + std::to_string(next_rank_id);
std::string next_rank_addr = this->cgn_->GetMetadata(next_rank_name);
if (next_rank_addr.length() > 0) {
if (this->tcp_clients_[next_rank_id]->Connect(next_rank_addr)) {
this->node_addresses_[next_rank_id] = next_rank_addr;
this->initialized_ = true;
break;
}
}
MS_LOG(INFO) << "Retry to get the address of next rank : " << next_rank_name;
static uint32_t interval = 3;
sleep(interval);
}
});
return true;
}
bool TopologyNode::Initialized() {
init_thread_.join();
return initialized_;
}
bool TopologyNode::Finalize() {
// Destroy the tcp server.
MS_EXCEPTION_IF_NULL(tcp_server_);
tcp_server_->Finalize();
tcp_server_.reset();
// Destroy the tcp clients.
for (auto iter = tcp_clients_.begin(); iter != tcp_clients_.end(); iter++) {
auto &client = iter->second;
if (client != nullptr) {
client->Finalize();
delete client;
client = nullptr;
}
}
// Destroy the received message queues.
for (auto iter = received_messages_.begin(); iter != received_messages_.end(); iter++) {
auto &queue = iter->second;
if (queue != nullptr) {
delete queue;
queue = nullptr;
}
}
return true;
}
bool TopologyNode::SendAsync(size_t rank_id, void *data, size_t size) {
if (tcp_clients_.find(rank_id) == tcp_clients_.end()) {
MS_LOG(ERROR) << "Cann not find tcp client for rank id: " << rank_id;
return false;
}
auto &tcp_client = tcp_clients_[rank_id];
MS_EXCEPTION_IF_NULL(tcp_client);
std::unique_ptr<MessageBase> message = std::make_unique<MessageBase>();
MS_EXCEPTION_IF_NULL(message);
message->name = std::to_string(rank_id_);
message->to = AID("", node_addresses_[rank_id]);
message->body.reserve(size);
message->body.append(static_cast<char *>(data), size);
tcp_client->SendAsync(std::move(message));
return true;
}
bool TopologyNode::WaitForSend(size_t rank_id) {
// Wait for all the pending data to be sent to the destination of specified rank id.
if (tcp_clients_.find(rank_id) == tcp_clients_.end()) {
MS_LOG(ERROR) << "Can not find tcp client for rank id: " << rank_id;
return false;
}
if (node_addresses_.find(rank_id) == node_addresses_.end()) {
MS_LOG(ERROR) << "Can not find the address for rank id: " << rank_id;
}
auto &tcp_client = tcp_clients_[rank_id];
MS_EXCEPTION_IF_NULL(tcp_client);
return tcp_client->Flush(node_addresses_[rank_id]);
}
bool TopologyNode::Receive(size_t rank_id, MessageBase **message, size_t timeout) {
std::unique_lock<std::mutex> lock(cond_mutex_);
bool rt = cond_var_.wait_for(lock, std::chrono::seconds(timeout), [this, rank_id] {
return this->received_messages_.find(rank_id) != this->received_messages_.end() &&
this->received_messages_[rank_id] != nullptr && this->received_messages_[rank_id]->size() > 0;
});
if (rt) {
auto queue = this->received_messages_[rank_id];
MS_EXCEPTION_IF_NULL(queue);
auto recv_msg = queue->front();
queue->pop();
MS_EXCEPTION_IF_NULL(message);
MS_EXCEPTION_IF_NULL(recv_msg);
*message = recv_msg;
}
return rt;
}
size_t TopologyNode::rank_id() { return rank_id_; }
MessageBase *const TopologyNode::HandleMessage(MessageBase *const message) {
MS_EXCEPTION_IF_NULL(message);
auto rank_id = std::stoi(message->name);
std::lock_guard<std::mutex> lock(cond_mutex_);
std::queue<MessageBase *> *queue = nullptr;
if (received_messages_.find(rank_id) == received_messages_.end()) {
queue = new std::queue<MessageBase *>();
received_messages_[rank_id] = queue;
}
queue->push(message);
cond_var_.notify_all();
return distributed::rpc::NULL_MSG;
}
} // namespace cpu
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,97 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_TOPO_H_
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_TOPO_H_
#include <string>
#include <memory>
#include <queue>
#include <map>
#include <atomic>
#include <mutex>
#include <condition_variable>
#include "actor/msg.h"
#include "distributed/rpc/tcp/tcp_client.h"
#include "distributed/rpc/tcp/tcp_server.h"
#include "distributed/cluster/topology/compute_graph_node.h"
namespace mindspore {
namespace device {
namespace cpu {
class TopologyNode {
public:
TopologyNode(size_t total_node_num, std::shared_ptr<distributed::cluster::topology::ComputeGraphNode> cgn)
: rank_id_(-1), total_node_num_(total_node_num), cgn_(cgn), initialized_(false) {}
~TopologyNode() = default;
// Init this topology node includes build tcp clients and server.
bool Initialize();
// Indicates whether this topo node has been initialized successfully.
bool Initialized();
// Destroy tcp clients and the tcp server.
bool Finalize();
// Send data asynchronously to the specified rank node.
bool SendAsync(size_t rank_id, void *data, size_t size);
// Wait for all the pending sending tasks to the rank_id to be finished.
bool WaitForSend(size_t rank_id);
// Receive data asynchronously from the specified rank node.
bool Receive(size_t rank_id, MessageBase **message, size_t timeout = 15);
size_t rank_id();
private:
// Handle the message received by the tcp server.
MessageBase *const HandleMessage(MessageBase *const message);
// The rank id of this node in the collective communication topology.
size_t rank_id_;
// The total topology node number.
size_t total_node_num_;
// The received messages sent from other rank nodes.
std::map<size_t, std::queue<MessageBase *> *> received_messages_;
// Synchronizer for receive message queue reads and writes.
std::mutex cond_mutex_;
std::condition_variable cond_var_;
// The tcp clients for other ranks, each client is responsible for sending message to the specified rank node.
std::map<size_t, distributed::rpc::TCPClient *> tcp_clients_;
// Maintain the tcp addresses for other nodes if needed.
std::map<size_t, std::string> node_addresses_;
// The tcp server which is responsible for receiving messages from other rank nodes.
std::unique_ptr<distributed::rpc::TCPServer> tcp_server_;
// The compute grpah node used to exchange the topology meta info(eg. ip:port) between topology nodes.
std::shared_ptr<distributed::cluster::topology::ComputeGraphNode> cgn_;
std::atomic<bool> initialized_;
std::thread init_thread_;
};
} // namespace cpu
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_TOPO_H_

View File

@ -78,6 +78,7 @@ if(ENABLE_MINDDATA)
./tbe/*.cc
./mindapi/*.cc
./runtime/graph_scheduler/*.cc
./plugin/device/cpu/hal/*.cc
)
if(NOT ENABLE_SECURITY)
file(GLOB_RECURSE UT_SRCS_DEBUG RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
@ -148,6 +149,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/lic_manager.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_graph_optimization.cc"
"../../../mindspore/ccsrc/plugin/device/cpu/hal/hardware/ms_collective_topo.cc"
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/cpu_kernel.cc"
"../../../mindspore/ccsrc/plugin/factory/ms_factory.h"
"../../../mindspore/ccsrc/plugin/device/cpu/kernel/sparse_apply_adam_cpu_kernel.cc"

View File

@ -0,0 +1,136 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include "distributed/cluster/topology/compute_graph_node.h"
#include "distributed/cluster/topology/meta_server_node.h"
#include "plugin/device/cpu/hal/hardware/ms_collective_topo.h"
#include "utils/ms_utils.h"
#include "common/common_test.h"
namespace mindspore {
namespace device {
namespace cpu {
class TestMSCollectiveTopo : public UT::Common {
protected:
void SetUp() {}
void TearDown() {}
};
/// Feature: test create cpu collective topology node.
/// Description: create the topology node.
/// Expectation: the topology node is created successfully.
TEST_F(TestMSCollectiveTopo, InitCollectiveTopoNode) {
std::string server_host = "127.0.0.1";
std::string server_port = "8090";
common::SetEnv(distributed::cluster::topology::kEnvMetaServerHost, server_host.c_str());
common::SetEnv(distributed::cluster::topology::kEnvMetaServerPort, server_port.c_str());
size_t total_node_num = 8;
std::vector<std::shared_ptr<distributed::cluster::topology::ComputeGraphNode>> cgns;
distributed::cluster::topology::MetaServerNode msn("meta_server_node", total_node_num);
ASSERT_TRUE(msn.Initialize());
for (size_t i = 0; i < total_node_num; ++i) {
auto cgn =
std::make_shared<distributed::cluster::topology::ComputeGraphNode>("compute_graph_node_" + std::to_string(i + 1));
ASSERT_TRUE(cgn->Initialize());
cgns.push_back(cgn);
}
size_t interval = 1;
size_t retry = 30;
while (((msn.GetAliveNodeNum() != total_node_num) ||
(msn.TopologyState() != distributed::cluster::topology::TopoState::kInitialized)) &&
(retry-- > 0)) {
sleep(interval);
}
ASSERT_EQ(total_node_num, msn.GetAliveNodeNum());
ASSERT_EQ(distributed::cluster::topology::TopoState::kInitialized, msn.TopologyState());
// Create the topo nodes.
std::vector<std::shared_ptr<TopologyNode>> topo_nodes;
for (size_t i = 0; i < total_node_num; ++i) {
auto node = std::make_shared<TopologyNode>(total_node_num, cgns[i]);
topo_nodes.push_back(node);
node->Initialize();
}
for (size_t i = 0; i < total_node_num; ++i) {
ASSERT_TRUE(topo_nodes[i]->Initialized());
}
// Check the rank id of topo node.
for (size_t i = 0; i < total_node_num; ++i) {
ASSERT_EQ(i, topo_nodes[i]->rank_id());
}
// Test data communication.
for (size_t i = 0; i < total_node_num; ++i) {
auto node = topo_nodes[i];
auto rank_id = node->rank_id();
auto next_rank_id = (rank_id + 1) % total_node_num;
std::string data = "model gradients " + std::to_string(rank_id);
node->SendAsync(next_rank_id, data.data(), data.length());
}
// Flush all the sending data.
for (size_t i = 0; i < total_node_num; ++i) {
auto node = topo_nodes[i];
auto rank_id = node->rank_id();
auto next_rank_id = (rank_id + 1) % total_node_num;
node->WaitForSend(next_rank_id);
}
// Receive data from other rank nodes.
for (size_t i = 0; i < total_node_num; ++i) {
auto node = topo_nodes[i];
auto rank_id = node->rank_id();
auto upstream_rank_id = (rank_id > 0) ? (rank_id - 1) : (total_node_num - 1);
MessageBase *message = nullptr;
node->Receive(upstream_rank_id, &message);
ASSERT_NE(nullptr, message);
ASSERT_EQ(std::to_string(upstream_rank_id), message->name);
ASSERT_EQ("model gradients " + std::to_string(upstream_rank_id), message->body);
delete message;
message = nullptr;
}
// Destroy the topo nodes.
for (size_t i = 0; i < total_node_num; ++i) {
topo_nodes[i]->Finalize();
}
for (auto &cgn : cgns) {
cgn->Finalize();
}
retry = 30;
while ((msn.GetAliveNodeNum() > 0 || msn.TopologyState() != distributed::cluster::topology::TopoState::kFinished) &&
retry-- > 0) {
sleep(interval);
}
ASSERT_EQ(0, msn.GetAliveNodeNum());
ASSERT_EQ(distributed::cluster::topology::TopoState::kFinished, msn.TopologyState());
msn.Finalize();
}
} // namespace cpu
} // namespace device
} // namespace mindspore