Impove the transforer apis

Support attention mask is None
Fix attention check
WIP in custom activation
Support attention
This commit is contained in:
huangxinjing 2022-08-20 09:16:31 +08:00
parent 587091ea47
commit 908975d458
10 changed files with 397 additions and 146 deletions

View File

@ -10,7 +10,7 @@
如果query tensor、key tensor和value tensor相同则上述即为自注意力机制的计算过程。
参数:
- **batch_size** (int) - 表示训练批次的大小
- **batch_size** (int) - 表示增量预测时输入张量的批量大小应该是正整数。当进行训练或预测时该参数将不起作用用户可将None传递给此参数
- **src_seq_length** (int) - 表示query向量的序列长度。
- **tgt_seq_length** (int) - 表示key向量和value向量的序列长度。
- **hidden_size** (int) - 表示输入的向量大小。
@ -28,7 +28,7 @@
- **query_tensor** (Tensor) - use_past为False或is_first_iteration为True时表示shape为(batch_size, src_seq_length, hidden_size)或(batch_size * src_seq_length, hidden_size)的query向量。否则shape必须为(batch_size, 1, hidden_size)。
- **key_tensor** (Tensor) - use_past为False或is_first_iteration为True时表示shape为(batch_size, tgt_seq_length, hidden_size)或(batch_size * tgt_seq_length, hidden_size)的key向量。否则shape必须为(batch_size, 1, hidden_size)。
- **value_tensor** (Tensor) - use_past为False或is_first_iteration为True时表示shape为(batch_size, tgt_seq_length, hidden_size)或(batch_size * tgt_seq_length, hidden_size)的value向量。否则shape必须为(batch_size, 1, hidden_size)。
- **attention_mask** (Tensor) - use_past为False或is_first_iteration为True时表示shape为(batch_size, src_seq_length, tgt_seq_length)的注意力掩码矩阵。否则shape必须为(batch_size, 1, tgt_seq_length)。
- **attention_mask** (Tensor) - use_past为False或is_first_iteration为True时表示shape为(batch_size, src_seq_length, tgt_seq_length)的注意力掩码矩阵, 或者为NoneNone表示在Softmax计算中将不会进行掩码。否则shape必须为(batch_size, 1, tgt_seq_length)。
- **key_past** (Tensor) - shape为(batch_size, num_heads, size_per_head, tgt_seq_length)的Float16 tensor表示过去所计算的key向量。
当use_past为True时需要传入非None值用于增量预测。默认值为None。
- **value_past** (Tensor) - shape为(batch_size, num_heads, tgt_seq_length, size_per_head)的Float16 tensor表示过去所计算的value向量。

View File

@ -7,7 +7,7 @@
这是一个实验接口,可能会被更改或者删除。
参数:
- **batch_size** (int) - 表示输入的批次大小
- **batch_size** (int) - 表示增量预测时输入张量的批量大小应该是正整数。当进行训练或预测时该参数将不起作用用户可将None传递给此参数
- **encoder_layers** (int) - 表示 `TransformerEncoderLayer` 的层数。
- **decoder_layers** (int) - 表示 `TransformerDecoderLayer` 的层数。
- **hidden_size** (int) - 表示输入向量的大小。
@ -29,10 +29,10 @@
输入:
- **encoder_inputs** (Tensor) - shape为[batch_size, seq_length, hidden_size]或[batch_size * seq_length, hidden_size]的输入Tensor。
- **encoder_masks** (Tensor) - shape为[batch_size, seq_length, seq_length]的解码器的注意力掩码。
- **encoder_masks** (Tensor) - shape为[batch_size, seq_length, seq_length]的解码器的注意力掩码。或者为NoneNone表示在编码器中self attention中的Softmax计算中将不会进行掩码。
- **decoder_inputs** (Tensor) - shape为[batch_size, seq_length, hidden_size]或[batch_size * seq_length, hidden_size]的编码器的输出。如果解码器层数为0则此值应为None。
- **decoder_masks** (Tensor) - shape为[batch_size, seq_length, seq_length]的解码器的注意力掩码。
- **memory_mask** (Tensor) - shape为[batch, tgt_seq_length, src_seq_length]的交叉注意力的memory掩码其中tgt_seq_length表示解码器的长度。如果解码器层为0则shape为[batch_size, seq_length, hidden_size]的编码器的输出应为None
- **decoder_masks** (Tensor) - shape为[batch_size, seq_length, seq_length]的解码器的注意力掩码。或者为NoneNone表示将不会在解码器中的self attention中的Softmax计算中引入掩码计算。
- **memory_mask** (Tensor) - shape为[batch, tgt_seq_length, src_seq_length]的交叉注意力的memory掩码其中tgt_seq_length表示解码器的长度。或者为NoneNone表示将不会在cross attention中的Softmax计算中引入掩码计算
- **init_reset** (Tensor) - shape为[1]的bool tensor用于清除增量预测中使用的past key参数和past value参数。仅当use_past为True时有效。默认值为True。
- **batch_valid_length** (Tensor) - shape为[batch_size]的Int32 tensor表示过去所计算的索引。当use_past为True时它用于增量预测。默认值为None。

View File

@ -3,7 +3,7 @@
Transformer中的解码器模块为多层堆叠的 `TransformerDecoderLayer` ,包括多头自注意力层、交叉注意力层和前馈层。
参数:
- **batch_size** (int) - 表示输入Tensor的批次大小
- **batch_size** (int) - 表示增量预测时输入张量的批量大小应该是正整数。当进行训练或预测时该参数将不起作用用户可将None传递给此参数
- **num_layers** (int) - 表示 `TransformerDecoderLayer` 的层数。
- **hidden_size** (int) - 表示输入的隐藏大小。
- **ffn_hidden_size** (int) - 表示前馈层中bottleneck的隐藏大小。
@ -25,12 +25,12 @@
输入:
- **hidden_stats** (Tensor) - shape为[batch_size, seq_length, hidden_size]或[batch_size * seq_length, hidden_size]的输入tensor。
- **attention_mask** (Tensor) - shape为[batch_size, seq_length, seq_length]的解码器的注意力掩码。
- **attention_mask** (Tensor) - shape为[batch_size, seq_length, seq_length]的解码器的注意力掩码。或者为NoneNone表示将不会在self attention中的Softmax计算中引入掩码计算。
- **encoder_output** (Tensor) - shape为[batch_size, seq_length, hidden_size]或[batch_size * seq_length, hidden_size]的编码器的输出。
.. note::当网络位于最外层时此参数不能通过None传递。默认值为None。
- **memory_mask** (Tensor) - shape为[batch, tgt_seq_length, src_seq_length]的交叉注意力的memory掩码其中tgt_seq_length表示解码器的长度。当网络位于最外层时此参数不能通过None传递。默认值为None
- **memory_mask** (Tensor) - shape为[batch, tgt_seq_length, src_seq_length]的交叉注意力的memory掩码其中tgt_seq_length表示解码器的长度。或者为NoneNone表示将不会在cross attention中的Softmax计算中引入掩码计算
- **init_reset** (Tensor) - shape为[1]的bool tensor用于清除增量预测中使用的past key参数和past value参数。仅当use_past为True时有效。默认值为True。
- **batch_valid_length** (Tensor) - shape为[batch_size]的Int32 tensor表示过去所计算的索引。当use_past为True时它用于增量预测。默认值为None。

