fix docs and remove assert of Transformer apis

This commit is contained in:
lvyufeng 2023-02-25 16:10:58 +08:00
parent dc7c03cc46
commit a48c647f5b
9 changed files with 355 additions and 313 deletions

View File

@ -14,37 +14,37 @@ mindspore.nn.MultiheadAttention
参数:
- **embed_dim** (int) - 模型的总维数。
- **num_heads** (int) - 并行注意力头的数量。``num_heads`` 需要能够被 ``embed_dim`` 整除(每个头的维数为 ``embed_dim // num_heads``)。
- **dropout** (float) - 应用到输入 ``attn_output_weights`` 上的随机丢弃比例. 默认: ``0.0`` (不丢弃)。
- **has_bias** (bool) - 是否给输入、输出投射层添加偏置。默认: ``True``
- **add_bias_kv** (bool) - 是否给key、value序列的0维添加偏置。默认 ``False``
- **add_zero_attn** (bool) - 是否给key、value序列的1维添加0。默认 ``False``
- **kdim** (int) - key的总特征数。默认 ``None`` (即 ``kdim=embed_dim``)。
- **vdim** (int) - value的总特征数。默认``None`` (即 ``vdim=embed_dim``)。
- **batch_first** (bool) - 如果为 ``True``则输入输出Tensor的shape为 (batch, seq, feature)否则shape为(seq, batch, feature)。 默认 ``False``
- **num_heads** (int) - 并行注意力头的数量。`num_heads` 需要能够被 `embed_dim` 整除(每个头的维数为 `embed_dim // num_heads`)。
- **dropout** (float) - 应用到输入 `attn_output_weights` 上的随机丢弃比例. 默认值: ``0.0``
- **has_bias** (bool) - 是否给输入、输出投射层添加偏置。默认 ``True``
- **add_bias_kv** (bool) - 是否给key、value序列的0维添加偏置。默认 ``False``
- **add_zero_attn** (bool) - 是否给key、value序列的1维添加0。默认 ``False``
- **kdim** (int) - key的总特征数。默认值: ``None`` (即 `kdim=embed_dim`)。
- **vdim** (int) - value的总特征数。默认值:``None`` (即 `vdim=embed_dim`)。
- **batch_first** (bool) - 如果为 ``True``则输入输出Tensor的shape为 (batch, seq, feature)否则shape为(seq, batch, feature)。 默认 ``False``
输入:
- **query** (Tensor) - Query矩阵。当输入非Batch数据时Shape为 :math:`(L, E_q)` 。当输入Batch数据参数 ``batch_first=False``Shape为 :math:`(L, N, E_q)`
``batch_first=True``Shape为 :math:`(N, L, E_q)`。其中, :math:`L` 为目标序列的长度, :math:`N` 为batch size:math:`E_q` 为Query矩阵的维数 ``embed_dim``
- **query** (Tensor) - Query矩阵。当输入非Batch数据时Shape为 :math:`(L, E_q)` 。当输入Batch数据参数 `batch_first=False`Shape为 :math:`(L, N, E_q)`
`batch_first=True`Shape为 :math:`(N, L, E_q)`。其中, :math:`L` 为目标序列的长度, :math:`N` 为batch size:math:`E_q` 为Query矩阵的维数 `embed_dim`
注意力机制通过Query与Key-Value运算以生成最终输出。详情请见"Attention Is All You Need"。
- **key** (Tensor) - Key矩阵。当输入非Batch数据时Shape为 :math:`(S, E_k)` 。当输入Batch数据参数 ``batch_first=False``Shape为 :math:`(S, N, E_k)`
``batch_first=True``Shape为 :math:`(N, S, E_k)`。其中, :math:`S` 为源序列的长度, :math:`N` 为batch size:math:`E_k` 为Key矩阵的维数 ``kdim``。详情请见:"Attention Is All You Need"。
- **value** (Tensor) - Value矩阵。当输入非Batch数据时Shape为 :math:`(S, E_v)` 。当输入Batch数据参数 ``batch_first=False``Shape为 :math:`(S, N, E_v)`
``batch_first=True``Shape为 :math:`(N, S, E_v)`。其中, :math:`S` 为源序列的长度, :math:`N` 为batch size:math:`E_v` 为Key矩阵的维数 ``vdim``。详情请见:"Attention Is All You Need"。
- **key_padding_mask** (Tensor) - 如果指定此值则表示Shape为 :math:`(N, S)`的掩码将被用于 ``key``。当输入非Batch数据时Shape为 :math:`(S)`
如果输入Tensor为Bool类型``key`` 中对应为 ``True`` 的位置将在Attention计算时被忽略。如果输入Tensor为Float类型则将直接与 ``key`` 相加。默认``None``
- **need_weights** (bool) - 是否需要返回 ``attn_output_weights``,如果为 ``True``,则输出包含 ``attn_output_weights``。默认:``True``
- **attn_mask** (Tensor) - 如果指定此值则表示Shape为 :math:`(L, S)`:math:`(N\cdot\text{num\_heads}, L, S)` 的掩码将被用于Attention计算。其中 :math:`N` 为batch size
:math:`L` 为目标序列长度,:math:`S` 为源序列长度。如果输入为2维矩阵则将自动沿batch维广播至3维矩阵。若为3维矩阵则允许沿batch维使用不同的掩码。如果输入Tensor为Bool类型则值为 ``True`` 对应位置允许被注意力计算。如果输入Tensor为Float类型则将直接与注意力权重相加。默认``None``
- **average_attn_weights** (bool) - 如果为 ``True`` 则返回值 ``attn_weights`` 为注意力头的平均值。如果为 ``False``,则 ``attn_weights`` 分别返回每个注意力头的值。
本参数仅在 ``need_weights=True`` 时生效。默认 ``True``
- **key** (Tensor) - Key矩阵。当输入非Batch数据时Shape为 :math:`(S, E_k)` 。当输入Batch数据参数 `batch_first=False`Shape为 :math:`(S, N, E_k)`
`batch_first=True`Shape为 :math:`(N, S, E_k)`。其中, :math:`S` 为源序列的长度, :math:`N` 为batch size:math:`E_k` 为Key矩阵的维数 `kdim`。详情请见:"Attention Is All You Need"。
- **value** (Tensor) - Value矩阵。当输入非Batch数据时Shape为 :math:`(S, E_v)` 。当输入Batch数据参数 `batch_first=False`Shape为 :math:`(S, N, E_v)`
`batch_first=True`Shape为 :math:`(N, S, E_v)`。其中, :math:`S` 为源序列的长度, :math:`N` 为batch size:math:`E_v` 为Key矩阵的维数 `vdim`。详情请见:"Attention Is All You Need"。
- **key_padding_mask** (Tensor, optional) - 如果指定此值则表示Shape为 :math:`(N, S)`的掩码将被用于 `key`。当输入非Batch数据时Shape为 :math:`(S)`
如果输入Tensor为Bool类型`key` 中对应为 ``True`` 的位置将在Attention计算时被忽略。如果输入Tensor为Float类型则将直接与 `key` 相加。默认值``None``
- **need_weights** (bool) - 是否需要返回 `attn_output_weights`,如果为 ``True``,则输出包含 `attn_output_weights`。默认``True``
- **attn_mask** (Tensor, optional) - 如果指定此值则表示Shape为 :math:`(L, S)`:math:`(N\cdot\text{num\_heads}, L, S)` 的掩码将被用于Attention计算。其中 :math:`N` 为batch size
:math:`L` 为目标序列长度,:math:`S` 为源序列长度。如果输入为2维矩阵则将自动沿batch维广播至3维矩阵。若为3维矩阵则允许沿batch维使用不同的掩码。如果输入Tensor为Bool类型则值为 ``True`` 对应位置允许被注意力计算。如果输入Tensor为Float类型则将直接与注意力权重相加。默认``None``
- **average_attn_weights** (bool) - 如果为 ``True`` 则返回值 `attn_weights` 为注意力头的平均值。如果为 ``False``,则 ``attn_weights`` 分别返回每个注意力头的值。
本参数仅在 `need_weights=True` 时生效。默认值 ``True``
输出:
Tuple表示一个包含(`attn_output`, `attn_output_weights`)的元组。
- **attn_output** - 注意力机制的输出。当输入非Batch数据时Shape为 :math:`(L, E)` 。当输入Batch数据 参数 ``batch_first=False``Shape为 :math:`(L, N, E)`
``batch_first=True``Shape为 :math:`(N, L, E)`。其中, :math:`L` 为目标序列的长度, :math:`N` 为batch size :math:`E` 为模型的总维数 ``embed_dim``
- **attn_output_weights** - 仅当 ``need_weights=True`` 时返回。如果 ``average_attn_weights=True``,则返回值 ``attn_weights`` 为注意力头的平均值。当输入非Batch数据时
- **attn_output** - 注意力机制的输出。当输入非Batch数据时Shape为 :math:`(L, E)` 。当输入Batch数据 参数 `batch_first=False`Shape为 :math:`(L, N, E)`
`batch_first=True`Shape为 :math:`(N, L, E)`。其中, :math:`L` 为目标序列的长度, :math:`N` 为batch size :math:`E` 为模型的总维数 `embed_dim`
- **attn_output_weights** - 仅当 ``need_weights=True`` 时返回。如果 `average_attn_weights=True`,则返回值 `attn_weights` 为注意力头的平均值。当输入非Batch数据时
Shape为 :math:`(L, S)` 当输入Batch数据时Shape为 :math:`(N, L, S)`。其中 :math:`N` 为batch size :math:`L` 为目标序列的长度,:math:`S` 为源序列长度。
如果 ``average_attn_weights=False`` 分别返回每个注意力头的值。当输入非Batch数据时Shape为 :math:`(\text{num\_heads}, L, S)` 当输入Batch数据时Shape为
如果 `average_attn_weights=False` 分别返回每个注意力头的值。当输入非Batch数据时Shape为 :math:`(\text{num\_heads}, L, S)` 当输入Batch数据时Shape为
:math:`(N, \text{num\_heads}, L, S)`

View File

@ -6,28 +6,28 @@ mindspore.nn.Transformer
Transformer模块包括编码器和解码器。本模块与原论文的实现不同原论文在LayerNorm前使用了残差模块。且默认的隐藏层激活函数为 `gelu` 。详情可见 `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_
参数:
- **d_model** (int) - Encoder或Decoder输入的特征数。默认512。
- **nhead** (int) - 注意力头的数量。
- **num_encoder_layers** (int) - Encoder的层数。默认6。
- **num_decoder_layers** (int) - Decoder的层数。默认6。
- **dim_feedforward** (int) - FeedForward层的维数。默认2048。
- **dropout** (float) - 随机丢弃比例。默认0.1。
- **activation** (str, Cell) - Encoder或Decoder中间层的激活函数可以输入字符串、函数接口或激活函数层实例。支持"relu"、"gelu"。默认:"relu"。
- **custom_encoder** (Cell) - 自定义Encoder层。默认None。
- **custom_decoder** (Cell) - 自定义Decoder层。默认None。
- **layer_norm_eps** (float) - LayerNorm层的eps值默认1e-5。
- **batch_first** (bool) - 如果为 ``True`` 则输入输出Shape为(batch, seq, feature)反之Shape为(seq, batch, feature)。默认: ``False``
- **norm_first** (bool) - 如果为 ``True``则LayerNorm层位于MultiheadAttention层和FeedForward层之前反之位于其后。默认 ``False``
- **d_model** (int) - Encoder或Decoder输入的特征数。默认``512``
- **nhead** (int) - 注意力头的数量。默认值:``8``
- **num_encoder_layers** (int) - Encoder的层数。默认``6``
- **num_decoder_layers** (int) - Decoder的层数。默认``6``
- **dim_feedforward** (int) - FeedForward层的维数。默认``2048``
- **dropout** (float) - 随机丢弃比例。默认``0.1``
- **activation** (Union[str, callable, Cell]) - Encoder或Decoder中间层的激活函数可以输入字符串``"relu"````"gelu"``)、函数接口(``ops.relu````ops.gelu``)或激活函数层实例(``nn.ReLU()````nn.GELU()``)。默认值:``"relu"``
- **custom_encoder** (Cell) - 自定义Encoder层。默认``None``
- **custom_decoder** (Cell) - 自定义Decoder层。默认``None``
- **layer_norm_eps** (float) - LayerNorm层的eps值默认``1e-5``
- **batch_first** (bool) - 如果为 ``True`` 则输入输出Shape为(batch, seq, feature)反之Shape为(seq, batch, feature)。默认 ``False``
- **norm_first** (bool) - 如果为 ``True``则LayerNorm层位于MultiheadAttention层和FeedForward层之前反之位于其后。默认 ``False``
输入:
- **src** (Tensor) - 源序列。
- **tgt** (Tensor) - 目标序列。
- **src_mask** (Tensor) - 源序列的掩码矩阵 (可选)。默认None。
- **tgt_mask** (Tensor) - 目标序列的掩码矩阵 (可选)。默认None。
- **memory_mask** (Tensor) - memory序列的掩码矩阵 (可选)。默认None。
- **src_key_padding_mask** (Tensor) - 源序列Key矩阵的掩码矩阵 (可选)。默认None。
- **tgt_key_padding_mask** (Tensor) - 目标序列Key矩阵的掩码矩阵 (可选)。默认None。
- **memory_key_padding_mask** (Tensor) - memory序列Key矩阵的掩码矩阵 (可选)。默认None。
- **src_mask** (Tensor, 可选) - 源序列的掩码矩阵。默认``None``
- **tgt_mask** (Tensor, 可选) - 目标序列的掩码矩阵。默认``None``
- **memory_mask** (Tensor, 可选) - memory序列的掩码矩阵。默认``None``
- **src_key_padding_mask** (Tensor, 可选) - 源序列Key矩阵的掩码矩阵。默认``None``
- **tgt_key_padding_mask** (Tensor, 可选) - 目标序列Key矩阵的掩码矩阵。默认``None``
- **memory_key_padding_mask** (Tensor, 可选) - memory序列Key矩阵的掩码矩阵。默认``None``
输出:
Tensor。

View File

@ -8,15 +8,15 @@ mindspore.nn.TransformerDecoder
参数:
- **decoder_layer** (Cell) - TransformerDecoderLayer()的实例。
- **num_layers** (int) - 解码器层数。
- **norm** (Cell) - 自定义LayerNorm层(可选)
- **norm** (Cell, 可选) - 自定义LayerNorm层。
输入:
- **tgt** (Tensor) - 目标序列。
- **memory** (Tensor) - TransformerEncoder的最后一层输出序列。
- **tgt_mask** (Tensor) - 目标序列的掩码矩阵 (可选)。默认None。
- **memory_mask** (Tensor) - memory序列的掩码矩阵 (可选)。默认None。
- **tgt_key_padding_mask** (Tensor) - 目标序列Key矩阵的掩码矩阵 (可选)。默认None。
- **memory_key_padding_mask** (Tensor) - memory序列Key矩阵的掩码矩阵 (可选)。默认None。
- **tgt_mask** (Tensor, 可选) - 目标序列的掩码矩阵。默认``None``
- **memory_mask** (Tensor, 可选) - memory序列的掩码矩阵。默认``None``
- **tgt_key_padding_mask** (Tensor, 可选) - 目标序列Key矩阵的掩码矩阵。默认``None``
- **memory_key_padding_mask** (Tensor, 可选) - memory序列Key矩阵的掩码矩阵。默认``None``
输出:
Tensor。

View File

@ -8,20 +8,20 @@ mindspore.nn.TransformerDecoderLayer
参数:
- **d_model** (int) - 输入的特征数。
- **nhead** (int) - 注意力头的数量。
- **dim_feedforward** (int) - FeedForward层的维数。默认2048。
- **dropout** (float) - 随机丢弃比例。默认0.1。
- **activation** (str, Cell) - 中间层的激活函数,可以输入字符串、函数接口或激活函数层实例。支持"relu"、"gelu"。默认:"relu"。
- **layer_norm_eps** (float) - LayerNorm层的eps值默认1e-5。
- **batch_first** (bool) - 如果为 ``True`` 则输入输出Shape为(batch, seq, feature)反之Shape为(seq, batch, feature)。默认: ``False``
- **norm_first** (bool) - 如果为 ``True`` 则LayerNorm层位于Self Attention层、MultiheadAttention层和FeedForward层之前反之位于其后。默认 ``False``
- **dim_feedforward** (int) - FeedForward层的维数。默认``2048``
- **dropout** (float) - 随机丢弃比例。默认``0.1``
- **activation** (Union[str, callable, Cell]) - 中间层的激活函数,可以输入字符串``"relu"````"gelu"``)、函数接口(``ops.relu````ops.gelu``)或激活函数层实例(``nn.ReLU()````nn.GELU()``)。默认值:``"relu"``
- **layer_norm_eps** (float) - LayerNorm层的eps值默认``1e-5``
- **batch_first** (bool) - 如果为 ``True`` 则输入输出Shape为(batch, seq, feature)反之Shape为(seq, batch, feature)。默认 ``False``
- **norm_first** (bool) - 如果为 ``True`` 则LayerNorm层位于Self Attention层、MultiheadAttention层和FeedForward层之前反之位于其后。默认 ``False``
输入:
- **tgt** (Tensor) - 目标序列。
- **memory** (Tensor) - TransformerEncoder的最后一层输出序列。
- **tgt_mask** (Tensor) - 目标序列的掩码矩阵 (可选)。默认None。
- **memory_mask** (Tensor) - memory序列的掩码矩阵 (可选)。默认None。
- **tgt_key_padding_mask** (Tensor) - 目标序列Key矩阵的掩码矩阵 (可选)。默认None。
- **memory_key_padding_mask** (Tensor) - memory序列Key矩阵的掩码矩阵 (可选)。默认None。
- **tgt_mask** (Tensor, 可选) - 目标序列的掩码矩阵。默认``None``
- **memory_mask** (Tensor, 可选) - memory序列的掩码矩阵。默认``None``
- **tgt_key_padding_mask** (Tensor, 可选) - 目标序列Key矩阵的掩码矩阵。默认``None``
- **memory_key_padding_mask** (Tensor, 可选) - memory序列Key矩阵的掩码矩阵∂。默认值``None``
输出:
Tensor。

View File

@ -8,12 +8,12 @@ mindspore.nn.TransformerEncoder
参数:
- **encoder_layer** (Cell) - TransformerEncoderLayer()的实例。
- **num_layers** (int) - 编码器层数。
- **norm** (Cell) - 自定义LayerNorm层(可选)
- **norm** (Cell, 可选) - 自定义LayerNorm层。
输入:
- **src** (Tensor) - 源序列。
- **src_mask** (Tensor) - 源序列的掩码矩阵 (可选)。默认None。
- **src_key_padding_mask** (Tensor) - 源序列Key矩阵的掩码矩阵 (可选)。默认None。
- **src_mask** (Tensor, 可选) - 源序列的掩码矩阵。默认``None``
- **src_key_padding_mask** (Tensor, 可选) - 源序列Key矩阵的掩码矩阵。默认``None``
输出:
Tensor。

View File

@ -8,17 +8,17 @@ mindspore.nn.TransformerEncoderLayer
参数:
- **d_model** (int) - 输入的特征数。
- **nhead** (int) - 注意力头的数量。
- **dim_feedforward** (int) - FeedForward层的维数。默认2048。
- **dropout** (float) - 随机丢弃比例。默认0.1。
- **activation** (str, Cell) - 中间层的激活函数,可以输入字符串、函数接口或激活函数层实例。支持'relu''gelu'。默认:'relu'。
- **layer_norm_eps** (float) - LayerNorm层的eps值默认1e-5。
- **batch_first** (bool) - 如果为 ``True`` 则输入输出Shape为(batch, seq, feature)反之Shape为(seq, batch, feature)。默认: ``False``
- **norm_first** (bool) - 如果为 ``True`` 则LayerNorm层位于MultiheadAttention层和FeedForward层之前反之位于其后。默认 ``False``
- **dim_feedforward** (int) - FeedForward层的维数。默认``2048``
- **dropout** (float) - 随机丢弃比例。默认``0.1``
- **activation** (Union[str, callable, Cell]) - 中间层的激活函数,可以输入字符串``"relu"````"gelu"``)、函数接口(``ops.relu````ops.gelu``)或激活函数层实例(``nn.ReLU()````nn.GELU()``)。默认值:``"relu"``
- **layer_norm_eps** (float) - LayerNorm层的eps值默认``1e-5``
- **batch_first** (bool) - 如果为 ``True`` 则输入输出Shape为(batch, seq, feature)反之Shape为(seq, batch, feature)。默认 ``False``
- **norm_first** (bool) - 如果为 ``True`` 则LayerNorm层位于MultiheadAttention层和FeedForward层之前反之位于其后。默认 ``False``
输入:
- **src** (Tensor) - 源序列。
- **src_mask** (Tensor) - 源序列的掩码矩阵 (可选)。默认None。
- **src_key_padding_mask** (Tensor) - 源序列Key矩阵的掩码矩阵 (可选)。默认None。
- **src_mask** (Tensor, 可选) - 源序列的掩码矩阵。默认``None``
- **src_key_padding_mask** (Tensor, 可选) - 源序列Key矩阵的掩码矩阵。默认``None``
输出:
Tensor。

View File

@ -58,65 +58,66 @@ class MultiheadAttention(Cell):
if query, key and value tensor is same, then it will be self attention.
Args:
embed_dim (int): Total dimension of the model.
num_heads (int): Number of parallel attention heads. Note that ``embed_dim`` will be split
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
dropout (float): Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
embed_dim (int): Total dimension of MultiheadAttention.
num_heads (int): Number of attention heads. Note that `embed_dim` will be split
across `num_heads` (i.e. each head will have dimension `embed_dim // num_heads`).
dropout (float): Dropout probability of `attn_output_weights`. Default: ``0.0``.
has_bias (bool): Whether adds bias to input / output projection layers. Default: ``True``.
add_bias_kv (bool): Whether adds bias to the key and value sequences at dim=0. Default: ``False``.
add_zero_attn (bool): Whether adds a new batch of zeros to the key and value sequences at dim=1.
add_bias_kv (bool): Whether adds bias to the key and value sequences at axis=0. Default: ``False``.
add_zero_attn (bool): Whether adds a new batch of zeros to the key and value sequences at axis=1.
Default: ``False``.
kdim (int): Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
vdim (int): Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
batch_first (bool): If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
kdim (int): Total number of features for keys. Default: ``None`` (`kdim=embed_dim`).
vdim (int): Total number of features for values. Default: ``None`` (`vdim=embed_dim`).
batch_first (bool): If ``True``, then the input and output shape are (batch, seq, feature),
else (seq, batch, feature). Default: ``False``.
Inputs:
- **query** (Tensor): Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)`
when ``batch_first=False`` or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L`is the
target sequence length, :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension
``embed_dim``. Queries are compared against key-value pairs to produce the output.
See "Attention Is All You Need" for more details.
- **key** (Tensor): Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)`
when ``batch_first=False`` or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the
source sequence length, :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension
``kdim``. See "Attention Is All You Need" for more details.
- **value** (Tensor): Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)`
when ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the
source sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension
``vdim``. See "Attention Is All You Need" for more details.
- **key_padding_mask** (Tensor): If specified, a mask of shape :math:`(N, S)` indicating which elements
within ``key`` to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`,
shape should be :math:`(S)`. Binary and byte masks are supported.
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
- **need_weights** (bool): If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
- **query** (Tensor): The query embeddings. If `query` is unbatched, the shape is :math:`(L, E_q)`,
otherwise the shape is :math:`(L, N, E_q)` when `batch_first=False` or :math:`(N, L, E_q)` when
`batch_first=True`, where :math:`L`is the target sequence length, :math:`N` is the batch size,
and :math:`E_q` is the query embedding dimension `embed_dim`. Queries are compared against
key-value pairs to produce the output. See "Attention Is All You Need" for more details.
- **key** (Tensor): The key embeddings. If `key` is unbatched, the shape is :math:`(S, E_k)`, otherwise
the shape is :math:`(S, N, E_k)` when `batch_first=False` or :math:`(N, S, E_k)` when
`batch_first=True`, where :math:`S` is the source sequence length, :math:`N` is the batch size,
and :math:`E_k` is the key embedding dimension `kdim`. See "Attention Is All You Need" for more details.
- **value** (Tensor): The value embeddings. If `value` is unbatched, the shape is :math:`(S, E_v)`,
otherwise the shape is :math:`(S, N, E_v)` when `batch_first=False` or :math:`(N, S, E_v)` when
`batch_first=True`, where :math:`S` is the source sequence length, :math:`N` is the batch size,
and :math:`E_v` is the value embedding dimension `vdim`. See "Attention Is All You Need" for more details.
- **key_padding_mask** (Tensor, optional): If specified, a mask of shape :math:`(N, S)` indicating which
elements within `key` to ignore for the purpose of attention (i.e. treat as "padding").
For unbatched `query`, shape should be :math:`(S)`. Binary and byte masks are supported.
For a binary mask, a ``True`` value indicates that the corresponding `key` value will be ignored for
the purpose of attention. For a float mask, it will be directly added to the corresponding `key` value.
- **need_weights** (bool): Whether returns `attn_output_weights` in addition to `attn_outputs`.
Default: ``True``.
- **attn_mask** (Tensor): If specified, a 2D or 3D mask preventing attention to certain positions. Must
be of shape :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
- **attn_mask** (Tensor, optional): If specified, a 2D or 3D mask preventing attention to certain positions.
Must be of shape :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the
batch size, :math:`L` is the target sequence length, and :math:`S` is the source sequence length.
A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry
in the batch. Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates
that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that
the corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.
- **average_attn_weights** (bool): If true, indicates that the returned ``attn_weights`` should be averaged
across heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only
has an effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
- **average_attn_weights** (bool): If true, indicates that the returned `attn_weights` should be averaged
across heads. Otherwise, `attn_weights` are provided separately per head. Note that this flag only
has an effect when `need_weights=True`. Default: ``True`` (i.e. average weights across heads)
Outputs:
Tuple, a tuple contains(`attn_output`, `attn_output_weights`)
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
embedding dimension ``embed_dim``.
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
- **attn_output** - Attention outputs. If input is unbatched, the output shape is:math:`(L, E)`, otherwise
the output shape is :math:`(L, N, E)` when `batch_first=False` or :math:`(N, L, E)` when
`batch_first=True`, where :math:`L` is the target sequence length, :math:`N` is the batch size,
and :math:`E` is the embedding dimension `embed_dim`.
- **attn_output_weights** - Only returned when `need_weights=True`. If `average_attn_weights=True`,
returns attention weights averaged across heads with shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)` when input is batched, where :math:`N` is the batch size, :math:`L` is
the target sequence length, and :math:`S` is the source sequence length.
If `average_attn_weights=False`, returns attention weights per
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or
:math:`(N, \text{num\_heads}, L, S)`.
:math:`(N, \text{num\_heads}, L, S)` when input is batched.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -129,7 +130,7 @@ class MultiheadAttention(Cell):
def __init__(self, embed_dim, num_heads, dropout=0., has_bias=True, add_bias_kv=False,
add_zero_attn=False, kdim=None, vdim=None, batch_first=False):
super(MultiheadAttention, self).__init__()
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
@ -139,7 +140,8 @@ class MultiheadAttention(Cell):
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if self.head_dim * num_heads != self.embed_dim:
raise ValueError("The init argument 'embed_dim' must be divisible by 'num_heads'.")
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, embed_dim)), 'q_proj_weight')
@ -186,7 +188,7 @@ class MultiheadAttention(Cell):
"only bool and floating types of key_padding_mask are supported")
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
# k_is_v and q_is_k preprocess in __call__ since Graph mode do not support `is`
if self.k_is_v:
if self.q_is_k:
query = key = value = query.swapaxes(1, 0)
@ -232,22 +234,24 @@ class TransformerEncoderLayer(Cell):
encoder layer, including multihead attention and feedward layer.
Args:
d_model (int): the number of expected features in the input (required).
nhead (int): the number of heads in the multiheadattention models (required).
dim_feedforward (int): the dimension of the feedforward network model (default=2048).
dropout (float): the dropout value (default=0.1).
activation (str, Cell): the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
layer_norm_eps (float): the eps value in layer normalization components (default=1e-5).
batch_first (bool): If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
norm_first (bool): if ``True``, layer norm is done prior to attention and feedforward
operations, respectively. Otherwise it's done after. Default: ``False`` (after).
d_model (int): The number of features in the input tensor.
nhead (int): The number of heads in the MultiheadAttention modules.
dim_feedforward (int): The dimension of the feedforward layer. Default: ``2048``.
dropout (float): The dropout value. Default: ``0.1``.
activation (Union[str, callable, Cell]): The activation function of the intermediate layer,
can be a string (`"relu"` or `"gelu"`), Cell instance (`nn.ReLU()` or `nn.GELU()`) or
a callable (`ops.relu` or `ops.gelu`). Default: ``"relu"``.
layer_norm_eps (float): The epsilon value in LayerNorm modules. Default: ``1e-5``.
batch_first (bool): If `batch_first = True`, then the shape of input and output tensors is
(batch, seq, feature), otherwise the shape is (seq, batch, feature). Default: ``False``.
norm_first (bool): If `norm_first = True`, layer norm is done prior to attention and feedforward
operations, respectively. Default: ``False``.
Inputs:
- **src** (Tensor): the sequence to the encoder layer (required).
- **src_mask** (Tensor): the mask for the src sequence (optional).
- **src_key_padding_mask** (Tensor): the mask for the src keys per batch (optional).
- **src** (Tensor): the sequence to the encoder layer.
- **src_mask** (Tensor, optional): the mask for the src sequence. Default: ``None``.
- **src_key_padding_mask** (Tensor, optional): the mask for the src keys per batch.
Default: ``None``.
Outputs:
Tensor.
@ -267,11 +271,11 @@ class TransformerEncoderLayer(Cell):
__constants__ = ['batch_first', 'norm_first']
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Cell] = 'relu', layer_norm_eps: float = 1e-5, batch_first: bool = False,
norm_first: bool = False):
super(TransformerEncoderLayer, self).__init__()
activation: Union[str, Cell, callable] = 'relu', layer_norm_eps: float = 1e-5,
batch_first: bool = False, norm_first: bool = False):
super().__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
# Implementation of Feedforward model
# feedforward layer
self.linear1 = _Linear(d_model, dim_feedforward)
self.dropout = Dropout(1-dropout)
self.linear2 = _Linear(dim_feedforward, d_model)
@ -282,18 +286,21 @@ class TransformerEncoderLayer(Cell):
self.dropout1 = Dropout(1-dropout)
self.dropout2 = Dropout(1-dropout)
# Legacy string support for activation function.
if not isinstance(activation, str) and not isinstance(activation, Cell) \
and not callable(activation):
raise ValueError(f"The argument 'activation' must be str, callable or Cell instance,"
f" but get {activation}.")
if isinstance(activation, Cell) and (not isinstance(activation, ReLU) or \
not isinstance(activation, GELU)):
raise ValueError(f"The argument 'activation' must be nn.ReLU or nn.GELU instance,"
f" but get {activation}.")
if callable(activation) and (activation is not ops.relu or \
activation is not ops.gelu):
raise ValueError(f"The argument 'activation' must be ops.relu or ops.gelu instance,"
f" but get {activation}.")
# string inputs of activation
if isinstance(activation, str):
activation = _get_activation_fn(activation)
# We can't test self.activation in forward() in TorchScript,
# so stash some information about it instead.
if activation is ops.relu or isinstance(activation, ReLU):
self.activation_relu_or_gelu = 1
elif activation is ops.gelu or isinstance(activation, GELU):
self.activation_relu_or_gelu = 2
else:
self.activation_relu_or_gelu = 0
self.activation = activation
def construct(self, src: Tensor, src_mask: Optional[Tensor] = None,
@ -332,26 +339,28 @@ class TransformerDecoderLayer(Cell):
decoder layer, including self-attention, cross attention and feedward layer.
Args:
d_model (int): the number of expected features in the input (required).
nhead (int): the number of heads in the multiheadattention models (required).
dim_feedforward (int): the dimension of the feedforward network model (default=2048).
dropout (float): the dropout value (default=0.1).
activation (str, Cell): the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
layer_norm_eps (float): the eps value in layer normalization components (default=1e-5).
batch_first (bool): If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
norm_first (bool): if ``True``, layer norm is done prior to self attention, multihead
attention and feedforward operations, respectively. Otherwise it's done after.
Default: ``False`` (after).
d_model (int): The number of expected features in the input tensor.
nhead (int): The number of heads in the MultiheadAttention modules.
dim_feedforward (int): The dimension of the feedforward layer. Default: ``2048``.
dropout (float): The dropout value. Default: ``0.1``.
activation (Union[str, callable, Cell]): The activation function of the intermediate layer,
can be a string (`"relu"` or `"gelu"`), Cell instance (`nn.ReLU()` or `nn.GELU()`) or
a callable (`ops.relu` or `ops.gelu`). Default: ``"relu"``
layer_norm_eps (float): The epsilon value in LayerNorm modules. Default: ``1e-5``.
batch_first (bool): If `batch_first = True`, then the shape of input and output tensors is
(batch, seq, feature), otherwise the shape is (seq, batch, feature). Default: ``False``.
norm_first (bool): If `norm_first = True`, layer norm is done prior to attention and feedforward
operations, respectively. Default: ``False``.
Inputs:
- **tgt** (Tensor): the sequence to the decoder layer (required).
- **memory** (Tensor): the sequence from the last layer of the encoder (required).
- **tgt_mask** (Tensor): the mask for the tgt sequence (optional).
- **memory_mask** (Tensor): the mask for the memory sequence (optional).
- **tgt_key_padding_mask** (Tensor): the mask for the tgt keys per batch (optional).
- **memory_key_padding_mask** (Tensor): the mask for the memory keys per batch (optional).
- **tgt** (Tensor): The sequence to the decoder layer.
- **memory** (Tensor): The sequence from the last layer of the encoder.
- **tgt_mask** (Tensor, optional): The mask of the tgt sequence. Default: ``None``.
- **memory_mask** (Tensor, optional): The mask of the memory sequence. Default: ``None``.
- **tgt_key_padding_mask** (Tensor, optional): The mask of the tgt keys per batch.
Default: ``None``.
- **memory_key_padding_mask** (Tensor, optional): The mask of the memory keys per batch.
Default: ``None``.
Outputs:
Tensor.
@ -364,7 +373,7 @@ class TransformerDecoderLayer(Cell):
>>> memory = Tensor(np.random.rand(10, 32, 512), mindspore.float32)
>>> tgt = Tensor(np.random.rand(20, 32, 512), mindspore.float32)
>>> out = decoder_layer(tgt, memory)
>>> # Alternatively, when ``batch_first`` is ``True``:
>>> # Alternatively, when `batch_first` is ``True``:
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
>>> memory = Tensor(np.random.rand(32, 10, 512), mindspore.float32)
>>> tgt = Tensor(np.random.rand(32, 20, 512), mindspore.float32)
@ -373,12 +382,12 @@ class TransformerDecoderLayer(Cell):
__constants__ = ['batch_first', 'norm_first']
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Cell] = 'relu', layer_norm_eps: float = 1e-5, batch_first: bool = False,
norm_first: bool = False):
super(TransformerDecoderLayer, self).__init__()
activation: Union[str, Cell, callable] = 'relu', layer_norm_eps: float = 1e-5,
batch_first: bool = False, norm_first: bool = False):
super().__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
# Implementation of Feedforward model
# feedforward layer
self.linear1 = _Linear(d_model, dim_feedforward)
self.dropout = Dropout(1-dropout)
self.linear2 = _Linear(dim_feedforward, d_model)
@ -391,11 +400,22 @@ class TransformerDecoderLayer(Cell):
self.dropout2 = Dropout(1-dropout)
self.dropout3 = Dropout(1-dropout)
# Legacy string support for activation function.
if not isinstance(activation, str) and not isinstance(activation, Cell) \
and not callable(activation):
raise ValueError(f"The argument 'activation' must be str, callable or Cell instance,"
f" but get {activation}.")
if isinstance(activation, Cell) and (not isinstance(activation, ReLU) or \
not isinstance(activation, GELU)):
raise ValueError(f"The argument 'activation' must be nn.ReLU or nn.GELU instance,"
f" but get {activation}.")
if callable(activation) and (activation is not ops.relu or \
activation is not ops.gelu):
raise ValueError(f"The argument 'activation' must be ops.relu or ops.gelu instance,"
f" but get {activation}.")
# string inputs of activation
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
else:
self.activation = activation
activation = _get_activation_fn(activation)
self.activation = activation
def construct(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
@ -438,14 +458,15 @@ class TransformerEncoder(Cell):
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
Args:
encoder_layer (Cell): an instance of the TransformerEncoderLayer() class (required).
num_layers (int): the number of sub-encoder-layers in the encoder (required).
norm (Cell): the layer normalization component (optional).
encoder_layer (Cell): An instance of the TransformerEncoderLayer() class.
num_layers (int): The number of encoder-layers in the encoder.
norm (Cell, optional): The layer normalization module.
Inputs:
- **src** (Tensor): the sequence to the encoder (required).
- **src_mask** (Tensor): the mask for the src sequence (optional).
- **src_key_padding_mask** (Tensor): the mask for the src keys per batch (optional).
- **src** (Tensor): The sequence to the encoder.
- **src_mask** (Tensor, optional): The mask of the src sequence. Default: ``None``.
- **src_key_padding_mask** (Tensor, optional): the mask of the src keys per batch .
Default: ``None``.
Outputs:
Tensor.
@ -467,7 +488,7 @@ class TransformerEncoder(Cell):
self.num_layers = num_layers
self.norm = norm
def construct(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None):
def construct(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None):
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != mindspore.bool_ and not ops.is_floating_point(src_key_padding_mask):
@ -476,7 +497,7 @@ class TransformerEncoder(Cell):
output = src
src_key_padding_mask_for_layers = src_key_padding_mask
for mod in self.layers:
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask_for_layers)
output = mod(output, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask_for_layers)
if self.norm is not None:
output = self.norm(output)
@ -490,17 +511,19 @@ class TransformerDecoder(Cell):
attention, cross attention and feedforward layer.
Args:
decoder_layer (Cell): an instance of the TransformerDecoderLayer() class (required).
num_layers (int): the number of sub-decoder-layers in the decoder (required).
norm (Cell): the layer normalization component (optional).
decoder_layer (Cell): An instance of the TransformerDecoderLayer() class.
num_layers (int): The number of decoder-layers in the decoder.
norm (Cell): The layer normalization module.
Inputs:
- **tgt** (Tensor): the sequence to the decoder (required).
- **memory** (Tensor): the sequence from the last layer of the encoder (required).
- **tgt_mask** (Tensor): the mask for the tgt sequence (optional).
- **memory_mask** (Tensor): the mask for the memory sequence (optional).
- **tgt_key_padding_mask** (Tensor): the mask for the tgt keys per batch (optional).
- **memory_key_padding_mask** (Tensor): the mask for the memory keys per batch (optional).
- **tgt** (Tensor): The sequence to the decoder.
- **memory** (Tensor): The sequence from the last layer of the encoder.
- **tgt_mask** (Tensor, optional): the mask of the tgt sequence. Default: ``None``.
- **memory_mask** (Tensor, optional): the mask of the memory sequence. Default: ``None``.
- **tgt_key_padding_mask** (Tensor, optional): the mask of the tgt keys per batch.
Default: ``None``.
- **memory_key_padding_mask** (Tensor, optional): the mask of the memory keys per batch.
Default: ``None``.
Outputs:
Tensor.
@ -547,31 +570,36 @@ class Transformer(Cell):
The details can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
Args:
d_model (int): the number of expected features in the encoder/decoder inputs (default=512).
nhead (int): the number of heads in the multiheadattention models (default=8).
num_encoder_layers (int): the number of sub-encoder-layers in the encoder (default=6).
num_decoder_layers (int): the number of sub-decoder-layers in the decoder (default=6).
dim_feedforward (int): the dimension of the feedforward network model (default=2048).
dropout (float): the dropout value (default=0.1).
activation (str, Cell): the activation function of encoder/decoder intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
custom_encoder (Cell): custom encoder (default=None).
custom_decoder (Cell): custom decoder (default=None).
layer_norm_eps (float): the eps value in layer normalization components (default=1e-5).
batch_first (bool): If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before
other attention and feedforward operations, otherwise after. Default: ``False`` (after).
d_model (int): The number of expected features in the inputs tensor. Default: ``512``.
nhead (int): The number of heads in the MultiheadAttention modules. Default: ``8``.
num_encoder_layers (int): The number of encoder-layers in the encoder. Default: ``6``.
num_decoder_layers (int): The number of decoder-layers in the decoder. Default: ``6``.
dim_feedforward (int): The dimension of the feedforward layer. Default: ``2048``.
dropout (float): The dropout value. Default: ``0.1``.
activation (Union[str, callable, Cell]): The activation function of the intermediate layer,
can be a string (`"relu"` or `"gelu"`), Cell instance (`nn.ReLU()` or `nn.GELU()`) or
a callable (`ops.relu` or `ops.gelu`). Default: ``"relu"``
custom_encoder (Cell): Custom encoder. Default: ``None``.
custom_decoder (Cell): Custom decoder. Default: ``None``.
layer_norm_eps (float): the epsilion value in layer normalization module. Default: ``1e-5``.
batch_first (bool): If `batch_first = True`, then the shape of input and output tensors is
(batch, seq, feature), otherwise the shape is (seq, batch, feature). Default: ``False``.
norm_first (bool): If `norm_first = True`, layer norm is done prior to attention and feedforward
operations, respectively. Default: ``False``.
Inputs:
- **src** (Tensor): the sequence to the encoder (required).
- **tgt** (Tensor): the sequence to the decoder (required).
- **src_mask** (Tensor): the additive mask for the src sequence (optional).
- **tgt_mask** (Tensor): the additive mask for the tgt sequence (optional).
- **memory_mask** (Tensor): the additive mask for the encoder output (optional).
- **src_key_padding_mask** (Tensor): the ByteTensor mask for src keys per batch (optional).
- **tgt_key_padding_mask** (Tensor): the ByteTensor mask for tgt keys per batch (optional).
- **memory_key_padding_mask** (Tensor): the ByteTensor mask for memory keys per batch (optional).
- **src** (Tensor): The source sequence to the encoder.
- **tgt** (Tensor): The target sequence to the decoder.
- **src_mask** (Tensor, optional): The mask of the src sequence. Default: ``None``.
- **tgt_mask** (Tensor, optional): The mask of the tgt sequence. Default: ``None``.
- **memory_mask** (Tensor, optional): The additive mask of the encoder output.
Default: ``None``.
- **src_key_padding_mask** (Tensor, optional): The mask of src keys per batch.
Default: ``None``.
- **tgt_key_padding_mask** (Tensor, optional): The mask of tgt keys per batch.
Default: ``None``.
- **memory_key_padding_mask** (Tensor, optional): The mask of memory keys per batch.
Default: ``None``.
Outputs:
Tensor.
@ -585,7 +613,7 @@ class Transformer(Cell):
def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Cell] = 'relu', custom_encoder: Optional[Cell] = None,
activation: Union[str, Cell, callable] = 'relu', custom_encoder: Optional[Cell] = None,
custom_decoder: Optional[Cell] = None, layer_norm_eps: float = 1e-5,
batch_first: bool = False, norm_first: bool = False):
super(Transformer, self).__init__()
@ -606,7 +634,9 @@ class Transformer(Cell):
decoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
self._reset_parameters()
for _, p in self.parameters_and_names():
if p.ndim > 1:
p.set_data(initializer('xavier_uniform', p.shape, p.dtype))
self.d_model = d_model
self.nhead = nhead
@ -617,27 +647,24 @@ class Transformer(Cell):
memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None):
is_batched = src.ndim == 3
if not self.batch_first and src.shape[1] != tgt.shape[1] and is_batched:
raise RuntimeError("the batch number of src and tgt must be equal")
if self.batch_first and src.shape[0] != tgt.shape[0] and is_batched:
raise RuntimeError("the batch number of src and tgt must be equal")
if self.batch_first:
src_batch_size = src.shape[0]
tgt_batch_size = src.shape[0]
else:
src_batch_size = src.shape[1]
tgt_batch_size = src.shape[1]
if src_batch_size != tgt_batch_size and is_batched:
raise ValueError("The number of batch size for 'src' and 'tgt' must be equal.")
if src.shape[-1] != self.d_model or tgt.shape[-1] != self.d_model:
raise RuntimeError("the feature number of src and tgt must be equal to d_model")
raise ValueError("The number of features for 'src' and 'tgt' must be equal to `d_model`.")
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
memory = self.encoder(src, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask)
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
return output
def _reset_parameters(self):
r"""Initiate parameters in the transformer model."""
for _, p in self.parameters_and_names():
if p.ndim > 1:
p.set_data(initializer('xavier_uniform', p.shape, p.dtype))
def _get_activation_fn(activation: str):
if activation == "relu":
@ -645,7 +672,7 @@ def _get_activation_fn(activation: str):
if activation == "gelu":
return ops.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
raise ValueError(f"The activation must be relu/gelu, but get {activation}")
def _get_clones(module, N):

