forked from mindspore-Ecosystem/mindspore
!27871 add alltoall/neighborexchange/neighborexchangev2 example
Merge pull request !27871 from zhoufeng/code_docs_alltoall_example
This commit is contained in:
commit
5ba24ca880
|
@ -638,6 +638,35 @@ class NeighborExchange(Primitive):
|
|||
send_shapes (tuple(list(int))): Data shape which send to the send_rank_ids.
|
||||
recv_type (type): Data type which received from recv_rank_ids
|
||||
group (str):
|
||||
|
||||
Example:
|
||||
>>> # This example should be run with 2 devices. Refer to the tutorial > Distributed Training on mindspore.cn
|
||||
>>> import os
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore import context
|
||||
>>> from mindspore.communication import init
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import mindspore.ops as ops
|
||||
>>> import numpy as np
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.neighborexchange = ops.NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1],
|
||||
... recv_shapes=([2, 2],), send_shapes=([3, 3],),
|
||||
... recv_type=ms.float32)
|
||||
...
|
||||
...
|
||||
... def construct(self, x):
|
||||
... out = self.neighborexchange((x,))
|
||||
...
|
||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
>>> init()
|
||||
>>> net = Net()
|
||||
>>> input_x = Tensor(np.ones([3, 3]), dtype = ms.float32)
|
||||
>>> output = net(input_x)
|
||||
>>> print(output)
|
||||
[[2. 2.], [2. 2.]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -676,6 +705,34 @@ class AlltoAll(PrimitiveWithInfer):
|
|||
|
||||
Raises:
|
||||
TypeError: If group is not a string.
|
||||
|
||||
Example:
|
||||
>>> # This example should be run with 8 devices. Refer to the tutorial > Distributed Training on mindspore.cn
|
||||
>>> import os
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore import context
|
||||
>>> from mindspore.communication import init
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import mindspore.ops as ops
|
||||
>>> import numpy as np
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.alltoall = ops.AlltoAll(split_count = 8, split_dim = -2, concat_dim = -1)
|
||||
...
|
||||
... def construct(self, x):
|
||||
... out = self.alltoall(x)
|
||||
... return out
|
||||
...
|
||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
>>> init()
|
||||
>>> net = Net()
|
||||
>>> rank_id = int(os.getenv("RANK_ID"))
|
||||
>>> input_x = Tensor(np.ones([1, 1, 8, 1]) * rank_id, dtype = ms.float32)
|
||||
>>> output = net(input_x)
|
||||
>>> print(output)
|
||||
[[[[0. 1. 2. 3. 4. 5. 6. 7.]]]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -728,6 +785,37 @@ class NeighborExchangeV2(Primitive):
|
|||
[top, bottom, left, right].
|
||||
data_format (str): Data format, only support NCHW now.
|
||||
group (str): The communication group to work on. Default: "GlobalComm.WORLD_COMM_GROUP".
|
||||
|
||||
Example:
|
||||
>>> # This example should be run with 2 devices. Refer to the tutorial > Distributed Training on mindspore.cn
|
||||
>>> import os
|
||||
>>> import mindspore as ms
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore import context
|
||||
>>> from mindspore.communication import init
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import mindspore.ops as ops
|
||||
>>> import numpy as np
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self):
|
||||
... super(Net, self).__init__()
|
||||
... self.neighborexchangev2 = ops.NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
... send_lens=[0, 1, 0, 0],
|
||||
... recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
|
||||
... recv_lens=[0, 1, 0, 0],
|
||||
... data_format="NCHW")
|
||||
...
|
||||
... def construct(self, x):
|
||||
... out = self.neighborexchangev2(x)
|
||||
... return out
|
||||
...
|
||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
>>> init()
|
||||
>>> input_x = Tensor(np.ones([1, 1, 2, 2]), dtype = ms.float32)
|
||||
>>> net = Net()
|
||||
>>> output = net(input_x)
|
||||
>>> print(output)
|
||||
[[[[1. 1.], [1. 1.], [2. 2.]]]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
Loading…
Reference in New Issue