From 2f8e7ff693a57e2e7c132048a3c56835b267ec74 Mon Sep 17 00:00:00 2001 From: Yi Huaijie Date: Tue, 26 May 2020 15:19:25 +0800 Subject: [PATCH] add operator HostAllGather and HostReduceScatter --- mindspore/ccsrc/transform/convert.cc | 2 + mindspore/ccsrc/utils/utils.h | 2 + mindspore/ops/_grad/grad_comm_ops.py | 37 +++++- mindspore/ops/operations/__init__.py | 5 +- mindspore/ops/operations/comm_ops.py | 128 +++++++++++++++++++++ tests/ut/python/communication/test_comm.py | 61 ++++++++++ 6 files changed, 232 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index 56b1e149ac9..aee0654c458 100644 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -55,7 +55,9 @@ const char kNameSimpleMeanGrad[] = "SimpleMeanGrad"; const char kNameAllReduce[] = "AllReduce"; const char kNameBroadcast[] = "Broadcast"; const char kNameAllgather[] = "AllGather"; +const char kNameHostAllgather[] = "HostAllGather"; const char kNameReduceScatter[] = "ReduceScatter"; +const char kNameHostReduceScatter[] = "HostReduceScatter"; const char kNameReduceSum[] = "ReduceSum"; const char kNameIsFinite[] = "isFinite"; const char kNameReciprocal[] = "Reciprocal"; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index ae420cfaec5..bc06d61a67a 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -45,8 +45,10 @@ constexpr auto kAtomicAddrCleanOpName = "AtomicAddrClean"; constexpr auto kGetNextOpName = "GetNext"; constexpr auto kAllReduceOpName = "AllReduce"; constexpr auto kAllGatherOpName = "AllGather"; +constexpr auto kHostAllGatherOpName = "HostAllGather"; constexpr auto kBroadcastOpName = "Broadcast"; constexpr auto kReduceScatterOpName = "ReduceScatter"; +constexpr auto kHostReduceScatterOpName = "HostReduceScatter"; constexpr auto kMemCpyAsyncOpName = "memcpy_async"; constexpr auto kTopKOpName = "TopK"; constexpr auto kExtractImagePatchesOpName = "ExtractImagePatches"; diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 97b8b3fdf30..057d150be1e 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -18,9 +18,9 @@ import mindspore.common.dtype as mstype from mindspore.ops import functional as F from .. import operations as P from ..composite.multitype_ops.zeros_like_impl import zeros_like -from ..operations.comm_ops import (AllGather, AllReduce, _AlltoAll, Broadcast, +from ..operations.comm_ops import (AllGather, HostAllGather, AllReduce, _AlltoAll, Broadcast, _GetTensorSlice, _MirrorOperator, ReduceOp, - ReduceScatter, _VirtualDiv) + ReduceScatter, HostReduceScatter, _VirtualDiv) from .grad_base import bprop_getters @@ -79,6 +79,21 @@ def get_bprop_all_gather(self): return bprop +@bprop_getters.register(HostAllGather) +def get_bprop_host_all_gather(self): + """Generate bprop for HostAllGather""" + host_all_gather_grad = HostReduceScatter(ReduceOp.SUM, self.group) + if self.instance_name: + instance_name = "grad" + self.instance_name + host_all_gather_grad.set_prim_instance_name(instance_name) + + def bprop(x, out, dout): + dx = host_all_gather_grad(dout) + return (dx,) + + return bprop + + @bprop_getters.register(ReduceScatter) def get_bprop_reduce_scatter(self): """Generate bprop for ReduceScatter""" @@ -97,6 +112,24 @@ def get_bprop_reduce_scatter(self): return bprop +@bprop_getters.register(HostReduceScatter) +def get_bprop_host_reduce_scatter(self): + """Generate bprop for HostReduceScatter""" + host_reduce_scatter_grad = HostAllGather(self.group) + if self.instance_name: + instance_name = "grad" + self.instance_name + host_reduce_scatter_grad.set_prim_instance_name(instance_name) + + if self.op != ReduceOp.SUM: + raise RuntimeError("The hostreducescatter bprop only support ReduceOp.SUM until now.") + + def bprop(x, out, dout): + dx = host_reduce_scatter_grad(dout) + return (dx,) + + return bprop + + @bprop_getters.register(_AlltoAll) def get_bprop_all_to_all(self): """Generate bprop for AlltoAll.""" diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index e933fa97013..4b842b707e0 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -32,7 +32,8 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace) from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast, _MirrorOperator, ReduceOp, _VirtualDataset, - _VirtualDiv, _GetTensorSlice) + _VirtualDiv, _GetTensorSlice, + HostAllGather, HostReduceScatter) from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, TensorSummary, HistogramSummary, Print) from .control_ops import ControlDepend, GeSwitch, Merge @@ -217,8 +218,10 @@ __all__ = [ 'UnsortedSegmentSum', 'UnsortedSegmentMin', "AllGather", + "HostAllGather", "AllReduce", "ReduceScatter", + "HostReduceScatter", "Broadcast", "ReduceOp", 'ScalarCast', diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index f5c005e8198..19e94a1f5fc 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -169,6 +169,72 @@ class AllGather(PrimitiveWithInfer): raise NotImplementedError +class HostAllGather(PrimitiveWithInfer): + """ + Gathers tensors from the specified communication group on host. + + Note: + Tensor must have the same shape and format in all processes participating in the collective. + + Args: + group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. + + Raises: + TypeError: If group is not a list nor tuple, or elements of group are not int. + ValueError: If the local rank id of the calling process not in group, + or rank_id from group not in [0, 7]. + + Inputs: + - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + + Outputs: + Tensor. If the number of devices in the group is N, + then the shape of output is :math:`(N, x_1, x_2, ..., x_R)`. + + Examples: + >>> from mindspore.communication import init + >>> import mindspore.ops.operations as P + >>> init('nccl') + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.hostallgather = P.HostAllGather(group=(0, 1, 2, 3)) + >>> + >>> def construct(self, x): + >>> return self.hostallgather(x) + >>> + >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) + >>> net = Net() + >>> output = net(input_) + """ + + @prim_attr_register + def __init__(self, group=None): + if group is None: + raise ValueError(f"For '{self.name}' group must be set.") + validator.check_value_type('group', group, (tuple, list), self.name) + validator.check_integer("group size", len(group), 2, Rel.GE, self.name) + for r in group: + validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name) + validator.check_value_type("rank_id", r, (int,), self.name) + self.group_size = len(group) + self.rank = get_rank() + validator.check('rank', self.rank, 'group', self.group, Rel.IN, self.name) + self.add_prim_attr('group', group) + + def infer_shape(self, x_shape): + validator.check_integer("x shape", len(x_shape), 0, Rel.GT, self.name) + x_shape[0] = x_shape[0] * self.group_size + return x_shape + + def infer_dtype(self, x_dtype): + validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) + return x_dtype + + def __call__(self, tensor): + raise NotImplementedError + + class ReduceScatter(PrimitiveWithInfer): """ Reduces and scatters tensors from the specified communication group. @@ -226,6 +292,68 @@ class ReduceScatter(PrimitiveWithInfer): raise NotImplementedError +class HostReduceScatter(PrimitiveWithInfer): + """ + Reduces and scatters tensors from the specified communication group on host. + + Note: + Tensor must have the same shape and format in all processes participating in the collective. + + Args: + op (str): Specifies an operation used for element-wise reductions, + like sum, max, avg. Default: ReduceOp.SUM. + group (Union[tuple[int],list[int]]): The rand_ids of communication group to work on. + + Raise: + TypeError: If op is not a string and group is not a list nor tuple, + or elements of group are not int. + ValueError: If the first dimension of input can not be divided by rank size, + or group is not set, or rank_id not in [1, 7]. + + Examples: + >>> from mindspore.communication import init + >>> import mindspore.ops.operations as P + >>> init('nccl') + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.hostreducescatter = P.HostReduceScatter(ReduceOp.SUM, group=[0, 1, 2, 3]) + >>> + >>> def construct(self, x): + >>> return self.hostreducescatter(x) + >>> + >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) + >>> net = Net() + >>> output = net(input_) + """ + @prim_attr_register + def __init__(self, op=ReduceOp.SUM, group=None): + if group is None: + raise ValueError(f"For '{self.name}' group must be set.") + validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) + validator.check_value_type('group', group, (tuple, list), self.name) + validator.check_integer("group size", len(group), 2, Rel.GE, self.name) + for r in group: + validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name) + validator.check_value_type("rank_id", r, (int,), self.name) + self.op = op + self.group_size = len(group) + self.add_prim_attr('group', group) + + def infer_shape(self, x_shape): + if x_shape[0] % self.group_size != 0: + raise ValueError(f"For '{self.name}' the first dimension of x should be divided by group_size.") + x_shape[0] = int(x_shape[0]/self.group_size) + return x_shape + + def infer_dtype(self, x_dtype): + validator.check_tensor_type_same({'x': x_dtype}, target_dtypes, self.name) + return x_dtype + + def __call__(self, tensor): + raise NotImplementedError + + class Broadcast(PrimitiveWithInfer): """ Broadcasts the tensor to the whole group. diff --git a/tests/ut/python/communication/test_comm.py b/tests/ut/python/communication/test_comm.py index 7688adb41a5..f3530cb2612 100644 --- a/tests/ut/python/communication/test_comm.py +++ b/tests/ut/python/communication/test_comm.py @@ -26,6 +26,7 @@ from mindspore.nn import Momentum from mindspore.nn import ReLU from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter +from mindspore.ops.operations.comm_ops import HostAllGather, HostReduceScatter from mindspore.ops.operations.comm_ops import Broadcast # pylint: disable=W0212 @@ -86,6 +87,21 @@ class AllGatherNet(nn.Cell): return self.relu(x) +class HostAllGatherNet(nn.Cell): + """HostAllGatherNet definition""" + + def __init__(self, input_channel, output_channel): + super(HostAllGatherNet, self).__init__() + self.dense = Dense(input_channel, output_channel) + self.hostallgather = HostAllGather((0, 1)) + self.relu = ReLU() + + def construct(self, x): + x = self.dense(x) + x = self.hostallgather(x) + return self.relu(x) + + class ReduceScatterNet(nn.Cell): """ReduceScatterNet definition""" @@ -101,6 +117,21 @@ class ReduceScatterNet(nn.Cell): return self.relu(x) +class HostReduceScatterNet(nn.Cell): + """HostReduceScatterNet definition""" + + def __init__(self, input_channel, out_channel, op): + super(HostReduceScatterNet, self).__init__() + self.dense = Dense(input_channel, out_channel) + self.hostreducescatter = HostReduceScatter(op, (0, 1)) + self.relu = ReLU() + + def construct(self, x): + x = self.dense(x) + x = self.hostreducescatter(x) + return self.relu(x) + + class AlltoAllNet(nn.Cell): """AlltoAllNet definition""" @@ -154,6 +185,21 @@ def test_allgather(): _executor.compile(network, input_tensor, label_tensor) +def test_hostallgather(): + """test_hostallgather""" + context.set_context(mode=context.GRAPH_MODE) + input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32)) + label_tensor = Tensor(np.array([[1.2], [2.2], [3.2], [4.2]], dtype=np.float32)) + network = HostAllGatherNet(2, 1) + loss_fn = nn.SoftmaxCrossEntropyWithLogits() + optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), + learning_rate=0.1, + momentum=0.9) + network = WithLossCell(network, loss_fn) + network = TrainOneStepCell(network, optimizer) + _executor.compile(network, input_tensor, label_tensor) + + def run_reducescatter(op): """run_reducescatter""" context.set_context(mode=context.GRAPH_MODE) @@ -175,6 +221,21 @@ def test_reducescatter(): run_reducescatter(ReduceOp.SUM) +def test_hostreducescatter(): + """test_hostreducescatter""" + context.set_context(mode=context.GRAPH_MODE) + input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32)) + label_tensor = Tensor(np.array([[1.2]], dtype=np.float32)) + network = HostReduceScatterNet(2, 1, ReduceOp.SUM) + loss_fn = nn.SoftmaxCrossEntropyWithLogits() + optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), + learning_rate=0.1, + momentum=0.9) + network = WithLossCell(network, loss_fn) + network = TrainOneStepCell(network, optimizer) + _executor.compile(network, input_tensor, label_tensor) + + def test_broadcast(): """test_broadcast""" context.set_context(mode=context.GRAPH_MODE)