forked from mindspore-Ecosystem/mindspore
!23108 Fix spelling error for transformer
Merge pull request !23108 from huangxinjing/fix_spell_error
This commit is contained in:
commit
0abff9ad65
|
@ -184,7 +184,7 @@ class _LayerNorm(Cell):
|
||||||
Args:
|
Args:
|
||||||
strategy (tuple): The strategy for the dropout. Should be the same shape as the inputs.
|
strategy (tuple): The strategy for the dropout. Should be the same shape as the inputs.
|
||||||
Examples:
|
Examples:
|
||||||
>>> net = nn.parallel.transformer.LayerNorm(normalized_shape=(1024, 10))
|
>>> net = mindspore.parallel.nn.transformer.LayerNorm(normalized_shape=(1024, 10))
|
||||||
>>> net.shard(((10, 2, 1),))
|
>>> net.shard(((10, 2, 1),))
|
||||||
"""
|
"""
|
||||||
self.mean.shard(strategy)
|
self.mean.shard(strategy)
|
||||||
|
|
|
@ -30,6 +30,8 @@ from mindspore.ops.primitive import constexpr
|
||||||
from mindspore.nn.cell import Cell
|
from mindspore.nn.cell import Cell
|
||||||
from mindspore._checkparam import Validator
|
from mindspore._checkparam import Validator
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
|
from mindspore.parallel._utils import _get_parallel_mode
|
||||||
|
from mindspore.context import ParallelMode
|
||||||
from .layers import _LayerNorm, _Linear, _check_input_shape, \
|
from .layers import _LayerNorm, _Linear, _check_input_shape, \
|
||||||
_args_type_validator_check, _valid_type_checks, _valid_value_checks, \
|
_args_type_validator_check, _valid_type_checks, _valid_value_checks, \
|
||||||
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value, Router
|
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value, Router
|
||||||
|
@ -284,7 +286,11 @@ class FeedForward(Cell):
|
||||||
will project the input dimension from hidden_size to ffn_hidden_size, the second linear will project the
|
will project the input dimension from hidden_size to ffn_hidden_size, the second linear will project the
|
||||||
dimension from ffn_hidden_size to hidden_size. The first linear is sharded on the relative dimension,
|
dimension from ffn_hidden_size to hidden_size. The first linear is sharded on the relative dimension,
|
||||||
the second linear is sharded on the output dimension. The overview process can be
|
the second linear is sharded on the output dimension. The overview process can be
|
||||||
`DROPOUT(FFN(FFN(x)))`
|
|
||||||
|
.. math::
|
||||||
|
Dropout((xW_1+b_1)W_2 + b_2))
|
||||||
|
|
||||||
|
where the W_1, W_2, b_1 and b_2 are trainable parameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hidden_size (int): The dimension of the inputs.
|
hidden_size (int): The dimension of the inputs.
|
||||||
|
@ -308,7 +314,7 @@ class FeedForward(Cell):
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: `hidden_act` is not a string.
|
ValueError: `hidden_act` is not a string.
|
||||||
ValueError: `parallel_config` is not a subclass of OpParallelConfig.
|
TypeError: `parallel_config` is not a subclass of OpParallelConfig.
|
||||||
ValueError: `ffn_hidden_size` is not a multiple of the model parallel way.
|
ValueError: `ffn_hidden_size` is not a multiple of the model parallel way.
|
||||||
ValueError: `hidden_size` is not a multiple of the model parallel way.
|
ValueError: `hidden_size` is not a multiple of the model parallel way.
|
||||||
|
|
||||||
|
@ -343,12 +349,12 @@ class FeedForward(Cell):
|
||||||
dp = parallel_config.data_parallel
|
dp = parallel_config.data_parallel
|
||||||
mp = parallel_config.model_parallel
|
mp = parallel_config.model_parallel
|
||||||
if ffn_hidden_size % mp != 0:
|
if ffn_hidden_size % mp != 0:
|
||||||
raise ValueError("ffn_hidden_size {ffn_hidden_size} should be a multiple of the model parallel way {mp}")
|
raise ValueError(f"ffn_hidden_size {ffn_hidden_size} should be a multiple of the model parallel way {mp}")
|
||||||
if hidden_size % mp != 0:
|
if hidden_size % mp != 0:
|
||||||
raise ValueError("hidden_size {hidden_size} should be a multiple of the model parallel way {mp}")
|
raise ValueError(f"hidden_size {hidden_size} should be a multiple of the model parallel way {mp}")
|
||||||
if dropout_rate < 0 or dropout_rate >= 1:
|
if dropout_rate < 0 or dropout_rate >= 1:
|
||||||
raise ValueError("dropout_rate probability should be a number in range [0, 1.0), "
|
raise ValueError(f"dropout_rate probability should be a number in range [0, 1.0), "
|
||||||
"but got {}".format(dropout_rate))
|
"but got {dropout_rate}")
|
||||||
input_size = hidden_size
|
input_size = hidden_size
|
||||||
output_size = ffn_hidden_size
|
output_size = ffn_hidden_size
|
||||||
# Here, 'ep' stands for expert parallel number, which is equal to data parallel number.
|
# Here, 'ep' stands for expert parallel number, which is equal to data parallel number.
|
||||||
|
@ -360,6 +366,7 @@ class FeedForward(Cell):
|
||||||
transpose_b=False,
|
transpose_b=False,
|
||||||
expert_num=expert_num,
|
expert_num=expert_num,
|
||||||
param_init_type=param_init_type)
|
param_init_type=param_init_type)
|
||||||
|
|
||||||
if expert_num > 1:
|
if expert_num > 1:
|
||||||
self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)),
|
self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)),
|
||||||
strategy_bias=((ep, 1, mp), (mp,)),
|
strategy_bias=((ep, 1, mp), (mp,)),
|
||||||
|
@ -368,7 +375,7 @@ class FeedForward(Cell):
|
||||||
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
|
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
|
||||||
strategy_bias=((dp, mp), (mp,)),
|
strategy_bias=((dp, mp), (mp,)),
|
||||||
strategy_activation=((dp, 1, mp),))
|
strategy_activation=((dp, 1, mp),))
|
||||||
# Project back to embedding_size
|
# Project back to hidden_size
|
||||||
self.projection = _Linear(in_channels=output_size,
|
self.projection = _Linear(in_channels=output_size,
|
||||||
out_channels=input_size,
|
out_channels=input_size,
|
||||||
transpose_b=False,
|
transpose_b=False,
|
||||||
|
@ -515,6 +522,7 @@ class MoE(Cell):
|
||||||
aux_loss = self.mul(self.aux_loss_factor, aux_loss)
|
aux_loss = self.mul(self.aux_loss_factor, aux_loss)
|
||||||
return combined_output, aux_loss
|
return combined_output, aux_loss
|
||||||
|
|
||||||
|
|
||||||
class AttentionMask(Cell):
|
class AttentionMask(Cell):
|
||||||
r"""
|
r"""
|
||||||
Get the Lower triangular matrix from the input mask. The input mask is a 2D tensor (batch_size, seq_length)
|
Get the Lower triangular matrix from the input mask. The input mask is a 2D tensor (batch_size, seq_length)
|
||||||
|
@ -535,14 +543,14 @@ class AttentionMask(Cell):
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: `seq_length` is not a int.
|
TypeError: `seq_length` is not a int.
|
||||||
ValueError: `seq_length` is not a positive value.
|
ValueError: `seq_length` is not a positive value.
|
||||||
ValueError: `parallel_config` is not a subclass of OpParallelConfig.
|
TypeError: `parallel_config` is not a subclass of OpParallelConfig.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> mask = mindspore.parallel.nn.AttentionMask(seq_length=4)
|
>>> mask = mindspore.parallel.nn.AttentionMask(seq_length=4)
|
||||||
>>> mask_array = np.array([[1, 1, 1, 0]], np.int32)
|
>>> mask_array = np.array([[1, 1, 1, 0]], np.float32)
|
||||||
>>> inputs = Tensor(mask_array)
|
>>> inputs = Tensor(mask_array)
|
||||||
>>> res = mask(inputs)
|
>>> res = mask(inputs)
|
||||||
>>> print(res)
|
>>> print(res)
|
||||||
|
@ -617,7 +625,7 @@ class VocabEmbedding(Cell):
|
||||||
parallel_config.model_parallel
|
parallel_config.model_parallel
|
||||||
ValueError: `vocab_size` is not a positive value.
|
ValueError: `vocab_size` is not a positive value.
|
||||||
ValueError: `embedding_size` is not a positive value.
|
ValueError: `embedding_size` is not a positive value.
|
||||||
ValueError: `parallel_config` is not a subclass of OpParallelConfig.
|
TypeError: `parallel_config` is not a subclass of OpParallelConfig.
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU``
|
||||||
|
@ -661,9 +669,17 @@ class VocabEmbedding(Cell):
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(Cell):
|
class MultiHeadAttention(Cell):
|
||||||
"""
|
r"""
|
||||||
This is an implementation of multihead attention in the paper `Attention is all you need
|
This is an implementation of multihead attention in the paper `Attention is all you need
|
||||||
<https://arxiv.org/pdf/1706.03762v5.pdf>`_.
|
<https://arxiv.org/pdf/1706.03762v5.pdf>`_. Given the query vector with source length, and the
|
||||||
|
key and value vector with target length, the attention will be performered as the following
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
MultiHeadAttention(query, key, vector) = Concat(head_1, \dots, head_h)W^O
|
||||||
|
|
||||||
|
where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`. The default is with a bias.
|
||||||
|
|
||||||
|
if query, key and value tensor is same, then it will be self attention.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch_size(int): The batch size of the input tensor.
|
batch_size(int): The batch size of the input tensor.
|
||||||
|
@ -714,7 +730,7 @@ class MultiHeadAttention(Cell):
|
||||||
... num_heads=3)
|
... num_heads=3)
|
||||||
>>> from_tensor = Tensor(np.ones((2, 20, 15)), dtype.float32)
|
>>> from_tensor = Tensor(np.ones((2, 20, 15)), dtype.float32)
|
||||||
>>> to_tensor = Tensor(np.ones((2, 20, 15)), dtype.float16)
|
>>> to_tensor = Tensor(np.ones((2, 20, 15)), dtype.float16)
|
||||||
>>> attention_mask = Tensor(np.ones((2, 1, 20, 20)), dtype.float16)
|
>>> attention_mask = Tensor(np.ones((2, 20, 20)), dtype.float16)
|
||||||
>>> attn_out, past = model(from_tensor, to_tensor, to_tensor, attention_mask)
|
>>> attn_out, past = model(from_tensor, to_tensor, to_tensor, attention_mask)
|
||||||
>>> print(attn_out.shape)
|
>>> print(attn_out.shape)
|
||||||
(2, 20, 15)
|
(2, 20, 15)
|
||||||
|
@ -731,7 +747,7 @@ class MultiHeadAttention(Cell):
|
||||||
tgt_seq_length=Validator.check_positive_int,
|
tgt_seq_length=Validator.check_positive_int,
|
||||||
attention_dropout_rate=Validator.check_non_negative_float,
|
attention_dropout_rate=Validator.check_non_negative_float,
|
||||||
hidden_dropout_rate=Validator.check_non_negative_float,
|
hidden_dropout_rate=Validator.check_non_negative_float,
|
||||||
softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16],
|
softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"MultiHeadAttention"),
|
"MultiHeadAttention"),
|
||||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"MultiHeadAttention"),
|
"MultiHeadAttention"),
|
||||||
|
@ -746,7 +762,7 @@ class MultiHeadAttention(Cell):
|
||||||
hidden_dropout_rate=0.1,
|
hidden_dropout_rate=0.1,
|
||||||
attention_dropout_rate=0.1,
|
attention_dropout_rate=0.1,
|
||||||
compute_dtype=mstype.float16,
|
compute_dtype=mstype.float16,
|
||||||
softmax_comptue_type=mstype.float32,
|
softmax_compute_type=mstype.float32,
|
||||||
param_init_type=mstype.float32,
|
param_init_type=mstype.float32,
|
||||||
use_past=False,
|
use_past=False,
|
||||||
parallel_config=default_dpmp_config):
|
parallel_config=default_dpmp_config):
|
||||||
|
@ -757,11 +773,11 @@ class MultiHeadAttention(Cell):
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1:
|
if hidden_dropout_rate < 0 or hidden_dropout_rate >= 1:
|
||||||
raise ValueError("hidden_dropout_rate probability should be a number in range [0, 1.0), "
|
raise ValueError(f"hidden_dropout_rate probability should be a number in range [0, 1.0), "
|
||||||
"but got {}".format(hidden_dropout_rate))
|
"but got {hidden_dropout_rate}")
|
||||||
if attention_dropout_rate < 0 or attention_dropout_rate >= 1:
|
if attention_dropout_rate < 0 or attention_dropout_rate >= 1:
|
||||||
raise ValueError("attention_dropout_rate probability should be a number in range [0, 1.0), "
|
raise ValueError(f"attention_dropout_rate probability should be a number in range [0, 1.0), "
|
||||||
"but got {}".format(attention_dropout_rate))
|
"but got {attention_dropout_rate}")
|
||||||
if hidden_size % num_heads != 0:
|
if hidden_size % num_heads != 0:
|
||||||
raise ValueError(f"The hidden size {hidden_size} should be a multiple of num_heads {num_heads}")
|
raise ValueError(f"The hidden size {hidden_size} should be a multiple of num_heads {num_heads}")
|
||||||
if num_heads % parallel_config.model_parallel != 0:
|
if num_heads % parallel_config.model_parallel != 0:
|
||||||
|
@ -837,7 +853,7 @@ class MultiHeadAttention(Cell):
|
||||||
strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
|
strategy_bias=((parallel_config.data_parallel, parallel_config.model_parallel),
|
||||||
(parallel_config.model_parallel,)))
|
(parallel_config.model_parallel,)))
|
||||||
self.dtype = compute_dtype
|
self.dtype = compute_dtype
|
||||||
self.softmax_dtype = softmax_comptue_type
|
self.softmax_dtype = softmax_compute_type
|
||||||
if self.use_past:
|
if self.use_past:
|
||||||
# operators used for state reuse
|
# operators used for state reuse
|
||||||
seq_range = np.arange(src_seq_length).reshape(1, 1, -1)
|
seq_range = np.arange(src_seq_length).reshape(1, 1, -1)
|
||||||
|
@ -933,11 +949,10 @@ class MultiHeadAttention(Cell):
|
||||||
|
|
||||||
layer_present = (key_present, value_present)
|
layer_present = (key_present, value_present)
|
||||||
# multi head attention considering attention mask
|
# multi head attention considering attention mask
|
||||||
|
# [bs, seq_length, hidden_size]
|
||||||
attention = self._attn(query, key, value, attention_mask)
|
attention = self._attn(query, key, value, attention_mask)
|
||||||
# [bs, seq_length, embedding_size]
|
|
||||||
attention_merge = self._merge_heads(attention)
|
|
||||||
# Output
|
# Output
|
||||||
output = self.projection(attention_merge)
|
output = self.projection(attention)
|
||||||
output = self.dropout(output)
|
output = self.dropout(output)
|
||||||
return output, layer_present
|
return output, layer_present
|
||||||
|
|
||||||
|
@ -1038,7 +1053,8 @@ class MultiHeadAttention(Cell):
|
||||||
attention_probs = self.prob_dropout(attention_probs)
|
attention_probs = self.prob_dropout(attention_probs)
|
||||||
# Weighted sum output [bs, num_heads, seq_length, size_per_head]
|
# Weighted sum output [bs, num_heads, seq_length, size_per_head]
|
||||||
weighted_values = self.batch_matmul(attention_probs, value)
|
weighted_values = self.batch_matmul(attention_probs, value)
|
||||||
return weighted_values
|
attention_merge = self._merge_heads(weighted_values)
|
||||||
|
return attention_merge
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(Cell):
|
class TransformerEncoderLayer(Cell):
|
||||||
|
@ -1060,7 +1076,7 @@ class TransformerEncoderLayer(Cell):
|
||||||
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
|
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
|
||||||
layernorm_compute_type(dtype.Number): The computation type of the layernorm.
|
layernorm_compute_type(dtype.Number): The computation type of the layernorm.
|
||||||
Can be dtype.float32 or dtype.float16. Default dtype.float16.
|
Can be dtype.float32 or dtype.float16. Default dtype.float16.
|
||||||
softmax_comptue_type(dtype.Number): The computation type of the softmax in the attention.
|
softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
|
||||||
Can be dtype.float32 or dtype.float16. Default mstype.float16.
|
Can be dtype.float32 or dtype.float16. Default mstype.float16.
|
||||||
param_init_type(dtype.Number): The parameter initialization type of the module.
|
param_init_type(dtype.Number): The parameter initialization type of the module.
|
||||||
Can be dtype.float32 or dtype.float16. Default dtype.float32.
|
Can be dtype.float32 or dtype.float16. Default dtype.float32.
|
||||||
|
@ -1115,7 +1131,7 @@ class TransformerEncoderLayer(Cell):
|
||||||
post_layernorm_residual=Validator.check_bool,
|
post_layernorm_residual=Validator.check_bool,
|
||||||
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"TransformerEncoderLayer"),
|
"TransformerEncoderLayer"),
|
||||||
softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16],
|
softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"TransformerEncoderLayer"),
|
"TransformerEncoderLayer"),
|
||||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"TransformerEncoderLayer"),
|
"TransformerEncoderLayer"),
|
||||||
|
@ -1132,7 +1148,7 @@ class TransformerEncoderLayer(Cell):
|
||||||
hidden_dropout_rate=0.1,
|
hidden_dropout_rate=0.1,
|
||||||
post_layernorm_residual=False,
|
post_layernorm_residual=False,
|
||||||
layernorm_compute_type=mstype.float32,
|
layernorm_compute_type=mstype.float32,
|
||||||
softmax_comptue_type=mstype.float32,
|
softmax_compute_type=mstype.float32,
|
||||||
param_init_type=mstype.float32,
|
param_init_type=mstype.float32,
|
||||||
hidden_act='gelu',
|
hidden_act='gelu',
|
||||||
use_past=False,
|
use_past=False,
|
||||||
|
@ -1142,8 +1158,16 @@ class TransformerEncoderLayer(Cell):
|
||||||
_check_config(parallel_config)
|
_check_config(parallel_config)
|
||||||
if num_heads % parallel_config.model_parallel != 0:
|
if num_heads % parallel_config.model_parallel != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel},"
|
f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel}, "
|
||||||
f"but found {num_heads}")
|
f"but found {num_heads}")
|
||||||
|
if hidden_size % parallel_config.model_parallel != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"hidden_size must be divisibled by the model parallel way {parallel_config.model_parallel}, "
|
||||||
|
f"but found {hidden_size}")
|
||||||
|
if ffn_hidden_size % parallel_config.model_parallel != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"ffn_hidden_size must be divisibled by the model parallel way {parallel_config.model_parallel}, "
|
||||||
|
f"but found {ffn_hidden_size}")
|
||||||
self.use_past = use_past
|
self.use_past = use_past
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
@ -1160,7 +1184,7 @@ class TransformerEncoderLayer(Cell):
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
hidden_dropout_rate=hidden_dropout_rate,
|
hidden_dropout_rate=hidden_dropout_rate,
|
||||||
attention_dropout_rate=attention_dropout_rate,
|
attention_dropout_rate=attention_dropout_rate,
|
||||||
softmax_comptue_type=softmax_comptue_type,
|
softmax_compute_type=softmax_compute_type,
|
||||||
param_init_type=param_init_type,
|
param_init_type=param_init_type,
|
||||||
use_past=use_past,
|
use_past=use_past,
|
||||||
parallel_config=parallel_config)
|
parallel_config=parallel_config)
|
||||||
|
@ -1298,7 +1322,7 @@ class TransformerDecoderLayer(Cell):
|
||||||
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
|
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
|
||||||
layernorm_compute_type(dtype.Number): The computation type of the layernorm.
|
layernorm_compute_type(dtype.Number): The computation type of the layernorm.
|
||||||
Can be dtype.float32 or dtype.float16. Default dtype.float16.
|
Can be dtype.float32 or dtype.float16. Default dtype.float16.
|
||||||
softmax_comptue_type(dtype.Number): The computation type of the softmax in the attention.
|
softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
|
||||||
Can be dtype.float32 or dtype.float16. Default mstype.float16.
|
Can be dtype.float32 or dtype.float16. Default mstype.float16.
|
||||||
param_init_type: The parameter initialization type of the module. Can be dtype.float32 or dtype.float16.
|
param_init_type: The parameter initialization type of the module. Can be dtype.float32 or dtype.float16.
|
||||||
Default dtype.float32.
|
Default dtype.float32.
|
||||||
|
@ -1337,8 +1361,8 @@ class TransformerDecoderLayer(Cell):
|
||||||
... src_seq_length=20, tgt_seq_length=10)
|
... src_seq_length=20, tgt_seq_length=10)
|
||||||
>>> encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
|
>>> encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
|
||||||
>>> decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
|
>>> decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
|
||||||
>>> decoder_input_mask = Tensor(np.ones((2, 1, 10, 10)), dtype.float16)
|
>>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
|
||||||
>>> memory_mask = Tensor(np.ones((2, 1, 10, 20)), dtype.float16)
|
>>> memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
|
||||||
>>> output, past = model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
|
>>> output, past = model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
|
||||||
>>> print(output.shape)
|
>>> print(output.shape)
|
||||||
(2, 10, 64)
|
(2, 10, 64)
|
||||||
|
@ -1364,7 +1388,7 @@ class TransformerDecoderLayer(Cell):
|
||||||
post_layernorm_residual=Validator.check_bool,
|
post_layernorm_residual=Validator.check_bool,
|
||||||
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"TransformerDecoderLayer"),
|
"TransformerDecoderLayer"),
|
||||||
softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16],
|
softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"TransformerDecoderLayer"),
|
"TransformerDecoderLayer"),
|
||||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"TransformerDecoderLayer"),
|
"TransformerDecoderLayer"),
|
||||||
|
@ -1382,16 +1406,28 @@ class TransformerDecoderLayer(Cell):
|
||||||
post_layernorm_residual=False,
|
post_layernorm_residual=False,
|
||||||
use_past=False,
|
use_past=False,
|
||||||
layernorm_compute_type=mstype.float32,
|
layernorm_compute_type=mstype.float32,
|
||||||
softmax_comptue_type=mstype.float32,
|
softmax_compute_type=mstype.float32,
|
||||||
param_init_type=mstype.float32,
|
param_init_type=mstype.float32,
|
||||||
hidden_act='gelu',
|
hidden_act='gelu',
|
||||||
moe_config=default_moe_config,
|
moe_config=default_moe_config,
|
||||||
parallel_config=default_dpmp_config):
|
parallel_config=default_dpmp_config):
|
||||||
super(TransformerDecoderLayer, self).__init__()
|
super(TransformerDecoderLayer, self).__init__()
|
||||||
_check_config(parallel_config)
|
_check_config(parallel_config)
|
||||||
|
if num_heads % parallel_config.model_parallel != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"num heads must be divisibled by the model parallel way {parallel_config.model_parallel}, "
|
||||||
|
f"but found {num_heads}")
|
||||||
|
if hidden_size % parallel_config.model_parallel != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"hidden_size must be divisibled by the model parallel way {parallel_config.model_parallel}, "
|
||||||
|
f"but found {hidden_size}")
|
||||||
|
if ffn_hidden_size % parallel_config.model_parallel != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"ffn_hidden_size must be divisibled by the model parallel way {parallel_config.model_parallel}, "
|
||||||
|
f"but found {ffn_hidden_size}")
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.use_past = use_past
|
self.use_past = use_past
|
||||||
self.softmax_comptue_type = softmax_comptue_type
|
self.softmax_compute_type = softmax_compute_type
|
||||||
|
|
||||||
self.src_seq_length = src_seq_length
|
self.src_seq_length = src_seq_length
|
||||||
self.tgt_seq_length = tgt_seq_length
|
self.tgt_seq_length = tgt_seq_length
|
||||||
|
@ -1411,7 +1447,7 @@ class TransformerDecoderLayer(Cell):
|
||||||
hidden_dropout_rate=hidden_dropout_rate,
|
hidden_dropout_rate=hidden_dropout_rate,
|
||||||
attention_dropout_rate=attention_dropout_rate,
|
attention_dropout_rate=attention_dropout_rate,
|
||||||
use_past=use_past,
|
use_past=use_past,
|
||||||
softmax_comptue_type=softmax_comptue_type,
|
softmax_compute_type=softmax_compute_type,
|
||||||
param_init_type=param_init_type,
|
param_init_type=param_init_type,
|
||||||
parallel_config=parallel_config)
|
parallel_config=parallel_config)
|
||||||
# Cross attention with the output of encoder as memory tensor
|
# Cross attention with the output of encoder as memory tensor
|
||||||
|
@ -1422,7 +1458,7 @@ class TransformerDecoderLayer(Cell):
|
||||||
tgt_seq_length=src_seq_length,
|
tgt_seq_length=src_seq_length,
|
||||||
hidden_dropout_rate=hidden_dropout_rate,
|
hidden_dropout_rate=hidden_dropout_rate,
|
||||||
attention_dropout_rate=attention_dropout_rate,
|
attention_dropout_rate=attention_dropout_rate,
|
||||||
softmax_comptue_type=softmax_comptue_type,
|
softmax_compute_type=softmax_compute_type,
|
||||||
use_past=use_past,
|
use_past=use_past,
|
||||||
param_init_type=param_init_type,
|
param_init_type=param_init_type,
|
||||||
parallel_config=parallel_config)
|
parallel_config=parallel_config)
|
||||||
|
@ -1614,7 +1650,8 @@ def _get_lambda_func(total_layer=None):
|
||||||
|
|
||||||
class TransformerEncoder(Cell):
|
class TransformerEncoder(Cell):
|
||||||
r"""
|
r"""
|
||||||
Transformer Encoder module with multi-layer stacled of `TransformerEncoderLayer`.
|
Transformer Encoder module with multi-layer stacked of `TransformerEncoderLayer`, including multihead self
|
||||||
|
attention and feedforward layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch_size(int): The batch size of the input tensor.
|
batch_size(int): The batch size of the input tensor.
|
||||||
|
@ -1631,7 +1668,7 @@ class TransformerEncoder(Cell):
|
||||||
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
|
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
|
||||||
layernorm_compute_type(dtype.Number): The computation type of the layernorm.
|
layernorm_compute_type(dtype.Number): The computation type of the layernorm.
|
||||||
Can be dtype.float32 or dtype.float16. Default dtype.float16.
|
Can be dtype.float32 or dtype.float16. Default dtype.float16.
|
||||||
softmax_comptue_type(dtype.Number): The computation type of the softmax in the attention.
|
softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
|
||||||
Can be dtype.float32 or dtype.float16. Default mstype.float16.
|
Can be dtype.float32 or dtype.float16. Default mstype.float16.
|
||||||
param_init_type: The parameter initialization type of the module. Can be dtype.float32 or dtype.float16.
|
param_init_type: The parameter initialization type of the module. Can be dtype.float32 or dtype.float16.
|
||||||
Default dtype.float32.
|
Default dtype.float32.
|
||||||
|
@ -1697,7 +1734,7 @@ class TransformerEncoder(Cell):
|
||||||
post_layernorm_residual=Validator.check_bool,
|
post_layernorm_residual=Validator.check_bool,
|
||||||
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"TransformerEncoder"),
|
"TransformerEncoder"),
|
||||||
softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16],
|
softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"TransformerEncoder"),
|
"TransformerEncoder"),
|
||||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"TransformerEncoder"),
|
"TransformerEncoder"),
|
||||||
|
@ -1716,7 +1753,7 @@ class TransformerEncoder(Cell):
|
||||||
hidden_act='gelu',
|
hidden_act='gelu',
|
||||||
post_layernorm_residual=False,
|
post_layernorm_residual=False,
|
||||||
layernorm_compute_type=mstype.float32,
|
layernorm_compute_type=mstype.float32,
|
||||||
softmax_comptue_type=mstype.float32,
|
softmax_compute_type=mstype.float32,
|
||||||
param_init_type=mstype.float32,
|
param_init_type=mstype.float32,
|
||||||
lambda_func=None,
|
lambda_func=None,
|
||||||
offset=0,
|
offset=0,
|
||||||
|
@ -1729,6 +1766,8 @@ class TransformerEncoder(Cell):
|
||||||
self.use_moe = (moe_config.expert_num > 1)
|
self.use_moe = (moe_config.expert_num > 1)
|
||||||
self.add = P.TensorAdd().shard(((), ()))
|
self.add = P.TensorAdd().shard(((), ()))
|
||||||
self.aux_loss = Tensor(0.0, mstype.float32)
|
self.aux_loss = Tensor(0.0, mstype.float32)
|
||||||
|
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
|
||||||
|
raise RuntimeError(f"The {self.cls_name} does not support auto parallel mode now.")
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.blocks = nn.CellList()
|
self.blocks = nn.CellList()
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
|
@ -1739,7 +1778,7 @@ class TransformerEncoder(Cell):
|
||||||
attention_dropout_rate=attention_dropout_rate,
|
attention_dropout_rate=attention_dropout_rate,
|
||||||
hidden_dropout_rate=hidden_dropout_rate,
|
hidden_dropout_rate=hidden_dropout_rate,
|
||||||
layernorm_compute_type=layernorm_compute_type,
|
layernorm_compute_type=layernorm_compute_type,
|
||||||
softmax_comptue_type=softmax_comptue_type,
|
softmax_compute_type=softmax_compute_type,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
hidden_act=hidden_act,
|
hidden_act=hidden_act,
|
||||||
post_layernorm_residual=post_layernorm_residual,
|
post_layernorm_residual=post_layernorm_residual,
|
||||||
|
@ -1780,7 +1819,8 @@ class TransformerEncoder(Cell):
|
||||||
|
|
||||||
class TransformerDecoder(Cell):
|
class TransformerDecoder(Cell):
|
||||||
r"""
|
r"""
|
||||||
Transformer Decoder module with multi-layer stacled of `TransformerDecoderLayer`.
|
Transformer Decoder module with multi-layer stacked of `TransformerDecoderLayer`, including multihead self
|
||||||
|
attention, cross attention and feedforward layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch_size(int): The batch size of the input tensor.
|
batch_size(int): The batch size of the input tensor.
|
||||||
|
@ -1798,7 +1838,7 @@ class TransformerDecoder(Cell):
|
||||||
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
|
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
|
||||||
layernorm_compute_type(dtype.Number): The computation type of the layernorm.
|
layernorm_compute_type(dtype.Number): The computation type of the layernorm.
|
||||||
Can be dtype.float32 or dtype.float16. Default dtype.float16.
|
Can be dtype.float32 or dtype.float16. Default dtype.float16.
|
||||||
softmax_comptue_type(dtype.Number): The computation type of the softmax in the attention.
|
softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
|
||||||
Can be dtype.float32 or dtype.float16. Default mstype.float16.
|
Can be dtype.float32 or dtype.float16. Default mstype.float16.
|
||||||
param_init_type: The parameter initialization type of the module. Can be dtype.float32 or dtype.float16.
|
param_init_type: The parameter initialization type of the module. Can be dtype.float32 or dtype.float16.
|
||||||
Default dtype.float32.
|
Default dtype.float32.
|
||||||
|
@ -1826,6 +1866,7 @@ class TransformerDecoder(Cell):
|
||||||
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
|
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
|
||||||
- **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index.
|
- **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index.
|
||||||
Used for incremental prediction when the use_past is True. Default None.
|
Used for incremental prediction when the use_past is True. Default None.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tuple, a tuple contains(`output`, `layer_present`)
|
Tuple, a tuple contains(`output`, `layer_present`)
|
||||||
|
|
||||||
|
@ -1844,8 +1885,8 @@ class TransformerDecoder(Cell):
|
||||||
... num_heads=2, src_seq_length=20, tgt_seq_length=10)
|
... num_heads=2, src_seq_length=20, tgt_seq_length=10)
|
||||||
>>> encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
|
>>> encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
|
||||||
>>> decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
|
>>> decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
|
||||||
>>> decoder_input_mask = Tensor(np.ones((2, 1, 10, 10)), dtype.float16)
|
>>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
|
||||||
>>> memory_mask = Tensor(np.ones((2, 1, 10, 20)), dtype.float16)
|
>>> memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
|
||||||
>>> output, past = model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
|
>>> output, past = model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
|
||||||
>>> print(output.shape)
|
>>> print(output.shape)
|
||||||
(2, 10, 64)
|
(2, 10, 64)
|
||||||
|
@ -1876,7 +1917,7 @@ class TransformerDecoder(Cell):
|
||||||
post_layernorm_residual=Validator.check_bool,
|
post_layernorm_residual=Validator.check_bool,
|
||||||
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"TransformerDecoder"),
|
"TransformerDecoder"),
|
||||||
softmax_comptue_type=_valid_value_checks([mstype.float32, mstype.float16],
|
softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"TransformerDecoder"),
|
"TransformerDecoder"),
|
||||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"TransformerDecoder"),
|
"TransformerDecoder"),
|
||||||
|
@ -1895,7 +1936,7 @@ class TransformerDecoder(Cell):
|
||||||
hidden_dropout_rate=0.1,
|
hidden_dropout_rate=0.1,
|
||||||
post_layernorm_residual=False,
|
post_layernorm_residual=False,
|
||||||
layernorm_compute_type=mstype.float32,
|
layernorm_compute_type=mstype.float32,
|
||||||
softmax_comptue_type=mstype.float32,
|
softmax_compute_type=mstype.float32,
|
||||||
param_init_type=mstype.float32,
|
param_init_type=mstype.float32,
|
||||||
hidden_act='gelu',
|
hidden_act='gelu',
|
||||||
lambda_func=None,
|
lambda_func=None,
|
||||||
|
@ -1908,6 +1949,8 @@ class TransformerDecoder(Cell):
|
||||||
|
|
||||||
self.add = P.TensorAdd().shard(((), ()))
|
self.add = P.TensorAdd().shard(((), ()))
|
||||||
self.aux_loss = Tensor(0.0, mstype.float32)
|
self.aux_loss = Tensor(0.0, mstype.float32)
|
||||||
|
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
|
||||||
|
raise RuntimeError(f"The {self.cls_name} does not support auto parallel mode now.")
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.blocks = nn.CellList()
|
self.blocks = nn.CellList()
|
||||||
self.use_moe = (moe_config.expert_num > 1)
|
self.use_moe = (moe_config.expert_num > 1)
|
||||||
|
@ -1921,7 +1964,7 @@ class TransformerDecoder(Cell):
|
||||||
hidden_dropout_rate=hidden_dropout_rate,
|
hidden_dropout_rate=hidden_dropout_rate,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
layernorm_compute_type=layernorm_compute_type,
|
layernorm_compute_type=layernorm_compute_type,
|
||||||
softmax_comptue_type=softmax_comptue_type,
|
softmax_compute_type=softmax_compute_type,
|
||||||
hidden_act=hidden_act,
|
hidden_act=hidden_act,
|
||||||
use_past=use_past,
|
use_past=use_past,
|
||||||
param_init_type=param_init_type,
|
param_init_type=param_init_type,
|
||||||
|
@ -1969,7 +2012,7 @@ class TransformerDecoder(Cell):
|
||||||
class Transformer(Cell):
|
class Transformer(Cell):
|
||||||
r"""
|
r"""
|
||||||
Transformer module including encoder and decoder. The difference with the original implements is the module use
|
Transformer module including encoder and decoder. The difference with the original implements is the module use
|
||||||
the residual addition before the layernormalization. And the default hidden act is `gelu`.
|
the residual addition before the layer normalization. And the default hidden act is `gelu`.
|
||||||
The detials can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
|
The detials can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
|
@ -2037,13 +2080,13 @@ class Transformer(Cell):
|
||||||
``Ascend`` ``GPU``
|
``Ascend`` ``GPU``
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> model = Transformer(encoder_layers=1, decoder_layers=2, hidden_size=64, ffn_hidden_size=64,
|
>>> model = Transformer(batch_size=2, encoder_layers=1, decoder_layers=2, hidden_size=64, ffn_hidden_size=64,
|
||||||
... src_seq_length=20, tgt_seq_length=10)
|
... src_seq_length=20, tgt_seq_length=10)
|
||||||
>>> encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
|
>>> encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
|
||||||
>>> encoder_input_mask = Tensor(np.ones((2, 1, 20, 20)), dtype.float16)
|
>>> encoder_input_mask = Tensor(np.ones((2, 20, 20)), dtype.float16)
|
||||||
>>> decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
|
>>> decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
|
||||||
>>> decoder_input_mask = Tensor(np.ones((2, 1, 10, 10)), dtype.float16)
|
>>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
|
||||||
>>> memory_mask = Tensor(np.ones((2, 1, 10, 20)), dtype.float16)
|
>>> memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
|
||||||
>>> output, en_past, de_past = model(encoder_input_value, encoder_input_mask, decoder_input_value,
|
>>> output, en_past, de_past = model(encoder_input_value, encoder_input_mask, decoder_input_value,
|
||||||
... decoder_input_mask, memory_mask)
|
... decoder_input_mask, memory_mask)
|
||||||
>>> print(output.shape)
|
>>> print(output.shape)
|
||||||
|
@ -2079,11 +2122,11 @@ class Transformer(Cell):
|
||||||
hidden_dropout_rate=Validator.check_non_negative_float,
|
hidden_dropout_rate=Validator.check_non_negative_float,
|
||||||
hidden_act=_valid_type_checks([str], "Transformer"),
|
hidden_act=_valid_type_checks([str], "Transformer"),
|
||||||
post_layernorm_residual=Validator.check_bool,
|
post_layernorm_residual=Validator.check_bool,
|
||||||
layernorm_compute_type=_valid_type_checks([mstype.float32, mstype.float16],
|
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"Transformer"),
|
"Transformer"),
|
||||||
softmax_comptue_type=_valid_type_checks([mstype.float32, mstype.float16],
|
softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||||
"Transformer"),
|
"Transformer"),
|
||||||
param_init_type=_valid_type_checks([mstype.float32, mstype.float16], "Transformer"),
|
param_init_type=_valid_value_checks([mstype.float32, mstype.float16], "Transformer"),
|
||||||
parallel_config=_valid_type_checks([TransformerOpParallelConfig], "Transformer"),
|
parallel_config=_valid_type_checks([TransformerOpParallelConfig], "Transformer"),
|
||||||
use_past=Validator.check_bool)
|
use_past=Validator.check_bool)
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -2100,7 +2143,7 @@ class Transformer(Cell):
|
||||||
hidden_act='gelu',
|
hidden_act='gelu',
|
||||||
post_layernorm_residual=False,
|
post_layernorm_residual=False,
|
||||||
layernorm_compute_type=mstype.float32,
|
layernorm_compute_type=mstype.float32,
|
||||||
softmax_comptue_type=mstype.float32,
|
softmax_compute_type=mstype.float32,
|
||||||
param_init_type=mstype.float32,
|
param_init_type=mstype.float32,
|
||||||
lambda_func=None,
|
lambda_func=None,
|
||||||
use_past=False,
|
use_past=False,
|
||||||
|
@ -2118,7 +2161,9 @@ class Transformer(Cell):
|
||||||
f"layer {decoder_layers}, please use TransformerDecoder")
|
f"layer {decoder_layers}, please use TransformerDecoder")
|
||||||
if encoder_layers > 0 and decoder_layers > 0 and use_past is True:
|
if encoder_layers > 0 and decoder_layers > 0 and use_past is True:
|
||||||
raise ValueError("The transformer with encoder and decoder does not support use_past=True.")
|
raise ValueError("The transformer with encoder and decoder does not support use_past=True.")
|
||||||
# The shard setting of Transformer is set within the class StackedTransformer
|
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
|
||||||
|
raise RuntimeError(f"The {self.cls_name} does not support auto parallel mode now.")
|
||||||
|
# The shard setting of Transformer is set within the TransformerEncoderLayer
|
||||||
if not lambda_func:
|
if not lambda_func:
|
||||||
lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers)
|
lambda_func = _get_lambda_func(total_layer=encoder_layers + decoder_layers)
|
||||||
|
|
||||||
|
@ -2136,7 +2181,7 @@ class Transformer(Cell):
|
||||||
hidden_dropout_rate=hidden_dropout_rate,
|
hidden_dropout_rate=hidden_dropout_rate,
|
||||||
hidden_act=hidden_act,
|
hidden_act=hidden_act,
|
||||||
layernorm_compute_type=layernorm_compute_type,
|
layernorm_compute_type=layernorm_compute_type,
|
||||||
softmax_comptue_type=softmax_comptue_type,
|
softmax_compute_type=softmax_compute_type,
|
||||||
post_layernorm_residual=post_layernorm_residual,
|
post_layernorm_residual=post_layernorm_residual,
|
||||||
param_init_type=param_init_type,
|
param_init_type=param_init_type,
|
||||||
lambda_func=lambda_func,
|
lambda_func=lambda_func,
|
||||||
|
@ -2162,7 +2207,7 @@ class Transformer(Cell):
|
||||||
hidden_act=hidden_act,
|
hidden_act=hidden_act,
|
||||||
post_layernorm_residual=post_layernorm_residual,
|
post_layernorm_residual=post_layernorm_residual,
|
||||||
layernorm_compute_type=layernorm_compute_type,
|
layernorm_compute_type=layernorm_compute_type,
|
||||||
softmax_comptue_type=softmax_comptue_type,
|
softmax_compute_type=softmax_compute_type,
|
||||||
lambda_func=lambda_func,
|
lambda_func=lambda_func,
|
||||||
use_past=use_past,
|
use_past=use_past,
|
||||||
param_init_type=param_init_type,
|
param_init_type=param_init_type,
|
||||||
|
|
|
@ -398,7 +398,7 @@ class PanGUAlphaWithLoss(Cell):
|
||||||
self.not_equal = P.NotEqual().shard(((dp, 1), ()))
|
self.not_equal = P.NotEqual().shard(((dp, 1), ()))
|
||||||
self.batch_size = config.batch_size
|
self.batch_size = config.batch_size
|
||||||
self.len = config.seq_length
|
self.len = config.seq_length
|
||||||
self.expand = P.ExpandDims().shard(((dp, 1, 1),))
|
self.slice2 = P.StridedSlice().shard(((dp, 1, 1),))
|
||||||
self.micro_batch_step = 1
|
self.micro_batch_step = 1
|
||||||
if config.parallel_config.pipeline_stage > 1:
|
if config.parallel_config.pipeline_stage > 1:
|
||||||
self.micro_batch_step = config.parallel_config.micro_batch_num
|
self.micro_batch_step = config.parallel_config.micro_batch_num
|
||||||
|
@ -407,13 +407,14 @@ class PanGUAlphaWithLoss(Cell):
|
||||||
r"""Forward process of the pangu alpha model"""
|
r"""Forward process of the pangu alpha model"""
|
||||||
tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1))
|
tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1))
|
||||||
input_position = self.slice(input_position, (0, 0), (self.batch_size, self.len), (1, 1))
|
input_position = self.slice(input_position, (0, 0), (self.batch_size, self.len), (1, 1))
|
||||||
encoder_attention_masks = attention_mask
|
decoder_attention_masks = self.slice2(attention_mask, (0, 0, 0), (self.batch_size, self.len, self.len),
|
||||||
|
(1, 1, 1))
|
||||||
input_mask = F.cast(self.not_equal(tokens, self.eod_token),
|
input_mask = F.cast(self.not_equal(tokens, self.eod_token),
|
||||||
mstype.float32)
|
mstype.float32)
|
||||||
|
|
||||||
logits = self.network(tokens,
|
logits = self.network(tokens,
|
||||||
input_position,
|
input_position,
|
||||||
encoder_attention_masks)
|
decoder_attention_masks)
|
||||||
# Get label corresponding to input tokens
|
# Get label corresponding to input tokens
|
||||||
labels = self.slice(input_ids, (0, 1), (self.batch_size, self.len + 1),
|
labels = self.slice(input_ids, (0, 1), (self.batch_size, self.len + 1),
|
||||||
(1, 1))
|
(1, 1))
|
||||||
|
|
|
@ -74,14 +74,15 @@ class NetWithLossFiveInputs(nn.Cell):
|
||||||
|
|
||||||
def run_total_transformer_model_head(e_layer,
|
def run_total_transformer_model_head(e_layer,
|
||||||
d_layer,
|
d_layer,
|
||||||
arg_parallel_config):
|
arg_parallel_config,
|
||||||
|
mode=ParallelMode.SEMI_AUTO_PARALLEL):
|
||||||
dp = arg_parallel_config.data_parallel
|
dp = arg_parallel_config.data_parallel
|
||||||
mp = arg_parallel_config.model_parallel
|
mp = arg_parallel_config.model_parallel
|
||||||
pp = arg_parallel_config.pipeline_stage
|
pp = arg_parallel_config.pipeline_stage
|
||||||
if dp * mp * pp != 1:
|
if dp * mp * pp != 1:
|
||||||
set_auto_parallel_context(device_num=8,
|
set_auto_parallel_context(device_num=8,
|
||||||
full_batch=True,
|
full_batch=True,
|
||||||
global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
global_rank=0, parallel_mode=mode)
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self, en_layer, de_layer, parallel_config):
|
def __init__(self, en_layer, de_layer, parallel_config):
|
||||||
|
@ -208,6 +209,13 @@ def test_transformer_model_head_stand_alone():
|
||||||
run_total_transformer_model_head(e_layer=2, d_layer=2, arg_parallel_config=local_config)
|
run_total_transformer_model_head(e_layer=2, d_layer=2, arg_parallel_config=local_config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformer_model_auto_parallel_no_support():
|
||||||
|
local_config = TransformerOpParallelConfig(data_parallel=8, model_parallel=1)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
run_total_transformer_model_head(e_layer=2, d_layer=2, arg_parallel_config=local_config,
|
||||||
|
mode=ParallelMode.AUTO_PARALLEL)
|
||||||
|
|
||||||
|
|
||||||
def test_pipeline_single_transformer():
|
def test_pipeline_single_transformer():
|
||||||
set_auto_parallel_context(device_num=32,
|
set_auto_parallel_context(device_num=32,
|
||||||
full_batch=True,
|
full_batch=True,
|
||||||
|
@ -405,6 +413,7 @@ def test_sparse_attention_parallel_mp():
|
||||||
model = Model(net)
|
model = Model(net)
|
||||||
model.train(1, dataset, dataset_sink_mode=False)
|
model.train(1, dataset, dataset_sink_mode=False)
|
||||||
|
|
||||||
|
|
||||||
def test_sparse_attention_parallel_mix():
|
def test_sparse_attention_parallel_mix():
|
||||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
|
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||||
set_algo_parameters(fully_use_devices=False)
|
set_algo_parameters(fully_use_devices=False)
|
||||||
|
@ -423,6 +432,7 @@ def test_sparse_attention_parallel_mix():
|
||||||
model = Model(net)
|
model = Model(net)
|
||||||
model.train(1, dataset, dataset_sink_mode=False)
|
model.train(1, dataset, dataset_sink_mode=False)
|
||||||
|
|
||||||
|
|
||||||
def test_sparse_attention_parallel_mix1():
|
def test_sparse_attention_parallel_mix1():
|
||||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
|
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||||
set_algo_parameters(fully_use_devices=False)
|
set_algo_parameters(fully_use_devices=False)
|
||||||
|
@ -441,6 +451,7 @@ def test_sparse_attention_parallel_mix1():
|
||||||
model = Model(net)
|
model = Model(net)
|
||||||
model.train(1, dataset, dataset_sink_mode=False)
|
model.train(1, dataset, dataset_sink_mode=False)
|
||||||
|
|
||||||
|
|
||||||
def test_sparse_attention_parallel_dp():
|
def test_sparse_attention_parallel_dp():
|
||||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
|
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||||
set_algo_parameters(fully_use_devices=False)
|
set_algo_parameters(fully_use_devices=False)
|
||||||
|
@ -459,6 +470,7 @@ def test_sparse_attention_parallel_dp():
|
||||||
model = Model(net)
|
model = Model(net)
|
||||||
model.train(1, dataset, dataset_sink_mode=False)
|
model.train(1, dataset, dataset_sink_mode=False)
|
||||||
|
|
||||||
|
|
||||||
def test_parallel_cross_entroy_loss_semi_auto_parallel():
|
def test_parallel_cross_entroy_loss_semi_auto_parallel():
|
||||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
|
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
|
||||||
|
|
||||||
|
@ -496,7 +508,7 @@ def test_transformer_args():
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
|
Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
|
||||||
tgt_seq_length=20, softmax_comptue_type=mstype.int64)
|
tgt_seq_length=20, softmax_compute_type=mstype.int64)
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
|
Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
|
||||||
|
@ -510,6 +522,9 @@ def test_transformer_args():
|
||||||
Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
|
Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
|
||||||
tgt_seq_length=20, hidden_dropout_rate=mstype.int64)
|
tgt_seq_length=20, hidden_dropout_rate=mstype.int64)
|
||||||
|
|
||||||
|
Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
|
||||||
|
tgt_seq_length=20, softmax_compute_type=mstype.float16)
|
||||||
|
|
||||||
|
|
||||||
def test_transformer_parallel_config():
|
def test_transformer_parallel_config():
|
||||||
parallel_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=3)
|
parallel_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=3)
|
||||||
|
|
Loading…
Reference in New Issue