Fix spell error

This commit is contained in:
huangxinjing 2021-09-17 20:47:31 +08:00
parent a17915b2c4
commit b787c5c8c8
4 changed files with 139 additions and 129 deletions

View File

@ -376,14 +376,14 @@ class FixedSparseAttention(nn.Cell):
only supports 64, 128 for now
Inputs:
- **q** - Tensor uery (:class:`mstype.fp16` [batch_size, seq_length, hidden_size]): Sequence of
- **q** (Tensor) - Tensor query (:class:`mstype.fp16` [batch_size, seq_length, hidden_size]): Sequence of
queries to query the context.
- **k** - Tensor key (:class:`mstype.fp16` [batch_size, seq_length, hidden_size]): Sequence of
- **k** (Tensor) - Tensor key (:class:`mstype.fp16` [batch_size, seq_length, hidden_size]): Sequence of
queries to query the context.
- **v** - Tensor value (:class:`mstype.fp16` [batch size, sequence length, Embedding Size]): Sequence of
queries to query the context.
- **attention_mask** - Tensor the mask of (:class:`mstype.fp32` [batch_size, seq_length, seq_length]):
Lower triangular matrix to pass masked information.
- **v** (Tensor) - Tensor value (:class:`mstype.fp16` [batch size, sequence length, Embedding Size]):
Sequence of queries to query the context.
- **attention_mask** (Tensor) - Float Tensor the mask of (:class:`mstype.fp32`, :class:`mstype.fp16`
[batch_size, seq_length, seq_length]): Lower triangular matrix to pass masked information.
Outputs:
A Tensor. The output of the attention with shape [batch_size, seq_length, hidden_size]
@ -396,10 +396,10 @@ class FixedSparseAttention(nn.Cell):
... num_heads=8,
... size_per_head=64,
... block_size=64)
>>> q = Tensor(np.ones((2, 1024, 8*64)), dtype.float16)
>>> k = Tensor(np.ones((2, 1024, 8*64)), dtype.float16)
>>> v = Tensor(np.ones((2, 1024, 8*64)), dtype.float16)
>>> attention_mask = Tensor(np.ones((2, 1024, 1024)), dtype.float16)
>>> q = Tensor(np.ones((2, 1024, 8*64)), mstype.float16)
>>> k = Tensor(np.ones((2, 1024, 8*64)), mstype.float16)
>>> v = Tensor(np.ones((2, 1024, 8*64)), mstype.float16)
>>> attention_mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
>>> output = model(q, k, v, attention_mask)
>>> print(output.shape)
(2, 1024, 512)
@ -550,7 +550,7 @@ class FixedSparseAttention(nn.Cell):
_check_input_dtype(F.dtype(v), "v", [mstype.float16], self.cls_name)
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, self.seq_length, self.seq_length])
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32], self.cls_name)
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
q, k, v = self._transpose_inputs(q, k, v)
local_mask, global_mask = self._generate_attention_mask(attention_mask)

View File

