Move the AllSwap to the _AllSwap

This commit is contained in:
huangxinjing 2021-12-13 11:24:10 +08:00
parent 9fb2f887e5
commit f5889750f4
6 changed files with 20 additions and 21 deletions

View File

@ -455,7 +455,7 @@ inline const PrimitivePtr kPrimNeighborExchangeV2 = std::make_shared<Primitive>(
inline const PrimitivePtr kPrimNeighborExchangeV2Grad = std::make_shared<Primitive>("NeighborExchangeV2Grad");
inline const PrimitivePtr kPrimAllToAll = std::make_shared<Primitive>("AlltoAll");
inline const PrimitivePtr kPrimAllToAllv = std::make_shared<Primitive>("AllToAllv");
inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap");
inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("_AllSwap");
inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast");
inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGather");
inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter");

View File

@ -24,7 +24,7 @@ from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, NeighborExchange, AlltoAll, NeighborExchangeV2,
Broadcast, _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap,
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, _AllSwap,
_VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator, _MicroStepAllGather)
from .grad_base import bprop_getters
from ..operations._inner_ops import Send, Receive
@ -352,10 +352,10 @@ def get_bprop_reduce_scatter(self):
return bprop
@bprop_getters.register(AllSwap)
@bprop_getters.register(_AllSwap)
def get_bprop_allswap(self):
"""Generate bprop for AllSwap."""
all_swap_grad = AllSwap(self.group)
"""Generate bprop for _AllSwap."""
all_swap_grad = _AllSwap(self.group)
if self.instance_name:
instance_name = "grad" + self.instance_name
all_swap_grad.set_prim_instance_name(instance_name)

View File

@ -1,9 +1,8 @@
0.1.0 MindSpore*1.6.0:<3A>
l
bprop.20:doutbprop.20:[CNode]21:1bprop.20:[CNode]21:1"S-Prim-MakeTuple:Default/S-Prim-MakeTuple-op13bprop.20*
bprop.20:x*
bprop.20:out*
bprop.20:dout2
bprop.20:[CNode]21:1:@96c75d48466ae9dd2ae51ee64181426e1bf1c36337f7c6cf3bdd01083bfb1a6eP
0.1.0 MindSpore*1.5.0:ó
e
bprop.8:doutbprop.8:[CNode]:1bprop.8:[CNode]:1"S-Prim-MakeTuple:Default/S-Prim-MakeTuple-op13bprop.8*
bprop.8:x*
bprop.8:out*
bprop.8:dout2
bprop.8:[CNode]:1:@6114cc535041d10038b41a1d2a9da09c7ffdfc140121c87b8837270a9795c3a0P

View File

@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements, ExtractVolumePatches)
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, AllSwap, ReduceScatter, Broadcast,
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather)
@ -361,7 +361,7 @@ __all__ = [
'UnsortedSegmentProd',
"AllGather",
"AllReduce",
"AllSwap",
"_AllSwap",
"ReduceScatter",
"Broadcast",
"ReduceOp",

View File

@ -570,11 +570,11 @@ class Broadcast(PrimitiveWithInfer):
return x_dtype
class AllSwap(PrimitiveWithCheck):
class _AllSwap(PrimitiveWithCheck):
"""
AllSwap is a collective operation.
_AllSwap is a collective operation.
AllSwap sends data from the all processes to the all processes in the specified group. It has two phases:
_AllSwap sends data from the all processes to the all processes in the specified group. It has two phases:
- The scatter phase: On each process, the operand is split into the send size of blocks along the
0-th axis, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.
@ -600,7 +600,7 @@ class AllSwap(PrimitiveWithCheck):
@prim_attr_register
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
"""Initialize AllSwap"""
"""Initialize _AllSwap"""
validator.check_value_type('group', _get_group(group), (str,), self.name)
self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out'])
self.add_prim_attr('group', _get_group(group))

View File

@ -26,7 +26,7 @@ from mindspore.nn import Momentum
from mindspore.nn import ReLU
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, AlltoAll, ReduceOp, ReduceScatter
from mindspore.ops.operations.comm_ops import Broadcast, AllSwap
from mindspore.ops.operations.comm_ops import Broadcast, _AllSwap
from mindspore.ops.operations.array_ops import Gather
import mindspore
@ -128,7 +128,7 @@ class AllSwapNet(nn.Cell):
def __init__(self, batch_size, input_channel, out_channel):
super(AllSwapNet, self).__init__()
self.dense = Dense(input_channel, out_channel)
self.allswap = AllSwap()
self.allswap = _AllSwap()
self.relu = ReLU()
part_slice = batch_size / 2
self.send_size = Tensor([0, part_slice*out_channel, part_slice*out_channel], mindspore.int64)