!43407 [AutoParallel]Make the Batch Error if the inputs's batch is different

Merge pull request !43407 from huangxinjing/fix_shape_infer_error
This commit is contained in:
i-robot 2022-10-27 03:50:07 +00:00 committed by Gitee
commit e54b62bd30
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 111 additions and 16 deletions

View File

@ -12,7 +12,7 @@
- **hidden_size** (int) - 表示输入的维度。
- **ffn_hidden_size** (int) - 表示中间隐藏大小。
- **dropout_rate** (float) - 表示第二个线性层输出的丢弃率。
- **hidden_act** (str, nn.Cell) - 表示前馈层的激活行为。其值可为'relu'、'relu6'、'tanh'、'gelu'、'fast_gelu'、'elu'、'sigmoid'、'prelu'、'leakyrelu'、'hswish'、'hsigmoid'、'logsigmoid'等等。用户可以传入自定义的激活函数。如果用户要想在并行模式下运行此网络,自定义的激活函数必须提供 `activation_shard` 类方法。请查看类 `mindspore.nn.transformer.FeedForward` 的示例。默认值gelu。
- **hidden_act** (str, nn.Cell) - 表示前馈层的激活行为。其值可为'relu'、'relu6'、'tanh'、'gelu'、'fast_gelu'、'elu'、'sigmoid'、'prelu'、'leakyrelu'、'hswish'、'hsigmoid'、'logsigmoid'等等。用户可以传入自定义的激活函数。如果用户要想在并行模式下运行此网络,自定义的激活函数必须提供 `activation_shard` 类方法。请查看类 `mindspore.nn.transformer.FeedForward` 的示例。默认值gelu。
- **expert_num** (int) - 表示线性层中使用的专家数量。对于expert_num > 1用例使用BatchMatMul。BatchMatMul中的第一个维度表示expert_num。默认值1。
- **expert_group_size** (int) - 表示每个数据并行组收到的词语token数量。默认值None。该参数只在自动并行且非策略传播模式下起作用。
- **param_init_type** (dtype.Number) - 表示参数初始化类型。其值应为mstype.float32或mstype.float16。默认值mstype.float32。

View File

