rm constant sharding
This commit is contained in:
parent
797d3336d6
commit
f7949289e6
|
@ -417,6 +417,7 @@ class TopkRouter(Cell):
|
|||
self.div1 = P.RealDiv()
|
||||
self.div2 = P.RealDiv()
|
||||
self.add = P.Add()
|
||||
self.add1 = P.Add()
|
||||
self.add2 = P.Add()
|
||||
self.add3 = P.Add()
|
||||
self.add4 = P.Add()
|
||||
|
@ -453,8 +454,8 @@ class TopkRouter(Cell):
|
|||
self.reduce_mean2 = P.ReduceMean(keep_dims=False).shard(((dp, 1, 1),))
|
||||
self.reduce_mean3 = P.ReduceMean(keep_dims=False).shard(((dp, 1),))
|
||||
self.mul = P.Mul().shard(((dp, 1), (dp, 1)))
|
||||
self.mul2 = P.Mul().shard(((1,), ()))
|
||||
self.mul3 = P.Mul().shard(((1,), ()))
|
||||
self.mul2 = P.Mul().shard(((), ()))
|
||||
self.mul3 = P.Mul().shard(((), ()))
|
||||
self.mul4 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
|
||||
self.mul5 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
|
||||
self.mul6 = P.Mul().shard(((dp, 1), (dp, 1)))
|
||||
|
@ -465,6 +466,7 @@ class TopkRouter(Cell):
|
|||
self.div1 = P.RealDiv().shard(((dp, 1, 1), (dp, 1, 1)))
|
||||
self.div2 = P.RealDiv().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
|
||||
self.add = P.Add().shard(((dp, 1, 1), (dp, 1, 1)))
|
||||
self.add1 = P.Add().shard(((dp, 1, 1), ()))
|
||||
self.add2 = P.Add().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
|
||||
self.add3 = P.Add().shard(((dp, 1), (dp, 1)))
|
||||
self.add4 = P.Add().shard(((dp, 1, 1, 1), ()))
|
||||
|
@ -537,7 +539,7 @@ class TopkRouter(Cell):
|
|||
# expert_mask's shape: (dp_group, tokens_per_group, self.expert_dim)
|
||||
expert_mask = self.onehot(expert_index, self.expert_dim, self.on_value, self.off_value)
|
||||
# renormalize the rest prob to be of sum 1
|
||||
router_prob_normal = self.div1(router_prob, self.add(self.reduce_sum_keep(router_prob, -1), 1e-9))
|
||||
router_prob_normal = self.div1(router_prob, self.add1(self.reduce_sum_keep(router_prob, -1), 1e-9))
|
||||
|
||||
# the balance loss is computed at each routing step
|
||||
loss += self._auxiliary_loss(expert_mask, router_prob_normal)
|
||||
|
|
|
@ -66,6 +66,20 @@ class NetWithLossFiveInputs(nn.Cell):
|
|||
return self.loss(predict)
|
||||
|
||||
|
||||
class NetWithLossMoe(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLossMoe, self).__init__()
|
||||
self.network = network
|
||||
self.add = P.Add().shard(((), ()))
|
||||
self.reduce_mean = P.ReduceMean(keep_dims=False).shard(((1, 1),))
|
||||
|
||||
def construct(self, x1, x2, x3, x4, x5):
|
||||
predict, _, _, moe_loss = self.network(x1, x2, x3, x4, x5)
|
||||
predict = P.Reshape()(predict, (-1, 1))
|
||||
predict = self.reduce_mean(predict)
|
||||
return self.add(predict, moe_loss)
|
||||
|
||||
|
||||
def test_transformer_model():
|
||||
"""
|
||||
Feature: Test Transformer+MoE, with All2All enabled.
|
||||
|
@ -91,7 +105,7 @@ def test_transformer_model():
|
|||
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
|
||||
net = NetWithLossFiveInputs(net)
|
||||
net = NetWithLossMoe(net)
|
||||
params = net.trainable_params()
|
||||
optimizer = AdamWeightDecay(params)
|
||||
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||
|
|
Loading…
Reference in New Issue