View File

@ -3,7 +3,7 @@
Transformer的解码器层。Transformer的解码器层上的单层的实现包括自注意力层、交叉注意力层和前馈层。当encoder_output为None时交叉注意力将无效。
参数:
- **batch_size** (int) - 表示输入Tensor的批次大小
- **batch_size** (int) - 表示增量预测时输入张量的批量大小应该是正整数。当进行训练或预测时该参数将不起作用用户可将None传递给此参数
- **hidden_size** (int) - 表示输入的隐藏大小。
- **src_seq_length** (int) - 表示输入源序列长度。
- **tgt_seq_length** (int) - 表示输入目标序列长度。
@ -22,9 +22,9 @@
输入:
- **hidden_stats** (Tensor) - shape为[batch_size, tgt_seq_length, hidden_size]或[batch_size * tgt_seq_length, hidden_size]的输入tensor。
- **decoder_mask** (Tensor) - shape为[batch_size, src_seq_length, seq_length]的解码器的注意力掩码。
- **decoder_mask** (Tensor) - shape为[batch_size, src_seq_length, seq_length]的解码器的注意力掩码。或者为NoneNone表示将不会在self attention中的Softmax计算中引入掩码计算。
- **encoder_output** (Tensor) - shape为[batch_size, seq_length, hidden_size]或[batch_size * seq_length, hidden_size]的编码器的输出。注当网络位于最外层时此参数不能通过None传递。默认值为None。
- **memory_mask** (Tensor) - shape为[batch, tgt_seq_length, src_seq_length]的交叉注意力的memory掩码其中tgt_seq_length表示解码器的长度。当网络位于最外层时此参数不能通过None传递。默认值为None
- **memory_mask** (Tensor) - shape为[batch, tgt_seq_length, src_seq_length]的交叉注意力的memory掩码其中tgt_seq_length表示解码器的长度。或者为NoneNone表示将不会在cross attention中的Softmax计算中引入掩码计算
- **init_reset** (Tensor) - shape为[1]的bool tensor用于清除增量预测中使用的past key参数和past value参数。仅当use_past为True时有效。默认值为True。
- **batch_valid_length** (Tensor) - shape为[batch_size]的Int32 tensor表示过去所计算的索引。当use_past为True时它用于增量预测。默认值为None。

View File

@ -3,7 +3,7 @@
Transformer中的编码器模块具有多层堆叠的 `TransformerEncoderLayer` ,包括多头自注意力层和前馈层。
参数:
- **batch_size** (int) - 表示输入tensor的批次大小
- **batch_size** (int) - 表示增量预测时输入张量的批量大小应该是正整数。当进行训练或预测时该参数将不起作用用户可将None传递给此参数
- **num_layers** (int) - 表示 `TransformerEncoderLayer` 的层。
- **hidden_size** (int) - 表示输入的隐藏大小。
- **ffn_hidden_size** (int) - 表示前馈层中bottleneck的隐藏大小。
@ -24,7 +24,7 @@
输入:
- **hidden_states** (Tensor) - Tensor。如果use_past为False或者is_first_iteration为Trueshape为[batch_size, seq_length, hidden_size]或者[batch_size * seq_length, hidden_size]。否则shape应为[batch_size, 1, hidden_size]。
- **attention_mask** (Tensor) - Tensor表示shape为[batch_size, seq_length, seq_length]的注意力掩码。
- **attention_mask** (Tensor) - Tensor。use_past为False或者is_first_iteration为True时表示shape为[batch_size, seq_length, seq_length]的注意力掩码或者为NoneNone表示在Softmax计算中将不会进行掩码。否则shape应为[batch_size, 1, hidden_size]
- **init_reset** (Tensor) - shape为[1]的bool tensor用于清除增量预测中使用的past key参数和past value参数。仅当use_past为True时有效。默认值为True。
- **batch_valid_length** (Tensor) - shape为[batch_size]的Int32 tensor表示过去所计算的索引。当use_past为True时它用于增量预测。默认值为None。

View File

@ -3,7 +3,7 @@
Transformer的编码器层。Transformer的编码器层上的单层的实现包括多头注意力层和前馈层。
参数:
- **batch_size** (int) - 表示输入Tensor的批次大小
- **batch_size** (int) - 表示增量预测时输入张量的批量大小应该是正整数。当进行训练或预测时该参数将不起作用用户可将None传递给此参数
- **hidden_size** (int) - 表示输入的隐藏大小。
- **seq_length** (int) - 表示输入序列长度。
- **ffn_hidden_size** (int) - 表示前馈层中bottleneck的隐藏大小。
@ -21,7 +21,7 @@
输入:
- **x** (Tensor) - Float Tensor。如果use_past为False或者is_first_iteration为Trueshape应为[batch_size, seq_length, hidden_size]或者[batch_size * seq_length, hidden_size]。否则shape应为[batch_size, 1, hidden_size]。
- **input_mask** (Tensor) - Float tensor。use_past为False或者is_first_iteration为True时表示shape为[batch_size, seq_length, seq_length]的注意力掩码。否则shape应为[batch_size, 1, hidden_size]。
- **input_mask** (Tensor) - Float tensor。use_past为False或者is_first_iteration为True时表示shape为[batch_size, seq_length, seq_length]的注意力掩码或者为NoneNone表示在Softmax计算中将不会进行掩码。否则shape应为[batch_size, 1, hidden_size]。
- **init_reset** (Tensor) - shape为[1]的bool tensor用于清除增量预测中使用的past key参数和past value参数。仅当use_past为True时有效。默认值为True。
- **batch_valid_length** (Tensor) - shape为[batch_size]的Int32 tensor表示过去所计算的索引。当use_past为True时它用于增量预测。默认值为None。

