From 245bcdf47dd381f134e88ce63ff13ee3a326989d Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Fri, 17 Feb 2023 10:12:26 +0800 Subject: [PATCH] add new Transformer api --- docs/api/api_python/mindspore.nn.rst | 15 + .../nn/mindspore.nn.MultiheadAttention.rst | 50 ++ .../nn/mindspore.nn.Transformer.rst | 33 + .../nn/mindspore.nn.TransformerDecoder.rst | 22 + .../mindspore.nn.TransformerDecoderLayer.rst | 27 + .../nn/mindspore.nn.TransformerEncoder.rst | 19 + .../mindspore.nn.TransformerEncoderLayer.rst | 24 + docs/api/api_python_en/mindspore.nn.rst | 21 +- .../python/mindspore/nn/layer/__init__.py | 4 +- .../python/mindspore/nn/layer/transformer.py | 652 ++++++++++++++++++ .../python/mindspore/ops/function/nn_func.py | 277 ++++++++ tests/st/nn/test_transformer.py | 168 +++++ 12 files changed, 1308 insertions(+), 4 deletions(-) create mode 100644 docs/api/api_python/nn/mindspore.nn.MultiheadAttention.rst create mode 100644 docs/api/api_python/nn/mindspore.nn.Transformer.rst create mode 100644 docs/api/api_python/nn/mindspore.nn.TransformerDecoder.rst create mode 100644 docs/api/api_python/nn/mindspore.nn.TransformerDecoderLayer.rst create mode 100644 docs/api/api_python/nn/mindspore.nn.TransformerEncoder.rst create mode 100644 docs/api/api_python/nn/mindspore.nn.TransformerEncoderLayer.rst create mode 100644 mindspore/python/mindspore/nn/layer/transformer.py create mode 100644 tests/st/nn/test_transformer.py diff --git a/docs/api/api_python/mindspore.nn.rst b/docs/api/api_python/mindspore.nn.rst index 0c4fb5815ec..c0a239a2389 100644 --- a/docs/api/api_python/mindspore.nn.rst +++ b/docs/api/api_python/mindspore.nn.rst @@ -84,6 +84,21 @@ MindSpore中 `mindspore.nn` 接口与上一版本相比,新增、删除和支 mindspore.nn.LSTM mindspore.nn.LSTMCell +Transformer层 +----------------- + +.. mscnplatformautosummary:: + :toctree: nn + :nosignatures: + :template: classtemplate.rst + + mindspore.nn.MultiheadAttention + mindspore.nn.TransformerEncoderLayer + mindspore.nn.TransformerDecoderLayer + mindspore.nn.TransformerEncoder + mindspore.nn.TransformerDecoder + mindspore.nn.Transformer + 嵌入层 ----------------- diff --git a/docs/api/api_python/nn/mindspore.nn.MultiheadAttention.rst b/docs/api/api_python/nn/mindspore.nn.MultiheadAttention.rst new file mode 100644 index 00000000000..e5361ba608e --- /dev/null +++ b/docs/api/api_python/nn/mindspore.nn.MultiheadAttention.rst @@ -0,0 +1,50 @@ +mindspore.nn.MultiheadAttention +======================================== + +.. py:class:: mindspore.nn.MultiheadAttention(embed_dim, num_heads, dropout=0., has_bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False) + + 论文 `Attention Is All You Need `_ 中所述的多头注意力的实现。给定query向量,key向量和value,注意力计算流程如下: + + .. math:: + MultiHeadAttention(query, key, vector) = Dropout(Concat(head_1, \dots, head_h)W^O) + + 其中, :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)` 。注意:输出层的投影计算中带有偏置参数。 + + 如果query、key和value相同,则上述即为自注意力机制的计算过程。 + + 参数: + - **embed_dim** (int) - 模型的总维数。 + - **num_heads** (int) - 并行注意力头的数量。``num_heads`` 需要能够被 ``embed_dim`` 整除(每个头的维数为 ``embed_dim // num_heads``)。 + - **dropout** (float) - 应用到输入 ``attn_output_weights``上的随机丢弃比例. 默认: ``0.0`` (不丢弃)。 + - **has_bias** (bool) - 是否给输入、输出投射层添加偏置。默认: ``True``。 + - **add_bias_kv** (bool) - 是否给key、value序列的0维添加偏置。默认: ``False``。 + - **add_zero_attn** (bool) - 是否给key、value序列的1维添加0。默认: ``False``。 + - **kdim** (int) - key的总特征数。默认: ``None`` (即 ``kdim=embed_dim``)。 + - **vdim** (int) - value的总特征数。默认:``None`` (即 ``vdim=embed_dim``)。 + - **batch_first** (bool) - 如果为 ``True``,则输入输出Tensor的shape为 (batch, seq, feature),否则shape为(seq, batch, feature)。 默认: ``False`` 。 + + 输入: + - **query** (Tensor) - Query矩阵。当输入非Batch数据时,Shape为: :math:`(L, E_q)` 。当输入Batch数据,参数 ``batch_first=False`` 时,Shape为 :math:`(L, N, E_q)` , + 当 ``batch_first=True`` 时,Shape为 :math:`(N, L, E_q)`。其中, :math:`L` 为目标序列的长度, :math:`N` 为batch size,:math:`E_q` 为Query矩阵的维数 ``embed_dim``。 + 注意力机制通过Query与Key-Value运算以生成最终输出。详情请见:"Attention Is All You Need"。 + - **key** (Tensor) - Key矩阵。当输入非Batch数据时,Shape为: :math:`(S, E_k)` 。当输入Batch数据,参数 ``batch_first=False`` 时,Shape为 :math:`(S, N, E_k)` , + 当 ``batch_first=True`` 时,Shape为 :math:`(N, S, E_k)`。其中, :math:`S` 为源序列的长度, :math:`N` 为batch size,:math:`E_k` 为Key矩阵的维数 ``kdim``。详情请见:"Attention Is All You Need"。 + - **value** (Tensor) - Value矩阵。当输入非Batch数据时,Shape为: :math:`(S, E_v)` 。当输入Batch数据,参数 ``batch_first=False`` 时,Shape为 :math:`(S, N, E_v)` , + 当 ``batch_first=True`` 时,Shape为 :math:`(N, S, E_v)`。其中, :math:`S` 为源序列的长度, :math:`N` 为batch size,:math:`E_v` 为Key矩阵的维数 ``vdim``。详情请见:"Attention Is All You Need"。 + - **key_padding_mask** (Tensor) - 如果指定此值,则表示Shape为 :math:`(N, S)`的掩码将被用于 ``key``。当输入非Batch数据时,Shape为: :math:`(S)` 。 + 如果输入Tensor为Bool类型,则 ``key`` 中对应为 ``True`` 的位置将在Attention计算时被忽略。如果输入Tensor为Float类型,则将直接与 ``key`` 相加。默认:``None``。 + - **need_weights** (bool) - 是否需要返回 ``attn_output_weights``,如果为 ``True``,则输出包含``attn_output_weights``。默认:``True``。 + - **attn_mask** (Tensor) - 如果指定此值,则表示Shape为 :math:`(L, S)` 或 :math:`(N\cdot\text{num\_heads}, L, S)` 的掩码将被用于Attention计算。其中 :math:`N` 为batch size, + :math:`L` 为目标序列长度,:math:`S` 为源序列长度。如果输入为2维矩阵,则将自动沿batch维广播至3维矩阵。若为3维矩阵,则允许沿batch维使用不同的掩码。如果输入Tensor为Bool类型,则值为 ``True`` 对应位置允许被注意力计算。如果输入Tensor为Float类型,则将直接与注意力权重相加。默认:``None``。 + - **average_attn_weights** (bool) - 如果为 ``True``, 则返回值 ``attn_weights`` 为注意力头的平均值。如果为 ``False``,则``attn_weights``分别返回每个注意力头的值。 + 本参数仅在 ``need_weights=True`` 时生效。默认: ``True`` 。 + + 输出: + Tuple,表示一个包含(`attn_output`, `attn_output_weights`)的元组。 + + - **attn_output** - 注意力机制的输出。当输入非Batch数据时,Shape为: :math:`(L, E)` 。当输入Batch数据, 参数 ``batch_first=False`` 时,Shape为 :math:`(L, N, E)` , + 当 ``batch_first=True`` 时,Shape为 :math:`(N, L, E)`。其中, :math:`L` 为目标序列的长度, :math:`N` 为batch size, :math:`E` 为模型的总维数 ``embed_dim``。 + - **attn_output_weights** - 仅当 ``need_weights=True`` 时返回。如果 ``average_attn_weights=True``,则返回值 ``attn_weights`` 为注意力头的平均值。当输入非Batch数据时, + Shape为: :math:`(L, S)` ,当输入Batch数据时,Shape为 :math:`(N, L, S)`。其中 :math:`N` 为batch size, :math:`L` 为目标序列的长度,:math:`S` 为源序列长度。 + 如果 ``average_attn_weights=False`` ,分别返回每个注意力头的值。当输入非Batch数据时,Shape为: :math:`(\text{num\_heads}, L, S)` ,当输入Batch数据时,Shape为 + :math:`(N, \text{num\_heads}, L, S)`。 diff --git a/docs/api/api_python/nn/mindspore.nn.Transformer.rst b/docs/api/api_python/nn/mindspore.nn.Transformer.rst new file mode 100644 index 00000000000..4a899fc6a24 --- /dev/null +++ b/docs/api/api_python/nn/mindspore.nn.Transformer.rst @@ -0,0 +1,33 @@ +mindspore.nn.Transformer +======================================== + +.. py:class:: mindspore.nn.Transformer(d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Cell] = 'relu', custom_encoder: Optional[Cell] = None, custom_decoder: Optional[Cell] = None, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False) + + Transformer模块,包括编码器和解码器。本模块与原论文的实现不同,原论文在LayerNorm前使用了残差模块。且默认的隐藏层激活函数为 `gelu` 。详情可见 `Attention is all you need `_ 。 + + 参数: + - **d_model** (int) - Encoder或Decoder输入的特征数。默认:512。 + - **nhead** (int) - 注意力头的数量。 + - **num_encoder_layers** (int) - Encoder的层数。默认:6。 + - **num_decoder_layers** (int) - Decoder的层数。默认:6。 + - **dim_feedforward** (int) - FeedForward层的维数。默认:2048。 + - **dropout** (float) - 随机丢弃比例。默认:0.1。 + - **activation** (str, Cell) - Encoder或Decoder中间层的激活函数,可以输入字符串、函数接口或激活函数层实例。支持"relu"、"gelu"。默认:"relu"。 + - **custom_encoder** (Cell) - 自定义Encoder层。默认:None。 + - **custom_decoder** (Cell) - 自定义Decoder层。默认:None。 + - **layer_norm_eps** (float) - LayerNorm层的eps值,默认:1e-5。 + - **batch_first** (bool) - 如果为 ``True`` 则输入输出Shape为(batch, seq, feature),反之,Shape为(seq, batch, feature)。默认: ``False``。 + - **norm_first** (bool) - 如果为 ``True``,则LayerNorm层位于MultiheadAttention层和FeedForward层之前,反之,位于其后。默认: ``False``。 + + 输入: + - **src** (Tensor) - 源序列。 + - **tgt** (Tensor) - 目标序列。 + - **src_mask** (Tensor) - 源序列的掩码矩阵 (可选)。默认:None。 + - **tgt_mask** (Tensor) - 目标序列的掩码矩阵 (可选)。默认:None。 + - **memory_mask** (Tensor) - memory序列的掩码矩阵 (可选)。默认:None。 + - **src_key_padding_mask** (Tensor) - 源序列Key矩阵的掩码矩阵 (可选)。默认:None。 + - **tgt_key_padding_mask** (Tensor) - 目标序列Key矩阵的掩码矩阵 (可选)。默认:None。 + - **memory_key_padding_mask** (Tensor) - memory序列Key矩阵的掩码矩阵 (可选)。默认:None。 + + 输出: + Tensor。 diff --git a/docs/api/api_python/nn/mindspore.nn.TransformerDecoder.rst b/docs/api/api_python/nn/mindspore.nn.TransformerDecoder.rst new file mode 100644 index 00000000000..c693ae373ac --- /dev/null +++ b/docs/api/api_python/nn/mindspore.nn.TransformerDecoder.rst @@ -0,0 +1,22 @@ +mindspore.nn.TransformerDecoder +======================================== + +.. py:class:: mindspore.nn.TransformerDecoder(decoder_layer, num_layers, norm=None) + + Transformer的解码器。多层 `TransformerDecoderLayer` 的堆叠,包括Self Attention层、MultiheadAttention层和FeedForward层。 + + 参数: + - **decoder_layer** (Cell) - TransformerDecoderLayer()的实例。 + - **num_layers** (int) - 解码器层数。 + - **norm** (Cell) - 自定义LayerNorm层(可选)。 + + 输入: + - **tgt** (Tensor) - 目标序列。 + - **memory** (Tensor) - TransformerEncoder的最后一层输出序列。 + - **tgt_mask** (Tensor) - 目标序列的掩码矩阵 (可选)。默认:None。 + - **memory_mask** (Tensor) - memory序列的掩码矩阵 (可选)。默认:None。 + - **tgt_key_padding_mask** (Tensor) - 目标序列Key矩阵的掩码矩阵 (可选)。默认:None。 + - **memory_key_padding_mask** (Tensor) - memory序列Key矩阵的掩码矩阵 (可选)。默认:None。 + + 输出: + Tensor。 diff --git a/docs/api/api_python/nn/mindspore.nn.TransformerDecoderLayer.rst b/docs/api/api_python/nn/mindspore.nn.TransformerDecoderLayer.rst new file mode 100644 index 00000000000..5582ffde15d --- /dev/null +++ b/docs/api/api_python/nn/mindspore.nn.TransformerDecoderLayer.rst @@ -0,0 +1,27 @@ +mindspore.nn.TransformerDecoderLayer +======================================== + +.. py:class:: mindspore.nn.TransformerDecoderLayer(d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Cell] = 'relu', layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False) + + Transformer的解码器层。Transformer解码器的单层实现,包括Self Attention层、MultiheadAttention层和FeedForward层。 + + 参数: + - **d_model** (int) - 输入的特征数。 + - **nhead** (int) - 注意力头的数量。 + - **dim_feedforward** (int) - FeedForward层的维数。默认:2048。 + - **dropout** (float) - 随机丢弃比例。默认:0.1。 + - **activation** (str, Cell) - 中间层的激活函数,可以输入字符串、函数接口或激活函数层实例。支持"relu"、"gelu"。默认:"relu"。 + - **layer_norm_eps** (float) - LayerNorm层的eps值,默认:1e-5。 + - **batch_first** (bool) - 如果为 ``True`` 则输入输出Shape为(batch, seq, feature),反之,Shape为(seq, batch, feature)。默认: ``False``。 + - **norm_first** (bool) - 如果为 ``True``, 则LayerNorm层位于Self Attention层、MultiheadAttention层和FeedForward层之前,反之,位于其后。默认: ``False``。 + + 输入: + - **tgt** (Tensor) - 目标序列。 + - **memory** (Tensor) - TransformerEncoder的最后一层输出序列。 + - **tgt_mask** (Tensor) - 目标序列的掩码矩阵 (可选)。默认:None。 + - **memory_mask** (Tensor) - memory序列的掩码矩阵 (可选)。默认:None。 + - **tgt_key_padding_mask** (Tensor) - 目标序列Key矩阵的掩码矩阵 (可选)。默认:None。 + - **memory_key_padding_mask** (Tensor) - memory序列Key矩阵的掩码矩阵 (可选)。默认:None。 + + 输出: + Tensor。 diff --git a/docs/api/api_python/nn/mindspore.nn.TransformerEncoder.rst b/docs/api/api_python/nn/mindspore.nn.TransformerEncoder.rst new file mode 100644 index 00000000000..ae89fe0bd7f --- /dev/null +++ b/docs/api/api_python/nn/mindspore.nn.TransformerEncoder.rst @@ -0,0 +1,19 @@ +mindspore.nn.TransformerEncoder +======================================== + +.. py:class:: mindspore.nn.TransformerEncoder(encoder_layer, num_layers, norm=None) + + Transformer编码器模块,多层 `TransformerEncoderLayer` 的堆叠,包括MultiheadAttention层和FeedForward层。可以使用此模块构造BERT(https://arxiv.org/abs/1810.04805)模型。 + + 参数: + - **encoder_layer** (Cell) - TransformerEncoderLayer()的实例。 + - **num_layers** (int) - 编码器层数。 + - **norm** (Cell) - 自定义LayerNorm层(可选)。 + + 输入: + - **src** (Tensor) - 源序列。 + - **src_mask** (Tensor) - 源序列的掩码矩阵 (可选)。默认:None。 + - **src_key_padding_mask** (Tensor) - 源序列Key矩阵的掩码矩阵 (可选)。默认:None。 + + 输出: + Tensor。 diff --git a/docs/api/api_python/nn/mindspore.nn.TransformerEncoderLayer.rst b/docs/api/api_python/nn/mindspore.nn.TransformerEncoderLayer.rst new file mode 100644 index 00000000000..c8e7e102631 --- /dev/null +++ b/docs/api/api_python/nn/mindspore.nn.TransformerEncoderLayer.rst @@ -0,0 +1,24 @@ +mindspore.nn.TransformerEncoderLayer +======================================== + +.. py:class:: mindspore.nn.TransformerEncoderLayer(d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Cell] = 'relu', layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False) + + Transformer的编码器层。Transformer编码器的单层实现,包括MultiheadAttention层和FeedForward层。 + + 参数: + - **d_model** (int) - 输入的特征数。 + - **nhead** (int) - 注意力头的数量。 + - **dim_feedforward** (int) - FeedForward层的维数。默认:2048。 + - **dropout** (float) - 随机丢弃比例。默认:0.1。 + - **activation** (str, Cell) - 中间层的激活函数,可以输入字符串、函数接口或激活函数层实例。支持'relu','gelu'。默认:'relu'。 + - **layer_norm_eps** (float) - LayerNorm层的eps值,默认:1e-5。 + - **batch_first** (bool) - 如果为 ``True`` 则输入输出Shape为(batch, seq, feature),反之,Shape为(seq, batch, feature)。默认: ``False``。 + - **norm_first** (bool) - 如果为 ``True``, 则LayerNorm层位于MultiheadAttention层和FeedForward层之前,反之,位于其后。默认: ``False``。 + + 输入: + - **src** (Tensor) - 源序列。 + - **src_mask** (Tensor) - 源序列的掩码矩阵 (可选)。默认:None。 + - **src_key_padding_mask** (Tensor) - 源序列Key矩阵的掩码矩阵 (可选)。默认:None。 + + 输出: + Tensor。 diff --git a/docs/api/api_python_en/mindspore.nn.rst b/docs/api/api_python_en/mindspore.nn.rst index e7cde17ca53..37875eaed5b 100644 --- a/docs/api/api_python_en/mindspore.nn.rst +++ b/docs/api/api_python_en/mindspore.nn.rst @@ -53,7 +53,7 @@ Wrapper Layer mindspore.nn.WithEvalCell mindspore.nn.WithLossCell -Convolutional Neural Network Layer +Convolutional Layer ---------------------------------- .. msplatformautosummary:: @@ -69,7 +69,7 @@ Convolutional Neural Network Layer mindspore.nn.Conv3dTranspose mindspore.nn.Unfold -Recurrent Neural Network Layer +Recurrent Layer ------------------------------ .. msplatformautosummary:: @@ -84,6 +84,21 @@ Recurrent Neural Network Layer mindspore.nn.LSTM mindspore.nn.LSTMCell +Transformer Layer +--------------------------- + +.. mscnplatformautosummary:: + :toctree: nn + :nosignatures: + :template: classtemplate.rst + + mindspore.nn.MultiheadAttention + mindspore.nn.TransformerEncoderLayer + mindspore.nn.TransformerDecoderLayer + mindspore.nn.TransformerEncoder + mindspore.nn.TransformerDecoder + mindspore.nn.Transformer + Embedding Layer --------------- @@ -96,7 +111,7 @@ Embedding Layer mindspore.nn.EmbeddingLookup mindspore.nn.MultiFieldEmbeddingLookup -Nonlinear Activation Function Layer +Nonlinear Activation Layer ----------------------------------- .. msplatformautosummary:: diff --git a/mindspore/python/mindspore/nn/layer/__init__.py b/mindspore/python/mindspore/nn/layer/__init__.py index af9eb845516..83fd3b10ea6 100644 --- a/mindspore/python/mindspore/nn/layer/__init__.py +++ b/mindspore/python/mindspore/nn/layer/__init__.py @@ -20,7 +20,7 @@ The high-level components(Cells) used to construct the neural network. from __future__ import absolute_import from mindspore.nn.layer import activation, normalization, container, conv, basic, embedding, pooling, \ - image, math, combined, timedistributed, thor_layer, rnns, rnn_cells, padding, dense + image, math, combined, timedistributed, thor_layer, rnns, rnn_cells, padding, dense, transformer from mindspore.nn.layer.activation import * from mindspore.nn.layer.normalization import * from mindspore.nn.layer.container import * @@ -35,6 +35,7 @@ from mindspore.nn.layer.image import * from mindspore.nn.layer.math import * from mindspore.nn.layer.combined import * from mindspore.nn.layer.timedistributed import * +from mindspore.nn.layer.transformer import * from mindspore.nn.layer.channel_shuffle import ChannelShuffle from mindspore.nn.layer.thor_layer import DenseThor, Conv2dThor, EmbeddingThor, EmbeddingLookupThor from mindspore.nn.layer.padding import ConstantPad1d, ConstantPad2d, ConstantPad3d, ReflectionPad1d, \ @@ -55,6 +56,7 @@ __all__.extend(image.__all__) __all__.extend(math.__all__) __all__.extend(combined.__all__) __all__.extend(timedistributed.__all__) +__all__.extend(transformer.__all__) __all__.extend(thor_layer.__all__) __all__.extend(padding.__all__) __all__.extend(channel_shuffle.__all__) diff --git a/mindspore/python/mindspore/nn/layer/transformer.py b/mindspore/python/mindspore/nn/layer/transformer.py new file mode 100644 index 00000000000..b226de14c70 --- /dev/null +++ b/mindspore/python/mindspore/nn/layer/transformer.py @@ -0,0 +1,652 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Transformer Cells module, include TransformerEncoderLayer, TransformerDecoderLayer, +TransformerEncoder, TransformerDecoder, Transformer. +""" +import copy +import math +from typing import Union, Optional +import mindspore +import mindspore.ops as ops +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer, XavierNormal, XavierUniform, \ + HeUniform, Uniform, _calculate_fan_in_and_fan_out +from mindspore.ops.function.nn_func import multi_head_attention_forward +from mindspore.nn.cell import Cell +from .basic import Dense, Dropout +from .activation import ReLU, GELU +from .normalization import LayerNorm +from .container import CellList + +__all__ = ['MultiheadAttention', 'TransformerEncoderLayer', 'TransformerDecoderLayer', + 'TransformerEncoder', 'TransformerDecoder', 'Transformer'] + + +class _Linear(Dense): + def __init__(self, in_channels, out_channels, has_bias=True): + fan_in, _ = _calculate_fan_in_and_fan_out((out_channels, in_channels)) + bound = 1 / math.sqrt(fan_in) + super().__init__(in_channels, out_channels, weight_init=HeUniform(math.sqrt(5)), + bias_init=Uniform(bound), has_bias=has_bias, activation=None) + + +class MultiheadAttention(Cell): + r""" + This is an implementation of multihead attention in the paper `Attention is all you need + `_. Given the query vector with source length, and the + key and value vector with target length, the attention will be performed as the following + + .. math:: + MultiHeadAttention(query, key, vector) = Concat(head_1, \dots, head_h)W^O + + where :math:`head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)`. The default is with a bias. + + if query, key and value tensor is same, then it will be self attention. + + Args: + embed_dim (int): Total dimension of the model. + num_heads (int): Number of parallel attention heads. Note that ``embed_dim`` will be split + across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + dropout (float): Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + has_bias (bool): Whether adds bias to input / output projection layers. Default: ``True``. + add_bias_kv (bool): Whether adds bias to the key and value sequences at dim=0. Default: ``False``. + add_zero_attn (bool): Whether adds a new batch of zeros to the key and value sequences at dim=1. + Default: ``False``. + kdim (int): Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). + vdim (int): Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). + batch_first (bool): If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Inputs: + - **query** (Tensor): Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` + when ``batch_first=False`` or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L`is the + target sequence length, :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension + ``embed_dim``. Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + - **key** (Tensor): Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` + when ``batch_first=False`` or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the + source sequence length, :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension + ``kdim``. See "Attention Is All You Need" for more details. + - **value** (Tensor): Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` + when ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the + source sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension + ``vdim``. See "Attention Is All You Need" for more details. + - **key_padding_mask** (Tensor): If specified, a mask of shape :math:`(N, S)` indicating which elements + within ``key`` to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, + shape should be :math:`(S)`. Binary and byte masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + - **need_weights** (bool): If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Default: ``True``. + - **attn_mask** (Tensor): If specified, a 2D or 3D mask preventing attention to certain positions. Must + be of shape :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + - **average_attn_weights** (bool): If true, indicates that the returned ``attn_weights`` should be averaged + across heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only + has an effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + + Outputs: + Tuple, a tuple contains(`attn_output`, `attn_output_weights`) + + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or + :math:`(N, \text{num\_heads}, L, S)`. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + """ + + def __init__(self, embed_dim, num_heads, dropout=0., has_bias=True, add_bias_kv=False, + add_zero_attn=False, kdim=None, vdim=None, batch_first=False): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, embed_dim)), 'q_proj_weight') + self.k_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.kdim)), 'k_proj_weight') + self.v_proj_weight = Parameter(initializer(XavierUniform(), (embed_dim, self.vdim)), 'v_proj_weight') + self.in_proj_weight = None + else: + self.in_proj_weight = Parameter(initializer(XavierUniform(), (3 * embed_dim, embed_dim)), 'in_proj_weight') + self.q_proj_weight = None + self.k_proj_weight = None + self.v_proj_weight = None + + if has_bias: + self.in_proj_bias = Parameter(initializer('zeros', (3 * embed_dim)), 'in_proj_bias') + else: + self.in_proj_bias = None + self.out_proj = _Linear(embed_dim, embed_dim, has_bias=has_bias) + + if add_bias_kv: + self.bias_k = Parameter(initializer(XavierNormal(), (1, 1, embed_dim)), 'bias_k') + self.bias_v = Parameter(initializer(XavierNormal(), (1, 1, embed_dim)), 'bias_v') + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + self.k_is_v = False + self.q_is_k = False + + def __call__(self, *args, **kwargs): + query = kwargs.get('query', args[0]) + key = kwargs.get('key', args[1]) + value = kwargs.get('value', args[2]) + self.k_is_v = key is value + self.q_is_k = query is key + return super().__call__(*args, **kwargs) + + def construct(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True): + is_batched = query.ndim == 3 + if key_padding_mask is not None: + _kpm_dtype = key_padding_mask.dtype + if _kpm_dtype != mindspore.bool_ and not ops.is_floating_point(key_padding_mask): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported") + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if self.k_is_v: + if self.q_is_k: + query = key = value = query.swapaxes(1, 0) + else: + query, key = [x.swapaxes(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.swapaxes(1, 0) for x in (query, key, value)] + + if not self._qkv_same_embed_dim: + attn_output, attn_output_weights = multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights, + k_is_v=self.k_is_v, q_is_k=self.q_is_k) + else: + attn_output, attn_output_weights = multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, average_attn_weights=average_attn_weights, + k_is_v=self.k_is_v, q_is_k=self.q_is_k) + + if self.batch_first and is_batched: + attn_output = attn_output.swapaxes(1, 0) + if need_weights: + return attn_output, attn_output_weights + return (attn_output,) + + +class TransformerEncoderLayer(Cell): + r""" + Transformer Encoder Layer. This is an implementation of the single layer of the transformer + encoder layer, including multihead attention and feedward layer. + + Args: + d_model (int): the number of expected features in the input (required). + nhead (int): the number of heads in the multiheadattention models (required). + dim_feedforward (int): the dimension of the feedforward network model (default=2048). + dropout (float): the dropout value (default=0.1). + activation (str, Cell): the activation function of the intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. Default: relu + layer_norm_eps (float): the eps value in layer normalization components (default=1e-5). + batch_first (bool): If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + norm_first (bool): if ``True``, layer norm is done prior to attention and feedforward + operations, respectively. Otherwise it's done after. Default: ``False`` (after). + + Inputs: + - **src** (Tensor): the sequence to the encoder layer (required). + - **src_mask** (Tensor): the mask for the src sequence (optional). + - **src_key_padding_mask** (Tensor): the mask for the src keys per batch (optional). + + Outputs: + Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = Tensor(np.random.rand(10, 32, 512), mindspore.float32) + >>> out = encoder_layer(src) + >>> # Alternatively, when batch_first=True: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True) + >>> src = Tensor(np.random.rand(32, 10, 512), mindspore.float32) + >>> out = encoder_layer(src) + """ + __constants__ = ['batch_first', 'norm_first'] + + def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, + activation: Union[str, Cell] = 'relu', layer_norm_eps: float = 1e-5, batch_first: bool = False, + norm_first: bool = False): + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first) + # Implementation of Feedforward model + self.linear1 = _Linear(d_model, dim_feedforward) + self.dropout = Dropout(1-dropout) + self.linear2 = _Linear(dim_feedforward, d_model) + + self.norm_first = norm_first + self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps) + self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps) + self.dropout1 = Dropout(1-dropout) + self.dropout2 = Dropout(1-dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + activation = _get_activation_fn(activation) + + # We can't test self.activation in forward() in TorchScript, + # so stash some information about it instead. + if activation is ops.relu or isinstance(activation, ReLU): + self.activation_relu_or_gelu = 1 + elif activation is ops.gelu or isinstance(activation, GELU): + self.activation_relu_or_gelu = 2 + else: + self.activation_relu_or_gelu = 0 + self.activation = activation + + def construct(self, src: Tensor, src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None): + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != mindspore.bool_ and not ops.is_floating_point(src_key_padding_mask): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported") + + x = src + if self.norm_first: + x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + x = x + self._ff_block(self.norm2(x)) + else: + x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) + x = self.norm2(x + self._ff_block(x)) + + return x + + def _sa_block(self, x, attn_mask, key_padding_mask): + x = self.self_attn(x, x, x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False)[0] + return self.dropout1(x) + + def _ff_block(self, x): + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class TransformerDecoderLayer(Cell): + r""" + Transformer Decoder Layer. This is an implementation of the single layer of the transformer + decoder layer, including self-attention, cross attention and feedward layer. + + Args: + d_model (int): the number of expected features in the input (required). + nhead (int): the number of heads in the multiheadattention models (required). + dim_feedforward (int): the dimension of the feedforward network model (default=2048). + dropout (float): the dropout value (default=0.1). + activation (str, Cell): the activation function of the intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. Default: relu + layer_norm_eps (float): the eps value in layer normalization components (default=1e-5). + batch_first (bool): If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + norm_first (bool): if ``True``, layer norm is done prior to self attention, multihead + attention and feedforward operations, respectively. Otherwise it's done after. + Default: ``False`` (after). + + Inputs: + - **tgt** (Tensor): the sequence to the decoder layer (required). + - **memory** (Tensor): the sequence from the last layer of the encoder (required). + - **tgt_mask** (Tensor): the mask for the tgt sequence (optional). + - **memory_mask** (Tensor): the mask for the memory sequence (optional). + - **tgt_key_padding_mask** (Tensor): the mask for the tgt keys per batch (optional). + - **memory_key_padding_mask** (Tensor): the mask for the memory keys per batch (optional). + + Outputs: + Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = Tensor(np.random.rand(10, 32, 512), mindspore.float32) + >>> tgt = Tensor(np.random.rand(20, 32, 512), mindspore.float32) + >>> out = decoder_layer(tgt, memory) + >>> # Alternatively, when ``batch_first`` is ``True``: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True) + >>> memory = Tensor(np.random.rand(32, 10, 512), mindspore.float32) + >>> tgt = Tensor(np.random.rand(32, 20, 512), mindspore.float32) + >>> out = decoder_layer(tgt, memory) + """ + __constants__ = ['batch_first', 'norm_first'] + + def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, + activation: Union[str, Cell] = 'relu', layer_norm_eps: float = 1e-5, batch_first: bool = False, + norm_first: bool = False): + super(TransformerDecoderLayer, self).__init__() + self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first) + self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first) + # Implementation of Feedforward model + self.linear1 = _Linear(d_model, dim_feedforward) + self.dropout = Dropout(1-dropout) + self.linear2 = _Linear(dim_feedforward, d_model) + + self.norm_first = norm_first + self.norm1 = LayerNorm((d_model,), epsilon=layer_norm_eps) + self.norm2 = LayerNorm((d_model,), epsilon=layer_norm_eps) + self.norm3 = LayerNorm((d_model,), epsilon=layer_norm_eps) + self.dropout1 = Dropout(1-dropout) + self.dropout2 = Dropout(1-dropout) + self.dropout3 = Dropout(1-dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + self.activation = _get_activation_fn(activation) + else: + self.activation = activation + + def construct(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None): + x = tgt + if self.norm_first: + x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask) + x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask) + x = x + self._ff_block(self.norm3(x)) + else: + x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask)) + x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask)) + x = self.norm3(x + self._ff_block(x)) + + return x + + def _sa_block(self, x, attn_mask, key_padding_mask): + x = self.self_attn(x, x, x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False)[0] + return self.dropout1(x) + + def _mha_block(self, x, mem, attn_mask, key_padding_mask): + x = self.multihead_attn(x, mem, mem, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False)[0] + return self.dropout2(x) + + def _ff_block(self, x): + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout3(x) + + +class TransformerEncoder(Cell): + r""" + Transformer Encoder module with multi-layer stacked of `TransformerEncoderLayer`, including multihead self + attention and feedforward layer. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + + Args: + encoder_layer (Cell): an instance of the TransformerEncoderLayer() class (required). + num_layers (int): the number of sub-encoder-layers in the encoder (required). + norm (Cell): the layer normalization component (optional). + + Inputs: + - **src** (Tensor): the sequence to the encoder (required). + - **mask** (Tensor): the mask for the src sequence (optional). + - **src_key_padding_mask** (Tensor): the mask for the src keys per batch (optional). + + Outputs: + Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) + >>> src = Tensor(np.random.rand(10, 32, 512), mindspore.float32) + >>> out = transformer_encoder(src) + """ + __constants__ = ['norm'] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def construct(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None): + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != mindspore.bool_ and not ops.is_floating_point(src_key_padding_mask): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported") + output = src + src_key_padding_mask_for_layers = src_key_padding_mask + for mod in self.layers: + output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask_for_layers) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(Cell): + r""" + Transformer Decoder module with multi-layer stacked of `TransformerDecoderLayer`, including multihead self + attention, cross attention and feedforward layer. + + Args: + decoder_layer (Cell): an instance of the TransformerDecoderLayer() class (required). + num_layers (int): the number of sub-decoder-layers in the decoder (required). + norm (Cell): the layer normalization component (optional). + + Inputs: + - **tgt** (Tensor): the sequence to the decoder (required). + - **memory** (Tensor): the sequence from the last layer of the encoder (required). + - **tgt_mask** (Tensor): the mask for the tgt sequence (optional). + - **memory_mask** (Tensor): the mask for the memory sequence (optional). + - **tgt_key_padding_mask** (Tensor): the mask for the tgt keys per batch (optional). + - **memory_key_padding_mask** (Tensor): the mask for the memory keys per batch (optional). + + Outputs: + Tensor. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) + >>> memory = Tensor(np.random.rand(10, 32, 512), mindspore.float32) + >>> tgt = Tensor(np.random.rand(20, 32, 512), mindspore.float32) + >>> out = transformer_decoder(tgt, memory) + """ + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm=None): + super(TransformerDecoder, self).__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def construct(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None): + output = tgt + + for mod in self.layers: + output = mod(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class Transformer(Cell): + r""" + Transformer module including encoder and decoder. The difference with the original implements is the module use + the residual addition before the layer normalization. And the default hidden act is `gelu`. + The details can be found in `Attention is all you need `_. + + Args: + d_model (int): the number of expected features in the encoder/decoder inputs (default=512). + nhead (int): the number of heads in the multiheadattention models (default=8). + num_encoder_layers (int): the number of sub-encoder-layers in the encoder (default=6). + num_decoder_layers (int): the number of sub-decoder-layers in the decoder (default=6). + dim_feedforward (int): the dimension of the feedforward network model (default=2048). + dropout (float): the dropout value (default=0.1). + activation (str, Cell): the activation function of encoder/decoder intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. Default: relu + custom_encoder (Cell): custom encoder (default=None). + custom_decoder (Cell): custom decoder (default=None). + layer_norm_eps (float): the eps value in layer normalization components (default=1e-5). + batch_first (bool): If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + norm_first (bool): if ``True``, encoder and decoder layers will perform LayerNorms before + other attention and feedforward operations, otherwise after. Default: ``False`` (after). + + Inputs: + - **src** (Tensor): the sequence to the encoder (required). + - **tgt** (Tensor): the sequence to the decoder (required). + - **src_mask** (Tensor): the additive mask for the src sequence (optional). + - **tgt_mask** (Tensor): the additive mask for the tgt sequence (optional). + - **memory_mask** (Tensor): the additive mask for the encoder output (optional). + - **src_key_padding_mask** (Tensor): the ByteTensor mask for src keys per batch (optional). + - **tgt_key_padding_mask** (Tensor): the ByteTensor mask for tgt keys per batch (optional). + - **memory_key_padding_mask** (Tensor): the ByteTensor mask for memory keys per batch (optional). + + Outputs: + Tensor. + + Examples: + >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) + >>> src = Tensor(np.random.rand(10, 32, 512), mindspore.float32) + >>> tgt = Tensor(np.random.rand(20, 32, 512), mindspore.float32) + >>> out = transformer_model(src, tgt) + """ + + def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, + num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, + activation: Union[str, Cell] = 'relu', custom_encoder: Optional[Cell] = None, + custom_decoder: Optional[Cell] = None, layer_norm_eps: float = 1e-5, + batch_first: bool = False, norm_first: bool = False): + super(Transformer, self).__init__() + + if custom_encoder is not None: + self.encoder = custom_encoder + else: + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, + activation, layer_norm_eps, batch_first, norm_first) + encoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps) + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + if custom_decoder is not None: + self.decoder = custom_decoder + else: + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, + activation, layer_norm_eps, batch_first, norm_first) + decoder_norm = LayerNorm((d_model,), epsilon=layer_norm_eps) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + self.batch_first = batch_first + + def construct(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None): + is_batched = src.ndim == 3 + if not self.batch_first and src.shape[1] != tgt.shape[1] and is_batched: + raise RuntimeError("the batch number of src and tgt must be equal") + if self.batch_first and src.shape[0] != tgt.shape[0] and is_batched: + raise RuntimeError("the batch number of src and tgt must be equal") + + if src.shape[-1] != self.d_model or tgt.shape[-1] != self.d_model: + raise RuntimeError("the feature number of src and tgt must be equal to d_model") + + memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) + output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask) + return output + + def _reset_parameters(self): + r"""Initiate parameters in the transformer model.""" + + for _, p in self.parameters_and_names(): + if p.ndim > 1: + p.set_data(initializer('xavier_uniform', p.shape, p.dtype)) + + +def _get_activation_fn(activation: str): + if activation == "relu": + return ops.relu + if activation == "gelu": + return ops.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + + +def _get_clones(module, N): + return CellList([copy.deepcopy(module) for i in range(N)]) diff --git a/mindspore/python/mindspore/ops/function/nn_func.py b/mindspore/python/mindspore/ops/function/nn_func.py index 952a6596505..d52772459a4 100644 --- a/mindspore/python/mindspore/ops/function/nn_func.py +++ b/mindspore/python/mindspore/ops/function/nn_func.py @@ -5795,6 +5795,283 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-06, return triplet_margin_loss_op(anchor, positive, negative, margin_tensor) +def linear(x, w, b): + out = ops.matmul(x, w.swapaxes(-1, -2)) + if b is not None: + out = out + b + return out + + +def _in_projection(q, k, v, w_q, w_k, w_v, b_q=None, b_k=None, b_v=None): + """in projection function""" + Eq, Ek, Ev = q.shape[-1], k.shape[-1], v.shape[-1] + assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" + assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" + assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" + assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" + assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" + assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +def _in_projection_packed(q, k, v, w, b, k_is_v, q_is_k): + """in projecktion packed function""" + E = q.shape[-1] + if k_is_v: + if q_is_k: + # self-attention + return linear(q, w, b).tensor_split(3, axis=-1) + # encoder-decoder attention + w_q, w_kv = w.split([E, E * 2]) + if b is None: + b_q = b_kv = None + else: + b_q, b_kv = b.split([E, E * 2]) + return (linear(q, w_q, b_q),) + linear(k, w_kv, b_kv).tensor_split(2, axis=-1) + w_q, w_k, w_v = w.tensor_split(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.tensor_split(3) + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +def _scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, is_causal, is_training): + """scaled dot product attention""" + embed_size = query.shape[-1] + scaling_factor = Tensor(embed_size, mstype.float32).sqrt().sqrt() + query = query / scaling_factor + + if is_causal: + L = query.shape[-2], S = key.shape[-2] + attn_mask = ops.ones((L, S), mstype.bool_).tril() + + attn = ops.matmul(query, key.swapaxes(-2, -1) / scaling_factor) + if attn_mask is not None: + attn = attn + attn_mask + attn = ops.softmax(attn, -1) + if dropout_p > 0. and is_training: + attn = ops.dropout(attn, dropout_p) + output = ops.matmul(attn, value) + + return (output, attn) + + +def _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads): + """ + Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask` + and returns if the input is batched or not. + Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor. + """ + # Shape check. + if query.ndim == 3: + # Batched Inputs + is_batched = True + assert key.ndim == 3 and value.ndim == 3, \ + ("For batched (3-D) `query`, expected `key` and `value` to be 3-D" + f" but found {key.ndim}-D and {value.ndim}-D tensors respectively") + if key_padding_mask is not None: + assert key_padding_mask.ndim == 2, \ + ("For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D" + f" but found {key_padding_mask.ndim}-D tensor instead") + if attn_mask is not None: + assert attn_mask.ndim in (2, 3), \ + ("For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.ndim}-D tensor instead") + elif query.ndim == 2: + # Unbatched Inputs + is_batched = False + assert key.ndim == 2 and value.ndim == 2, \ + ("For unbatched (2-D) `query`, expected `key` and `value` to be 2-D" + f" but found {key.ndim}-D and {value.ndim}-D tensors respectively") + + if key_padding_mask is not None: + assert key_padding_mask.ndim == 1, \ + ("For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D" + f" but found {key_padding_mask.ndim}-D tensor instead") + + if attn_mask is not None: + assert attn_mask.ndim in (2, 3), \ + ("For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D" + f" but found {attn_mask.ndim}-D tensor instead") + if attn_mask.ndim == 3: + expected_shape = (num_heads, query.shape[0], key.shape[0]) + assert attn_mask.shape == expected_shape, \ + (f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}") + else: + raise AssertionError( + f"query should be unbatched 2D or batched 3D tensor but received {query.ndim}-D query tensor") + + return is_batched + + +def multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, + in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, + out_proj_bias, training=True, key_padding_mask=None, attn_mask=None, + use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, + v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=True, + is_causal=False, k_is_v=False, q_is_k=False): + """multi head attetion forward function""" + is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads) + + if not is_batched: + query = query.expand_dims(1) + key = key.expand_dims(1) + value = value.expand_dims(1) + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.expand_dims(0) + + tgt_len, bsz, embed_dim = query.shape + src_len, _, _ = key.shape + if key_padding_mask is not None: + _kpm_dtype = key_padding_mask.dtype + if _kpm_dtype != mstype.bool_ and not ops.is_floating_point(key_padding_mask): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported") + assert embed_dim == embed_dim_to_check, \ + f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" + + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" + if use_separate_proj_weight: + # allow MHA to have different embedding dimensions when separate projection weights are used + assert key.shape[:2] == value.shape[:2], \ + f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" + else: + assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" + + # compute in-projection + if not use_separate_proj_weight: + assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None" + q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias, k_is_v, q_is_k) + else: + assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" + assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" + assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" + if in_proj_bias is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = in_proj_bias.tensor_split(3) + q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v) + + # prep attention mask + if attn_mask is not None: + if attn_mask.dtype == mstype.uint8: + attn_mask = attn_mask.astype(mstype.bool_) + else: + assert ops.is_floating_point(attn_mask) or attn_mask.dtype == mstype.bool_, \ + f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}" + # ensure attn_mask's dim is 3 + if attn_mask.ndim == 2: + correct_2d_size = (tgt_len, src_len) + if attn_mask.shape != correct_2d_size: + raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, " + "but should be {correct_2d_size}.") + attn_mask = attn_mask.expand_dims(0) + elif attn_mask.ndim == 3: + correct_3d_size = (bsz * num_heads, tgt_len, src_len) + if attn_mask.shape != correct_3d_size: + raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, " + "but should be {correct_3d_size}.") + else: + raise RuntimeError(f"attn_mask's dimension {attn_mask.ndim} is not supported") + + # add bias along batch dimension + if bias_k is not None and bias_v is not None: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + k = ops.cat([k, bias_k.repeat(1, bsz, 1)]) + v = ops.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = ops.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = ops.pad(key_padding_mask, (0, 1)) + else: + assert bias_k is None + assert bias_v is None + + # reshape q, k, v for multihead attention and make em batch first + q = q.view(tgt_len, bsz * num_heads, head_dim).swapaxes(0, 1) + if static_k is None: + k = k.view(k.shape[0], bsz * num_heads, head_dim).swapaxes(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_k.shape[0] == bsz * num_heads, \ + f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" + assert static_k.shape[2] == head_dim, \ + f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" + k = static_k + if static_v is None: + v = v.view(v.shape[0], bsz * num_heads, head_dim).swapaxes(0, 1) + else: + # TODO finish disentangling control flow so we don't do in-projections when statics are passed + assert static_v.shape[0] == bsz * num_heads, \ + f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" + assert static_v.shape[2] == head_dim, \ + f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" + v = static_v + + # add zero attention along batch dimension (now first) + if add_zero_attn: + zero_attn_shape = (bsz * num_heads, 1, head_dim) + k = ops.cat([k, ops.zeros(zero_attn_shape, dtype=k.dtype)], axis=1) + v = ops.cat([v, ops.zeros(zero_attn_shape, dtype=v.dtype)], axis=1) + if attn_mask is not None: + attn_mask = ops.pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = ops.pad(key_padding_mask, (0, 1)) + + # update source sequence length after adjustments + src_len = k.shape[1] + + # merge key padding and attention masks + if key_padding_mask is not None: + assert key_padding_mask.shape == (bsz, src_len), \ + f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ + expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) + if attn_mask is None: + attn_mask = key_padding_mask + elif attn_mask.dtype == mstype.bool_: + attn_mask = attn_mask.logical_or(key_padding_mask) + else: + attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf")) + + # convert mask to float + if attn_mask is not None and attn_mask.dtype == mstype.bool_: + new_attn_mask = ops.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill(attn_mask, float("-inf")) + attn_mask = new_attn_mask + + if attn_mask is not None: + if attn_mask.shape[0] == 1: + attn_mask = attn_mask.expand_dims(0) + else: + attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) + + q = q.view(bsz, num_heads, tgt_len, head_dim) + k = k.view(bsz, num_heads, src_len, head_dim) + v = v.view(bsz, num_heads, src_len, head_dim) + + attn_output, attn_output_weights = _scaled_dot_product_attention( + q, k, v, attn_mask, dropout_p, is_causal, training) + attn_output = attn_output.transpose(2, 0, 1, 3).view(bsz * tgt_len, embed_dim) + + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = attn_output.view(tgt_len, bsz, attn_output.shape[1]) + + # optionally average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + if average_attn_weights: + attn_output_weights = attn_output_weights.sum(axis=1) / num_heads + + if not is_batched: + # squeeze the output if input was unbatched + attn_output = attn_output.squeeze(1) + attn_output_weights = attn_output_weights.squeeze(0) + return attn_output, attn_output_weights + + __all__ = [ 'adaptive_avg_pool1d', 'adaptive_avg_pool2d', diff --git a/tests/st/nn/test_transformer.py b/tests/st/nn/test_transformer.py new file mode 100644 index 00000000000..c010a5d9f18 --- /dev/null +++ b/tests/st/nn/test_transformer.py @@ -0,0 +1,168 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore as ms +from mindspore import Tensor, ops +from mindspore.nn import MultiheadAttention, TransformerEncoderLayer, \ + TransformerEncoder, TransformerDecoderLayer, TransformerDecoder, Transformer + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('dtype', [ms.float16, ms.float32]) +@pytest.mark.parametrize('jit', [False, True]) +def test_multihead_attention_pynative(dtype, jit): + """ + Feature: MultiheadAttention + Description: Verify the result of AMultiheadAttentionvgPool3d + Expectation: success + """ + embed_dim = 128 + num_heads = 8 + sl = 10 + bs = 8 + model = MultiheadAttention(embed_dim, num_heads).to_float(dtype) + q = Tensor(np.random.randn(sl, bs, embed_dim), dtype) + k = Tensor(np.random.randn(sl, bs, embed_dim), dtype) + v = Tensor(np.random.randn(sl, bs, embed_dim), dtype) + + def forward(q, k, v): + out = model(q, k, v) + return out + + if jit: + forward = ms.jit(forward) + + out = forward(q, k, v) + assert q.shape == out[0].shape + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('training', [True, False]) +@pytest.mark.parametrize('jit', [False, True]) +def test_transformerencoder_square_input(training, jit): + """ + Feature: TransformerEncoder + Description: Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has + batch size == sequence length + Expectation: success + """ + model = TransformerEncoder( + TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True), + num_layers=2) + + # set constant weights of the model + for _, p in model.parameters_and_names(): + x = p.data + sz = x.view(-1).shape[0] + shape = x.shape + x = ops.cos(ops.arange(0, sz).astype(ms.float32).view(shape)) + p.set_data(x) + + if training: + model = model.set_train() + else: + model = model.set_train(False) + x = ops.arange(0, 16).reshape(2, 2, 4).astype(ms.float32) + src_mask = Tensor([[0, 1], [0, 0]]).to(ms.bool_) + + def forward(x, mask): + result = model(x, mask=mask) + return result + + if jit: + forward = ms.jit(forward) + + result = forward(x, src_mask) + ref_output = ms.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351], + [2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]], + [[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689], + [2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]], + ms.float32) + assert tuple(result.shape) == tuple(ref_output.shape) + np.allclose(result.asnumpy(), ref_output.asnumpy(), rtol=1e-7, atol=1e-5) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('jit', [False, True]) +def test_transformerdecoder(jit): + """ + Feature: TransformerDecoder + Description: Test shape (batch size, sequence length, embedding dimension) + Expectation: success + """ + decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8) + transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6) + memory = Tensor(np.random.rand(10, 32, 512), ms.float32) + tgt = Tensor(np.random.rand(20, 32, 512), ms.float32) + + def forward(tgt, memory): + out = transformer_decoder(tgt, memory) + return out + + if jit: + forward = ms.jit(forward) + + result = forward(tgt, memory) + assert result.shape == tgt.shape + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.env_onecard +@pytest.mark.parametrize('jit', [False, True]) +def test_transformer(jit): + """ + Feature: Transformer + Description: Test shape (batch size, sequence length, embedding dimension) + Expectation: success + """ + transformer_model = Transformer(nhead=16, num_encoder_layers=12) + src = Tensor(np.random.rand(10, 32, 512), ms.float32) + tgt = Tensor(np.random.rand(20, 32, 512), ms.float32) + + def forward(src, tgt): + out = transformer_model(src, tgt) + return out + + if jit: + forward = ms.jit(forward) + + result = forward(src, tgt) + assert result.shape == tgt.shape