forked from mindspore-Ecosystem/mindspore
!40893 add muxsend/receive prim
Merge pull request !40893 from VectorSL/add-prim-muxsend/recev
This commit is contained in:
commit
0b71a52aea
|
@ -84,6 +84,8 @@ constexpr auto kHostAllGatherOpName = "HostAllGather";
|
|||
constexpr auto kBroadcastOpName = "Broadcast";
|
||||
constexpr auto kReceiveOpName = "Receive";
|
||||
constexpr auto kHcomSendOpName = "Send";
|
||||
constexpr auto kMuxReceiveOpName = "MuxReceive";
|
||||
constexpr auto kMuxSendOpName = "MuxSend";
|
||||
constexpr auto kReduceScatterOpName = "ReduceScatter";
|
||||
constexpr auto kHostReduceScatterOpName = "HostReduceScatter";
|
||||
constexpr auto kMemCpyAsyncOpName = "memcpy_async";
|
||||
|
|
|
@ -901,7 +901,7 @@ bool AnfAlgo::IsInplaceNode(const mindspore::AnfNodePtr &kernel, const string &t
|
|||
bool AnfAlgo::IsCommunicationOp(const AnfNodePtr &node) {
|
||||
static const std::set<std::string> kCommunicationOpNames = {kAllReduceOpName, kAllGatherOpName, kBroadcastOpName,
|
||||
kReduceScatterOpName, kHcomSendOpName, kReceiveOpName,
|
||||
kAllToAllVOpName};
|
||||
kAllToAllVOpName, kMuxReceiveOpName, kMuxSendOpName};
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
|
|
|
@ -23,6 +23,7 @@ from ..._checkparam import Validator as validator
|
|||
from ...common import dtype as mstype
|
||||
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
|
||||
from ..._checkparam import Rel
|
||||
from ...communication.management import GlobalComm
|
||||
|
||||
|
||||
class EnvCreate(PrimitiveWithInfer):
|
||||
|
@ -1047,3 +1048,121 @@ class TensorsQueueClear(PrimitiveWithInfer):
|
|||
def infer_dtype(self, handle_type):
|
||||
validator.check_type_name("handle", handle_type, (mstype.int64), self.name)
|
||||
return mstype.int64
|
||||
|
||||
|
||||
class MuxSend(PrimitiveWithInfer):
|
||||
r"""
|
||||
Send tensors to the specified dest_rank.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Note:
|
||||
Send and Receive must be used in combination.
|
||||
Send must be used between servers.
|
||||
|
||||
Args:
|
||||
dest_rank (int): A required integer identifying the destination rank.
|
||||
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.ops as ops
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore.communication import init
|
||||
>>> from mindspore import Tensor
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> init()
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.depend = ops.Depend()
|
||||
>>> self.send = ops.Send(dest_rank=8, group="hccl_world_group")
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> out = self.depend(x, self.send(x))
|
||||
>>> return out
|
||||
>>>
|
||||
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
|
||||
>>> net = Net()
|
||||
>>> output = net(input_)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
self.dest_rank = dest_rank
|
||||
self.group = group
|
||||
self.add_prim_attr("fusion", 1)
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
self.add_prim_attr("shape", x_shape)
|
||||
return []
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
return x_dtype[0]
|
||||
|
||||
|
||||
class MuxReceive(PrimitiveWithInfer):
|
||||
r"""
|
||||
receive tensors from src_rank.
|
||||
|
||||
.. warning::
|
||||
This is an experimental prototype that is subject to change and/or deletion.
|
||||
|
||||
Note:
|
||||
Send and Receive must be used in combination.
|
||||
Receive must be used between servers.
|
||||
|
||||
Args:
|
||||
shape (list[int]): A required list identifying the shape of the tensor to be received.
|
||||
dtype (Type): A required Type identifying the type of the tensor to be received. The supported types:
|
||||
int8, int16, int32, float16, float32.
|
||||
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.ops as ops
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore.communication import init
|
||||
>>> from mindspore import Tensor
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> init()
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.recv = ops.Receive(shape=[2, 8], dtype=np.float32, group="hccl_world_group")
|
||||
>>>
|
||||
>>> def construct(self):
|
||||
>>> out = self.recv()
|
||||
>>> return out
|
||||
>>>
|
||||
>>> net = Net()
|
||||
>>> output = net()
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.group = group
|
||||
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
|
||||
args = {"dtype": dtype}
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
self.add_prim_attr("fusion", 1)
|
||||
validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
|
||||
|
||||
def infer_shape(self, x_shape=None):
|
||||
return tuple(self.get_attr_dict()['shape'])
|
||||
|
||||
def infer_dtype(self, x_dtype=None):
|
||||
out_type = []
|
||||
for _ in self.shape:
|
||||
out_type.append(self.dtype)
|
||||
return tuple(out_type)
|
||||
|
|
Loading…
Reference in New Issue