forked from mindspore-Ecosystem/mindspore
!28774 Add overflow optimizer for Transformer
Merge pull request !28774 from huangxinjing/fx_transformer_overflow
This commit is contained in:
commit
a47044716d
|
@ -857,7 +857,7 @@ class MultiHeadAttention(Cell):
|
|||
((parallel_config.data_parallel, 1, 1, 1),
|
||||
(parallel_config.data_parallel, parallel_config.model_parallel, 1, 1)))
|
||||
# Normalize factor for attention, sqrt(dk) as widely used
|
||||
self.scale_factor = Tensor(math.sqrt(self.size_per_head))
|
||||
self.scale_factor = Tensor(math.sqrt(math.sqrt(self.size_per_head)))
|
||||
self.use_past = use_past
|
||||
self.dropout = _Dropout(1 - hidden_dropout_rate)
|
||||
self.dropout.shard(((parallel_config.data_parallel, 1),))
|
||||
|
@ -1086,11 +1086,10 @@ class MultiHeadAttention(Cell):
|
|||
"""
|
||||
# Normalize query and key before MatMul, default off
|
||||
# Attention score [bs, num_heads, seq_length, seq_length]
|
||||
factor = P.Cast()(self.scale_factor, P.DType()(query))
|
||||
query = self.real_div(query, factor)
|
||||
key = self.real_div(key, factor)
|
||||
score = self.batch_matmul(query, key)
|
||||
# Normalize after query and key MatMul
|
||||
score = self.real_div(
|
||||
score,
|
||||
P.Cast()(self.scale_factor, P.DType()(score)))
|
||||
|
||||
ori_dtype = P.DType()(score)
|
||||
score = P.Cast()(score, self.softmax_dtype)
|
||||
|
|
Loading…
Reference in New Issue