forked from mindspore-Ecosystem/mindspore
add operator HostAllGather and HostReduceScatter
This commit is contained in:
parent
93fc82b8f7
commit
2f8e7ff693
|
@ -55,7 +55,9 @@ const char kNameSimpleMeanGrad[] = "SimpleMeanGrad";
|
|||
const char kNameAllReduce[] = "AllReduce";
|
||||
const char kNameBroadcast[] = "Broadcast";
|
||||
const char kNameAllgather[] = "AllGather";
|
||||
const char kNameHostAllgather[] = "HostAllGather";
|
||||
const char kNameReduceScatter[] = "ReduceScatter";
|
||||
const char kNameHostReduceScatter[] = "HostReduceScatter";
|
||||
const char kNameReduceSum[] = "ReduceSum";
|
||||
const char kNameIsFinite[] = "isFinite";
|
||||
const char kNameReciprocal[] = "Reciprocal";
|
||||
|
|
|
@ -45,8 +45,10 @@ constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean";
|
|||
constexpr auto kGetNextOpName = "GetNext";
|
||||
constexpr auto kAllReduceOpName = "AllReduce";
|
||||
constexpr auto kAllGatherOpName = "AllGather";
|
||||
constexpr auto kHostAllGatherOpName = "HostAllGather";
|
||||
constexpr auto kBroadcastOpName = "Broadcast";
|
||||
constexpr auto kReduceScatterOpName = "ReduceScatter";
|
||||
constexpr auto kHostReduceScatterOpName = "HostReduceScatter";
|
||||
constexpr auto kMemCpyAsyncOpName = "memcpy_async";
|
||||
constexpr auto kTopKOpName = "TopK";
|
||||
constexpr auto kExtractImagePatchesOpName = "ExtractImagePatches";
|
||||
|
|
|
@ -18,9 +18,9 @@ import mindspore.common.dtype as mstype
|
|||
from mindspore.ops import functional as F
|
||||
from .. import operations as P
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations.comm_ops import (AllGather, AllReduce, _AlltoAll, Broadcast,
|
||||
from ..operations.comm_ops import (AllGather, HostAllGather, AllReduce, _AlltoAll, Broadcast,
|
||||
_GetTensorSlice, _MirrorOperator, ReduceOp,
|
||||
ReduceScatter, _VirtualDiv)
|
||||
ReduceScatter, HostReduceScatter, _VirtualDiv)
|
||||
from .grad_base import bprop_getters
|
||||
|
||||
|
||||
|
@ -79,6 +79,21 @@ def get_bprop_all_gather(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(HostAllGather)
|
||||
def get_bprop_host_all_gather(self):
|
||||
"""Generate bprop for HostAllGather"""
|
||||
host_all_gather_grad = HostReduceScatter(ReduceOp.SUM, self.group)
|
||||
if self.instance_name:
|
||||
instance_name = "grad" + self.instance_name
|
||||
host_all_gather_grad.set_prim_instance_name(instance_name)
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = host_all_gather_grad(dout)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(ReduceScatter)
|
||||
def get_bprop_reduce_scatter(self):
|
||||
"""Generate bprop for ReduceScatter"""
|
||||
|
@ -97,6 +112,24 @@ def get_bprop_reduce_scatter(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(HostReduceScatter)
|
||||
def get_bprop_host_reduce_scatter(self):
|
||||
"""Generate bprop for HostReduceScatter"""
|
||||
host_reduce_scatter_grad = HostAllGather(self.group)
|
||||
if self.instance_name:
|
||||
instance_name = "grad" + self.instance_name
|
||||
host_reduce_scatter_grad.set_prim_instance_name(instance_name)
|
||||
|
||||
if self.op != ReduceOp.SUM:
|
||||
raise RuntimeError("The hostreducescatter bprop only support ReduceOp.SUM until now.")
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = host_reduce_scatter_grad(dout)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_AlltoAll)
|
||||
def get_bprop_all_to_all(self):
|
||||
"""Generate bprop for AlltoAll."""
|
||||
|
|
|
@ -32,7 +32,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace)
|
||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, ReduceOp, _VirtualDataset,
|
||||
_VirtualDiv, _GetTensorSlice)
|
||||
_VirtualDiv, _GetTensorSlice,
|
||||
HostAllGather, HostReduceScatter)
|
||||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||
TensorSummary, HistogramSummary, Print)
|
||||
from .control_ops import ControlDepend, GeSwitch, Merge
|
||||
|
@ -217,8 +218,10 @@ __all__ = [
|
|||
'UnsortedSegmentSum',
|
||||
'UnsortedSegmentMin',
|
||||
"AllGather",
|
||||
"HostAllGather",
|
||||
"AllReduce",
|
||||
"ReduceScatter",
|
||||
"HostReduceScatter",
|
||||
"Broadcast",
|
||||
"ReduceOp",
|
||||
'ScalarCast',
|
||||
|
|
|
@ -169,6 +169,72 @@ class AllGather(PrimitiveWithInfer):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class HostAllGather(PrimitiveWithInfer):
|
||||
"""
|
||||
Gathers tensors from the specified communication group on host.
|
||||
|
||||
Note:
|
||||
Tensor must have the same shape and format in all processes participating in the collective.
|
||||
|
||||
Args:
|
||||
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on.
|
||||
|
||||
Raises:
|
||||
TypeError: If group is not a list nor tuple, or elements of group are not int.
|
||||
ValueError: If the local rank id of the calling process not in group,
|
||||
or rank_id from group not in [0, 7].
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
|
||||
Outputs:
|
||||
Tensor. If the number of devices in the group is N,
|
||||
then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.communication import init
|
||||
>>> import mindspore.ops.operations as P
|
||||
>>> init('nccl')
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.hostallgather = P.HostAllGather(group=(0, 1, 2, 3))
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> return self.hostallgather(x)
|
||||
>>>
|
||||
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
|
||||
>>> net = Net()
|
||||
>>> output = net(input_)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, group=None):
|
||||
if group is None:
|
||||
raise ValueError(f"For '{self.name}' group must be set.")
|
||||
validator.check_value_type('group', group, (tuple, list), self.name)
|
||||
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
|
||||
for r in group:
|
||||
validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name)
|
||||
validator.check_value_type("rank_id", r, (int,), self.name)
|
||||
self.group_size = len(group)
|
||||
self.rank = get_rank()
|
||||
validator.check('rank', self.rank, 'group', self.group, Rel.IN, self.name)
|
||||
self.add_prim_attr('group', group)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_integer("x shape", len(x_shape), 0, Rel.GT, self.name)
|
||||
x_shape[0] = x_shape[0] * self.group_size
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
|
||||
return x_dtype
|
||||
|
||||
def __call__(self, tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ReduceScatter(PrimitiveWithInfer):
|
||||
"""
|
||||
Reduces and scatters tensors from the specified communication group.
|
||||
|
@ -226,6 +292,68 @@ class ReduceScatter(PrimitiveWithInfer):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class HostReduceScatter(PrimitiveWithInfer):
|
||||
"""
|
||||
Reduces and scatters tensors from the specified communication group on host.
|
||||
|
||||
Note:
|
||||
Tensor must have the same shape and format in all processes participating in the collective.
|
||||
|
||||
Args:
|
||||
op (str): Specifies an operation used for element-wise reductions,
|
||||
like sum, max, avg. Default: ReduceOp.SUM.
|
||||
group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on.
|
||||
|
||||
Raise:
|
||||
TypeError: If op is not a string and group is not a list nor tuple,
|
||||
or elements of group are not int.
|
||||
ValueError: If the first dimension of input can not be divided by rank size,
|
||||
or group is not set, or rank_id not in [1, 7].
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.communication import init
|
||||
>>> import mindspore.ops.operations as P
|
||||
>>> init('nccl')
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.hostreducescatter = P.HostReduceScatter(ReduceOp.SUM, group=[0, 1, 2, 3])
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> return self.hostreducescatter(x)
|
||||
>>>
|
||||
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
|
||||
>>> net = Net()
|
||||
>>> output = net(input_)
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, op=ReduceOp.SUM, group=None):
|
||||
if group is None:
|
||||
raise ValueError(f"For '{self.name}' group must be set.")
|
||||
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name)
|
||||
validator.check_value_type('group', group, (tuple, list), self.name)
|
||||
validator.check_integer("group size", len(group), 2, Rel.GE, self.name)
|
||||
for r in group:
|
||||
validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name)
|
||||
validator.check_value_type("rank_id", r, (int,), self.name)
|
||||
self.op = op
|
||||
self.group_size = len(group)
|
||||
self.add_prim_attr('group', group)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
if x_shape[0] % self.group_size != 0:
|
||||
raise ValueError(f"For '{self.name}' the first dimension of x should be divided by group_size.")
|
||||
x_shape[0] = int(x_shape[0]/self.group_size)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name)
|
||||
return x_dtype
|
||||
|
||||
def __call__(self, tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Broadcast(PrimitiveWithInfer):
|
||||
"""
|
||||
Broadcasts the tensor to the whole group.
|
||||
|
|
|
@ -26,6 +26,7 @@ from mindspore.nn import Momentum
|
|||
from mindspore.nn import ReLU
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter
|
||||
from mindspore.ops.operations.comm_ops import HostAllGather, HostReduceScatter
|
||||
from mindspore.ops.operations.comm_ops import Broadcast
|
||||
|
||||
# pylint: disable=W0212
|
||||
|
@ -86,6 +87,21 @@ class AllGatherNet(nn.Cell):
|
|||
return self.relu(x)
|
||||
|
||||
|
||||
class HostAllGatherNet(nn.Cell):
|
||||
"""HostAllGatherNet definition"""
|
||||
|
||||
def __init__(self, input_channel, output_channel):
|
||||
super(HostAllGatherNet, self).__init__()
|
||||
self.dense = Dense(input_channel, output_channel)
|
||||
self.hostallgather = HostAllGather((0, 1))
|
||||
self.relu = ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.dense(x)
|
||||
x = self.hostallgather(x)
|
||||
return self.relu(x)
|
||||
|
||||
|
||||
class ReduceScatterNet(nn.Cell):
|
||||
"""ReduceScatterNet definition"""
|
||||
|
||||
|
@ -101,6 +117,21 @@ class ReduceScatterNet(nn.Cell):
|
|||
return self.relu(x)
|
||||
|
||||
|
||||
class HostReduceScatterNet(nn.Cell):
|
||||
"""HostReduceScatterNet definition"""
|
||||
|
||||
def __init__(self, input_channel, out_channel, op):
|
||||
super(HostReduceScatterNet, self).__init__()
|
||||
self.dense = Dense(input_channel, out_channel)
|
||||
self.hostreducescatter = HostReduceScatter(op, (0, 1))
|
||||
self.relu = ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.dense(x)
|
||||
x = self.hostreducescatter(x)
|
||||
return self.relu(x)
|
||||
|
||||
|
||||
class AlltoAllNet(nn.Cell):
|
||||
"""AlltoAllNet definition"""
|
||||
|
||||
|
@ -154,6 +185,21 @@ def test_allgather():
|
|||
_executor.compile(network, input_tensor, label_tensor)
|
||||
|
||||
|
||||
def test_hostallgather():
|
||||
"""test_hostallgather"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
|
||||
label_tensor = Tensor(np.array([[1.2], [2.2], [3.2], [4.2]], dtype=np.float32))
|
||||
network = HostAllGatherNet(2, 1)
|
||||
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
|
||||
learning_rate=0.1,
|
||||
momentum=0.9)
|
||||
network = WithLossCell(network, loss_fn)
|
||||
network = TrainOneStepCell(network, optimizer)
|
||||
_executor.compile(network, input_tensor, label_tensor)
|
||||
|
||||
|
||||
def run_reducescatter(op):
|
||||
"""run_reducescatter"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
@ -175,6 +221,21 @@ def test_reducescatter():
|
|||
run_reducescatter(ReduceOp.SUM)
|
||||
|
||||
|
||||
def test_hostreducescatter():
|
||||
"""test_hostreducescatter"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
|
||||
label_tensor = Tensor(np.array([[1.2]], dtype=np.float32))
|
||||
network = HostReduceScatterNet(2, 1, ReduceOp.SUM)
|
||||
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
|
||||
learning_rate=0.1,
|
||||
momentum=0.9)
|
||||
network = WithLossCell(network, loss_fn)
|
||||
network = TrainOneStepCell(network, optimizer)
|
||||
_executor.compile(network, input_tensor, label_tensor)
|
||||
|
||||
|
||||
def test_broadcast():
|
||||
"""test_broadcast"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
|
Loading…
Reference in New Issue