forked from mindspore-Ecosystem/mindspore
!35699 Fix MultiHeadAttention dtype Error
Merge pull request !35699 from huangxinjing/fix_attention_dtype
This commit is contained in:
commit
2919e625e3
|
@ -383,12 +383,15 @@ class _Linear(Cell):
|
|||
x = P.Reshape()(x, (-1, self.in_channels))
|
||||
if self.expert_flag:
|
||||
x = P.Reshape()(x, (self.outer_batch, self.expert_num, -1, self.in_channels))
|
||||
ori_dtype = F.dtype(x)
|
||||
weight = self.cast(self.weight, self.dtype)
|
||||
x = self.cast(x, self.dtype)
|
||||
x = self.matmul(x, weight)
|
||||
if self.has_bias:
|
||||
x = self.bias_add(x, self.cast(self.bias, self.dtype))
|
||||
if self.activation_flag:
|
||||
x = self.activation(x)
|
||||
x = F.cast(x, ori_dtype)
|
||||
output = P.Reshape()(x, out_shape)
|
||||
return output
|
||||
|
||||
|
|
|
@ -877,7 +877,8 @@ class MultiHeadAttention(Cell):
|
|||
self.projection = _Linear(in_channels=hidden_size,
|
||||
out_channels=hidden_size,
|
||||
transpose_b=False,
|
||||
param_init_type=param_init_type).to_float(compute_dtype)
|
||||
compute_dtype=compute_dtype,
|
||||
param_init_type=param_init_type)
|
||||
self.projection.shard(strategy_bias=((parallel_config.data_parallel, 1), (1,)),
|
||||
strategy_matmul=((parallel_config.data_parallel, parallel_config.model_parallel),
|
||||
(parallel_config.model_parallel, 1)))
|
||||
|
@ -910,15 +911,18 @@ class MultiHeadAttention(Cell):
|
|||
# Query
|
||||
self.dense1 = _Linear(hidden_size,
|
||||
hidden_size,
|
||||
param_init_type=param_init_type).to_float(compute_dtype)
|
||||
compute_dtype=compute_dtype,
|
||||
param_init_type=param_init_type)
|
||||
# Key
|
||||
self.dense2 = _Linear(hidden_size,
|
||||
hidden_size,
|
||||
param_init_type=param_init_type).to_float(compute_dtype)
|
||||
compute_dtype=compute_dtype,
|
||||
param_init_type=param_init_type)
|
||||
# Value
|
||||
self.dense3 = _Linear(hidden_size,
|
||||
hidden_size,
|
||||
param_init_type=param_init_type).to_float(compute_dtype)
|
||||
compute_dtype=compute_dtype,
|
||||
param_init_type=param_init_type)
|
||||
|
||||
self.dtype = compute_dtype
|
||||
self.softmax_dtype = softmax_compute_type
|
||||
|
@ -972,7 +976,8 @@ class MultiHeadAttention(Cell):
|
|||
self.projection = _Linear(in_channels=hidden_size,
|
||||
out_channels=hidden_size,
|
||||
transpose_b=False,
|
||||
param_init_type=param_init_type).to_float(compute_dtype)
|
||||
compute_dtype=compute_dtype,
|
||||
param_init_type=param_init_type)
|
||||
self.projection.shard(strategy_bias=((parallel_config.data_parallel, 1), (1,)),
|
||||
strategy_matmul=((parallel_config.data_parallel, parallel_config.model_parallel),
|
||||
(parallel_config.model_parallel, 1)))
|
||||
|
@ -1019,14 +1024,16 @@ class MultiHeadAttention(Cell):
|
|||
# Query
|
||||
self.dense1 = _Linear(hidden_size,
|
||||
hidden_size,
|
||||
param_init_type=param_init_type).to_float(compute_dtype)
|
||||
compute_dtype=compute_dtype,
|
||||
param_init_type=param_init_type)
|
||||
self.dense1.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
|
||||
strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
|
||||
(parallel_config.model_parallel,)))
|
||||
# Key
|
||||
self.dense2 = _Linear(hidden_size,
|
||||
hidden_size,
|
||||
param_init_type=param_init_type).to_float(compute_dtype)
|
||||
compute_dtype=compute_dtype,
|
||||
param_init_type=param_init_type)
|
||||
self.dense2.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
|
||||
strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
|
||||
(parallel_config.model_parallel,)))
|
||||
|
@ -1034,7 +1041,8 @@ class MultiHeadAttention(Cell):
|
|||
# Value
|
||||
self.dense3 = _Linear(hidden_size,
|
||||
hidden_size,
|
||||
param_init_type=param_init_type).to_float(compute_dtype)
|
||||
compute_dtype=compute_dtype,
|
||||
param_init_type=param_init_type)
|
||||
self.dense3.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
|
||||
strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
|
||||
(parallel_config.model_parallel,)))
|
||||
|
@ -1066,7 +1074,10 @@ class MultiHeadAttention(Cell):
|
|||
key_tensor,
|
||||
value_tensor,
|
||||
attention_mask)
|
||||
|
||||
ori_dtype = F.dtype(query_tensor)
|
||||
query_tensor = F.cast(query_tensor, self.dtype)
|
||||
key_tensor = F.cast(key_tensor, self.dtype)
|
||||
value_tensor = F.cast(value_tensor, self.dtype)
|
||||
# multi head attention: query, key, value are derived from the same inputs
|
||||
query = self.dense1(query_tensor)
|
||||
key = self.dense2(key_tensor)
|
||||
|
@ -1137,6 +1148,7 @@ class MultiHeadAttention(Cell):
|
|||
output = self.projection(attention)
|
||||
output = self.dropout(output)
|
||||
output = F.reshape(output, ori_shape)
|
||||
output = F.cast(output, ori_dtype)
|
||||
return output, layer_present
|
||||
|
||||
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
|
||||
|
@ -2060,6 +2072,7 @@ class TransformerDecoderLayer(Cell):
|
|||
if encoder_output is not None:
|
||||
middle_output = self.cross_attention_layernorm(x)
|
||||
middle_output = F.cast(middle_output, self.dtype)
|
||||
encoder_output = F.cast(encoder_output, self.dtype)
|
||||
cross_attn_output, cross_layer_present = self.cross_attention(middle_output, encoder_output,
|
||||
encoder_output,
|
||||
memory_mask, self.key_past,
|
||||
|
|
|
@ -217,6 +217,24 @@ def test_multihead_attention_wrong_batch():
|
|||
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
|
||||
|
||||
|
||||
def test_multihead_attention_fp32_dtype():
|
||||
"""
|
||||
Feature: Test MultiHeadAttention with float32 as compute dtype
|
||||
Description: Test using float32 as computation for linear layer.
|
||||
Expectation: No exception
|
||||
"""
|
||||
model = MultiHeadAttention(hidden_size=15,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=20,
|
||||
compute_dtype=dtype.float32,
|
||||
batch_size=2,
|
||||
num_heads=3)
|
||||
from_tensor = Tensor(np.ones((2, 20, 15)), dtype.float32)
|
||||
to_tensor = Tensor(np.ones((2, 20, 15)), dtype.float32)
|
||||
attention_mask = Tensor(np.ones((2, 20, 20)), dtype.float32)
|
||||
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
|
||||
|
||||
|
||||
def test_feedforward_layer():
|
||||
model = FeedForward(hidden_size=15,
|
||||
ffn_hidden_size=30,
|
||||
|
|
Loading…
Reference in New Issue