!35699 Fix MultiHeadAttention dtype Error

Merge pull request !35699 from huangxinjing/fix_attention_dtype
This commit is contained in:
i-robot 2022-06-20 11:19:32 +00:00 committed by Gitee
commit 2919e625e3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 43 additions and 9 deletions

View File

@ -383,12 +383,15 @@ class _Linear(Cell):
x = P.Reshape()(x, (-1, self.in_channels))
if self.expert_flag:
x = P.Reshape()(x, (self.outer_batch, self.expert_num, -1, self.in_channels))
ori_dtype = F.dtype(x)
weight = self.cast(self.weight, self.dtype)
x = self.cast(x, self.dtype)
x = self.matmul(x, weight)
if self.has_bias:
x = self.bias_add(x, self.cast(self.bias, self.dtype))
if self.activation_flag:
x = self.activation(x)
x = F.cast(x, ori_dtype)
output = P.Reshape()(x, out_shape)
return output

View File

@ -877,7 +877,8 @@ class MultiHeadAttention(Cell):
self.projection = _Linear(in_channels=hidden_size,
out_channels=hidden_size,
transpose_b=False,
param_init_type=param_init_type).to_float(compute_dtype)
compute_dtype=compute_dtype,
param_init_type=param_init_type)
self.projection.shard(strategy_bias=((parallel_config.data_parallel, 1), (1,)),
strategy_matmul=((parallel_config.data_parallel, parallel_config.model_parallel),
(parallel_config.model_parallel, 1)))
@ -910,15 +911,18 @@ class MultiHeadAttention(Cell):
# Query
self.dense1 = _Linear(hidden_size,
hidden_size,
param_init_type=param_init_type).to_float(compute_dtype)
compute_dtype=compute_dtype,
param_init_type=param_init_type)
# Key
self.dense2 = _Linear(hidden_size,
hidden_size,
param_init_type=param_init_type).to_float(compute_dtype)
compute_dtype=compute_dtype,
param_init_type=param_init_type)
# Value
self.dense3 = _Linear(hidden_size,
hidden_size,
param_init_type=param_init_type).to_float(compute_dtype)
compute_dtype=compute_dtype,
param_init_type=param_init_type)
self.dtype = compute_dtype
self.softmax_dtype = softmax_compute_type
@ -972,7 +976,8 @@ class MultiHeadAttention(Cell):
self.projection = _Linear(in_channels=hidden_size,
out_channels=hidden_size,
transpose_b=False,
param_init_type=param_init_type).to_float(compute_dtype)
compute_dtype=compute_dtype,
param_init_type=param_init_type)
self.projection.shard(strategy_bias=((parallel_config.data_parallel, 1), (1,)),
strategy_matmul=((parallel_config.data_parallel, parallel_config.model_parallel),
(parallel_config.model_parallel, 1)))
@ -1019,14 +1024,16 @@ class MultiHeadAttention(Cell):
# Query
self.dense1 = _Linear(hidden_size,
hidden_size,
param_init_type=param_init_type).to_float(compute_dtype)
compute_dtype=compute_dtype,
param_init_type=param_init_type)
self.dense1.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
(parallel_config.model_parallel,)))
# Key
self.dense2 = _Linear(hidden_size,
hidden_size,
param_init_type=param_init_type).to_float(compute_dtype)
compute_dtype=compute_dtype,
param_init_type=param_init_type)
self.dense2.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
(parallel_config.model_parallel,)))
@ -1034,7 +1041,8 @@ class MultiHeadAttention(Cell):
# Value
self.dense3 = _Linear(hidden_size,
hidden_size,
param_init_type=param_init_type).to_float(compute_dtype)
compute_dtype=compute_dtype,
param_init_type=param_init_type)
self.dense3.shard(strategy_matmul=((parallel_config.data_parallel, 1), (parallel_config.model_parallel, 1)),
strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
(parallel_config.model_parallel,)))
@ -1066,7 +1074,10 @@ class MultiHeadAttention(Cell):
key_tensor,
value_tensor,
attention_mask)
ori_dtype = F.dtype(query_tensor)
query_tensor = F.cast(query_tensor, self.dtype)
key_tensor = F.cast(key_tensor, self.dtype)
value_tensor = F.cast(value_tensor, self.dtype)
# multi head attention: query, key, value are derived from the same inputs
query = self.dense1(query_tensor)
key = self.dense2(key_tensor)
@ -1137,6 +1148,7 @@ class MultiHeadAttention(Cell):
output = self.projection(attention)
output = self.dropout(output)
output = F.reshape(output, ori_shape)
output = F.cast(output, ori_dtype)
return output, layer_present
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
@ -2060,6 +2072,7 @@ class TransformerDecoderLayer(Cell):
if encoder_output is not None:
middle_output = self.cross_attention_layernorm(x)
middle_output = F.cast(middle_output, self.dtype)
encoder_output = F.cast(encoder_output, self.dtype)
cross_attn_output, cross_layer_present = self.cross_attention(middle_output, encoder_output,
encoder_output,
memory_mask, self.key_past,

View File

@ -217,6 +217,24 @@ def test_multihead_attention_wrong_batch():
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
def test_multihead_attention_fp32_dtype():
"""
Feature: Test MultiHeadAttention with float32 as compute dtype
Description: Test using float32 as computation for linear layer.
Expectation: No exception
"""
model = MultiHeadAttention(hidden_size=15,
src_seq_length=20,
tgt_seq_length=20,
compute_dtype=dtype.float32,
batch_size=2,
num_heads=3)
from_tensor = Tensor(np.ones((2, 20, 15)), dtype.float32)
to_tensor = Tensor(np.ones((2, 20, 15)), dtype.float32)
attention_mask = Tensor(np.ones((2, 20, 20)), dtype.float32)
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
def test_feedforward_layer():
model = FeedForward(hidden_size=15,
ffn_hidden_size=30,