!28774 Add overflow optimizer for Transformer

Merge pull request !28774 from huangxinjing/fx_transformer_overflow
This commit is contained in:
i-robot 2022-01-12 03:17:58 +00:00 committed by Gitee
commit a47044716d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 4 additions and 5 deletions

View File

@ -857,7 +857,7 @@ class MultiHeadAttention(Cell):
((parallel_config.data_parallel, 1, 1, 1),
(parallel_config.data_parallel, parallel_config.model_parallel, 1, 1)))
# Normalize factor for attention, sqrt(dk) as widely used
self.scale_factor = Tensor(math.sqrt(self.size_per_head))
self.scale_factor = Tensor(math.sqrt(math.sqrt(self.size_per_head)))
self.use_past = use_past
self.dropout = _Dropout(1 - hidden_dropout_rate)
self.dropout.shard(((parallel_config.data_parallel, 1),))
@ -1086,11 +1086,10 @@ class MultiHeadAttention(Cell):
"""
# Normalize query and key before MatMul, default off
# Attention score [bs, num_heads, seq_length, seq_length]
factor = P.Cast()(self.scale_factor, P.DType()(query))
query = self.real_div(query, factor)
key = self.real_div(key, factor)
score = self.batch_matmul(query, key)
# Normalize after query and key MatMul
score = self.real_div(
score,
P.Cast()(self.scale_factor, P.DType()(score)))
ori_dtype = P.DType()(score)
score = P.Cast()(score, self.softmax_dtype)