From a7f8024c297ce62771a357c676afe3fbfadd517c Mon Sep 17 00:00:00 2001 From: lichenever Date: Thu, 22 Jul 2021 17:30:14 +0800 Subject: [PATCH] add_replace_graph_for_conv2d --- .../parallel/graph_util/generate_graph.cc | 8 ++ .../parallel/graph_util/generate_graph.h | 1 + .../frontend/parallel/ops_info/conv2d_info.cc | 108 +++++++++++++++++- .../frontend/parallel/ops_info/conv2d_info.h | 6 +- .../frontend/parallel/ops_info/ops_utils.h | 8 ++ mindspore/core/base/core_ops.h | 2 +- .../ops/{alltoallv.cc => neighborexchange.cc} | 10 +- .../ops/{alltoallv.h => neighborexchange.h} | 22 ++-- .../ops/_grad_experimental/grad_comm_ops.py | 14 +-- mindspore/ops/operations/_inner_ops.py | 13 ++- ...alltoall_v.py => test_neighborexchange.py} | 14 +-- 11 files changed, 165 insertions(+), 41 deletions(-) rename mindspore/core/ops/{alltoallv.cc => neighborexchange.cc} (85%) rename mindspore/core/ops/{alltoallv.h => neighborexchange.h} (59%) rename tests/ut/python/parallel/{test_alltoall_v.py => test_neighborexchange.py} (80%) diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc index 184885e0d24..113227e56e3 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc @@ -100,6 +100,14 @@ AnfNodePtr CreatInt64Imm(int64_t value) { return ValuePtrToAnfNodePtr(value_ptr); } +AnfNodePtr CreatTuple(const std::vector &tuple) { + std::vector value_list; + std::transform(tuple.begin(), tuple.end(), std::back_inserter(value_list), + [](const int64_t value) { return MakeValue(value); }); + ValueTuplePtr value_tuple_ptr = std::make_shared(value_list); + return ValuePtrToAnfNodePtr(value_tuple_ptr); +} + std::string GetInstanceNameByCNode(const CNodePtr &cnode) { PrimitivePtr prim = GetValueNode(cnode->input(0)); if (!prim) { diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h index 58847cae37c..55801c0af5f 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h +++ b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h @@ -41,6 +41,7 @@ AnfNodePtr CreatTypeInt(int64_t value); AnfNodePtr CreatInt64Imm(int64_t value); AnfNodePtr CreateInt32Tensor(int64_t value); AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr); +AnfNodePtr CreatTuple(const std::vector &tuple); std::string HashInstanceName(const std::string &name); class GenerateGraph { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc index 9cb4600f855..8fc52daed14 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc @@ -25,6 +25,7 @@ #include "frontend/parallel/device_matrix.h" #include "frontend/parallel/strategy.h" #include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "frontend/parallel/graph_util/generate_graph.h" #include "pipeline/jit/resource.h" namespace mindspore { @@ -230,7 +231,7 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) { if (weight_strategy[0] > 1) { out_channel_shard_ = true; - new_out_channel_ = out_channel_ / weight_strategy[1]; + new_out_channel_ = out_channel_ / weight_strategy[0]; } else { out_channel_shard_ = false; } @@ -514,7 +515,7 @@ void Conv2DInfo::InferSendRecvFlag() { MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the recv rank ids is " << recv_rank_ids_; } -void Conv2DInfo::InferRecvShapes() { +void Conv2DInfo::InferOverlapShapes() { if (left_need_recv_) { Shape left_recv_shape = input_slice_shape_; left_recv_shape[3] = overlap_left_size_; @@ -535,6 +536,9 @@ void Conv2DInfo::InferStridedSliceAttrs() { left_strided_slice_end_ = input_slice_shape_; left_strided_slice_end_[3] = left_rank_overlap_right_size_; left_strided_slice_strides_ = {1, 1, 1, 1}; + Shape left_send_shape = input_slice_shape_; + left_send_shape[3] = left_rank_overlap_right_size_; + send_shapes_.push_back(left_send_shape); MS_LOG(INFO) << name_ << ": The left strided slice begin is " << left_strided_slice_begin_ << ", end is " << left_strided_slice_end_; } @@ -544,6 +548,9 @@ void Conv2DInfo::InferStridedSliceAttrs() { right_strided_slice_begin_[3] = input_slice_shape_[3] - right_rank_overlap_left_size_; right_strided_slice_end_ = input_slice_shape_; right_strided_slice_strides_ = {1, 1, 1, 1}; + Shape right_send_shape = input_slice_shape_; + right_send_shape[3] = right_rank_overlap_left_size_; + send_shapes_.push_back(right_send_shape); MS_LOG(INFO) << name_ << ": The right strided slice begin is " << right_strided_slice_begin_ << ", end is " << right_strided_slice_end_; } @@ -554,11 +561,101 @@ void Conv2DInfo::InferNewOperatorAttrs() { InferSendRecvFlag(); - InferRecvShapes(); + InferOverlapShapes(); InferStridedSliceAttrs(); } +OperatorAttrs Conv2DInfo::CreatNeighborExchangeAttrs(const CNodePtr &cnode) { + auto type = cnode->Type(); + MS_EXCEPTION_IF_NULL(type); + auto tensor_type = type->cast(); + MS_EXCEPTION_IF_NULL(tensor_type); + auto dtype = tensor_type->element(); + MS_EXCEPTION_IF_NULL(dtype); + Attr send_ranks = {SEND_RNAK_IDS, MakeValue(send_rank_ids_)}; + Attr recv_ranks = {RECV_RNAK_IDS, MakeValue(recv_rank_ids_)}; + Attr send_shapes = {SEND_SHAPES, MakeValue(send_shapes_)}; + Attr recv_shapes = {RECV_SHAPES, MakeValue(recv_shapes_)}; + Attr recv_type = {RECV_TYPE, dtype}; + OperatorAttrs attrs = {send_ranks, recv_ranks, recv_shapes, send_shapes, recv_type}; + return attrs; +} + +OperatorAttrs Conv2DInfo::CreatConv2DAttrs() { + Attr out_channel = {OUT_CHANNEL, MakeValue(new_out_channel_)}; + Attr kernel_size = {KERNEL_SIZE, MakeValue(kernel_size_)}; + Attr mode = {MODE, MakeValue(mode_)}; + Attr pad_mode = {PAD_MODE, MakeValue("pad")}; + Attr pad = {PAD, MakeValue(new_pad_list_)}; + Attr stride = {STRIDE, MakeValue(stride_)}; + Attr dilation = {DILATION, MakeValue(dilation_)}; + Attr group = {GROUP, MakeValue(group_)}; + Attr data_format = {DATA_FORMAT, MakeValue(format_)}; + OperatorAttrs attrs = {out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format}; + return attrs; +} + +Status Conv2DInfo::ComputeReplaceGraph(const CNodePtr &cnode) { + auto graph = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(graph); + GenerateGraph gen_g = GenerateGraph(attrs_); + if (gen_g.Init(cnode) != SUCCESS) { + MS_LOG(ERROR) << "GenerateGraph Init failed"; + return FAILED; + } + std::vector> input_nodes; + std::vector make_tuple_a_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + if (left_need_send_) { + auto slice_left_begin = CreatTuple(left_strided_slice_begin_); + auto slice_left_end = CreatTuple(left_strided_slice_end_); + auto slice_left_strided = CreatTuple(left_strided_slice_strides_); + auto slice_left = gen_g.PushBack( + {gen_g.NewOpInst(STRIDED_SLICE), cnode->input(1), slice_left_begin, slice_left_end, slice_left_strided}); + make_tuple_a_inputs.push_back(slice_left); + } + if (right_need_send_) { + auto slice_right_begin = CreatTuple(right_strided_slice_begin_); + auto slice_right_end = CreatTuple(right_strided_slice_end_); + auto slice_right_strided = CreatTuple(right_strided_slice_strides_); + auto slice_right = gen_g.PushBack( + {gen_g.NewOpInst(STRIDED_SLICE), cnode->input(1), slice_right_begin, slice_right_end, slice_right_strided}); + make_tuple_a_inputs.push_back(slice_right); + } + auto make_tuple_a = graph->NewCNode(make_tuple_a_inputs); + auto alltoall_attrs = CreatNeighborExchangeAttrs(cnode); + auto alltoall_v = gen_g.PushBack({gen_g.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a}); + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + if (left_need_recv_) { + std::vector tuple_getitem_l_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v, + CreatInt64Imm(0)}; + auto tuple_getitem_l = graph->NewCNode(tuple_getitem_l_inputs); + std::vector make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), cnode->input(1), + tuple_getitem_l}; + auto make_tuple_l = graph->NewCNode(make_tuple_l_inputs); + auto concat_l = gen_g.PushBack({gen_g.NewOpInst(CONCAT), make_tuple_l}); + make_tuple_inputs.push_back(concat_l); + } + if (right_need_recv_) { + std::vector tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v, + CreatInt64Imm(0)}; + auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs); + make_tuple_inputs.push_back(tuple_getitem_r); + } else { + make_tuple_inputs.push_back(cnode->input(1)); + } + auto make_tuple = graph->NewCNode(make_tuple_inputs); + Attr concat_axis = {AXIS, MakeValue(-1)}; + OperatorAttrs concat_attrs = {concat_axis}; + std::vector concat_inputs = {gen_g.NewOpInst(CONCAT, concat_attrs), make_tuple}; + auto concat = graph->NewCNode(concat_inputs); + auto conv2d_attrs = CreatConv2DAttrs(); + auto conv2d = gen_g.PushBack({gen_g.NewOpInst(CONV2D, conv2d_attrs), concat, cnode->input(2)}); + replace_graph_ = std::make_shared>, AnfNodePtr>>( + std::make_pair(input_nodes, conv2d)); + return SUCCESS; +} + ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) { if (!need_exchange_overlap_) { if (!out_channel_shard_) { @@ -579,6 +676,11 @@ ReplaceGraphPtr Conv2DInfo::replace_graph(const CNodePtr &cnode) { InferNewOperatorAttrs(); + if (ComputeReplaceGraph(cnode) != SUCCESS) { + return nullptr; + } else { + return replace_graph_; + } return nullptr; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h index c79aa8e3912..1ae1e4a752a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h @@ -55,9 +55,12 @@ class Conv2DInfo : public OperatorInfo { Status InferOverlapSize(); void InferNewOperatorAttrs(); void InferSendRecvFlag(); - void InferRecvShapes(); + void InferOverlapShapes(); void InferStridedSliceAttrs(); ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; + OperatorAttrs CreatNeighborExchangeAttrs(const CNodePtr &cnode); + OperatorAttrs CreatConv2DAttrs(); + Status ComputeReplaceGraph(const CNodePtr &cnode); int64_t out_channel_ = 1; std::vector kernel_size_; // two integers @@ -100,6 +103,7 @@ class Conv2DInfo : public OperatorInfo { std::vector send_rank_ids_; std::vector recv_rank_ids_; + Shapes send_shapes_; Shapes recv_shapes_; virtual Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 90a255e2d7c..294bcc162ca 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -154,6 +154,11 @@ constexpr char REPLACE[] = "replace"; constexpr char CONNSYMBOL[] = "/"; constexpr char INSTANCE_NAME[] = "instance_name"; constexpr char SPLIT_SENS[] = "split_sens"; +constexpr char SEND_RNAK_IDS[] = "send_rank_ids"; +constexpr char RECV_RNAK_IDS[] = "recv_rank_ids"; +constexpr char RECV_SHAPES[] = "recv_shapes"; +constexpr char SEND_SHAPES[] = "send_shapes"; +constexpr char RECV_TYPE[] = "recv_type"; constexpr char SPLIT_TENSOR[] = "split_tensor"; constexpr char DEV_MAT[] = "dev_mat"; constexpr char TENSOR_MAP[] = "tensor_map"; @@ -193,6 +198,8 @@ constexpr char KERNEL_SIZE[] = "kernel_size"; constexpr char MODE[] = "mode"; constexpr char PAD_MODE[] = "pad_mode"; constexpr char PAD_LIST[] = "pad_list"; +constexpr char PAD[] = "pad"; +constexpr char DATA_FORMAT[] = "data_format"; constexpr char STRIDE[] = "stride"; constexpr char DILATION[] = "dilation"; constexpr char FORMAT[] = "format"; @@ -207,6 +214,7 @@ constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; constexpr char GET_TENSOR_SLICE[] = "_GetTensorSlice"; constexpr char SPLIT[] = "Split"; constexpr char ALL_TO_ALL[] = "_AlltoAll"; +constexpr char NEIGHBOREXCHANGE[] = "NeighborExchange"; constexpr char PERMUTE_BY_AXIS[] = "PermuteByAxis"; constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis"; constexpr char SPLIT_BY_AXIS[] = "SplitByAxis"; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 0471fa22c0e..5657d8f2df9 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -388,7 +388,7 @@ inline const PrimitivePtr kPrimVirtualOutput = std::make_shared("_Vir inline const PrimitivePtr kPrimSend = std::make_shared("Send"); inline const PrimitivePtr kPrimReceive = std::make_shared("Receive"); inline const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); -inline const PrimitivePtr kPrimAllToAllv = std::make_shared("AllToAllv"); +inline const PrimitivePtr kPrimNeighborExchange = std::make_shared("NeighborExchange"); inline const PrimitivePtr kPrimAllSwap = std::make_shared("AllSwap"); inline const PrimitivePtr kPrimBroadcast = std::make_shared("Broadcast"); inline const PrimitivePtr kPrimAllGather = std::make_shared("AllGather"); diff --git a/mindspore/core/ops/alltoallv.cc b/mindspore/core/ops/neighborexchange.cc similarity index 85% rename from mindspore/core/ops/alltoallv.cc rename to mindspore/core/ops/neighborexchange.cc index 84946f4855d..ebb872f4a53 100644 --- a/mindspore/core/ops/alltoallv.cc +++ b/mindspore/core/ops/neighborexchange.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "ops/alltoallv.h" +#include "ops/neighborexchange.h" #include "ops/op_utils.h" #include "utils/check_convert_utils.h" #include "abstract/primitive_infer_map.h" @@ -46,7 +46,7 @@ abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vec TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - CheckAndConvertUtils::CheckInteger("AllToAllv infer", input_args.size(), kEqual, 1, prim_name); + CheckAndConvertUtils::CheckInteger("NeighborExchange infer", input_args.size(), kEqual, 1, prim_name); MS_EXCEPTION_IF_NULL(input_args[0]); auto recv_shapes = primitive->GetAttr(RecvShapes); MS_EXCEPTION_IF_NULL(recv_shapes); @@ -60,13 +60,13 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector(type_vec); } -AbstractBasePtr AllToAllvInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { +AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { auto type = InferType(primitive, input_args); auto shape = InferShape(primitive, input_args); return abstract::MakeAbstract(shape, type); } -REGISTER_PRIMITIVE_EVAL_IMPL(AllToAllv, prim::kPrimAllToAllv, AllToAllvInfer, nullptr, true); +REGISTER_PRIMITIVE_EVAL_IMPL(NeighborExchange, prim::kPrimNeighborExchange, NeighborExchangeInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/alltoallv.h b/mindspore/core/ops/neighborexchange.h similarity index 59% rename from mindspore/core/ops/alltoallv.h rename to mindspore/core/ops/neighborexchange.h index 4391bf7fe77..58f1e53da42 100644 --- a/mindspore/core/ops/alltoallv.h +++ b/mindspore/core/ops/neighborexchange.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CORE_OPS_ALLTOALLV_H_ -#define MINDSPORE_CORE_OPS_ALLTOALLV_H_ +#ifndef MINDSPORE_CORE_OPS_NEIGHBOREXCHANGE_H_ +#define MINDSPORE_CORE_OPS_NEIGHBOREXCHANGE_H_ #include #include #include "ops/primitive_c.h" @@ -24,20 +24,20 @@ namespace mindspore { namespace ops { -constexpr auto kNameAllToAllv = "AllToAllv"; +constexpr auto kNameNeighborExchange = "NeighborExchange"; constexpr auto RecvShapes = "recv_shapes"; constexpr auto RecvType = "recv_type"; -class AllToAllv : public PrimitiveC { +class NeighborExchange : public PrimitiveC { public: - AllToAllv() : PrimitiveC(kNameAllToAllv) {} - ~AllToAllv() = default; - MS_DECLARE_PARENT(AllToAllv, PrimitiveC); + NeighborExchange() : PrimitiveC(kNameNeighborExchange) {} + ~NeighborExchange() = default; + MS_DECLARE_PARENT(NeighborExchange, PrimitiveC); void Init() {} }; -AbstractBasePtr AllToAllvInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); -using PrimAllToAllPtr = std::shared_ptr; +AbstractBasePtr NeighborExchangeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimNeighborExchangePtr = std::shared_ptr; } // namespace ops } // namespace mindspore -#endif // MINDSPORE_CORE_OPS_ALLTOALLV_H_ +#endif // MINDSPORE_CORE_OPS_NEIGHBOREXCHANGE_H_ diff --git a/mindspore/ops/_grad_experimental/grad_comm_ops.py b/mindspore/ops/_grad_experimental/grad_comm_ops.py index eac9f55fa33..878a2d094a1 100644 --- a/mindspore/ops/_grad_experimental/grad_comm_ops.py +++ b/mindspore/ops/_grad_experimental/grad_comm_ops.py @@ -15,19 +15,19 @@ """Generate bprop for comm ops""" from .._grad.grad_base import bprop_getters -from ..operations._inner_ops import AllToAllv +from ..operations._inner_ops import NeighborExchange -@bprop_getters.register(AllToAllv) -def get_bprop_alltoallv(self): - """Generate bprop for AllToAllv.""" +@bprop_getters.register(NeighborExchange) +def get_bprop_neighborexchange(self): + """Generate bprop for NeighborExchange.""" group = self.group send_rank_ids = self.recv_rank_ids recv_rank_ids = self.send_rank_ids - recv_shapes = self.recv_shapes_backward + recv_shapes = self.send_shapes recv_type = self.recv_type - alltoallv_grad = AllToAllv(send_rank_ids, recv_rank_ids, recv_shapes, recv_shapes, recv_type, group) + neighborexchange_grad = NeighborExchange(send_rank_ids, recv_rank_ids, recv_shapes, recv_shapes, recv_type, group) def bprop(x, out, dout): - return (alltoallv_grad(dout),) + return (neighborexchange_grad(dout),) return bprop diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 82b76e16103..ab2a1609cdd 100755 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -492,29 +492,30 @@ class Receive(PrimitiveWithInfer): return self.dtype -class AllToAllv(Primitive): +class NeighborExchange(Primitive): """ - AlltoAllv is a collective operation. + NeighborExchange is a collective operation. - AlltoAllv sends data from the local rank to ranks in the send_rank_ids, as while receive data from recv_rank_ids. + NeighborExchange sends data from the local rank to ranks in the send_rank_ids, + as while receive data from recv_rank_ids. Args: send_rank_ids (list): Ranks which the data is sent to. recv_rank_ids (list): Ranks which the data is received from. recv_shapes (list): Data shape which received from recv_rank_ids. - recv_shapes_backward (list): Data shape which received from send_rank_ids in the backward. + send_shapes (list): Data shape which send to the send_rank_ids. recv_type (type): Data type which received from recv_rank_ids group (str): """ @prim_attr_register - def __init__(self, send_rank_ids, recv_rank_ids, recv_shapes, recv_shapes_backward, recv_type, + def __init__(self, send_rank_ids, recv_rank_ids, recv_shapes, send_shapes, recv_type, group=GlobalComm.WORLD_COMM_GROUP): self.init_prim_io_names(inputs=['x'], outputs=['output']) self.send_rank_ids = send_rank_ids self.recv_rank_ids = recv_rank_ids self.recv_shapes = recv_shapes - self.recv_shapes_backward = recv_shapes_backward + self.send_shapes = send_shapes self.recv_type = recv_type diff --git a/tests/ut/python/parallel/test_alltoall_v.py b/tests/ut/python/parallel/test_neighborexchange.py similarity index 80% rename from tests/ut/python/parallel/test_alltoall_v.py rename to tests/ut/python/parallel/test_neighborexchange.py index 3ecc315c3a5..787dd86704a 100644 --- a/tests/ut/python/parallel/test_alltoall_v.py +++ b/tests/ut/python/parallel/test_neighborexchange.py @@ -20,7 +20,7 @@ import mindspore.nn as nn from mindspore.common.api import _executor from mindspore.nn import TrainOneStepCell, Momentum from mindspore.ops import operations as P -from mindspore.ops.operations._inner_ops import AllToAllv +from mindspore.ops.operations._inner_ops import NeighborExchange class MatMulNet(nn.Cell): @@ -28,8 +28,8 @@ class MatMulNet(nn.Cell): super(MatMulNet, self).__init__() self.matmul = P.MatMul() self.mul = P.Mul() - self.alltoallv = AllToAllv(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]), - recv_shapes_backward=([32, 32], [32, 16]), recv_type=ms.float32) + self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]), + send_shapes=([32, 32], [32, 16]), recv_type=ms.float32) self.weight1 = Parameter(weight1, "w1") def construct(self, x1, x2): @@ -44,8 +44,8 @@ class MatMulNet2(nn.Cell): super(MatMulNet2, self).__init__() self.matmul = P.MatMul() self.mul = P.Mul() - self.alltoallv = AllToAllv(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]), - recv_shapes_backward=([32, 32],), recv_type=ms.float32) + self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]), + send_shapes=([32, 32],), recv_type=ms.float32) self.weight1 = Parameter(weight1, "w1") def construct(self, x1, x2): @@ -68,13 +68,13 @@ def compile_net(net): _executor.compile(train_net, _x1, _x2) -def test_AllToAllv_two_inputs(): +def test_NeighborExchange_two_inputs(): context.set_auto_parallel_context(device_num=8, global_rank=0) net = MatMulNet(_w1) compile_net(net) -def test_AllToAllv_single_input(): +def test_NeighborExchange_single_input(): context.set_auto_parallel_context(device_num=8, global_rank=0) net = MatMulNet2(_w1) compile_net(net)