From f7949289e6a362524658eae493d2d23520ad9a8c Mon Sep 17 00:00:00 2001 From: wangshengnan123 Date: Thu, 14 Apr 2022 11:15:42 +0800 Subject: [PATCH] rm constant sharding --- mindspore/python/mindspore/nn/transformer/moe.py | 8 +++++--- tests/ut/python/parallel/test_parallel_moe.py | 16 +++++++++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/mindspore/python/mindspore/nn/transformer/moe.py b/mindspore/python/mindspore/nn/transformer/moe.py index d78beab3c7c..580b88f8024 100644 --- a/mindspore/python/mindspore/nn/transformer/moe.py +++ b/mindspore/python/mindspore/nn/transformer/moe.py @@ -417,6 +417,7 @@ class TopkRouter(Cell): self.div1 = P.RealDiv() self.div2 = P.RealDiv() self.add = P.Add() + self.add1 = P.Add() self.add2 = P.Add() self.add3 = P.Add() self.add4 = P.Add() @@ -453,8 +454,8 @@ class TopkRouter(Cell): self.reduce_mean2 = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),)) self.reduce_mean3 = P.ReduceMean(keep_dims=False).shard(((dp, 1),)) self.mul = P.Mul().shard(((dp, 1), (dp, 1))) - self.mul2 = P.Mul().shard(((1,), ())) - self.mul3 = P.Mul().shard(((1,), ())) + self.mul2 = P.Mul().shard(((), ())) + self.mul3 = P.Mul().shard(((), ())) self.mul4 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul5 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul6 = P.Mul().shard(((dp, 1), (dp, 1))) @@ -465,6 +466,7 @@ class TopkRouter(Cell): self.div1 = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1))) self.div2 = P.RealDiv().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) self.add = P.Add().shard(((dp, 1, 1), (dp, 1, 1))) + self.add1 = P.Add().shard(((dp, 1, 1), ())) self.add2 = P.Add().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) self.add3 = P.Add().shard(((dp, 1), (dp, 1))) self.add4 = P.Add().shard(((dp, 1, 1, 1), ())) @@ -537,7 +539,7 @@ class TopkRouter(Cell): # expert_mask's shape: (dp_group, tokens_per_group, self.expert_dim) expert_mask = self.onehot(expert_index, self.expert_dim, self.on_value, self.off_value) # renormalize the rest prob to be of sum 1 - router_prob_normal = self.div1(router_prob, self.add(self.reduce_sum_keep(router_prob, -1), 1e-9)) + router_prob_normal = self.div1(router_prob, self.add1(self.reduce_sum_keep(router_prob, -1), 1e-9)) # the balance loss is computed at each routing step loss += self._auxiliary_loss(expert_mask, router_prob_normal) diff --git a/tests/ut/python/parallel/test_parallel_moe.py b/tests/ut/python/parallel/test_parallel_moe.py index 653284126f3..b127b6da5ec 100644 --- a/tests/ut/python/parallel/test_parallel_moe.py +++ b/tests/ut/python/parallel/test_parallel_moe.py @@ -66,6 +66,20 @@ class NetWithLossFiveInputs(nn.Cell): return self.loss(predict) +class NetWithLossMoe(nn.Cell): + def __init__(self, network): + super(NetWithLossMoe, self).__init__() + self.network = network + self.add = P.Add().shard(((), ())) + self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((1, 1),)) + + def construct(self, x1, x2, x3, x4, x5): + predict, _, _, moe_loss = self.network(x1, x2, x3, x4, x5) + predict = P.Reshape()(predict, (-1, 1)) + predict = self.reduce_mean(predict) + return self.add(predict, moe_loss) + + def test_transformer_model(): """ Feature: Test Transformer+MoE, with All2All enabled. @@ -91,7 +105,7 @@ def test_transformer_model(): decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32) decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16) memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16) - net = NetWithLossFiveInputs(net) + net = NetWithLossMoe(net) params = net.trainable_params() optimizer = AdamWeightDecay(params) dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,