!49032 add new Transformer api

Merge pull request !49032 from 吕昱峰(Nate.River)/transformer
This commit is contained in:
i-robot 2023-02-21 02:07:09 +00:00 committed by Gitee
commit d9476a264f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
12 changed files with 1308 additions and 4 deletions

View File

@ -84,6 +84,21 @@ MindSpore中 `mindspore.nn` 接口与上一版本相比,新增、删除和支
mindspore.nn.LSTM
mindspore.nn.LSTMCell
Transformer层
-----------------
.. mscnplatformautosummary::
:toctree: nn
:nosignatures:
:template: classtemplate.rst
mindspore.nn.MultiheadAttention
mindspore.nn.TransformerEncoderLayer
mindspore.nn.TransformerDecoderLayer
mindspore.nn.TransformerEncoder
mindspore.nn.TransformerDecoder
mindspore.nn.Transformer
嵌入层
-----------------

View File

@ -0,0 +1,50 @@
mindspore.nn.MultiheadAttention
========================================
.. py:class:: mindspore.nn.MultiheadAttention(embed_dim, num_heads, dropout=0., has_bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False)
论文 `Attention Is All You Need <https://arxiv.org/pdf/1706.03762v5.pdf>`_ 中所述的多头注意力的实现。给定query向量key向量和value注意力计算流程如下
.. math::
MultiHeadAttention(query, key, vector) = Dropout(Concat(head_1, \dots, head_h)W^O)
其中, :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)` 。注意:输出层的投影计算中带有偏置参数。
如果query、key和value相同则上述即为自注意力机制的计算过程。
参数:
- **embed_dim** (int) - 模型的总维数。
- **num_heads** (int) - 并行注意力头的数量。``num_heads`` 需要能够被 ``embed_dim`` 整除(每个头的维数为 ``embed_dim // num_heads``)。
- **dropout** (float) - 应用到输入 ``attn_output_weights``上的随机丢弃比例. 默认: ``0.0`` (不丢弃)。
- **has_bias** (bool) - 是否给输入、输出投射层添加偏置。默认: ``True``
- **add_bias_kv** (bool) - 是否给key、value序列的0维添加偏置。默认 ``False``
- **add_zero_attn** (bool) - 是否给key、value序列的1维添加0。默认 ``False``
- **kdim** (int) - key的总特征数。默认 ``None`` (即 ``kdim=embed_dim``)。
- **vdim** (int) - value的总特征数。默认``None`` (即 ``vdim=embed_dim``)。
- **batch_first** (bool) - 如果为 ``True``则输入输出Tensor的shape为 (batch, seq, feature)否则shape为(seq, batch, feature)。 默认: ``False``
输入:
- **query** (Tensor) - Query矩阵。当输入非Batch数据时Shape为 :math:`(L, E_q)` 。当输入Batch数据参数 ``batch_first=False``Shape为 :math:`(L, N, E_q)`
``batch_first=True``Shape为 :math:`(N, L, E_q)`。其中, :math:`L` 为目标序列的长度, :math:`N` 为batch size:math:`E_q` 为Query矩阵的维数 ``embed_dim``
注意力机制通过Query与Key-Value运算以生成最终输出。详情请见"Attention Is All You Need"。
- **key** (Tensor) - Key矩阵。当输入非Batch数据时Shape为 :math:`(S, E_k)` 。当输入Batch数据参数 ``batch_first=False``Shape为 :math:`(S, N, E_k)`
``batch_first=True``Shape为 :math:`(N, S, E_k)`。其中, :math:`S` 为源序列的长度, :math:`N` 为batch size:math:`E_k` 为Key矩阵的维数 ``kdim``。详情请见:"Attention Is All You Need"。
- **value** (Tensor) - Value矩阵。当输入非Batch数据时Shape为 :math:`(S, E_v)` 。当输入Batch数据参数 ``batch_first=False``Shape为 :math:`(S, N, E_v)`
``batch_first=True``Shape为 :math:`(N, S, E_v)`。其中, :math:`S` 为源序列的长度, :math:`N` 为batch size:math:`E_v` 为Key矩阵的维数 ``vdim``。详情请见:"Attention Is All You Need"。
- **key_padding_mask** (Tensor) - 如果指定此值则表示Shape为 :math:`(N, S)`的掩码将被用于 ``key``。当输入非Batch数据时Shape为 :math:`(S)`
如果输入Tensor为Bool类型``key`` 中对应为 ``True`` 的位置将在Attention计算时被忽略。如果输入Tensor为Float类型则将直接与 ``key`` 相加。默认:``None``
- **need_weights** (bool) - 是否需要返回 ``attn_output_weights``,如果为 ``True``,则输出包含``attn_output_weights``。默认:``True``
- **attn_mask** (Tensor) - 如果指定此值则表示Shape为 :math:`(L, S)`:math:`(N\cdot\text{num\_heads}, L, S)` 的掩码将被用于Attention计算。其中 :math:`N` 为batch size
:math:`L` 为目标序列长度,:math:`S` 为源序列长度。如果输入为2维矩阵则将自动沿batch维广播至3维矩阵。若为3维矩阵则允许沿batch维使用不同的掩码。如果输入Tensor为Bool类型则值为 ``True`` 对应位置允许被注意力计算。如果输入Tensor为Float类型则将直接与注意力权重相加。默认``None``
- **average_attn_weights** (bool) - 如果为 ``True`` 则返回值 ``attn_weights`` 为注意力头的平均值。如果为 ``False``,则``attn_weights``分别返回每个注意力头的值。
本参数仅在 ``need_weights=True`` 时生效。默认: ``True``
输出:
Tuple表示一个包含(`attn_output`, `attn_output_weights`)的元组。
- **attn_output** - 注意力机制的输出。当输入非Batch数据时Shape为 :math:`(L, E)` 。当输入Batch数据 参数 ``batch_first=False``Shape为 :math:`(L, N, E)`
``batch_first=True``Shape为 :math:`(N, L, E)`。其中, :math:`L` 为目标序列的长度, :math:`N` 为batch size :math:`E` 为模型的总维数 ``embed_dim``
- **attn_output_weights** - 仅当 ``need_weights=True`` 时返回。如果 ``average_attn_weights=True``,则返回值 ``attn_weights`` 为注意力头的平均值。当输入非Batch数据时
Shape为 :math:`(L, S)` 当输入Batch数据时Shape为 :math:`(N, L, S)`。其中 :math:`N` 为batch size :math:`L` 为目标序列的长度,:math:`S` 为源序列长度。
如果 ``average_attn_weights=False`` 分别返回每个注意力头的值。当输入非Batch数据时Shape为 :math:`(\text{num\_heads}, L, S)` 当输入Batch数据时Shape为
:math:`(N, \text{num\_heads}, L, S)`

View File

@ -0,0 +1,33 @@
mindspore.nn.Transformer
========================================
.. py:class:: mindspore.nn.Transformer(d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Cell] = 'relu', custom_encoder: Optional[Cell] = None, custom_decoder: Optional[Cell] = None, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False)
Transformer模块包括编码器和解码器。本模块与原论文的实现不同原论文在LayerNorm前使用了残差模块。且默认的隐藏层激活函数为 `gelu` 。详情可见 `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_
参数:
- **d_model** (int) - Encoder或Decoder输入的特征数。默认512。
- **nhead** (int) - 注意力头的数量。
- **num_encoder_layers** (int) - Encoder的层数。默认6。
- **num_decoder_layers** (int) - Decoder的层数。默认6。
- **dim_feedforward** (int) - FeedForward层的维数。默认2048。
- **dropout** (float) - 随机丢弃比例。默认0.1。
- **activation** (str, Cell) - Encoder或Decoder中间层的激活函数可以输入字符串、函数接口或激活函数层实例。支持"relu"、"gelu"。默认:"relu"。
- **custom_encoder** (Cell) - 自定义Encoder层。默认None。
- **custom_decoder** (Cell) - 自定义Decoder层。默认None。
- **layer_norm_eps** (float) - LayerNorm层的eps值默认1e-5。
- **batch_first** (bool) - 如果为 ``True`` 则输入输出Shape为(batch, seq, feature)反之Shape为(seq, batch, feature)。默认: ``False``
- **norm_first** (bool) - 如果为 ``True``则LayerNorm层位于MultiheadAttention层和FeedForward层之前反之位于其后。默认 ``False``
输入:
- **src** (Tensor) - 源序列。
- **tgt** (Tensor) - 目标序列。
- **src_mask** (Tensor) - 源序列的掩码矩阵 (可选)。默认None。
- **tgt_mask** (Tensor) - 目标序列的掩码矩阵 (可选)。默认None。
- **memory_mask** (Tensor) - memory序列的掩码矩阵 (可选)。默认None。
- **src_key_padding_mask** (Tensor) - 源序列Key矩阵的掩码矩阵 (可选)。默认None。
- **tgt_key_padding_mask** (Tensor) - 目标序列Key矩阵的掩码矩阵 (可选)。默认None。
- **memory_key_padding_mask** (Tensor) - memory序列Key矩阵的掩码矩阵 (可选)。默认None。
输出:
Tensor。

