!26292 Add GPU operator NeighborExchange

Merge pull request !26292 from Cononlly/master
This commit is contained in:
i-robot 2021-11-18 05:16:41 +00:00 committed by Gitee
commit ec4cd6933d
8 changed files with 363 additions and 50 deletions

View File

@ -73,6 +73,7 @@ if(ENABLE_GPU)
file(GLOB_RECURSE GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "gpu/*.cc")
list(REMOVE_ITEM GPU_SRC_LIST "gpu/nccl/nccl_collective_gpu_kernel.cc")
list(REMOVE_ITEM GPU_SRC_LIST "gpu/nccl/nccl_p2p_gpu_kernel.cc")
list(REMOVE_ITEM GPU_SRC_LIST "gpu/nccl/nccl_send_gpu_kernel.cc")
list(REMOVE_ITEM GPU_SRC_LIST "gpu/nccl/nccl_recv_gpu_kernel.cc")
list(REMOVE_ITEM GPU_SRC_LIST "gpu/trt/trt_kernel.cc")

View File

@ -58,15 +58,5 @@ MS_REG_GPU_KERNEL_ONE(Broadcast,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
NcclCollectiveGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
AllToAllv, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
NcclCollectiveGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
AllToAllv, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
NcclCollectiveGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(AllToAllv,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
NcclCollectiveGpuKernel, int)
} // namespace kernel
} // namespace mindspore

View File

