diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.cc index 4726cfbf782..4a8a4be6b4d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.cc @@ -57,5 +57,16 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(Broadcast, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), NcclCollectiveGpuKernel, int) + +MS_REG_GPU_KERNEL_ONE( + AllToAllv, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + NcclCollectiveGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + AllToAllv, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + NcclCollectiveGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(AllToAllv, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + NcclCollectiveGpuKernel, int) + } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h index 5d0294d25d2..bdee1ef4ec5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nccl/nccl_collective_gpu_kernel.h @@ -31,14 +31,14 @@ enum NcclKernelType { NCCL_ALL_GATHER, NCCL_REDUCE_SCATTER, NCCL_BROADCAST, + NCCL_ALLTOALLV, NCCL_INVALID_TYPE = 255 }; -const std::map kNcclTypeMap = { - {"AllReduce", NCCL_ALL_REDUCE}, - {"AllGather", NCCL_ALL_GATHER}, - {"ReduceScatter", NCCL_REDUCE_SCATTER}, - {"Broadcast", NCCL_BROADCAST}, -}; +const std::map kNcclTypeMap = {{"AllReduce", NCCL_ALL_REDUCE}, + {"AllGather", NCCL_ALL_GATHER}, + {"ReduceScatter", NCCL_REDUCE_SCATTER}, + {"Broadcast", NCCL_BROADCAST}, + {"AllToAllv", NCCL_ALLTOALLV}}; template class NcclCollectiveGpuKernel : public NcclGpuKernel { @@ -72,6 +72,10 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel { LaunchBroadcast(inputs, outputs, stream_ptr); break; } + case NCCL_ALLTOALLV: { + LaunchAllToAllv(inputs, outputs, stream_ptr); + break; + } default: { MS_LOG(EXCEPTION) << "Kernel type " << nccl_kernel_type_ << " is not supported."; } @@ -195,8 +199,8 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel { void LaunchBroadcast(const std::vector &inputs, const std::vector &outputs, void *stream_ptr) { - T *input_addr = GetDeviceAddress(inputs, 0); - T *output_addr = GetDeviceAddress(outputs, 0); + T *input_addr; + T *output_addr; cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); auto broadcast_funcptr = reinterpret_cast(dlsym(const_cast(collective_handle_), "Broadcast")); MS_EXCEPTION_IF_NULL(broadcast_funcptr); @@ -210,6 +214,37 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel { } } + void LaunchAllToAllv(const std::vector &inputs, const std::vector &outputs, + void *stream_ptr) { + T *input_addr; + T *output_addr; + cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast(stream_ptr); + auto nccl_recv_func = reinterpret_cast(dlsym(const_cast(collective_handle_), "Recv")); + auto nccl_send_func = reinterpret_cast(dlsym(const_cast(collective_handle_), "Send")); + auto nccl_gstart_func = reinterpret_cast(dlsym(const_cast(collective_handle_), "GroupStart")); + auto nccl_gend_func = reinterpret_cast(dlsym(const_cast(collective_handle_), "GroupEnd")); + MS_EXCEPTION_IF_NULL(nccl_recv_func); + MS_EXCEPTION_IF_NULL(nccl_send_func); + MS_EXCEPTION_IF_NULL(nccl_gstart_func); + MS_EXCEPTION_IF_NULL(nccl_gend_func); + + // This implementation refers to NVIDIA NCCL 2.11 doc. + CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gstart_func)(), "AllToAllv: ncclGroupStart failed"); + for (int i = 0; i < SizeToInt(input_size_list_.size()); ++i) { + input_addr = GetDeviceAddress(inputs, i); + output_addr = GetDeviceAddress(outputs, i); + CHECK_NCCL_RET_WITH_EXCEPT( + kernel_node_, + (*nccl_send_func)(input_addr, input_size_list_[i] / sizeof(T), nccl_data_type_, i, stream, group_name_), + "AllToAllv: ncclSend failed"); + CHECK_NCCL_RET_WITH_EXCEPT( + kernel_node_, + (*nccl_recv_func)(output_addr, output_size_list_[i] / sizeof(T), nccl_data_type_, i, stream, group_name_), + "AllToAllv: ncclRecv failed"); + } + CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gend_func)(), "AllToAllv: ncclGroupEnd failed"); + } + void InferCommType(const CNodePtr &kernel_node) { std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); auto iter = kNcclTypeMap.find(kernel_name); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/alltoall_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/alltoall_fusion.cc new file mode 100644 index 00000000000..ace626e196f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/alltoall_fusion.cc @@ -0,0 +1,186 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/gpu/alltoall_fusion.h" + +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "runtime/device/gpu/kernel_info_setter.h" +#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr size_t kCNodePrimitiveIdx = 0; +constexpr size_t kAllToAllInputIdx = 1; + +typedef std::vector (*GetGroupRanks)(const std::string &); + +inline int64_t NormalizeDim(const std::vector &shape, int64_t dim) { + return dim < 0 ? SizeToLong(shape.size()) + dim : dim; +} + +uint32_t GetRankSize(const std::string &group) { + uint32_t rank_size; + const void *collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); + MS_EXCEPTION_IF_NULL(collective_handle_); + + // Get group size + auto get_group_size_funcptr = + reinterpret_cast(dlsym(const_cast(collective_handle_), "GetGroupRanks")); + MS_EXCEPTION_IF_NULL(get_group_size_funcptr); + std::vector group_ranks = (*get_group_size_funcptr)(group); + rank_size = group_ranks.size(); + return rank_size; +} + +CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(all_to_all); + int64_t split_count = AnfAlgo::GetNodeAttr(all_to_all, kAttrSplitCount); + int64_t split_dim = AnfAlgo::GetNodeAttr(all_to_all, kAttrSplitDim); + if (all_to_all->size() <= kAllToAllInputIdx) { + MS_LOG(EXCEPTION) << "Invalid cnode " << all_to_all->DebugString() << " input size " << all_to_all->size(); + } + + // Make a split CNode. + auto all_to_all_input = all_to_all->input(kAllToAllInputIdx); + std::vector split_input = {NewValueNode(std::make_shared(prim::kPrimSplit->name())), + all_to_all_input}; + auto split = graph->NewCNode(split_input); + MS_EXCEPTION_IF_NULL(split); + + // Judge validity of split_dim and shape + auto dtype = AnfAlgo::GetOutputInferDataType(all_to_all_input, 0); + auto shape = AnfAlgo::GetOutputInferShape(all_to_all_input, 0); + split_dim = NormalizeDim(shape, split_dim); + if (SizeToLong(shape.size()) <= split_dim) { + MS_LOG(EXCEPTION) << "Invalid split dim " << split_dim << " is over the shape size " << shape.size(); + } + if (split_count == 0 || shape[LongToSize(split_dim)] % split_count != 0) { + MS_LOG(EXCEPTION) << "Invalid split count " << split_count << " cannot be divisible by shape[" << split_dim + << "] = " << shape[LongToSize(split_dim)]; + } + shape[LongToSize(split_dim)] /= split_count; + + // Set Split CNode outputs type and shape, and CNode attributes. + std::vector dtypes(split_count, dtype); + std::vector> shapes(split_count, shape); + AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get()); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(split_dim), split); + AnfAlgo::SetNodeAttr(kAttrOutputNum, MakeValue(split_count), split); + return split; +} + +CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &split) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(all_to_all); + MS_EXCEPTION_IF_NULL(split); + int64_t split_count = AnfAlgo::GetNodeAttr(all_to_all, kAttrSplitCount); + std::string group = AnfAlgo::GetNodeAttr(all_to_all, kAttrGroup); + std::vector split_outputs; + CreateMultipleOutputsOfAnfNode(graph, split, split_count, &split_outputs); + if (split_outputs.empty()) { + MS_LOG(EXCEPTION) << "The node " << split->DebugString() << " should have at least one output, but got 0."; + } + + // Make a all_to_all_v CNode. + std::vector all_to_all_v_input = {NewValueNode(std::make_shared(kAllToAllVOpName))}; + all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs.begin(), split_outputs.end()); + auto all_to_all_v = graph->NewCNode(all_to_all_v_input); + MS_EXCEPTION_IF_NULL(all_to_all_v); + + // Prepare dtypes, shapes and ranks vectors. + auto single_shape = AnfAlgo::GetOutputInferShape(split_outputs[0], 0); + auto single_type = AnfAlgo::GetOutputInferDataType(split_outputs[0], 0); + std::vector dtypes(split_count, single_type); + std::vector> shapes(split_count, single_shape); + AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, all_to_all_v.get()); + uint32_t rank_size = GetRankSize(group); + std::vector rank_ids(rank_size, 0); + for (uint32_t i = 0; i < rank_size; ++i) { + rank_ids[i] = static_cast(i); + } + + // Set AllToAllv CNode outputs and attributes. + AnfAlgo::SetNodeAttr(kAttrSendRankIds, MakeValue>(rank_ids), all_to_all_v); + AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue>(rank_ids), all_to_all_v); + AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue(group), all_to_all_v); + MS_LOG(INFO) << "Create AllToAllv success, split count " << split_count << ", rank size " << rank_size; + return all_to_all_v; +} + +CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &all_to_all_v) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(all_to_all); + MS_EXCEPTION_IF_NULL(all_to_all_v); + int64_t split_count = AnfAlgo::GetNodeAttr(all_to_all, kAttrSplitCount); + int64_t concat_dim = AnfAlgo::GetNodeAttr(all_to_all, kAttrConcatDim); + std::vector all_to_all_v_outputs; + CreateMultipleOutputsOfAnfNode(graph, all_to_all_v, split_count, &all_to_all_v_outputs); + if (all_to_all_v_outputs.empty()) { + MS_LOG(EXCEPTION) << "The node " << all_to_all_v->DebugString() << " should have at least one output, but got 0."; + } + + // Make a Concat CNode. + std::vector concat_input = {NewValueNode(std::make_shared(kConcatOpName))}; + concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.end()); + auto concat = graph->NewCNode(concat_input); + MS_EXCEPTION_IF_NULL(concat); + + // Judge validity of concat_dim. + auto single_shape = AnfAlgo::GetOutputInferShape(all_to_all_v_outputs[0], 0); + concat_dim = NormalizeDim(single_shape, concat_dim); + if (LongToSize(concat_dim) >= single_shape.size()) { + MS_LOG(EXCEPTION) << "Invalid concat dim " << concat_dim << " is greater than shape size " << single_shape.size(); + } + + // Set Concat CNode outputs and attributes. + single_shape[LongToSize(concat_dim)] *= split_count; + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(all_to_all_v_outputs[0], 0)}, {single_shape}, + concat.get()); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(concat_dim), concat); + AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(split_count), concat); + std::vector dyn_input_size{split_count}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat); + return concat; +} +} // namespace + +const BaseRef AllToAllFusion::DefinePattern() const { + return VectorRef({prim::kPrimAllToAll, std::make_shared()}); +} + +const AnfNodePtr AllToAllFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + auto all_to_all = node->cast(); + MS_EXCEPTION_IF_NULL(all_to_all); + + // Step1: Split the AllToAll input Tensor into n_ranks parts along the AllToAll split_dim. + auto split = CreateSplitNode(graph, all_to_all); + // Step2: AllToAllv send and recv data to and from different rank. + auto all_to_all_v = CreateAllToAllvNode(graph, all_to_all, split); + // Step3: Concat all parts into one Tensor. + auto concat = CreateConcatNode(graph, all_to_all, all_to_all_v); + return concat; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/alltoall_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/alltoall_fusion.h new file mode 100644 index 00000000000..4e053a65e8d --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/alltoall_fusion.h @@ -0,0 +1,34 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ALLTOALL_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ALLTOALL_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { + +class AllToAllFusion : public PatternProcessPass { + public: + explicit AllToAllFusion(bool multigraph = true) : PatternProcessPass("alltoall_fusion", multigraph) {} + ~AllToAllFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ALLTOALL_FUSION_H_ diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index b7d5ec2c780..90829cb8b9b 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -23,6 +23,7 @@ #include "backend/optimizer/common/common_backend_optimization.h" #include "backend/optimizer/gpu/adam_weight_decay_fusion.h" #include "backend/optimizer/gpu/adam_fusion.h" +#include "backend/optimizer/gpu/alltoall_fusion.h" #include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h" #include "backend/optimizer/gpu/apply_momentum_scale_fusion.h" #include "backend/optimizer/gpu/apply_momentum_weight_fusion.h" @@ -160,6 +161,7 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc b/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc index 5ff0ad44c85..f02b5489177 100644 --- a/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc +++ b/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc @@ -274,6 +274,7 @@ void GPUDeviceContext::FuseOperators(const KernelGraphPtr &graph) const { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/runtime/hardware/gpu/optimizer.h b/mindspore/ccsrc/runtime/hardware/gpu/optimizer.h index b64c25ee71d..2230338cd18 100644 --- a/mindspore/ccsrc/runtime/hardware/gpu/optimizer.h +++ b/mindspore/ccsrc/runtime/hardware/gpu/optimizer.h @@ -23,6 +23,7 @@ #include "backend/optimizer/common/common_backend_optimization.h" #include "backend/optimizer/gpu/adam_weight_decay_fusion.h" #include "backend/optimizer/gpu/adam_fusion.h" +#include "backend/optimizer/gpu/alltoall_fusion.h" #include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h" #include "backend/optimizer/gpu/apply_momentum_scale_fusion.h" #include "backend/optimizer/gpu/apply_momentum_weight_fusion.h" diff --git a/tests/st/nccl/test_nccl_all.py b/tests/st/nccl/test_nccl_all.py index 480bf257668..656ad198ed1 100644 --- a/tests/st/nccl/test_nccl_all.py +++ b/tests/st/nccl/test_nccl_all.py @@ -60,3 +60,11 @@ def test_nccl_broadcast_op(): def test_nccl_send_recv_op(): return_code = os.system("mpirun -n 8 pytest -s test_nccl_send_recv_op.py") assert return_code == 0 + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_single +def test_nccl_all_to_all_op(): + return_code = os.system("mpirun -n 8 pytest -s test_nccl_all_to_all_op.py") + assert return_code == 0 + \ No newline at end of file diff --git a/tests/st/nccl/test_nccl_all_to_all_op.py b/tests/st/nccl/test_nccl_all_to_all_op.py new file mode 100644 index 00000000000..c6507c77329 --- /dev/null +++ b/tests/st/nccl/test_nccl_all_to_all_op.py @@ -0,0 +1,55 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + +init() +rank = get_rank() +size = get_group_size() + +x = np.asarray([1, 1, 1, 1, 1, 1, 1, 1]).astype(np.float32) * (rank + 1) +x1 = np.asarray([1, 2, 3, 4, 5, 6, 7, 8]).astype(np.float32) + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.alltoall = P.comm_ops.AlltoAll(split_count=8, split_dim=0, concat_dim=0) + + def construct(self, inputs): + return self.alltoall(inputs) + + +def test_AlltoAll(): + alltoall = Net() + expect0 = np.asarray([1, 2, 3, 4, 5, 6, 7, 8]).astype(np.float32) + output0 = alltoall(Tensor(x)).asnumpy() + diff0 = output0 - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output0.shape == expect0.shape + + expect1 = np.asarray([1, 1, 1, 1, 1, 1, 1, 1]).astype(np.float32) * (rank + 1) + output1 = alltoall(Tensor(x1)).asnumpy() + diff1 = output1 - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output1.shape == expect1.shape