From f5889750f4d240f1544f52b3d66ba8dee9a357fd Mon Sep 17 00:00:00 2001 From: huangxinjing Date: Mon, 13 Dec 2021 11:24:10 +0800 Subject: [PATCH] Move the AllSwap to the _AllSwap --- mindspore/core/base/core_ops.h | 2 +- mindspore/ops/_grad/grad_comm_ops.py | 8 ++++---- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir | 15 +++++++-------- mindspore/ops/operations/__init__.py | 4 ++-- mindspore/ops/operations/comm_ops.py | 8 ++++---- tests/ut/python/communication/test_comm.py | 4 ++-- 6 files changed, 20 insertions(+), 21 deletions(-) diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index cd0ae743d18..290c1516ed3 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -455,7 +455,7 @@ inline const PrimitivePtr kPrimNeighborExchangeV2 = std::make_shared( inline const PrimitivePtr kPrimNeighborExchangeV2Grad = std::make_shared("NeighborExchangeV2Grad"); inline const PrimitivePtr kPrimAllToAll = std::make_shared("AlltoAll"); inline const PrimitivePtr kPrimAllToAllv = std::make_shared("AllToAllv"); -inline const PrimitivePtr kPrimAllSwap = std::make_shared("AllSwap"); +inline const PrimitivePtr kPrimAllSwap = std::make_shared("_AllSwap"); inline const PrimitivePtr kPrimBroadcast = std::make_shared("Broadcast"); inline const PrimitivePtr kPrimAllGather = std::make_shared("AllGather"); inline const PrimitivePtr kPrimReduceScatter = std::make_shared("ReduceScatter"); diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 424ae096732..3e2af65ffd7 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -24,7 +24,7 @@ from ...common.tensor import RowTensor from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, NeighborExchange, AlltoAll, NeighborExchangeV2, Broadcast, _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, - ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap, + ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, _AllSwap, _VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator, _MicroStepAllGather) from .grad_base import bprop_getters from ..operations._inner_ops import Send, Receive @@ -352,10 +352,10 @@ def get_bprop_reduce_scatter(self): return bprop -@bprop_getters.register(AllSwap) +@bprop_getters.register(_AllSwap) def get_bprop_allswap(self): - """Generate bprop for AllSwap.""" - all_swap_grad = AllSwap(self.group) + """Generate bprop for _AllSwap.""" + all_swap_grad = _AllSwap(self.group) if self.instance_name: instance_name = "grad" + self.instance_name all_swap_grad.set_prim_instance_name(instance_name) diff --git a/mindspore/ops/bprop_mindir/Broadcast_bprop.mindir b/mindspore/ops/bprop_mindir/Broadcast_bprop.mindir index 94fca38022e..c7d72a88349 100644 --- a/mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +++ b/mindspore/ops/bprop_mindir/Broadcast_bprop.mindir @@ -1,9 +1,8 @@ -0.1.0 MindSpore*1.6.0: -l - bprop.20:doutbprop.20:[CNode]21:1bprop.20:[CNode]21:1"S-Prim-MakeTuple:Default/S-Prim-MakeTuple-op13bprop.20* - -bprop.20:x* - bprop.20:out* - bprop.20:dout2 -bprop.20:[CNode]21:1:@96c75d48466ae9dd2ae51ee64181426e1bf1c36337f7c6cf3bdd01083bfb1a6eP \ No newline at end of file +0.1.0 MindSpore*1.5.0: +e + bprop.8:doutbprop.8:[CNode]:1bprop.8:[CNode]:1"S-Prim-MakeTuple:Default/S-Prim-MakeTuple-op13bprop.8* + bprop.8:x* + bprop.8:out* + bprop.8:dout2 +bprop.8:[CNode]:1:@6114cc535041d10038b41a1d2a9da09c7ffdfc140121c87b8837270a9795c3a0P \ No newline at end of file diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 013008d3001..c246da9df59 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted, TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements, ExtractVolumePatches) -from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, AllSwap, ReduceScatter, Broadcast, +from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter, Broadcast, _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, _VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad, _HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather) @@ -361,7 +361,7 @@ __all__ = [ 'UnsortedSegmentProd', "AllGather", "AllReduce", - "AllSwap", + "_AllSwap", "ReduceScatter", "Broadcast", "ReduceOp", diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index a6c9c44ea04..4020e7d86ab 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -570,11 +570,11 @@ class Broadcast(PrimitiveWithInfer): return x_dtype -class AllSwap(PrimitiveWithCheck): +class _AllSwap(PrimitiveWithCheck): """ - AllSwap is a collective operation. + _AllSwap is a collective operation. - AllSwap sends data from the all processes to the all processes in the specified group. It has two phases: + _AllSwap sends data from the all processes to the all processes in the specified group. It has two phases: - The scatter phase: On each process, the operand is split into the send size of blocks along the 0-th axis, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process. @@ -600,7 +600,7 @@ class AllSwap(PrimitiveWithCheck): @prim_attr_register def __init__(self, group=GlobalComm.WORLD_COMM_GROUP): - """Initialize AllSwap""" + """Initialize _AllSwap""" validator.check_value_type('group', _get_group(group), (str,), self.name) self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out']) self.add_prim_attr('group', _get_group(group)) diff --git a/tests/ut/python/communication/test_comm.py b/tests/ut/python/communication/test_comm.py index f644591982a..27191e8ea2c 100644 --- a/tests/ut/python/communication/test_comm.py +++ b/tests/ut/python/communication/test_comm.py @@ -26,7 +26,7 @@ from mindspore.nn import Momentum from mindspore.nn import ReLU from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.ops.operations.comm_ops import AllReduce, AllGather, AlltoAll, ReduceOp, ReduceScatter -from mindspore.ops.operations.comm_ops import Broadcast, AllSwap +from mindspore.ops.operations.comm_ops import Broadcast, _AllSwap from mindspore.ops.operations.array_ops import Gather import mindspore @@ -128,7 +128,7 @@ class AllSwapNet(nn.Cell): def __init__(self, batch_size, input_channel, out_channel): super(AllSwapNet, self).__init__() self.dense = Dense(input_channel, out_channel) - self.allswap = AllSwap() + self.allswap = _AllSwap() self.relu = ReLU() part_slice = batch_size / 2 self.send_size = Tensor([0, part_slice*out_channel, part_slice*out_channel], mindspore.int64)