!26783 [Auto parallel] [MoE] Using NotEqual instead of Cast-to-bool for performance

Merge pull request !26783 from Xiaoda/112-change-moe-cast-to-notequal
This commit is contained in:
i-robot 2021-11-26 01:47:49 +00:00 committed by Gitee
commit 981278f6da
1 changed files with 4 additions and 1 deletions

View File

@ -380,6 +380,7 @@ class SwitchRouter(Cell):
self.mul7 = P.Mul().shard(((dp, 1), (dp, 1)))
self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
self.not_equal = P.NotEqual().shard(((dp, 1, 1, 1), ()))
self.cumsum = _CumSum(config=parallel_config)
self.less = P.Less().shard(((dp, 1, 1), ()))
@ -449,5 +450,7 @@ class SwitchRouter(Cell):
combine_tensor = self.mul9(self.expand2(combine_tensor, -1),
self.onehot3(self.cast(position_in_expert, mstype.int32), expert_capacity,
self.on_value, self.off_value))
dispatch_tensor = self.cast(combine_tensor, mstype.bool_)
# dispatch_tensor is of boolean type. Here, using NotEqual instead of Cast, for that 'Cast to bool' has
# bad performance
dispatch_tensor = self.not_equal(combine_tensor, 0.0)
return dispatch_tensor, combine_tensor, loss