@ -356,7 +356,7 @@ class FeedForward(Cell):
hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
If user want to run the net in the parallel mode, the custom activation must also provide
If user wants to run the net in the parallel mode, the custom activation must also provide
the `activation_shard` function. Please see examples. Default: gelu.
expert_num (int): The number of experts used in Linear. For the case expert_num > 1, BatchMatMul is used
and the first dimension in BatchMatMul indicate expert_num. Default: 1.
@ -409,8 +409,9 @@ class FeedForward(Cell):
>>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
>>> output = model(tensor)
>>> print(output.shape)
(2, 20, 15)
>>> # Example 3 using custom hidden activation with activation_shard
>>> # If user wants to run on the SEMI/AUTO parallel mode, the custom activation must provide
>>> # If user wantss to run on the SEMI/AUTO parallel mode, the custom activation must provide
>>> # a class function named activation_shard. It accepts the argument parallel_config (OpParallelConfig,
>>> # MoEParallelConfig) and set the shard for the primitives used in the construct.
>>> class MyActivationWithShard(nn.Cell):
@ -427,6 +428,7 @@ class FeedForward(Cell):
>>> tensor = Tensor(np.ones((2, 20, 15)), mstype.float32)
>>> output = model(tensor)
>>> print(output.shape)
(2, 20, 15)
"""
@_LogActionOnce(logger=logger, key='FeedForward',
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
@ -1115,11 +1117,7 @@ class MultiHeadAttention(Cell):
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
value_past, batch_valid_length)
ori_shape = F.shape(query_tensor)
batch_size = None
if len(F.shape(query_tensor)) == 2:
batch_size = F.shape(query_tensor)[0] // self.src_seq_length
else:
batch_size = F.shape(query_tensor)[0]
batch_size = self._get_batch_size_from_query(query_tensor)
query_tensor, key_tensor, value_tensor = self._convert_to_2d_tensor(query_tensor,
key_tensor,
value_tensor,
@ -1136,18 +1134,21 @@ class MultiHeadAttention(Cell):
query = self.transpose(
F.reshape(
query,
(batch_size, -1, self.n_head, self.size_per_head)),
(batch_size, self._get_seq_length_under_incremental(self.src_seq_length),
self.n_head, self.size_per_head)),
(0, 2, 1, 3))
# the returned shape is [bs, size_per_head, seq_length, num_heads]
key = self.transpose(
F.reshape(
key, (batch_size, -1, self.n_head, self.size_per_head)),
key, (batch_size, self._get_seq_length_under_incremental(self.tgt_seq_length),
self.n_head, self.size_per_head)),
(0, 2, 3, 1))
# the returned shape is [bs, num_heads, seq_length, size_per_head]
value = self.transpose(
F.reshape(
value,
(batch_size, -1, self.n_head, self.size_per_head)),
(batch_size, self._get_seq_length_under_incremental(self.tgt_seq_length),
self.n_head, self.size_per_head)),
(0, 2, 1, 3))
# support input shape is [bs, seq, seq] or [bs, heads, seq, seq]
if attention_mask is not None and len(F.shape(attention_mask)) == 3:
@ -1201,6 +1202,24 @@ class MultiHeadAttention(Cell):
output = F.cast(output, ori_dtype)
return output, layer_present
def _get_batch_size_from_query(self, query):
r"""Get the batch size from query tensor"""
batch_size = None
# For the incremental prediction, the seq length for the input is 1.
if len(F.shape(query)) == 2 and self.is_first_iteration:
batch_size = F.shape(query)[0] // self.src_seq_length
else:
batch_size = F.shape(query)[0]
return batch_size
def _get_seq_length_under_incremental(self, length):
r"""Return the length of the tensor.
For the incremental prediction, the seq length for the input is 1.
"""
if self.is_first_iteration:
return length
return 1
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
value_past=None, batch_valid_length=None):
r"""Check inputs"""
@ -1384,7 +1403,7 @@ class TransformerEncoderLayer(Cell):
hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
If user want to run the net in the parallel mode, the custom activation must also provide
If user wants to run the net in the parallel mode, the custom activation must also provide
the `activation_shard` function. Please see the examples of the
class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two
@ -1811,7 +1830,7 @@ class TransformerDecoderLayer(Cell):
hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
If user want to run the net in the parallel mode, the custom activation must also provide
If user wants to run the net in the parallel mode, the custom activation must also provide
the `activation_shard` function. Please see the examples of the
class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
moe_config(MoEConfig): The configuration of MoE (Mixture of Expert). Default is an instance of MoEConfig
@ -2319,7 +2338,7 @@ class TransformerEncoder(Cell):
hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
If user want to run the net in the parallel mode, the custom activation must also provide
If user wants to run the net in the parallel mode, the custom activation must also provide
the `activation_shard` function. Please see the examples of the
class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.
@ -2582,7 +2601,7 @@ class TransformerDecoder(Cell):
hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
If user want to run the net in the parallel mode, the custom activation must also provide
If user wants to run the net in the parallel mode, the custom activation must also provide
the `activation_shard` function. Please see the examples of the
class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
lambda_func(function): A function can determine the fusion index,
@ -2827,7 +2846,7 @@ class Transformer(Cell):
hidden_act (str, nn.Cell): The activation of the internal feedforward layer. Supports 'relu',
'relu6', 'tanh', 'gelu', 'fast_gelu', 'elu', 'sigmoid', 'prelu', 'leakyrelu', 'hswish',
'hsigmoid', 'logsigmoid' and so on. User can provide custom activition to the argument.
If user want to run the net in the parallel mode, the custom activation must also provide
If user wants to run the net in the parallel mode, the custom activation must also provide
the `activation_shard` function. Please see the examples of the
class:`mindspore.nn.transformer.FeedForward`. Default: gelu.
post_layernorm_residual(bool): Do residuals adds before the layernorm. Default False.

View File

@ -544,3 +544,79 @@ class TestBasicWarningValidator:
# Force to rebuild the logger
test_transformer_decoder()
self.check_warning_log()
def test_attention_with_wrong_batch_3d_inputs():
"""
Feature: Test Transformer batch error when the input's batch size is different
Description: Test the input's batch size is different between the tensors. The input is 3d
Expectation: Raise a reshape error exception
"""
model = MultiHeadAttention(hidden_size=15, src_seq_length=20, tgt_seq_length=20,
batch_size=None, num_heads=3)
from_tensor = Tensor(np.ones((3, 20, 15)), dtype.float32)
to_tensor = Tensor(np.ones((5, 20, 15)), dtype.float16)
attention_mask = Tensor(np.ones((3, 20, 20)), dtype.float16)
with pytest.raises(ValueError):
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
def test_attention_with_wrong_batch_2d_inputs():
"""
Feature: Test Transformer batch error when the input's batch size is different
Description: Test the input's batch size is different between the tensors. The inputs is 2d
Expectation: Raise a reshape error exception
"""
model = MultiHeadAttention(hidden_size=15, src_seq_length=20, tgt_seq_length=20,
batch_size=None, num_heads=3)
from_tensor = Tensor(np.ones((60, 15)), dtype.float32)
to_tensor = Tensor(np.ones((100, 15)), dtype.float16)
attention_mask = Tensor(np.ones((3, 20, 20)), dtype.float16)
with pytest.raises(ValueError):
_cell_graph_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
def test_incremental_prediction_first_iterator():
"""
Feature: Test MultiHeadAttention with incremental prediction
Description: Test MultiHeadAttention with incremental prediction in the first iterator
Expectation: No Expectation
"""
# Step 1: set is_first_iteration=True, and input the full sequence length's state.
# We need to prepare the memory parameters for saving key and value states firstly.
from_tensor = Tensor(np.ones((2, 20, 15)), dtype.float32)
to_tensor = Tensor(np.ones((2, 20, 15)), dtype.float16)
attention_mask = Tensor(np.ones((2, 20, 20)), dtype.float16)
key_past = Tensor(np.zeros(shape=(2, 3, 5, 20)), dtype.float16)
value_past = Tensor(np.zeros(shape=(2, 3, 20, 5)), dtype.float16)
batch_valid_length = Tensor(np.ones((2,)), dtype.int32)
model = MultiHeadAttention(batch_size=2, hidden_size=15, src_seq_length=20, tgt_seq_length=20,
num_heads=3, use_past=True)
model.add_flags_recursive(is_first_iteration=True)
model(from_tensor, to_tensor, to_tensor, attention_mask, key_past, value_past,
batch_valid_length)
def test_incremental_prediction_second_iterator():
"""
Feature: Test MultiHeadAttention with incremental prediction
Description: Test MultiHeadAttention with incremental prediction in the second iterator
Expectation: No Expectation
"""
model = MultiHeadAttention(batch_size=2, hidden_size=15, src_seq_length=20, tgt_seq_length=20,
num_heads=3, use_past=True)
key_past = Tensor(np.zeros(shape=(2, 3, 5, 20)), dtype.float16)
value_past = Tensor(np.zeros(shape=(2, 3, 20, 5)), dtype.float16)
batch_valid_length = Tensor(np.ones((2,)), dtype.int32)
# Set is_first_iteration=True to generate the full memory states
from_tensor = Tensor(np.ones((2, 1, 15)), dtype.float32)
to_tensor = Tensor(np.ones((2, 1, 15)), dtype.float16)
attention_mask = Tensor(np.ones((2, 1, 20)), dtype.float16)
# Step 2: set is_first_iteration=False, and pass the single word to run the prediction rather than the
# full sequence.
model.add_flags_recursive(is_first_iteration=False)
model(from_tensor, to_tensor, to_tensor, attention_mask, key_past, value_past,
batch_valid_length)