@ -34,7 +34,7 @@ class CrossEntropyLoss(Cell):
Args:
parallel_config (OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
a instance of `OpParallelConfig` with default args.
an instance of `OpParallelConfig` with default args.
Inputs:
- **logits** (Tensor) - Tensor of shape (N, C). Data type must be float16 or float32. the output logits of
@ -48,8 +48,9 @@ class CrossEntropyLoss(Cell):
Outputs:
Tensor. the corresponding cross entropy loss
Exapmes:
>>> loss = mindspore.parallel.nn.CrossEntropyLoss()
Examples:
>>> from mindspore.parallel.nn import CrossEntropyLoss
>>> loss = CrossEntropyLoss()
>>> logits = Tensor(np.array([[3, 5, 6, 9, 12, 33, 42, 12, 32, 72]]), mindspore.float32)
>>> labels_np = np.array([1]).astype(np.int32)
>>> input_mask = Tensor(np.ones(1).astype(np.float32))
@ -88,9 +89,6 @@ class CrossEntropyLoss(Cell):
self.div2 = P.RealDiv()
def construct(self, logits, label, input_mask):
r"""
Compute loss using logits, label and input mask
"""
self._check_input(logits, label, input_mask)
# the shape is [bs*seq_length, vocab_size]

View File

@ -41,7 +41,7 @@ class _Config:
class OpParallelConfig(_Config):
r"""
OpParallelConfig for the setting the data parallel and model parallel.
OpParallelConfig for the setting data parallel and model parallel.
Args:
data_parallel (int): The data parallel way. Default: 1
@ -81,7 +81,7 @@ class OpParallelConfig(_Config):
class _PipeLineConfig(_Config):
r"""
PPConfig for the setting the data parallel, model parallel
PPConfig for the setting data parallel, model parallel
Args:
pipeline_stage (int): The number of the pipeline stages. Default: 1

View File

@ -53,7 +53,7 @@ __all__ = [
class EmbeddingOpParallelConfig(_Config):
r"""
EmbeddingOpParallelConfig for the setting the data parallel or row slice for the embedding table.
EmbeddingOpParallelConfig for the setting data parallel or row slice for the embedding table.
Args:
data_parallel (int): The data parallel way. Default: 1
@ -100,7 +100,7 @@ class EmbeddingOpParallelConfig(_Config):
@property
def dp_mp_config(self):
r"""
To obtain the DPMPlConfig for the setting the data parallel, model parallel
To obtain the DPMPlConfig for the setting data parallel, model parallel
Supported Platforms:
``Ascend`` ``GPU``
@ -114,21 +114,21 @@ class EmbeddingOpParallelConfig(_Config):
class TransformerOpParallelConfig(_Config):
r"""
TransformerOpParallelConfig for the setting the global data parallel, model parallel and fusion group.
TransformerOpParallelConfig for the setting global data parallel, model parallel and fusion group.
The parallel configure setting.
Note:
Except the recompute argument, other arguments will not be effective when the user doesn't set
auto_parallel_context to `SEMI_AUTO_PARALLEL` or `AUTO_PARALLEL`.
The micro_batch_num must be greater then or equal to pipeline_stage. The data_parallel\*model_parallel
\*pipeline_stage must be equal to the device. When setting the pipeline stage and
The micro_batch_num must be greater than or equal to pipeline_stage. The data_parallel\*model_parallel
\*pipeline_stage must be equal or less equal to the device. When setting the pipeline stage and
optimizer_shard, the config will overwrite the auto_parallel_context.
Args:
data_parallel (int): The data parallel way. Default: 1.
model_parallel (int): The model parallel way. Default: 1.
pipeline_stage (int): The number of the pipeline stage. Should be a positive value. Default: 1.
micro_batch_num (int): The micore size of the batches for the pipeline training. Default: 1.
micro_batch_num (int): The microe size of the batches for the pipeline training. Default: 1.
optimizer_shard (bool): Whether to enable optimizer shard. Default False.
gradient_aggregation_group (int): The fusion group size of the optimizer state sharding. Default: 4.
recompute (bool): Enable recomputation of the transformer block or not. Default: False.
@ -221,7 +221,7 @@ class TransformerOpParallelConfig(_Config):
@property
def embedding_dp_mp_config(self):
r"""
To obtain the EmbeddingParallelConfig for the setting the data parallel, model parallel amd embedding
To obtain the EmbeddingParallelConfig for the setting data parallel, model parallel and embedding
parallel.
Supported Platforms:
@ -236,7 +236,7 @@ class TransformerOpParallelConfig(_Config):
@property
def dp_mp_config(self):
r"""
To obtain the EmbeddingParallelConfig for the setting the data parallel, model parallel amd embedding
To obtain the EmbeddingParallelConfig for the setting data parallel, model parallel and embedding
parallel.
Supported Platforms:
@ -274,10 +274,11 @@ class FeedForward(Cell):
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
expert_num (int): The number of experts used in Linear. For the case expert_num > 1, BatchMatMul is used
and the first dimension in BatchMatMul indicate expert_num. Default: 1.
param_init_type (dtype.Number): The parameter initialization type. Can be dtype.float32 or dtype.float16.
param_init_type (dtype.Number): The parameter initialization type. Should be dtype.float32 or dtype.float16.
Default: dtype.float32.
parallel_config(OpParallelConfig): The config of parallel setting, see `OpParallelConfig`.
Default `default_dpmp_config`, a instance of `OpParallelConfig` with default
args.
Default `default_dpmp_config`, an instance of `OpParallelConfig` with
default args.
Inputs:
- **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor.
@ -296,7 +297,7 @@ class FeedForward(Cell):
Examples:
>>> model = FeedForward(hidden_size=15, ffn_hidden_size=30, dropout_rate=0.1)
>>> tensor = Tensor(np.ones((2, 20, 15)), dtype.float32)
>>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
>>> output = model(tensor)
>>> print(output.shape)
(2, 20, 15)
@ -383,19 +384,19 @@ class AttentionMask(Cell):
with 1 and 0. 1 indicates the current position is a valid token, otherwise not.
Args:
seq_length(int): the sequence length of the input tensor.
parallel_config(OpParallelConfig): the parallel configure. Default `default_dpmp_config`,
a instance of `OpParallelConfig` with default args.
seq_length(int): The sequence length of the input tensor.
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
an instance of `OpParallelConfig` with default args.
Inputs:
- **input_mask** (Tensor) - the mask indicating whether each position is a valid input with
- **input_mask** (Tensor) - The mask indicating whether each position is a valid input with
(batch_size, seq_length).
Outputs:
Tensor. the attention mask matrix with shape (batch_size, seq_length, seq_length).
Tensor. The attention mask matrix with shape (batch_size, seq_length, seq_length).
Raises:
TypeError: `seq_length` is not a int.
TypeError: `seq_length` is not an integer.
ValueError: `seq_length` is not a positive value.
TypeError: `parallel_config` is not a subclass of OpParallelConfig.
@ -403,15 +404,16 @@ class AttentionMask(Cell):
``Ascend`` ``GPU``
Examples:
>>> mask = mindspore.parallel.nn.AttentionMask(seq_length=4)
>>> from mindspore.parallel.nn import AttentionMask
>>> mask = AttentionMask(seq_length=4)
>>> mask_array = np.array([[1, 1, 1, 0]], np.float32)
>>> inputs = Tensor(mask_array)
>>> res = mask(inputs)
>>> print(res)
Tensor(shape=[1, 4, 4], dtype=Float32,value=[[[1, 0, 0, 0],
[[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[0, 0, 0, 0]]])
[0, 0, 0, 0]]]
"""
@_args_type_validator_check(seq_length=Validator.check_positive_int,
@ -452,8 +454,8 @@ class VocabEmbedding(Cell):
"""
The embedding lookup table from the 0-th dim of the parameter table. When the parallel_config.vocab_emb_dp is
True and in the `AUTO_PARALLEL_MODE`, the embedding lookup will be a `parallel_config.data_parallel`
data parallel way, or will shard the parameter at the 0-th dimension in `parallel_config.model_parallel`, so called
row slice of the embedding table
data parallel way, or will shard the parameter at the 0-th dimension in `parallel_config.model_parallel`, so-called
row slice of the embedding table.
Args:
vocab_size (int): Size of the dictionary of embeddings.
@ -461,11 +463,11 @@ class VocabEmbedding(Cell):
param_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
Refer to class `initializer` for the values of string when a string
is specified. Default: 'normal'.
parallel_config(EmbeddingOpParallelConfig): the parallel config of network. Default
`default_embedding_parallel_config`, a instance of `EmbeddingOpParallelConfig` with default args.
parallel_config(EmbeddingOpParallelConfig): The parallel config of network. Default
`default_embedding_parallel_config`, an instance of `EmbeddingOpParallelConfig` with default args.
Inputs:
**input_ids** (Tensor) - the tokenized inputs with datatype int32 with shape (batch_size, seq_length)
**input_ids** (Tensor) - The tokenized inputs with datatype int32 with shape (batch_size, seq_length)
Outputs:
Tuple, a tuple contains (`output`, `embedding_table`)
@ -486,7 +488,7 @@ class VocabEmbedding(Cell):
Examples:
>>> model = VocabEmbedding(vocab_size=30, embedding_size=30)
>>> tensor = Tensor(np.ones((20, 15)), dtype.int32)
>>> tensor = Tensor(np.ones((20, 15)), mstype.int32)
>>> output, table = model(tensor)
>>> print(output.shape)
(20, 15, 30)
@ -526,7 +528,7 @@ class MultiHeadAttention(Cell):
r"""
This is an implementation of multihead attention in the paper `Attention is all you need
<https://arxiv.org/pdf/1706.03762v5.pdf>`_. Given the query vector with source length, and the
key and value vector with target length, the attention will be performered as the following
key and value vector with target length, the attention will be performed as the following
.. math::
MultiHeadAttention(query, key, vector) = Concat(head_1, \dots, head_h)W^O
@ -543,13 +545,15 @@ class MultiHeadAttention(Cell):
num_heads(int): The number of the heads.
hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1
attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1
compute_dtype(dtype.Number): The computation type. Default dtype.float16. The computation of the
softmax will be converted to the float32.
compute_dtype(dtype.Number): The computation type of dense. Default dtype.float16.
Should be dtype.float32 or dtype.float16.
param_init_type(dtype.Number). The parameter initialization type of the module. Default dtype.float32.
Can be dtype.float32 or dtype.float16.
Should be dtype.float32 or dtype.float16.
softmax_compute_type(dtype.Number). The type of softmax computation module. Default dtype.float32.
Should be dtype.float32 or dtype.float16.
use_past(bool): Use the past state to compute, used for incremental prediction. Default False.
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
a instance of `OpParallelConfig` with default args.
an instance of `OpParallelConfig` with default args.
Inputs:
- **query_tensor** (Tensor) - the query vector with shape (batch_size, src_seq_length, hidden_size).
@ -572,7 +576,7 @@ class MultiHeadAttention(Cell):
- **output** (Tensor) - Tensor, the float tensor of the output of the layer with
shape (batch_size, src_seq_length, hidden_size)
- **layer_present** (Tuple) - A tuple of the Tensor the projected key and value vector with
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
((batch_size, num_heads, size_per_head, tgt_seq_length),
(batch_size, num_heads, tgt_seq_length, size_per_head)).
@ -582,9 +586,9 @@ class MultiHeadAttention(Cell):
Examples:
>>> model = MultiHeadAttention(batch_size=2, hidden_size=15, src_seq_length=20, tgt_seq_length=20,
... num_heads=3)
>>> from_tensor = Tensor(np.ones((2, 20, 15)), dtype.float32)
>>> to_tensor = Tensor(np.ones((2, 20, 15)), dtype.float16)
>>> attention_mask = Tensor(np.ones((2, 20, 20)), dtype.float16)
>>> from_tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
>>> to_tensor = Tensor(np.ones((2, 20, 15)), mstype.float16)
>>> attention_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
>>> attn_out, past = model(from_tensor, to_tensor, to_tensor, attention_mask)
>>> print(attn_out.shape)
(2, 20, 15)
@ -601,6 +605,8 @@ class MultiHeadAttention(Cell):
tgt_seq_length=Validator.check_positive_int,
attention_dropout_rate=Validator.check_non_negative_float,
hidden_dropout_rate=Validator.check_non_negative_float,
compute_dtype=_valid_value_checks([mstype.float32, mstype.float16],
"MultiHeadAttention"),
softmax_compute_type=_valid_value_checks([mstype.float32, mstype.float16],
"MultiHeadAttention"),
param_init_type=_valid_value_checks([mstype.float32, mstype.float16],
@ -915,7 +921,7 @@ class MultiHeadAttention(Cell):
class TransformerEncoderLayer(Cell):
r"""
Transformer Encoder Layer. This is an implementation of the single layer of the transformer
encoder layer including multihead attention and feedward layer.
encoder layer, including multihead attention and feedward layer.
Args:
batch_size(int): The batch size of the input tensor.
@ -930,22 +936,22 @@ class TransformerEncoderLayer(Cell):
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
layernorm_compute_type(dtype.Number): The computation type of the layernorm.
Can be dtype.float32 or dtype.float16. Default dtype.float16.
Should be dtype.float32 or dtype.float16. Default dtype.float32.
softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
Can be dtype.float32 or dtype.float16. Default mstype.float16.
Should be dtype.float32 or dtype.float16. Default mstype.float32.
param_init_type(dtype.Number): The parameter initialization type of the module.
Can be dtype.float32 or dtype.float16. Default dtype.float32.
Should be dtype.float32 or dtype.float16. Default dtype.float32.
use_past(bool): Use the past state to compute, used for incremental prediction. Default False.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
a instance of `OpParallelConfig` with default args.
an instance of `OpParallelConfig` with default args.
Inputs:
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size].
- **input_mask** (Tensor) - Float Tensor, attention mask with shape [batch_size, seq_length, seq_length].
- **init_reset** (Tensor) - A bool tensor with shape [batch_size,], used to clear the past key parameter and
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
- **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index. Used
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
for incremental prediction when the use_past is True. Default None.
Outputs:
@ -954,7 +960,7 @@ class TransformerEncoderLayer(Cell):
- **output** (Tensor) - The float tensor of the output of the layer with
shape (batch_size, seq_length, hidden_size).
- **layer_present** (Tuple) - A tuple of the Tensor the projected key and value vector with
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
((batch_size, num_heads, size_per_head, seq_length),
(batch_size, num_heads, seq_length, size_per_head)).
@ -964,8 +970,8 @@ class TransformerEncoderLayer(Cell):
Examples:
>>> model = TransformerEncoderLayer(batch_size=2, hidden_size=8, ffn_hidden_size=64, seq_length=16,
... num_heads=2)
>>> encoder_input_value = Tensor(np.ones((2, 16, 8)), dtype.float32)
>>> encoder_input_mask = Tensor(np.ones((2, 16, 16)), dtype.float16)
>>> encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32)
>>> encoder_input_mask = Tensor(np.ones((2, 16, 16)), mstype.float16)
>>> output, past = model(encoder_input_value, encoder_input_mask)
>>> print(output.shape)
(2, 16, 8)
@ -1159,7 +1165,7 @@ class TransformerEncoderLayer(Cell):
class TransformerDecoderLayer(Cell):
r"""
Transformer Decoder Layer. This is an implementation of the single layer of the transformer
decoder layer including self-attention, cross attention and feedward layer. When the encoder_output is None,
decoder layer, including self-attention, cross attention and feedward layer. When the encoder_output is None,
the cross attention will not be effective.
Args:
@ -1176,15 +1182,15 @@ class TransformerDecoderLayer(Cell):
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
layernorm_compute_type(dtype.Number): The computation type of the layernorm.
Can be dtype.float32 or dtype.float16. Default dtype.float16.
Should be dtype.float32 or dtype.float16. Default dtype.float32.
softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
Can be dtype.float32 or dtype.float16. Default mstype.float16.
param_init_type: The parameter initialization type of the module. Can be dtype.float32 or dtype.float16.
Default dtype.float32.
Should be dtype.float32 or dtype.float16. Default mstype.float32.
param_init_type(dtype.Number): The parameter initialization type of the module.
Should be dtype.float32 or dtype.float16. Default dtype.float32.
use_past(bool): Use the past state to compute, used for incremental prediction. Default False.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
parallel_config(OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
a instance of `OpParallelConfig` with default args.
an instance of `OpParallelConfig` with default args.
Inputs:
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, tgt_seq_length, hidden_size].
@ -1193,18 +1199,18 @@ class TransformerDecoderLayer(Cell):
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size].
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
src_seq_length], where tgt_seq_length is the length of the decoder.
- **init_reset** (Tensor) - A bool tensor with shape [batch_size,], used to clear the past key parameter and
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
- **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index. Used
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
for incremental prediction when the use_past is True. Default None.
Outputs:
Tuple, a tuple contains(`output`, `layer_present`)
- **output** (Tensor) - the output logit of this layer. The shape is [batch, seq_length, hidden_size]
- **layer_present** (Tensor) - A tuple, where each tuple is the tensor the projected key and value
- **layer_present** (Tensor) - A tuple, where each tuple is the tensor of the projected key and value
vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length),
(batch_size, num_heads, tgt_seq_length, size_per_head), and the projected key and value vector
(batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector
in cross attention with shape (batch_size, num_heads, size_per_head, src_seq_length),
(batch_size, num_heads, src_seq_length, size_per_head)).
@ -1214,10 +1220,10 @@ class TransformerDecoderLayer(Cell):
Examples:
>>> model = TransformerDecoderLayer(batch_size=2, hidden_size=64, ffn_hidden_size=64, num_heads=2,
... src_seq_length=20, tgt_seq_length=10)
>>> encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
>>> decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
>>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
>>> memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
>>> encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
>>> decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
>>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
>>> memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
>>> output, past = model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
>>> print(output.shape)
(2, 10, 64)
@ -1477,7 +1483,7 @@ def _get_lambda_func(total_layer=None):
Args:
network(Cell) - Represents the transformer block
layer_id(int) - Means the layer index for the current module, counts from zero.
offset(int) - Means the layer_index needs a offset, if there are other modules in the net.
offset(int) - Means the layer_index needs an offset, if there are other modules in the net.
layers(int) - The total layers used for the model.
"""
# override the layers
@ -1522,30 +1528,30 @@ class TransformerEncoder(Cell):
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
layernorm_compute_type(dtype.Number): The computation type of the layernorm.
Can be dtype.float32 or dtype.float16. Default dtype.float16.
Should be dtype.float32 or dtype.float16. Default dtype.float32.
softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
Can be dtype.float32 or dtype.float16. Default mstype.float16.
param_init_type: The parameter initialization type of the module. Can be dtype.float32 or dtype.float16.
Default dtype.float32.
Should be dtype.float32 or dtype.float16. Default mstype.float32.
param_init_type(dtype.Number): The parameter initialization type of the module.
Should be dtype.float32 or dtype.float16. Default dtype.float32.
use_past(bool): Use the past state to compute, used for incremental prediction. Default False.
lambda_func: A function can specific the fusion index, pipeline stages and recompute attribute. If the user
wants to specific the pipeline stage and gradient aggregation fusion, the user can pass a function
lambda_func: A function can determine the fusion index, pipeline stages and recompute attribute. If the user
wants to determine the pipeline stage and gradient aggregation fusion, the user can pass a function
that accepts `network`, `layer_id`, `offset`, `parallel_config`, `layers`. The `network(Cell)`
represents the transformer block, `layer_id(int)` means the layer index for the current module, counts from
zero, `offset(int)` means the layer_index needs a offset, if there are other modules in the net. The
zero, `offset(int)` means the layer_index needs an offset, if there are other modules in the net. The
default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
offset(int): The initial layer index for the `decoder`. Used for setting the fusion id and stage id, to not
overlap with the encoder layer.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`,
a instance of `TransformerOpParallelConfig` with default args.
an instance of `TransformerOpParallelConfig` with default args.
Inputs:
- **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size]
- **attention_mask** (Tensor) - Tensor, attention mask with shape [batch_size, seq_length, seq_length]
- **init_reset** (Tensor) - A bool tensor with shape [batch_size,], used to clear the past key parameter and
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
- **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index. Used
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
for incremental prediction when the use_past is True. Default None.
Outputs:
@ -1563,8 +1569,8 @@ class TransformerEncoder(Cell):
Examples:
>>> model = TransformerEncoder(batch_size=2, num_layers=2, hidden_size=8, ffn_hidden_size=64, seq_length=16,
... num_heads=2)
>>> encoder_input_value = Tensor(np.ones((2, 16, 8)), dtype.float32)
>>> encoder_input_mask = Tensor(np.ones((2, 16, 16)), dtype.float16)
>>> encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32)
>>> encoder_input_mask = Tensor(np.ones((2, 16, 16)), mstype.float16)
>>> output, past = model(encoder_input_value, encoder_input_mask)
>>> print(output.shape)
(2, 16, 8)
@ -1692,23 +1698,23 @@ class TransformerDecoder(Cell):
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
layernorm_compute_type(dtype.Number): The computation type of the layernorm.
Can be dtype.float32 or dtype.float16. Default dtype.float16.
Should be dtype.float32 or dtype.float16. Default dtype.float32.
softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
Can be dtype.float32 or dtype.float16. Default mstype.float16.
param_init_type: The parameter initialization type of the module. Can be dtype.float32 or dtype.float16.
Default dtype.float32.
Should be dtype.float32 or dtype.float16. Default mstype.float32.
param_init_type(dtype.Number): The parameter initialization type of the module.
Should be dtype.float32 or dtype.float16. Default dtype.float32.
offset(int): The initial layer index for the `decoder`. Used for setting the fusion id and stage id, to not
overlap with the encoder layer.
lambda_func: A function can specific the fusion index, pipeline stages and recompute attribute. If the user
wants to specific the pipeline stage and gradient aggregation fusion, the user can pass a function
lambda_func: A function can determine the fusion index, pipeline stages and recompute attribute. If the user
wants to determine the pipeline stage and gradient aggregation fusion, the user can pass a function
that accepts `network`, `layer_id`, `offset`, `parallel_config`, `layers`. The `network(Cell)`
represents the transformer block, `layer_id(int)` means the layer index for the current module, counts from
zero, `offset(int)` means the layer_index needs a offset, if there are other modules in the net. The
zero, `offset(int)` means the layer_index needs an offset, if there are other modules in the net. The
default setting for the pipeline is: `(layer_id + offset) // (layers / pipeline_stage)`.
Default: None
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`,
a instance of `TransformerOpParallelConfig` with default args.
an instance of `TransformerOpParallelConfig` with default args.
Inputs:
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size]
@ -1717,18 +1723,18 @@ class TransformerDecoder(Cell):
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
src_seq_length] where tgt_seq_length is the length of the decoder. the output of the encoder with shape
[batch_size, seq_length, hidden_size],
- **init_reset** (Tensor) - A bool tensor with shape [batch_size,], used to clear the past key parameter and
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
- **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index.
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
Used for incremental prediction when the use_past is True. Default None.
Outputs:
Tuple, a tuple contains(`output`, `layer_present`)
- **output** (Tensor) - The output logit of this layer. The shape is [batch, tgt_seq_length, hidden_size]
- **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor the projected
- **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor of the projected
key and value vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length),
(batch_size, num_heads, tgt_seq_length, size_per_head), and the projected key and value vector
(batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector
in cross attention with shape (batch_size, num_heads, size_per_head, src_seq_length),
(batch_size, num_heads, src_seq_length, size_per_head)).
@ -1738,10 +1744,10 @@ class TransformerDecoder(Cell):
Examples:
>>> model = TransformerDecoder(batch_size=2, num_layers=1, hidden_size=64, ffn_hidden_size=64,
... num_heads=2, src_seq_length=20, tgt_seq_length=10)
>>> encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
>>> decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
>>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
>>> memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
>>> encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
>>> decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
>>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
>>> memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
>>> output, past = model(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
>>> print(output.shape)
(2, 10, 64)
@ -1868,7 +1874,7 @@ class Transformer(Cell):
r"""
Transformer module including encoder and decoder. The difference with the original implements is the module use
the residual addition before the layer normalization. And the default hidden act is `gelu`.
The detials can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
The details can be found in `Attention is all you need <https://arxiv.org/pdf/1706.03762v5.pdf>`_.
Note:
This is an experimental interface that is subject to change and/or deletion.
@ -1881,37 +1887,43 @@ class Transformer(Cell):
ffn_hidden_size(int): The hidden size of bottleneck in the feedforward layer.
src_seq_length(int): The seq_length of the encoder's input tensor.
tgt_seq_length(int): The seq_length of the decoder's input tensor.
num_heads(int): The number of the heads.
num_heads(int): The number of the heads. Default: 2.
hidden_dropout_rate(float): The dropout rate of the final output of the layer. Default:0.1
attention_dropout_rate(float): The dropout rate of the attention scores. Default:0.1
post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
layernorm_compute_type(dtype.Number): The computation type of the layernorm.
Should be dtype.float32 or dtype.float16. Default dtype.float32.
softmax_compute_type(dtype.Number): The computation type of the softmax in the attention.
Should be dtype.float32 or dtype.float16. Default mstype.float32.
param_init_type(dtype.Number): The parameter initialization type of the module.
Should be dtype.float32 or dtype.float16. Default dtype.float32.
hidden_act(str): The activation of the internal feedforward layer. Supports 'relu',
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. Default: gelu.
lambda_func: A function can specific the fusion index, pipeline stages and recompute attribute. If the user
wants to specific the pipeline stage and gradient aggregation fusion, the user can pass a function
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert).
lambda_func: A function can determine the fusion index, pipeline stages and recompute attribute. If the user
wants to determine the pipeline stage and gradient aggregation fusion, the user can pass a function
that accepts `network`, `layer_id`, `offset`, `parallel_config`, `layers`. The `network(Cell)`
represents the transformer block, `layer_id(int)` means the layer index for the current module, counts from
zero, `offset(int)` means the layer_index needs a offset, if there are other modules in the net. The
zero, `offset(int)` means the layer_index needs an offset, if there are other modules in the net. The
default setting for the pipeline is: `(layer_id + offset) // ((encoder_layers + decoder_length)
/ pipeline_stage)`.
parallel_config(TransformerOpParallelConfig): The parallel configure. Default `default_transformer_config`,
a instance of `TransformerOpParallelConfig` with default args.
an instance of `TransformerOpParallelConfig` with default args.
Inputs:
- **encoder_inputs** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size].
- **encoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length].
- **decoder_inputs** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size],
this can be none if the decoder layer is 0.
- **decoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, 1,
seq_length, seq_length]
this should be none if the decoder layer is 0.
- **decoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length]
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
src_seq_length]
where tgt_seq_length is the length of the decoder. the output of the encoder with shape [batch_size,
seq_length, hidden_size], this can be none if the decoder layer is 0.
- **init_reset** (Tensor) - A bool tensor with shape [batch_size,], used to clear the past key parameter and
seq_length, hidden_size], this should be none if the decoder layer is 0.
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
- **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index. Used
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
for incremental prediction when the use_past is True. Default None.
Outputs:
@ -1922,10 +1934,10 @@ class Transformer(Cell):
decoder layer. The shape is [batch, tgt_seq_length, hidden_size].
- **encoder_layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor the
projected key and value vector in self attention with shape ((batch_size, num_heads, size_per_head,
src_seq_length), (batch_size, num_heads, src_seq_length, size_per_head).
src_seq_length), (batch_size, num_heads, src_seq_length, size_per_head)).
- **decoder_layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor
the projected key and value vector in self attention with shape ((batch_size, num_heads, size_per_head,
tgt_seq_length), (batch_size, num_heads, tgt_seq_length, size_per_head), and the
of the projected key and value vector in self attention with shape ((batch_size, num_heads, size_per_head,
tgt_seq_length), (batch_size, num_heads, tgt_seq_length, size_per_head)), and the
projected key and value vector in cross attention with shape
(batch_size, num_heads, size_per_head, src_seq_length),
(batch_size, num_heads, src_seq_length, size_per_head)). If the decoder is not set, the
@ -1937,11 +1949,11 @@ class Transformer(Cell):
Examples:
>>> model = Transformer(batch_size=2, encoder_layers=1, decoder_layers=2, hidden_size=64, ffn_hidden_size=64,
... src_seq_length=20, tgt_seq_length=10)
>>> encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
>>> encoder_input_mask = Tensor(np.ones((2, 20, 20)), dtype.float16)
>>> decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
>>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
>>> memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
>>> encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
>>> encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
>>> decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
>>> decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
>>> memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
>>> output, en_past, de_past = model(encoder_input_value, encoder_input_mask, decoder_input_value,
... decoder_input_mask, memory_mask)
>>> print(output.shape)
@ -2094,7 +2106,7 @@ class Transformer(Cell):
output = encoder_output
if self.decoder is not None:
# decoder mask can be created outside of the model
# decoder mask should be created outside of the model
if self.use_moe is True:
decoder_output, decoder_layer_present, decoder_aux_loss = self.decoder(decoder_inputs, decoder_masks,
encoder_output, memory_mask,