!49909 fix Transformer docs

Merge pull request !49909 from 吕昱峰(Nate.River)/code_docs_master
This commit is contained in:
i-robot 2023-03-07 09:14:09 +00:00 committed by Gitee
commit e2ac12b1ad
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 20 additions and 1 deletions

View File

@ -123,9 +123,15 @@ class MultiheadAttention(Cell):
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> embed_dim, num_heads = 128, 8
>>> seq_length, batch_size = 10, 8
>>> query = Tensor(np.random.randn(seq_length, batch_size, embed_dim), mindspore.float32)
>>> key = Tensor(np.random.randn(seq_length, batch_size, embed_dim), mindspore.float32)
>>> value = Tensor(np.random.randn(seq_length, batch_size, embed_dim), mindspore.float32)
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
>>> print(attn_output.shape)
(10, 8, 128)
"""
def __init__(self, embed_dim, num_heads, dropout=0., has_bias=True, add_bias_kv=False,
@ -267,6 +273,8 @@ class TransformerEncoderLayer(Cell):
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
>>> src = Tensor(np.random.rand(32, 10, 512), mindspore.float32)
>>> out = encoder_layer(src)
>>> print(out.shape)
(32, 10, 512)
"""
__constants__ = ['batch_first', 'norm_first']
@ -378,6 +386,8 @@ class TransformerDecoderLayer(Cell):
>>> memory = Tensor(np.random.rand(32, 10, 512), mindspore.float32)
>>> tgt = Tensor(np.random.rand(32, 20, 512), mindspore.float32)
>>> out = decoder_layer(tgt, memory)
>>> print(out.shape)
(32, 20, 512)
"""
__constants__ = ['batch_first', 'norm_first']
@ -479,6 +489,8 @@ class TransformerEncoder(Cell):
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
>>> src = Tensor(np.random.rand(10, 32, 512), mindspore.float32)
>>> out = transformer_encoder(src)
>>> print(out.shape)
(10, 32, 512)
"""
__constants__ = ['norm']
@ -537,6 +549,8 @@ class TransformerDecoder(Cell):
>>> memory = Tensor(np.random.rand(10, 32, 512), mindspore.float32)
>>> tgt = Tensor(np.random.rand(20, 32, 512), mindspore.float32)
>>> out = transformer_decoder(tgt, memory)
>>> print(out.shape)
(20, 32, 512)
"""
__constants__ = ['norm']
@ -604,11 +618,16 @@ class Transformer(Cell):
Outputs:
Tensor.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
>>> src = Tensor(np.random.rand(10, 32, 512), mindspore.float32)
>>> tgt = Tensor(np.random.rand(20, 32, 512), mindspore.float32)
>>> out = transformer_model(src, tgt)
>>> print(out.shape)
(20, 32, 512)
"""
def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6,