View File

@ -37,7 +37,8 @@ from mindspore._checkparam import Validator
from mindspore.ops.primitive import constexpr
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from mindspore.context import ParallelMode
from mindspore.nn.transformer.op_parallel_config import default_dpmp_config, OpParallelConfig
from mindspore.nn.transformer.op_parallel_config import default_dpmp_config, OpParallelConfig, MoEParallelConfig
from mindspore import log as logger
__all__ = [
"FixedSparseAttention"
@ -154,6 +155,30 @@ class _LayerInputCheck:
f"but got {input_shape[dim]}")
return True
@staticmethod
def check_shape_equal_without_batch(input_shape, param_name, func_name, target_shape):
"""
Check the input shape's is equal to the expected shape, the value on 0-th is viewed as batch, and the
batch size will not be checked.
"""
target_shape = target_shape
length, hidden = target_shape
if isinstance(input_shape, tuple):
input_shape = list(input_shape)
_LayerInputCheck.check_shape_length(input_shape, param_name, func_name,
[len(target_shape), len(target_shape) + 1])
if input_shape[-1] != hidden:
raise ValueError(f"For {func_name}, the last dimension of {param_name} shape must be {hidden},"
f"but got the last dimension {input_shape[-1]} in {input_shape}.")
if input_shape[0] == 0:
raise ValueError(f"For {func_name}, the first dimension of {param_name} shape greater than 0,"
f"but got the first dimension {input_shape[0]} in {input_shape}.")
if len(input_shape) == 2 and input_shape[0] % length != 0:
raise ValueError(f"For {func_name}, the first dimension of {param_name} shape should be divisible "
f"by {length}, "
f"but got the first dimension {input_shape[0]} in {input_shape}.")
return True
@constexpr
def _check_past_none_input_none(use_past, param_name, func_name, default_value, is_tensor, is_default):
@ -192,6 +217,11 @@ def _check_input_shape_value(input_shape, dim, param_name, cls_name, target_valu
_LayerInputCheck.check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value)
@constexpr
def _check_shape_equal_without_batch(input_shape, param_name, func_name, target_shape):
_LayerInputCheck.check_shape_equal_without_batch(input_shape, param_name, func_name, target_shape)
class _Dropout(nn.Cell):
r"""
A Dropout Implements with P.DropoutGenMask and P.DropoutDoMask for parallel training.
@ -392,7 +422,6 @@ class _Linear(Cell):
@_args_type_validator_check(in_channels=Validator.check_positive_int,
out_channels=Validator.check_positive_int,
has_bias=Validator.check_bool,
activation=_valid_type_checks([type(None), str], "Linear"),
transpose_b=Validator.check_bool,
expert_num=Validator.check_positive_int,
outer_batch=Validator.check_positive_int,
@ -449,7 +478,10 @@ class _Linear(Cell):
self.bias.parallel_optimizer = False
self.bias_add = P.Add()
self.act_name = activation
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if callable(activation):
self.activation = activation()
else:
self.activation = get_activation(activation) if isinstance(activation, str) else activation
self.activation_flag = self.activation is not None
self.dtype = compute_dtype
self.cast = P.Cast()
@ -491,7 +523,7 @@ class _Linear(Cell):
self.matmul.shard(strategy_matmul)
if self.has_bias:
self.bias_add.shard(strategy_bias)
if self.activation_flag:
if self.activation_flag and isinstance(self.act_name, str):
# some operations has many primitives, need to manually set the shard
if self.act_name.lower() == "leakyrelu":
self.activation.select_op.shard((strategy_activation[0], strategy_activation[0]))
@ -506,7 +538,26 @@ class _Linear(Cell):
"or auto parallel mode.")
else:
getattr(self.activation, self.act_name).shard(strategy_activation)
elif self.activation_flag and isinstance(self.activation, Cell):
if hasattr(self.activation, 'shard') and strategy_activation:
shard_tuple = strategy_activation[0]
if len(shard_tuple) == 2:
parallel_config = OpParallelConfig(data_parallel=shard_tuple[0],
model_parallel=shard_tuple[1])
elif len(shard_tuple) == 4:
parallel_config = MoEParallelConfig(data_parallel=shard_tuple[0],
expert_parallel=shard_tuple[1],
model_parallel=shard_tuple[2])
else:
raise ValueError("The user-defined activation function currently only supports the case where the "
"input policy is 2 or 4, so that relevant policies can be extracted from it."
"To avoid this error, you need to add the function of extracting "
"'ParallelConfig' or 'OpParallelConfig' from the incoming strategy_activation ")
self.activation.shard(parallel_config)
else:
logger.warning(f"The user passed the custom defined activation function {self.activation_flag}. "
f"If the user want to enable shard for the activation cell, "
f"the user should set the shard for each primitives in the cell.")
return self

View File

@ -37,8 +37,9 @@ from mindspore.context import ParallelMode
from mindspore.log import _LogActionOnce
from mindspore.nn.transformer.layers import _LayerNorm, _Linear, _check_input_shape, \
_args_type_validator_check, _valid_type_checks, _valid_value_checks, \
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value
from mindspore.nn.transformer.op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig,\
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value, \
_check_shape_equal_without_batch
from mindspore.nn.transformer.op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, \
_Config, _check_config, MoEParallelConfig
from mindspore.nn.transformer.moe import default_moe_config, MoE, _check_moe_config
@ -399,7 +400,6 @@ class FeedForward(Cell):
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
ffn_hidden_size=Validator.check_positive_int,
dropout_rate=Validator.check_non_negative_float,
hidden_act=_valid_type_checks([str], "FeedForward"),
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
"FeedForward"),
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig],
@ -718,7 +718,9 @@ class MultiHeadAttention(Cell):
if query, key and value tensor is same, then it will be self attention.
Args:
batch_size(int): The batch size of the input tensor.
batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
value. When do training or prediction, the argument will not work and the user can just pass None to
the argument.
src_seq_length(int): The sequence length of the query vector.
tgt_seq_length(int): The sequence length of the key and value vector.
hidden_size(int): The hidden size of the input.
@ -751,9 +753,9 @@ class MultiHeadAttention(Cell):
- **value_tensor** (Tensor) - The value vector with shape (batch_size, tgt_seq_length, hidden_size) or
(batch_size * tgt_seq_length, hidden_size), if the use_past is False or is_first_iteration=True.
Otherwise, must be (batch_size, 1, hidden_size)
- **attention_mask** (Tensor) - The attention mask matrix with shape (batch_size, src_seq_length,
tgt_seq_length), if the use_past is False or is_first_iteration=True. Otherwise,
must be (batch_size, 1, tgt_seq_length)
- **attention_mask** (Tensor) - If the use_past is False or is_first_iteration=True, the attention mask
matrix should ba (batch_size, src_seq_length, tgt_seq_length), or None. None means there will be no mask
in softmax computation. Otherwise, the mask must be (batch_size, 1, tgt_seq_length)
- **key_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, size_per_head, tgt_seq_length).
The past calculated key vector. Used for incremental prediction when the use_past is True.
Default None.
@ -783,7 +785,7 @@ class MultiHeadAttention(Cell):
>>> from mindspore.nn.transformer import MultiHeadAttention
>>> from mindspore import dtype as mstype
>>> from mindspore import Tensor
>>> model = MultiHeadAttention(batch_size=2, hidden_size=15, src_seq_length=20, tgt_seq_length=20,
>>> model = MultiHeadAttention(batch_size=None, hidden_size=15, src_seq_length=20, tgt_seq_length=20,
... num_heads=3)
>>> from_tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
>>> to_tensor = Tensor(np.ones((2, 20, 15)), mstype.float16)
@ -830,8 +832,7 @@ class MultiHeadAttention(Cell):
"""
@_LogActionOnce(logger=logger, key='MultiHeadAttention',
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
@_args_type_validator_check(batch_size=Validator.check_positive_int,
hidden_size=Validator.check_positive_int,
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
num_heads=Validator.check_positive_int,
src_seq_length=Validator.check_positive_int,
tgt_seq_length=Validator.check_positive_int,
@ -860,10 +861,13 @@ class MultiHeadAttention(Cell):
parallel_config=default_dpmp_config):
super(MultiHeadAttention, self).__init__()
self._is_ascend = context.get_context('device_target') in ["Ascend"]
self.dp = parallel_config.data_parallel
self.is_parallel_mode = _get_parallel_mode() in (
ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if batch_size:
Validator.check_positive_int(batch_size)
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
_check_config(parallel_config)
self.is_parallel_mode = _get_parallel_mode() in (
ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.src_seq_length = src_seq_length
self.tgt_seq_length = tgt_seq_length
self.hidden_size = hidden_size
@ -883,11 +887,6 @@ class MultiHeadAttention(Cell):
"'parallel_config.model_parallel', but got the num_heads is {} "
"and the parallel_config.model_parallel is {}."
.format(num_heads, parallel_config.model_parallel))
if self.is_parallel_mode and batch_size % parallel_config.data_parallel != 0:
raise ValueError("For 'MultiHeadAttention', the class variable 'batch_size' must be a multiple of "
"'parallel_config.data_parallel', but got the batch_size is {} "
"and the parallel_config.data_parallel is {}."
.format(batch_size, parallel_config.data_parallel))
self.is_first_iteration = True
# Output layer
self.projection = _Linear(in_channels=hidden_size,
@ -961,8 +960,6 @@ class MultiHeadAttention(Cell):
self.mul1 = P.Mul().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
else:
_check_config(parallel_config)
self.is_parallel_mode = _get_parallel_mode() in (
ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.src_seq_length = src_seq_length
self.tgt_seq_length = tgt_seq_length
self.hidden_size = hidden_size
@ -982,11 +979,6 @@ class MultiHeadAttention(Cell):
"'parallel_config.model_parallel', but got the num_heads is {} "
"and the parallel_config.model_parallel is {}."
.format(num_heads, parallel_config.model_parallel))
if self.is_parallel_mode and batch_size % parallel_config.data_parallel != 0:
raise ValueError("For 'MultiHeadAttention', the class variable 'batch_size' must be a multiple of "
"'parallel_config.data_parallel', but got the batch_size is {} "
"and the parallel_config.data_parallel is {}."
.format(batch_size, parallel_config.data_parallel))
self.is_first_iteration = True
# Output layer
self.projection = _Linear(in_channels=hidden_size,
@ -1086,10 +1078,16 @@ class MultiHeadAttention(Cell):
value_past=None, batch_valid_length=None):
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
value_past, batch_valid_length)
query_tensor, key_tensor, value_tensor, batch_size, ori_shape = self._convert_to_2d_tensor(query_tensor,
key_tensor,
value_tensor,
attention_mask)
ori_shape = F.shape(query_tensor)
batch_size = None
if len(F.shape(query_tensor)) == 2:
batch_size = F.shape(query_tensor)[0] // self.src_seq_length
else:
batch_size = F.shape(query_tensor)[0]
query_tensor, key_tensor, value_tensor = self._convert_to_2d_tensor(query_tensor,
key_tensor,
value_tensor,
attention_mask)
ori_dtype = F.dtype(query_tensor)
query_tensor = F.cast(query_tensor, self.dtype)
key_tensor = F.cast(key_tensor, self.dtype)
@ -1116,7 +1114,7 @@ class MultiHeadAttention(Cell):
(batch_size, -1, self.n_head, self.size_per_head)),
(0, 2, 1, 3))
# support input shape is [bs, seq, seq] or [bs, heads, seq, seq]
if len(F.shape(attention_mask)) == 3:
if attention_mask is not None and len(F.shape(attention_mask)) == 3:
# expand attention mask from [bs, seq, seq] -> [bs, 1, seq, seq]
attention_mask = self.expand_dims(attention_mask, 1)
# key and value for current token(s)
@ -1171,17 +1169,15 @@ class MultiHeadAttention(Cell):
value_past=None, batch_valid_length=None):
r"""Check inputs"""
if not self.use_past or (self.use_past and self.is_first_iteration):
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
[[self.batch_size, self.src_seq_length, self.hidden_size],
[self.batch_size * self.src_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, self.src_seq_length, self.tgt_seq_length])
_check_shape_equal_without_batch(F.shape(query_tensor), "query_tensor", self.cls_name,
[self.src_seq_length, self.hidden_size])
_check_shape_equal_without_batch(F.shape(key_tensor), "key_tensor", self.cls_name,
[self.tgt_seq_length, self.hidden_size])
_check_shape_equal_without_batch(F.shape(value_tensor), "value_tensor", self.cls_name,
[self.tgt_seq_length, self.hidden_size])
if attention_mask is not None:
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[F.shape(attention_mask)[0], self.src_seq_length, self.tgt_seq_length])
else:
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
[[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
@ -1189,13 +1185,16 @@ class MultiHeadAttention(Cell):
[[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
[[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[[self.batch_size, 1, self.tgt_seq_length], [self.batch_size, self.hidden_size]])
if attention_mask is not None:
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[[self.batch_size, 1, self.tgt_seq_length], [self.batch_size, self.hidden_size]])
_check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(value_tensor), "value_tensor", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
if attention_mask is not None:
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16],
self.cls_name)
key_is_tensor = isinstance(key_past, Tensor)
value_is_tensor = isinstance(value_past, Tensor)
@ -1228,7 +1227,8 @@ class MultiHeadAttention(Cell):
key_tensor = F.reshape(key_tensor, (-1, key_shape[-1]))
value_shape = F.shape(value_tensor)
value_tensor = F.reshape(value_tensor, (-1, value_shape[-1]))
return query_tensor, key_tensor, value_tensor, F.shape(attention_mask)[0], query_shape
return query_tensor, key_tensor, value_tensor
def _merge_heads(self, x):
"""
@ -1286,30 +1286,31 @@ class MultiHeadAttention(Cell):
score = self.batch_matmul(query, key)
ori_dtype = P.DType()(score)
score = P.Cast()(score, self.softmax_dtype)
attention_scores = P.Cast()(score, self.softmax_dtype)
# for input size of (bs, 1) namely the second graph,
# the shape of attention_mask matrix should be (bs, 1, 1, seq_length)
if self.use_past and not self.is_first_iteration:
# Calculate the current total token
current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0),
(F.shape(query)[0], 1, 1, self.seq_length),
(1, 1, 1, 1)),
0), mstype.float32), (1, 2, 3))
# Get the precise position index
index = self.sub1(F.cast(current_index, mstype.int32), 1)
index = F.reshape(index, (-1, 1, 1))
# Calculate the attention_mask matrix via the position index
attention_mask = F.cast(self.tensor_le(self.range, index), mstype.int32)
attention_mask = self.expand_dims(attention_mask, 2)
if attention_mask is not None:
if self.use_past and not self.is_first_iteration:
# Calculate the current total token
current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0),
(F.shape(query)[0], 1, 1,
self.seq_length),
(1, 1, 1, 1)),
0), mstype.float32), (1, 2, 3))
# Get the precise position index
index = self.sub1(F.cast(current_index, mstype.int32), 1)
index = F.reshape(index, (-1, 1, 1))
# Calculate the attention_mask matrix via the position index
attention_mask = F.cast(self.tensor_le(self.range, index), mstype.int32)
attention_mask = self.expand_dims(attention_mask, 2)
# Minus 10000 for the position where masked to exclude them from softmax
multiplu_out = self.sub(
P.Cast()(F.tuple_to_array((1.0,)), P.DType()(attention_scores)),
P.Cast()(attention_mask, P.DType()(attention_scores)))
# Minus 10000 for the position where masked to exclude them from softmax
multiplu_out = self.sub(
P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
P.Cast()(attention_mask, P.DType()(score)))
adder = self.mul(multiplu_out, self.multiply_data)
attention_scores = self.add(adder, score)
adder = self.mul(multiplu_out, self.multiply_data)
attention_scores = self.add(adder, attention_scores)
# attention probs
attention_probs = self._softmax(attention_scores)
@ -1328,7 +1329,9 @@ class TransformerEncoderLayer(Cell):
encoder layer, including multihead attention and feedward layer.
Args:
batch_size(int): The batch size of the input tensor.
batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
value. When do training or prediction, the argument will not work and the user can just pass None to
the argument.
hidden_size(int): The hidden size of the input.
ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
num_heads(int): The number of the heads.
@ -1362,8 +1365,9 @@ class TransformerEncoderLayer(Cell):
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise,
should be [batch_size, 1, hidden_size]
- **input_mask** (Tensor) - Float Tensor, attention mask with shape [batch_size, seq_length, seq_length],
if the use_past is False or is_first_iteration=True. Otherwise, should be [batch_size, 1, hidden_size]
- **input_mask** (Tensor) - Float Tensor, If the use_past is False or is_first_iteration=True,
the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will
be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size]
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
@ -1430,14 +1434,12 @@ class TransformerEncoderLayer(Cell):
"""
@_LogActionOnce(logger=logger, key='TransformerEncoderLayer',
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
@_args_type_validator_check(batch_size=Validator.check_positive_int,
hidden_size=Validator.check_positive_int,
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
num_heads=Validator.check_positive_int,
ffn_hidden_size=Validator.check_positive_int,
seq_length=Validator.check_positive_int,
attention_dropout_rate=Validator.check_non_negative_float,
hidden_dropout_rate=Validator.check_non_negative_float,
hidden_act=_valid_type_checks([str], "TransformerEncoderLayer"),
post_layernorm_residual=Validator.check_bool,
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
"TransformerEncoderLayer"),
@ -1465,6 +1467,9 @@ class TransformerEncoderLayer(Cell):
moe_config=default_moe_config,
parallel_config=default_dpmp_config):
super(TransformerEncoderLayer, self).__init__()
if batch_size or use_past:
Validator.check_positive_int(batch_size)
self.batch_size = batch_size
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
_check_config(parallel_config)
if num_heads % parallel_config.model_parallel != 0:
@ -1488,7 +1493,6 @@ class TransformerEncoderLayer(Cell):
self.use_past = use_past
self.seq_length = seq_length
self.hidden_size = hidden_size
self.batch_size = batch_size
self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
@ -1564,7 +1568,6 @@ class TransformerEncoderLayer(Cell):
self.use_past = use_past
self.seq_length = seq_length
self.hidden_size = hidden_size
self.batch_size = batch_size
self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
self.layernorm1.shard(((parallel_config.data_parallel, 1),))
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
@ -1623,7 +1626,7 @@ class TransformerEncoderLayer(Cell):
raise RuntimeError(f"The {self.cls_name} only support sharding propagation or "
f"semi-auto parallel mode now.")
def construct(self, x, input_mask, init_reset=True, batch_valid_length=None):
def construct(self, x, input_mask=None, init_reset=True, batch_valid_length=None):
self._check_input(x, input_mask, init_reset, batch_valid_length)
x_shape = F.shape(x)
x = F.reshape(x, (-1, x_shape[-1]))
@ -1699,17 +1702,19 @@ class TransformerEncoderLayer(Cell):
def _check_input(self, x, input_mask, init_reset, batch_valid_length):
r"""Check inputs"""
if not self.use_past or (self.use_past and self.is_first_iteration):
_check_shape_equal(F.shape(x), "x", self.cls_name,
[[self.batch_size, self.seq_length, self.hidden_size],
[self.batch_size * self.seq_length, self.hidden_size]])
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
[self.batch_size, self.seq_length, self.seq_length])
_check_shape_equal_without_batch(F.shape(x), "x", self.cls_name,
[self.seq_length, self.hidden_size])
if input_mask is not None:
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
[F.shape(input_mask)[0], self.seq_length, self.seq_length])
else:
_check_shape_equal(F.shape(x), "x", self.cls_name, [self.batch_size, 1, self.hidden_size])
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
[self.batch_size, 1, self.seq_length])
if input_mask is not None:
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
[F.shape(input_mask)[0], 1, self.seq_length])
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
if input_mask is not None:
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
init_reset_is_tensor = isinstance(init_reset, Tensor)
init_reset_is_default = init_reset is True
@ -1738,7 +1743,9 @@ class TransformerDecoderLayer(Cell):
hidden_size(int): The hidden size of the input.
ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
num_heads(int): The number of the heads.
batch_size(int): The batch size of the input tensor.
batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
value. When do training or prediction, the argument will not work and the user can just pass None to
the argument.
src_seq_length(int): The input source sequence length.
tgt_seq_length(int): The input target sequence length.
attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1.
@ -1764,13 +1771,13 @@ class TransformerDecoderLayer(Cell):
- **hidden_stats** (Tensor) - The input tensor with shape [batch_size, tgt_seq_length, hidden_size] or
[batch_size * tgt_seq_length, hidden_size].
- **decoder_mask** (Tensor) - The attention mask for decoder with shape [batch_size, src_seq_length,
seq_length].
seq_length] or None. None means there will be no mask in softmax computation in self attention.
- **encoder_output** (Tensor) - The output of the encoder with shape [batch_size, seq_length, hidden_size]
or [batch_size * seq_length, hidden_size].
Note this args can not be passed by None when the net is in outermost layer. Default None.
- **memory_mask** (Tensor) - The memory mask of the cross attention with shape [batch, tgt_seq_length,
src_seq_length] where tgt_seq_length is the length of the decoder. Note this args can not be passed by
None when the net is in outermost layer. Default None.
src_seq_length] where tgt_seq_length is the length of the decoder. The user can also pass None. None
means there will be no mask in softmax computation in cross attention. Default None.
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
@ -1815,15 +1822,13 @@ class TransformerDecoderLayer(Cell):
"""
@_LogActionOnce(logger=logger, key='TransformerDecoderLayer',
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
@_args_type_validator_check(batch_size=Validator.check_positive_int,
hidden_size=Validator.check_positive_int,
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
num_heads=Validator.check_positive_int,
ffn_hidden_size=Validator.check_positive_int,
src_seq_length=Validator.check_positive_int,
tgt_seq_length=Validator.check_positive_int,
attention_dropout_rate=Validator.check_non_negative_float,
hidden_dropout_rate=Validator.check_non_negative_float,
hidden_act=_valid_type_checks([str], "TransformerDecoderLayer"),
post_layernorm_residual=Validator.check_bool,
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
"TransformerDecoderLayer"),
@ -1854,6 +1859,8 @@ class TransformerDecoderLayer(Cell):
_check_moe_config(moe_config, parallel_config)
self.use_moe = (moe_config.expert_num > 1)
config_to_attention = parallel_config.dpmp if self.use_moe else parallel_config
if batch_size or use_past:
Validator.check_positive_int(batch_size)
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
_check_config(parallel_config)
if num_heads % parallel_config.model_parallel != 0:
@ -2144,28 +2151,30 @@ class TransformerDecoderLayer(Cell):
def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length):
r"""Check inputs"""
if not self.use_past or (self.use_past and self.is_first_iteration):
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, self.tgt_seq_length, self.tgt_seq_length])
_check_shape_equal_without_batch(F.shape(hidden_states), "hidden_states", self.cls_name,
[self.tgt_seq_length, self.hidden_size])
if attention_mask is not None:
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[F.shape(attention_mask)[0], self.tgt_seq_length, self.tgt_seq_length])
else:
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
[self.batch_size, 1, self.hidden_size])
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, 1, self.tgt_seq_length])
if attention_mask is not None:
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, 1, self.tgt_seq_length])
_check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
if attention_mask is not None:
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16],
self.cls_name)
if encoder_output is not None:
_check_shape_equal(F.shape(encoder_output), "encoder_output", self.cls_name,
[[self.batch_size, self.src_seq_length, self.hidden_size],
[self.batch_size * self.src_seq_length, self.hidden_size]])
_check_shape_equal_without_batch(F.shape(encoder_output), "encoder_output", self.cls_name,
[self.src_seq_length, self.hidden_size])
_check_input_dtype(F.dtype(encoder_output), "encoder_output",
[mstype.float32, mstype.float16], self.cls_name)
if memory_mask is not None:
_check_shape_equal(F.shape(memory_mask), "memory_mask", self.cls_name,
[self.batch_size, self.tgt_seq_length, self.src_seq_length])
_check_shape_equal_without_batch(F.shape(memory_mask), "memory_mask", self.cls_name,
[self.tgt_seq_length, self.src_seq_length])
_check_input_dtype(F.dtype(memory_mask), "memory_mask",
[mstype.float32, mstype.float16], self.cls_name)
@ -2240,7 +2249,9 @@ class TransformerEncoder(Cell):
attention and feedforward layer.
Args:
batch_size(int): The batch size of the input tensor.
batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
value. When do training or prediction, the argument will not work and the user can just pass None to
the argument.
num_layers(int): The layers of the `TransformerEncoderLayer`
hidden_size(int): The hidden size of the input.
ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
@ -2284,7 +2295,9 @@ class TransformerEncoder(Cell):
- **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise,
should be [batch_size, 1, hidden_size].
- **attention_mask** (Tensor) - Tensor, attention mask with shape [batch_size, seq_length, seq_length]
- **attention_mask** (Tensor) - Float Tensor, If the use_past is False or is_first_iteration=True,
the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will
be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size]
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
@ -2361,7 +2374,6 @@ class TransformerEncoder(Cell):
offset=Validator.check_non_negative_int,
attention_dropout_rate=Validator.check_non_negative_float,
hidden_dropout_rate=Validator.check_non_negative_float,
hidden_act=_valid_type_checks([str], "TransformerEncoder"),
post_layernorm_residual=Validator.check_bool,
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
"TransformerEncoder"),
@ -2490,7 +2502,9 @@ class TransformerDecoder(Cell):
Args:
num_layers(int): The layers of the `TransformerDecoderLayer`.
batch_size(int): The batch size of the input tensor.
batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
value. When do training or prediction, the argument will not work and the user can just pass None to
the argument.
hidden_size(int): The hidden size of the input.
ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
src_seq_length(int): The input source sequence length.
@ -2528,13 +2542,14 @@ class TransformerDecoder(Cell):
- **hidden_stats** (Tensor) - The input tensor with shape [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size]
- **attention_mask** (Tensor) - The attention mask for decoder with shape
[batch_size, seq_length, seq_length]
[batch_size, seq_length, seq_length] or None. None means there will be no mask in softmax
computation in self attention.
- **encoder_output** (Tensor) - The output of the encoder with shape [batch_size, seq_length, hidden_size]
or [batch_size * seq_length, hidden_size]. Note this args can not be passed by None when the net is in
outermost layer. Default None.
- **memory_mask** (Tensor) - The memory mask of the cross attention with shape [batch, tgt_seq_length,
src_seq_length] where tgt_seq_length is the length of the decoder. Note this args can not be passed by
None when the net is in outermost layer. Default None.
src_seq_length] where tgt_seq_length is the length of the decoder. The user can also pass None. None
means there will be no mask in softmax computation in cross attention. Default None.
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
@ -2591,7 +2606,6 @@ class TransformerDecoder(Cell):
offset=Validator.check_non_negative_int,
attention_dropout_rate=Validator.check_non_negative_float,
hidden_dropout_rate=Validator.check_non_negative_float,
hidden_act=_valid_type_checks([str], "TransformerDecoder"),
post_layernorm_residual=Validator.check_bool,
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
"TransformerDecoder"),
@ -2736,7 +2750,9 @@ class Transformer(Cell):
Args:
hidden_size(int): The hidden size of the input.
batch_size(int): The batch size of the input tensor.
batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
value. When do training or prediction, the argument will not work and the user can just pass None to
the argument.
ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
src_seq_length(int): The seq_length of the encoder's input tensor.
tgt_seq_length(int): The seq_length of the decoder's input tensor.
@ -2772,15 +2788,17 @@ class Transformer(Cell):
- **encoder_inputs** (Tensor) - The input tensor with shape [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size].
- **encoder_masks** (Tensor) - The attention mask for decoder with shape
[batch_size, seq_length, seq_length].
[batch_size, seq_length, seq_length] or None. None means there will be no mask in softmax computation
in self attention of the encoder module.
- **decoder_inputs** (Tensor) - The output of the encoder with shape [batch_size, seq_length, hidden_size]
or [batch_size * seq_length, hidden_size], this should be none if the decoder layer is 0.
- **decoder_masks** (Tensor) - The attention mask for decoder with shape
[batch_size, seq_length, seq_length]
[batch_size, seq_length, seq_length] or None. None means there will be no mask in softmax computation
in self attention of the decoder module.
- **memory_mask** (Tensor) - The memory mask of the cross attention with shape [batch, tgt_seq_length,
src_seq_length]
where tgt_seq_length is the length of the decoder. The output of the encoder with shape [batch_size,
seq_length, hidden_size], this should be none if the decoder layer is 0.
seq_length, hidden_size], this should be none if the decoder layer is 0 or the user wants no mask.
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
@ -2854,7 +2872,6 @@ class Transformer(Cell):
tgt_seq_length=Validator.check_positive_int,
attention_dropout_rate=Validator.check_non_negative_float,
hidden_dropout_rate=Validator.check_non_negative_float,
hidden_act=_valid_type_checks([str], "Transformer"),
post_layernorm_residual=Validator.check_bool,
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
"Transformer"),

View File

@ -18,15 +18,28 @@ import shutil
import numpy as np
import pytest
import mindspore
from mindspore import Tensor
from mindspore.common import dtype
from mindspore.ops import operations as ops
from mindspore.parallel.nn import MultiHeadAttention, FeedForward, TransformerEncoderLayer, TransformerEncoder, \
TransformerDecoder, TransformerDecoderLayer, Transformer, CrossEntropyLoss, AttentionMask, FixedSparseAttention
from mindspore.common.api import _cell_graph_executor
class MyActivation(mindspore.nn.Cell):
def __init__(self):
super(MyActivation, self).__init__()
self.add = ops.Add()
def construct(self, x):
return self.add(x, 0.1)
def shard(self, parallel_config):
self.add.shard(((parallel_config.data_parallel, parallel_config.model_parallel), ()))
def test_transformer_encoder_only():
model = Transformer(batch_size=2,
src_seq_length=20,
@ -208,18 +221,42 @@ def test_multihead_attention():
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
def test_multihead_attention_wrong_batch():
@pytest.mark.parametrize('batch_size', [1, 2, None, 4])
def test_multihead_attention_wrong_batch(batch_size):
"""
Feature: Test MultiHeadAttention with wrong batch for training
Description: Test the batch size to be any int or None
Expectation: No exception
"""
model = MultiHeadAttention(hidden_size=15,
src_seq_length=20,
tgt_seq_length=20,
batch_size=2,
batch_size=batch_size,
num_heads=3)
from_tensor = Tensor(np.ones((3, 20, 15)), dtype.float32)
to_tensor = Tensor(np.ones((3, 20, 15)), dtype.float16)
attention_mask = Tensor(np.ones((3, 20, 20)), dtype.float16)
with pytest.raises(ValueError):
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
@pytest.mark.parametrize('from_tensor,to_tensor', [(Tensor(np.ones((20, 15)), dtype.float32),
Tensor(np.ones((20, 15)), dtype.float16)),
(Tensor(np.ones((3, 20, 15)), dtype.float32),
Tensor(np.ones((3, 20, 15)), dtype.float16))])
def test_multihead_attention_no_mask_2d_or_3d_shape(from_tensor, to_tensor):
"""
Feature: Test MultiHeadAttention no mask
Description: Test MultiHeadAttention no mask and 2d as inputs.
Expectation: No exception
"""
model = MultiHeadAttention(hidden_size=15,
src_seq_length=20,
tgt_seq_length=20,
batch_size=None,
num_heads=3)
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, None)
def test_multihead_attention_fp32_dtype():
@ -240,6 +277,154 @@ def test_multihead_attention_fp32_dtype():
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
@pytest.mark.parametrize('batch_size', [1, 2, None, 4])
def test_transformerencoder_wrong_batch(batch_size):
"""
Feature: Test TransformerEncoderLayer with wrong batch for training
Description: Test the batch size to be any int or None
Expectation: No exception
"""
model = TransformerEncoderLayer(batch_size=batch_size, hidden_size=8, ffn_hidden_size=64, seq_length=16,
num_heads=2)
encoder_input_value = Tensor(np.ones((2, 16, 8)), dtype.float32)
encoder_input_mask = Tensor(np.ones((2, 16, 16)), dtype.float16)
model(encoder_input_value, encoder_input_mask)
@pytest.mark.parametrize('attention_mask', [Tensor(np.ones((2, 16, 16)), dtype.float16),
None])
def test_transformerencoder_no_mask(attention_mask):
"""
Feature: Test TransformerEncoderLayer with no mask
Description: Test the attention mask is None
Expectation: No exception
"""
model = TransformerEncoderLayer(batch_size=None, hidden_size=8, ffn_hidden_size=64, seq_length=16,
num_heads=2)
encoder_input_value = Tensor(np.ones((2, 16, 8)), dtype.float32)
model(encoder_input_value, attention_mask)
@pytest.mark.parametrize('shape', [(2, 16, 8), (32, 8)])
def test_transformerencoder_2d_or_3d_shape(shape):
"""
Feature: Test TransformerEncoderLayer with 2d or 3d inputs
Description: Test the attention mask is None
Expectation: No exception
"""
model = TransformerEncoderLayer(batch_size=None, hidden_size=8, ffn_hidden_size=64, seq_length=16,
num_heads=2)
encoder_input_value = Tensor(np.ones(shape), dtype.float32)
model(encoder_input_value, None)
@pytest.mark.parametrize('batch_size', [1, 2, None, 4])
def test_transformerdecoder_wrong_batch(batch_size):
"""
Feature: Test TransformerDecoderLayer with wrong batch for training
Description: Test the batch size to be any int or None
Expectation: No exception
"""
model = TransformerDecoderLayer(batch_size=batch_size, hidden_size=64, ffn_hidden_size=64, num_heads=2,
src_seq_length=20, tgt_seq_length=10)
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
@pytest.mark.parametrize('decoder_input_mask,memory_mask',
[(None, None), (Tensor(np.ones((2, 10, 10)), dtype.float16), None),
(None, Tensor(np.ones((2, 10, 20)), dtype.float16))])
def test_transformerdecoder_mask(decoder_input_mask, memory_mask):
"""
Feature: Test TransformerDecoderLayer with empty mask
Description: Test the mask is None
Expectation: No exception
"""
model = TransformerDecoderLayer(batch_size=4, hidden_size=64, ffn_hidden_size=64, num_heads=2,
src_seq_length=20, tgt_seq_length=10)
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
def test_transformerdecoder_custom_activation():
"""
Feature: Test TransformerDecoderLayer custom activation
Description: Test TransformerDecoderLayer custom activation
Expectation: No exception
"""
model = TransformerDecoderLayer(batch_size=4, hidden_size=64, ffn_hidden_size=64, num_heads=2,
hidden_act=MyActivation,
src_seq_length=20, tgt_seq_length=10)
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
model(decoder_input_value, None, encoder_input_value, None)
@pytest.mark.parametrize('encoder_shape,decoder_shape', [((2, 20, 64), (2, 10, 64)),
((20, 64), (10, 64))])
def test_transformerdecoder_2d_or_3d_shape(encoder_shape, decoder_shape):
"""
Feature: Test TransformerDecoderLayer with 2d or 3d inputs
Description: Test the attention mask is None
Expectation: No exception
"""
model = TransformerDecoderLayer(batch_size=None, hidden_size=64, ffn_hidden_size=64, num_heads=2,
src_seq_length=20, tgt_seq_length=10)
encoder_input_value = Tensor(np.ones(encoder_shape), dtype.float32)
decoder_input_value = Tensor(np.ones(decoder_shape), dtype.float32)
model(decoder_input_value, None, encoder_input_value, None)
@pytest.mark.parametrize('hidden_act', [MyActivation, None, "relu"])
def test_transformer_hidden_act(hidden_act):
"""
Feature: Test Transformer hidden activation with activation or None
Description: Test the transformer hidden activation
Expectation: No exception
"""
model = Transformer(batch_size=2, encoder_layers=1, decoder_layers=2, hidden_size=64,
hidden_act=hidden_act,
ffn_hidden_size=64, src_seq_length=20, tgt_seq_length=10)
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
encoder_input_mask = Tensor(np.ones((2, 20, 20)), dtype.float16)
decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
model(encoder_input_value, encoder_input_mask, decoder_input_value,
decoder_input_mask, memory_mask)
def test_transformer_hidden_act_with_wrong_hidden_act_wrong_lambda_func():
"""
Feature: Test Transformer hidden activation with activation or None
Description: Test the transformer hidden activation
Expectation: No exception
"""
with pytest.raises(TypeError):
Transformer(batch_size=2, encoder_layers=1, decoder_layers=2, hidden_size=64,
hidden_act=lambda x: x,
ffn_hidden_size=64, src_seq_length=20, tgt_seq_length=10)
def test_transformer_hidden_act_with_wrong_hidden_act_wrong_str():
"""
Feature: Test Transformer hidden activation with wrong activation
Description: Test the transformer hidden activation
Expectation: No exception
"""
with pytest.raises(KeyError):
Transformer(batch_size=2, encoder_layers=1, decoder_layers=2, hidden_size=64,
hidden_act="no_string",
ffn_hidden_size=64, src_seq_length=20, tgt_seq_length=10)
def test_feedforward_layer():
model = FeedForward(hidden_size=15,
ffn_hidden_size=30,

View File

@ -624,17 +624,15 @@ def test_transformer_wrong_dp_no_error():
def test_transformer_wrong_semi_auto_dp_error():
"""
Feature: test transformer api
Description: Test transformer exception scene
Description: Test transformer parallel batch with no check
Expectation: Raise correct error.
"""
set_auto_parallel_context(device_num=64, full_batch=False, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
check_config = TransformerOpParallelConfig(data_parallel=16, model_parallel=1, vocab_emb_dp=False)
with pytest.raises(ValueError):
net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64,
parallel_config=check_config)
del net
Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64,
parallel_config=check_config)
def test_encoder():