Impove the transforer apis
Support attention mask is None Fix attention check WIP in custom activation Support attention
This commit is contained in:
parent
587091ea47
commit
908975d458
|
@ -10,7 +10,7 @@
|
|||
如果query tensor、key tensor和value tensor相同,则上述即为自注意力机制的计算过程。
|
||||
|
||||
参数:
|
||||
- **batch_size** (int) - 表示训练批次的大小。
|
||||
- **batch_size** (int) - 表示增量预测时输入张量的批量大小,应该是正整数。当进行训练或预测时,该参数将不起作用,用户可将None传递给此参数。
|
||||
- **src_seq_length** (int) - 表示query向量的序列长度。
|
||||
- **tgt_seq_length** (int) - 表示key向量和value向量的序列长度。
|
||||
- **hidden_size** (int) - 表示输入的向量大小。
|
||||
|
@ -28,7 +28,7 @@
|
|||
- **query_tensor** (Tensor) - use_past为False或is_first_iteration为True时,表示shape为(batch_size, src_seq_length, hidden_size)或(batch_size * src_seq_length, hidden_size)的query向量。否则,shape必须为(batch_size, 1, hidden_size)。
|
||||
- **key_tensor** (Tensor) - use_past为False或is_first_iteration为True时,表示shape为(batch_size, tgt_seq_length, hidden_size)或(batch_size * tgt_seq_length, hidden_size)的key向量。否则,shape必须为(batch_size, 1, hidden_size)。
|
||||
- **value_tensor** (Tensor) - use_past为False或is_first_iteration为True时,表示shape为(batch_size, tgt_seq_length, hidden_size)或(batch_size * tgt_seq_length, hidden_size)的value向量。否则,shape必须为(batch_size, 1, hidden_size)。
|
||||
- **attention_mask** (Tensor) - use_past为False或is_first_iteration为True时,表示shape为(batch_size, src_seq_length, tgt_seq_length)的注意力掩码矩阵。否则,shape必须为(batch_size, 1, tgt_seq_length)。
|
||||
- **attention_mask** (Tensor) - use_past为False或is_first_iteration为True时,表示shape为(batch_size, src_seq_length, tgt_seq_length)的注意力掩码矩阵, 或者为None,None表示在Softmax计算中将不会进行掩码。否则,shape必须为(batch_size, 1, tgt_seq_length)。
|
||||
- **key_past** (Tensor) - shape为(batch_size, num_heads, size_per_head, tgt_seq_length)的Float16 tensor,表示过去所计算的key向量。
|
||||
当use_past为True时,需要传入非None值用于增量预测。默认值为None。
|
||||
- **value_past** (Tensor) - shape为(batch_size, num_heads, tgt_seq_length, size_per_head)的Float16 tensor,表示过去所计算的value向量。
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
这是一个实验接口,可能会被更改或者删除。
|
||||
|
||||
参数:
|
||||
- **batch_size** (int) - 表示输入的批次大小。
|
||||
- **batch_size** (int) - 表示增量预测时输入张量的批量大小,应该是正整数。当进行训练或预测时,该参数将不起作用,用户可将None传递给此参数。
|
||||
- **encoder_layers** (int) - 表示 `TransformerEncoderLayer` 的层数。
|
||||
- **decoder_layers** (int) - 表示 `TransformerDecoderLayer` 的层数。
|
||||
- **hidden_size** (int) - 表示输入向量的大小。
|
||||
|
@ -29,10 +29,10 @@
|
|||
|
||||
输入:
|
||||
- **encoder_inputs** (Tensor) - shape为[batch_size, seq_length, hidden_size]或[batch_size * seq_length, hidden_size]的输入Tensor。
|
||||
- **encoder_masks** (Tensor) - shape为[batch_size, seq_length, seq_length]的解码器的注意力掩码。
|
||||
- **encoder_masks** (Tensor) - shape为[batch_size, seq_length, seq_length]的解码器的注意力掩码。或者为None,None表示在编码器中self attention中的Softmax计算中将不会进行掩码。
|
||||
- **decoder_inputs** (Tensor) - shape为[batch_size, seq_length, hidden_size]或[batch_size * seq_length, hidden_size]的编码器的输出。如果解码器层数为0,则此值应为None。
|
||||
- **decoder_masks** (Tensor) - shape为[batch_size, seq_length, seq_length]的解码器的注意力掩码。
|
||||
- **memory_mask** (Tensor) - shape为[batch, tgt_seq_length, src_seq_length]的交叉注意力的memory掩码,其中tgt_seq_length表示解码器的长度。如果解码器层为0,则shape为[batch_size, seq_length, hidden_size]的编码器的输出应为None。
|
||||
- **decoder_masks** (Tensor) - shape为[batch_size, seq_length, seq_length]的解码器的注意力掩码。或者为None,None表示将不会在解码器中的self attention中的Softmax计算中引入掩码计算。
|
||||
- **memory_mask** (Tensor) - shape为[batch, tgt_seq_length, src_seq_length]的交叉注意力的memory掩码,其中tgt_seq_length表示解码器的长度。或者为None,None表示将不会在cross attention中的Softmax计算中引入掩码计算。
|
||||
- **init_reset** (Tensor) - shape为[1]的bool tensor,用于清除增量预测中使用的past key参数和past value参数。仅当use_past为True时有效。默认值为True。
|
||||
- **batch_valid_length** (Tensor) - shape为[batch_size]的Int32 tensor,表示过去所计算的索引。当use_past为True时,它用于增量预测。默认值为None。
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
Transformer中的解码器模块,为多层堆叠的 `TransformerDecoderLayer` ,包括多头自注意力层、交叉注意力层和前馈层。
|
||||
|
||||
参数:
|
||||
- **batch_size** (int) - 表示输入Tensor的批次大小。
|
||||
- **batch_size** (int) - 表示增量预测时输入张量的批量大小,应该是正整数。当进行训练或预测时,该参数将不起作用,用户可将None传递给此参数。
|
||||
- **num_layers** (int) - 表示 `TransformerDecoderLayer` 的层数。
|
||||
- **hidden_size** (int) - 表示输入的隐藏大小。
|
||||
- **ffn_hidden_size** (int) - 表示前馈层中bottleneck的隐藏大小。
|
||||
|
@ -25,12 +25,12 @@
|
|||
|
||||
输入:
|
||||
- **hidden_stats** (Tensor) - shape为[batch_size, seq_length, hidden_size]或[batch_size * seq_length, hidden_size]的输入tensor。
|
||||
- **attention_mask** (Tensor) - shape为[batch_size, seq_length, seq_length]的解码器的注意力掩码。
|
||||
- **attention_mask** (Tensor) - shape为[batch_size, seq_length, seq_length]的解码器的注意力掩码。或者为None,None表示将不会在self attention中的Softmax计算中引入掩码计算。
|
||||
- **encoder_output** (Tensor) - shape为[batch_size, seq_length, hidden_size]或[batch_size * seq_length, hidden_size]的编码器的输出。
|
||||
|
||||
.. note::当网络位于最外层时,此参数不能通过None传递。默认值为None。
|
||||
|
||||
- **memory_mask** (Tensor) - shape为[batch, tgt_seq_length, src_seq_length]的交叉注意力的memory掩码,其中tgt_seq_length表示解码器的长度。注:当网络位于最外层时,此参数不能通过None传递。默认值为None。
|
||||
- **memory_mask** (Tensor) - shape为[batch, tgt_seq_length, src_seq_length]的交叉注意力的memory掩码,其中tgt_seq_length表示解码器的长度。或者为None,None表示将不会在cross attention中的Softmax计算中引入掩码计算。
|
||||
- **init_reset** (Tensor) - shape为[1]的bool tensor,用于清除增量预测中使用的past key参数和past value参数。仅当use_past为True时有效。默认值为True。
|
||||
- **batch_valid_length** (Tensor) - shape为[batch_size]的Int32 tensor,表示过去所计算的索引。当use_past为True时,它用于增量预测。默认值为None。
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
Transformer的解码器层。Transformer的解码器层上的单层的实现,包括自注意力层、交叉注意力层和前馈层。当encoder_output为None时,交叉注意力将无效。
|
||||
|
||||
参数:
|
||||
- **batch_size** (int) - 表示输入Tensor的批次大小。
|
||||
- **batch_size** (int) - 表示增量预测时输入张量的批量大小,应该是正整数。当进行训练或预测时,该参数将不起作用,用户可将None传递给此参数。
|
||||
- **hidden_size** (int) - 表示输入的隐藏大小。
|
||||
- **src_seq_length** (int) - 表示输入源序列长度。
|
||||
- **tgt_seq_length** (int) - 表示输入目标序列长度。
|
||||
|
@ -22,9 +22,9 @@
|
|||
|
||||
输入:
|
||||
- **hidden_stats** (Tensor) - shape为[batch_size, tgt_seq_length, hidden_size]或[batch_size * tgt_seq_length, hidden_size]的输入tensor。
|
||||
- **decoder_mask** (Tensor) - shape为[batch_size, src_seq_length, seq_length]的解码器的注意力掩码。
|
||||
- **decoder_mask** (Tensor) - shape为[batch_size, src_seq_length, seq_length]的解码器的注意力掩码。或者为None,None表示将不会在self attention中的Softmax计算中引入掩码计算。
|
||||
- **encoder_output** (Tensor) - shape为[batch_size, seq_length, hidden_size]或[batch_size * seq_length, hidden_size]的编码器的输出。注:当网络位于最外层时,此参数不能通过None传递。默认值为None。
|
||||
- **memory_mask** (Tensor) - shape为[batch, tgt_seq_length, src_seq_length]的交叉注意力的memory掩码,其中tgt_seq_length表示解码器的长度。注:当网络位于最外层时,此参数不能通过None传递。默认值为None。
|
||||
- **memory_mask** (Tensor) - shape为[batch, tgt_seq_length, src_seq_length]的交叉注意力的memory掩码,其中tgt_seq_length表示解码器的长度。或者为None,None表示将不会在cross attention中的Softmax计算中引入掩码计算。
|
||||
- **init_reset** (Tensor) - shape为[1]的bool tensor,用于清除增量预测中使用的past key参数和past value参数。仅当use_past为True时有效。默认值为True。
|
||||
- **batch_valid_length** (Tensor) - shape为[batch_size]的Int32 tensor,表示过去所计算的索引。当use_past为True时,它用于增量预测。默认值为None。
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
Transformer中的编码器模块,具有多层堆叠的 `TransformerEncoderLayer` ,包括多头自注意力层和前馈层。
|
||||
|
||||
参数:
|
||||
- **batch_size** (int) - 表示输入tensor的批次大小。
|
||||
- **batch_size** (int) - 表示增量预测时输入张量的批量大小,应该是正整数。当进行训练或预测时,该参数将不起作用,用户可将None传递给此参数。
|
||||
- **num_layers** (int) - 表示 `TransformerEncoderLayer` 的层。
|
||||
- **hidden_size** (int) - 表示输入的隐藏大小。
|
||||
- **ffn_hidden_size** (int) - 表示前馈层中bottleneck的隐藏大小。
|
||||
|
@ -24,7 +24,7 @@
|
|||
|
||||
输入:
|
||||
- **hidden_states** (Tensor) - Tensor。如果use_past为False或者is_first_iteration为True,shape为[batch_size, seq_length, hidden_size]或者[batch_size * seq_length, hidden_size]。否则,shape应为[batch_size, 1, hidden_size]。
|
||||
- **attention_mask** (Tensor) - Tensor,表示shape为[batch_size, seq_length, seq_length]的注意力掩码。
|
||||
- **attention_mask** (Tensor) - Tensor。use_past为False或者is_first_iteration为True时,表示shape为[batch_size, seq_length, seq_length]的注意力掩码,或者为None,None表示在Softmax计算中将不会进行掩码。否则,shape应为[batch_size, 1, hidden_size]。
|
||||
- **init_reset** (Tensor) - shape为[1]的bool tensor,用于清除增量预测中使用的past key参数和past value参数。仅当use_past为True时有效。默认值为True。
|
||||
- **batch_valid_length** (Tensor) - shape为[batch_size]的Int32 tensor,表示过去所计算的索引。当use_past为True时,它用于增量预测。默认值为None。
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
Transformer的编码器层。Transformer的编码器层上的单层的实现,包括多头注意力层和前馈层。
|
||||
|
||||
参数:
|
||||
- **batch_size** (int) - 表示输入Tensor的批次大小。
|
||||
- **batch_size** (int) - 表示增量预测时输入张量的批量大小,应该是正整数。当进行训练或预测时,该参数将不起作用,用户可将None传递给此参数。
|
||||
- **hidden_size** (int) - 表示输入的隐藏大小。
|
||||
- **seq_length** (int) - 表示输入序列长度。
|
||||
- **ffn_hidden_size** (int) - 表示前馈层中bottleneck的隐藏大小。
|
||||
|
@ -21,7 +21,7 @@
|
|||
|
||||
输入:
|
||||
- **x** (Tensor) - Float Tensor。如果use_past为False或者is_first_iteration为True,shape应为[batch_size, seq_length, hidden_size]或者[batch_size * seq_length, hidden_size]。否则,shape应为[batch_size, 1, hidden_size]。
|
||||
- **input_mask** (Tensor) - Float tensor。use_past为False或者is_first_iteration为True时,表示shape为[batch_size, seq_length, seq_length]的注意力掩码。否则,shape应为[batch_size, 1, hidden_size]。
|
||||
- **input_mask** (Tensor) - Float tensor。use_past为False或者is_first_iteration为True时,表示shape为[batch_size, seq_length, seq_length]的注意力掩码,或者为None,None表示在Softmax计算中将不会进行掩码。否则,shape应为[batch_size, 1, hidden_size]。
|
||||
- **init_reset** (Tensor) - shape为[1]的bool tensor,用于清除增量预测中使用的past key参数和past value参数。仅当use_past为True时有效。默认值为True。
|
||||
- **batch_valid_length** (Tensor) - shape为[batch_size]的Int32 tensor,表示过去所计算的索引。当use_past为True时,它用于增量预测。默认值为None。
|
||||
|
||||
|
|
|
@ -37,7 +37,8 @@ from mindspore._checkparam import Validator
|
|||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.transformer.op_parallel_config import default_dpmp_config, OpParallelConfig
|
||||
from mindspore.nn.transformer.op_parallel_config import default_dpmp_config, OpParallelConfig, MoEParallelConfig
|
||||
from mindspore import log as logger
|
||||
|
||||
__all__ = [
|
||||
"FixedSparseAttention"
|
||||
|
@ -154,6 +155,30 @@ class _LayerInputCheck:
|
|||
f"but got {input_shape[dim]}")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def check_shape_equal_without_batch(input_shape, param_name, func_name, target_shape):
|
||||
"""
|
||||
Check the input shape's is equal to the expected shape, the value on 0-th is viewed as batch, and the
|
||||
batch size will not be checked.
|
||||
"""
|
||||
target_shape = target_shape
|
||||
length, hidden = target_shape
|
||||
if isinstance(input_shape, tuple):
|
||||
input_shape = list(input_shape)
|
||||
_LayerInputCheck.check_shape_length(input_shape, param_name, func_name,
|
||||
[len(target_shape), len(target_shape) + 1])
|
||||
if input_shape[-1] != hidden:
|
||||
raise ValueError(f"For {func_name}, the last dimension of {param_name} shape must be {hidden},"
|
||||
f"but got the last dimension {input_shape[-1]} in {input_shape}.")
|
||||
if input_shape[0] == 0:
|
||||
raise ValueError(f"For {func_name}, the first dimension of {param_name} shape greater than 0,"
|
||||
f"but got the first dimension {input_shape[0]} in {input_shape}.")
|
||||
if len(input_shape) == 2 and input_shape[0] % length != 0:
|
||||
raise ValueError(f"For {func_name}, the first dimension of {param_name} shape should be divisible "
|
||||
f"by {length}, "
|
||||
f"but got the first dimension {input_shape[0]} in {input_shape}.")
|
||||
return True
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_past_none_input_none(use_past, param_name, func_name, default_value, is_tensor, is_default):
|
||||
|
@ -192,6 +217,11 @@ def _check_input_shape_value(input_shape, dim, param_name, cls_name, target_valu
|
|||
_LayerInputCheck.check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_shape_equal_without_batch(input_shape, param_name, func_name, target_shape):
|
||||
_LayerInputCheck.check_shape_equal_without_batch(input_shape, param_name, func_name, target_shape)
|
||||
|
||||
|
||||
class _Dropout(nn.Cell):
|
||||
r"""
|
||||
A Dropout Implements with P.DropoutGenMask and P.DropoutDoMask for parallel training.
|
||||
|
@ -392,7 +422,6 @@ class _Linear(Cell):
|
|||
@_args_type_validator_check(in_channels=Validator.check_positive_int,
|
||||
out_channels=Validator.check_positive_int,
|
||||
has_bias=Validator.check_bool,
|
||||
activation=_valid_type_checks([type(None), str], "Linear"),
|
||||
transpose_b=Validator.check_bool,
|
||||
expert_num=Validator.check_positive_int,
|
||||
outer_batch=Validator.check_positive_int,
|
||||
|
@ -449,7 +478,10 @@ class _Linear(Cell):
|
|||
self.bias.parallel_optimizer = False
|
||||
self.bias_add = P.Add()
|
||||
self.act_name = activation
|
||||
self.activation = get_activation(activation) if isinstance(activation, str) else activation
|
||||
if callable(activation):
|
||||
self.activation = activation()
|
||||
else:
|
||||
self.activation = get_activation(activation) if isinstance(activation, str) else activation
|
||||
self.activation_flag = self.activation is not None
|
||||
self.dtype = compute_dtype
|
||||
self.cast = P.Cast()
|
||||
|
@ -491,7 +523,7 @@ class _Linear(Cell):
|
|||
self.matmul.shard(strategy_matmul)
|
||||
if self.has_bias:
|
||||
self.bias_add.shard(strategy_bias)
|
||||
if self.activation_flag:
|
||||
if self.activation_flag and isinstance(self.act_name, str):
|
||||
# some operations has many primitives, need to manually set the shard
|
||||
if self.act_name.lower() == "leakyrelu":
|
||||
self.activation.select_op.shard((strategy_activation[0], strategy_activation[0]))
|
||||
|
@ -506,7 +538,26 @@ class _Linear(Cell):
|
|||
"or auto parallel mode.")
|
||||
else:
|
||||
getattr(self.activation, self.act_name).shard(strategy_activation)
|
||||
|
||||
elif self.activation_flag and isinstance(self.activation, Cell):
|
||||
if hasattr(self.activation, 'shard') and strategy_activation:
|
||||
shard_tuple = strategy_activation[0]
|
||||
if len(shard_tuple) == 2:
|
||||
parallel_config = OpParallelConfig(data_parallel=shard_tuple[0],
|
||||
model_parallel=shard_tuple[1])
|
||||
elif len(shard_tuple) == 4:
|
||||
parallel_config = MoEParallelConfig(data_parallel=shard_tuple[0],
|
||||
expert_parallel=shard_tuple[1],
|
||||
model_parallel=shard_tuple[2])
|
||||
else:
|
||||
raise ValueError("The user-defined activation function currently only supports the case where the "
|
||||
"input policy is 2 or 4, so that relevant policies can be extracted from it."
|
||||
"To avoid this error, you need to add the function of extracting "
|
||||
"'ParallelConfig' or 'OpParallelConfig' from the incoming strategy_activation ")
|
||||
self.activation.shard(parallel_config)
|
||||
else:
|
||||
logger.warning(f"The user passed the custom defined activation function {self.activation_flag}. "
|
||||
f"If the user want to enable shard for the activation cell, "
|
||||
f"the user should set the shard for each primitives in the cell.")
|
||||
return self
|
||||
|
||||
|
||||
|
|
|
@ -37,8 +37,9 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.log import _LogActionOnce
|
||||
from mindspore.nn.transformer.layers import _LayerNorm, _Linear, _check_input_shape, \
|
||||
_args_type_validator_check, _valid_type_checks, _valid_value_checks, \
|
||||
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value
|
||||
from mindspore.nn.transformer.op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig,\
|
||||
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value, \
|
||||
_check_shape_equal_without_batch
|
||||
from mindspore.nn.transformer.op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, \
|
||||
_Config, _check_config, MoEParallelConfig
|
||||
from mindspore.nn.transformer.moe import default_moe_config, MoE, _check_moe_config
|
||||
|
||||
|
@ -399,7 +400,6 @@ class FeedForward(Cell):
|
|||
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
|
||||
ffn_hidden_size=Validator.check_positive_int,
|
||||
dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_act=_valid_type_checks([str], "FeedForward"),
|
||||
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"FeedForward"),
|
||||
parallel_config=_valid_type_checks([OpParallelConfig, MoEParallelConfig],
|
||||
|
@ -718,7 +718,9 @@ class MultiHeadAttention(Cell):
|
|||
if query, key and value tensor is same, then it will be self attention.
|
||||
|
||||
Args:
|
||||
batch_size(int): The batch size of the input tensor.
|
||||
batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
|
||||
value. When do training or prediction, the argument will not work and the user can just pass None to
|
||||
the argument.
|
||||
src_seq_length(int): The sequence length of the query vector.
|
||||
tgt_seq_length(int): The sequence length of the key and value vector.
|
||||
hidden_size(int): The hidden size of the input.
|
||||
|
@ -751,9 +753,9 @@ class MultiHeadAttention(Cell):
|
|||
- **value_tensor** (Tensor) - The value vector with shape (batch_size, tgt_seq_length, hidden_size) or
|
||||
(batch_size * tgt_seq_length, hidden_size), if the use_past is False or is_first_iteration=True.
|
||||
Otherwise, must be (batch_size, 1, hidden_size)
|
||||
- **attention_mask** (Tensor) - The attention mask matrix with shape (batch_size, src_seq_length,
|
||||
tgt_seq_length), if the use_past is False or is_first_iteration=True. Otherwise,
|
||||
must be (batch_size, 1, tgt_seq_length)
|
||||
- **attention_mask** (Tensor) - If the use_past is False or is_first_iteration=True, the attention mask
|
||||
matrix should ba (batch_size, src_seq_length, tgt_seq_length), or None. None means there will be no mask
|
||||
in softmax computation. Otherwise, the mask must be (batch_size, 1, tgt_seq_length)
|
||||
- **key_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, size_per_head, tgt_seq_length).
|
||||
The past calculated key vector. Used for incremental prediction when the use_past is True.
|
||||
Default None.
|
||||
|
@ -783,7 +785,7 @@ class MultiHeadAttention(Cell):
|
|||
>>> from mindspore.nn.transformer import MultiHeadAttention
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> from mindspore import Tensor
|
||||
>>> model = MultiHeadAttention(batch_size=2, hidden_size=15, src_seq_length=20, tgt_seq_length=20,
|
||||
>>> model = MultiHeadAttention(batch_size=None, hidden_size=15, src_seq_length=20, tgt_seq_length=20,
|
||||
... num_heads=3)
|
||||
>>> from_tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
|
||||
>>> to_tensor = Tensor(np.ones((2, 20, 15)), mstype.float16)
|
||||
|
@ -830,8 +832,7 @@ class MultiHeadAttention(Cell):
|
|||
"""
|
||||
@_LogActionOnce(logger=logger, key='MultiHeadAttention',
|
||||
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
||||
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
||||
hidden_size=Validator.check_positive_int,
|
||||
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
|
||||
num_heads=Validator.check_positive_int,
|
||||
src_seq_length=Validator.check_positive_int,
|
||||
tgt_seq_length=Validator.check_positive_int,
|
||||
|
@ -860,10 +861,13 @@ class MultiHeadAttention(Cell):
|
|||
parallel_config=default_dpmp_config):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
self._is_ascend = context.get_context('device_target') in ["Ascend"]
|
||||
self.dp = parallel_config.data_parallel
|
||||
self.is_parallel_mode = _get_parallel_mode() in (
|
||||
ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
if batch_size:
|
||||
Validator.check_positive_int(batch_size)
|
||||
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
|
||||
_check_config(parallel_config)
|
||||
self.is_parallel_mode = _get_parallel_mode() in (
|
||||
ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
self.src_seq_length = src_seq_length
|
||||
self.tgt_seq_length = tgt_seq_length
|
||||
self.hidden_size = hidden_size
|
||||
|
@ -883,11 +887,6 @@ class MultiHeadAttention(Cell):
|
|||
"'parallel_config.model_parallel', but got the num_heads is {} "
|
||||
"and the parallel_config.model_parallel is {}."
|
||||
.format(num_heads, parallel_config.model_parallel))
|
||||
if self.is_parallel_mode and batch_size % parallel_config.data_parallel != 0:
|
||||
raise ValueError("For 'MultiHeadAttention', the class variable 'batch_size' must be a multiple of "
|
||||
"'parallel_config.data_parallel', but got the batch_size is {} "
|
||||
"and the parallel_config.data_parallel is {}."
|
||||
.format(batch_size, parallel_config.data_parallel))
|
||||
self.is_first_iteration = True
|
||||
# Output layer
|
||||
self.projection = _Linear(in_channels=hidden_size,
|
||||
|
@ -961,8 +960,6 @@ class MultiHeadAttention(Cell):
|
|||
self.mul1 = P.Mul().shard(((1, 1, 1, 1), (1, 1, 1, 1)))
|
||||
else:
|
||||
_check_config(parallel_config)
|
||||
self.is_parallel_mode = _get_parallel_mode() in (
|
||||
ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
self.src_seq_length = src_seq_length
|
||||
self.tgt_seq_length = tgt_seq_length
|
||||
self.hidden_size = hidden_size
|
||||
|
@ -982,11 +979,6 @@ class MultiHeadAttention(Cell):
|
|||
"'parallel_config.model_parallel', but got the num_heads is {} "
|
||||
"and the parallel_config.model_parallel is {}."
|
||||
.format(num_heads, parallel_config.model_parallel))
|
||||
if self.is_parallel_mode and batch_size % parallel_config.data_parallel != 0:
|
||||
raise ValueError("For 'MultiHeadAttention', the class variable 'batch_size' must be a multiple of "
|
||||
"'parallel_config.data_parallel', but got the batch_size is {} "
|
||||
"and the parallel_config.data_parallel is {}."
|
||||
.format(batch_size, parallel_config.data_parallel))
|
||||
self.is_first_iteration = True
|
||||
# Output layer
|
||||
self.projection = _Linear(in_channels=hidden_size,
|
||||
|
@ -1086,10 +1078,16 @@ class MultiHeadAttention(Cell):
|
|||
value_past=None, batch_valid_length=None):
|
||||
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
|
||||
value_past, batch_valid_length)
|
||||
query_tensor, key_tensor, value_tensor, batch_size, ori_shape = self._convert_to_2d_tensor(query_tensor,
|
||||
key_tensor,
|
||||
value_tensor,
|
||||
attention_mask)
|
||||
ori_shape = F.shape(query_tensor)
|
||||
batch_size = None
|
||||
if len(F.shape(query_tensor)) == 2:
|
||||
batch_size = F.shape(query_tensor)[0] // self.src_seq_length
|
||||
else:
|
||||
batch_size = F.shape(query_tensor)[0]
|
||||
query_tensor, key_tensor, value_tensor = self._convert_to_2d_tensor(query_tensor,
|
||||
key_tensor,
|
||||
value_tensor,
|
||||
attention_mask)
|
||||
ori_dtype = F.dtype(query_tensor)
|
||||
query_tensor = F.cast(query_tensor, self.dtype)
|
||||
key_tensor = F.cast(key_tensor, self.dtype)
|
||||
|
@ -1116,7 +1114,7 @@ class MultiHeadAttention(Cell):
|
|||
(batch_size, -1, self.n_head, self.size_per_head)),
|
||||
(0, 2, 1, 3))
|
||||
# support input shape is [bs, seq, seq] or [bs, heads, seq, seq]
|
||||
if len(F.shape(attention_mask)) == 3:
|
||||
if attention_mask is not None and len(F.shape(attention_mask)) == 3:
|
||||
# expand attention mask from [bs, seq, seq] -> [bs, 1, seq, seq]
|
||||
attention_mask = self.expand_dims(attention_mask, 1)
|
||||
# key and value for current token(s)
|
||||
|
@ -1171,17 +1169,15 @@ class MultiHeadAttention(Cell):
|
|||
value_past=None, batch_valid_length=None):
|
||||
r"""Check inputs"""
|
||||
if not self.use_past or (self.use_past and self.is_first_iteration):
|
||||
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
|
||||
[[self.batch_size, self.src_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.src_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
|
||||
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
|
||||
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, self.src_seq_length, self.tgt_seq_length])
|
||||
_check_shape_equal_without_batch(F.shape(query_tensor), "query_tensor", self.cls_name,
|
||||
[self.src_seq_length, self.hidden_size])
|
||||
_check_shape_equal_without_batch(F.shape(key_tensor), "key_tensor", self.cls_name,
|
||||
[self.tgt_seq_length, self.hidden_size])
|
||||
_check_shape_equal_without_batch(F.shape(value_tensor), "value_tensor", self.cls_name,
|
||||
[self.tgt_seq_length, self.hidden_size])
|
||||
if attention_mask is not None:
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[F.shape(attention_mask)[0], self.src_seq_length, self.tgt_seq_length])
|
||||
else:
|
||||
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
|
||||
[[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
|
||||
|
@ -1189,13 +1185,16 @@ class MultiHeadAttention(Cell):
|
|||
[[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
|
||||
[[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[[self.batch_size, 1, self.tgt_seq_length], [self.batch_size, self.hidden_size]])
|
||||
if attention_mask is not None:
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[[self.batch_size, 1, self.tgt_seq_length], [self.batch_size, self.hidden_size]])
|
||||
|
||||
_check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(value_tensor), "value_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
|
||||
if attention_mask is not None:
|
||||
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16],
|
||||
self.cls_name)
|
||||
|
||||
key_is_tensor = isinstance(key_past, Tensor)
|
||||
value_is_tensor = isinstance(value_past, Tensor)
|
||||
|
@ -1228,7 +1227,8 @@ class MultiHeadAttention(Cell):
|
|||
key_tensor = F.reshape(key_tensor, (-1, key_shape[-1]))
|
||||
value_shape = F.shape(value_tensor)
|
||||
value_tensor = F.reshape(value_tensor, (-1, value_shape[-1]))
|
||||
return query_tensor, key_tensor, value_tensor, F.shape(attention_mask)[0], query_shape
|
||||
|
||||
return query_tensor, key_tensor, value_tensor
|
||||
|
||||
def _merge_heads(self, x):
|
||||
"""
|
||||
|
@ -1286,30 +1286,31 @@ class MultiHeadAttention(Cell):
|
|||
score = self.batch_matmul(query, key)
|
||||
|
||||
ori_dtype = P.DType()(score)
|
||||
score = P.Cast()(score, self.softmax_dtype)
|
||||
attention_scores = P.Cast()(score, self.softmax_dtype)
|
||||
|
||||
# for input size of (bs, 1) namely the second graph,
|
||||
# the shape of attention_mask matrix should be (bs, 1, 1, seq_length)
|
||||
if self.use_past and not self.is_first_iteration:
|
||||
# Calculate the current total token
|
||||
current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0),
|
||||
(F.shape(query)[0], 1, 1, self.seq_length),
|
||||
(1, 1, 1, 1)),
|
||||
0), mstype.float32), (1, 2, 3))
|
||||
# Get the precise position index
|
||||
index = self.sub1(F.cast(current_index, mstype.int32), 1)
|
||||
index = F.reshape(index, (-1, 1, 1))
|
||||
# Calculate the attention_mask matrix via the position index
|
||||
attention_mask = F.cast(self.tensor_le(self.range, index), mstype.int32)
|
||||
attention_mask = self.expand_dims(attention_mask, 2)
|
||||
if attention_mask is not None:
|
||||
if self.use_past and not self.is_first_iteration:
|
||||
# Calculate the current total token
|
||||
current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0),
|
||||
(F.shape(query)[0], 1, 1,
|
||||
self.seq_length),
|
||||
(1, 1, 1, 1)),
|
||||
0), mstype.float32), (1, 2, 3))
|
||||
# Get the precise position index
|
||||
index = self.sub1(F.cast(current_index, mstype.int32), 1)
|
||||
index = F.reshape(index, (-1, 1, 1))
|
||||
# Calculate the attention_mask matrix via the position index
|
||||
attention_mask = F.cast(self.tensor_le(self.range, index), mstype.int32)
|
||||
attention_mask = self.expand_dims(attention_mask, 2)
|
||||
# Minus 10000 for the position where masked to exclude them from softmax
|
||||
multiplu_out = self.sub(
|
||||
P.Cast()(F.tuple_to_array((1.0,)), P.DType()(attention_scores)),
|
||||
P.Cast()(attention_mask, P.DType()(attention_scores)))
|
||||
|
||||
# Minus 10000 for the position where masked to exclude them from softmax
|
||||
multiplu_out = self.sub(
|
||||
P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)),
|
||||
P.Cast()(attention_mask, P.DType()(score)))
|
||||
|
||||
adder = self.mul(multiplu_out, self.multiply_data)
|
||||
attention_scores = self.add(adder, score)
|
||||
adder = self.mul(multiplu_out, self.multiply_data)
|
||||
attention_scores = self.add(adder, attention_scores)
|
||||
|
||||
# attention probs
|
||||
attention_probs = self._softmax(attention_scores)
|
||||
|
@ -1328,7 +1329,9 @@ class TransformerEncoderLayer(Cell):
|
|||
encoder layer, including multihead attention and feedward layer.
|
||||
|
||||
Args:
|
||||
batch_size(int): The batch size of the input tensor.
|
||||
batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
|
||||
value. When do training or prediction, the argument will not work and the user can just pass None to
|
||||
the argument.
|
||||
hidden_size(int): The hidden size of the input.
|
||||
ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
|
||||
num_heads(int): The number of the heads.
|
||||
|
@ -1362,8 +1365,9 @@ class TransformerEncoderLayer(Cell):
|
|||
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or
|
||||
[batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise,
|
||||
should be [batch_size, 1, hidden_size]
|
||||
- **input_mask** (Tensor) - Float Tensor, attention mask with shape [batch_size, seq_length, seq_length],
|
||||
if the use_past is False or is_first_iteration=True. Otherwise, should be [batch_size, 1, hidden_size]
|
||||
- **input_mask** (Tensor) - Float Tensor, If the use_past is False or is_first_iteration=True,
|
||||
the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will
|
||||
be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size]
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
|
||||
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
|
||||
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
|
||||
|
@ -1430,14 +1434,12 @@ class TransformerEncoderLayer(Cell):
|
|||
"""
|
||||
@_LogActionOnce(logger=logger, key='TransformerEncoderLayer',
|
||||
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
||||
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
||||
hidden_size=Validator.check_positive_int,
|
||||
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
|
||||
num_heads=Validator.check_positive_int,
|
||||
ffn_hidden_size=Validator.check_positive_int,
|
||||
seq_length=Validator.check_positive_int,
|
||||
attention_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_act=_valid_type_checks([str], "TransformerEncoderLayer"),
|
||||
post_layernorm_residual=Validator.check_bool,
|
||||
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerEncoderLayer"),
|
||||
|
@ -1465,6 +1467,9 @@ class TransformerEncoderLayer(Cell):
|
|||
moe_config=default_moe_config,
|
||||
parallel_config=default_dpmp_config):
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
if batch_size or use_past:
|
||||
Validator.check_positive_int(batch_size)
|
||||
self.batch_size = batch_size
|
||||
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
|
||||
_check_config(parallel_config)
|
||||
if num_heads % parallel_config.model_parallel != 0:
|
||||
|
@ -1488,7 +1493,6 @@ class TransformerEncoderLayer(Cell):
|
|||
self.use_past = use_past
|
||||
self.seq_length = seq_length
|
||||
self.hidden_size = hidden_size
|
||||
self.batch_size = batch_size
|
||||
self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
||||
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
||||
|
||||
|
@ -1564,7 +1568,6 @@ class TransformerEncoderLayer(Cell):
|
|||
self.use_past = use_past
|
||||
self.seq_length = seq_length
|
||||
self.hidden_size = hidden_size
|
||||
self.batch_size = batch_size
|
||||
self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
||||
self.layernorm1.shard(((parallel_config.data_parallel, 1),))
|
||||
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
||||
|
@ -1623,7 +1626,7 @@ class TransformerEncoderLayer(Cell):
|
|||
raise RuntimeError(f"The {self.cls_name} only support sharding propagation or "
|
||||
f"semi-auto parallel mode now.")
|
||||
|
||||
def construct(self, x, input_mask, init_reset=True, batch_valid_length=None):
|
||||
def construct(self, x, input_mask=None, init_reset=True, batch_valid_length=None):
|
||||
self._check_input(x, input_mask, init_reset, batch_valid_length)
|
||||
x_shape = F.shape(x)
|
||||
x = F.reshape(x, (-1, x_shape[-1]))
|
||||
|
@ -1699,17 +1702,19 @@ class TransformerEncoderLayer(Cell):
|
|||
def _check_input(self, x, input_mask, init_reset, batch_valid_length):
|
||||
r"""Check inputs"""
|
||||
if not self.use_past or (self.use_past and self.is_first_iteration):
|
||||
_check_shape_equal(F.shape(x), "x", self.cls_name,
|
||||
[[self.batch_size, self.seq_length, self.hidden_size],
|
||||
[self.batch_size * self.seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
|
||||
[self.batch_size, self.seq_length, self.seq_length])
|
||||
_check_shape_equal_without_batch(F.shape(x), "x", self.cls_name,
|
||||
[self.seq_length, self.hidden_size])
|
||||
if input_mask is not None:
|
||||
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
|
||||
[F.shape(input_mask)[0], self.seq_length, self.seq_length])
|
||||
else:
|
||||
_check_shape_equal(F.shape(x), "x", self.cls_name, [self.batch_size, 1, self.hidden_size])
|
||||
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
|
||||
[self.batch_size, 1, self.seq_length])
|
||||
if input_mask is not None:
|
||||
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
|
||||
[F.shape(input_mask)[0], 1, self.seq_length])
|
||||
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
|
||||
if input_mask is not None:
|
||||
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
|
||||
|
||||
init_reset_is_tensor = isinstance(init_reset, Tensor)
|
||||
init_reset_is_default = init_reset is True
|
||||
|
@ -1738,7 +1743,9 @@ class TransformerDecoderLayer(Cell):
|
|||
hidden_size(int): The hidden size of the input.
|
||||
ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
|
||||
num_heads(int): The number of the heads.
|
||||
batch_size(int): The batch size of the input tensor.
|
||||
batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
|
||||
value. When do training or prediction, the argument will not work and the user can just pass None to
|
||||
the argument.
|
||||
src_seq_length(int): The input source sequence length.
|
||||
tgt_seq_length(int): The input target sequence length.
|
||||
attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1.
|
||||
|
@ -1764,13 +1771,13 @@ class TransformerDecoderLayer(Cell):
|
|||
- **hidden_stats** (Tensor) - The input tensor with shape [batch_size, tgt_seq_length, hidden_size] or
|
||||
[batch_size * tgt_seq_length, hidden_size].
|
||||
- **decoder_mask** (Tensor) - The attention mask for decoder with shape [batch_size, src_seq_length,
|
||||
seq_length].
|
||||
seq_length] or None. None means there will be no mask in softmax computation in self attention.
|
||||
- **encoder_output** (Tensor) - The output of the encoder with shape [batch_size, seq_length, hidden_size]
|
||||
or [batch_size * seq_length, hidden_size].
|
||||
Note this args can not be passed by None when the net is in outermost layer. Default None.
|
||||
- **memory_mask** (Tensor) - The memory mask of the cross attention with shape [batch, tgt_seq_length,
|
||||
src_seq_length] where tgt_seq_length is the length of the decoder. Note this args can not be passed by
|
||||
None when the net is in outermost layer. Default None.
|
||||
src_seq_length] where tgt_seq_length is the length of the decoder. The user can also pass None. None
|
||||
means there will be no mask in softmax computation in cross attention. Default None.
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
|
||||
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
|
||||
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
|
||||
|
@ -1815,15 +1822,13 @@ class TransformerDecoderLayer(Cell):
|
|||
"""
|
||||
@_LogActionOnce(logger=logger, key='TransformerDecoderLayer',
|
||||
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
||||
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
||||
hidden_size=Validator.check_positive_int,
|
||||
@_args_type_validator_check(hidden_size=Validator.check_positive_int,
|
||||
num_heads=Validator.check_positive_int,
|
||||
ffn_hidden_size=Validator.check_positive_int,
|
||||
src_seq_length=Validator.check_positive_int,
|
||||
tgt_seq_length=Validator.check_positive_int,
|
||||
attention_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_act=_valid_type_checks([str], "TransformerDecoderLayer"),
|
||||
post_layernorm_residual=Validator.check_bool,
|
||||
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerDecoderLayer"),
|
||||
|
@ -1854,6 +1859,8 @@ class TransformerDecoderLayer(Cell):
|
|||
_check_moe_config(moe_config, parallel_config)
|
||||
self.use_moe = (moe_config.expert_num > 1)
|
||||
config_to_attention = parallel_config.dpmp if self.use_moe else parallel_config
|
||||
if batch_size or use_past:
|
||||
Validator.check_positive_int(batch_size)
|
||||
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
|
||||
_check_config(parallel_config)
|
||||
if num_heads % parallel_config.model_parallel != 0:
|
||||
|
@ -2144,28 +2151,30 @@ class TransformerDecoderLayer(Cell):
|
|||
def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length):
|
||||
r"""Check inputs"""
|
||||
if not self.use_past or (self.use_past and self.is_first_iteration):
|
||||
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
|
||||
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, self.tgt_seq_length, self.tgt_seq_length])
|
||||
_check_shape_equal_without_batch(F.shape(hidden_states), "hidden_states", self.cls_name,
|
||||
[self.tgt_seq_length, self.hidden_size])
|
||||
if attention_mask is not None:
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[F.shape(attention_mask)[0], self.tgt_seq_length, self.tgt_seq_length])
|
||||
|
||||
else:
|
||||
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
|
||||
[self.batch_size, 1, self.hidden_size])
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, 1, self.tgt_seq_length])
|
||||
if attention_mask is not None:
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, 1, self.tgt_seq_length])
|
||||
_check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
|
||||
if attention_mask is not None:
|
||||
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16],
|
||||
self.cls_name)
|
||||
if encoder_output is not None:
|
||||
_check_shape_equal(F.shape(encoder_output), "encoder_output", self.cls_name,
|
||||
[[self.batch_size, self.src_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.src_seq_length, self.hidden_size]])
|
||||
_check_shape_equal_without_batch(F.shape(encoder_output), "encoder_output", self.cls_name,
|
||||
[self.src_seq_length, self.hidden_size])
|
||||
_check_input_dtype(F.dtype(encoder_output), "encoder_output",
|
||||
[mstype.float32, mstype.float16], self.cls_name)
|
||||
if memory_mask is not None:
|
||||
_check_shape_equal(F.shape(memory_mask), "memory_mask", self.cls_name,
|
||||
[self.batch_size, self.tgt_seq_length, self.src_seq_length])
|
||||
_check_shape_equal_without_batch(F.shape(memory_mask), "memory_mask", self.cls_name,
|
||||
[self.tgt_seq_length, self.src_seq_length])
|
||||
_check_input_dtype(F.dtype(memory_mask), "memory_mask",
|
||||
[mstype.float32, mstype.float16], self.cls_name)
|
||||
|
||||
|
@ -2240,7 +2249,9 @@ class TransformerEncoder(Cell):
|
|||
attention and feedforward layer.
|
||||
|
||||
Args:
|
||||
batch_size(int): The batch size of the input tensor.
|
||||
batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
|
||||
value. When do training or prediction, the argument will not work and the user can just pass None to
|
||||
the argument.
|
||||
num_layers(int): The layers of the `TransformerEncoderLayer`
|
||||
hidden_size(int): The hidden size of the input.
|
||||
ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
|
||||
|
@ -2284,7 +2295,9 @@ class TransformerEncoder(Cell):
|
|||
- **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size] or
|
||||
[batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise,
|
||||
should be [batch_size, 1, hidden_size].
|
||||
- **attention_mask** (Tensor) - Tensor, attention mask with shape [batch_size, seq_length, seq_length]
|
||||
- **attention_mask** (Tensor) - Float Tensor, If the use_past is False or is_first_iteration=True,
|
||||
the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will
|
||||
be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size]
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
|
||||
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
|
||||
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
|
||||
|
@ -2361,7 +2374,6 @@ class TransformerEncoder(Cell):
|
|||
offset=Validator.check_non_negative_int,
|
||||
attention_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_act=_valid_type_checks([str], "TransformerEncoder"),
|
||||
post_layernorm_residual=Validator.check_bool,
|
||||
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerEncoder"),
|
||||
|
@ -2490,7 +2502,9 @@ class TransformerDecoder(Cell):
|
|||
|
||||
Args:
|
||||
num_layers(int): The layers of the `TransformerDecoderLayer`.
|
||||
batch_size(int): The batch size of the input tensor.
|
||||
batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
|
||||
value. When do training or prediction, the argument will not work and the user can just pass None to
|
||||
the argument.
|
||||
hidden_size(int): The hidden size of the input.
|
||||
ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
|
||||
src_seq_length(int): The input source sequence length.
|
||||
|
@ -2528,13 +2542,14 @@ class TransformerDecoder(Cell):
|
|||
- **hidden_stats** (Tensor) - The input tensor with shape [batch_size, seq_length, hidden_size] or
|
||||
[batch_size * seq_length, hidden_size]
|
||||
- **attention_mask** (Tensor) - The attention mask for decoder with shape
|
||||
[batch_size, seq_length, seq_length]
|
||||
[batch_size, seq_length, seq_length] or None. None means there will be no mask in softmax
|
||||
computation in self attention.
|
||||
- **encoder_output** (Tensor) - The output of the encoder with shape [batch_size, seq_length, hidden_size]
|
||||
or [batch_size * seq_length, hidden_size]. Note this args can not be passed by None when the net is in
|
||||
outermost layer. Default None.
|
||||
- **memory_mask** (Tensor) - The memory mask of the cross attention with shape [batch, tgt_seq_length,
|
||||
src_seq_length] where tgt_seq_length is the length of the decoder. Note this args can not be passed by
|
||||
None when the net is in outermost layer. Default None.
|
||||
src_seq_length] where tgt_seq_length is the length of the decoder. The user can also pass None. None
|
||||
means there will be no mask in softmax computation in cross attention. Default None.
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
|
||||
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
|
||||
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
|
||||
|
@ -2591,7 +2606,6 @@ class TransformerDecoder(Cell):
|
|||
offset=Validator.check_non_negative_int,
|
||||
attention_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_act=_valid_type_checks([str], "TransformerDecoder"),
|
||||
post_layernorm_residual=Validator.check_bool,
|
||||
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"TransformerDecoder"),
|
||||
|
@ -2736,7 +2750,9 @@ class Transformer(Cell):
|
|||
|
||||
Args:
|
||||
hidden_size(int): The hidden size of the input.
|
||||
batch_size(int): The batch size of the input tensor.
|
||||
batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive
|
||||
value. When do training or prediction, the argument will not work and the user can just pass None to
|
||||
the argument.
|
||||
ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
|
||||
src_seq_length(int): The seq_length of the encoder's input tensor.
|
||||
tgt_seq_length(int): The seq_length of the decoder's input tensor.
|
||||
|
@ -2772,15 +2788,17 @@ class Transformer(Cell):
|
|||
- **encoder_inputs** (Tensor) - The input tensor with shape [batch_size, seq_length, hidden_size] or
|
||||
[batch_size * seq_length, hidden_size].
|
||||
- **encoder_masks** (Tensor) - The attention mask for decoder with shape
|
||||
[batch_size, seq_length, seq_length].
|
||||
[batch_size, seq_length, seq_length] or None. None means there will be no mask in softmax computation
|
||||
in self attention of the encoder module.
|
||||
- **decoder_inputs** (Tensor) - The output of the encoder with shape [batch_size, seq_length, hidden_size]
|
||||
or [batch_size * seq_length, hidden_size], this should be none if the decoder layer is 0.
|
||||
- **decoder_masks** (Tensor) - The attention mask for decoder with shape
|
||||
[batch_size, seq_length, seq_length]
|
||||
[batch_size, seq_length, seq_length] or None. None means there will be no mask in softmax computation
|
||||
in self attention of the decoder module.
|
||||
- **memory_mask** (Tensor) - The memory mask of the cross attention with shape [batch, tgt_seq_length,
|
||||
src_seq_length]
|
||||
where tgt_seq_length is the length of the decoder. The output of the encoder with shape [batch_size,
|
||||
seq_length, hidden_size], this should be none if the decoder layer is 0.
|
||||
seq_length, hidden_size], this should be none if the decoder layer is 0 or the user wants no mask.
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
|
||||
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
|
||||
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
|
||||
|
@ -2854,7 +2872,6 @@ class Transformer(Cell):
|
|||
tgt_seq_length=Validator.check_positive_int,
|
||||
attention_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_dropout_rate=Validator.check_non_negative_float,
|
||||
hidden_act=_valid_type_checks([str], "Transformer"),
|
||||
post_layernorm_residual=Validator.check_bool,
|
||||
layernorm_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
|
||||
"Transformer"),
|
||||
|
|
|
@ -18,15 +18,28 @@ import shutil
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
import mindspore
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype
|
||||
from mindspore.ops import operations as ops
|
||||
from mindspore.parallel.nn import MultiHeadAttention, FeedForward, TransformerEncoderLayer, TransformerEncoder, \
|
||||
TransformerDecoder, TransformerDecoderLayer, Transformer, CrossEntropyLoss, AttentionMask, FixedSparseAttention
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
|
||||
|
||||
class MyActivation(mindspore.nn.Cell):
|
||||
def __init__(self):
|
||||
super(MyActivation, self).__init__()
|
||||
self.add = ops.Add()
|
||||
|
||||
def construct(self, x):
|
||||
|
||||
return self.add(x, 0.1)
|
||||
|
||||
def shard(self, parallel_config):
|
||||
self.add.shard(((parallel_config.data_parallel, parallel_config.model_parallel), ()))
|
||||
|
||||
|
||||
def test_transformer_encoder_only():
|
||||
model = Transformer(batch_size=2,
|
||||
src_seq_length=20,
|
||||
|
@ -208,18 +221,42 @@ def test_multihead_attention():
|
|||
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
|
||||
|
||||
|
||||
def test_multihead_attention_wrong_batch():
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, None, 4])
|
||||
def test_multihead_attention_wrong_batch(batch_size):
|
||||
"""
|
||||
Feature: Test MultiHeadAttention with wrong batch for training
|
||||
Description: Test the batch size to be any int or None
|
||||
Expectation: No exception
|
||||
"""
|
||||
model = MultiHeadAttention(hidden_size=15,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=20,
|
||||
batch_size=2,
|
||||
batch_size=batch_size,
|
||||
num_heads=3)
|
||||
from_tensor = Tensor(np.ones((3, 20, 15)), dtype.float32)
|
||||
to_tensor = Tensor(np.ones((3, 20, 15)), dtype.float16)
|
||||
attention_mask = Tensor(np.ones((3, 20, 20)), dtype.float16)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
|
||||
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('from_tensor,to_tensor', [(Tensor(np.ones((20, 15)), dtype.float32),
|
||||
Tensor(np.ones((20, 15)), dtype.float16)),
|
||||
(Tensor(np.ones((3, 20, 15)), dtype.float32),
|
||||
Tensor(np.ones((3, 20, 15)), dtype.float16))])
|
||||
def test_multihead_attention_no_mask_2d_or_3d_shape(from_tensor, to_tensor):
|
||||
"""
|
||||
Feature: Test MultiHeadAttention no mask
|
||||
Description: Test MultiHeadAttention no mask and 2d as inputs.
|
||||
Expectation: No exception
|
||||
"""
|
||||
model = MultiHeadAttention(hidden_size=15,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=20,
|
||||
batch_size=None,
|
||||
num_heads=3)
|
||||
|
||||
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, None)
|
||||
|
||||
|
||||
def test_multihead_attention_fp32_dtype():
|
||||
|
@ -240,6 +277,154 @@ def test_multihead_attention_fp32_dtype():
|
|||
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, None, 4])
|
||||
def test_transformerencoder_wrong_batch(batch_size):
|
||||
"""
|
||||
Feature: Test TransformerEncoderLayer with wrong batch for training
|
||||
Description: Test the batch size to be any int or None
|
||||
Expectation: No exception
|
||||
"""
|
||||
model = TransformerEncoderLayer(batch_size=batch_size, hidden_size=8, ffn_hidden_size=64, seq_length=16,
|
||||
num_heads=2)
|
||||
encoder_input_value = Tensor(np.ones((2, 16, 8)), dtype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 16, 16)), dtype.float16)
|
||||
|
||||
model(encoder_input_value, encoder_input_mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('attention_mask', [Tensor(np.ones((2, 16, 16)), dtype.float16),
|
||||
None])
|
||||
def test_transformerencoder_no_mask(attention_mask):
|
||||
"""
|
||||
Feature: Test TransformerEncoderLayer with no mask
|
||||
Description: Test the attention mask is None
|
||||
Expectation: No exception
|
||||
"""
|
||||
model = TransformerEncoderLayer(batch_size=None, hidden_size=8, ffn_hidden_size=64, seq_length=16,
|
||||
num_heads=2)
|
||||
encoder_input_value = Tensor(np.ones((2, 16, 8)), dtype.float32)
|
||||
|
||||
model(encoder_input_value, attention_mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('shape', [(2, 16, 8), (32, 8)])
|
||||
def test_transformerencoder_2d_or_3d_shape(shape):
|
||||
"""
|
||||
Feature: Test TransformerEncoderLayer with 2d or 3d inputs
|
||||
Description: Test the attention mask is None
|
||||
Expectation: No exception
|
||||
"""
|
||||
model = TransformerEncoderLayer(batch_size=None, hidden_size=8, ffn_hidden_size=64, seq_length=16,
|
||||
num_heads=2)
|
||||
encoder_input_value = Tensor(np.ones(shape), dtype.float32)
|
||||
|
||||
model(encoder_input_value, None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('batch_size', [1, 2, None, 4])
|
||||
def test_transformerdecoder_wrong_batch(batch_size):
|
||||
"""
|
||||
Feature: Test TransformerDecoderLayer with wrong batch for training
|
||||
Description: Test the batch size to be any int or None
|
||||
Expectation: No exception
|
||||
"""
|
||||
model = TransformerDecoderLayer(batch_size=batch_size, hidden_size=64, ffn_hidden_size=64, num_heads=2,
|
||||
src_seq_length=20, tgt_seq_length=10)
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
|
||||
model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('decoder_input_mask,memory_mask',
|
||||
[(None, None), (Tensor(np.ones((2, 10, 10)), dtype.float16), None),
|
||||
(None, Tensor(np.ones((2, 10, 20)), dtype.float16))])
|
||||
def test_transformerdecoder_mask(decoder_input_mask, memory_mask):
|
||||
"""
|
||||
Feature: Test TransformerDecoderLayer with empty mask
|
||||
Description: Test the mask is None
|
||||
Expectation: No exception
|
||||
"""
|
||||
model = TransformerDecoderLayer(batch_size=4, hidden_size=64, ffn_hidden_size=64, num_heads=2,
|
||||
src_seq_length=20, tgt_seq_length=10)
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
|
||||
model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
|
||||
|
||||
|
||||
def test_transformerdecoder_custom_activation():
|
||||
"""
|
||||
Feature: Test TransformerDecoderLayer custom activation
|
||||
Description: Test TransformerDecoderLayer custom activation
|
||||
Expectation: No exception
|
||||
"""
|
||||
model = TransformerDecoderLayer(batch_size=4, hidden_size=64, ffn_hidden_size=64, num_heads=2,
|
||||
hidden_act=MyActivation,
|
||||
src_seq_length=20, tgt_seq_length=10)
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
|
||||
model(decoder_input_value, None, encoder_input_value, None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('encoder_shape,decoder_shape', [((2, 20, 64), (2, 10, 64)),
|
||||
((20, 64), (10, 64))])
|
||||
def test_transformerdecoder_2d_or_3d_shape(encoder_shape, decoder_shape):
|
||||
"""
|
||||
Feature: Test TransformerDecoderLayer with 2d or 3d inputs
|
||||
Description: Test the attention mask is None
|
||||
Expectation: No exception
|
||||
"""
|
||||
model = TransformerDecoderLayer(batch_size=None, hidden_size=64, ffn_hidden_size=64, num_heads=2,
|
||||
src_seq_length=20, tgt_seq_length=10)
|
||||
encoder_input_value = Tensor(np.ones(encoder_shape), dtype.float32)
|
||||
decoder_input_value = Tensor(np.ones(decoder_shape), dtype.float32)
|
||||
model(decoder_input_value, None, encoder_input_value, None)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('hidden_act', [MyActivation, None, "relu"])
|
||||
def test_transformer_hidden_act(hidden_act):
|
||||
"""
|
||||
Feature: Test Transformer hidden activation with activation or None
|
||||
Description: Test the transformer hidden activation
|
||||
Expectation: No exception
|
||||
"""
|
||||
model = Transformer(batch_size=2, encoder_layers=1, decoder_layers=2, hidden_size=64,
|
||||
hidden_act=hidden_act,
|
||||
ffn_hidden_size=64, src_seq_length=20, tgt_seq_length=10)
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 20, 20)), dtype.float16)
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
|
||||
model(encoder_input_value, encoder_input_mask, decoder_input_value,
|
||||
decoder_input_mask, memory_mask)
|
||||
|
||||
|
||||
def test_transformer_hidden_act_with_wrong_hidden_act_wrong_lambda_func():
|
||||
"""
|
||||
Feature: Test Transformer hidden activation with activation or None
|
||||
Description: Test the transformer hidden activation
|
||||
Expectation: No exception
|
||||
"""
|
||||
with pytest.raises(TypeError):
|
||||
Transformer(batch_size=2, encoder_layers=1, decoder_layers=2, hidden_size=64,
|
||||
hidden_act=lambda x: x,
|
||||
ffn_hidden_size=64, src_seq_length=20, tgt_seq_length=10)
|
||||
|
||||
|
||||
def test_transformer_hidden_act_with_wrong_hidden_act_wrong_str():
|
||||
"""
|
||||
Feature: Test Transformer hidden activation with wrong activation
|
||||
Description: Test the transformer hidden activation
|
||||
Expectation: No exception
|
||||
"""
|
||||
with pytest.raises(KeyError):
|
||||
Transformer(batch_size=2, encoder_layers=1, decoder_layers=2, hidden_size=64,
|
||||
hidden_act="no_string",
|
||||
ffn_hidden_size=64, src_seq_length=20, tgt_seq_length=10)
|
||||
|
||||
|
||||
def test_feedforward_layer():
|
||||
model = FeedForward(hidden_size=15,
|
||||
ffn_hidden_size=30,
|
||||
|
|
|
@ -624,17 +624,15 @@ def test_transformer_wrong_dp_no_error():
|
|||
def test_transformer_wrong_semi_auto_dp_error():
|
||||
"""
|
||||
Feature: test transformer api
|
||||
Description: Test transformer exception scene
|
||||
Description: Test transformer parallel batch with no check
|
||||
Expectation: Raise correct error.
|
||||
"""
|
||||
set_auto_parallel_context(device_num=64, full_batch=False, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
||||
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
|
||||
check_config = TransformerOpParallelConfig(data_parallel=16, model_parallel=1, vocab_emb_dp=False)
|
||||
with pytest.raises(ValueError):
|
||||
net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
|
||||
decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64,
|
||||
parallel_config=check_config)
|
||||
del net
|
||||
Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
|
||||
decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64,
|
||||
parallel_config=check_config)
|
||||
|
||||
|
||||
def test_encoder():
|
||||
|
|
Loading…
Reference in New Issue