From 0b4648881b10481beef2908a0e3f6d73feca8b19 Mon Sep 17 00:00:00 2001 From: lirongzhen1 Date: Sun, 26 Apr 2020 22:27:59 +0800 Subject: [PATCH] add reducescatter bprop --- .../auto_parallel/rec_core/rec_partition.cc | 1 - mindspore/ops/_grad/grad_comm_ops.py | 20 ++++++++++- tests/ut/python/communication/test_comm.py | 34 ++++++++++++++++++- 3 files changed, 52 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc index 81e0eaa2dd8..eafe4784a42 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc @@ -29,7 +29,6 @@ namespace mindspore { namespace parallel { - // Get the target node's weight for sorting. double GetWeights(const Graph::NodeType &node) { const OperatorRec &op = node.apply; diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 3a31c8aeec2..97b8b3fdf30 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -67,11 +67,29 @@ def get_bprop_broad_cast(self): @bprop_getters.register(AllGather) def get_bprop_all_gather(self): """Generate bprop for AllGather""" - reduce_scatter_grad = ReduceScatter(ReduceOp.SUM, self.group) + all_gather_grad = ReduceScatter(ReduceOp.SUM, self.group) + if self.instance_name: + instance_name = "grad" + self.instance_name + all_gather_grad.set_prim_instance_name(instance_name) + + def bprop(x, out, dout): + dx = all_gather_grad(dout) + return (dx,) + + return bprop + + +@bprop_getters.register(ReduceScatter) +def get_bprop_reduce_scatter(self): + """Generate bprop for ReduceScatter""" + reduce_scatter_grad = AllGather(self.group) if self.instance_name: instance_name = "grad" + self.instance_name reduce_scatter_grad.set_prim_instance_name(instance_name) + if self.op != ReduceOp.SUM: + raise RuntimeError("The reducescatter bprop only support ReduceOp.SUM until now.") + def bprop(x, out, dout): dx = reduce_scatter_grad(dout) return (dx,) diff --git a/tests/ut/python/communication/test_comm.py b/tests/ut/python/communication/test_comm.py index 38fd7199fd4..981603b687a 100644 --- a/tests/ut/python/communication/test_comm.py +++ b/tests/ut/python/communication/test_comm.py @@ -14,7 +14,7 @@ """ test Communicate """ import numpy as np -from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp +from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter from mindspore.ops.operations.comm_ops import Broadcast from mindspore.communication.management import HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, GlobalComm, init from mindspore.communication._comm_helper import Backend @@ -78,6 +78,19 @@ class AllGatherNet(nn.Cell): x = self.allgather(x) return self.relu(x) +class ReduceScatterNet(nn.Cell): + """ReduceScatterNet definition""" + def __init__(self, input_channel, out_channel, op): + super(ReduceScatterNet, self).__init__() + self.dense = Dense(input_channel, out_channel) + self.reducescatter = ReduceScatter(op) + self.relu = ReLU() + + def construct(self, x): + x = self.dense(x) + x = self.reducescatter(x) + return self.relu(x) + class AlltoAllNet(nn.Cell): """AlltoAllNet definition""" def __init__(self, input_channel, out_channel): @@ -126,6 +139,25 @@ def test_allgather(): network = TrainOneStepCell(network, optimizer) _executor.compile(network, input_tensor, label_tensor) +def run_reducescatter(op): + """run_reducescatter""" + context.set_context(mode=context.GRAPH_MODE) + input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32)) + label_tensor = Tensor(np.array([[1.2], [2.2]], dtype=np.float32)) + network = ReduceScatterNet(2, 1, op) + loss_fn = nn.SoftmaxCrossEntropyWithLogits() + optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()), + learning_rate=0.1, + momentum=0.9) + network = WithLossCell(network, loss_fn) + network = TrainOneStepCell(network, optimizer) + _executor.compile(network, input_tensor, label_tensor) + +def test_reducescatter(): + """test_reducescatter""" + context.set_context(mode=context.GRAPH_MODE) + run_reducescatter(ReduceOp.SUM) + def test_broadcast(): """test_broadcast""" context.set_context(mode=context.GRAPH_MODE)