!27871 add alltoall/neighborexchange/neighborexchangev2 example

Merge pull request !27871 from zhoufeng/code_docs_alltoall_example
This commit is contained in:
i-robot 2021-12-20 01:40:56 +00:00 committed by Gitee
commit 5ba24ca880
1 changed files with 88 additions and 0 deletions

View File

@ -638,6 +638,35 @@ class NeighborExchange(Primitive):
send_shapes (tuple(list(int))): Data shape which send to the send_rank_ids.
recv_type (type): Data type which received from recv_rank_ids
group (str):
Example:
>>> # This example should be run with 2 devices. Refer to the tutorial > Distributed Training on mindspore.cn
>>> import os
>>> import mindspore as ms
>>> from mindspore import Tensor
>>> from mindspore import context
>>> from mindspore.communication import init
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> import numpy as np
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.neighborexchange = ops.NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1],
... recv_shapes=([2, 2],), send_shapes=([3, 3],),
... recv_type=ms.float32)
...
...
... def construct(self, x):
... out = self.neighborexchange((x,))
...
>>> context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
>>> init()
>>> net = Net()
>>> input_x = Tensor(np.ones([3, 3]), dtype = ms.float32)
>>> output = net(input_x)
>>> print(output)
[[2. 2.], [2. 2.]]
"""
@prim_attr_register
@ -676,6 +705,34 @@ class AlltoAll(PrimitiveWithInfer):
Raises:
TypeError: If group is not a string.
Example:
>>> # This example should be run with 8 devices. Refer to the tutorial > Distributed Training on mindspore.cn
>>> import os
>>> import mindspore as ms
>>> from mindspore import Tensor
>>> from mindspore import context
>>> from mindspore.communication import init
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> import numpy as np
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.alltoall = ops.AlltoAll(split_count = 8, split_dim = -2, concat_dim = -1)
...
... def construct(self, x):
... out = self.alltoall(x)
... return out
...
>>> context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
>>> init()
>>> net = Net()
>>> rank_id = int(os.getenv("RANK_ID"))
>>> input_x = Tensor(np.ones([1, 1, 8, 1]) * rank_id, dtype = ms.float32)
>>> output = net(input_x)
>>> print(output)
[[[[0. 1. 2. 3. 4. 5. 6. 7.]]]]
"""
@prim_attr_register
@ -728,6 +785,37 @@ class NeighborExchangeV2(Primitive):
[top, bottom, left, right].
data_format (str): Data format, only support NCHW now.
group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP".
Example:
>>> # This example should be run with 2 devices. Refer to the tutorial > Distributed Training on mindspore.cn
>>> import os
>>> import mindspore as ms
>>> from mindspore import Tensor
>>> from mindspore import context
>>> from mindspore.communication import init
>>> import mindspore.nn as nn
>>> import mindspore.ops as ops
>>> import numpy as np
>>> class Net(nn.Cell):
... def __init__(self):
... super(Net, self).__init__()
... self.neighborexchangev2 = ops.NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
... send_lens=[0, 1, 0, 0],
... recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
... recv_lens=[0, 1, 0, 0],
... data_format="NCHW")
...
... def construct(self, x):
... out = self.neighborexchangev2(x)
... return out
...
>>> context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
>>> init()
>>> input_x = Tensor(np.ones([1, 1, 2, 2]), dtype = ms.float32)
>>> net = Net()
>>> output = net(input_x)
>>> print(output)
[[[[1. 1.], [1. 1.], [2. 2.]]]]
"""
@prim_attr_register