forked from mindspore-Ecosystem/mindspore
!23537 Update pangu reshape and softmax.
Merge pull request !23537 from linqingke/pangu
This commit is contained in:
commit
7cde7731b0
|
@ -26,7 +26,7 @@ import mindspore.common.dtype as mstype
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore._extends import cell_attr_register
|
||||
from mindspore.nn.cell import Cell
|
||||
import mindspore.nn as nn
|
||||
from mindspore import nn
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore._checkparam import Validator
|
||||
|
@ -87,12 +87,66 @@ def _valid_value_checks(types, class_name):
|
|||
return validator_check_func
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_shape(input_shape, param_name, func_name, target_len):
|
||||
if len(input_shape) != target_len:
|
||||
raise ValueError(f"{func_name} {param_name} should be {target_len}d, but got shape {input_shape}")
|
||||
class _LayerInputCheck:
|
||||
"""
|
||||
A input check class for the inputs of the transformer model.
|
||||
"""
|
||||
@staticmethod
|
||||
def check_shape_length(input_shape, param_name, func_name, target_len):
|
||||
"""
|
||||
Check the input shape's length is equal to the expected shape
|
||||
:param input_shape(list): a list of the tensor shapes.
|
||||
:param param_name(str): the name of the checked parameter.
|
||||
:param func_name(str): the name of the function.
|
||||
:param target_len: the expected length of the shape.
|
||||
:return:
|
||||
"""
|
||||
if not isinstance(target_len, list):
|
||||
target_len = [target_len]
|
||||
matched = False
|
||||
for item in target_len:
|
||||
if len(input_shape) == item:
|
||||
matched = True
|
||||
if not matched:
|
||||
raise ValueError(f"{func_name} {param_name} shape length should be one of {target_len} dimension, "
|
||||
f"but got shape {input_shape}")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def check_shape_equal(input_shape, param_name, func_name, target_shape):
|
||||
"""
|
||||
Check the input shape's is equal to the expected shape
|
||||
:param input_shape(list): a list of the tensor shapes.
|
||||
:param param_name(str): the name of the checked parameter.
|
||||
:param func_name(str): the name of the function.
|
||||
:param target_shape: the expected shape.
|
||||
:return:
|
||||
"""
|
||||
if not isinstance(target_shape[0], list):
|
||||
target_shape = [target_shape]
|
||||
if isinstance(input_shape, tuple):
|
||||
input_shape = list(input_shape)
|
||||
_LayerInputCheck.check_shape_length(input_shape, param_name, func_name,
|
||||
[len(item) for item in target_shape])
|
||||
matched = False
|
||||
for item in target_shape:
|
||||
if item == input_shape:
|
||||
matched = True
|
||||
break
|
||||
|
||||
if not matched:
|
||||
raise ValueError(f"{func_name} {param_name} shape should be one of {target_shape},"
|
||||
f"but got {input_shape}")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value):
|
||||
if input_shape[dim] != target_value:
|
||||
raise ValueError(f"{cls_name} {param_name} at {dim} shape should be {target_value},"
|
||||
f"but got {input_shape[dim]}")
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_past_none_input_none(use_past, param_name, func_name, input_tensor, default_value=None):
|
||||
|
@ -104,28 +158,27 @@ def _check_past_none_input_none(use_past, param_name, func_name, input_tensor, d
|
|||
return True
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_shape_equal(input_shape, param_name, func_name, target_shape):
|
||||
if len(input_shape) != len(target_shape):
|
||||
raise ValueError(f"{func_name} {param_name} shape should be {target_shape},"
|
||||
f"but got {input_shape}")
|
||||
for i in range(len(input_shape)):
|
||||
if input_shape[i] != target_shape[i]:
|
||||
raise ValueError(f"{func_name} {param_name} shape should be {target_shape},"
|
||||
f"but got {input_shape}")
|
||||
return True
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
||||
Validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_shape(input_shape, param_name, func_name, target_len):
|
||||
# check the input length
|
||||
_LayerInputCheck.check_shape_length(input_shape, param_name, func_name, target_len)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_shape_equal(input_shape, param_name, func_name, target_shape):
|
||||
# check the input length
|
||||
_LayerInputCheck.check_shape_equal(input_shape, param_name, func_name, target_shape)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_shape_value(input_shape, dim, param_name, cls_name, target_value):
|
||||
if input_shape[dim] != target_value:
|
||||
raise ValueError(f"{cls_name} {param_name} at {dim} shape should be {target_value},"
|
||||
f"but got {input_shape[dim]}")
|
||||
_LayerInputCheck.check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value)
|
||||
|
||||
|
||||
class _LayerNorm(Cell):
|
||||
|
@ -147,6 +200,11 @@ class _LayerNorm(Cell):
|
|||
super(_LayerNorm, self).__init__()
|
||||
if param_init_type not in [mstype.float32, mstype.float16]:
|
||||
raise TypeError(f"param type should in [float32, float16], but found type {type(param_init_type)}")
|
||||
if normalized_shape[0] <= 1024:
|
||||
self.layer_norm = P.LayerNorm(begin_norm_axis=-1,
|
||||
begin_params_axis=-1,
|
||||
epsilon=eps)
|
||||
self.is_self_defined = normalized_shape[0] > 1024
|
||||
self.gamma = Parameter(initializer('ones', normalized_shape, param_init_type), name="gamma",
|
||||
parallel_optimizer=False)
|
||||
self.beta = Parameter(initializer('zeros', normalized_shape, param_init_type), name="beta",
|
||||
|
@ -166,12 +224,15 @@ class _LayerNorm(Cell):
|
|||
r"""
|
||||
x : batch x seq_length x hidden_size
|
||||
"""
|
||||
if self.is_self_defined:
|
||||
mean = self.mean(x, -1)
|
||||
diff = self.sub1(x, mean)
|
||||
variance = self.mean(self.square(diff), -1)
|
||||
variance_eps = self.sqrt(self.add(variance, self.eps))
|
||||
output = self.real_div(diff, variance_eps)
|
||||
output = self.add2(self.mul(output, self.gamma), self.beta)
|
||||
else:
|
||||
output, _, _ = self.layer_norm(x, self.gamma, self.beta)
|
||||
return output
|
||||
|
||||
def shard(self, strategy):
|
||||
|
@ -188,6 +249,7 @@ class _LayerNorm(Cell):
|
|||
>>> net = mindspore.parallel.nn.transformer.LayerNorm(normalized_shape=(1024, 10))
|
||||
>>> net.shard(((10, 2, 1),))
|
||||
"""
|
||||
if self.is_self_defined:
|
||||
self.mean.shard(strategy)
|
||||
self.square.shard(strategy)
|
||||
self.sqrt.shard(strategy)
|
||||
|
@ -197,6 +259,8 @@ class _LayerNorm(Cell):
|
|||
self.mul.shard((strategy[0], (1,)))
|
||||
self.add2.shard((strategy[0], (1,)))
|
||||
self.real_div.shard((strategy[0], strategy[0]))
|
||||
else:
|
||||
self.layer_norm.shard((strategy[0], (1,), (1,)))
|
||||
|
||||
return self
|
||||
|
||||
|
|
|
@ -125,7 +125,7 @@ class MoE(Cell):
|
|||
|
||||
|
||||
def construct(self, input_tensor):
|
||||
bs = self.shape(input_tensor)[0]
|
||||
input_shape = F.shape(input_tensor)
|
||||
input_tensor = self.reshape(input_tensor, (-1, self.hidden_size))
|
||||
bs_and_dmodel = self.shape(input_tensor)
|
||||
tokens_per_device = bs_and_dmodel[0] / self.expert_parallel
|
||||
|
@ -148,7 +148,7 @@ class MoE(Cell):
|
|||
expert_capacity))
|
||||
# expert_input's shape: (self.expert_dim, self.expert_parallel, expert_capacity, self.hidden_size)
|
||||
expert_input = self.transpose2(expert_input, (2, 0, 3, 1))
|
||||
expert_input = self.reshape(expert_input, (self.expert_dim, self.expert_parallel * expert_capacity,
|
||||
expert_input = self.reshape(expert_input, (self.expert_dim * self.expert_parallel * expert_capacity,
|
||||
self.hidden_size))
|
||||
|
||||
# expert_output's shape: (self.expert_dim, self.expert_parallel*expert_capacity, self.hidden_size)
|
||||
|
@ -170,7 +170,7 @@ class MoE(Cell):
|
|||
# combined_output's shape: (self.expert_parallel, tokens_per_device, self.hidden_size)
|
||||
combined_output = self.transpose5(combined_output, (0, 2, 1))
|
||||
combined_output = self.reshape(combined_output, (bs_and_dmodel[0], bs_and_dmodel[1]))
|
||||
combined_output = self.reshape(combined_output, (bs, -1, self.hidden_size))
|
||||
combined_output = self.reshape(combined_output, input_shape)
|
||||
|
||||
aux_loss = self.mul(self.aux_loss_factor, aux_loss)
|
||||
return combined_output, aux_loss
|
||||
|
|
|
@ -281,10 +281,12 @@ class FeedForward(Cell):
|
|||
default args.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor.
|
||||
- **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
|
||||
Float tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`.
|
||||
Tensor, the output of this layer after mapping.
|
||||
The shape is `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
|
||||
|
||||
Raises:
|
||||
ValueError: `hidden_act` is not a string.
|
||||
|
@ -344,11 +346,11 @@ class FeedForward(Cell):
|
|||
if expert_num > 1:
|
||||
self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)),
|
||||
strategy_bias=((ep, 1, mp), (mp,)),
|
||||
strategy_activation=((ep, 1, mp),))
|
||||
strategy_activation=((ep, mp),))
|
||||
else:
|
||||
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
|
||||
strategy_bias=((dp, mp), (mp,)),
|
||||
strategy_activation=((dp, 1, mp),))
|
||||
strategy_activation=((dp, mp),))
|
||||
# Project back to hidden_size
|
||||
self.projection = _Linear(in_channels=output_size,
|
||||
out_channels=input_size,
|
||||
|
@ -363,17 +365,17 @@ class FeedForward(Cell):
|
|||
strategy_bias=((dp, 1), (1,)))
|
||||
self.projection.bias.parallel_optimizer = False
|
||||
self.dropout = nn.Dropout(1 - dropout_rate)
|
||||
self.dropout.dropout.shard(((dp, 1, 1),))
|
||||
self.dropout.dropout.shard(((dp, 1),))
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
_check_input_shape(F.shape(x), "x", self.cls_name, 3)
|
||||
_check_input_shape(F.shape(x), "x", self.cls_name, [2, 3])
|
||||
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
|
||||
x = self.cast(x, mstype.float16)
|
||||
# returned shape is [bs, seq_length, ffn_hidden_size]
|
||||
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
|
||||
hidden = self.mapping(x)
|
||||
output = self.projection(hidden)
|
||||
# returned shape is [bs, seq_length, hidden_size]
|
||||
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
|
@ -556,9 +558,12 @@ class MultiHeadAttention(Cell):
|
|||
an instance of `OpParallelConfig` with default args.
|
||||
|
||||
Inputs:
|
||||
- **query_tensor** (Tensor) - the query vector with shape (batch_size, src_seq_length, hidden_size).
|
||||
- **key_tensor** (Tensor) - the key vector with shape (batch_size, tgt_seq_length, hidden_size).
|
||||
- **value_tensor** (Tensor) - the value vector with shape (batch_size, tgt_seq_length, hidden_size).
|
||||
- **query_tensor** (Tensor) - the query vector with shape (batch_size, src_seq_length, hidden_size) or
|
||||
(batch_size * src_seq_length, hidden_size).
|
||||
- **key_tensor** (Tensor) - the key vector with shape (batch_size, tgt_seq_length, hidden_size) or
|
||||
(batch_size * src_seq_length, hidden_size).
|
||||
- **value_tensor** (Tensor) - the value vector with shape (batch_size, tgt_seq_length, hidden_size) or
|
||||
(batch_size * src_seq_length, hidden_size).
|
||||
- **attention_mask** (Tensor) - the attention mask matrix with shape (batch_size, src_seq_length,
|
||||
tgt_seq_length).
|
||||
- **key_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, size_per_head, tgt_seq_length).
|
||||
|
@ -574,7 +579,7 @@ class MultiHeadAttention(Cell):
|
|||
Tuple, a tuple contains(`output`, `layer_present`)
|
||||
|
||||
- **output** (Tensor) - Tensor, the float tensor of the output of the layer with
|
||||
shape (batch_size, src_seq_length, hidden_size)
|
||||
shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size)
|
||||
|
||||
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
|
||||
((batch_size, num_heads, size_per_head, tgt_seq_length),
|
||||
|
@ -683,11 +688,11 @@ class MultiHeadAttention(Cell):
|
|||
self.scale_factor = Tensor(math.sqrt(self.size_per_head))
|
||||
self.use_past = use_past
|
||||
self.dropout = nn.Dropout(1 - hidden_dropout_rate)
|
||||
self.dropout.dropout.shard(((parallel_config.data_parallel, 1, 1),))
|
||||
self.dropout.dropout.shard(((parallel_config.data_parallel, 1),))
|
||||
self.prob_dropout = nn.Dropout(1 - attention_dropout_rate)
|
||||
self.prob_dropout.dropout.shard(
|
||||
((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),))
|
||||
self.softmax = nn.Softmax()
|
||||
self.softmax = nn.Softmax().to_float(softmax_compute_type)
|
||||
self.softmax.softmax.shard(((parallel_config.data_parallel, parallel_config.model_parallel, 1),))
|
||||
self.expand_dims = P.ExpandDims().shard(((parallel_config.data_parallel, 1, 1),))
|
||||
|
||||
|
@ -737,14 +742,11 @@ class MultiHeadAttention(Cell):
|
|||
value_past=None, batch_valid_length=None):
|
||||
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
|
||||
value_past, batch_valid_length)
|
||||
query_tensor_original_shape = F.shape(query_tensor)
|
||||
query_tensor = F.reshape(query_tensor, (-1, query_tensor_original_shape[-1]))
|
||||
|
||||
key_tensor_original_shape = F.shape(key_tensor)
|
||||
key_tensor = F.reshape(key_tensor, (-1, key_tensor_original_shape[-1]))
|
||||
|
||||
value_tensor_original_shape = F.shape(value_tensor)
|
||||
value_tensor = F.reshape(value_tensor, (-1, value_tensor_original_shape[-1]))
|
||||
batch_size = F.shape(attention_mask)[0]
|
||||
query_tensor, key_tensor, value_tensor, batch_size, ori_shape = self._convert_to_2d_tensor(query_tensor,
|
||||
key_tensor,
|
||||
value_tensor,
|
||||
attention_mask)
|
||||
|
||||
# multi head attention: query, key, value are derived from the same inputs
|
||||
query = self.dense1(query_tensor)
|
||||
|
@ -754,18 +756,18 @@ class MultiHeadAttention(Cell):
|
|||
query = self.transpose(
|
||||
F.reshape(
|
||||
query,
|
||||
(-1, query_tensor_original_shape[1], self.n_head, self.size_per_head)),
|
||||
(batch_size, -1, self.n_head, self.size_per_head)),
|
||||
(0, 2, 1, 3))
|
||||
# the returned shape is [bs, num_heads, size_per_head, seq_length]
|
||||
# the returned shape is [bs, size_per_head, seq_length, num_heads]
|
||||
key = self.transpose(
|
||||
F.reshape(
|
||||
key, (-1, key_tensor_original_shape[1], self.n_head, self.size_per_head)),
|
||||
key, (batch_size, -1, self.n_head, self.size_per_head)),
|
||||
(0, 2, 3, 1))
|
||||
# the returned shape is [bs, num_heads, seq_length, size_per_head]
|
||||
value = self.transpose(
|
||||
F.reshape(
|
||||
value,
|
||||
(-1, value_tensor_original_shape[1], self.n_head, self.size_per_head)),
|
||||
(batch_size, -1, self.n_head, self.size_per_head)),
|
||||
(0, 2, 1, 3))
|
||||
# support input shape is [bs, seq, seq] or [bs, heads, seq, seq]
|
||||
if len(F.shape(attention_mask)) == 3:
|
||||
|
@ -810,22 +812,26 @@ class MultiHeadAttention(Cell):
|
|||
|
||||
layer_present = (key_present, value_present)
|
||||
# multi head attention considering attention mask
|
||||
# the return shape is [bs, seq_length, hidden_size]
|
||||
# the return shape is [bs * seq_length, hidden_size]
|
||||
attention = self._attn(query, key, value, attention_mask)
|
||||
# Output
|
||||
output = self.projection(attention)
|
||||
output = self.dropout(output)
|
||||
output = F.reshape(output, ori_shape)
|
||||
return output, layer_present
|
||||
|
||||
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
|
||||
value_past=None, batch_valid_length=None):
|
||||
r"""Check inputs"""
|
||||
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
|
||||
[self.batch_size, self.src_seq_length, self.hidden_size])
|
||||
[[self.batch_size, self.src_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.src_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
|
||||
[self.batch_size, self.tgt_seq_length, self.hidden_size])
|
||||
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
|
||||
[self.batch_size, self.tgt_seq_length, self.hidden_size])
|
||||
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, self.src_seq_length, self.tgt_seq_length])
|
||||
|
||||
|
@ -839,20 +845,30 @@ class MultiHeadAttention(Cell):
|
|||
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
|
||||
return True
|
||||
|
||||
def _convert_to_2d_tensor(self, query_tensor, key_tensor, value_tensor, attention_mask):
|
||||
"""convert a nd tensor to a 2d tensor"""
|
||||
query_shape = F.shape(query_tensor)
|
||||
query_tensor = F.reshape(query_tensor, (-1, query_shape[-1]))
|
||||
key_shape = F.shape(key_tensor)
|
||||
key_tensor = F.reshape(key_tensor, (-1, key_shape[-1]))
|
||||
value_shape = F.shape(value_tensor)
|
||||
value_tensor = F.reshape(value_tensor, (-1, value_shape[-1]))
|
||||
return query_tensor, key_tensor, value_tensor, F.shape(attention_mask)[0], query_shape
|
||||
|
||||
def _merge_heads(self, x):
|
||||
"""
|
||||
convert a 4d input to a 3d output
|
||||
convert a 4d input to a 2d output
|
||||
|
||||
Inputs:
|
||||
x: input tensor
|
||||
|
||||
Output:
|
||||
x_merge: the 3d output
|
||||
x_merge: the 2d output
|
||||
"""
|
||||
x = self.merger_head_transpose(
|
||||
x, (0, 2, 1, 3)) # bs, seq_length, head, size_per_head
|
||||
x_shape = P.Shape()(x)
|
||||
new_shape = x_shape[:-2] + (x_shape[-2] * x_shape[-1],)
|
||||
new_shape = (-1, x_shape[-2] * x_shape[-1])
|
||||
x_merge = self.reshape(x, new_shape)
|
||||
return x_merge
|
||||
|
||||
|
@ -947,7 +963,8 @@ class TransformerEncoderLayer(Cell):
|
|||
an instance of `OpParallelConfig` with default args.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size].
|
||||
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or
|
||||
[batch_size * seq_length, hidden_size].
|
||||
- **input_mask** (Tensor) - Float Tensor, attention mask with shape [batch_size, seq_length, seq_length].
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
|
||||
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
|
||||
|
@ -958,7 +975,7 @@ class TransformerEncoderLayer(Cell):
|
|||
Tuple, a tuple contains(`output`, `layer_present`).
|
||||
|
||||
- **output** (Tensor) - The float tensor of the output of the layer with
|
||||
shape (batch_size, seq_length, hidden_size).
|
||||
shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size).
|
||||
|
||||
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
|
||||
((batch_size, num_heads, size_per_head, seq_length),
|
||||
|
@ -1034,9 +1051,9 @@ class TransformerEncoderLayer(Cell):
|
|||
self.hidden_size = hidden_size
|
||||
self.batch_size = batch_size
|
||||
self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
||||
self.layernorm1.shard(((parallel_config.data_parallel, 1, 1),))
|
||||
self.layernorm1.shard(((parallel_config.data_parallel, 1),))
|
||||
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
||||
self.layernorm2.shard(((parallel_config.data_parallel, 1, 1),))
|
||||
self.layernorm2.shard(((parallel_config.data_parallel, 1),))
|
||||
|
||||
self.attention = MultiHeadAttention(batch_size=batch_size,
|
||||
src_seq_length=seq_length,
|
||||
|
@ -1067,7 +1084,8 @@ class TransformerEncoderLayer(Cell):
|
|||
hidden_act=hidden_act,
|
||||
parallel_config=parallel_config)
|
||||
self.post_layernorm_residual = post_layernorm_residual
|
||||
self.add = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
|
||||
self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1)))
|
||||
self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
|
||||
self.dtype = mstype.float16
|
||||
self.key_past = None
|
||||
self.value_past = None
|
||||
|
@ -1089,6 +1107,8 @@ class TransformerEncoderLayer(Cell):
|
|||
|
||||
def construct(self, x, input_mask, init_reset=True, batch_valid_length=None):
|
||||
self._check_input(x, input_mask, init_reset, batch_valid_length)
|
||||
x_shape = F.shape(x)
|
||||
x = F.reshape(x, (-1, x_shape[-1]))
|
||||
input_x = self.layernorm1(x)
|
||||
input_x = F.cast(input_x, self.dtype)
|
||||
|
||||
|
@ -1137,10 +1157,23 @@ class TransformerEncoderLayer(Cell):
|
|||
mlp_logit = F.depend(mlp_logit, value_update)
|
||||
mlp_logit = F.depend(mlp_logit, key_update)
|
||||
|
||||
# if shape is 3d, we reshape the inputs of the add
|
||||
if len(x_shape) == 3:
|
||||
output_x = P.Reshape()(output_x, x_shape)
|
||||
mlp_logit = P.Reshape()(mlp_logit, x_shape)
|
||||
x = P.Reshape()(x, x_shape)
|
||||
|
||||
if self.post_layernorm_residual:
|
||||
output = self.add_3d(output_x, mlp_logit)
|
||||
else:
|
||||
output = self.add_3d(x, mlp_logit)
|
||||
else:
|
||||
if self.post_layernorm_residual:
|
||||
output = self.add(output_x, mlp_logit)
|
||||
else:
|
||||
output = self.add(x, mlp_logit)
|
||||
output = F.reshape(output, x_shape)
|
||||
|
||||
if self.use_moe is True:
|
||||
return output, layer_present, aux_loss
|
||||
return output, layer_present
|
||||
|
@ -1148,7 +1181,8 @@ class TransformerEncoderLayer(Cell):
|
|||
def _check_input(self, x, input_mask, init_reset, batch_valid_length):
|
||||
r"""Check inputs"""
|
||||
_check_shape_equal(F.shape(x), "x", self.cls_name,
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
[[self.batch_size, self.seq_length, self.hidden_size],
|
||||
[self.batch_size * self.seq_length, self.hidden_size]])
|
||||
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
|
||||
[self.batch_size, self.seq_length, self.seq_length])
|
||||
|
@ -1193,10 +1227,12 @@ class TransformerDecoderLayer(Cell):
|
|||
an instance of `OpParallelConfig` with default args.
|
||||
|
||||
Inputs:
|
||||
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, tgt_seq_length, hidden_size].
|
||||
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, tgt_seq_length, hidden_size] or
|
||||
[batch_size * tgt_seq_length, hidden_size].
|
||||
- **decoder_mask** (Tensor) - the attention mask for decoder with shape [batch_size, src_seq_length,
|
||||
seq_length].
|
||||
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size].
|
||||
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or
|
||||
[batch_size * seq_length, hidden_size].
|
||||
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
|
||||
src_seq_length], where tgt_seq_length is the length of the decoder.
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
|
||||
|
@ -1207,7 +1243,8 @@ class TransformerDecoderLayer(Cell):
|
|||
Outputs:
|
||||
Tuple, a tuple contains(`output`, `layer_present`)
|
||||
|
||||
- **output** (Tensor) - the output logit of this layer. The shape is [batch, seq_length, hidden_size]
|
||||
- **output** (Tensor) - the output logit of this layer. The shape is [batch, seq_length, hidden_size] or
|
||||
[batch * seq_length, hidden_size].
|
||||
- **layer_present** (Tensor) - A tuple, where each tuple is the tensor of the projected key and value
|
||||
vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length),
|
||||
(batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector
|
||||
|
@ -1295,10 +1332,10 @@ class TransformerDecoderLayer(Cell):
|
|||
self.use_past = use_past
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.layernorm1 = _LayerNorm((hidden_size,), parallel_config.data_parallel).to_float(layernorm_compute_type)
|
||||
self.layernorm1.shard(((parallel_config.data_parallel, 1, 1),))
|
||||
self.layernorm2 = _LayerNorm((hidden_size,), parallel_config.data_parallel).to_float(layernorm_compute_type)
|
||||
self.layernorm2.shard(((parallel_config.data_parallel, 1, 1),))
|
||||
self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
||||
self.layernorm1.shard(((parallel_config.data_parallel, 1),))
|
||||
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
||||
self.layernorm2.shard(((parallel_config.data_parallel, 1),))
|
||||
|
||||
self.attention = MultiHeadAttention(hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
|
@ -1323,9 +1360,9 @@ class TransformerDecoderLayer(Cell):
|
|||
use_past=use_past,
|
||||
param_init_type=param_init_type,
|
||||
parallel_config=parallel_config)
|
||||
self.cross_attention_layernorm = _LayerNorm((hidden_size,), parallel_config.data_parallel).to_float(
|
||||
self.cross_attention_layernorm = _LayerNorm((hidden_size,)).to_float(
|
||||
layernorm_compute_type)
|
||||
self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1, 1),))
|
||||
self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1),))
|
||||
self.use_moe = (moe_config.expert_num > 1)
|
||||
if self.use_moe is True:
|
||||
self.output = MoE(hidden_size=hidden_size,
|
||||
|
@ -1344,7 +1381,8 @@ class TransformerDecoderLayer(Cell):
|
|||
param_init_type=param_init_type,
|
||||
parallel_config=parallel_config)
|
||||
self.post_layernorm_residual = post_layernorm_residual
|
||||
self.add = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
|
||||
self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1)))
|
||||
self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
|
||||
self.dtype = mstype.float16
|
||||
self.key_past = None
|
||||
self.value_past = None
|
||||
|
@ -1369,7 +1407,9 @@ class TransformerDecoderLayer(Cell):
|
|||
memory_mask=None,
|
||||
init_reset=True, batch_valid_length=None):
|
||||
self._check_input(hidden_stats, decoder_mask, encoder_output, memory_mask, init_reset, batch_valid_length)
|
||||
# the returned shape is [bs, seq_length, embedding_size]
|
||||
# the returned shape is [bs, seq_length, embedding_size] or [bs * seq_length, embedding_size]
|
||||
hidden_shape = F.shape(hidden_stats)
|
||||
hidden_stats = F.reshape(hidden_stats, (-1, hidden_shape[-1]))
|
||||
input_x = self.layernorm1(hidden_stats)
|
||||
input_x = F.cast(input_x, self.dtype)
|
||||
|
||||
|
@ -1431,10 +1471,23 @@ class TransformerDecoderLayer(Cell):
|
|||
mlp_logit = F.depend(mlp_logit, value_update)
|
||||
mlp_logit = F.depend(mlp_logit, key_update)
|
||||
|
||||
# if shape is 3d, we reshape the inputs of the add
|
||||
if len(hidden_shape) == 3:
|
||||
output_x = P.Reshape()(output_x, hidden_shape)
|
||||
mlp_logit = P.Reshape()(mlp_logit, hidden_shape)
|
||||
x = P.Reshape()(x, hidden_shape)
|
||||
|
||||
if self.post_layernorm_residual:
|
||||
output = self.add_3d(output_x, mlp_logit)
|
||||
else:
|
||||
output = self.add_3d(x, mlp_logit)
|
||||
else:
|
||||
if self.post_layernorm_residual:
|
||||
output = self.add(output_x, mlp_logit)
|
||||
else:
|
||||
output = self.add(x, mlp_logit)
|
||||
output = F.reshape(output, hidden_shape)
|
||||
|
||||
if self.use_moe is True:
|
||||
return output, layer_present, aux_loss
|
||||
return output, layer_present
|
||||
|
@ -1442,14 +1495,16 @@ class TransformerDecoderLayer(Cell):
|
|||
def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length):
|
||||
r"""Check inputs"""
|
||||
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
|
||||
[self.batch_size, self.tgt_seq_length, self.hidden_size])
|
||||
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, self.tgt_seq_length, self.tgt_seq_length])
|
||||
_check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
|
||||
if encoder_output is not None:
|
||||
_check_shape_equal(F.shape(encoder_output), "encoder_output", self.cls_name,
|
||||
[self.batch_size, self.src_seq_length, self.hidden_size])
|
||||
[[self.batch_size, self.src_seq_length, self.hidden_size],
|
||||
[self.batch_size * self.src_seq_length, self.hidden_size]])
|
||||
_check_input_dtype(F.dtype(encoder_output), "encoder_output",
|
||||
[mstype.float32, mstype.float16], self.cls_name)
|
||||
if memory_mask is not None:
|
||||
|
@ -1547,7 +1602,8 @@ class TransformerEncoder(Cell):
|
|||
an instance of `TransformerOpParallelConfig` with default args.
|
||||
|
||||
Inputs:
|
||||
- **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size]
|
||||
- **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size] or
|
||||
[batch_size * seq_length, hidden_size]
|
||||
- **attention_mask** (Tensor) - Tensor, attention mask with shape [batch_size, seq_length, seq_length]
|
||||
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
|
||||
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
|
||||
|
@ -1558,7 +1614,7 @@ class TransformerEncoder(Cell):
|
|||
Tuple, a tuple contains(`output`, `layer_present`)
|
||||
|
||||
- **output** (Tensor) - The float tensor of the output of the layer with
|
||||
shape (batch_size, seq_length, hidden_size)
|
||||
shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size)
|
||||
- **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple contains the Tensor the
|
||||
projected key and value vector with shape ((batch_size, num_heads, size_per_head, seq_length),
|
||||
and (batch_size, num_heads, seq_length, size_per_head)).
|
||||
|
@ -1717,9 +1773,11 @@ class TransformerDecoder(Cell):
|
|||
an instance of `TransformerOpParallelConfig` with default args.
|
||||
|
||||
Inputs:
|
||||
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size]
|
||||
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size] or
|
||||
[batch_size * seq_length, hidden_size]
|
||||
- **attention_mask** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length]
|
||||
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size]
|
||||
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or
|
||||
[batch_size * seq_length, hidden_size]
|
||||
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
|
||||
src_seq_length] where tgt_seq_length is the length of the decoder. the output of the encoder with shape
|
||||
[batch_size, seq_length, hidden_size],
|
||||
|
@ -1731,7 +1789,8 @@ class TransformerDecoder(Cell):
|
|||
Outputs:
|
||||
Tuple, a tuple contains(`output`, `layer_present`)
|
||||
|
||||
- **output** (Tensor) - The output logit of this layer. The shape is [batch, tgt_seq_length, hidden_size]
|
||||
- **output** (Tensor) - The output logit of this layer. The shape is [batch, tgt_seq_length, hidden_size] or
|
||||
[batch * tgt_seq_length, hidden_size]
|
||||
- **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor of the projected
|
||||
key and value vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length),
|
||||
(batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector
|
||||
|
@ -1912,9 +1971,11 @@ class Transformer(Cell):
|
|||
an instance of `TransformerOpParallelConfig` with default args.
|
||||
|
||||
Inputs:
|
||||
- **encoder_inputs** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size].
|
||||
- **encoder_inputs** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size] or
|
||||
[batch_size * seq_length, hidden_size].
|
||||
- **encoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length].
|
||||
- **decoder_inputs** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size],
|
||||
- **decoder_inputs** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or
|
||||
[batch_size * seq_length, hidden_size],
|
||||
this should be none if the decoder layer is 0.
|
||||
- **decoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length]
|
||||
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
|
||||
|
@ -1930,8 +1991,9 @@ class Transformer(Cell):
|
|||
Tuple, a tuple contains(`output`, `encoder_layer_present`, `encoder_layer_present`)
|
||||
|
||||
- **output** (Tensor) - If there is only encoder, the output logit of the encoder layer. The shape is
|
||||
[batch, src_seq_length, hidden_size], if there are encoder and decoders, the output is from the
|
||||
decoder layer. The shape is [batch, tgt_seq_length, hidden_size].
|
||||
[batch, src_seq_length, hidden_size] or [batch * src_seq_length, hidden_size], if there are encoder and
|
||||
decoders, the output is from the decoder layer. The shape is [batch, tgt_seq_length, hidden_size] or
|
||||
[batch * tgt_seq_length, hidden_size].
|
||||
- **encoder_layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor the
|
||||
projected key and value vector in self attention with shape ((batch_size, num_heads, size_per_head,
|
||||
src_seq_length), (batch_size, num_heads, src_seq_length, size_per_head)).
|
||||
|
|
|
@ -20,7 +20,7 @@ from mindspore.context import set_auto_parallel_context, ParallelMode
|
|||
from mindspore.ops import composite as C
|
||||
from mindspore.parallel.nn import Transformer, TransformerOpParallelConfig, MoEConfig
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell
|
||||
from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell, _VirtualDatasetCell
|
||||
from mindspore.train import Model
|
||||
from tests.dataset_mock import MindData
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
@ -77,7 +77,7 @@ def test_transformer_model():
|
|||
ffn_hidden_size=64,
|
||||
moe_config=moe_config,
|
||||
parallel_config=config)
|
||||
|
||||
net = _VirtualDatasetCell(net)
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
|
||||
|
@ -92,3 +92,35 @@ def test_transformer_model():
|
|||
model = Model(net_with_grad)
|
||||
|
||||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_transformer_model_2d():
|
||||
set_auto_parallel_context(device_num=16, global_rank=0,
|
||||
full_batch=True, enable_alltoall=True,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
net = Transformer(encoder_layers=1,
|
||||
decoder_layers=1,
|
||||
batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
hidden_size=64,
|
||||
num_heads=8,
|
||||
ffn_hidden_size=64,
|
||||
moe_config=moe_config,
|
||||
parallel_config=config)
|
||||
net = _VirtualDatasetCell(net)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((40, 64)), mstype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
|
||||
decoder_input_value = Tensor(np.ones((20, 64)), mstype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
|
||||
net = NetWithLossFiveInputs(net)
|
||||
params = net.trainable_params()
|
||||
optimizer = AdamWeightDecay(params)
|
||||
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||
memory_mask)
|
||||
net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
|
||||
model = Model(net_with_grad)
|
||||
|
||||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
|
|
@ -56,6 +56,28 @@ class Dataset(MindData):
|
|||
self.index = 0
|
||||
|
||||
|
||||
class TransformerNet(nn.Cell):
|
||||
def __init__(self, en_layer, de_layer, parallel_config):
|
||||
super(TransformerNet, self).__init__()
|
||||
self.embedding = VocabEmbedding(vocab_size=240, embedding_size=20,
|
||||
parallel_config=config.embedding_dp_mp_config)
|
||||
self.network = Transformer(encoder_layers=en_layer,
|
||||
decoder_layers=de_layer,
|
||||
batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
hidden_size=64,
|
||||
num_heads=8,
|
||||
ffn_hidden_size=64,
|
||||
parallel_config=parallel_config)
|
||||
self.head = Linear(in_channels=64, out_channels=200)
|
||||
self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config)
|
||||
|
||||
def construct(self, x1, x2, x3, x4, x5, y, mask):
|
||||
predict, _, _ = self.network(x1, x2, x3, x4, x5)
|
||||
predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
|
||||
return self.loss(predict, y, mask)
|
||||
|
||||
config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
|
||||
pipeline_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, pipeline_stage=4,
|
||||
micro_batch_num=4, vocab_emb_dp=False)
|
||||
|
@ -84,28 +106,6 @@ def run_total_transformer_model_head(e_layer,
|
|||
full_batch=True,
|
||||
global_rank=0, parallel_mode=mode)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, en_layer, de_layer, parallel_config):
|
||||
super(Net, self).__init__()
|
||||
self.embedding = VocabEmbedding(vocab_size=240, embedding_size=20,
|
||||
parallel_config=config.embedding_dp_mp_config)
|
||||
self.network = Transformer(encoder_layers=en_layer,
|
||||
decoder_layers=de_layer,
|
||||
batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
hidden_size=64,
|
||||
num_heads=8,
|
||||
ffn_hidden_size=64,
|
||||
parallel_config=parallel_config)
|
||||
self.head = Linear(in_channels=64, out_channels=200)
|
||||
self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config)
|
||||
|
||||
def construct(self, x1, x2, x3, x4, x5, y, mask):
|
||||
predict, _, _ = self.network(x1, x2, x3, x4, x5)
|
||||
predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
|
||||
return self.loss(predict, y, mask)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
|
||||
|
@ -116,7 +116,8 @@ def run_total_transformer_model_head(e_layer,
|
|||
seq = 10
|
||||
label = Tensor(np.ones((2 * seq,)), mstype.int32)
|
||||
input_mask = Tensor(np.ones((2 * seq,)), mstype.float32)
|
||||
net = Net(en_layer=e_layer, de_layer=d_layer, parallel_config=arg_parallel_config)
|
||||
net = TransformerNet(en_layer=e_layer, de_layer=d_layer, parallel_config=arg_parallel_config)
|
||||
net = _VirtualDatasetCell(net)
|
||||
params = net.trainable_params()
|
||||
optimizer = AdamWeightDecay(params)
|
||||
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||
|
@ -147,6 +148,38 @@ def test_transformer_model():
|
|||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
|
||||
net = NetWithLossFiveInputs(net)
|
||||
net = _VirtualDatasetCell(net)
|
||||
params = net.trainable_params()
|
||||
optimizer = AdamWeightDecay(params)
|
||||
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||
memory_mask)
|
||||
net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
|
||||
model = Model(net_with_grad)
|
||||
|
||||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_transformer_model_2d_inputs():
|
||||
set_auto_parallel_context(device_num=8, global_rank=0,
|
||||
full_batch=True,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
net = Transformer(encoder_layers=1,
|
||||
decoder_layers=2,
|
||||
batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
hidden_size=64,
|
||||
num_heads=8,
|
||||
ffn_hidden_size=64,
|
||||
parallel_config=config)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((40, 64)), mstype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
|
||||
decoder_input_value = Tensor(np.ones((20, 64)), mstype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
|
||||
net = NetWithLossFiveInputs(net)
|
||||
net = _VirtualDatasetCell(net)
|
||||
params = net.trainable_params()
|
||||
optimizer = AdamWeightDecay(params)
|
||||
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||
|
@ -177,6 +210,7 @@ def test_transformer_model_int64_inputs():
|
|||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
|
||||
net = NetWithLossFiveInputs(net)
|
||||
net = _VirtualDatasetCell(net)
|
||||
params = net.trainable_params()
|
||||
optimizer = AdamWeightDecay(params)
|
||||
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||
|
@ -330,6 +364,8 @@ def test_encoder():
|
|||
|
||||
net = NetWithLoss(net)
|
||||
|
||||
net = _VirtualDatasetCell(net)
|
||||
|
||||
dataset = Dataset(encoder_input_value, encoder_input_mask)
|
||||
|
||||
model = Model(net)
|
||||
|
@ -367,6 +403,8 @@ def test_decoder():
|
|||
|
||||
net = NetWithLoss(net)
|
||||
|
||||
net = _VirtualDatasetCell(net)
|
||||
|
||||
dataset = Dataset(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
|
||||
|
||||
model = Model(net)
|
||||
|
@ -388,6 +426,7 @@ def test_vocabembedding_dp_true():
|
|||
|
||||
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
|
||||
net = NetWithLoss(net)
|
||||
net = _VirtualDatasetCell(net)
|
||||
encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
|
||||
dataset = Dataset(encoder_input_value)
|
||||
|
||||
|
@ -410,6 +449,7 @@ def test_vocabembedding_dp_false():
|
|||
|
||||
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
|
||||
net = NetWithLoss(net)
|
||||
net = _VirtualDatasetCell(net)
|
||||
encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
|
||||
dataset = Dataset(encoder_input_value)
|
||||
|
||||
|
@ -484,6 +524,7 @@ def test_sparse_attention_parallel_dp():
|
|||
num_heads=8,
|
||||
block_size=64,
|
||||
parallel_config=sparse_attention_config)
|
||||
net = _VirtualDatasetCell(net)
|
||||
q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||
k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||
v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||
|
@ -509,6 +550,7 @@ def test_parallel_cross_entroy_loss_semi_auto_parallel():
|
|||
|
||||
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
|
||||
net = NetWithLoss(net, config.dp_mp_config)
|
||||
net = _VirtualDatasetCell(net)
|
||||
embed_ids = Tensor(np.ones((2, 64)), mstype.int32)
|
||||
labels = Tensor(np.ones((2 * 64,)), mstype.int32)
|
||||
input_mask = Tensor(np.ones((2 * 64,)), mstype.float32)
|
||||
|
|
Loading…
Reference in New Issue