View File

@ -0,0 +1,22 @@
mindspore.nn.TransformerDecoder
========================================
.. py:class:: mindspore.nn.TransformerDecoder(decoder_layer, num_layers, norm=None)
Transformer的解码器。多层 `TransformerDecoderLayer` 的堆叠包括Self Attention层、MultiheadAttention层和FeedForward层。
参数:
- **decoder_layer** (Cell) - TransformerDecoderLayer()的实例。
- **num_layers** (int) - 解码器层数。
- **norm** (Cell) - 自定义LayerNorm层可选
输入:
- **tgt** (Tensor) - 目标序列。
- **memory** (Tensor) - TransformerEncoder的最后一层输出序列。
- **tgt_mask** (Tensor) - 目标序列的掩码矩阵 (可选)。默认None。
- **memory_mask** (Tensor) - memory序列的掩码矩阵 (可选)。默认None。
- **tgt_key_padding_mask** (Tensor) - 目标序列Key矩阵的掩码矩阵 (可选)。默认None。
- **memory_key_padding_mask** (Tensor) - memory序列Key矩阵的掩码矩阵 (可选)。默认None。
输出:
Tensor。

View File

@ -0,0 +1,27 @@
mindspore.nn.TransformerDecoderLayer
========================================
.. py:class:: mindspore.nn.TransformerDecoderLayer(d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Cell] = 'relu', layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False)
Transformer的解码器层。Transformer解码器的单层实现包括Self Attention层、MultiheadAttention层和FeedForward层。
参数:
- **d_model** (int) - 输入的特征数。
- **nhead** (int) - 注意力头的数量。
- **dim_feedforward** (int) - FeedForward层的维数。默认2048。
- **dropout** (float) - 随机丢弃比例。默认0.1。
- **activation** (str, Cell) - 中间层的激活函数,可以输入字符串、函数接口或激活函数层实例。支持"relu"、"gelu"。默认:"relu"。
- **layer_norm_eps** (float) - LayerNorm层的eps值默认1e-5。
- **batch_first** (bool) - 如果为 ``True`` 则输入输出Shape为(batch, seq, feature)反之Shape为(seq, batch, feature)。默认: ``False``
- **norm_first** (bool) - 如果为 ``True`` 则LayerNorm层位于Self Attention层、MultiheadAttention层和FeedForward层之前反之位于其后。默认 ``False``
输入:
- **tgt** (Tensor) - 目标序列。
- **memory** (Tensor) - TransformerEncoder的最后一层输出序列。
- **tgt_mask** (Tensor) - 目标序列的掩码矩阵 (可选)。默认None。
- **memory_mask** (Tensor) - memory序列的掩码矩阵 (可选)。默认None。
- **tgt_key_padding_mask** (Tensor) - 目标序列Key矩阵的掩码矩阵 (可选)。默认None。
- **memory_key_padding_mask** (Tensor) - memory序列Key矩阵的掩码矩阵 (可选)。默认None。
输出:
Tensor。

View File

@ -0,0 +1,19 @@
mindspore.nn.TransformerEncoder
========================================
.. py:class:: mindspore.nn.TransformerEncoder(encoder_layer, num_layers, norm=None)
Transformer编码器模块多层 `TransformerEncoderLayer` 的堆叠包括MultiheadAttention层和FeedForward层。可以使用此模块构造BERT(https://arxiv.org/abs/1810.04805)模型。
参数:
- **encoder_layer** (Cell) - TransformerEncoderLayer()的实例。
- **num_layers** (int) - 编码器层数。
- **norm** (Cell) - 自定义LayerNorm层可选
输入:
- **src** (Tensor) - 源序列。
- **src_mask** (Tensor) - 源序列的掩码矩阵 (可选)。默认None。
- **src_key_padding_mask** (Tensor) - 源序列Key矩阵的掩码矩阵 (可选)。默认None。
输出:
Tensor。

View File

@ -0,0 +1,24 @@
mindspore.nn.TransformerEncoderLayer
========================================
.. py:class:: mindspore.nn.TransformerEncoderLayer(d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Cell] = 'relu', layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False)
Transformer的编码器层。Transformer编码器的单层实现包括MultiheadAttention层和FeedForward层。
参数:
- **d_model** (int) - 输入的特征数。
- **nhead** (int) - 注意力头的数量。
- **dim_feedforward** (int) - FeedForward层的维数。默认2048。
- **dropout** (float) - 随机丢弃比例。默认0.1。
- **activation** (str, Cell) - 中间层的激活函数,可以输入字符串、函数接口或激活函数层实例。支持'relu''gelu'。默认:'relu'。
- **layer_norm_eps** (float) - LayerNorm层的eps值默认1e-5。
- **batch_first** (bool) - 如果为 ``True`` 则输入输出Shape为(batch, seq, feature)反之Shape为(seq, batch, feature)。默认: ``False``
- **norm_first** (bool) - 如果为 ``True`` 则LayerNorm层位于MultiheadAttention层和FeedForward层之前反之位于其后。默认 ``False``
输入:
- **src** (Tensor) - 源序列。
- **src_mask** (Tensor) - 源序列的掩码矩阵 (可选)。默认None。
- **src_key_padding_mask** (Tensor) - 源序列Key矩阵的掩码矩阵 (可选)。默认None。
输出:
Tensor。

View File

@ -53,7 +53,7 @@ Wrapper Layer
mindspore.nn.WithEvalCell
mindspore.nn.WithLossCell
Convolutional Neural Network Layer
Convolutional Layer
----------------------------------
.. msplatformautosummary::
@ -69,7 +69,7 @@ Convolutional Neural Network Layer
mindspore.nn.Conv3dTranspose
mindspore.nn.Unfold
Recurrent Neural Network Layer
Recurrent Layer
------------------------------
.. msplatformautosummary::
@ -84,6 +84,21 @@ Recurrent Neural Network Layer
mindspore.nn.LSTM
mindspore.nn.LSTMCell
Transformer Layer
---------------------------
.. mscnplatformautosummary::
:toctree: nn
:nosignatures:
:template: classtemplate.rst
mindspore.nn.MultiheadAttention
mindspore.nn.TransformerEncoderLayer
mindspore.nn.TransformerDecoderLayer
mindspore.nn.TransformerEncoder
mindspore.nn.TransformerDecoder
mindspore.nn.Transformer
Embedding Layer
---------------
@ -96,7 +111,7 @@ Embedding Layer
mindspore.nn.EmbeddingLookup
mindspore.nn.MultiFieldEmbeddingLookup
Nonlinear Activation Function Layer
Nonlinear Activation Layer
-----------------------------------
.. msplatformautosummary::

View File

@ -20,7 +20,7 @@ The high-level components(Cells) used to construct the neural network.
from __future__ import absolute_import
from mindspore.nn.layer import activation, normalization, container, conv, basic, embedding, pooling, \
image, math, combined, timedistributed, thor_layer, rnns, rnn_cells, padding, dense
image, math, combined, timedistributed, thor_layer, rnns, rnn_cells, padding, dense, transformer
from mindspore.nn.layer.activation import *
from mindspore.nn.layer.normalization import *
from mindspore.nn.layer.container import *
@ -35,6 +35,7 @@ from mindspore.nn.layer.image import *
from mindspore.nn.layer.math import *
from mindspore.nn.layer.combined import *
from mindspore.nn.layer.timedistributed import *
from mindspore.nn.layer.transformer import *
from mindspore.nn.layer.channel_shuffle import ChannelShuffle
from mindspore.nn.layer.thor_layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor
from mindspore.nn.layer.padding import ConstantPad1d, ConstantPad2d, ConstantPad3d, ReflectionPad1d, \
@ -55,6 +56,7 @@ __all__.extend(image.__all__)
__all__.extend(math.__all__)
__all__.extend(combined.__all__)
__all__.extend(timedistributed.__all__)
__all__.extend(transformer.__all__)
__all__.extend(thor_layer.__all__)
__all__.extend(padding.__all__)
__all__.extend(channel_shuffle.__all__)

