From 2655ccdee1c1d2d9dbc1d6fca257c58a3b4ae259 Mon Sep 17 00:00:00 2001 From: Xiaoda Zhang Date: Thu, 25 Nov 2021 14:49:29 +0800 Subject: [PATCH] using notequal instead of cast-to-bool for performance --- mindspore/parallel/nn/moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mindspore/parallel/nn/moe.py b/mindspore/parallel/nn/moe.py index 028d484c1bd..0b4e0c90a46 100644 --- a/mindspore/parallel/nn/moe.py +++ b/mindspore/parallel/nn/moe.py @@ -380,6 +380,7 @@ class SwitchRouter(Cell): self.mul7 = P.Mul().shard(((dp, 1), (dp, 1))) self.mul8 = P.Mul().shard(((dp, 1, 1), (dp, 1, 1))) self.mul9 = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) + self.not_equal = P.NotEqual().shard(((dp, 1, 1, 1), ())) self.cumsum = _CumSum(config=parallel_config) self.less = P.Less().shard(((dp, 1, 1), ())) @@ -449,5 +450,7 @@ class SwitchRouter(Cell): combine_tensor = self.mul9(self.expand2(combine_tensor, -1), self.onehot3(self.cast(position_in_expert, mstype.int32), expert_capacity, self.on_value, self.off_value)) - dispatch_tensor = self.cast(combine_tensor, mstype.bool_) + # dispatch_tensor is of boolean type. Here, using NotEqual instead of Cast, for that 'Cast to bool' has + # bad performance + dispatch_tensor = self.not_equal(combine_tensor, 0.0) return dispatch_tensor, combine_tensor, loss