!23746 Fix Past None Type Check

Merge pull request !23746 from huangxinjing/fix_past_none_check
This commit is contained in:
i-robot 2021-09-26 06:12:47 +00:00 committed by Gitee
commit 889fd81d77
2 changed files with 113 additions and 48 deletions

View File

@ -149,12 +149,17 @@ class _LayerInputCheck:
@constexpr
def _check_past_none_input_none(use_past, param_name, func_name, input_tensor, default_value=None):
def _check_past_none_input_none(use_past, param_name, func_name, default_value, is_tensor, is_default):
""" If the past is True, check whether the inputs is None"""
if not use_past and input_tensor is not default_value:
raise ValueError(f"{func_name} {param_name} should be {default_value}, if use_past is False.")
if use_past and input_tensor is default_value:
raise ValueError(f"{func_name} {param_name} should not be {default_value}, if use_past is True.")
if not use_past:
if is_tensor:
raise TypeError(f"{func_name} {param_name} should be {default_value}, if use_pat is False, but found "
f"a tensor")
if not is_default:
raise TypeError(f"{func_name} {param_name} should be {default_value}, if use_pat is False.")
else:
if not is_tensor:
raise TypeError(f"{func_name} {param_name} should be tensor, if use_pat is True")
return True

View File

@ -254,7 +254,7 @@ default_embedding_parallel_config = EmbeddingOpParallelConfig()
class FeedForward(Cell):
"""
r"""
The multilayer perceptron with two linear layers with dropout applied at final output. The first linear
will project the input dimension from hidden_size to ffn_hidden_size, the second linear will project the
dimension from ffn_hidden_size to hidden_size. The first linear is sharded on the relative dimension,
@ -263,7 +263,7 @@ class FeedForward(Cell):
.. math::
Dropout((xW_1+b_1)W_2 + b_2))
where the W_1, W_2, b_1 and b_2 are trainable parameters.
where the :math:`W_1, W_2, b_1` and :math:`b_2` are trainable parameters.
Args:
hidden_size (int): The dimension of the inputs.
@ -652,6 +652,7 @@ class MultiHeadAttention(Cell):
if self.is_parallel_mode and batch_size % parallel_config.data_parallel != 0:
raise ValueError(f"The batch size {batch_size} must be a "
f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.")
self.is_first_iteration = True
# Output layer
self.projection = _Linear(in_channels=hidden_size,
out_channels=hidden_size,
@ -780,7 +781,7 @@ class MultiHeadAttention(Cell):
# The first graph with the input size of (bs, seq_length)
if self.is_first_iteration:
# Get the valid input length without padding
valid_length_vector = F.cast(self.less(self.range, batch_valid_length.view(1, 1, -1)), self.dtype)
valid_length_vector = F.cast(self.less(self.range, batch_valid_length.view(-1, 1, 1)), self.dtype)
# Cover the key and value numbers corresponding to the padding position
key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
@ -823,26 +824,54 @@ class MultiHeadAttention(Cell):
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
value_past=None, batch_valid_length=None):
r"""Check inputs"""
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
[[self.batch_size, self.src_seq_length, self.hidden_size],
[self.batch_size * self.src_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, self.src_seq_length, self.tgt_seq_length])
if not self.use_past or (self.use_past and self.is_first_iteration):
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
[[self.batch_size, self.src_seq_length, self.hidden_size],
[self.batch_size * self.src_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, self.src_seq_length, self.tgt_seq_length])
else:
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
[self.batch_size, 1, self.hidden_size])
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
[self.batch_size, 1, self.hidden_size])
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
[self.batch_size, 1, self.hidden_size])
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, 1, self.tgt_seq_length])
_check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(value_tensor), "value_tensor", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
_check_past_none_input_none(self.use_past, "key_past", self.cls_name, key_past)
_check_past_none_input_none(self.use_past, "value_past", self.cls_name, value_past)
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
key_is_tensor = isinstance(key_past, Tensor)
value_is_tensor = isinstance(value_past, Tensor)
batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor)
key_is_default = key_past is None
value_is_default = value_past is None
batch_is_default = batch_valid_length is None
_check_past_none_input_none(self.use_past, "key_past", self.cls_name, None, key_is_tensor,
key_is_default)
_check_past_none_input_none(self.use_past, "value_past", self.cls_name, None, value_is_tensor,
value_is_default)
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
batch_valid_length_is_tensor, batch_is_default)
if self.use_past:
_check_shape_equal(F.shape(key_past), "key_past", self.cls_name,
[self.batch_size, self.n_head, self.size_per_head, self.tgt_seq_length])
_check_input_dtype(F.dtype(key_past), "key_past", [mstype.float16], self.cls_name)
_check_shape_equal(F.shape(value_past), "value_past", self.cls_name,
[self.batch_size, self.n_head, self.tgt_seq_length, self.size_per_head])
_check_input_dtype(F.dtype(value_past), "value_past", [mstype.float16], self.cls_name)
_check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
_check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
return True
def _convert_to_2d_tensor(self, query_tensor, key_tensor, value_tensor, attention_mask):
@ -966,7 +995,7 @@ class TransformerEncoderLayer(Cell):
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size].
- **input_mask** (Tensor) - Float Tensor, attention mask with shape [batch_size, seq_length, seq_length].
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
for incremental prediction when the use_past is True. Default None.
@ -1180,18 +1209,32 @@ class TransformerEncoderLayer(Cell):
def _check_input(self, x, input_mask, init_reset, batch_valid_length):
r"""Check inputs"""
_check_shape_equal(F.shape(x), "x", self.cls_name,
[[self.batch_size, self.seq_length, self.hidden_size],
[self.batch_size * self.seq_length, self.hidden_size]])
if not self.use_past or (self.use_past and self.is_first_iteration):
_check_shape_equal(F.shape(x), "x", self.cls_name,
[[self.batch_size, self.seq_length, self.hidden_size],
[self.batch_size * self.seq_length, self.hidden_size]])
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
[self.batch_size, self.seq_length, self.seq_length])
else:
_check_shape_equal(F.shape(x), "x", self.cls_name, [self.batch_size, 1, self.hidden_size])
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
[self.batch_size, 1, self.seq_length])
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
[self.batch_size, self.seq_length, self.seq_length])
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
_check_past_none_input_none(self.use_past, "init_reset", self.cls_name, init_reset, True)
if init_reset is not True:
init_reset_is_tensor = isinstance(init_reset, Tensor)
init_reset_is_default = init_reset is True
batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor)
batch_is_default = batch_valid_length is None
_check_past_none_input_none(self.use_past, "init_reset", self.cls_name, True, init_reset_is_tensor,
init_reset_is_default)
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
batch_valid_length_is_tensor, batch_is_default)
if self.use_past:
_check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1])
_check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
if batch_valid_length is not None:
_check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
_check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
return True
@ -1235,7 +1278,7 @@ class TransformerDecoderLayer(Cell):
[batch_size * seq_length, hidden_size].
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
src_seq_length], where tgt_seq_length is the length of the decoder.
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
for incremental prediction when the use_past is True. Default None.
@ -1323,6 +1366,8 @@ class TransformerDecoderLayer(Cell):
raise ValueError(
f"ffn_hidden_size must be divisibled by the model parallel way {parallel_config.model_parallel}, "
f"but found {ffn_hidden_size}")
if use_past is True:
raise ValueError(f"The {self.cls_name} does not support use_past=True.")
self.batch_size = batch_size
self.use_past = use_past
self.softmax_compute_type = softmax_compute_type
@ -1392,8 +1437,8 @@ class TransformerDecoderLayer(Cell):
self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
size_per_head = int(hidden_size / num_heads)
self.key_shape = (batch_size, num_heads, size_per_head, src_seq_length)
self.value_shape = (batch_size, num_heads, src_seq_length, size_per_head)
self.key_shape = (batch_size, num_heads, size_per_head, tgt_seq_length)
self.value_shape = (batch_size, num_heads, tgt_seq_length, size_per_head)
# parameters saving key and value states
self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past")
@ -1494,11 +1539,18 @@ class TransformerDecoderLayer(Cell):
def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length):
r"""Check inputs"""
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, self.tgt_seq_length, self.tgt_seq_length])
if not self.use_past or (self.use_past and self.is_first_iteration):
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, self.tgt_seq_length, self.tgt_seq_length])
else:
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
[self.batch_size, 1, self.hidden_size])
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, 1, self.tgt_seq_length])
_check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
if encoder_output is not None:
@ -1513,11 +1565,19 @@ class TransformerDecoderLayer(Cell):
_check_input_dtype(F.dtype(memory_mask), "memory_mask",
[mstype.float32, mstype.float16], self.cls_name)
_check_past_none_input_none(self.use_past, "init_reset", self.cls_name, init_reset, True)
if init_reset is not True:
init_reset_is_tensor = isinstance(init_reset, Tensor)
init_reset_is_default = init_reset is True
batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor)
batch_is_default = batch_valid_length is None
_check_past_none_input_none(self.use_past, "init_reset", self.cls_name, True, init_reset_is_tensor,
init_reset_is_default)
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
batch_valid_length_is_tensor, batch_is_default)
if self.use_past:
_check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1])
_check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
if batch_valid_length is not None:
_check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
_check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
return True
@ -1605,7 +1665,7 @@ class TransformerEncoder(Cell):
- **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size]
- **attention_mask** (Tensor) - Tensor, attention mask with shape [batch_size, seq_length, seq_length]
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
for incremental prediction when the use_past is True. Default None.
@ -1781,7 +1841,7 @@ class TransformerDecoder(Cell):
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
src_seq_length] where tgt_seq_length is the length of the decoder. the output of the encoder with shape
[batch_size, seq_length, hidden_size],
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
Used for incremental prediction when the use_past is True. Default None.
@ -1982,7 +2042,7 @@ class Transformer(Cell):
src_seq_length]
where tgt_seq_length is the length of the decoder. the output of the encoder with shape [batch_size,
seq_length, hidden_size], this should be none if the decoder layer is 0.
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
for incremental prediction when the use_past is True. Default None.
@ -2089,7 +2149,7 @@ class Transformer(Cell):
raise ValueError(f"Transformer doest support encoder layer {encoder_layers} and decoder"
f"layer {decoder_layers}, please use TransformerDecoder")
if encoder_layers > 0 and decoder_layers > 0 and use_past is True:
raise ValueError("The transformer with encoder and decoder does not support use_past=True.")
raise ValueError(f"The {self.cls_name} with encoder and decoder does not support use_past=True.")
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
raise RuntimeError(f"The {self.cls_name} does not support auto parallel mode now.")
# The shard setting of Transformer is set within the TransformerEncoderLayer