!40893 add muxsend/receive prim

Merge pull request !40893 from VectorSL/add-prim-muxsend/recev
This commit is contained in:
i-robot 2022-08-28 11:23:24 +00:00 committed by Gitee
commit 0b71a52aea
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 124 additions and 3 deletions

View File

@ -84,6 +84,8 @@ constexpr auto kHostAllGatherOpName = "HostAllGather";
constexpr auto kBroadcastOpName = "Broadcast";
constexpr auto kReceiveOpName = "Receive";
constexpr auto kHcomSendOpName = "Send";
constexpr auto kMuxReceiveOpName = "MuxReceive";
constexpr auto kMuxSendOpName = "MuxSend";
constexpr auto kReduceScatterOpName = "ReduceScatter";
constexpr auto kHostReduceScatterOpName = "HostReduceScatter";
constexpr auto kMemCpyAsyncOpName = "memcpy_async";

View File

@ -901,7 +901,7 @@ bool AnfAlgo::IsInplaceNode(const mindspore::AnfNodePtr &kernel, const string &t
bool AnfAlgo::IsCommunicationOp(const AnfNodePtr &node) {
static const std::set<std::string> kCommunicationOpNames = {kAllReduceOpName, kAllGatherOpName, kBroadcastOpName,
kReduceScatterOpName, kHcomSendOpName, kReceiveOpName,
kAllToAllVOpName};
kAllToAllVOpName, kMuxReceiveOpName, kMuxSendOpName};
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;

View File

@ -23,6 +23,7 @@ from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
from ..._checkparam import Rel
from ...communication.management import GlobalComm
class EnvCreate(PrimitiveWithInfer):
@ -1047,3 +1048,121 @@ class TensorsQueueClear(PrimitiveWithInfer):
def infer_dtype(self, handle_type):
validator.check_type_name("handle", handle_type, (mstype.int64), self.name)
return mstype.int64
class MuxSend(PrimitiveWithInfer):
r"""
Send tensors to the specified dest_rank.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Note:
Send and Receive must be used in combination.
Send must be used between servers.
Args:
dest_rank (int): A required integer identifying the destination rank.
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Examples:
>>> import mindspore.ops as ops
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>> import numpy as np
>>>
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.depend = ops.Depend()
>>> self.send = ops.Send(dest_rank=8, group="hccl_world_group")
>>>
>>> def construct(self, x):
>>> out = self.depend(x, self.send(x))
>>> return out
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
@prim_attr_register
def __init__(self, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
self.dest_rank = dest_rank
self.group = group
self.add_prim_attr("fusion", 1)
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, x_shape):
self.add_prim_attr("shape", x_shape)
return []
def infer_dtype(self, x_dtype):
return x_dtype[0]
class MuxReceive(PrimitiveWithInfer):
r"""
receive tensors from src_rank.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Note:
Send and Receive must be used in combination.
Receive must be used between servers.
Args:
shape (list[int]): A required list identifying the shape of the tensor to be received.
dtype (Type): A required Type identifying the type of the tensor to be received. The supported types:
int8, int16, int32, float16, float32.
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Examples:
>>> import mindspore.ops as ops
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>> import numpy as np
>>>
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.recv = ops.Receive(shape=[2, 8], dtype=np.float32, group="hccl_world_group")
>>>
>>> def construct(self):
>>> out = self.recv()
>>> return out
>>>
>>> net = Net()
>>> output = net()
"""
@prim_attr_register
def __init__(self, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
self.shape = shape
self.dtype = dtype
self.group = group
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
args = {"dtype": dtype}
self.add_prim_attr('side_effect_mem', True)
self.add_prim_attr("fusion", 1)
validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
def infer_shape(self, x_shape=None):
return tuple(self.get_attr_dict()['shape'])
def infer_dtype(self, x_dtype=None):
out_type = []
for _ in self.shape:
out_type.append(self.dtype)
return tuple(out_type)