rm constant sharding

This commit is contained in:
wangshengnan123 2022-04-14 11:15:42 +08:00
parent 797d3336d6
commit f7949289e6
2 changed files with 20 additions and 4 deletions

View File

@ -417,6 +417,7 @@ class TopkRouter(Cell):
self.div1 = P.RealDiv()
self.div2 = P.RealDiv()
self.add = P.Add()
self.add1 = P.Add()
self.add2 = P.Add()
self.add3 = P.Add()
self.add4 = P.Add()
@ -453,8 +454,8 @@ class TopkRouter(Cell):
self.reduce_mean2 = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),))
self.reduce_mean3 = P.ReduceMean(keep_dims=False).shard(((dp, 1),))
self.mul = P.Mul().shard(((dp, 1), (dp, 1)))
self.mul2 = P.Mul().shard(((1,), ()))
self.mul3 = P.Mul().shard(((1,), ()))
self.mul2 = P.Mul().shard(((), ()))
self.mul3 = P.Mul().shard(((), ()))
self.mul4 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
self.mul5 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
self.mul6 = P.Mul().shard(((dp, 1), (dp, 1)))
@ -465,6 +466,7 @@ class TopkRouter(Cell):
self.div1 = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1)))
self.div2 = P.RealDiv().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
self.add = P.Add().shard(((dp, 1, 1), (dp, 1, 1)))
self.add1 = P.Add().shard(((dp, 1, 1), ()))
self.add2 = P.Add().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
self.add3 = P.Add().shard(((dp, 1), (dp, 1)))
self.add4 = P.Add().shard(((dp, 1, 1, 1), ()))
@ -537,7 +539,7 @@ class TopkRouter(Cell):
# expert_mask's shape: (dp_group, tokens_per_group, self.expert_dim)
expert_mask = self.onehot(expert_index, self.expert_dim, self.on_value, self.off_value)
# renormalize the rest prob to be of sum 1
router_prob_normal = self.div1(router_prob, self.add(self.reduce_sum_keep(router_prob, -1), 1e-9))
router_prob_normal = self.div1(router_prob, self.add1(self.reduce_sum_keep(router_prob, -1), 1e-9))
# the balance loss is computed at each routing step
loss += self._auxiliary_loss(expert_mask, router_prob_normal)

View File

@ -66,6 +66,20 @@ class NetWithLossFiveInputs(nn.Cell):
return self.loss(predict)
class NetWithLossMoe(nn.Cell):
def __init__(self, network):
super(NetWithLossMoe, self).__init__()
self.network = network
self.add = P.Add().shard(((), ()))
self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((1, 1),))
def construct(self, x1, x2, x3, x4, x5):
predict, _, _, moe_loss = self.network(x1, x2, x3, x4, x5)
predict = P.Reshape()(predict, (-1, 1))
predict = self.reduce_mean(predict)
return self.add(predict, moe_loss)
def test_transformer_model():
"""
Feature: Test Transformer+MoE, with All2All enabled.
@ -91,7 +105,7 @@ def test_transformer_model():
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
net = NetWithLossFiveInputs(net)
net = NetWithLossMoe(net)
params = net.trainable_params()
optimizer = AdamWeightDecay(params)
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,