!43407 [AutoParallel]Make the Batch Error if the inputs's batch is different
Merge pull request !43407 from huangxinjing/fix_shape_infer_error
This commit is contained in:
commit
e54b62bd30
|
@ -12,7 +12,7 @@
|
|||
- **hidden_size** (int) - 表示输入的维度。
|
||||
- **ffn_hidden_size** (int) - 表示中间隐藏大小。
|
||||
- **dropout_rate** (float) - 表示第二个线性层输出的丢弃率。
|
||||
- **hidden_act** (str, nn.Cell) - 表示前馈层的激活行为。其值可为'relu'、'relu6'、'tanh'、'gelu'、'fast_gelu'、'elu'、'sigmoid'、'prelu'、'leakyrelu'、'hswish'、'hsigmoid'、'logsigmoid'等等。用户可以传入自定义的激活函数。如果用户要想在并行模式下运行此网络,自定义的激活函数必须提供 `activation_shard` 类方法。请查看类 `mindspore.nn.transformer.FeedForward` 的示例。默认值:gelu。
|
||||
- **hidden_act** (str, nn.Cell) - 表示前馈层的激活行为。其值可为'relu'、'relu6'、'tanh'、'gelu'、'fast_gelu'、'elu'、'sigmoid'、'prelu'、'leakyrelu'、'hswish'、'hsigmoid'、'logsigmoid'等等。用户可以传入自定义的激活函数。如果用户要想在并行模式下运行此网络,自定义的激活函数必须提供 `activation_shard` 类方法。请查看类 `mindspore.nn.transformer.FeedForward` 的示例。默认值:gelu。
|
||||
- **expert_num** (int) - 表示线性层中使用的专家数量。对于expert_num > 1用例,使用BatchMatMul。BatchMatMul中的第一个维度表示expert_num。默认值:1。
|
||||
- **expert_group_size** (int) - 表示每个数据并行组收到的词语(token)数量。默认值:None。该参数只在自动并行且非策略传播模式下起作用。
|
||||
- **param_init_type** (dtype.Number) - 表示参数初始化类型。其值应为mstype.float32或mstype.float16。默认值:mstype.float32。
|
||||
|
|
|
@ -356,7 +356,7 @@ class FeedForward(Cell):
|
|||
hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
|
||||
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
|
||||
'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
|
||||
If user want to run the net in the parallel mode, the custom activation must also provide
|
||||
If user wants to run the net in the parallel mode, the custom activation must also provide
|
||||
the `activation_shard` function. Please see examples. Default: gelu.
|
||||
expert_num (int): The number of experts used in Linear. For the case expert_num > 1, BatchMatMul is used
|
||||
and the first dimension in BatchMatMul indicate expert_num. Default: 1.
|
||||
|
@ -409,8 +409,9 @@ class FeedForward(Cell):
|
|||
>>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
|
||||
>>> output = model(tensor)
|
||||
>>> print(output.shape)
|
||||
(2, 20, 15)
|
||||
>>> # Example 3 using custom hidden activation with activation_shard
|
||||
>>> # If user wants to run on the SEMI/AUTO parallel mode, the custom activation must provide
|
||||
>>> # If user wantss to run on the SEMI/AUTO parallel mode, the custom activation must provide
|
||||
>>> # a class function named activation_shard. It accepts the argument parallel_config (OpParallelConfig,
|
||||
>>> # MoEParallelConfig) and set the shard for the primitives used in the construct.
|
||||
>>> class MyActivationWithShard(nn.Cell):
|
||||
|
@ -427,6 +428,7 @@ class FeedForward(Cell):
|
|||
>>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
|
||||
>>> output = model(tensor)
|
||||
>>> print(output.shape)
|
||||
(2, 20, 15)
|
||||
"""
|
||||
@_LogActionOnce(logger=logger, key='FeedForward',
|
||||
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
||||
|
@ -1115,11 +1117,7 @@ class MultiHeadAttention(Cell):
|
|||
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
|
||||
value_past, batch_valid_length)
|
||||
ori_shape = F.shape(query_tensor)
|
||||
batch_size = None
|
||||
if len(F.shape(query_tensor)) == 2:
|
||||
batch_size = F.shape(query_tensor)[0] // self.src_seq_length
|
||||
else:
|
||||
batch_size = F.shape(query_tensor)[0]
|
||||
batch_size = self._get_batch_size_from_query(query_tensor)
|
||||
query_tensor, key_tensor, value_tensor = self._convert_to_2d_tensor(query_tensor,
|
||||
key_tensor,
|
||||
value_tensor,
|
||||
|
@ -1136,18 +1134,21 @@ class MultiHeadAttention(Cell):
|
|||
query = self.transpose(
|
||||
F.reshape(
|
||||
query,
|
||||
(batch_size, -1, self.n_head, self.size_per_head)),
|
||||
(batch_size, self._get_seq_length_under_incremental(self.src_seq_length),
|
||||
self.n_head, self.size_per_head)),
|
||||
(0, 2, 1, 3))
|
||||
# the returned shape is [bs, size_per_head, seq_length, num_heads]
|
||||
key = self.transpose(
|
||||
F.reshape(
|
||||
key, (batch_size, -1, self.n_head, self.size_per_head)),
|
||||
key, (batch_size, self._get_seq_length_under_incremental(self.tgt_seq_length),
|
||||
self.n_head, self.size_per_head)),
|
||||
(0, 2, 3, 1))
|
||||
# the returned shape is [bs, num_heads, seq_length, size_per_head]
|
||||
value = self.transpose(
|
||||
F.reshape(
|
||||
value,
|
||||
(batch_size, -1, self.n_head, self.size_per_head)),
|
||||
(batch_size, self._get_seq_length_under_incremental(self.tgt_seq_length),
|
||||
self.n_head, self.size_per_head)),
|
||||
(0, 2, 1, 3))
|
||||
# support input shape is [bs, seq, seq] or [bs, heads, seq, seq]
|
||||
if attention_mask is not None and len(F.shape(attention_mask)) == 3:
|
||||
|
@ -1201,6 +1202,24 @@ class MultiHeadAttention(Cell):
|
|||
output = F.cast(output, ori_dtype)
|
||||
return output, layer_present
|
||||
|
||||
def _get_batch_size_from_query(self, query):
|
||||
r"""Get the batch size from query tensor"""
|
||||
batch_size = None
|
||||
# For the incremental prediction, the seq length for the input is 1.
|
||||
if len(F.shape(query)) == 2 and self.is_first_iteration:
|
||||
batch_size = F.shape(query)[0] // self.src_seq_length
|
||||
else:
|
||||
batch_size = F.shape(query)[0]
|
||||
return batch_size
|
||||
|
||||
def _get_seq_length_under_incremental(self, length):
|
||||
r"""Return the length of the tensor.
|
||||
For the incremental prediction, the seq length for the input is 1.
|
||||
"""
|
||||
if self.is_first_iteration:
|
||||
return length
|
||||
return 1
|
||||
|
||||
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
|
||||
value_past=None, batch_valid_length=None):
|
||||
r"""Check inputs"""
|
||||
|
@ -1384,7 +1403,7 @@ class TransformerEncoderLayer(Cell):
|
|||
hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
|
||||
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
|
||||
'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
|
||||
If user want to run the net in the parallel mode, the custom activation must also provide
|
||||
If user wants to run the net in the parallel mode, the custom activation must also provide
|
||||
the `activation_shard` function. Please see the examples of the
|
||||
class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
|
||||
use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two
|
||||
|
@ -1811,7 +1830,7 @@ class TransformerDecoderLayer(Cell):
|
|||
hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
|
||||
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
|
||||
'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
|
||||
If user want to run the net in the parallel mode, the custom activation must also provide
|
||||
If user wants to run the net in the parallel mode, the custom activation must also provide
|
||||
the `activation_shard` function. Please see the examples of the
|
||||
class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
|
||||
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig
|
||||
|
@ -2319,7 +2338,7 @@ class TransformerEncoder(Cell):
|
|||
hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
|
||||
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
|
||||
'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
|
||||
If user want to run the net in the parallel mode, the custom activation must also provide
|
||||
If user wants to run the net in the parallel mode, the custom activation must also provide
|
||||
the `activation_shard` function. Please see the examples of the
|
||||
class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
|
||||
post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
|
||||
|
@ -2582,7 +2601,7 @@ class TransformerDecoder(Cell):
|
|||
hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
|
||||
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
|
||||
'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
|
||||
If user want to run the net in the parallel mode, the custom activation must also provide
|
||||
If user wants to run the net in the parallel mode, the custom activation must also provide
|
||||
the `activation_shard` function. Please see the examples of the
|
||||
class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
|
||||
lambda_func(function): A function can determine the fusion index,
|
||||
|
@ -2827,7 +2846,7 @@ class Transformer(Cell):
|
|||
hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
|
||||
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
|
||||
'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
|
||||
If user want to run the net in the parallel mode, the custom activation must also provide
|
||||
If user wants to run the net in the parallel mode, the custom activation must also provide
|
||||
the `activation_shard` function. Please see the examples of the
|
||||
class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
|
||||
post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
|
||||
|
|
|
@ -544,3 +544,79 @@ class TestBasicWarningValidator:
|
|||
# Force to rebuild the logger
|
||||
test_transformer_decoder()
|
||||
self.check_warning_log()
|
||||
|
||||
|
||||
def test_attention_with_wrong_batch_3d_inputs():
|
||||
"""
|
||||
Feature: Test Transformer batch error when the input's batch size is different
|
||||
Description: Test the input's batch size is different between the tensors. The input is 3d
|
||||
Expectation: Raise a reshape error exception
|
||||
"""
|
||||
model = MultiHeadAttention(hidden_size=15, src_seq_length=20, tgt_seq_length=20,
|
||||
batch_size=None, num_heads=3)
|
||||
from_tensor = Tensor(np.ones((3, 20, 15)), dtype.float32)
|
||||
to_tensor = Tensor(np.ones((5, 20, 15)), dtype.float16)
|
||||
attention_mask = Tensor(np.ones((3, 20, 20)), dtype.float16)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
|
||||
|
||||
|
||||
def test_attention_with_wrong_batch_2d_inputs():
|
||||
"""
|
||||
Feature: Test Transformer batch error when the input's batch size is different
|
||||
Description: Test the input's batch size is different between the tensors. The inputs is 2d
|
||||
Expectation: Raise a reshape error exception
|
||||
"""
|
||||
model = MultiHeadAttention(hidden_size=15, src_seq_length=20, tgt_seq_length=20,
|
||||
batch_size=None, num_heads=3)
|
||||
from_tensor = Tensor(np.ones((60, 15)), dtype.float32)
|
||||
to_tensor = Tensor(np.ones((100, 15)), dtype.float16)
|
||||
attention_mask = Tensor(np.ones((3, 20, 20)), dtype.float16)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
|
||||
|
||||
|
||||
def test_incremental_prediction_first_iterator():
|
||||
"""
|
||||
Feature: Test MultiHeadAttention with incremental prediction
|
||||
Description: Test MultiHeadAttention with incremental prediction in the first iterator
|
||||
Expectation: No Expectation
|
||||
"""
|
||||
# Step 1: set is_first_iteration=True, and input the full sequence length's state.
|
||||
# We need to prepare the memory parameters for saving key and value states firstly.
|
||||
from_tensor = Tensor(np.ones((2, 20, 15)), dtype.float32)
|
||||
to_tensor = Tensor(np.ones((2, 20, 15)), dtype.float16)
|
||||
attention_mask = Tensor(np.ones((2, 20, 20)), dtype.float16)
|
||||
key_past = Tensor(np.zeros(shape=(2, 3, 5, 20)), dtype.float16)
|
||||
value_past = Tensor(np.zeros(shape=(2, 3, 20, 5)), dtype.float16)
|
||||
batch_valid_length = Tensor(np.ones((2,)), dtype.int32)
|
||||
|
||||
model = MultiHeadAttention(batch_size=2, hidden_size=15, src_seq_length=20, tgt_seq_length=20,
|
||||
num_heads=3, use_past=True)
|
||||
model.add_flags_recursive(is_first_iteration=True)
|
||||
model(from_tensor, to_tensor, to_tensor, attention_mask, key_past, value_past,
|
||||
batch_valid_length)
|
||||
|
||||
|
||||
def test_incremental_prediction_second_iterator():
|
||||
"""
|
||||
Feature: Test MultiHeadAttention with incremental prediction
|
||||
Description: Test MultiHeadAttention with incremental prediction in the second iterator
|
||||
Expectation: No Expectation
|
||||
"""
|
||||
model = MultiHeadAttention(batch_size=2, hidden_size=15, src_seq_length=20, tgt_seq_length=20,
|
||||
num_heads=3, use_past=True)
|
||||
key_past = Tensor(np.zeros(shape=(2, 3, 5, 20)), dtype.float16)
|
||||
value_past = Tensor(np.zeros(shape=(2, 3, 20, 5)), dtype.float16)
|
||||
batch_valid_length = Tensor(np.ones((2,)), dtype.int32)
|
||||
# Set is_first_iteration=True to generate the full memory states
|
||||
from_tensor = Tensor(np.ones((2, 1, 15)), dtype.float32)
|
||||
to_tensor = Tensor(np.ones((2, 1, 15)), dtype.float16)
|
||||
attention_mask = Tensor(np.ones((2, 1, 20)), dtype.float16)
|
||||
# Step 2: set is_first_iteration=False, and pass the single word to run the prediction rather than the
|
||||
# full sequence.
|
||||
model.add_flags_recursive(is_first_iteration=False)
|
||||
model(from_tensor, to_tensor, to_tensor, attention_mask, key_past, value_past,
|
||||
batch_valid_length)
|
||||
|
|
Loading…
Reference in New Issue