View File

@ -5804,12 +5804,21 @@ def linear(x, w, b):
def _in_projection(q, k, v, w_q, w_k, w_v, b_q=None, b_k=None, b_v=None):
"""in projection function"""
Eq, Ek, Ev = q.shape[-1], k.shape[-1], v.shape[-1]
assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
w_q_shape, w_k_shape, w_v_shape = w_q.shape, w_k.shape, w_v.shape
b_q_shape, b_k_shape, b_v_shape = b_q.shape, b_k.shape, b_v.shape
if w_q_shape != (Eq, Eq):
raise ValueError(f"Expecting query weights shape of {(Eq, Eq)}, but got {w_q_shape}")
if w_k_shape != (Eq, Ek):
raise ValueError(f"Expecting key weights shape of {(Eq, Ek)}, but got {w_k_shape}")
if w_v_shape != (Eq, Ev):
raise ValueError(f"Expecting value weights shape of {(Eq, Ev)}, but got {w_v_shape}")
if b_q is not None and b_q_shape != (Eq,):
raise ValueError(f"Expecting query bias shape of {(Eq,)}, but got {b_q_shape}")
if b_k is not None and b_k_shape != (Eq,):
raise ValueError(f"Expecting key bias shape of {(Eq,)}, but got {b_k_shape}")
if b_v is not None and b_v_shape != (Eq,):
raise ValueError(f"Expecting value bias shape of {(Eq,)}, but got {b_v_shape}")
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
@ -5842,7 +5851,8 @@ def _scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, is_ca
query = query / scaling_factor
if is_causal:
L = query.shape[-2], S = key.shape[-2]
L = query.shape[-2]
S = key.shape[-2]
attn_mask = ops.ones((L, S), mstype.bool_).tril()
attn = ops.matmul(query, key.swapaxes(-2, -1) / scaling_factor)
@ -5858,48 +5868,46 @@ def _scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, is_ca
def _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads):
"""
Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
and returns if the input is batched or not.
Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
Check the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
and returns whether the input is batched.
"""
# Shape check.
if query.ndim == 3:
# Batched Inputs
is_batched = True
assert key.ndim == 3 and value.ndim == 3, \
("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
f" but found {key.ndim}-D and {value.ndim}-D tensors respectively")
if key_padding_mask is not None:
assert key_padding_mask.ndim == 2, \
("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
f" but found {key_padding_mask.ndim}-D tensor instead")
if attn_mask is not None:
assert attn_mask.ndim in (2, 3), \
("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
f" but found {attn_mask.ndim}-D tensor instead")
if key.ndim != 3 or value.ndim != 3:
raise ValueError(f"For batched `query`, the `key` and `value` must be 3D tensor, "
f"but got `key` with {key.ndim}D and `value` with {value.ndim}D.")
if key_padding_mask is not None and key_padding_mask.ndim != 2:
raise ValueError(f"For batched `query`, the `key_padding_mask` must be `None` or 2D, "
f"but got `key_padding_mask` with {key_padding_mask.ndim}D.")
if attn_mask is not None and attn_mask.ndim not in (2, 3):
raise ValueError(f"For batched `query`, the `attn_mask` must be `None`, 2-D or 3-D, "
f"but got `attn_mask` with{attn_mask.ndim}D.")
elif query.ndim == 2:
# Unbatched Inputs
is_batched = False
assert key.ndim == 2 and value.ndim == 2, \
("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
f" but found {key.ndim}-D and {value.ndim}-D tensors respectively")
if key.ndim != 2 or value.ndim != 2:
raise ValueError(f"For batched `query`, the `key` and `value` must be 2D tensor, "
f"but got `key` with {key.ndim}D and `value` with {value.ndim}D.")
if key_padding_mask is not None:
assert key_padding_mask.ndim == 1, \
("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
f" but found {key_padding_mask.ndim}-D tensor instead")
if key_padding_mask is not None and key_padding_mask.ndim != 1:
raise ValueError(f"For batched `query`, the `key_padding_mask` must be `None` or 1D, "
f"but got `key_padding_mask` with {key_padding_mask.ndim}D.")
if attn_mask is not None:
assert attn_mask.ndim in (2, 3), \
("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
f" but found {attn_mask.ndim}-D tensor instead")
if attn_mask.ndim not in (2, 3):
raise ValueError(f"For batched `query`, the `attn_mask` must be `None`, 2-D or 3-D, "
f"but got `attn_mask` with{attn_mask.ndim}D.")
if attn_mask.ndim == 3:
expected_shape = (num_heads, query.shape[0], key.shape[0])
assert attn_mask.shape == expected_shape, \
(f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
if attn_mask.shape != expected_shape:
raise ValueError(f"The shape of `attn_mask` must to be {expected_shape}, "
f"but got {attn_mask.shape}.")
else:
raise AssertionError(
f"query should be unbatched 2D or batched 3D tensor but received {query.ndim}-D query tensor")
raise ValueError(f"The `query` should be unbatched 2D or batched 3D tensor, "
f"but got `query` with {query.ndim}D.")
return is_batched
@ -5925,28 +5933,34 @@ def multi_head_attention_forward(query, key, value, embed_dim_to_check, num_head
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != mstype.bool_ and not ops.is_floating_point(key_padding_mask):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported")
assert embed_dim == embed_dim_to_check, \
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
raise ValueError("The `key_padding_mask` only supports bool and floating dtypes.")
if embed_dim != embed_dim_to_check:
raise ValueError(f"The `embed_dim` should be {embed_dim_to_check}, but got {embed_dim}.")
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if head_dim * num_heads != embed_dim:
raise ValueError(f"The `embed_dim` {embed_dim} can not be divisible by `num_heads` {num_heads}.")
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert key.shape[:2] == value.shape[:2], \
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
# allow MHA to have different embedding dims when separate projection weights are used
if key.shape[:2] != value.shape[:2]:
raise ValueError(f"The sequence length and batch dims of `key`: {key.shape[:2]} do not match "
f"`value`: {value.shape[:2]}.")
else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
if key.shape != value.shape:
raise ValueError(f"The shape of `key` {key.shape} does not match `value` {value.shape}.")
# compute in-projection
if not use_separate_proj_weight:
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
if in_proj_weight is None:
raise ValueError("`use_separate_proj_weight` is ``False`` but `in_proj_weight` got ``None``.")
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias, k_is_v, q_is_k)
else:
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
if q_proj_weight is None:
raise ValueError("`use_separate_proj_weight` is ``True`` but `q_proj_weight` got ``None``.")
if k_proj_weight is None:
raise ValueError("`use_separate_proj_weight` is ``True`` but `k_proj_weight` got ``None``.")
if v_proj_weight is None:
raise ValueError("`use_separate_proj_weight` is ``True`` but `v_proj_weight` got ``None``.")
if in_proj_bias is None:
b_q = b_k = b_v = None
else:
@ -5958,27 +5972,30 @@ def multi_head_attention_forward(query, key, value, embed_dim_to_check, num_head
if attn_mask.dtype == mstype.uint8:
attn_mask = attn_mask.astype(mstype.bool_)
else:
assert ops.is_floating_point(attn_mask) or attn_mask.dtype == mstype.bool_, \
f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
# ensure attn_mask's dim is 3
if not ops.is_floating_point(attn_mask) and attn_mask.dtype != mstype.bool_:
raise ValueError(f"`attn_mask` only support float, byte, and bool types, "
f"but got not {attn_mask.dtype}.")
# ensure attn_mask's ndim is 3
if attn_mask.ndim == 2:
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, "
"but should be {correct_2d_size}.")
raise ValueError(f"The shape of the `attn_mask` should be {correct_2d_size}, "
f"but got {attn_mask.shape}.")
attn_mask = attn_mask.expand_dims(0)
elif attn_mask.ndim == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, "
"but should be {correct_3d_size}.")
raise ValueError(f"The shape of the `attn_mask` should be {correct_3d_size}, "
f"but got {attn_mask.shape}.")
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.ndim} is not supported")
raise ValueError(f"The ndim of `attn_mask` only support 2 or 3, "
f"but got {attn_mask.ndim}.")
# add bias along batch dimension
if bias_k is not None and bias_v is not None:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
if static_k is not None:
raise ValueError("The bias_k cannot be added to static_k.")
if static_v is not None:
raise ValueError("The bias_v cannot be added to static_v.")
k = ops.cat([k, bias_k.repeat(1, bsz, 1)])
v = ops.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
@ -5986,31 +6003,32 @@ def multi_head_attention_forward(query, key, value, embed_dim_to_check, num_head
if key_padding_mask is not None:
key_padding_mask = ops.pad(key_padding_mask, (0, 1))
else:
assert bias_k is None
assert bias_v is None
if bias_k is not None or bias_v is not None:
raise ValueError("The bias_k and bias_v should be ``None``"
"at the same time.")
# reshape q, k, v for multihead attention and make em batch first
q = q.view(tgt_len, bsz * num_heads, head_dim).swapaxes(0, 1)
if static_k is None:
k = k.view(k.shape[0], bsz * num_heads, head_dim).swapaxes(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_k.shape[0] == bsz * num_heads, \
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert static_k.shape[2] == head_dim, \
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
if static_k.shape[0] != bsz * num_heads:
raise ValueError(f"The shape[0] of `static_k` should be {bsz * num_heads}, "
f"but got {static_k.shape[0]}")
if static_k.shape[2] != head_dim:
raise ValueError(f"The shape[2] of `static_k` should be {head_dim}, "
f"but got {static_k.shape[2]}")
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).swapaxes(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_v.shape[0] == bsz * num_heads, \
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert static_v.shape[2] == head_dim, \
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
if static_v.shape[0] != bsz * num_heads:
raise ValueError(f"The shape[0] of `static_v` should be {bsz * num_heads}, "
f"but got {static_v.shape[0]}")
if static_v.shape[2] != head_dim:
raise ValueError(f"The shape[2] of `static_v` should be {head_dim}, "
f"but got {static_v.shape[2]}")
v = static_v
# add zero attention along batch dimension (now first)
if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = ops.cat([k, ops.zeros(zero_attn_shape, dtype=k.dtype)], axis=1)
@ -6020,14 +6038,14 @@ def multi_head_attention_forward(query, key, value, embed_dim_to_check, num_head
if key_padding_mask is not None:
key_padding_mask = ops.pad(key_padding_mask, (0, 1))
# update source sequence length after adjustments
src_len = k.shape[1]
# merge key padding and attention masks
if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, src_len), \
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
if key_padding_mask.shape != (bsz, src_len):
raise ValueError(f"The shape of `key_padding_mask` should be {(bsz, src_len)}, "
f"but got {key_padding_mask.shape}.")
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
if attn_mask is None:
attn_mask = key_padding_mask
@ -6036,7 +6054,6 @@ def multi_head_attention_forward(query, key, value, embed_dim_to_check, num_head
else:
attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
# convert mask to float
if attn_mask is not None and attn_mask.dtype == mstype.bool_:
new_attn_mask = ops.zeros_like(attn_mask, dtype=q.dtype)
new_attn_mask.masked_fill(attn_mask, float("-inf"))
@ -6059,13 +6076,11 @@ def multi_head_attention_forward(query, key, value, embed_dim_to_check, num_head
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.shape[1])
# optionally average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
if average_attn_weights:
attn_output_weights = attn_output_weights.sum(axis=1) / num_heads
if not is_batched:
# squeeze the output if input was unbatched
attn_output = attn_output.squeeze(1)
attn_output_weights = attn_output_weights.squeeze(0)
return attn_output, attn_output_weights

View File

@ -93,7 +93,7 @@ def test_transformerencoder_square_input(training, jit):
src_mask = Tensor([[0, 1], [0, 0]]).to(ms.bool_)
def forward(x, mask):
result = model(x, mask=mask)
result = model(x, src_mask=mask)
return result
if jit: