parent
2eb47e4296
commit
9ddd2b7669
|
@ -149,12 +149,17 @@ class _LayerInputCheck:
|
|||
|
||||
|
||||
@constexpr
|
||||
def _check_past_none_input_none(use_past, param_name, func_name, input_tensor, default_value=None):
|
||||
def _check_past_none_input_none(use_past, param_name, func_name, default_value, is_tensor, is_default):
|
||||
""" If the past is True, check whether the inputs is None"""
|
||||
if not use_past and input_tensor is not default_value:
|
||||
raise ValueError(f"{func_name} {param_name} should be {default_value}, if use_past is False.")
|
||||
if use_past and input_tensor is default_value:
|
||||
raise ValueError(f"{func_name} {param_name} should not be {default_value}, if use_past is True.")
|
||||
if not use_past:
|
||||
if is_tensor:
|
||||
raise TypeError(f"{func_name} {param_name} should be {default_value}, if use_pat is False, but found "
|
||||
f"a tensor")
|
||||
if not is_default:
|
||||
raise TypeError(f"{func_name} {param_name} should be {default_value}, if use_pat is False.")
|
||||
else:
|
||||
if not is_tensor:
|
||||
raise TypeError(f"{func_name} {param_name} should be tensor, if use_pat is True")
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
@ -254,7 +254,7 @@ default_embedding_parallel_config = EmbeddingOpParallelConfig()
|
|||
|
||||
|
||||
class FeedForward(Cell):
|
||||
"""
|
||||
r"""
|
||||
The multilayer perceptron with two linear layers with dropout applied at final output. The first linear
|
||||
will project the input dimension from hidden_size to ffn_hidden_size, the second linear will project the
|
||||
dimension from ffn_hidden_size to hidden_size. The first linear is sharded on the relative dimension,
|
||||
|
@ -263,7 +263,7 @@ class FeedForward(Cell):
|
|||
.. math::
|
||||
Dropout((xW_1+b_1)W_2 + b_2))
|
||||
|
||||
where the W_1, W_2, b_1 and b_2 are trainable parameters.
|
||||
where the :math:`W_1, W_2, b_1` and :math:`b_2` are trainable parameters.
|
||||
|
||||
Args:
|
||||
hidden_size (int): The dimension of the inputs.
|
||||
|
@ -652,6 +652,7 @@ class MultiHeadAttention(Cell):
|
|||
if self.is_parallel_mode and batch_size % parallel_config.data_parallel != 0:
|
||||
raise ValueError(f"The batch size {batch_size} must be a "
|
||||
f"multiple of parallel_config.data_parallel {parallel_config.data_parallel}.")
|
||||
self.is_first_iteration = True
|
||||
# Output layer
|
||||
self.projection = _Linear(in_channels=hidden_size,
|
||||
out_channels=hidden_size,
|
||||
|
@ -780,7 +781,7 @@ class MultiHeadAttention(Cell):
|
|||
# The first graph with the input size of (bs, seq_length)
|
||||
if self.is_first_iteration:
|
||||
# Get the valid input length without padding
|
||||
valid_length_vector = F.cast(self.less(self.range, batch_valid_length.view(1, 1, -1)), self.dtype)
|
||||
valid_length_vector = F.cast(self.less(self.range, batch_valid_length.view(-1, 1, 1)), self.dtype)
|
||||
# Cover the key and value numbers corresponding to the padding position
|
||||
key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2))
|
||||
value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3))
|
||||
|
@ -823,26 +824,54 @@ class MultiHeadAttention(Cell):
|
|||
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
|
||||
value_past=None, batch_valid_length=None):
|
||||
r"""Check inputs"""
|
||||
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
|
||||
[[self.batch_size, self.src_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.src_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
|
||||
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
|
||||
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, self.src_seq_length, self.tgt_seq_length])
|
||||
if not self.use_past or (self.use_past and self.is_first_iteration):
|
||||
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
|
||||
[[self.batch_size, self.src_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.src_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
|
||||
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
|
||||
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, self.src_seq_length, self.tgt_seq_length])
|
||||
else:
|
||||
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
|
||||
[self.batch_size, 1, self.hidden_size])
|
||||
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
|
||||
[self.batch_size, 1, self.hidden_size])
|
||||
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
|
||||
[self.batch_size, 1, self.hidden_size])
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, 1, self.tgt_seq_length])
|
||||
|
||||
_check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(value_tensor), "value_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
|
||||
|
||||
_check_past_none_input_none(self.use_past, "key_past", self.cls_name, key_past)
|
||||
_check_past_none_input_none(self.use_past, "value_past", self.cls_name, value_past)
|
||||
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
|
||||
key_is_tensor = isinstance(key_past, Tensor)
|
||||
value_is_tensor = isinstance(value_past, Tensor)
|
||||
batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor)
|
||||
key_is_default = key_past is None
|
||||
value_is_default = value_past is None
|
||||
batch_is_default = batch_valid_length is None
|
||||
_check_past_none_input_none(self.use_past, "key_past", self.cls_name, None, key_is_tensor,
|
||||
key_is_default)
|
||||
_check_past_none_input_none(self.use_past, "value_past", self.cls_name, None, value_is_tensor,
|
||||
value_is_default)
|
||||
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
|
||||
batch_valid_length_is_tensor, batch_is_default)
|
||||
if self.use_past:
|
||||
_check_shape_equal(F.shape(key_past), "key_past", self.cls_name,
|
||||
[self.batch_size, self.n_head, self.size_per_head, self.tgt_seq_length])
|
||||
_check_input_dtype(F.dtype(key_past), "key_past", [mstype.float16], self.cls_name)
|
||||
_check_shape_equal(F.shape(value_past), "value_past", self.cls_name,
|
||||
[self.batch_size, self.n_head, self.tgt_seq_length, self.size_per_head])
|
||||
_check_input_dtype(F.dtype(value_past), "value_past", [mstype.float16], self.cls_name)
|
||||
_check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
|
||||
_check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
|
||||
return True
|
||||
|
||||
def _convert_to_2d_tensor(self, query_tensor, key_tensor, value_tensor, attention_mask):
|
||||
|
@ -966,7 +995,7 @@ class TransformerEncoderLayer(Cell):
|
|||
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or
|
||||
[batch_size * seq_length, hidden_size].
|
||||
- **input_mask** (Tensor) - Float Tensor, attention mask with shape [batch_size, seq_length, seq_length].
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
|
||||
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
|
||||
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
|
||||
for incremental prediction when the use_past is True. Default None.
|
||||
|
@ -1180,18 +1209,32 @@ class TransformerEncoderLayer(Cell):
|
|||
|
||||
def _check_input(self, x, input_mask, init_reset, batch_valid_length):
|
||||
r"""Check inputs"""
|
||||
_check_shape_equal(F.shape(x), "x", self.cls_name,
|
||||
[[self.batch_size, self.seq_length, self.hidden_size],
|
||||
[self.batch_size * self.seq_length, self.hidden_size]])
|
||||
if not self.use_past or (self.use_past and self.is_first_iteration):
|
||||
_check_shape_equal(F.shape(x), "x", self.cls_name,
|
||||
[[self.batch_size, self.seq_length, self.hidden_size],
|
||||
[self.batch_size * self.seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
|
||||
[self.batch_size, self.seq_length, self.seq_length])
|
||||
else:
|
||||
_check_shape_equal(F.shape(x), "x", self.cls_name, [self.batch_size, 1, self.hidden_size])
|
||||
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
|
||||
[self.batch_size, 1, self.seq_length])
|
||||
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
|
||||
[self.batch_size, self.seq_length, self.seq_length])
|
||||
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_past_none_input_none(self.use_past, "init_reset", self.cls_name, init_reset, True)
|
||||
if init_reset is not True:
|
||||
|
||||
init_reset_is_tensor = isinstance(init_reset, Tensor)
|
||||
init_reset_is_default = init_reset is True
|
||||
batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor)
|
||||
batch_is_default = batch_valid_length is None
|
||||
_check_past_none_input_none(self.use_past, "init_reset", self.cls_name, True, init_reset_is_tensor,
|
||||
init_reset_is_default)
|
||||
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
|
||||
batch_valid_length_is_tensor, batch_is_default)
|
||||
|
||||
if self.use_past:
|
||||
_check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1])
|
||||
_check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
|
||||
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
|
||||
if batch_valid_length is not None:
|
||||
_check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
|
||||
_check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
|
||||
return True
|
||||
|
||||
|
@ -1235,7 +1278,7 @@ class TransformerDecoderLayer(Cell):
|
|||
[batch_size * seq_length, hidden_size].
|
||||
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
|
||||
src_seq_length], where tgt_seq_length is the length of the decoder.
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
|
||||
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
|
||||
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
|
||||
for incremental prediction when the use_past is True. Default None.
|
||||
|
@ -1323,6 +1366,8 @@ class TransformerDecoderLayer(Cell):
|
|||
raise ValueError(
|
||||
f"ffn_hidden_size must be divisibled by the model parallel way {parallel_config.model_parallel}, "
|
||||
f"but found {ffn_hidden_size}")
|
||||
if use_past is True:
|
||||
raise ValueError(f"The {self.cls_name} does not support use_past=True.")
|
||||
self.batch_size = batch_size
|
||||
self.use_past = use_past
|
||||
self.softmax_compute_type = softmax_compute_type
|
||||
|
@ -1392,8 +1437,8 @@ class TransformerDecoderLayer(Cell):
|
|||
self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ()))
|
||||
self.slice = P.StridedSlice().shard(((1, 1, 1, 1),))
|
||||
size_per_head = int(hidden_size / num_heads)
|
||||
self.key_shape = (batch_size, num_heads, size_per_head, src_seq_length)
|
||||
self.value_shape = (batch_size, num_heads, src_seq_length, size_per_head)
|
||||
self.key_shape = (batch_size, num_heads, size_per_head, tgt_seq_length)
|
||||
self.value_shape = (batch_size, num_heads, tgt_seq_length, size_per_head)
|
||||
# parameters saving key and value states
|
||||
self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past")
|
||||
self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past")
|
||||
|
@ -1494,11 +1539,18 @@ class TransformerDecoderLayer(Cell):
|
|||
|
||||
def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length):
|
||||
r"""Check inputs"""
|
||||
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
|
||||
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, self.tgt_seq_length, self.tgt_seq_length])
|
||||
if not self.use_past or (self.use_past and self.is_first_iteration):
|
||||
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
|
||||
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, self.tgt_seq_length, self.tgt_seq_length])
|
||||
|
||||
else:
|
||||
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
|
||||
[self.batch_size, 1, self.hidden_size])
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, 1, self.tgt_seq_length])
|
||||
_check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
|
||||
if encoder_output is not None:
|
||||
|
@ -1513,11 +1565,19 @@ class TransformerDecoderLayer(Cell):
|
|||
_check_input_dtype(F.dtype(memory_mask), "memory_mask",
|
||||
[mstype.float32, mstype.float16], self.cls_name)
|
||||
|
||||
_check_past_none_input_none(self.use_past, "init_reset", self.cls_name, init_reset, True)
|
||||
if init_reset is not True:
|
||||
init_reset_is_tensor = isinstance(init_reset, Tensor)
|
||||
init_reset_is_default = init_reset is True
|
||||
batch_valid_length_is_tensor = isinstance(batch_valid_length, Tensor)
|
||||
batch_is_default = batch_valid_length is None
|
||||
_check_past_none_input_none(self.use_past, "init_reset", self.cls_name, True, init_reset_is_tensor,
|
||||
init_reset_is_default)
|
||||
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
|
||||
batch_valid_length_is_tensor, batch_is_default)
|
||||
|
||||
if self.use_past:
|
||||
_check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1])
|
||||
_check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
|
||||
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
|
||||
if batch_valid_length is not None:
|
||||
_check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
|
||||
_check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
|
||||
return True
|
||||
|
||||
|
@ -1605,7 +1665,7 @@ class TransformerEncoder(Cell):
|
|||
- **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size] or
|
||||
[batch_size * seq_length, hidden_size]
|
||||
- **attention_mask** (Tensor) - Tensor, attention mask with shape [batch_size, seq_length, seq_length]
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
|
||||
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
|
||||
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
|
||||
for incremental prediction when the use_past is True. Default None.
|
||||
|
@ -1781,7 +1841,7 @@ class TransformerDecoder(Cell):
|
|||
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
|
||||
src_seq_length] where tgt_seq_length is the length of the decoder. the output of the encoder with shape
|
||||
[batch_size, seq_length, hidden_size],
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
|
||||
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
|
||||
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index.
|
||||
Used for incremental prediction when the use_past is True. Default None.
|
||||
|
@ -1982,7 +2042,7 @@ class Transformer(Cell):
|
|||
src_seq_length]
|
||||
where tgt_seq_length is the length of the decoder. the output of the encoder with shape [batch_size,
|
||||
seq_length, hidden_size], this should be none if the decoder layer is 0.
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [1], used to clear the past key parameter and
|
||||
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
|
||||
- **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. Used
|
||||
for incremental prediction when the use_past is True. Default None.
|
||||
|
@ -2089,7 +2149,7 @@ class Transformer(Cell):
|
|||
raise ValueError(f"Transformer doest support encoder layer {encoder_layers} and decoder"
|
||||
f"layer {decoder_layers}, please use TransformerDecoder")
|
||||
if encoder_layers > 0 and decoder_layers > 0 and use_past is True:
|
||||
raise ValueError("The transformer with encoder and decoder does not support use_past=True.")
|
||||
raise ValueError(f"The {self.cls_name} with encoder and decoder does not support use_past=True.")
|
||||
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,):
|
||||
raise RuntimeError(f"The {self.cls_name} does not support auto parallel mode now.")
|
||||
# The shard setting of Transformer is set within the TransformerEncoderLayer
|
||||
|
|
Loading…
Reference in New Issue