forked from mindspore-Ecosystem/mindspore
!746 reducescatter backforward operator
Merge pull request !746 from lirongzhen1/bp_reducescatter
This commit is contained in:
commit
5b3327d103
|
@ -29,7 +29,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
|
||||
// Get the target node's weight for sorting.
|
||||
double GetWeights(const Graph::NodeType &node) {
|
||||
const OperatorRec &op = node.apply;
|
||||
|
|
|
@ -67,11 +67,29 @@ def get_bprop_broad_cast(self):
|
|||
@bprop_getters.register(AllGather)
|
||||
def get_bprop_all_gather(self):
|
||||
"""Generate bprop for AllGather"""
|
||||
reduce_scatter_grad = ReduceScatter(ReduceOp.SUM, self.group)
|
||||
all_gather_grad = ReduceScatter(ReduceOp.SUM, self.group)
|
||||
if self.instance_name:
|
||||
instance_name = "grad" + self.instance_name
|
||||
all_gather_grad.set_prim_instance_name(instance_name)
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = all_gather_grad(dout)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(ReduceScatter)
|
||||
def get_bprop_reduce_scatter(self):
|
||||
"""Generate bprop for ReduceScatter"""
|
||||
reduce_scatter_grad = AllGather(self.group)
|
||||
if self.instance_name:
|
||||
instance_name = "grad" + self.instance_name
|
||||
reduce_scatter_grad.set_prim_instance_name(instance_name)
|
||||
|
||||
if self.op != ReduceOp.SUM:
|
||||
raise RuntimeError("The reducescatter bprop only support ReduceOp.SUM until now.")
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = reduce_scatter_grad(dout)
|
||||
return (dx,)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
""" test Communicate """
|
||||
import numpy as np
|
||||
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp
|
||||
from mindspore.ops.operations.comm_ops import AllReduce, AllGather, _AlltoAll, ReduceOp, ReduceScatter
|
||||
from mindspore.ops.operations.comm_ops import Broadcast
|
||||
from mindspore.communication.management import HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, GlobalComm, init
|
||||
from mindspore.communication._comm_helper import Backend
|
||||
|
@ -78,6 +78,19 @@ class AllGatherNet(nn.Cell):
|
|||
x = self.allgather(x)
|
||||
return self.relu(x)
|
||||
|
||||
class ReduceScatterNet(nn.Cell):
|
||||
"""ReduceScatterNet definition"""
|
||||
def __init__(self, input_channel, out_channel, op):
|
||||
super(ReduceScatterNet, self).__init__()
|
||||
self.dense = Dense(input_channel, out_channel)
|
||||
self.reducescatter = ReduceScatter(op)
|
||||
self.relu = ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.dense(x)
|
||||
x = self.reducescatter(x)
|
||||
return self.relu(x)
|
||||
|
||||
class AlltoAllNet(nn.Cell):
|
||||
"""AlltoAllNet definition"""
|
||||
def __init__(self, input_channel, out_channel):
|
||||
|
@ -126,6 +139,25 @@ def test_allgather():
|
|||
network = TrainOneStepCell(network, optimizer)
|
||||
_executor.compile(network, input_tensor, label_tensor)
|
||||
|
||||
def run_reducescatter(op):
|
||||
"""run_reducescatter"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
input_tensor = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]], dtype=np.float32))
|
||||
label_tensor = Tensor(np.array([[1.2], [2.2]], dtype=np.float32))
|
||||
network = ReduceScatterNet(2, 1, op)
|
||||
loss_fn = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad, network.get_parameters()),
|
||||
learning_rate=0.1,
|
||||
momentum=0.9)
|
||||
network = WithLossCell(network, loss_fn)
|
||||
network = TrainOneStepCell(network, optimizer)
|
||||
_executor.compile(network, input_tensor, label_tensor)
|
||||
|
||||
def test_reducescatter():
|
||||
"""test_reducescatter"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
run_reducescatter(ReduceOp.SUM)
|
||||
|
||||
def test_broadcast():
|
||||
"""test_broadcast"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
|
Loading…
Reference in New Issue