forked from mindspore-Ecosystem/mindspore
!35501 Add ms collective ops implementation
Merge pull request !35501 from chengang/add_ms_collective_ops
This commit is contained in:
commit
5cd41e5940
|
@ -3,7 +3,8 @@ if(ENABLE_CPU)
|
|||
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "mpi_collective_comm_lib.cc" "mpi_communication_group.cc")
|
||||
|
||||
if(WIN32 OR APPLE)
|
||||
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "ms_collective_comm_lib.cc" "allreduce_impl.cc")
|
||||
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "ms_collective_comm_lib.cc" "allreduce_impl.cc"
|
||||
"ms_collective_ops_impl.cc")
|
||||
list(REMOVE_ITEM HARDWARE_CPU_SRC_LIST "ms_collective_topo.cc")
|
||||
endif()
|
||||
if(ENABLE_MPI)
|
||||
|
|
|
@ -0,0 +1,203 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <numeric>
|
||||
#include "plugin/device/cpu/hal/hardware/ms_collective_ops_impl.h"
|
||||
#include "distributed/cluster/cluster_context.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
namespace {
|
||||
const char kCollectivePhaseRing[] = "ring";
|
||||
const char kCollectivePhaseGather[] = "gather";
|
||||
const char kCollectivePhaseReduce[] = "reduce";
|
||||
const char kCollectivePhaseBroadcast[] = "broadcast";
|
||||
} // namespace
|
||||
|
||||
bool MSCollectiveOpsImpl::Initialize() {
|
||||
MS_EXCEPTION_IF_NULL(topo_node_);
|
||||
rank_id_ = topo_node_->rank_id();
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MSCollectiveOpsImpl::RingAllGather(const void *sendbuff, void *recvbuff, size_t send_count) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
|
||||
|
||||
size_t chunk_size = send_count;
|
||||
std::vector<size_t> chunk_sizes(rank_size_, chunk_size);
|
||||
|
||||
// Store offsets to get every data chunk's address.
|
||||
std::vector<size_t> chunk_offset;
|
||||
for (size_t i = 0; i < rank_size_; i++) {
|
||||
size_t ofs = std::accumulate(chunk_sizes.begin(), chunk_sizes.begin() + SizeToLong(i), static_cast<size_t>(0),
|
||||
std::plus<size_t>());
|
||||
chunk_offset.push_back(ofs);
|
||||
}
|
||||
|
||||
uint32_t send_to_rank = (rank_id_ + 1) % rank_size_;
|
||||
uint32_t recv_from_rank = (rank_id_ - 1 + rank_size_) % rank_size_;
|
||||
MS_LOG(DEBUG) << "Ring AllGather count:" << send_count << ", rank_size:" << rank_size_ << ", rank_id_:" << rank_id_
|
||||
<< ", chunk_size:" << chunk_size << ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank
|
||||
<< ", recv_from_rank:" << recv_from_rank;
|
||||
|
||||
T *output_buff = reinterpret_cast<T *>(recvbuff);
|
||||
size_t src_size = send_count * sizeof(T);
|
||||
size_t dst_size = send_count * sizeof(T);
|
||||
int ret = memcpy_s(output_buff + chunk_offset[rank_id_], dst_size, sendbuff, src_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"
|
||||
<< ", dest size is " << dst_size << ", src size is " << src_size;
|
||||
return false;
|
||||
}
|
||||
return RingAllGatherImpl(send_to_rank, recv_from_rank, output_buff, chunk_offset, chunk_sizes);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MSCollectiveOpsImpl::RingAllGatherImpl(uint32_t send_to_rank, uint32_t recv_from_rank, T *output_buff,
|
||||
const std::vector<size_t> &chunk_offset,
|
||||
const std::vector<size_t> &chunk_sizes) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
// If enable recovery, set timeout 300s to prevent networking flapping.
|
||||
uint32_t timeout =
|
||||
context_ptr->get_param<bool>(MS_CTX_ENABLE_RECOVERY) ? kCollectiveCommMaxTimeout : kCollectiveCommTimeout;
|
||||
|
||||
for (size_t i = 0; i < rank_size_ - 1; i++) {
|
||||
size_t send_chunk_index = (rank_id_ - i + rank_size_) % rank_size_;
|
||||
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
|
||||
|
||||
topo_node_->SendAsync(send_to_rank, send_chunk, chunk_sizes[send_chunk_index] * sizeof(T));
|
||||
|
||||
size_t recv_chunk_index = (rank_id_ - i - 1 + rank_size_) % rank_size_;
|
||||
T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
|
||||
MS_LOG(DEBUG) << "Ring AllGather send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank
|
||||
<< ", send count:" << chunk_sizes[send_chunk_index]
|
||||
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
|
||||
|
||||
MessageBase *message = nullptr;
|
||||
if (!topo_node_->Receive(recv_from_rank, &message, timeout)) {
|
||||
MS_LOG(ERROR) << "Failed to receive data from rank " << recv_from_rank;
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
auto ret =
|
||||
memcpy_s(recv_chunk, chunk_sizes[recv_chunk_index] * sizeof(T), message->body.data(), message->body.length());
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"
|
||||
<< ", dest size is " << (chunk_sizes[recv_chunk_index] * sizeof(T)) << ", src size is "
|
||||
<< message->body.length();
|
||||
return false;
|
||||
}
|
||||
delete message;
|
||||
message = nullptr;
|
||||
|
||||
if (!topo_node_->WaitForSend(send_to_rank)) {
|
||||
MS_LOG(ERROR) << "Failed to send data to rank: " << send_to_rank;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MSCollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t count, uint32_t root,
|
||||
const CommunicationGroupInfo &group_info) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
|
||||
|
||||
// Initialize collective communication parameters.
|
||||
rank_id_ = topo_node_->rank_id();
|
||||
rank_size_ = group_info.size;
|
||||
if (rank_size_ == 0) {
|
||||
MS_LOG(ERROR) << "Rank size should not be 0.";
|
||||
return false;
|
||||
}
|
||||
if (rank_size_ == 1) {
|
||||
MS_LOG(INFO) << "Rank size is 1. Do nothing.";
|
||||
return true;
|
||||
}
|
||||
|
||||
auto group_to_global_ranks = group_info.group_to_global_ranks;
|
||||
if (group_to_global_ranks.empty()) {
|
||||
MS_LOG(ERROR) << "The group is empty.";
|
||||
return false;
|
||||
}
|
||||
uint32_t group_rank_size = SizeToUint(group_info.group_ranks.size());
|
||||
uint32_t global_root_rank = group_to_global_ranks[root];
|
||||
|
||||
// Broadcast data to processes which are not the root.
|
||||
MS_LOG(DEBUG) << "Start broadcast from root to other processes.";
|
||||
if (rank_id_ == global_root_rank) {
|
||||
for (uint32_t i = 1; i < group_rank_size; i++) {
|
||||
uint32_t dst_rank = group_to_global_ranks[i];
|
||||
MS_LOG(DEBUG) << "Broadcast data to process " << dst_rank;
|
||||
|
||||
topo_node_->SendAsync(dst_rank, const_cast<void *>(sendbuff), count * sizeof(T));
|
||||
if (!topo_node_->WaitForSend(dst_rank)) {
|
||||
MS_LOG(ERROR) << "Failed to send data to rank: " << dst_rank;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
|
||||
|
||||
MessageBase *message = nullptr;
|
||||
if (!topo_node_->Receive(global_root_rank, &message)) {
|
||||
MS_LOG(ERROR) << "Failed to receive data from rank " << global_root_rank;
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
int ret = memcpy_s(recvbuff, count * sizeof(T), message->body.data(), message->body.length());
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"
|
||||
<< ", dest size is " << (count * sizeof(T)) << ", src size is " << message->body.length();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "End broadcast.";
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MSCollectiveOpsImpl::AllGather(const void *sendbuff, void *recvbuff, size_t send_count) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
|
||||
|
||||
// Initialize collective communication parameters.
|
||||
rank_id_ = topo_node_->rank_id();
|
||||
rank_size_ = topo_node_->rank_size();
|
||||
if (rank_size_ == 0) {
|
||||
MS_LOG(ERROR) << "Rank size should not be 0.";
|
||||
return false;
|
||||
}
|
||||
if (rank_size_ == 1) {
|
||||
MS_LOG(INFO) << "Rank size is 1. Do nothing.";
|
||||
return true;
|
||||
}
|
||||
|
||||
return RingAllGather<T>(sendbuff, recvbuff, send_count);
|
||||
}
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,116 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_OPS_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_OPS_IMPL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include "plugin/device/cpu/hal/hardware/ms_collective_topo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
// The timeout for server collective communication in case of network jitter.
|
||||
constexpr uint32_t kCollectiveCommTimeout = 30;
|
||||
// The max timeout for server collective communication, used in disaster recovery to prevent networking flapping.
|
||||
constexpr uint32_t kCollectiveCommMaxTimeout = 300;
|
||||
|
||||
// The collective communication groups which are composed of multiple processes. Refer to MPI_Group.
|
||||
struct CommunicationGroupInfo {
|
||||
// This group's rank size.
|
||||
uint32_t size;
|
||||
|
||||
// This process's global rank id.
|
||||
uint32_t global_rank;
|
||||
|
||||
// The group ranks consists of global ranks of the processes.
|
||||
std::vector<uint32_t> group_ranks;
|
||||
|
||||
// The mapping of global ranks and group ranks.
|
||||
std::map<uint32_t, uint32_t> global_to_group_ranks;
|
||||
std::map<uint32_t, uint32_t> group_to_global_ranks;
|
||||
};
|
||||
|
||||
// MSCollectiveOpsImpl is the collective communication API of the server.
|
||||
// For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also
|
||||
// supported for the elastic scaling feature of the server.
|
||||
class MSCollectiveOpsImpl {
|
||||
public:
|
||||
explicit MSCollectiveOpsImpl(std::shared_ptr<TopologyNode> topo_node)
|
||||
: rank_id_(0), rank_size_(0), topo_node_(topo_node) {}
|
||||
~MSCollectiveOpsImpl() = default;
|
||||
|
||||
bool Initialize();
|
||||
|
||||
template <typename T>
|
||||
bool AllReduce(const std::string &data_name, void *sendbuff, void *recvbuff, size_t count);
|
||||
|
||||
template <typename T>
|
||||
bool AllGather(const void *sendbuff, void *recvbuff, size_t send_count);
|
||||
|
||||
// Collective broadcast within the specified group. The parameter "root" is the group rank of the root process.
|
||||
// Normally 0.
|
||||
template <typename T>
|
||||
bool Broadcast(const void *sendbuff, void *recvbuff, size_t count, uint32_t root,
|
||||
const CommunicationGroupInfo &group_info);
|
||||
|
||||
private:
|
||||
MSCollectiveOpsImpl(const MSCollectiveOpsImpl &) = delete;
|
||||
MSCollectiveOpsImpl &operator=(const MSCollectiveOpsImpl &) = delete;
|
||||
|
||||
// Implementation of RingAllGather.
|
||||
template <typename T>
|
||||
bool RingAllGather(const void *sendbuff, void *recvbuff, size_t send_count);
|
||||
|
||||
template <typename T>
|
||||
bool RingAllGatherImpl(uint32_t send_to_rank, uint32_t recv_from_rank, T *output_buff,
|
||||
const std::vector<size_t> &chunk_offset, const std::vector<size_t> &chunk_sizes);
|
||||
|
||||
uint32_t rank_id_;
|
||||
uint32_t rank_size_;
|
||||
|
||||
std::shared_ptr<TopologyNode> topo_node_{nullptr};
|
||||
|
||||
// The mutex to ensure that collective communication is threadsafe.
|
||||
std::mutex mtx_;
|
||||
};
|
||||
|
||||
template bool MSCollectiveOpsImpl::AllGather<float>(const void *sendbuff, void *recvbuff, size_t send_count);
|
||||
template bool MSCollectiveOpsImpl::AllGather<uint64_t>(const void *sendbuff, void *recvbuff, size_t send_count);
|
||||
template bool MSCollectiveOpsImpl::AllGather<int>(const void *sendbuff, void *recvbuff, size_t send_count);
|
||||
template bool MSCollectiveOpsImpl::AllGather<char>(const void *sendbuff, void *recvbuff, size_t send_count);
|
||||
|
||||
template bool MSCollectiveOpsImpl::RingAllGather<float>(const void *sendbuff, void *recvbuff, size_t send_count);
|
||||
template bool MSCollectiveOpsImpl::RingAllGather<uint64_t>(const void *sendbuff, void *recvbuff, size_t send_count);
|
||||
template bool MSCollectiveOpsImpl::RingAllGather<int>(const void *sendbuff, void *recvbuff, size_t send_count);
|
||||
template bool MSCollectiveOpsImpl::RingAllGather<char>(const void *sendbuff, void *recvbuff, size_t send_count);
|
||||
|
||||
template bool MSCollectiveOpsImpl::Broadcast<float>(const void *sendbuff, void *recvbuff, size_t count, uint32_t root,
|
||||
const CommunicationGroupInfo &group_info);
|
||||
template bool MSCollectiveOpsImpl::Broadcast<uint64_t>(const void *sendbuff, void *recvbuff, size_t count,
|
||||
uint32_t root, const CommunicationGroupInfo &group_info);
|
||||
template bool MSCollectiveOpsImpl::Broadcast<int>(const void *sendbuff, void *recvbuff, size_t count, uint32_t root,
|
||||
const CommunicationGroupInfo &group_info);
|
||||
template bool MSCollectiveOpsImpl::Broadcast<char>(const void *sendbuff, void *recvbuff, size_t count, uint32_t root,
|
||||
const CommunicationGroupInfo &group_info);
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_MS_COLLECTIVE_OPS_IMPL_H_
|
|
@ -104,7 +104,7 @@ bool TopologyNode::Finalize() {
|
|||
|
||||
bool TopologyNode::SendAsync(size_t rank_id, void *data, size_t size) {
|
||||
if (tcp_clients_.find(rank_id) == tcp_clients_.end()) {
|
||||
MS_LOG(ERROR) << "Cann not find tcp client for rank id: " << rank_id;
|
||||
MS_LOG(ERROR) << "Cann not find tcp client for rank id: " << rank_id << ", local rank: " << rank_id_;
|
||||
return false;
|
||||
}
|
||||
auto &tcp_client = tcp_clients_[rank_id];
|
||||
|
@ -125,11 +125,11 @@ bool TopologyNode::SendAsync(size_t rank_id, void *data, size_t size) {
|
|||
bool TopologyNode::WaitForSend(size_t rank_id) {
|
||||
// Wait for all the pending data to be sent to the destination of specified rank id.
|
||||
if (tcp_clients_.find(rank_id) == tcp_clients_.end()) {
|
||||
MS_LOG(ERROR) << "Can not find tcp client for rank id: " << rank_id;
|
||||
MS_LOG(ERROR) << "Can not find tcp client for rank id: " << rank_id << ", local rank: " << rank_id_;
|
||||
return false;
|
||||
}
|
||||
if (node_addresses_.find(rank_id) == node_addresses_.end()) {
|
||||
MS_LOG(ERROR) << "Can not find the address for rank id: " << rank_id;
|
||||
MS_LOG(ERROR) << "Can not find the address for rank id: " << rank_id << ", local rank: " << rank_id_;
|
||||
}
|
||||
auto &tcp_client = tcp_clients_[rank_id];
|
||||
MS_EXCEPTION_IF_NULL(tcp_client);
|
||||
|
@ -152,12 +152,16 @@ bool TopologyNode::Receive(size_t rank_id, MessageBase **message, size_t timeout
|
|||
MS_EXCEPTION_IF_NULL(message);
|
||||
MS_EXCEPTION_IF_NULL(recv_msg);
|
||||
*message = recv_msg;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Failed to receive message from rank: " << rank_id << ", local rank: " << rank_id_;
|
||||
}
|
||||
return rt;
|
||||
}
|
||||
|
||||
size_t TopologyNode::rank_id() { return rank_id_; }
|
||||
|
||||
size_t TopologyNode::rank_size() { return total_node_num_; }
|
||||
|
||||
MessageBase *const TopologyNode::HandleMessage(MessageBase *const message) {
|
||||
MS_EXCEPTION_IF_NULL(message);
|
||||
auto rank_id = std::stoi(message->name);
|
||||
|
|
|
@ -58,6 +58,8 @@ class TopologyNode {
|
|||
|
||||
size_t rank_id();
|
||||
|
||||
size_t rank_size();
|
||||
|
||||
private:
|
||||
// Handle the message received by the tcp server.
|
||||
MessageBase *const HandleMessage(MessageBase *const message);
|
||||
|
|
Loading…
Reference in New Issue