diff --git a/mindspore/python/mindspore/nn/transformer/transformer.py b/mindspore/python/mindspore/nn/transformer/transformer.py index bb4550721c5..f8b4c817b47 100644 --- a/mindspore/python/mindspore/nn/transformer/transformer.py +++ b/mindspore/python/mindspore/nn/transformer/transformer.py @@ -857,7 +857,7 @@ class MultiHeadAttention(Cell): ((parallel_config.data_parallel, 1, 1, 1), (parallel_config.data_parallel, parallel_config.model_parallel, 1, 1))) # Normalize factor for attention, sqrt(dk) as widely used - self.scale_factor = Tensor(math.sqrt(self.size_per_head)) + self.scale_factor = Tensor(math.sqrt(math.sqrt(self.size_per_head))) self.use_past = use_past self.dropout = _Dropout(1 - hidden_dropout_rate) self.dropout.shard(((parallel_config.data_parallel, 1),)) @@ -1086,11 +1086,10 @@ class MultiHeadAttention(Cell): """ # Normalize query and key before MatMul, default off # Attention score [bs, num_heads, seq_length, seq_length] + factor = P.Cast()(self.scale_factor, P.DType()(query)) + query = self.real_div(query, factor) + key = self.real_div(key, factor) score = self.batch_matmul(query, key) - # Normalize after query and key MatMul - score = self.real_div( - score, - P.Cast()(self.scale_factor, P.DType()(score))) ori_dtype = P.DType()(score) score = P.Cast()(score, self.softmax_dtype)