diff --git a/mindspore/python/mindspore/ops/operations/comm_ops.py b/mindspore/python/mindspore/ops/operations/comm_ops.py index d9d8978151d..355bd3cb031 100644 --- a/mindspore/python/mindspore/ops/operations/comm_ops.py +++ b/mindspore/python/mindspore/ops/operations/comm_ops.py @@ -638,6 +638,35 @@ class NeighborExchange(Primitive): send_shapes (tuple(list(int))): Data shape which send to the send_rank_ids. recv_type (type): Data type which received from recv_rank_ids group (str): + + Example: + >>> # This example should be run with 2 devices. Refer to the tutorial > Distributed Training on mindspore.cn + >>> import os + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> from mindspore import context + >>> from mindspore.communication import init + >>> import mindspore.nn as nn + >>> import mindspore.ops as ops + >>> import numpy as np + >>> class Net(nn.Cell): + ... def __init__(self): + ... super(Net, self).__init__() + ... self.neighborexchange = ops.NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1], + ... recv_shapes=([2, 2],), send_shapes=([3, 3],), + ... recv_type=ms.float32) + ... + ... + ... def construct(self, x): + ... out = self.neighborexchange((x,)) + ... + >>> context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') + >>> init() + >>> net = Net() + >>> input_x = Tensor(np.ones([3, 3]), dtype = ms.float32) + >>> output = net(input_x) + >>> print(output) + [[2. 2.], [2. 2.]] """ @prim_attr_register @@ -676,6 +705,34 @@ class AlltoAll(PrimitiveWithInfer): Raises: TypeError: If group is not a string. + + Example: + >>> # This example should be run with 8 devices. Refer to the tutorial > Distributed Training on mindspore.cn + >>> import os + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> from mindspore import context + >>> from mindspore.communication import init + >>> import mindspore.nn as nn + >>> import mindspore.ops as ops + >>> import numpy as np + >>> class Net(nn.Cell): + ... def __init__(self): + ... super(Net, self).__init__() + ... self.alltoall = ops.AlltoAll(split_count = 8, split_dim = -2, concat_dim = -1) + ... + ... def construct(self, x): + ... out = self.alltoall(x) + ... return out + ... + >>> context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') + >>> init() + >>> net = Net() + >>> rank_id = int(os.getenv("RANK_ID")) + >>> input_x = Tensor(np.ones([1, 1, 8, 1]) * rank_id, dtype = ms.float32) + >>> output = net(input_x) + >>> print(output) + [[[[0. 1. 2. 3. 4. 5. 6. 7.]]]] """ @prim_attr_register @@ -728,6 +785,37 @@ class NeighborExchangeV2(Primitive): [top, bottom, left, right]. data_format (str): Data format, only support NCHW now. group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP". + + Example: + >>> # This example should be run with 2 devices. Refer to the tutorial > Distributed Training on mindspore.cn + >>> import os + >>> import mindspore as ms + >>> from mindspore import Tensor + >>> from mindspore import context + >>> from mindspore.communication import init + >>> import mindspore.nn as nn + >>> import mindspore.ops as ops + >>> import numpy as np + >>> class Net(nn.Cell): + ... def __init__(self): + ... super(Net, self).__init__() + ... self.neighborexchangev2 = ops.NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + ... send_lens=[0, 1, 0, 0], + ... recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1], + ... recv_lens=[0, 1, 0, 0], + ... data_format="NCHW") + ... + ... def construct(self, x): + ... out = self.neighborexchangev2(x) + ... return out + ... + >>> context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') + >>> init() + >>> input_x = Tensor(np.ones([1, 1, 2, 2]), dtype = ms.float32) + >>> net = Net() + >>> output = net(input_x) + >>> print(output) + [[[[1. 1.], [1. 1.], [2. 2.]]]] """ @prim_attr_register