View File

@ -0,0 +1,652 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Transformer Cells module, include TransformerEncoderLayer, TransformerDecoderLayer,
TransformerEncoder, TransformerDecoder, Transformer.
"""
import copy
import math
from typing import Union, Optional
import mindspore
import mindspore.ops as ops
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer, XavierNormal, XavierUniform, \
HeUniform, Uniform, _calculate_fan_in_and_fan_out
from mindspore.ops.function.nn_func import multi_head_attention_forward
from mindspore.nn.cell import Cell
from .basic import Dense, Dropout
from .activation import ReLU, GELU
from .normalization import LayerNorm
from .container import CellList
__all__ = ['MultiheadAttention', 'TransformerEncoderLayer', 'TransformerDecoderLayer',
'TransformerEncoder', 'TransformerDecoder', 'Transformer']
class _Linear(Dense):
def __init__(self, in_channels, out_channels, has_bias=True):
fan_in, _ = _calculate_fan_in_and_fan_out((out_channels, in_channels))
bound = 1 / math.sqrt(fan_in)
super().__init__(in_channels, out_channels, weight_init=HeUniform(math.sqrt(5)),
bias_init=Uniform(bound), has_bias=has_bias, activation=None)
class MultiheadAttention(Cell):
r"""
This is an implementation of multihead attention in the paper `Attention is all you need
<https://arxiv.org/pdf/1706.03762v5.pdf>`_. Given the query vector with source length, and the
key and value vector with target length, the attention will be performed as the following
.. math::
MultiHeadAttention(query, key, vector) = Concat(head_1, \dots, head_h)W^O
where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`. The default is with a bias.
if query, key and value tensor is same, then it will be self attention.
Args:
embed_dim (int): Total dimension of the model.
num_heads (int): Number of parallel attention heads. Note that ``embed_dim`` will be split
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
dropout (float): Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
has_bias (bool): Whether adds bias to input / output projection layers. Default: ``True``.
add_bias_kv (bool): Whether adds bias to the key and value sequences at dim=0. Default: ``False``.
add_zero_attn (bool): Whether adds a new batch of zeros to the key and value sequences at dim=1.
Default: ``False``.
kdim (int): Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
vdim (int): Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
batch_first (bool): If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
Inputs:
- **query** (Tensor): Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)`
when ``batch_first=False`` or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L`is the
target sequence length, :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension
``embed_dim``. Queries are compared against key-value pairs to produce the output.
See "Attention Is All You Need" for more details.
- **key** (Tensor): Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)`
when ``batch_first=False`` or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the
source sequence length, :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension
``kdim``. See "Attention Is All You Need" for more details.
- **value** (Tensor): Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)`
when ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the
source sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension
``vdim``. See "Attention Is All You Need" for more details.
- **key_padding_mask** (Tensor): If specified, a mask of shape :math:`(N, S)` indicating which elements
within ``key`` to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`,
shape should be :math:`(S)`. Binary and byte masks are supported.
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
- **need_weights** (bool): If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
Default: ``True``.
- **attn_mask** (Tensor): If specified, a 2D or 3D mask preventing attention to certain positions. Must
be of shape :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.
- **average_attn_weights** (bool): If true, indicates that the returned ``attn_weights`` should be averaged
across heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only
has an effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
Outputs:
Tuple, a tuple contains(`attn_output`, `attn_output_weights`)
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
embedding dimension ``embed_dim``.
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or
:math:`(N, \text{num\_heads}, L, S)`.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
def __init__(self, embed_dim, num_heads, dropout=0., has_bias=True, add_bias_kv=False,
add_zero_attn=False, kdim=None, vdim=None, batch_first=False):
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, embed_dim)), 'q_proj_weight')
self.k_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.kdim)), 'k_proj_weight')
self.v_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.vdim)), 'v_proj_weight')
self.in_proj_weight = None
else:
self.in_proj_weight = Parameter(initializer(XavierUniform(), (3 * embed_dim, embed_dim)), 'in_proj_weight')
self.q_proj_weight = None
self.k_proj_weight = None
self.v_proj_weight = None
if has_bias:
self.in_proj_bias = Parameter(initializer('zeros', (3 * embed_dim)), 'in_proj_bias')
else:
self.in_proj_bias = None
self.out_proj = _Linear(embed_dim, embed_dim, has_bias=has_bias)
if add_bias_kv:
self.bias_k = Parameter(initializer(XavierNormal(), (1, 1, embed_dim)), 'bias_k')
self.bias_v = Parameter(initializer(XavierNormal(), (1, 1, embed_dim)), 'bias_v')
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.k_is_v = False
self.q_is_k = False
def __call__(self, *args, **kwargs):
query = kwargs.get('query', args[0])
key = kwargs.get('key', args[1])
value = kwargs.get('value', args[2])
self.k_is_v = key is value
self.q_is_k = query is key
return super().__call__(*args, **kwargs)
def construct(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True):
is_batched = query.ndim == 3
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != mindspore.bool_ and not ops.is_floating_point(key_padding_mask):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported")
if self.batch_first and is_batched:
# make sure that the transpose op does not affect the "is" property
if self.k_is_v:
if self.q_is_k:
query = key = value = query.swapaxes(1, 0)
else:
query, key = [x.swapaxes(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [x.swapaxes(1, 0) for x in (query, key, value)]
if not self._qkv_same_embed_dim:
attn_output, attn_output_weights = multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights,
k_is_v=self.k_is_v, q_is_k=self.q_is_k)
else:
attn_output, attn_output_weights = multi_head_attention_forward(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout, self.out_proj.weight, self.out_proj.bias,
training=self.training,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask, average_attn_weights=average_attn_weights,
k_is_v=self.k_is_v, q_is_k=self.q_is_k)
if self.batch_first and is_batched:
attn_output = attn_output.swapaxes(1, 0)
if need_weights:
return attn_output, attn_output_weights
return (attn_output,)
class TransformerEncoderLayer(Cell):
r"""
Transformer Encoder Layer. This is an implementation of the single layer of the transformer
encoder layer, including multihead attention and feedward layer.
Args:
d_model (int): the number of expected features in the input (required).
nhead (int): the number of heads in the multiheadattention models (required).
dim_feedforward (int): the dimension of the feedforward network model (default=2048).
dropout (float): the dropout value (default=0.1).
activation (str, Cell): the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
layer_norm_eps (float): the eps value in layer normalization components (default=1e-5).
batch_first (bool): If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
norm_first (bool): if ``True``, layer norm is done prior to attention and feedforward
operations, respectively. Otherwise it's done after. Default: ``False`` (after).
Inputs:
- **src** (Tensor): the sequence to the encoder layer (required).
- **src_mask** (Tensor): the mask for the src sequence (optional).
- **src_key_padding_mask** (Tensor): the mask for the src keys per batch (optional).
Outputs:
Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = Tensor(np.random.rand(10, 32, 512), mindspore.float32)
>>> out = encoder_layer(src)
>>> # Alternatively, when batch_first=True:
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
>>> src = Tensor(np.random.rand(32, 10, 512), mindspore.float32)
>>> out = encoder_layer(src)
"""
__constants__ = ['batch_first', 'norm_first']
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Cell] = 'relu', layer_norm_eps: float = 1e-5, batch_first: bool = False,
norm_first: bool = False):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
# Implementation of Feedforward model
self.linear1 = _Linear(d_model, dim_feedforward)
self.dropout = Dropout(1-dropout)
self.linear2 = _Linear(dim_feedforward, d_model)
self.norm_first = norm_first
self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps)
self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps)
self.dropout1 = Dropout(1-dropout)
self.dropout2 = Dropout(1-dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
activation = _get_activation_fn(activation)
# We can't test self.activation in forward() in TorchScript,
# so stash some information about it instead.
if activation is ops.relu or isinstance(activation, ReLU):
self.activation_relu_or_gelu = 1
elif activation is ops.gelu or isinstance(activation, GELU):
self.activation_relu_or_gelu = 2
else:
self.activation_relu_or_gelu = 0
self.activation = activation
def construct(self, src: Tensor, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None):
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != mindspore.bool_ and not ops.is_floating_point(src_key_padding_mask):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported")
x = src
if self.norm_first:
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
x = self.norm2(x + self._ff_block(x))
return x
def _sa_block(self, x, attn_mask, key_padding_mask):
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False)[0]
return self.dropout1(x)
def _ff_block(self, x):
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
class TransformerDecoderLayer(Cell):
r"""
Transformer Decoder Layer. This is an implementation of the single layer of the transformer
decoder layer, including self-attention, cross attention and feedward layer.
Args:
d_model (int): the number of expected features in the input (required).
nhead (int): the number of heads in the multiheadattention models (required).
dim_feedforward (int): the dimension of the feedforward network model (default=2048).
dropout (float): the dropout value (default=0.1).
activation (str, Cell): the activation function of the intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
layer_norm_eps (float): the eps value in layer normalization components (default=1e-5).
batch_first (bool): If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
norm_first (bool): if ``True``, layer norm is done prior to self attention, multihead
attention and feedforward operations, respectively. Otherwise it's done after.
Default: ``False`` (after).
Inputs:
- **tgt** (Tensor): the sequence to the decoder layer (required).
- **memory** (Tensor): the sequence from the last layer of the encoder (required).
- **tgt_mask** (Tensor): the mask for the tgt sequence (optional).
- **memory_mask** (Tensor): the mask for the memory sequence (optional).
- **tgt_key_padding_mask** (Tensor): the mask for the tgt keys per batch (optional).
- **memory_key_padding_mask** (Tensor): the mask for the memory keys per batch (optional).
Outputs:
Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> memory = Tensor(np.random.rand(10, 32, 512), mindspore.float32)
>>> tgt = Tensor(np.random.rand(20, 32, 512), mindspore.float32)
>>> out = decoder_layer(tgt, memory)
>>> # Alternatively, when ``batch_first`` is ``True``:
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
>>> memory = Tensor(np.random.rand(32, 10, 512), mindspore.float32)
>>> tgt = Tensor(np.random.rand(32, 20, 512), mindspore.float32)
>>> out = decoder_layer(tgt, memory)
"""
__constants__ = ['batch_first', 'norm_first']
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Cell] = 'relu', layer_norm_eps: float = 1e-5, batch_first: bool = False,
norm_first: bool = False):
super(TransformerDecoderLayer, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
# Implementation of Feedforward model
self.linear1 = _Linear(d_model, dim_feedforward)
self.dropout = Dropout(1-dropout)
self.linear2 = _Linear(dim_feedforward, d_model)
self.norm_first = norm_first
self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps)
self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps)
self.norm3 = LayerNorm((d_model,), epsilon=layer_norm_eps)
self.dropout1 = Dropout(1-dropout)
self.dropout2 = Dropout(1-dropout)
self.dropout3 = Dropout(1-dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
else:
self.activation = activation
def construct(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None):
x = tgt
if self.norm_first:
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)
x = x + self._ff_block(self.norm3(x))
else:
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask))
x = self.norm3(x + self._ff_block(x))
return x
def _sa_block(self, x, attn_mask, key_padding_mask):
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False)[0]
return self.dropout1(x)
def _mha_block(self, x, mem, attn_mask, key_padding_mask):
x = self.multihead_attn(x, mem, mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False)[0]
return self.dropout2(x)
def _ff_block(self, x):
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout3(x)
class TransformerEncoder(Cell):
r"""
Transformer Encoder module with multi-layer stacked of `TransformerEncoderLayer`, including multihead self
attention and feedforward layer. Users can build the
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
Args:
encoder_layer (Cell): an instance of the TransformerEncoderLayer() class (required).
num_layers (int): the number of sub-encoder-layers in the encoder (required).
norm (Cell): the layer normalization component (optional).
Inputs:
- **src** (Tensor): the sequence to the encoder (required).
- **mask** (Tensor): the mask for the src sequence (optional).
- **src_key_padding_mask** (Tensor): the mask for the src keys per batch (optional).
Outputs:
Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
>>> src = Tensor(np.random.rand(10, 32, 512), mindspore.float32)
>>> out = transformer_encoder(src)
"""
__constants__ = ['norm']
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def construct(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None):
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != mindspore.bool_ and not ops.is_floating_point(src_key_padding_mask):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported")
output = src
src_key_padding_mask_for_layers = src_key_padding_mask
for mod in self.layers:
output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask_for_layers)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerDecoder(Cell):
r"""
Transformer Decoder module with multi-layer stacked of `TransformerDecoderLayer`, including multihead self
attention, cross attention and feedforward layer.
Args:
decoder_layer (Cell): an instance of the TransformerDecoderLayer() class (required).
num_layers (int): the number of sub-decoder-layers in the decoder (required).
norm (Cell): the layer normalization component (optional).
Inputs:
- **tgt** (Tensor): the sequence to the decoder (required).
- **memory** (Tensor): the sequence from the last layer of the encoder (required).
- **tgt_mask** (Tensor): the mask for the tgt sequence (optional).
- **memory_mask** (Tensor): the mask for the memory sequence (optional).
- **tgt_key_padding_mask** (Tensor): the mask for the tgt keys per batch (optional).
- **memory_key_padding_mask** (Tensor): the mask for the memory keys per batch (optional).
Outputs:
Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
>>> memory = Tensor(np.random.rand(10, 32, 512), mindspore.float32)
>>> tgt = Tensor(np.random.rand(20, 32, 512), mindspore.float32)
>>> out = transformer_decoder(tgt, memory)
"""
__constants__ = ['norm']
def __init__(self, decoder_layer, num_layers, norm=None):
super(TransformerDecoder, self).__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def construct(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None):
output = tgt
for mod in self.layers:
output = mod(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
if self.norm is not None:
output = self.norm(output)
return output
class Transformer(Cell):
r"""
Transformer module including encoder and decoder. The difference with the original implements is the module use
the residual addition before the layer normalization. And the default hidden act is `gelu`.
The details can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
Args:
d_model (int): the number of expected features in the encoder/decoder inputs (default=512).
nhead (int): the number of heads in the multiheadattention models (default=8).
num_encoder_layers (int): the number of sub-encoder-layers in the encoder (default=6).
num_decoder_layers (int): the number of sub-decoder-layers in the decoder (default=6).
dim_feedforward (int): the dimension of the feedforward network model (default=2048).
dropout (float): the dropout value (default=0.1).
activation (str, Cell): the activation function of encoder/decoder intermediate layer, can be a string
("relu" or "gelu") or a unary callable. Default: relu
custom_encoder (Cell): custom encoder (default=None).
custom_decoder (Cell): custom decoder (default=None).
layer_norm_eps (float): the eps value in layer normalization components (default=1e-5).
batch_first (bool): If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before
other attention and feedforward operations, otherwise after. Default: ``False`` (after).
Inputs:
- **src** (Tensor): the sequence to the encoder (required).
- **tgt** (Tensor): the sequence to the decoder (required).
- **src_mask** (Tensor): the additive mask for the src sequence (optional).
- **tgt_mask** (Tensor): the additive mask for the tgt sequence (optional).
- **memory_mask** (Tensor): the additive mask for the encoder output (optional).
- **src_key_padding_mask** (Tensor): the ByteTensor mask for src keys per batch (optional).
- **tgt_key_padding_mask** (Tensor): the ByteTensor mask for tgt keys per batch (optional).
- **memory_key_padding_mask** (Tensor): the ByteTensor mask for memory keys per batch (optional).
Outputs:
Tensor.
Examples:
>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
>>> src = Tensor(np.random.rand(10, 32, 512), mindspore.float32)
>>> tgt = Tensor(np.random.rand(20, 32, 512), mindspore.float32)
>>> out = transformer_model(src, tgt)
"""
def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,
num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,
activation: Union[str, Cell] = 'relu', custom_encoder: Optional[Cell] = None,
custom_decoder: Optional[Cell] = None, layer_norm_eps: float = 1e-5,
batch_first: bool = False, norm_first: bool = False):
super(Transformer, self).__init__()
if custom_encoder is not None:
self.encoder = custom_encoder
else:
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout,
activation, layer_norm_eps, batch_first, norm_first)
encoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps)
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
if custom_decoder is not None:
self.decoder = custom_decoder
else:
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
activation, layer_norm_eps, batch_first, norm_first)
decoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
self.batch_first = batch_first
def construct(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None):
is_batched = src.ndim == 3
if not self.batch_first and src.shape[1] != tgt.shape[1] and is_batched:
raise RuntimeError("the batch number of src and tgt must be equal")
if self.batch_first and src.shape[0] != tgt.shape[0] and is_batched:
raise RuntimeError("the batch number of src and tgt must be equal")
if src.shape[-1] != self.d_model or tgt.shape[-1] != self.d_model:
raise RuntimeError("the feature number of src and tgt must be equal to d_model")
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
return output
def _reset_parameters(self):
r"""Initiate parameters in the transformer model."""
for _, p in self.parameters_and_names():
if p.ndim > 1:
p.set_data(initializer('xavier_uniform', p.shape, p.dtype))
def _get_activation_fn(activation: str):
if activation == "relu":
return ops.relu
if activation == "gelu":
return ops.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
def _get_clones(module, N):
return CellList([copy.deepcopy(module) for i in range(N)])