@ -31,14 +31,12 @@ enum NcclKernelType {
NCCL_ALL_GATHER,
NCCL_REDUCE_SCATTER,
NCCL_BROADCAST,
NCCL_ALLTOALLV,
NCCL_INVALID_TYPE = 255
};
const std::map<std::string, NcclKernelType> kNcclTypeMap = {{"AllReduce", NCCL_ALL_REDUCE},
{"AllGather", NCCL_ALL_GATHER},
{"ReduceScatter", NCCL_REDUCE_SCATTER},
{"Broadcast", NCCL_BROADCAST},
{"AllToAllv", NCCL_ALLTOALLV}};
{"Broadcast", NCCL_BROADCAST}};
template <typename T>
class NcclCollectiveGpuKernel : public NcclGpuKernel {
@ -72,10 +70,6 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
LaunchBroadcast(inputs, outputs, stream_ptr);
break;
}
case NCCL_ALLTOALLV: {
LaunchAllToAllv(inputs, outputs, stream_ptr);
break;
}
default: {
MS_LOG(EXCEPTION) << "Kernel type " << nccl_kernel_type_ << " is not supported.";
}
@ -214,37 +208,6 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
}
}
void LaunchAllToAllv(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs,
void *stream_ptr) {
T *input_addr = nullptr;
T *output_addr = nullptr;
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
auto nccl_recv_func = reinterpret_cast<Recv>(dlsym(const_cast<void *>(collective_handle_), "Recv"));
auto nccl_send_func = reinterpret_cast<Send>(dlsym(const_cast<void *>(collective_handle_), "Send"));
auto nccl_gstart_func = reinterpret_cast<GroupStart>(dlsym(const_cast<void *>(collective_handle_), "GroupStart"));
auto nccl_gend_func = reinterpret_cast<GroupEnd>(dlsym(const_cast<void *>(collective_handle_), "GroupEnd"));
MS_EXCEPTION_IF_NULL(nccl_recv_func);
MS_EXCEPTION_IF_NULL(nccl_send_func);
MS_EXCEPTION_IF_NULL(nccl_gstart_func);
MS_EXCEPTION_IF_NULL(nccl_gend_func);
// This implementation refers to NVIDIA NCCL 2.11 doc.
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gstart_func)(), "AllToAllv: ncclGroupStart failed");
for (int i = 0; i < SizeToInt(input_size_list_.size()); ++i) {
input_addr = GetDeviceAddress<T>(inputs, i);
output_addr = GetDeviceAddress<T>(outputs, i);
CHECK_NCCL_RET_WITH_EXCEPT(
kernel_node_,
(*nccl_send_func)(input_addr, input_size_list_[i] / sizeof(T), nccl_data_type_, i, stream, group_name_),
"AllToAllv: ncclSend failed");
CHECK_NCCL_RET_WITH_EXCEPT(
kernel_node_,
(*nccl_recv_func)(output_addr, output_size_list_[i] / sizeof(T), nccl_data_type_, i, stream, group_name_),
"AllToAllv: ncclRecv failed");
}
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gend_func)(), "AllToAllv: ncclGroupEnd failed");
}
void InferCommType(const CNodePtr &kernel_node) {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
auto iter = kNcclTypeMap.find(kernel_name);

View File

@ -0,0 +1,57 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nccl/nccl_p2p_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
AllToAllv, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
NcclP2PGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
AllToAllv, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
NcclP2PGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(AllToAllv,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
NcclP2PGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
AllToAllv, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat16),
NcclP2PGpuKernel, float, half)
MS_REG_GPU_KERNEL_TWO(
AllToAllv, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32),
NcclP2PGpuKernel, half, float)
MS_REG_GPU_KERNEL_TWO(
NeighborExchange,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
NcclP2PGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
NeighborExchange,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
NcclP2PGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(NeighborExchange,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
NcclP2PGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
NeighborExchange,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat16),
NcclP2PGpuKernel, float, half)
MS_REG_GPU_KERNEL_TWO(
NeighborExchange,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32),
NcclP2PGpuKernel, half, float)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,224 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_P2P_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_P2P_GPU_KERNEL_H_
#include <dlfcn.h>
#include <stdint.h>
#include <vector>
#include <string>
#include <map>
#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h"
namespace mindspore {
namespace kernel {
enum NcclKernelType { NCCL_ALLTOALLV = 0, NCCL_NEIGHBOREXCHANGE = 1, NCCL_INVALID_TYPE = 255 };
const std::map<std::string, NcclKernelType> kNcclTypeMap = {{"AllToAllv", NCCL_ALLTOALLV},
{"NeighborExchange", NCCL_NEIGHBOREXCHANGE}};
template <typename T, typename I>
class NcclP2PGpuKernel : public NcclGpuKernel {
public:
NcclP2PGpuKernel() { ResetResource(); }
~NcclP2PGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
switch (nccl_kernel_type_) {
case NCCL_ALLTOALLV: {
LaunchAllToAllv(inputs, outputs, stream_ptr);
break;
}
case NCCL_NEIGHBOREXCHANGE: {
LaunchAllToAllv(inputs, outputs, stream_ptr);
break;
}
default: {
MS_LOG(EXCEPTION) << "Kernel type " << nccl_kernel_type_ << " is not supported.";
}
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_node_ = kernel_node;
InferCommType(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (input_num > 0) {
input_nccl_data_type_ = nccl_dtype(AnfAlgo::GetInputDeviceDataType(kernel_node, 0));
}
if (output_num > 0) {
output_nccl_data_type_ = nccl_dtype(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0));
}
for (size_t i = 0; i < input_num; ++i) {
auto shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, i);
is_null_input_ = CHECK_NULL_INPUT(shape);
if (is_null_input_) {
MS_LOG(WARNING) << "For 'NcclP2PGpuKernel', input shape is null ";
InitSizeLists();
return true;
}
size_t size = sizeof(T);
for (size_t j = 0; j < shape.size(); j++) {
size *= IntToSize(shape[j]);
}
input_size_list_.push_back(size);
input_size_ += size;
}
for (size_t i = 0; i < output_num; ++i) {
auto shape = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, i);
is_null_input_ = CHECK_NULL_INPUT(shape);
if (is_null_input_) {
MS_LOG(WARNING) << "For 'NcclP2PGpuKernel', output shape is null";
InitSizeLists();
return true;
}
size_t size = sizeof(I);
for (size_t j = 0; j < shape.size(); j++) {
size *= IntToSize(shape[j]);
}
output_size_list_.push_back(size);
output_size_ += size;
}
group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup);
MS_LOG(INFO) << AnfAlgo::GetCNodeName(kernel_node) << " for group " << group_name_;
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(prim);
auto comm_stream_attr = prim->GetAttr("stream_id");
if (comm_stream_attr) {
comm_stream_ = reinterpret_cast<cudaStream_t>(GetValue<uintptr_t>(comm_stream_attr));
MS_EXCEPTION_IF_NULL(comm_stream_);
}
// Used by AlltoAllv
auto send_rank_ids_attr = prim->GetAttr(kAttrSendRankIds);
auto recv_rank_ids_attr = prim->GetAttr(kAttrRecvRankIds);
if (send_rank_ids_attr && recv_rank_ids_attr) {
send_rank_ids = GetValue<std::vector<int64_t>>(send_rank_ids_attr);
recv_rank_ids = GetValue<std::vector<int64_t>>(recv_rank_ids_attr);
}
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
MS_EXCEPTION_IF_NULL(collective_handle_);
return true;
}
void ResetResource() noexcept override {
nccl_kernel_type_ = NCCL_INVALID_TYPE;
input_size_ = 0;
output_size_ = 0;
root_ = 0;
is_null_input_ = false;
collective_handle_ = nullptr;
comm_stream_ = nullptr;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override { return; }
private:
void LaunchAllToAllv(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs,
void *stream_ptr) {
T *input_addr = nullptr;
I *output_addr = nullptr;
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
// send_rank_id and recv rank_id size needs to be equal to input_list size
if (send_rank_ids.size() != input_size_list_.size()) {
MS_LOG(ERROR) << "Trying to use AlltoAllv, but send_rank_ids vector size not equals to input_list size.";
}
if (recv_rank_ids.size() != output_size_list_.size()) {
MS_LOG(ERROR) << "Trying to use AlltoAllv, but recv_rank_ids vector size not equals to output_list size.";
}
auto nccl_recv_func = reinterpret_cast<Recv>(dlsym(const_cast<void *>(collective_handle_), "Recv"));
auto nccl_send_func = reinterpret_cast<Send>(dlsym(const_cast<void *>(collective_handle_), "Send"));
auto nccl_gstart_func = reinterpret_cast<GroupStart>(dlsym(const_cast<void *>(collective_handle_), "GroupStart"));
auto nccl_gend_func = reinterpret_cast<GroupEnd>(dlsym(const_cast<void *>(collective_handle_), "GroupEnd"));
MS_EXCEPTION_IF_NULL(nccl_recv_func);
MS_EXCEPTION_IF_NULL(nccl_send_func);
MS_EXCEPTION_IF_NULL(nccl_gstart_func);
MS_EXCEPTION_IF_NULL(nccl_gend_func);
// This implementation refers to NVIDIA NCCL 2.11 doc.
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gstart_func)(), "AllToAllv: ncclGroupStart failed");
for (int i = 0; i < SizeToInt(input_size_list_.size()); ++i) {
input_addr = GetDeviceAddress<T>(inputs, i);
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_,
(*nccl_send_func)(input_addr, input_size_list_[i] / sizeof(T), input_nccl_data_type_,
send_rank_ids[i], stream, group_name_),
"AllToAllv: ncclSend failed");
}
for (int i = 0; i < SizeToInt(output_size_list_.size()); ++i) {
output_addr = GetDeviceAddress<I>(outputs, i);
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_,
(*nccl_recv_func)(output_addr, output_size_list_[i] / sizeof(I),
output_nccl_data_type_, recv_rank_ids[i], stream, group_name_),
"AllToAllv: ncclRecv failed");
}
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gend_func)(), "AllToAllv: ncclGroupEnd failed");
}
void InferCommType(const CNodePtr &kernel_node) {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
auto iter = kNcclTypeMap.find(kernel_name);
if (iter == kNcclTypeMap.end()) {
MS_LOG(EXCEPTION) << "Kernel " << kernel_name << " is not supported.";
} else {
nccl_kernel_type_ = iter->second;
}
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(prim);
auto root_rank = prim->GetAttr(kAttrRootRank);
if (root_rank) {
root_ = static_cast<int>(GetValue<int64_t>(root_rank));
}
return;
}
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
NcclKernelType nccl_kernel_type_;
size_t input_size_;
size_t output_size_;
int root_;
bool is_null_input_;
const void *collective_handle_;
cudaStream_t comm_stream_;
ncclDataType_t output_nccl_data_type_;
ncclDataType_t input_nccl_data_type_;
std::vector<int64_t> send_rank_ids;
std::vector<int64_t> recv_rank_ids;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NCCL_P2P_GPU_KERNEL_H_

