!23537 Update pangu reshape and softmax.

Merge pull request !23537 from linqingke/pangu
This commit is contained in:
i-robot 2021-09-24 08:12:13 +00:00 committed by Gitee
commit 7cde7731b0
5 changed files with 330 additions and 130 deletions

View File

@ -26,7 +26,7 @@ import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore._extends import cell_attr_register
from mindspore.nn.cell import Cell
import mindspore.nn as nn
from mindspore import nn
from mindspore.nn.layer.activation import get_activation
from mindspore.ops import functional as F
from mindspore._checkparam import Validator
@ -87,12 +87,66 @@ def _valid_value_checks(types, class_name):
return validator_check_func
@constexpr
def _check_input_shape(input_shape, param_name, func_name, target_len):
if len(input_shape) != target_len:
raise ValueError(f"{func_name} {param_name} should be {target_len}d, but got shape {input_shape}")
class _LayerInputCheck:
"""
A input check class for the inputs of the transformer model.
"""
@staticmethod
def check_shape_length(input_shape, param_name, func_name, target_len):
"""
Check the input shape's length is equal to the expected shape
:param input_shape(list): a list of the tensor shapes.
:param param_name(str): the name of the checked parameter.
:param func_name(str): the name of the function.
:param target_len: the expected length of the shape.
:return:
"""
if not isinstance(target_len, list):
target_len = [target_len]
matched = False
for item in target_len:
if len(input_shape) == item:
matched = True
if not matched:
raise ValueError(f"{func_name} {param_name} shape length should be one of {target_len} dimension, "
f"but got shape {input_shape}")
return True
@staticmethod
def check_shape_equal(input_shape, param_name, func_name, target_shape):
"""
Check the input shape's is equal to the expected shape
:param input_shape(list): a list of the tensor shapes.
:param param_name(str): the name of the checked parameter.
:param func_name(str): the name of the function.
:param target_shape: the expected shape.
:return:
"""
if not isinstance(target_shape[0], list):
target_shape = [target_shape]
if isinstance(input_shape, tuple):
input_shape = list(input_shape)
_LayerInputCheck.check_shape_length(input_shape, param_name, func_name,
[len(item) for item in target_shape])
matched = False
for item in target_shape:
if item == input_shape:
matched = True
break
if not matched:
raise ValueError(f"{func_name} {param_name} shape should be one of {target_shape},"
f"but got {input_shape}")
return True
@staticmethod
def check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value):
if input_shape[dim] != target_value:
raise ValueError(f"{cls_name} {param_name} at {dim} shape should be {target_value},"
f"but got {input_shape[dim]}")
return True
@constexpr
def _check_past_none_input_none(use_past, param_name, func_name, input_tensor, default_value=None):
@ -104,28 +158,27 @@ def _check_past_none_input_none(use_past, param_name, func_name, input_tensor, d
return True
@constexpr
def _check_shape_equal(input_shape, param_name, func_name, target_shape):
if len(input_shape) != len(target_shape):
raise ValueError(f"{func_name} {param_name} shape should be {target_shape},"
f"but got {input_shape}")
for i in range(len(input_shape)):
if input_shape[i] != target_shape[i]:
raise ValueError(f"{func_name} {param_name} shape should be {target_shape},"
f"but got {input_shape}")
return True
@constexpr
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
Validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
@constexpr
def _check_input_shape(input_shape, param_name, func_name, target_len):
# check the input length
_LayerInputCheck.check_shape_length(input_shape, param_name, func_name, target_len)
@constexpr
def _check_shape_equal(input_shape, param_name, func_name, target_shape):
# check the input length
_LayerInputCheck.check_shape_equal(input_shape, param_name, func_name, target_shape)
@constexpr
def _check_input_shape_value(input_shape, dim, param_name, cls_name, target_value):
if input_shape[dim] != target_value:
raise ValueError(f"{cls_name} {param_name} at {dim} shape should be {target_value},"
f"but got {input_shape[dim]}")
_LayerInputCheck.check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value)
class _LayerNorm(Cell):
@ -147,6 +200,11 @@ class _LayerNorm(Cell):
super(_LayerNorm, self).__init__()
if param_init_type not in [mstype.float32, mstype.float16]:
raise TypeError(f"param type should in [float32, float16], but found type {type(param_init_type)}")
if normalized_shape[0] <= 1024:
self.layer_norm = P.LayerNorm(begin_norm_axis=-1,
begin_params_axis=-1,
epsilon=eps)
self.is_self_defined = normalized_shape[0] > 1024
self.gamma = Parameter(initializer('ones', normalized_shape, param_init_type), name="gamma",
parallel_optimizer=False)
self.beta = Parameter(initializer('zeros', normalized_shape, param_init_type), name="beta",
@ -166,12 +224,15 @@ class _LayerNorm(Cell):
r"""
x : batch x seq_length x hidden_size
"""
if self.is_self_defined:
mean = self.mean(x, -1)
diff = self.sub1(x, mean)
variance = self.mean(self.square(diff), -1)
variance_eps = self.sqrt(self.add(variance, self.eps))
output = self.real_div(diff, variance_eps)
output = self.add2(self.mul(output, self.gamma), self.beta)
else:
output, _, _ = self.layer_norm(x, self.gamma, self.beta)
return output
def shard(self, strategy):
@ -188,6 +249,7 @@ class _LayerNorm(Cell):
>>> net = mindspore.parallel.nn.transformer.LayerNorm(normalized_shape=(1024, 10))
>>> net.shard(((10, 2, 1),))
"""
if self.is_self_defined:
self.mean.shard(strategy)
self.square.shard(strategy)
self.sqrt.shard(strategy)
@ -197,6 +259,8 @@ class _LayerNorm(Cell):
self.mul.shard((strategy[0], (1,)))
self.add2.shard((strategy[0], (1,)))
self.real_div.shard((strategy[0], strategy[0]))
else:
self.layer_norm.shard((strategy[0], (1,), (1,)))
return self

View File

@ -125,7 +125,7 @@ class MoE(Cell):
def construct(self, input_tensor):
bs = self.shape(input_tensor)[0]
input_shape = F.shape(input_tensor)
input_tensor = self.reshape(input_tensor, (-1, self.hidden_size))
bs_and_dmodel = self.shape(input_tensor)
tokens_per_device = bs_and_dmodel[0] / self.expert_parallel
@ -148,7 +148,7 @@ class MoE(Cell):
expert_capacity))
# expert_input's shape: (self.expert_dim, self.expert_parallel, expert_capacity, self.hidden_size)
expert_input = self.transpose2(expert_input, (2, 0, 3, 1))
expert_input = self.reshape(expert_input, (self.expert_dim, self.expert_parallel * expert_capacity,
expert_input = self.reshape(expert_input, (self.expert_dim * self.expert_parallel * expert_capacity,
self.hidden_size))
# expert_output's shape: (self.expert_dim, self.expert_parallel*expert_capacity, self.hidden_size)
@ -170,7 +170,7 @@ class MoE(Cell):
# combined_output's shape: (self.expert_parallel, tokens_per_device, self.hidden_size)
combined_output = self.transpose5(combined_output, (0, 2, 1))
combined_output = self.reshape(combined_output, (bs_and_dmodel[0], bs_and_dmodel[1]))
combined_output = self.reshape(combined_output, (bs, -1, self.hidden_size))
combined_output = self.reshape(combined_output, input_shape)
aux_loss = self.mul(self.aux_loss_factor, aux_loss)
return combined_output, aux_loss

View File

@ -281,10 +281,12 @@ class FeedForward(Cell):
default args.
Inputs:
- **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor.
- **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
Float tensor.
Outputs:
Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`.
Tensor, the output of this layer after mapping.
The shape is `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
Raises:
ValueError: `hidden_act` is not a string.
@ -344,11 +346,11 @@ class FeedForward(Cell):
if expert_num > 1:
self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)),
strategy_bias=((ep, 1, mp), (mp,)),
strategy_activation=((ep, 1, mp),))
strategy_activation=((ep, mp),))
else:
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
strategy_bias=((dp, mp), (mp,)),
strategy_activation=((dp, 1, mp),))
strategy_activation=((dp, mp),))
# Project back to hidden_size
self.projection = _Linear(in_channels=output_size,
out_channels=input_size,
@ -363,17 +365,17 @@ class FeedForward(Cell):
strategy_bias=((dp, 1), (1,)))
self.projection.bias.parallel_optimizer = False
self.dropout = nn.Dropout(1 - dropout_rate)
self.dropout.dropout.shard(((dp, 1, 1),))
self.dropout.dropout.shard(((dp, 1),))
self.cast = P.Cast()
def construct(self, x):
_check_input_shape(F.shape(x), "x", self.cls_name, 3)
_check_input_shape(F.shape(x), "x", self.cls_name, [2, 3])
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
x = self.cast(x, mstype.float16)
# returned shape is [bs, seq_length, ffn_hidden_size]
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
hidden = self.mapping(x)
output = self.projection(hidden)
# returned shape is [bs, seq_length, hidden_size]
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
output = self.dropout(output)
return output
@ -556,9 +558,12 @@ class MultiHeadAttention(Cell):
an instance of `OpParallelConfig` with default args.
Inputs:
- **query_tensor** (Tensor) - the query vector with shape (batch_size, src_seq_length, hidden_size).
- **key_tensor** (Tensor) - the key vector with shape (batch_size, tgt_seq_length, hidden_size).
- **value_tensor** (Tensor) - the value vector with shape (batch_size, tgt_seq_length, hidden_size).
- **query_tensor** (Tensor) - the query vector with shape (batch_size, src_seq_length, hidden_size) or
(batch_size * src_seq_length, hidden_size).
- **key_tensor** (Tensor) - the key vector with shape (batch_size, tgt_seq_length, hidden_size) or
(batch_size * src_seq_length, hidden_size).
- **value_tensor** (Tensor) - the value vector with shape (batch_size, tgt_seq_length, hidden_size) or
(batch_size * src_seq_length, hidden_size).
- **attention_mask** (Tensor) - the attention mask matrix with shape (batch_size, src_seq_length,
tgt_seq_length).
- **key_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, size_per_head, tgt_seq_length).
@ -574,7 +579,7 @@ class MultiHeadAttention(Cell):
Tuple, a tuple contains(`output`, `layer_present`)
- **output** (Tensor) - Tensor, the float tensor of the output of the layer with
shape (batch_size, src_seq_length, hidden_size)
shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size)
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
((batch_size, num_heads, size_per_head, tgt_seq_length),
@ -683,11 +688,11 @@ class MultiHeadAttention(Cell):
self.scale_factor = Tensor(math.sqrt(self.size_per_head))
self.use_past = use_past
self.dropout = nn.Dropout(1 - hidden_dropout_rate)
self.dropout.dropout.shard(((parallel_config.data_parallel, 1, 1),))
self.dropout.dropout.shard(((parallel_config.data_parallel, 1),))
self.prob_dropout = nn.Dropout(1 - attention_dropout_rate)
self.prob_dropout.dropout.shard(
((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),))
self.softmax = nn.Softmax()
self.softmax = nn.Softmax().to_float(softmax_compute_type)
self.softmax.softmax.shard(((parallel_config.data_parallel, parallel_config.model_parallel, 1),))
self.expand_dims = P.ExpandDims().shard(((parallel_config.data_parallel, 1, 1),))
@ -737,14 +742,11 @@ class MultiHeadAttention(Cell):
value_past=None, batch_valid_length=None):
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
value_past, batch_valid_length)
query_tensor_original_shape = F.shape(query_tensor)
query_tensor = F.reshape(query_tensor, (-1, query_tensor_original_shape[-1]))
key_tensor_original_shape = F.shape(key_tensor)
key_tensor = F.reshape(key_tensor, (-1, key_tensor_original_shape[-1]))
value_tensor_original_shape = F.shape(value_tensor)
value_tensor = F.reshape(value_tensor, (-1, value_tensor_original_shape[-1]))
batch_size = F.shape(attention_mask)[0]
query_tensor, key_tensor, value_tensor, batch_size, ori_shape = self._convert_to_2d_tensor(query_tensor,
key_tensor,
value_tensor,
attention_mask)
# multi head attention: query, key, value are derived from the same inputs
query = self.dense1(query_tensor)
@ -754,18 +756,18 @@ class MultiHeadAttention(Cell):
query = self.transpose(
F.reshape(
query,
(-1, query_tensor_original_shape[1], self.n_head, self.size_per_head)),
(batch_size, -1, self.n_head, self.size_per_head)),
(0, 2, 1, 3))
# the returned shape is [bs, num_heads, size_per_head, seq_length]
# the returned shape is [bs, size_per_head, seq_length, num_heads]
key = self.transpose(
F.reshape(
key, (-1, key_tensor_original_shape[1], self.n_head, self.size_per_head)),
key, (batch_size, -1, self.n_head, self.size_per_head)),
(0, 2, 3, 1))
# the returned shape is [bs, num_heads, seq_length, size_per_head]
value = self.transpose(
F.reshape(
value,
(-1, value_tensor_original_shape[1], self.n_head, self.size_per_head)),
(batch_size, -1, self.n_head, self.size_per_head)),
(0, 2, 1, 3))
# support input shape is [bs, seq, seq] or [bs, heads, seq, seq]
if len(F.shape(attention_mask)) == 3:
@ -810,22 +812,26 @@ class MultiHeadAttention(Cell):
layer_present = (key_present, value_present)
# multi head attention considering attention mask
# the return shape is [bs, seq_length, hidden_size]
# the return shape is [bs * seq_length, hidden_size]
attention = self._attn(query, key, value, attention_mask)
# Output
output = self.projection(attention)
output = self.dropout(output)
output = F.reshape(output, ori_shape)
return output, layer_present
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
value_past=None, batch_valid_length=None):
r"""Check inputs"""
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
[self.batch_size, self.src_seq_length, self.hidden_size])
[[self.batch_size, self.src_seq_length, self.hidden_size],
[self.batch_size * self.src_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
[self.batch_size, self.tgt_seq_length, self.hidden_size])
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
[self.batch_size, self.tgt_seq_length, self.hidden_size])
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, self.src_seq_length, self.tgt_seq_length])
@ -839,20 +845,30 @@ class MultiHeadAttention(Cell):
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
return True
def _convert_to_2d_tensor(self, query_tensor, key_tensor, value_tensor, attention_mask):
"""convert a nd tensor to a 2d tensor"""
query_shape = F.shape(query_tensor)
query_tensor = F.reshape(query_tensor, (-1, query_shape[-1]))
key_shape = F.shape(key_tensor)
key_tensor = F.reshape(key_tensor, (-1, key_shape[-1]))
value_shape = F.shape(value_tensor)
value_tensor = F.reshape(value_tensor, (-1, value_shape[-1]))
return query_tensor, key_tensor, value_tensor, F.shape(attention_mask)[0], query_shape
def _merge_heads(self, x):
"""
convert a 4d input to a 3d output
convert a 4d input to a 2d output
Inputs:
x: input tensor
Output:
x_merge: the 3d output
x_merge: the 2d output
"""
x = self.merger_head_transpose(
x, (0, 2, 1, 3)) # bs, seq_length, head, size_per_head
x_shape = P.Shape()(x)
new_shape = x_shape[:-2] + (x_shape[-2] * x_shape[-1],)
new_shape = (-1, x_shape[-2] * x_shape[-1])
x_merge = self.reshape(x, new_shape)
return x_merge
@ -947,7 +963,8 @@ class TransformerEncoderLayer(Cell):
an instance of `OpParallelConfig` with default args.
Inputs:
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size].
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size].
- **input_mask** (Tensor) - Float Tensor, attention mask with shape [batch_size, seq_length, seq_length].
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
@ -958,7 +975,7 @@ class TransformerEncoderLayer(Cell):
Tuple, a tuple contains(`output`, `layer_present`).
- **output** (Tensor) - The float tensor of the output of the layer with
shape (batch_size, seq_length, hidden_size).
shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size).
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
((batch_size, num_heads, size_per_head, seq_length),
@ -1034,9 +1051,9 @@ class TransformerEncoderLayer(Cell):
self.hidden_size = hidden_size
self.batch_size = batch_size
self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
self.layernorm1.shard(((parallel_config.data_parallel, 1, 1),))
self.layernorm1.shard(((parallel_config.data_parallel, 1),))
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
self.layernorm2.shard(((parallel_config.data_parallel, 1, 1),))
self.layernorm2.shard(((parallel_config.data_parallel, 1),))
self.attention = MultiHeadAttention(batch_size=batch_size,
src_seq_length=seq_length,
@ -1067,7 +1084,8 @@ class TransformerEncoderLayer(Cell):
hidden_act=hidden_act,
parallel_config=parallel_config)
self.post_layernorm_residual = post_layernorm_residual
self.add = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1)))
self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
self.dtype = mstype.float16
self.key_past = None
self.value_past = None
@ -1089,6 +1107,8 @@ class TransformerEncoderLayer(Cell):
def construct(self, x, input_mask, init_reset=True, batch_valid_length=None):
self._check_input(x, input_mask, init_reset, batch_valid_length)
x_shape = F.shape(x)
x = F.reshape(x, (-1, x_shape[-1]))
input_x = self.layernorm1(x)
input_x = F.cast(input_x, self.dtype)
@ -1137,10 +1157,23 @@ class TransformerEncoderLayer(Cell):
mlp_logit = F.depend(mlp_logit, value_update)
mlp_logit = F.depend(mlp_logit, key_update)
# if shape is 3d, we reshape the inputs of the add
if len(x_shape) == 3:
output_x = P.Reshape()(output_x, x_shape)
mlp_logit = P.Reshape()(mlp_logit, x_shape)
x = P.Reshape()(x, x_shape)
if self.post_layernorm_residual:
output = self.add_3d(output_x, mlp_logit)
else:
output = self.add_3d(x, mlp_logit)
else:
if self.post_layernorm_residual:
output = self.add(output_x, mlp_logit)
else:
output = self.add(x, mlp_logit)
output = F.reshape(output, x_shape)
if self.use_moe is True:
return output, layer_present, aux_loss
return output, layer_present
@ -1148,7 +1181,8 @@ class TransformerEncoderLayer(Cell):
def _check_input(self, x, input_mask, init_reset, batch_valid_length):
r"""Check inputs"""
_check_shape_equal(F.shape(x), "x", self.cls_name,
[self.batch_size, self.seq_length, self.hidden_size])
[[self.batch_size, self.seq_length, self.hidden_size],
[self.batch_size * self.seq_length, self.hidden_size]])
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
[self.batch_size, self.seq_length, self.seq_length])
@ -1193,10 +1227,12 @@ class TransformerDecoderLayer(Cell):
an instance of `OpParallelConfig` with default args.
Inputs:
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, tgt_seq_length, hidden_size].
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, tgt_seq_length, hidden_size] or
[batch_size * tgt_seq_length, hidden_size].
- **decoder_mask** (Tensor) - the attention mask for decoder with shape [batch_size, src_seq_length,
seq_length].
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size].
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size].
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
src_seq_length], where tgt_seq_length is the length of the decoder.
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
@ -1207,7 +1243,8 @@ class TransformerDecoderLayer(Cell):
Outputs:
Tuple, a tuple contains(`output`, `layer_present`)
- **output** (Tensor) - the output logit of this layer. The shape is [batch, seq_length, hidden_size]
- **output** (Tensor) - the output logit of this layer. The shape is [batch, seq_length, hidden_size] or
[batch * seq_length, hidden_size].
- **layer_present** (Tensor) - A tuple, where each tuple is the tensor of the projected key and value
vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length),
(batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector
@ -1295,10 +1332,10 @@ class TransformerDecoderLayer(Cell):
self.use_past = use_past
self.hidden_size = hidden_size
self.layernorm1 = _LayerNorm((hidden_size,), parallel_config.data_parallel).to_float(layernorm_compute_type)
self.layernorm1.shard(((parallel_config.data_parallel, 1, 1),))
self.layernorm2 = _LayerNorm((hidden_size,), parallel_config.data_parallel).to_float(layernorm_compute_type)
self.layernorm2.shard(((parallel_config.data_parallel, 1, 1),))
self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
self.layernorm1.shard(((parallel_config.data_parallel, 1),))
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
self.layernorm2.shard(((parallel_config.data_parallel, 1),))
self.attention = MultiHeadAttention(hidden_size=hidden_size,
num_heads=num_heads,
@ -1323,9 +1360,9 @@ class TransformerDecoderLayer(Cell):
use_past=use_past,
param_init_type=param_init_type,
parallel_config=parallel_config)
self.cross_attention_layernorm = _LayerNorm((hidden_size,), parallel_config.data_parallel).to_float(
self.cross_attention_layernorm = _LayerNorm((hidden_size,)).to_float(
layernorm_compute_type)
self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1, 1),))
self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1),))
self.use_moe = (moe_config.expert_num > 1)
if self.use_moe is True:
self.output = MoE(hidden_size=hidden_size,
@ -1344,7 +1381,8 @@ class TransformerDecoderLayer(Cell):
param_init_type=param_init_type,
parallel_config=parallel_config)
self.post_layernorm_residual = post_layernorm_residual
self.add = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1)))
self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
self.dtype = mstype.float16
self.key_past = None
self.value_past = None
@ -1369,7 +1407,9 @@ class TransformerDecoderLayer(Cell):
memory_mask=None,
init_reset=True, batch_valid_length=None):
self._check_input(hidden_stats, decoder_mask, encoder_output, memory_mask, init_reset, batch_valid_length)
# the returned shape is [bs, seq_length, embedding_size]
# the returned shape is [bs, seq_length, embedding_size] or [bs * seq_length, embedding_size]
hidden_shape = F.shape(hidden_stats)
hidden_stats = F.reshape(hidden_stats, (-1, hidden_shape[-1]))
input_x = self.layernorm1(hidden_stats)
input_x = F.cast(input_x, self.dtype)
@ -1431,10 +1471,23 @@ class TransformerDecoderLayer(Cell):
mlp_logit = F.depend(mlp_logit, value_update)
mlp_logit = F.depend(mlp_logit, key_update)
# if shape is 3d, we reshape the inputs of the add
if len(hidden_shape) == 3:
output_x = P.Reshape()(output_x, hidden_shape)
mlp_logit = P.Reshape()(mlp_logit, hidden_shape)
x = P.Reshape()(x, hidden_shape)
if self.post_layernorm_residual:
output = self.add_3d(output_x, mlp_logit)
else:
output = self.add_3d(x, mlp_logit)
else:
if self.post_layernorm_residual:
output = self.add(output_x, mlp_logit)
else:
output = self.add(x, mlp_logit)
output = F.reshape(output, hidden_shape)
if self.use_moe is True:
return output, layer_present, aux_loss
return output, layer_present
@ -1442,14 +1495,16 @@ class TransformerDecoderLayer(Cell):
def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length):
r"""Check inputs"""
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
[self.batch_size, self.tgt_seq_length, self.hidden_size])
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, self.tgt_seq_length, self.tgt_seq_length])
_check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
if encoder_output is not None:
_check_shape_equal(F.shape(encoder_output), "encoder_output", self.cls_name,
[self.batch_size, self.src_seq_length, self.hidden_size])
[[self.batch_size, self.src_seq_length, self.hidden_size],
[self.batch_size * self.src_seq_length, self.hidden_size]])
_check_input_dtype(F.dtype(encoder_output), "encoder_output",
[mstype.float32, mstype.float16], self.cls_name)
if memory_mask is not None:
@ -1547,7 +1602,8 @@ class TransformerEncoder(Cell):
an instance of `TransformerOpParallelConfig` with default args.
Inputs:
- **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size]
- **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size]
- **attention_mask** (Tensor) - Tensor, attention mask with shape [batch_size, seq_length, seq_length]
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
@ -1558,7 +1614,7 @@ class TransformerEncoder(Cell):
Tuple, a tuple contains(`output`, `layer_present`)
- **output** (Tensor) - The float tensor of the output of the layer with
shape (batch_size, seq_length, hidden_size)
shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size)
- **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple contains the Tensor the
projected key and value vector with shape ((batch_size, num_heads, size_per_head, seq_length),
and (batch_size, num_heads, seq_length, size_per_head)).
@ -1717,9 +1773,11 @@ class TransformerDecoder(Cell):
an instance of `TransformerOpParallelConfig` with default args.
Inputs:
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size]
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size]
- **attention_mask** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length]
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size]
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size]
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
src_seq_length] where tgt_seq_length is the length of the decoder. the output of the encoder with shape
[batch_size, seq_length, hidden_size],
@ -1731,7 +1789,8 @@ class TransformerDecoder(Cell):
Outputs:
Tuple, a tuple contains(`output`, `layer_present`)
- **output** (Tensor) - The output logit of this layer. The shape is [batch, tgt_seq_length, hidden_size]
- **output** (Tensor) - The output logit of this layer. The shape is [batch, tgt_seq_length, hidden_size] or
[batch * tgt_seq_length, hidden_size]
- **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor of the projected
key and value vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length),
(batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector
@ -1912,9 +1971,11 @@ class Transformer(Cell):
an instance of `TransformerOpParallelConfig` with default args.
Inputs:
- **encoder_inputs** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size].
- **encoder_inputs** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size].
- **encoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length].
- **decoder_inputs** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size],
- **decoder_inputs** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or
[batch_size * seq_length, hidden_size],
this should be none if the decoder layer is 0.
- **decoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length]
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
@ -1930,8 +1991,9 @@ class Transformer(Cell):
Tuple, a tuple contains(`output`, `encoder_layer_present`, `encoder_layer_present`)
- **output** (Tensor) - If there is only encoder, the output logit of the encoder layer. The shape is
[batch, src_seq_length, hidden_size], if there are encoder and decoders, the output is from the
decoder layer. The shape is [batch, tgt_seq_length, hidden_size].
[batch, src_seq_length, hidden_size] or [batch * src_seq_length, hidden_size], if there are encoder and
decoders, the output is from the decoder layer. The shape is [batch, tgt_seq_length, hidden_size] or
[batch * tgt_seq_length, hidden_size].
- **encoder_layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor the
projected key and value vector in self attention with shape ((batch_size, num_heads, size_per_head,
src_seq_length), (batch_size, num_heads, src_seq_length, size_per_head)).

View File

@ -20,7 +20,7 @@ from mindspore.context import set_auto_parallel_context, ParallelMode
from mindspore.ops import composite as C
from mindspore.parallel.nn import Transformer, TransformerOpParallelConfig, MoEConfig
from mindspore.nn.optim import AdamWeightDecay
from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell
from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell, _VirtualDatasetCell
from mindspore.train import Model
from tests.dataset_mock import MindData
from tests.ut.python.ops.test_math_ops import VirtualLoss
@ -77,7 +77,7 @@ def test_transformer_model():
ffn_hidden_size=64,
moe_config=moe_config,
parallel_config=config)
net = _VirtualDatasetCell(net)
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
@ -92,3 +92,35 @@ def test_transformer_model():
model = Model(net_with_grad)
model.train(1, dataset, dataset_sink_mode=False)
def test_transformer_model_2d():
set_auto_parallel_context(device_num=16, global_rank=0,
full_batch=True, enable_alltoall=True,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
net = Transformer(encoder_layers=1,
decoder_layers=1,
batch_size=2,
src_seq_length=20,
tgt_seq_length=10,
hidden_size=64,
num_heads=8,
ffn_hidden_size=64,
moe_config=moe_config,
parallel_config=config)
net = _VirtualDatasetCell(net)
encoder_input_value = Tensor(np.ones((40, 64)), mstype.float32)
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
decoder_input_value = Tensor(np.ones((20, 64)), mstype.float32)
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
net = NetWithLossFiveInputs(net)
params = net.trainable_params()
optimizer = AdamWeightDecay(params)
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
memory_mask)
net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
model = Model(net_with_grad)
model.train(1, dataset, dataset_sink_mode=False)

View File

@ -56,6 +56,28 @@ class Dataset(MindData):
self.index = 0
class TransformerNet(nn.Cell):
def __init__(self, en_layer, de_layer, parallel_config):
super(TransformerNet, self).__init__()
self.embedding = VocabEmbedding(vocab_size=240, embedding_size=20,
parallel_config=config.embedding_dp_mp_config)
self.network = Transformer(encoder_layers=en_layer,
decoder_layers=de_layer,
batch_size=2,
src_seq_length=20,
tgt_seq_length=10,
hidden_size=64,
num_heads=8,
ffn_hidden_size=64,
parallel_config=parallel_config)
self.head = Linear(in_channels=64, out_channels=200)
self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config)
def construct(self, x1, x2, x3, x4, x5, y, mask):
predict, _, _ = self.network(x1, x2, x3, x4, x5)
predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
return self.loss(predict, y, mask)
config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
pipeline_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, pipeline_stage=4,
micro_batch_num=4, vocab_emb_dp=False)
@ -84,28 +106,6 @@ def run_total_transformer_model_head(e_layer,
full_batch=True,
global_rank=0, parallel_mode=mode)
class Net(nn.Cell):
def __init__(self, en_layer, de_layer, parallel_config):
super(Net, self).__init__()
self.embedding = VocabEmbedding(vocab_size=240, embedding_size=20,
parallel_config=config.embedding_dp_mp_config)
self.network = Transformer(encoder_layers=en_layer,
decoder_layers=de_layer,
batch_size=2,
src_seq_length=20,
tgt_seq_length=10,
hidden_size=64,
num_heads=8,
ffn_hidden_size=64,
parallel_config=parallel_config)
self.head = Linear(in_channels=64, out_channels=200)
self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config)
def construct(self, x1, x2, x3, x4, x5, y, mask):
predict, _, _ = self.network(x1, x2, x3, x4, x5)
predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
return self.loss(predict, y, mask)
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
@ -116,7 +116,8 @@ def run_total_transformer_model_head(e_layer,
seq = 10
label = Tensor(np.ones((2 * seq,)), mstype.int32)
input_mask = Tensor(np.ones((2 * seq,)), mstype.float32)
net = Net(en_layer=e_layer, de_layer=d_layer, parallel_config=arg_parallel_config)
net = TransformerNet(en_layer=e_layer, de_layer=d_layer, parallel_config=arg_parallel_config)
net = _VirtualDatasetCell(net)
params = net.trainable_params()
optimizer = AdamWeightDecay(params)
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
@ -147,6 +148,38 @@ def test_transformer_model():
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
net = NetWithLossFiveInputs(net)
net = _VirtualDatasetCell(net)
params = net.trainable_params()
optimizer = AdamWeightDecay(params)
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
memory_mask)
net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
model = Model(net_with_grad)
model.train(1, dataset, dataset_sink_mode=False)
def test_transformer_model_2d_inputs():
set_auto_parallel_context(device_num=8, global_rank=0,
full_batch=True,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
net = Transformer(encoder_layers=1,
decoder_layers=2,
batch_size=2,
src_seq_length=20,
tgt_seq_length=10,
hidden_size=64,
num_heads=8,
ffn_hidden_size=64,
parallel_config=config)
encoder_input_value = Tensor(np.ones((40, 64)), mstype.float32)
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
decoder_input_value = Tensor(np.ones((20, 64)), mstype.float32)
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
net = NetWithLossFiveInputs(net)
net = _VirtualDatasetCell(net)
params = net.trainable_params()
optimizer = AdamWeightDecay(params)
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
@ -177,6 +210,7 @@ def test_transformer_model_int64_inputs():
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
net = NetWithLossFiveInputs(net)
net = _VirtualDatasetCell(net)
params = net.trainable_params()
optimizer = AdamWeightDecay(params)
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
@ -330,6 +364,8 @@ def test_encoder():
net = NetWithLoss(net)
net = _VirtualDatasetCell(net)
dataset = Dataset(encoder_input_value, encoder_input_mask)
model = Model(net)
@ -367,6 +403,8 @@ def test_decoder():
net = NetWithLoss(net)
net = _VirtualDatasetCell(net)
dataset = Dataset(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
model = Model(net)
@ -388,6 +426,7 @@ def test_vocabembedding_dp_true():
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
net = NetWithLoss(net)
net = _VirtualDatasetCell(net)
encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
dataset = Dataset(encoder_input_value)
@ -410,6 +449,7 @@ def test_vocabembedding_dp_false():
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
net = NetWithLoss(net)
net = _VirtualDatasetCell(net)
encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
dataset = Dataset(encoder_input_value)
@ -484,6 +524,7 @@ def test_sparse_attention_parallel_dp():
num_heads=8,
block_size=64,
parallel_config=sparse_attention_config)
net = _VirtualDatasetCell(net)
q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
@ -509,6 +550,7 @@ def test_parallel_cross_entroy_loss_semi_auto_parallel():
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
net = NetWithLoss(net, config.dp_mp_config)
net = _VirtualDatasetCell(net)
embed_ids = Tensor(np.ones((2, 64)), mstype.int32)
labels = Tensor(np.ones((2 * 64,)), mstype.int32)
input_mask = Tensor(np.ones((2 * 64,)), mstype.float32)