!35501 Add ms collective ops implementation

Merge pull request !35501 from chengang/add_ms_collective_ops
This commit is contained in:
i-robot 2022-06-08 09:14:49 +00:00 committed by Gitee
commit 5cd41e5940
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 330 additions and 4 deletions

View File

@ -3,7 +3,8 @@ if(ENABLE_CPU)
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "mpi_collective_comm_lib.cc" "mpi_communication_group.cc")
if(WIN32 OR APPLE)
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "ms_collective_comm_lib.cc" "allreduce_impl.cc")
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "ms_collective_comm_lib.cc" "allreduce_impl.cc"
"ms_collective_ops_impl.cc")
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "ms_collective_topo.cc")
endif()
if(ENABLE_MPI)

View File

@ -0,0 +1,203 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <numeric>
#include "plugin/device/cpu/hal/hardware/ms_collective_ops_impl.h"
#include "distributed/cluster/cluster_context.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace device {
namespace cpu {
namespace {
const char kCollectivePhaseRing[] = "ring";
const char kCollectivePhaseGather[] = "gather";
const char kCollectivePhaseReduce[] = "reduce";
const char kCollectivePhaseBroadcast[] = "broadcast";
} // namespace
bool MSCollectiveOpsImpl::Initialize() {
MS_EXCEPTION_IF_NULL(topo_node_);
rank_id_ = topo_node_->rank_id();
return true;
}
template <typename T>
bool MSCollectiveOpsImpl::RingAllGather(const void *sendbuff, void *recvbuff, size_t send_count) {
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
size_t chunk_size = send_count;
std::vector<size_t> chunk_sizes(rank_size_, chunk_size);
// Store offsets to get every data chunk's address.
std::vector<size_t> chunk_offset;
for (size_t i = 0; i < rank_size_; i++) {
size_t ofs = std::accumulate(chunk_sizes.begin(), chunk_sizes.begin() + SizeToLong(i), static_cast<size_t>(0),
std::plus<size_t>());
chunk_offset.push_back(ofs);
}
uint32_t send_to_rank = (rank_id_ + 1) % rank_size_;
uint32_t recv_from_rank = (rank_id_ - 1 + rank_size_) % rank_size_;
MS_LOG(DEBUG) << "Ring AllGather count:" << send_count << ", rank_size:" << rank_size_ << ", rank_id_:" << rank_id_
<< ", chunk_size:" << chunk_size << ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank
<< ", recv_from_rank:" << recv_from_rank;
T *output_buff = reinterpret_cast<T *>(recvbuff);
size_t src_size = send_count * sizeof(T);
size_t dst_size = send_count * sizeof(T);
int ret = memcpy_s(output_buff + chunk_offset[rank_id_], dst_size, sendbuff, src_size);
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"
<< ", dest size is " << dst_size << ", src size is " << src_size;
return false;
}
return RingAllGatherImpl(send_to_rank, recv_from_rank, output_buff, chunk_offset, chunk_sizes);
}
template <typename T>
bool MSCollectiveOpsImpl::RingAllGatherImpl(uint32_t send_to_rank, uint32_t recv_from_rank, T *output_buff,
const std::vector<size_t> &chunk_offset,
const std::vector<size_t> &chunk_sizes) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
// If enable recovery, set timeout 300s to prevent networking flapping.
uint32_t timeout =
context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY) ? kCollectiveCommMaxTimeout : kCollectiveCommTimeout;
for (size_t i = 0; i < rank_size_ - 1; i++) {
size_t send_chunk_index = (rank_id_ - i + rank_size_) % rank_size_;
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
topo_node_->SendAsync(send_to_rank, send_chunk, chunk_sizes[send_chunk_index] * sizeof(T));
size_t recv_chunk_index = (rank_id_ - i - 1 + rank_size_) % rank_size_;
T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
MS_LOG(DEBUG) << "Ring AllGather send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank
<< ", send count:" << chunk_sizes[send_chunk_index]
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
MessageBase *message = nullptr;
if (!topo_node_->Receive(recv_from_rank, &message, timeout)) {
MS_LOG(ERROR) << "Failed to receive data from rank " << recv_from_rank;
return false;
}
MS_EXCEPTION_IF_NULL(message);
auto ret =
memcpy_s(recv_chunk, chunk_sizes[recv_chunk_index] * sizeof(T), message->body.data(), message->body.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"
<< ", dest size is " << (chunk_sizes[recv_chunk_index] * sizeof(T)) << ", src size is "
<< message->body.length();
return false;
}
delete message;
message = nullptr;
if (!topo_node_->WaitForSend(send_to_rank)) {
MS_LOG(ERROR) << "Failed to send data to rank: " << send_to_rank;
return false;
}
}
return true;
}
template <typename T>
bool MSCollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t count, uint32_t root,
const CommunicationGroupInfo &group_info) {
std::unique_lock<std::mutex> lock(mtx_);
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
// Initialize collective communication parameters.
rank_id_ = topo_node_->rank_id();
rank_size_ = group_info.size;
if (rank_size_ == 0) {
MS_LOG(ERROR) << "Rank size should not be 0.";
return false;
}
if (rank_size_ == 1) {
MS_LOG(INFO) << "Rank size is 1. Do nothing.";
return true;
}
auto group_to_global_ranks = group_info.group_to_global_ranks;
if (group_to_global_ranks.empty()) {
MS_LOG(ERROR) << "The group is empty.";
return false;
}
uint32_t group_rank_size = SizeToUint(group_info.group_ranks.size());
uint32_t global_root_rank = group_to_global_ranks[root];
// Broadcast data to processes which are not the root.
MS_LOG(DEBUG) << "Start broadcast from root to other processes.";
if (rank_id_ == global_root_rank) {
for (uint32_t i = 1; i < group_rank_size; i++) {
uint32_t dst_rank = group_to_global_ranks[i];
MS_LOG(DEBUG) << "Broadcast data to process " << dst_rank;
topo_node_->SendAsync(dst_rank, const_cast<void *>(sendbuff), count * sizeof(T));
if (!topo_node_->WaitForSend(dst_rank)) {
MS_LOG(ERROR) << "Failed to send data to rank: " << dst_rank;
return false;
}
}
} else {
MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
MessageBase *message = nullptr;
if (!topo_node_->Receive(global_root_rank, &message)) {
MS_LOG(ERROR) << "Failed to receive data from rank " << global_root_rank;
return false;
}
MS_EXCEPTION_IF_NULL(message);
int ret = memcpy_s(recvbuff, count * sizeof(T), message->body.data(), message->body.length());
if (ret != 0) {
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"
<< ", dest size is " << (count * sizeof(T)) << ", src size is " << message->body.length();
return false;
}
}
MS_LOG(DEBUG) << "End broadcast.";
return true;
}
template <typename T>
bool MSCollectiveOpsImpl::AllGather(const void *sendbuff, void *recvbuff, size_t send_count) {
std::unique_lock<std::mutex> lock(mtx_);
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
// Initialize collective communication parameters.
rank_id_ = topo_node_->rank_id();
rank_size_ = topo_node_->rank_size();
if (rank_size_ == 0) {
MS_LOG(ERROR) << "Rank size should not be 0.";
return false;
}
if (rank_size_ == 1) {
MS_LOG(INFO) << "Rank size is 1. Do nothing.";
return true;
}
return RingAllGather<T>(sendbuff, recvbuff, send_count);
}
} // namespace cpu
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,116 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_OPS_IMPL_H_
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_OPS_IMPL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <functional>
#include "plugin/device/cpu/hal/hardware/ms_collective_topo.h"
namespace mindspore {
namespace device {
namespace cpu {
// The timeout for server collective communication in case of network jitter.
constexpr uint32_t kCollectiveCommTimeout = 30;
// The max timeout for server collective communication, used in disaster recovery to prevent networking flapping.
constexpr uint32_t kCollectiveCommMaxTimeout = 300;
// The collective communication groups which are composed of multiple processes. Refer to MPI_Group.
struct CommunicationGroupInfo {
// This group's rank size.
uint32_t size;
// This process's global rank id.
uint32_t global_rank;
// The group ranks consists of global ranks of the processes.
std::vector<uint32_t> group_ranks;
// The mapping of global ranks and group ranks.
std::map<uint32_t, uint32_t> global_to_group_ranks;
std::map<uint32_t, uint32_t> group_to_global_ranks;
};
// MSCollectiveOpsImpl is the collective communication API of the server.
// For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also
// supported for the elastic scaling feature of the server.
class MSCollectiveOpsImpl {
public:
explicit MSCollectiveOpsImpl(std::shared_ptr<TopologyNode> topo_node)
: rank_id_(0), rank_size_(0), topo_node_(topo_node) {}
~MSCollectiveOpsImpl() = default;
bool Initialize();
template <typename T>
bool AllReduce(const std::string &data_name, void *sendbuff, void *recvbuff, size_t count);
template <typename T>
bool AllGather(const void *sendbuff, void *recvbuff, size_t send_count);
// Collective broadcast within the specified group. The parameter "root" is the group rank of the root process.
// Normally 0.
template <typename T>
bool Broadcast(const void *sendbuff, void *recvbuff, size_t count, uint32_t root,
const CommunicationGroupInfo &group_info);
private:
MSCollectiveOpsImpl(const MSCollectiveOpsImpl &) = delete;
MSCollectiveOpsImpl &operator=(const MSCollectiveOpsImpl &) = delete;
// Implementation of RingAllGather.
template <typename T>
bool RingAllGather(const void *sendbuff, void *recvbuff, size_t send_count);
template <typename T>
bool RingAllGatherImpl(uint32_t send_to_rank, uint32_t recv_from_rank, T *output_buff,
const std::vector<size_t> &chunk_offset, const std::vector<size_t> &chunk_sizes);
uint32_t rank_id_;
uint32_t rank_size_;
std::shared_ptr<TopologyNode> topo_node_{nullptr};
// The mutex to ensure that collective communication is threadsafe.
std::mutex mtx_;
};
template bool MSCollectiveOpsImpl::AllGather<float>(const void *sendbuff, void *recvbuff, size_t send_count);
template bool MSCollectiveOpsImpl::AllGather<uint64_t>(const void *sendbuff, void *recvbuff, size_t send_count);
template bool MSCollectiveOpsImpl::AllGather<int>(const void *sendbuff, void *recvbuff, size_t send_count);
template bool MSCollectiveOpsImpl::AllGather<char>(const void *sendbuff, void *recvbuff, size_t send_count);
template bool MSCollectiveOpsImpl::RingAllGather<float>(const void *sendbuff, void *recvbuff, size_t send_count);
template bool MSCollectiveOpsImpl::RingAllGather<uint64_t>(const void *sendbuff, void *recvbuff, size_t send_count);
template bool MSCollectiveOpsImpl::RingAllGather<int>(const void *sendbuff, void *recvbuff, size_t send_count);
template bool MSCollectiveOpsImpl::RingAllGather<char>(const void *sendbuff, void *recvbuff, size_t send_count);
template bool MSCollectiveOpsImpl::Broadcast<float>(const void *sendbuff, void *recvbuff, size_t count, uint32_t root,
const CommunicationGroupInfo &group_info);
template bool MSCollectiveOpsImpl::Broadcast<uint64_t>(const void *sendbuff, void *recvbuff, size_t count,
uint32_t root, const CommunicationGroupInfo &group_info);
template bool MSCollectiveOpsImpl::Broadcast<int>(const void *sendbuff, void *recvbuff, size_t count, uint32_t root,
const CommunicationGroupInfo &group_info);
template bool MSCollectiveOpsImpl::Broadcast<char>(const void *sendbuff, void *recvbuff, size_t count, uint32_t root,
const CommunicationGroupInfo &group_info);
} // namespace cpu
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_OPS_IMPL_H_

View File

@ -104,7 +104,7 @@ bool TopologyNode::Finalize() {
bool TopologyNode::SendAsync(size_t rank_id, void *data, size_t size) {
if (tcp_clients_.find(rank_id) == tcp_clients_.end()) {
MS_LOG(ERROR) << "Cann not find tcp client for rank id: " << rank_id;
MS_LOG(ERROR) << "Cann not find tcp client for rank id: " << rank_id << ", local rank: " << rank_id_;
return false;
}
auto &tcp_client = tcp_clients_[rank_id];
@ -125,11 +125,11 @@ bool TopologyNode::SendAsync(size_t rank_id, void *data, size_t size) {
bool TopologyNode::WaitForSend(size_t rank_id) {
// Wait for all the pending data to be sent to the destination of specified rank id.
if (tcp_clients_.find(rank_id) == tcp_clients_.end()) {
MS_LOG(ERROR) << "Can not find tcp client for rank id: " << rank_id;
MS_LOG(ERROR) << "Can not find tcp client for rank id: " << rank_id << ", local rank: " << rank_id_;
return false;
}
if (node_addresses_.find(rank_id) == node_addresses_.end()) {
MS_LOG(ERROR) << "Can not find the address for rank id: " << rank_id;
MS_LOG(ERROR) << "Can not find the address for rank id: " << rank_id << ", local rank: " << rank_id_;
}
auto &tcp_client = tcp_clients_[rank_id];
MS_EXCEPTION_IF_NULL(tcp_client);
@ -152,12 +152,16 @@ bool TopologyNode::Receive(size_t rank_id, MessageBase **message, size_t timeout
MS_EXCEPTION_IF_NULL(message);
MS_EXCEPTION_IF_NULL(recv_msg);
*message = recv_msg;
} else {
MS_LOG(ERROR) << "Failed to receive message from rank: " << rank_id << ", local rank: " << rank_id_;
}
return rt;
}
size_t TopologyNode::rank_id() { return rank_id_; }
size_t TopologyNode::rank_size() { return total_node_num_; }
MessageBase *const TopologyNode::HandleMessage(MessageBase *const message) {
MS_EXCEPTION_IF_NULL(message);
auto rank_id = std::stoi(message->name);

View File

@ -58,6 +58,8 @@ class TopologyNode {
size_t rank_id();
size_t rank_size();
private:
// Handle the message received by the tcp server.
MessageBase *const HandleMessage(MessageBase *const message);