View File

@ -658,7 +658,8 @@ OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) {
Attr send_shapes = {SEND_SHAPES, MakeTupleListValue(send_shapes_)};
Attr recv_shapes = {RECV_SHAPES, MakeTupleListValue(recv_shapes_)};
Attr recv_type = {RECV_TYPE, dtype};
OperatorAttrs attrs = {send_ranks, recv_ranks, recv_shapes, send_shapes, recv_type};
Attr group = {GROUP, MakeValue(g_device_manager->world_group())};
OperatorAttrs attrs = {send_ranks, recv_ranks, recv_shapes, send_shapes, recv_type, group};
return attrs;
}

View File

@ -67,4 +67,16 @@ def test_nccl_send_recv_op():
def test_nccl_all_to_all_op():
return_code = os.system("mpirun -n 8 pytest -s test_nccl_all_to_all_op.py")
assert return_code == 0
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_single
def test_nccl_neighbor_exchange_op():
"""
Feature: NeighborExchange GPU operator
Description: see details in test_nccl_neighbor_exchange_op.py
Expectation: success, return_code==0
"""
return_code = os.system(
"mpirun -n 8 pytest -s test_nccl_neighbor_exchange_op.py")
assert return_code == 0

View File

@ -0,0 +1,65 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.ops import operations as P
import mindspore as ms
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
init()
rank = get_rank()
size = get_group_size()
x = np.asarray([1, 1, 1, 1, 1, 1, 1, 1]).astype(np.float32) * rank
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.neighborexchange = P.comm_ops.NeighborExchange(
send_rank_ids=[(rank - 1) % 8],
recv_rank_ids=[(rank + 1) % 8],
recv_shapes=tuple([[8]]),
send_shapes=tuple([[8]]),
recv_type=ms.float32,
group="nccl_world_group")
def construct(self, inputs):
return self.neighborexchange(inputs)
def test_neighborexchange():
"""
Feature: NeighborExchange operator on GPU
Description: for each device, send to previous rank and receive from next rank.
example: rank 0 send to rank 7 and receive from rank 1.
Expectation: on rank i, result == [1 ,1 ,1, 1, 1, 1, 1, 1] * ((i + 1) % 8)
"""
neighborexchange = Net()
expect0 = np.asarray([1, 1, 1, 1, 1, 1, 1, 1]).astype(
np.float32) * ((rank + 1) % 8)
inputs = []
inputs.append(Tensor(x))
inputs = tuple(inputs)
output0 = neighborexchange(inputs)[0].asnumpy()
diff0 = output0 - expect0
error0 = np.ones(shape=expect0.shape) * 1.0e-5
assert np.all(diff0 < error0)
assert output0.shape == expect0.shape