forked from mindspore-Ecosystem/mindspore
Add AlltoAll GPU operator and IR pass and tests
This commit is contained in:
parent
48998444a0
commit
a8088525d5
|
@ -57,5 +57,16 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
MS_REG_GPU_KERNEL_ONE(Broadcast,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
NcclCollectiveGpuKernel, int)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AllToAllv, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
NcclCollectiveGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
AllToAllv, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
NcclCollectiveGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(AllToAllv,
|
||||
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
NcclCollectiveGpuKernel, int)
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,14 +31,14 @@ enum NcclKernelType {
|
|||
NCCL_ALL_GATHER,
|
||||
NCCL_REDUCE_SCATTER,
|
||||
NCCL_BROADCAST,
|
||||
NCCL_ALLTOALLV,
|
||||
NCCL_INVALID_TYPE = 255
|
||||
};
|
||||
const std::map<std::string, NcclKernelType> kNcclTypeMap = {
|
||||
{"AllReduce", NCCL_ALL_REDUCE},
|
||||
{"AllGather", NCCL_ALL_GATHER},
|
||||
{"ReduceScatter", NCCL_REDUCE_SCATTER},
|
||||
{"Broadcast", NCCL_BROADCAST},
|
||||
};
|
||||
const std::map<std::string, NcclKernelType> kNcclTypeMap = {{"AllReduce", NCCL_ALL_REDUCE},
|
||||
{"AllGather", NCCL_ALL_GATHER},
|
||||
{"ReduceScatter", NCCL_REDUCE_SCATTER},
|
||||
{"Broadcast", NCCL_BROADCAST},
|
||||
{"AllToAllv", NCCL_ALLTOALLV}};
|
||||
|
||||
template <typename T>
|
||||
class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
||||
|
@ -69,6 +69,10 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
|||
LaunchBroadcast(inputs, outputs, stream_ptr);
|
||||
break;
|
||||
}
|
||||
case NCCL_ALLTOALLV: {
|
||||
LaunchAllToAllv(inputs, outputs, stream_ptr);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
MS_LOG(EXCEPTION) << "Kernel type " << nccl_kernel_type_ << " is not supported.";
|
||||
}
|
||||
|
@ -177,8 +181,8 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
|||
|
||||
void LaunchBroadcast(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs,
|
||||
void *stream_ptr) {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
T *input_addr;
|
||||
T *output_addr;
|
||||
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto broadcast_funcptr = reinterpret_cast<Broadcast>(dlsym(const_cast<void *>(collective_handle_), "Broadcast"));
|
||||
MS_EXCEPTION_IF_NULL(broadcast_funcptr);
|
||||
|
@ -192,6 +196,37 @@ class NcclCollectiveGpuKernel : public NcclGpuKernel {
|
|||
}
|
||||
}
|
||||
|
||||
void LaunchAllToAllv(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs,
|
||||
void *stream_ptr) {
|
||||
T *input_addr;
|
||||
T *output_addr;
|
||||
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto nccl_recv_func = reinterpret_cast<Recv>(dlsym(const_cast<void *>(collective_handle_), "Recv"));
|
||||
auto nccl_send_func = reinterpret_cast<Send>(dlsym(const_cast<void *>(collective_handle_), "Send"));
|
||||
auto nccl_gstart_func = reinterpret_cast<GroupStart>(dlsym(const_cast<void *>(collective_handle_), "GroupStart"));
|
||||
auto nccl_gend_func = reinterpret_cast<GroupEnd>(dlsym(const_cast<void *>(collective_handle_), "GroupEnd"));
|
||||
MS_EXCEPTION_IF_NULL(nccl_recv_func);
|
||||
MS_EXCEPTION_IF_NULL(nccl_send_func);
|
||||
MS_EXCEPTION_IF_NULL(nccl_gstart_func);
|
||||
MS_EXCEPTION_IF_NULL(nccl_gend_func);
|
||||
|
||||
// This implementation refers to NVIDIA NCCL 2.11 doc.
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gstart_func)(), "AllToAllv: ncclGroupStart failed");
|
||||
for (int i = 0; i < SizeToInt(input_size_list_.size()); ++i) {
|
||||
input_addr = GetDeviceAddress<T>(inputs, i);
|
||||
output_addr = GetDeviceAddress<T>(outputs, i);
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
(*nccl_send_func)(input_addr, input_size_list_[i] / sizeof(T), nccl_data_type_, i, stream, group_name_),
|
||||
"AllToAllv: ncclSend failed");
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
(*nccl_recv_func)(output_addr, output_size_list_[i] / sizeof(T), nccl_data_type_, i, stream, group_name_),
|
||||
"AllToAllv: ncclRecv failed");
|
||||
}
|
||||
CHECK_NCCL_RET_WITH_EXCEPT(kernel_node_, (*nccl_gend_func)(), "AllToAllv: ncclGroupEnd failed");
|
||||
}
|
||||
|
||||
void InferCommType(const CNodePtr &kernel_node) {
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
auto iter = kNcclTypeMap.find(kernel_name);
|
||||
|
|
|
@ -0,0 +1,186 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/optimizer/gpu/alltoall_fusion.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "runtime/device/gpu/kernel_info_setter.h"
|
||||
#include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kCNodePrimitiveIdx = 0;
|
||||
constexpr size_t kAllToAllInputIdx = 1;
|
||||
|
||||
typedef std::vector<int> (*GetGroupRanks)(const std::string &);
|
||||
|
||||
inline int64_t NormalizeDim(const std::vector<size_t> &shape, int64_t dim) {
|
||||
return dim < 0 ? SizeToLong(shape.size()) + dim : dim;
|
||||
}
|
||||
|
||||
uint32_t GetRankSize(const std::string &group) {
|
||||
uint32_t rank_size;
|
||||
const void *collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
|
||||
MS_EXCEPTION_IF_NULL(collective_handle_);
|
||||
|
||||
// Get group size
|
||||
auto get_group_size_funcptr =
|
||||
reinterpret_cast<GetGroupRanks>(dlsym(const_cast<void *>(collective_handle_), "GetGroupRanks"));
|
||||
MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
|
||||
std::vector<int> group_ranks = (*get_group_size_funcptr)(group);
|
||||
rank_size = group_ranks.size();
|
||||
return rank_size;
|
||||
}
|
||||
|
||||
CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all);
|
||||
int64_t split_count = AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
|
||||
int64_t split_dim = AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitDim);
|
||||
if (all_to_all->size() <= kAllToAllInputIdx) {
|
||||
MS_LOG(EXCEPTION) << "Invalid cnode " << all_to_all->DebugString() << " input size " << all_to_all->size();
|
||||
}
|
||||
|
||||
// Make a split CNode.
|
||||
auto all_to_all_input = all_to_all->input(kAllToAllInputIdx);
|
||||
std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplit->name())),
|
||||
all_to_all_input};
|
||||
auto split = graph->NewCNode(split_input);
|
||||
MS_EXCEPTION_IF_NULL(split);
|
||||
|
||||
// Judge validity of split_dim and shape
|
||||
auto dtype = AnfAlgo::GetOutputInferDataType(all_to_all_input, 0);
|
||||
auto shape = AnfAlgo::GetOutputInferShape(all_to_all_input, 0);
|
||||
split_dim = NormalizeDim(shape, split_dim);
|
||||
if (SizeToLong(shape.size()) <= split_dim) {
|
||||
MS_LOG(EXCEPTION) << "Invalid split dim " << split_dim << " is over the shape size " << shape.size();
|
||||
}
|
||||
if (split_count == 0 || shape[LongToSize(split_dim)] % split_count != 0) {
|
||||
MS_LOG(EXCEPTION) << "Invalid split count " << split_count << " cannot be divisible by shape[" << split_dim
|
||||
<< "] = " << shape[LongToSize(split_dim)];
|
||||
}
|
||||
shape[LongToSize(split_dim)] /= split_count;
|
||||
|
||||
// Set Split CNode outputs type and shape, and CNode attributes.
|
||||
std::vector<TypeId> dtypes(split_count, dtype);
|
||||
std::vector<std::vector<size_t>> shapes(split_count, shape);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<int64_t>(split_dim), split);
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputNum, MakeValue<int64_t>(split_count), split);
|
||||
return split;
|
||||
}
|
||||
|
||||
CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &split) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all);
|
||||
MS_EXCEPTION_IF_NULL(split);
|
||||
int64_t split_count = AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
|
||||
std::string group = AnfAlgo::GetNodeAttr<std::string>(all_to_all, kAttrGroup);
|
||||
std::vector<AnfNodePtr> split_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, split, split_count, &split_outputs);
|
||||
if (split_outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << split->DebugString() << " should have at least one output, but got 0.";
|
||||
}
|
||||
|
||||
// Make a all_to_all_v CNode.
|
||||
std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
|
||||
all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs.begin(), split_outputs.end());
|
||||
auto all_to_all_v = graph->NewCNode(all_to_all_v_input);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all_v);
|
||||
|
||||
// Prepare dtypes, shapes and ranks vectors.
|
||||
auto single_shape = AnfAlgo::GetOutputInferShape(split_outputs[0], 0);
|
||||
auto single_type = AnfAlgo::GetOutputInferDataType(split_outputs[0], 0);
|
||||
std::vector<TypeId> dtypes(split_count, single_type);
|
||||
std::vector<std::vector<size_t>> shapes(split_count, single_shape);
|
||||
AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, all_to_all_v.get());
|
||||
uint32_t rank_size = GetRankSize(group);
|
||||
std::vector<int64_t> rank_ids(rank_size, 0);
|
||||
for (uint32_t i = 0; i < rank_size; ++i) {
|
||||
rank_ids[i] = static_cast<int64_t>(i);
|
||||
}
|
||||
|
||||
// Set AllToAllv CNode outputs and attributes.
|
||||
AnfAlgo::SetNodeAttr(kAttrSendRankIds, MakeValue<std::vector<int64_t>>(rank_ids), all_to_all_v);
|
||||
AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue<std::vector<int64_t>>(rank_ids), all_to_all_v);
|
||||
AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue<std::string>(group), all_to_all_v);
|
||||
MS_LOG(INFO) << "Create AllToAllv success, split count " << split_count << ", rank size " << rank_size;
|
||||
return all_to_all_v;
|
||||
}
|
||||
|
||||
CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &all_to_all_v) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all);
|
||||
MS_EXCEPTION_IF_NULL(all_to_all_v);
|
||||
int64_t split_count = AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
|
||||
int64_t concat_dim = AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrConcatDim);
|
||||
std::vector<AnfNodePtr> all_to_all_v_outputs;
|
||||
CreateMultipleOutputsOfAnfNode(graph, all_to_all_v, split_count, &all_to_all_v_outputs);
|
||||
if (all_to_all_v_outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The node " << all_to_all_v->DebugString() << " should have at least one output, but got 0.";
|
||||
}
|
||||
|
||||
// Make a Concat CNode.
|
||||
std::vector<AnfNodePtr> concat_input = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
|
||||
concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.end());
|
||||
auto concat = graph->NewCNode(concat_input);
|
||||
MS_EXCEPTION_IF_NULL(concat);
|
||||
|
||||
// Judge validity of concat_dim.
|
||||
auto single_shape = AnfAlgo::GetOutputInferShape(all_to_all_v_outputs[0], 0);
|
||||
concat_dim = NormalizeDim(single_shape, concat_dim);
|
||||
if (LongToSize(concat_dim) >= single_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid concat dim " << concat_dim << " is greater than shape size " << single_shape.size();
|
||||
}
|
||||
|
||||
// Set Concat CNode outputs and attributes.
|
||||
single_shape[LongToSize(concat_dim)] *= split_count;
|
||||
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(all_to_all_v_outputs[0], 0)}, {single_shape},
|
||||
concat.get());
|
||||
AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<int64_t>(concat_dim), concat);
|
||||
AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(split_count), concat);
|
||||
std::vector<int64_t> dyn_input_size{split_count};
|
||||
AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat);
|
||||
return concat;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef AllToAllFusion::DefinePattern() const {
|
||||
return VectorRef({prim::kPrimAllToAll, std::make_shared<SeqVar>()});
|
||||
}
|
||||
|
||||
const AnfNodePtr AllToAllFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto all_to_all = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(all_to_all);
|
||||
|
||||
// Step1: Split the AllToAll input Tensor into n_ranks parts along the AllToAll split_dim.
|
||||
auto split = CreateSplitNode(graph, all_to_all);
|
||||
// Step2: AllToAllv send and recv data to and from different rank.
|
||||
auto all_to_all_v = CreateAllToAllvNode(graph, all_to_all, split);
|
||||
// Step3: Concat all parts into one Tensor.
|
||||
auto concat = CreateConcatNode(graph, all_to_all, all_to_all_v);
|
||||
return concat;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ALLTOALL_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ALLTOALL_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
||||
class AllToAllFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit AllToAllFusion(bool multigraph = true) : PatternProcessPass("alltoall_fusion", multigraph) {}
|
||||
~AllToAllFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ALLTOALL_FUSION_H_
|
|
@ -23,6 +23,7 @@
|
|||
#include "backend/optimizer/common/common_backend_optimization.h"
|
||||
#include "backend/optimizer/gpu/adam_weight_decay_fusion.h"
|
||||
#include "backend/optimizer/gpu/adam_fusion.h"
|
||||
#include "backend/optimizer/gpu/alltoall_fusion.h"
|
||||
#include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h"
|
||||
#include "backend/optimizer/gpu/apply_momentum_scale_fusion.h"
|
||||
#include "backend/optimizer/gpu/apply_momentum_weight_fusion.h"
|
||||
|
@ -160,6 +161,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
pm->AddPass(std::make_shared<opt::MatMulBiasAddFusion>());
|
||||
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
|
||||
pm->AddPass(std::make_shared<opt::AdamFusion>());
|
||||
pm->AddPass(std::make_shared<opt::AllToAllFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>());
|
||||
|
|
|
@ -272,6 +272,7 @@ void GPUDeviceContext::FuseOperators(const KernelGraphPtr &graph) const {
|
|||
pm->AddPass(std::make_shared<opt::MatMulBiasAddFusion>());
|
||||
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
|
||||
pm->AddPass(std::make_shared<opt::AdamFusion>());
|
||||
pm->AddPass(std::make_shared<opt::AllToAllFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>());
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "backend/optimizer/common/common_backend_optimization.h"
|
||||
#include "backend/optimizer/gpu/adam_weight_decay_fusion.h"
|
||||
#include "backend/optimizer/gpu/adam_fusion.h"
|
||||
#include "backend/optimizer/gpu/alltoall_fusion.h"
|
||||
#include "backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h"
|
||||
#include "backend/optimizer/gpu/apply_momentum_scale_fusion.h"
|
||||
#include "backend/optimizer/gpu/apply_momentum_weight_fusion.h"
|
||||
|
|
|
@ -60,3 +60,11 @@ def test_nccl_broadcast_op():
|
|||
def test_nccl_send_recv_op():
|
||||
return_code = os.system("mpirun -n 8 pytest -s test_nccl_send_recv_op.py")
|
||||
assert return_code == 0
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_single
|
||||
def test_nccl_all_to_all_op():
|
||||
return_code = os.system("mpirun -n 8 pytest -s test_nccl_all_to_all_op.py")
|
||||
assert return_code == 0
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
init()
|
||||
rank = get_rank()
|
||||
size = get_group_size()
|
||||
|
||||
x = np.asarray([1, 1, 1, 1, 1, 1, 1, 1]).astype(np.float32) * (rank + 1)
|
||||
x1 = np.asarray([1, 2, 3, 4, 5, 6, 7, 8]).astype(np.float32)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.alltoall = P.comm_ops.AlltoAll(split_count=8, split_dim=0, concat_dim=0)
|
||||
|
||||
def construct(self, inputs):
|
||||
return self.alltoall(inputs)
|
||||
|
||||
|
||||
def test_AlltoAll():
|
||||
alltoall = Net()
|
||||
expect0 = np.asarray([1, 2, 3, 4, 5, 6, 7, 8]).astype(np.float32)
|
||||
output0 = alltoall(Tensor(x)).asnumpy()
|
||||
diff0 = output0 - expect0
|
||||
error0 = np.ones(shape=expect0.shape) * 1.0e-5
|
||||
assert np.all(diff0 < error0)
|
||||
assert output0.shape == expect0.shape
|
||||
|
||||
expect1 = np.asarray([1, 1, 1, 1, 1, 1, 1, 1]).astype(np.float32) * (rank + 1)
|
||||
output1 = alltoall(Tensor(x1)).asnumpy()
|
||||
diff1 = output1 - expect1
|
||||
error1 = np.ones(shape=expect1.shape) * 1.0e-5
|
||||
assert np.all(diff1 < error1)
|
||||
assert output1.shape == expect1.shape
|
Loading…
Reference in New Issue