forked from mindspore-Ecosystem/mindspore
!40893 add muxsend/receive prim
Merge pull request !40893 from VectorSL/add-prim-muxsend/recev
This commit is contained in:
commit
0b71a52aea
|
@ -84,6 +84,8 @@ constexpr auto kHostAllGatherOpName = "HostAllGather";
|
||||||
constexpr auto kBroadcastOpName = "Broadcast";
|
constexpr auto kBroadcastOpName = "Broadcast";
|
||||||
constexpr auto kReceiveOpName = "Receive";
|
constexpr auto kReceiveOpName = "Receive";
|
||||||
constexpr auto kHcomSendOpName = "Send";
|
constexpr auto kHcomSendOpName = "Send";
|
||||||
|
constexpr auto kMuxReceiveOpName = "MuxReceive";
|
||||||
|
constexpr auto kMuxSendOpName = "MuxSend";
|
||||||
constexpr auto kReduceScatterOpName = "ReduceScatter";
|
constexpr auto kReduceScatterOpName = "ReduceScatter";
|
||||||
constexpr auto kHostReduceScatterOpName = "HostReduceScatter";
|
constexpr auto kHostReduceScatterOpName = "HostReduceScatter";
|
||||||
constexpr auto kMemCpyAsyncOpName = "memcpy_async";
|
constexpr auto kMemCpyAsyncOpName = "memcpy_async";
|
||||||
|
|
|
@ -899,9 +899,9 @@ bool AnfAlgo::IsInplaceNode(const mindspore::AnfNodePtr &kernel, const string &t
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfAlgo::IsCommunicationOp(const AnfNodePtr &node) {
|
bool AnfAlgo::IsCommunicationOp(const AnfNodePtr &node) {
|
||||||
static const std::set<std::string> kCommunicationOpNames = {kAllReduceOpName, kAllGatherOpName, kBroadcastOpName,
|
static const std::set<std::string> kCommunicationOpNames = {kAllReduceOpName, kAllGatherOpName, kBroadcastOpName,
|
||||||
kReduceScatterOpName, kHcomSendOpName, kReceiveOpName,
|
kReduceScatterOpName, kHcomSendOpName, kReceiveOpName,
|
||||||
kAllToAllVOpName};
|
kAllToAllVOpName, kMuxReceiveOpName, kMuxSendOpName};
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -23,6 +23,7 @@ from ..._checkparam import Validator as validator
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
|
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
|
||||||
from ..._checkparam import Rel
|
from ..._checkparam import Rel
|
||||||
|
from ...communication.management import GlobalComm
|
||||||
|
|
||||||
|
|
||||||
class EnvCreate(PrimitiveWithInfer):
|
class EnvCreate(PrimitiveWithInfer):
|
||||||
|
@ -1047,3 +1048,121 @@ class TensorsQueueClear(PrimitiveWithInfer):
|
||||||
def infer_dtype(self, handle_type):
|
def infer_dtype(self, handle_type):
|
||||||
validator.check_type_name("handle", handle_type, (mstype.int64), self.name)
|
validator.check_type_name("handle", handle_type, (mstype.int64), self.name)
|
||||||
return mstype.int64
|
return mstype.int64
|
||||||
|
|
||||||
|
|
||||||
|
class MuxSend(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Send tensors to the specified dest_rank.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This is an experimental prototype that is subject to change and/or deletion.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Send and Receive must be used in combination.
|
||||||
|
Send must be used between servers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dest_rank (int): A required integer identifying the destination rank.
|
||||||
|
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mindspore.ops as ops
|
||||||
|
>>> import mindspore.nn as nn
|
||||||
|
>>> from mindspore.communication import init
|
||||||
|
>>> from mindspore import Tensor
|
||||||
|
>>> import numpy as np
|
||||||
|
>>>
|
||||||
|
>>> init()
|
||||||
|
>>> class Net(nn.Cell):
|
||||||
|
>>> def __init__(self):
|
||||||
|
>>> super(Net, self).__init__()
|
||||||
|
>>> self.depend = ops.Depend()
|
||||||
|
>>> self.send = ops.Send(dest_rank=8, group="hccl_world_group")
|
||||||
|
>>>
|
||||||
|
>>> def construct(self, x):
|
||||||
|
>>> out = self.depend(x, self.send(x))
|
||||||
|
>>> return out
|
||||||
|
>>>
|
||||||
|
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
|
||||||
|
>>> net = Net()
|
||||||
|
>>> output = net(input_)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
|
||||||
|
self.dest_rank = dest_rank
|
||||||
|
self.group = group
|
||||||
|
self.add_prim_attr("fusion", 1)
|
||||||
|
self.add_prim_attr('side_effect_mem', True)
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape):
|
||||||
|
self.add_prim_attr("shape", x_shape)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype):
|
||||||
|
return x_dtype[0]
|
||||||
|
|
||||||
|
|
||||||
|
class MuxReceive(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
receive tensors from src_rank.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
This is an experimental prototype that is subject to change and/or deletion.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Send and Receive must be used in combination.
|
||||||
|
Receive must be used between servers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape (list[int]): A required list identifying the shape of the tensor to be received.
|
||||||
|
dtype (Type): A required Type identifying the type of the tensor to be received. The supported types:
|
||||||
|
int8, int16, int32, float16, float32.
|
||||||
|
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import mindspore.ops as ops
|
||||||
|
>>> import mindspore.nn as nn
|
||||||
|
>>> from mindspore.communication import init
|
||||||
|
>>> from mindspore import Tensor
|
||||||
|
>>> import numpy as np
|
||||||
|
>>>
|
||||||
|
>>> init()
|
||||||
|
>>> class Net(nn.Cell):
|
||||||
|
>>> def __init__(self):
|
||||||
|
>>> super(Net, self).__init__()
|
||||||
|
>>> self.recv = ops.Receive(shape=[2, 8], dtype=np.float32, group="hccl_world_group")
|
||||||
|
>>>
|
||||||
|
>>> def construct(self):
|
||||||
|
>>> out = self.recv()
|
||||||
|
>>> return out
|
||||||
|
>>>
|
||||||
|
>>> net = Net()
|
||||||
|
>>> output = net()
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
|
||||||
|
self.shape = shape
|
||||||
|
self.dtype = dtype
|
||||||
|
self.group = group
|
||||||
|
valid_type = [mstype.float16, mstype.float32, mstype.int32, mstype.int8, mstype.uint8]
|
||||||
|
args = {"dtype": dtype}
|
||||||
|
self.add_prim_attr('side_effect_mem', True)
|
||||||
|
self.add_prim_attr("fusion", 1)
|
||||||
|
validator.check_scalar_or_tensor_types_same(args, valid_type, self.name)
|
||||||
|
|
||||||
|
def infer_shape(self, x_shape=None):
|
||||||
|
return tuple(self.get_attr_dict()['shape'])
|
||||||
|
|
||||||
|
def infer_dtype(self, x_dtype=None):
|
||||||
|
out_type = []
|
||||||
|
for _ in self.shape:
|
||||||
|
out_type.append(self.dtype)
|
||||||
|
return tuple(out_type)
|
||||||
|
|
Loading…
Reference in New Issue