View File

@ -5807,6 +5807,283 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
return triplet_margin_loss_op(anchor, positive, negative, margin_tensor)
def linear(x, w, b):
out = ops.matmul(x, w.swapaxes(-1, -2))
if b is not None:
out = out + b
return out
def _in_projection(q, k, v, w_q, w_k, w_v, b_q=None, b_k=None, b_v=None):
"""in projection function"""
Eq, Ek, Ev = q.shape[-1], k.shape[-1], v.shape[-1]
assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}"
assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}"
assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
def _in_projection_packed(q, k, v, w, b, k_is_v, q_is_k):
"""in projecktion packed function"""
E = q.shape[-1]
if k_is_v:
if q_is_k:
# self-attention
return linear(q, w, b).tensor_split(3, axis=-1)
# encoder-decoder attention
w_q, w_kv = w.split([E, E * 2])
if b is None:
b_q = b_kv = None
else:
b_q, b_kv = b.split([E, E * 2])
return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).tensor_split(2, axis=-1)
w_q, w_k, w_v = w.tensor_split(3)
if b is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = b.tensor_split(3)
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
def _scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, is_causal, is_training):
"""scaled dot product attention"""
embed_size = query.shape[-1]
scaling_factor = Tensor(embed_size, mstype.float32).sqrt().sqrt()
query = query / scaling_factor
if is_causal:
L = query.shape[-2], S = key.shape[-2]
attn_mask = ops.ones((L, S), mstype.bool_).tril()
attn = ops.matmul(query, key.swapaxes(-2, -1) / scaling_factor)
if attn_mask is not None:
attn = attn + attn_mask
attn = ops.softmax(attn, -1)
if dropout_p > 0. and is_training:
attn = ops.dropout(attn, dropout_p)
output = ops.matmul(attn, value)
return (output, attn)
def _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads):
"""
Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
and returns if the input is batched or not.
Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
"""
# Shape check.
if query.ndim == 3:
# Batched Inputs
is_batched = True
assert key.ndim == 3 and value.ndim == 3, \
("For batched (3-D) `query`, expected `key` and `value` to be 3-D"
f" but found {key.ndim}-D and {value.ndim}-D tensors respectively")
if key_padding_mask is not None:
assert key_padding_mask.ndim == 2, \
("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
f" but found {key_padding_mask.ndim}-D tensor instead")
if attn_mask is not None:
assert attn_mask.ndim in (2, 3), \
("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
f" but found {attn_mask.ndim}-D tensor instead")
elif query.ndim == 2:
# Unbatched Inputs
is_batched = False
assert key.ndim == 2 and value.ndim == 2, \
("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
f" but found {key.ndim}-D and {value.ndim}-D tensors respectively")
if key_padding_mask is not None:
assert key_padding_mask.ndim == 1, \
("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
f" but found {key_padding_mask.ndim}-D tensor instead")
if attn_mask is not None:
assert attn_mask.ndim in (2, 3), \
("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
f" but found {attn_mask.ndim}-D tensor instead")
if attn_mask.ndim == 3:
expected_shape = (num_heads, query.shape[0], key.shape[0])
assert attn_mask.shape == expected_shape, \
(f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}")
else:
raise AssertionError(
f"query should be unbatched 2D or batched 3D tensor but received {query.ndim}-D query tensor")
return is_batched
def multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight,
in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
out_proj_bias, training=True, key_padding_mask=None, attn_mask=None,
use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None,
v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=True,
is_causal=False, k_is_v=False, q_is_k=False):
"""multi head attetion forward function"""
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
if not is_batched:
query = query.expand_dims(1)
key = key.expand_dims(1)
value = value.expand_dims(1)
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.expand_dims(0)
tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != mstype.bool_ and not ops.is_floating_point(key_padding_mask):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported")
assert embed_dim == embed_dim_to_check, \
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert key.shape[:2] == value.shape[:2], \
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
else:
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
# compute in-projection
if not use_separate_proj_weight:
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias, k_is_v, q_is_k)
else:
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None:
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = in_proj_bias.tensor_split(3)
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
# prep attention mask
if attn_mask is not None:
if attn_mask.dtype == mstype.uint8:
attn_mask = attn_mask.astype(mstype.bool_)
else:
assert ops.is_floating_point(attn_mask) or attn_mask.dtype == mstype.bool_, \
f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
# ensure attn_mask's dim is 3
if attn_mask.ndim == 2:
correct_2d_size = (tgt_len, src_len)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, "
"but should be {correct_2d_size}.")
attn_mask = attn_mask.expand_dims(0)
elif attn_mask.ndim == 3:
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
if attn_mask.shape != correct_3d_size:
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, "
"but should be {correct_3d_size}.")
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.ndim} is not supported")
# add bias along batch dimension
if bias_k is not None and bias_v is not None:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
k = ops.cat([k, bias_k.repeat(1, bsz, 1)])
v = ops.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = ops.pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = ops.pad(key_padding_mask, (0, 1))
else:
assert bias_k is None
assert bias_v is None
# reshape q, k, v for multihead attention and make em batch first
q = q.view(tgt_len, bsz * num_heads, head_dim).swapaxes(0, 1)
if static_k is None:
k = k.view(k.shape[0], bsz * num_heads, head_dim).swapaxes(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_k.shape[0] == bsz * num_heads, \
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert static_k.shape[2] == head_dim, \
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).swapaxes(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert static_v.shape[0] == bsz * num_heads, \
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert static_v.shape[2] == head_dim, \
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v
# add zero attention along batch dimension (now first)
if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = ops.cat([k, ops.zeros(zero_attn_shape, dtype=k.dtype)], axis=1)
v = ops.cat([v, ops.zeros(zero_attn_shape, dtype=v.dtype)], axis=1)
if attn_mask is not None:
attn_mask = ops.pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = ops.pad(key_padding_mask, (0, 1))
# update source sequence length after adjustments
src_len = k.shape[1]
# merge key padding and attention masks
if key_padding_mask is not None:
assert key_padding_mask.shape == (bsz, src_len), \
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
if attn_mask is None:
attn_mask = key_padding_mask
elif attn_mask.dtype == mstype.bool_:
attn_mask = attn_mask.logical_or(key_padding_mask)
else:
attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
# convert mask to float
if attn_mask is not None and attn_mask.dtype == mstype.bool_:
new_attn_mask = ops.zeros_like(attn_mask, dtype=q.dtype)
new_attn_mask.masked_fill(attn_mask, float("-inf"))
attn_mask = new_attn_mask
if attn_mask is not None:
if attn_mask.shape[0] == 1:
attn_mask = attn_mask.expand_dims(0)
else:
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
q = q.view(bsz, num_heads, tgt_len, head_dim)
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
attn_output, attn_output_weights = _scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal, training)
attn_output = attn_output.transpose(2, 0, 1, 3).view(bsz * tgt_len, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.shape[1])
# optionally average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
if average_attn_weights:
attn_output_weights = attn_output_weights.sum(axis=1) / num_heads
if not is_batched:
# squeeze the output if input was unbatched
attn_output = attn_output.squeeze(1)
attn_output_weights = attn_output_weights.squeeze(0)
return attn_output, attn_output_weights
__all__ = [
'adaptive_avg_pool1d',
'adaptive_avg_pool2d',

View File

@ -0,0 +1,168 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
from mindspore import Tensor, ops
from mindspore.nn import MultiheadAttention, TransformerEncoderLayer, \
TransformerEncoder, TransformerDecoderLayer, TransformerDecoder, Transformer
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('dtype', [ms.float16, ms.float32])
@pytest.mark.parametrize('jit', [False, True])
def test_multihead_attention_pynative(dtype, jit):
"""
Feature: MultiheadAttention
Description: Verify the result of AMultiheadAttentionvgPool3d
Expectation: success
"""
embed_dim = 128
num_heads = 8
sl = 10
bs = 8
model = MultiheadAttention(embed_dim, num_heads).to_float(dtype)
q = Tensor(np.random.randn(sl, bs, embed_dim), dtype)
k = Tensor(np.random.randn(sl, bs, embed_dim), dtype)
v = Tensor(np.random.randn(sl, bs, embed_dim), dtype)
def forward(q, k, v):
out = model(q, k, v)
return out
if jit:
forward = ms.jit(forward)
out = forward(q, k, v)
assert q.shape == out[0].shape
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('training', [True, False])
@pytest.mark.parametrize('jit', [False, True])
def test_transformerencoder_square_input(training, jit):
"""
Feature: TransformerEncoder
Description: Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has
batch size == sequence length
Expectation: success
"""
model = TransformerEncoder(
TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True),
num_layers=2)
# set constant weights of the model
for _, p in model.parameters_and_names():
x = p.data
sz = x.view(-1).shape[0]
shape = x.shape
x = ops.cos(ops.arange(0, sz).astype(ms.float32).view(shape))
p.set_data(x)
if training:
model = model.set_train()
else:
model = model.set_train(False)
x = ops.arange(0, 16).reshape(2, 2, 4).astype(ms.float32)
src_mask = Tensor([[0, 1], [0, 0]]).to(ms.bool_)
def forward(x, mask):
result = model(x, mask=mask)
return result
if jit:
forward = ms.jit(forward)
result = forward(x, src_mask)
ref_output = ms.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351],
[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]],
[[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689],
[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]],
ms.float32)
assert tuple(result.shape) == tuple(ref_output.shape)
np.allclose(result.asnumpy(), ref_output.asnumpy(), rtol=1e-7, atol=1e-5)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('jit', [False, True])
def test_transformerdecoder(jit):
"""
Feature: TransformerDecoder
Description: Test shape (batch size, sequence length, embedding dimension)
Expectation: success
"""
decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8)
transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6)
memory = Tensor(np.random.rand(10, 32, 512), ms.float32)
tgt = Tensor(np.random.rand(20, 32, 512), ms.float32)
def forward(tgt, memory):
out = transformer_decoder(tgt, memory)
return out
if jit:
forward = ms.jit(forward)
result = forward(tgt, memory)
assert result.shape == tgt.shape
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.env_onecard
@pytest.mark.parametrize('jit', [False, True])
def test_transformer(jit):
"""
Feature: Transformer
Description: Test shape (batch size, sequence length, embedding dimension)
Expectation: success
"""
transformer_model = Transformer(nhead=16, num_encoder_layers=12)
src = Tensor(np.random.rand(10, 32, 512), ms.float32)
tgt = Tensor(np.random.rand(20, 32, 512), ms.float32)
def forward(src, tgt):
out = transformer_model(src, tgt)
return out
if jit:
forward = ms.jit(forward)
result = forward(src, tgt)
assert result.shape == tgt.shape