Move the AllSwap to the _AllSwap
This commit is contained in:
parent
9fb2f887e5
commit
f5889750f4
|
@ -455,7 +455,7 @@ inline const PrimitivePtr kPrimNeighborExchangeV2 = std::make_shared<Primitive>(
|
|||
inline const PrimitivePtr kPrimNeighborExchangeV2Grad = std::make_shared<Primitive>("NeighborExchangeV2Grad");
|
||||
inline const PrimitivePtr kPrimAllToAll = std::make_shared<Primitive>("AlltoAll");
|
||||
inline const PrimitivePtr kPrimAllToAllv = std::make_shared<Primitive>("AllToAllv");
|
||||
inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap");
|
||||
inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("_AllSwap");
|
||||
inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast");
|
||||
inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGather");
|
||||
inline const PrimitivePtr kPrimReduceScatter = std::make_shared<Primitive>("ReduceScatter");
|
||||
|
|
|
@ -24,7 +24,7 @@ from ...common.tensor import RowTensor
|
|||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, NeighborExchange, AlltoAll, NeighborExchangeV2,
|
||||
Broadcast, _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
|
||||
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap,
|
||||
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, _AllSwap,
|
||||
_VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator, _MicroStepAllGather)
|
||||
from .grad_base import bprop_getters
|
||||
from ..operations._inner_ops import Send, Receive
|
||||
|
@ -352,10 +352,10 @@ def get_bprop_reduce_scatter(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(AllSwap)
|
||||
@bprop_getters.register(_AllSwap)
|
||||
def get_bprop_allswap(self):
|
||||
"""Generate bprop for AllSwap."""
|
||||
all_swap_grad = AllSwap(self.group)
|
||||
"""Generate bprop for _AllSwap."""
|
||||
all_swap_grad = _AllSwap(self.group)
|
||||
if self.instance_name:
|
||||
instance_name = "grad" + self.instance_name
|
||||
all_swap_grad.set_prim_instance_name(instance_name)
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
|
||||
0.1.0 MindSpore*1.6.0:<3A>
|
||||
l
|
||||
bprop.20:doutbprop.20:[CNode]21:1bprop.20:[CNode]21:1"S-Prim-MakeTuple:Default/S-Prim-MakeTuple-op13bprop.20*
|
||||
|
||||
bprop.20:x*
|
||||
bprop.20:out*
|
||||
bprop.20:dout2
|
||||
bprop.20:[CNode]21:1:@96c75d48466ae9dd2ae51ee64181426e1bf1c36337f7c6cf3bdd01083bfb1a6eP
|
||||
0.1.0 MindSpore*1.5.0:ó
|
||||
e
|
||||
bprop.8:doutbprop.8:[CNode]:1bprop.8:[CNode]:1"S-Prim-MakeTuple:Default/S-Prim-MakeTuple-op13bprop.8*
|
||||
bprop.8:x*
|
||||
bprop.8:out*
|
||||
bprop.8:dout2
|
||||
bprop.8:[CNode]:1:@6114cc535041d10038b41a1d2a9da09c7ffdfc140121c87b8837270a9795c3a0P
|
|
@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
|||
BatchToSpace, SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence,
|
||||
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
|
||||
TensorScatterMax, TensorScatterMin, TensorScatterSub, ScatterElements, ExtractVolumePatches)
|
||||
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, AllSwap, ReduceScatter, Broadcast,
|
||||
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
||||
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
|
||||
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather)
|
||||
|
@ -361,7 +361,7 @@ __all__ = [
|
|||
'UnsortedSegmentProd',
|
||||
"AllGather",
|
||||
"AllReduce",
|
||||
"AllSwap",
|
||||
"_AllSwap",
|
||||
"ReduceScatter",
|
||||
"Broadcast",
|
||||
"ReduceOp",
|
||||
|
|
|
@ -570,11 +570,11 @@ class Broadcast(PrimitiveWithInfer):
|
|||
return x_dtype
|
||||
|
||||
|
||||
class AllSwap(PrimitiveWithCheck):
|
||||
class _AllSwap(PrimitiveWithCheck):
|
||||
"""
|
||||
AllSwap is a collective operation.
|
||||
_AllSwap is a collective operation.
|
||||
|
||||
AllSwap sends data from the all processes to the all processes in the specified group. It has two phases:
|
||||
_AllSwap sends data from the all processes to the all processes in the specified group. It has two phases:
|
||||
|
||||
- The scatter phase: On each process, the operand is split into the send size of blocks along the
|
||||
0-th axis, and the blocks are scattered to all processes, e.g., the ith block is send to the ith process.
|
||||
|
@ -600,7 +600,7 @@ class AllSwap(PrimitiveWithCheck):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
"""Initialize AllSwap"""
|
||||
"""Initialize _AllSwap"""
|
||||
validator.check_value_type('group', _get_group(group), (str,), self.name)
|
||||
self.init_prim_io_names(inputs=['tensor_in', 'send_size', 'recv_size'], outputs=['tensor_out'])
|
||||
self.add_prim_attr('group', _get_group(group))
|
||||
|
|
|
@ -26,7 +26,7 @@ from mindspore.nn import Momentum
|
|||
from mindspore.nn import ReLU
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, AlltoAll, ReduceOp, ReduceScatter
|
||||
from mindspore.ops.operations.comm_ops import Broadcast, AllSwap
|
||||
from mindspore.ops.operations.comm_ops import Broadcast, _AllSwap
|
||||
from mindspore.ops.operations.array_ops import Gather
|
||||
import mindspore
|
||||
|
||||
|
@ -128,7 +128,7 @@ class AllSwapNet(nn.Cell):
|
|||
def __init__(self, batch_size, input_channel, out_channel):
|
||||
super(AllSwapNet, self).__init__()
|
||||
self.dense = Dense(input_channel, out_channel)
|
||||
self.allswap = AllSwap()
|
||||
self.allswap = _AllSwap()
|
||||
self.relu = ReLU()
|
||||
part_slice = batch_size / 2
|
||||
self.send_size = Tensor([0, part_slice*out_channel, part_slice*out_channel], mindspore.int64)
|
||||
|
|
Loading…
Reference in New Issue