fix an error of configuring parallel

This commit is contained in:
Xiaoda Zhang 2022-02-28 17:03:53 +08:00
parent acc5567559
commit 81e5abe580
1 changed files with 1 additions and 1 deletions

View File

@ -153,7 +153,7 @@ class TransformerNet(nn.Cell):
ffn_hidden_size=64, ffn_hidden_size=64,
moe_config=moe_config, moe_config=moe_config,
parallel_config=parallel_config) parallel_config=parallel_config)
self.loss = CrossEntropyLoss(parallel_config=config.moe_parallel_config) self.loss = CrossEntropyLoss(parallel_config=parallel_config.moe_parallel_config)
def construct(self, x1, x2, x3, x4, x5, y, mask): def construct(self, x1, x2, x3, x4, x5, y, mask):
predict, _, _ = self.network(x1, x2, x3, x4, x5) predict, _, _ = self.network(x1, x2, x3, x4, x5)