forked from mindspore-Ecosystem/mindspore
update pangu reshape and softmax performance.
Add layer norm judge Fix layer norm name error Fix input tyoe check Fix ut test Add 3d supports
This commit is contained in:
parent
04a4eecca6
commit
acde7febef
|
@ -26,7 +26,7 @@ import mindspore.common.dtype as mstype
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
from mindspore.nn.cell import Cell
|
from mindspore.nn.cell import Cell
|
||||||
import mindspore.nn as nn
|
from mindspore import nn
|
||||||
from mindspore.nn.layer.activation import get_activation
|
from mindspore.nn.layer.activation import get_activation
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
from mindspore._checkparam import Validator
|
from mindspore._checkparam import Validator
|
||||||
|
@ -87,11 +87,65 @@ def _valid_value_checks(types, class_name):
|
||||||
return validator_check_func
|
return validator_check_func
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
class _LayerInputCheck:
|
||||||
def _check_input_shape(input_shape, param_name, func_name, target_len):
|
"""
|
||||||
if len(input_shape) != target_len:
|
A input check class for the inputs of the transformer model.
|
||||||
raise ValueError(f"{func_name} {param_name} should be {target_len}d, but got shape {input_shape}")
|
"""
|
||||||
return True
|
@staticmethod
|
||||||
|
def check_shape_length(input_shape, param_name, func_name, target_len):
|
||||||
|
"""
|
||||||
|
Check the input shape's length is equal to the expected shape
|
||||||
|
:param input_shape(list): a list of the tensor shapes.
|
||||||
|
:param param_name(str): the name of the checked parameter.
|
||||||
|
:param func_name(str): the name of the function.
|
||||||
|
:param target_len: the expected length of the shape.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if not isinstance(target_len, list):
|
||||||
|
target_len = [target_len]
|
||||||
|
matched = False
|
||||||
|
for item in target_len:
|
||||||
|
if len(input_shape) == item:
|
||||||
|
matched = True
|
||||||
|
if not matched:
|
||||||
|
raise ValueError(f"{func_name} {param_name} shape length should be one of {target_len} dimension, "
|
||||||
|
f"but got shape {input_shape}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_shape_equal(input_shape, param_name, func_name, target_shape):
|
||||||
|
"""
|
||||||
|
Check the input shape's is equal to the expected shape
|
||||||
|
:param input_shape(list): a list of the tensor shapes.
|
||||||
|
:param param_name(str): the name of the checked parameter.
|
||||||
|
:param func_name(str): the name of the function.
|
||||||
|
:param target_shape: the expected shape.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if not isinstance(target_shape[0], list):
|
||||||
|
target_shape = [target_shape]
|
||||||
|
if isinstance(input_shape, tuple):
|
||||||
|
input_shape = list(input_shape)
|
||||||
|
_LayerInputCheck.check_shape_length(input_shape, param_name, func_name,
|
||||||
|
[len(item) for item in target_shape])
|
||||||
|
matched = False
|
||||||
|
for item in target_shape:
|
||||||
|
if item == input_shape:
|
||||||
|
matched = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not matched:
|
||||||
|
raise ValueError(f"{func_name} {param_name} shape should be one of {target_shape},"
|
||||||
|
f"but got {input_shape}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value):
|
||||||
|
if input_shape[dim] != target_value:
|
||||||
|
raise ValueError(f"{cls_name} {param_name} at {dim} shape should be {target_value},"
|
||||||
|
f"but got {input_shape[dim]}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
|
@ -104,28 +158,27 @@ def _check_past_none_input_none(use_past, param_name, func_name, input_tensor, d
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
|
||||||
def _check_shape_equal(input_shape, param_name, func_name, target_shape):
|
|
||||||
if len(input_shape) != len(target_shape):
|
|
||||||
raise ValueError(f"{func_name} {param_name} shape should be {target_shape},"
|
|
||||||
f"but got {input_shape}")
|
|
||||||
for i in range(len(input_shape)):
|
|
||||||
if input_shape[i] != target_shape[i]:
|
|
||||||
raise ValueError(f"{func_name} {param_name} shape should be {target_shape},"
|
|
||||||
f"but got {input_shape}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
||||||
Validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
Validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _check_input_shape(input_shape, param_name, func_name, target_len):
|
||||||
|
# check the input length
|
||||||
|
_LayerInputCheck.check_shape_length(input_shape, param_name, func_name, target_len)
|
||||||
|
|
||||||
|
|
||||||
|
@constexpr
|
||||||
|
def _check_shape_equal(input_shape, param_name, func_name, target_shape):
|
||||||
|
# check the input length
|
||||||
|
_LayerInputCheck.check_shape_equal(input_shape, param_name, func_name, target_shape)
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _check_input_shape_value(input_shape, dim, param_name, cls_name, target_value):
|
def _check_input_shape_value(input_shape, dim, param_name, cls_name, target_value):
|
||||||
if input_shape[dim] != target_value:
|
_LayerInputCheck.check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value)
|
||||||
raise ValueError(f"{cls_name} {param_name} at {dim} shape should be {target_value},"
|
|
||||||
f"but got {input_shape[dim]}")
|
|
||||||
|
|
||||||
|
|
||||||
class _LayerNorm(Cell):
|
class _LayerNorm(Cell):
|
||||||
|
@ -147,6 +200,11 @@ class _LayerNorm(Cell):
|
||||||
super(_LayerNorm, self).__init__()
|
super(_LayerNorm, self).__init__()
|
||||||
if param_init_type not in [mstype.float32, mstype.float16]:
|
if param_init_type not in [mstype.float32, mstype.float16]:
|
||||||
raise TypeError(f"param type should in [float32, float16], but found type {type(param_init_type)}")
|
raise TypeError(f"param type should in [float32, float16], but found type {type(param_init_type)}")
|
||||||
|
if normalized_shape[0] <= 1024:
|
||||||
|
self.layer_norm = P.LayerNorm(begin_norm_axis=-1,
|
||||||
|
begin_params_axis=-1,
|
||||||
|
epsilon=eps)
|
||||||
|
self.is_self_defined = normalized_shape[0] > 1024
|
||||||
self.gamma = Parameter(initializer('ones', normalized_shape, param_init_type), name="gamma",
|
self.gamma = Parameter(initializer('ones', normalized_shape, param_init_type), name="gamma",
|
||||||
parallel_optimizer=False)
|
parallel_optimizer=False)
|
||||||
self.beta = Parameter(initializer('zeros', normalized_shape, param_init_type), name="beta",
|
self.beta = Parameter(initializer('zeros', normalized_shape, param_init_type), name="beta",
|
||||||
|
@ -166,12 +224,15 @@ class _LayerNorm(Cell):
|
||||||
r"""
|
r"""
|
||||||
x : batch x seq_length x hidden_size
|
x : batch x seq_length x hidden_size
|
||||||
"""
|
"""
|
||||||
mean = self.mean(x, -1)
|
if self.is_self_defined:
|
||||||
diff = self.sub1(x, mean)
|
mean = self.mean(x, -1)
|
||||||
variance = self.mean(self.square(diff), -1)
|
diff = self.sub1(x, mean)
|
||||||
variance_eps = self.sqrt(self.add(variance, self.eps))
|
variance = self.mean(self.square(diff), -1)
|
||||||
output = self.real_div(diff, variance_eps)
|
variance_eps = self.sqrt(self.add(variance, self.eps))
|
||||||
output = self.add2(self.mul(output, self.gamma), self.beta)
|
output = self.real_div(diff, variance_eps)
|
||||||
|
output = self.add2(self.mul(output, self.gamma), self.beta)
|
||||||
|
else:
|
||||||
|
output, _, _ = self.layer_norm(x, self.gamma, self.beta)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def shard(self, strategy):
|
def shard(self, strategy):
|
||||||
|
@ -188,15 +249,18 @@ class _LayerNorm(Cell):
|
||||||
>>> net = mindspore.parallel.nn.transformer.LayerNorm(normalized_shape=(1024, 10))
|
>>> net = mindspore.parallel.nn.transformer.LayerNorm(normalized_shape=(1024, 10))
|
||||||
>>> net.shard(((10, 2, 1),))
|
>>> net.shard(((10, 2, 1),))
|
||||||
"""
|
"""
|
||||||
self.mean.shard(strategy)
|
if self.is_self_defined:
|
||||||
self.square.shard(strategy)
|
self.mean.shard(strategy)
|
||||||
self.sqrt.shard(strategy)
|
self.square.shard(strategy)
|
||||||
self.sub1.shard((strategy[0], strategy[0]))
|
self.sqrt.shard(strategy)
|
||||||
self.sub2.shard((strategy[0], strategy[0]))
|
self.sub1.shard((strategy[0], strategy[0]))
|
||||||
self.add.shard((strategy[0], ()))
|
self.sub2.shard((strategy[0], strategy[0]))
|
||||||
self.mul.shard((strategy[0], (1,)))
|
self.add.shard((strategy[0], ()))
|
||||||
self.add2.shard((strategy[0], (1,)))
|
self.mul.shard((strategy[0], (1,)))
|
||||||
self.real_div.shard((strategy[0], strategy[0]))
|
self.add2.shard((strategy[0], (1,)))
|
||||||
|
self.real_div.shard((strategy[0], strategy[0]))
|
||||||
|
else:
|
||||||
|
self.layer_norm.shard((strategy[0], (1,), (1,)))
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -125,7 +125,7 @@ class MoE(Cell):
|
||||||
|
|
||||||
|
|
||||||
def construct(self, input_tensor):
|
def construct(self, input_tensor):
|
||||||
bs = self.shape(input_tensor)[0]
|
input_shape = F.shape(input_tensor)
|
||||||
input_tensor = self.reshape(input_tensor, (-1, self.hidden_size))
|
input_tensor = self.reshape(input_tensor, (-1, self.hidden_size))
|
||||||
bs_and_dmodel = self.shape(input_tensor)
|
bs_and_dmodel = self.shape(input_tensor)
|
||||||
tokens_per_device = bs_and_dmodel[0] / self.expert_parallel
|
tokens_per_device = bs_and_dmodel[0] / self.expert_parallel
|
||||||
|
@ -148,7 +148,7 @@ class MoE(Cell):
|
||||||
expert_capacity))
|
expert_capacity))
|
||||||
# expert_input's shape: (self.expert_dim, self.expert_parallel, expert_capacity, self.hidden_size)
|
# expert_input's shape: (self.expert_dim, self.expert_parallel, expert_capacity, self.hidden_size)
|
||||||
expert_input = self.transpose2(expert_input, (2, 0, 3, 1))
|
expert_input = self.transpose2(expert_input, (2, 0, 3, 1))
|
||||||
expert_input = self.reshape(expert_input, (self.expert_dim, self.expert_parallel * expert_capacity,
|
expert_input = self.reshape(expert_input, (self.expert_dim * self.expert_parallel * expert_capacity,
|
||||||
self.hidden_size))
|
self.hidden_size))
|
||||||
|
|
||||||
# expert_output's shape: (self.expert_dim, self.expert_parallel*expert_capacity, self.hidden_size)
|
# expert_output's shape: (self.expert_dim, self.expert_parallel*expert_capacity, self.hidden_size)
|
||||||
|
@ -170,7 +170,7 @@ class MoE(Cell):
|
||||||
# combined_output's shape: (self.expert_parallel, tokens_per_device, self.hidden_size)
|
# combined_output's shape: (self.expert_parallel, tokens_per_device, self.hidden_size)
|
||||||
combined_output = self.transpose5(combined_output, (0, 2, 1))
|
combined_output = self.transpose5(combined_output, (0, 2, 1))
|
||||||
combined_output = self.reshape(combined_output, (bs_and_dmodel[0], bs_and_dmodel[1]))
|
combined_output = self.reshape(combined_output, (bs_and_dmodel[0], bs_and_dmodel[1]))
|
||||||
combined_output = self.reshape(combined_output, (bs, -1, self.hidden_size))
|
combined_output = self.reshape(combined_output, input_shape)
|
||||||
|
|
||||||
aux_loss = self.mul(self.aux_loss_factor, aux_loss)
|
aux_loss = self.mul(self.aux_loss_factor, aux_loss)
|
||||||
return combined_output, aux_loss
|
return combined_output, aux_loss
|
||||||
|
|
|
@ -281,10 +281,12 @@ class FeedForward(Cell):
|
||||||
default args.
|
default args.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) - should be `[batch, seq_length, hidden_size]`. Float tensor.
|
- **x** (Tensor) - should be `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
|
||||||
|
Float tensor.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tensor, the output of this layer after mapping. The shape is `[batch, seq_length, hidden_size]`.
|
Tensor, the output of this layer after mapping.
|
||||||
|
The shape is `[batch, seq_length, hidden_size] or [batch * seq_length, hidden_size]`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: `hidden_act` is not a string.
|
ValueError: `hidden_act` is not a string.
|
||||||
|
@ -344,11 +346,11 @@ class FeedForward(Cell):
|
||||||
if expert_num > 1:
|
if expert_num > 1:
|
||||||
self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)),
|
self.mapping.shard(strategy_matmul=((ep, 1, 1), (ep, 1, mp)),
|
||||||
strategy_bias=((ep, 1, mp), (mp,)),
|
strategy_bias=((ep, 1, mp), (mp,)),
|
||||||
strategy_activation=((ep, 1, mp),))
|
strategy_activation=((ep, mp),))
|
||||||
else:
|
else:
|
||||||
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
|
self.mapping.shard(strategy_matmul=((dp, 1), (1, mp)),
|
||||||
strategy_bias=((dp, mp), (mp,)),
|
strategy_bias=((dp, mp), (mp,)),
|
||||||
strategy_activation=((dp, 1, mp),))
|
strategy_activation=((dp, mp),))
|
||||||
# Project back to hidden_size
|
# Project back to hidden_size
|
||||||
self.projection = _Linear(in_channels=output_size,
|
self.projection = _Linear(in_channels=output_size,
|
||||||
out_channels=input_size,
|
out_channels=input_size,
|
||||||
|
@ -363,17 +365,17 @@ class FeedForward(Cell):
|
||||||
strategy_bias=((dp, 1), (1,)))
|
strategy_bias=((dp, 1), (1,)))
|
||||||
self.projection.bias.parallel_optimizer = False
|
self.projection.bias.parallel_optimizer = False
|
||||||
self.dropout = nn.Dropout(1 - dropout_rate)
|
self.dropout = nn.Dropout(1 - dropout_rate)
|
||||||
self.dropout.dropout.shard(((dp, 1, 1),))
|
self.dropout.dropout.shard(((dp, 1),))
|
||||||
self.cast = P.Cast()
|
self.cast = P.Cast()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
_check_input_shape(F.shape(x), "x", self.cls_name, 3)
|
_check_input_shape(F.shape(x), "x", self.cls_name, [2, 3])
|
||||||
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
|
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
|
||||||
x = self.cast(x, mstype.float16)
|
x = self.cast(x, mstype.float16)
|
||||||
# returned shape is [bs, seq_length, ffn_hidden_size]
|
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
|
||||||
hidden = self.mapping(x)
|
hidden = self.mapping(x)
|
||||||
output = self.projection(hidden)
|
output = self.projection(hidden)
|
||||||
# returned shape is [bs, seq_length, hidden_size]
|
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
|
||||||
output = self.dropout(output)
|
output = self.dropout(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -556,9 +558,12 @@ class MultiHeadAttention(Cell):
|
||||||
an instance of `OpParallelConfig` with default args.
|
an instance of `OpParallelConfig` with default args.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **query_tensor** (Tensor) - the query vector with shape (batch_size, src_seq_length, hidden_size).
|
- **query_tensor** (Tensor) - the query vector with shape (batch_size, src_seq_length, hidden_size) or
|
||||||
- **key_tensor** (Tensor) - the key vector with shape (batch_size, tgt_seq_length, hidden_size).
|
(batch_size * src_seq_length, hidden_size).
|
||||||
- **value_tensor** (Tensor) - the value vector with shape (batch_size, tgt_seq_length, hidden_size).
|
- **key_tensor** (Tensor) - the key vector with shape (batch_size, tgt_seq_length, hidden_size) or
|
||||||
|
(batch_size * src_seq_length, hidden_size).
|
||||||
|
- **value_tensor** (Tensor) - the value vector with shape (batch_size, tgt_seq_length, hidden_size) or
|
||||||
|
(batch_size * src_seq_length, hidden_size).
|
||||||
- **attention_mask** (Tensor) - the attention mask matrix with shape (batch_size, src_seq_length,
|
- **attention_mask** (Tensor) - the attention mask matrix with shape (batch_size, src_seq_length,
|
||||||
tgt_seq_length).
|
tgt_seq_length).
|
||||||
- **key_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, size_per_head, tgt_seq_length).
|
- **key_past** (Tensor) - Float16 tensor with shape (batch_size, num_heads, size_per_head, tgt_seq_length).
|
||||||
|
@ -574,7 +579,7 @@ class MultiHeadAttention(Cell):
|
||||||
Tuple, a tuple contains(`output`, `layer_present`)
|
Tuple, a tuple contains(`output`, `layer_present`)
|
||||||
|
|
||||||
- **output** (Tensor) - Tensor, the float tensor of the output of the layer with
|
- **output** (Tensor) - Tensor, the float tensor of the output of the layer with
|
||||||
shape (batch_size, src_seq_length, hidden_size)
|
shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size)
|
||||||
|
|
||||||
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
|
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
|
||||||
((batch_size, num_heads, size_per_head, tgt_seq_length),
|
((batch_size, num_heads, size_per_head, tgt_seq_length),
|
||||||
|
@ -683,11 +688,11 @@ class MultiHeadAttention(Cell):
|
||||||
self.scale_factor = Tensor(math.sqrt(self.size_per_head))
|
self.scale_factor = Tensor(math.sqrt(self.size_per_head))
|
||||||
self.use_past = use_past
|
self.use_past = use_past
|
||||||
self.dropout = nn.Dropout(1 - hidden_dropout_rate)
|
self.dropout = nn.Dropout(1 - hidden_dropout_rate)
|
||||||
self.dropout.dropout.shard(((parallel_config.data_parallel, 1, 1),))
|
self.dropout.dropout.shard(((parallel_config.data_parallel, 1),))
|
||||||
self.prob_dropout = nn.Dropout(1 - attention_dropout_rate)
|
self.prob_dropout = nn.Dropout(1 - attention_dropout_rate)
|
||||||
self.prob_dropout.dropout.shard(
|
self.prob_dropout.dropout.shard(
|
||||||
((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),))
|
((parallel_config.data_parallel, parallel_config.model_parallel, 1, 1),))
|
||||||
self.softmax = nn.Softmax()
|
self.softmax = nn.Softmax().to_float(softmax_compute_type)
|
||||||
self.softmax.softmax.shard(((parallel_config.data_parallel, parallel_config.model_parallel, 1),))
|
self.softmax.softmax.shard(((parallel_config.data_parallel, parallel_config.model_parallel, 1),))
|
||||||
self.expand_dims = P.ExpandDims().shard(((parallel_config.data_parallel, 1, 1),))
|
self.expand_dims = P.ExpandDims().shard(((parallel_config.data_parallel, 1, 1),))
|
||||||
|
|
||||||
|
@ -737,14 +742,11 @@ class MultiHeadAttention(Cell):
|
||||||
value_past=None, batch_valid_length=None):
|
value_past=None, batch_valid_length=None):
|
||||||
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
|
self._check_inputs(query_tensor, key_tensor, value_tensor, attention_mask, key_past,
|
||||||
value_past, batch_valid_length)
|
value_past, batch_valid_length)
|
||||||
query_tensor_original_shape = F.shape(query_tensor)
|
batch_size = F.shape(attention_mask)[0]
|
||||||
query_tensor = F.reshape(query_tensor, (-1, query_tensor_original_shape[-1]))
|
query_tensor, key_tensor, value_tensor, batch_size, ori_shape = self._convert_to_2d_tensor(query_tensor,
|
||||||
|
key_tensor,
|
||||||
key_tensor_original_shape = F.shape(key_tensor)
|
value_tensor,
|
||||||
key_tensor = F.reshape(key_tensor, (-1, key_tensor_original_shape[-1]))
|
attention_mask)
|
||||||
|
|
||||||
value_tensor_original_shape = F.shape(value_tensor)
|
|
||||||
value_tensor = F.reshape(value_tensor, (-1, value_tensor_original_shape[-1]))
|
|
||||||
|
|
||||||
# multi head attention: query, key, value are derived from the same inputs
|
# multi head attention: query, key, value are derived from the same inputs
|
||||||
query = self.dense1(query_tensor)
|
query = self.dense1(query_tensor)
|
||||||
|
@ -754,18 +756,18 @@ class MultiHeadAttention(Cell):
|
||||||
query = self.transpose(
|
query = self.transpose(
|
||||||
F.reshape(
|
F.reshape(
|
||||||
query,
|
query,
|
||||||
(-1, query_tensor_original_shape[1], self.n_head, self.size_per_head)),
|
(batch_size, -1, self.n_head, self.size_per_head)),
|
||||||
(0, 2, 1, 3))
|
(0, 2, 1, 3))
|
||||||
# the returned shape is [bs, num_heads, size_per_head, seq_length]
|
# the returned shape is [bs, size_per_head, seq_length, num_heads]
|
||||||
key = self.transpose(
|
key = self.transpose(
|
||||||
F.reshape(
|
F.reshape(
|
||||||
key, (-1, key_tensor_original_shape[1], self.n_head, self.size_per_head)),
|
key, (batch_size, -1, self.n_head, self.size_per_head)),
|
||||||
(0, 2, 3, 1))
|
(0, 2, 3, 1))
|
||||||
# the returned shape is [bs, num_heads, seq_length, size_per_head]
|
# the returned shape is [bs, num_heads, seq_length, size_per_head]
|
||||||
value = self.transpose(
|
value = self.transpose(
|
||||||
F.reshape(
|
F.reshape(
|
||||||
value,
|
value,
|
||||||
(-1, value_tensor_original_shape[1], self.n_head, self.size_per_head)),
|
(batch_size, -1, self.n_head, self.size_per_head)),
|
||||||
(0, 2, 1, 3))
|
(0, 2, 1, 3))
|
||||||
# support input shape is [bs, seq, seq] or [bs, heads, seq, seq]
|
# support input shape is [bs, seq, seq] or [bs, heads, seq, seq]
|
||||||
if len(F.shape(attention_mask)) == 3:
|
if len(F.shape(attention_mask)) == 3:
|
||||||
|
@ -810,22 +812,26 @@ class MultiHeadAttention(Cell):
|
||||||
|
|
||||||
layer_present = (key_present, value_present)
|
layer_present = (key_present, value_present)
|
||||||
# multi head attention considering attention mask
|
# multi head attention considering attention mask
|
||||||
# the return shape is [bs, seq_length, hidden_size]
|
# the return shape is [bs * seq_length, hidden_size]
|
||||||
attention = self._attn(query, key, value, attention_mask)
|
attention = self._attn(query, key, value, attention_mask)
|
||||||
# Output
|
# Output
|
||||||
output = self.projection(attention)
|
output = self.projection(attention)
|
||||||
output = self.dropout(output)
|
output = self.dropout(output)
|
||||||
|
output = F.reshape(output, ori_shape)
|
||||||
return output, layer_present
|
return output, layer_present
|
||||||
|
|
||||||
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
|
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
|
||||||
value_past=None, batch_valid_length=None):
|
value_past=None, batch_valid_length=None):
|
||||||
r"""Check inputs"""
|
r"""Check inputs"""
|
||||||
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
|
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
|
||||||
[self.batch_size, self.src_seq_length, self.hidden_size])
|
[[self.batch_size, self.src_seq_length, self.hidden_size],
|
||||||
|
[self.batch_size * self.src_seq_length, self.hidden_size]])
|
||||||
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
|
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
|
||||||
[self.batch_size, self.tgt_seq_length, self.hidden_size])
|
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
|
||||||
|
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
|
||||||
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
|
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
|
||||||
[self.batch_size, self.tgt_seq_length, self.hidden_size])
|
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
|
||||||
|
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
|
||||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||||
[self.batch_size, self.src_seq_length, self.tgt_seq_length])
|
[self.batch_size, self.src_seq_length, self.tgt_seq_length])
|
||||||
|
|
||||||
|
@ -839,20 +845,30 @@ class MultiHeadAttention(Cell):
|
||||||
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
|
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, batch_valid_length)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _convert_to_2d_tensor(self, query_tensor, key_tensor, value_tensor, attention_mask):
|
||||||
|
"""convert a nd tensor to a 2d tensor"""
|
||||||
|
query_shape = F.shape(query_tensor)
|
||||||
|
query_tensor = F.reshape(query_tensor, (-1, query_shape[-1]))
|
||||||
|
key_shape = F.shape(key_tensor)
|
||||||
|
key_tensor = F.reshape(key_tensor, (-1, key_shape[-1]))
|
||||||
|
value_shape = F.shape(value_tensor)
|
||||||
|
value_tensor = F.reshape(value_tensor, (-1, value_shape[-1]))
|
||||||
|
return query_tensor, key_tensor, value_tensor, F.shape(attention_mask)[0], query_shape
|
||||||
|
|
||||||
def _merge_heads(self, x):
|
def _merge_heads(self, x):
|
||||||
"""
|
"""
|
||||||
convert a 4d input to a 3d output
|
convert a 4d input to a 2d output
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
x: input tensor
|
x: input tensor
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
x_merge: the 3d output
|
x_merge: the 2d output
|
||||||
"""
|
"""
|
||||||
x = self.merger_head_transpose(
|
x = self.merger_head_transpose(
|
||||||
x, (0, 2, 1, 3)) # bs, seq_length, head, size_per_head
|
x, (0, 2, 1, 3)) # bs, seq_length, head, size_per_head
|
||||||
x_shape = P.Shape()(x)
|
x_shape = P.Shape()(x)
|
||||||
new_shape = x_shape[:-2] + (x_shape[-2] * x_shape[-1],)
|
new_shape = (-1, x_shape[-2] * x_shape[-1])
|
||||||
x_merge = self.reshape(x, new_shape)
|
x_merge = self.reshape(x, new_shape)
|
||||||
return x_merge
|
return x_merge
|
||||||
|
|
||||||
|
@ -947,7 +963,8 @@ class TransformerEncoderLayer(Cell):
|
||||||
an instance of `OpParallelConfig` with default args.
|
an instance of `OpParallelConfig` with default args.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size].
|
- **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or
|
||||||
|
[batch_size * seq_length, hidden_size].
|
||||||
- **input_mask** (Tensor) - Float Tensor, attention mask with shape [batch_size, seq_length, seq_length].
|
- **input_mask** (Tensor) - Float Tensor, attention mask with shape [batch_size, seq_length, seq_length].
|
||||||
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
|
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
|
||||||
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
|
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True.
|
||||||
|
@ -958,7 +975,7 @@ class TransformerEncoderLayer(Cell):
|
||||||
Tuple, a tuple contains(`output`, `layer_present`).
|
Tuple, a tuple contains(`output`, `layer_present`).
|
||||||
|
|
||||||
- **output** (Tensor) - The float tensor of the output of the layer with
|
- **output** (Tensor) - The float tensor of the output of the layer with
|
||||||
shape (batch_size, seq_length, hidden_size).
|
shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size).
|
||||||
|
|
||||||
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
|
- **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with
|
||||||
((batch_size, num_heads, size_per_head, seq_length),
|
((batch_size, num_heads, size_per_head, seq_length),
|
||||||
|
@ -1034,9 +1051,9 @@ class TransformerEncoderLayer(Cell):
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
||||||
self.layernorm1.shard(((parallel_config.data_parallel, 1, 1),))
|
self.layernorm1.shard(((parallel_config.data_parallel, 1),))
|
||||||
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
||||||
self.layernorm2.shard(((parallel_config.data_parallel, 1, 1),))
|
self.layernorm2.shard(((parallel_config.data_parallel, 1),))
|
||||||
|
|
||||||
self.attention = MultiHeadAttention(batch_size=batch_size,
|
self.attention = MultiHeadAttention(batch_size=batch_size,
|
||||||
src_seq_length=seq_length,
|
src_seq_length=seq_length,
|
||||||
|
@ -1067,7 +1084,8 @@ class TransformerEncoderLayer(Cell):
|
||||||
hidden_act=hidden_act,
|
hidden_act=hidden_act,
|
||||||
parallel_config=parallel_config)
|
parallel_config=parallel_config)
|
||||||
self.post_layernorm_residual = post_layernorm_residual
|
self.post_layernorm_residual = post_layernorm_residual
|
||||||
self.add = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
|
self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1)))
|
||||||
|
self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
|
||||||
self.dtype = mstype.float16
|
self.dtype = mstype.float16
|
||||||
self.key_past = None
|
self.key_past = None
|
||||||
self.value_past = None
|
self.value_past = None
|
||||||
|
@ -1089,6 +1107,8 @@ class TransformerEncoderLayer(Cell):
|
||||||
|
|
||||||
def construct(self, x, input_mask, init_reset=True, batch_valid_length=None):
|
def construct(self, x, input_mask, init_reset=True, batch_valid_length=None):
|
||||||
self._check_input(x, input_mask, init_reset, batch_valid_length)
|
self._check_input(x, input_mask, init_reset, batch_valid_length)
|
||||||
|
x_shape = F.shape(x)
|
||||||
|
x = F.reshape(x, (-1, x_shape[-1]))
|
||||||
input_x = self.layernorm1(x)
|
input_x = self.layernorm1(x)
|
||||||
input_x = F.cast(input_x, self.dtype)
|
input_x = F.cast(input_x, self.dtype)
|
||||||
|
|
||||||
|
@ -1137,10 +1157,23 @@ class TransformerEncoderLayer(Cell):
|
||||||
mlp_logit = F.depend(mlp_logit, value_update)
|
mlp_logit = F.depend(mlp_logit, value_update)
|
||||||
mlp_logit = F.depend(mlp_logit, key_update)
|
mlp_logit = F.depend(mlp_logit, key_update)
|
||||||
|
|
||||||
if self.post_layernorm_residual:
|
# if shape is 3d, we reshape the inputs of the add
|
||||||
output = self.add(output_x, mlp_logit)
|
if len(x_shape) == 3:
|
||||||
|
output_x = P.Reshape()(output_x, x_shape)
|
||||||
|
mlp_logit = P.Reshape()(mlp_logit, x_shape)
|
||||||
|
x = P.Reshape()(x, x_shape)
|
||||||
|
|
||||||
|
if self.post_layernorm_residual:
|
||||||
|
output = self.add_3d(output_x, mlp_logit)
|
||||||
|
else:
|
||||||
|
output = self.add_3d(x, mlp_logit)
|
||||||
else:
|
else:
|
||||||
output = self.add(x, mlp_logit)
|
if self.post_layernorm_residual:
|
||||||
|
output = self.add(output_x, mlp_logit)
|
||||||
|
else:
|
||||||
|
output = self.add(x, mlp_logit)
|
||||||
|
output = F.reshape(output, x_shape)
|
||||||
|
|
||||||
if self.use_moe is True:
|
if self.use_moe is True:
|
||||||
return output, layer_present, aux_loss
|
return output, layer_present, aux_loss
|
||||||
return output, layer_present
|
return output, layer_present
|
||||||
|
@ -1148,7 +1181,8 @@ class TransformerEncoderLayer(Cell):
|
||||||
def _check_input(self, x, input_mask, init_reset, batch_valid_length):
|
def _check_input(self, x, input_mask, init_reset, batch_valid_length):
|
||||||
r"""Check inputs"""
|
r"""Check inputs"""
|
||||||
_check_shape_equal(F.shape(x), "x", self.cls_name,
|
_check_shape_equal(F.shape(x), "x", self.cls_name,
|
||||||
[self.batch_size, self.seq_length, self.hidden_size])
|
[[self.batch_size, self.seq_length, self.hidden_size],
|
||||||
|
[self.batch_size * self.seq_length, self.hidden_size]])
|
||||||
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
|
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
|
||||||
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
|
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
|
||||||
[self.batch_size, self.seq_length, self.seq_length])
|
[self.batch_size, self.seq_length, self.seq_length])
|
||||||
|
@ -1193,10 +1227,12 @@ class TransformerDecoderLayer(Cell):
|
||||||
an instance of `OpParallelConfig` with default args.
|
an instance of `OpParallelConfig` with default args.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, tgt_seq_length, hidden_size].
|
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, tgt_seq_length, hidden_size] or
|
||||||
|
[batch_size * tgt_seq_length, hidden_size].
|
||||||
- **decoder_mask** (Tensor) - the attention mask for decoder with shape [batch_size, src_seq_length,
|
- **decoder_mask** (Tensor) - the attention mask for decoder with shape [batch_size, src_seq_length,
|
||||||
seq_length].
|
seq_length].
|
||||||
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size].
|
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or
|
||||||
|
[batch_size * seq_length, hidden_size].
|
||||||
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
|
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
|
||||||
src_seq_length], where tgt_seq_length is the length of the decoder.
|
src_seq_length], where tgt_seq_length is the length of the decoder.
|
||||||
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
|
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
|
||||||
|
@ -1207,7 +1243,8 @@ class TransformerDecoderLayer(Cell):
|
||||||
Outputs:
|
Outputs:
|
||||||
Tuple, a tuple contains(`output`, `layer_present`)
|
Tuple, a tuple contains(`output`, `layer_present`)
|
||||||
|
|
||||||
- **output** (Tensor) - the output logit of this layer. The shape is [batch, seq_length, hidden_size]
|
- **output** (Tensor) - the output logit of this layer. The shape is [batch, seq_length, hidden_size] or
|
||||||
|
[batch * seq_length, hidden_size].
|
||||||
- **layer_present** (Tensor) - A tuple, where each tuple is the tensor of the projected key and value
|
- **layer_present** (Tensor) - A tuple, where each tuple is the tensor of the projected key and value
|
||||||
vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length),
|
vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length),
|
||||||
(batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector
|
(batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector
|
||||||
|
@ -1295,10 +1332,10 @@ class TransformerDecoderLayer(Cell):
|
||||||
self.use_past = use_past
|
self.use_past = use_past
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
self.layernorm1 = _LayerNorm((hidden_size,), parallel_config.data_parallel).to_float(layernorm_compute_type)
|
self.layernorm1 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
||||||
self.layernorm1.shard(((parallel_config.data_parallel, 1, 1),))
|
self.layernorm1.shard(((parallel_config.data_parallel, 1),))
|
||||||
self.layernorm2 = _LayerNorm((hidden_size,), parallel_config.data_parallel).to_float(layernorm_compute_type)
|
self.layernorm2 = _LayerNorm((hidden_size,)).to_float(layernorm_compute_type)
|
||||||
self.layernorm2.shard(((parallel_config.data_parallel, 1, 1),))
|
self.layernorm2.shard(((parallel_config.data_parallel, 1),))
|
||||||
|
|
||||||
self.attention = MultiHeadAttention(hidden_size=hidden_size,
|
self.attention = MultiHeadAttention(hidden_size=hidden_size,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
|
@ -1323,9 +1360,9 @@ class TransformerDecoderLayer(Cell):
|
||||||
use_past=use_past,
|
use_past=use_past,
|
||||||
param_init_type=param_init_type,
|
param_init_type=param_init_type,
|
||||||
parallel_config=parallel_config)
|
parallel_config=parallel_config)
|
||||||
self.cross_attention_layernorm = _LayerNorm((hidden_size,), parallel_config.data_parallel).to_float(
|
self.cross_attention_layernorm = _LayerNorm((hidden_size,)).to_float(
|
||||||
layernorm_compute_type)
|
layernorm_compute_type)
|
||||||
self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1, 1),))
|
self.cross_attention_layernorm.shard(((parallel_config.data_parallel, 1),))
|
||||||
self.use_moe = (moe_config.expert_num > 1)
|
self.use_moe = (moe_config.expert_num > 1)
|
||||||
if self.use_moe is True:
|
if self.use_moe is True:
|
||||||
self.output = MoE(hidden_size=hidden_size,
|
self.output = MoE(hidden_size=hidden_size,
|
||||||
|
@ -1344,7 +1381,8 @@ class TransformerDecoderLayer(Cell):
|
||||||
param_init_type=param_init_type,
|
param_init_type=param_init_type,
|
||||||
parallel_config=parallel_config)
|
parallel_config=parallel_config)
|
||||||
self.post_layernorm_residual = post_layernorm_residual
|
self.post_layernorm_residual = post_layernorm_residual
|
||||||
self.add = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
|
self.add = P.Add().shard(((parallel_config.data_parallel, 1), (parallel_config.data_parallel, 1)))
|
||||||
|
self.add_3d = P.Add().shard(((parallel_config.data_parallel, 1, 1), (parallel_config.data_parallel, 1, 1)))
|
||||||
self.dtype = mstype.float16
|
self.dtype = mstype.float16
|
||||||
self.key_past = None
|
self.key_past = None
|
||||||
self.value_past = None
|
self.value_past = None
|
||||||
|
@ -1369,7 +1407,9 @@ class TransformerDecoderLayer(Cell):
|
||||||
memory_mask=None,
|
memory_mask=None,
|
||||||
init_reset=True, batch_valid_length=None):
|
init_reset=True, batch_valid_length=None):
|
||||||
self._check_input(hidden_stats, decoder_mask, encoder_output, memory_mask, init_reset, batch_valid_length)
|
self._check_input(hidden_stats, decoder_mask, encoder_output, memory_mask, init_reset, batch_valid_length)
|
||||||
# the returned shape is [bs, seq_length, embedding_size]
|
# the returned shape is [bs, seq_length, embedding_size] or [bs * seq_length, embedding_size]
|
||||||
|
hidden_shape = F.shape(hidden_stats)
|
||||||
|
hidden_stats = F.reshape(hidden_stats, (-1, hidden_shape[-1]))
|
||||||
input_x = self.layernorm1(hidden_stats)
|
input_x = self.layernorm1(hidden_stats)
|
||||||
input_x = F.cast(input_x, self.dtype)
|
input_x = F.cast(input_x, self.dtype)
|
||||||
|
|
||||||
|
@ -1431,10 +1471,23 @@ class TransformerDecoderLayer(Cell):
|
||||||
mlp_logit = F.depend(mlp_logit, value_update)
|
mlp_logit = F.depend(mlp_logit, value_update)
|
||||||
mlp_logit = F.depend(mlp_logit, key_update)
|
mlp_logit = F.depend(mlp_logit, key_update)
|
||||||
|
|
||||||
if self.post_layernorm_residual:
|
# if shape is 3d, we reshape the inputs of the add
|
||||||
output = self.add(output_x, mlp_logit)
|
if len(hidden_shape) == 3:
|
||||||
|
output_x = P.Reshape()(output_x, hidden_shape)
|
||||||
|
mlp_logit = P.Reshape()(mlp_logit, hidden_shape)
|
||||||
|
x = P.Reshape()(x, hidden_shape)
|
||||||
|
|
||||||
|
if self.post_layernorm_residual:
|
||||||
|
output = self.add_3d(output_x, mlp_logit)
|
||||||
|
else:
|
||||||
|
output = self.add_3d(x, mlp_logit)
|
||||||
else:
|
else:
|
||||||
output = self.add(x, mlp_logit)
|
if self.post_layernorm_residual:
|
||||||
|
output = self.add(output_x, mlp_logit)
|
||||||
|
else:
|
||||||
|
output = self.add(x, mlp_logit)
|
||||||
|
output = F.reshape(output, hidden_shape)
|
||||||
|
|
||||||
if self.use_moe is True:
|
if self.use_moe is True:
|
||||||
return output, layer_present, aux_loss
|
return output, layer_present, aux_loss
|
||||||
return output, layer_present
|
return output, layer_present
|
||||||
|
@ -1442,14 +1495,16 @@ class TransformerDecoderLayer(Cell):
|
||||||
def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length):
|
def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length):
|
||||||
r"""Check inputs"""
|
r"""Check inputs"""
|
||||||
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
|
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
|
||||||
[self.batch_size, self.tgt_seq_length, self.hidden_size])
|
[[self.batch_size, self.tgt_seq_length, self.hidden_size],
|
||||||
|
[self.batch_size * self.tgt_seq_length, self.hidden_size]])
|
||||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||||
[self.batch_size, self.tgt_seq_length, self.tgt_seq_length])
|
[self.batch_size, self.tgt_seq_length, self.tgt_seq_length])
|
||||||
_check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name)
|
_check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name)
|
||||||
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
|
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
|
||||||
if encoder_output is not None:
|
if encoder_output is not None:
|
||||||
_check_shape_equal(F.shape(encoder_output), "encoder_output", self.cls_name,
|
_check_shape_equal(F.shape(encoder_output), "encoder_output", self.cls_name,
|
||||||
[self.batch_size, self.src_seq_length, self.hidden_size])
|
[[self.batch_size, self.src_seq_length, self.hidden_size],
|
||||||
|
[self.batch_size * self.src_seq_length, self.hidden_size]])
|
||||||
_check_input_dtype(F.dtype(encoder_output), "encoder_output",
|
_check_input_dtype(F.dtype(encoder_output), "encoder_output",
|
||||||
[mstype.float32, mstype.float16], self.cls_name)
|
[mstype.float32, mstype.float16], self.cls_name)
|
||||||
if memory_mask is not None:
|
if memory_mask is not None:
|
||||||
|
@ -1547,7 +1602,8 @@ class TransformerEncoder(Cell):
|
||||||
an instance of `TransformerOpParallelConfig` with default args.
|
an instance of `TransformerOpParallelConfig` with default args.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size]
|
- **hidden_states** (Tensor) - Tensor, shape should be [batch_size, seq_length, hidden_size] or
|
||||||
|
[batch_size * seq_length, hidden_size]
|
||||||
- **attention_mask** (Tensor) - Tensor, attention mask with shape [batch_size, seq_length, seq_length]
|
- **attention_mask** (Tensor) - Tensor, attention mask with shape [batch_size, seq_length, seq_length]
|
||||||
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
|
- **init_reset** (Tensor) - A bool tensor with shape [batch_size], used to clear the past key parameter and
|
||||||
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
|
past value parameter used in the incremental prediction. Only valid when use_past is True. Default True
|
||||||
|
@ -1558,7 +1614,7 @@ class TransformerEncoder(Cell):
|
||||||
Tuple, a tuple contains(`output`, `layer_present`)
|
Tuple, a tuple contains(`output`, `layer_present`)
|
||||||
|
|
||||||
- **output** (Tensor) - The float tensor of the output of the layer with
|
- **output** (Tensor) - The float tensor of the output of the layer with
|
||||||
shape (batch_size, seq_length, hidden_size)
|
shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size)
|
||||||
- **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple contains the Tensor the
|
- **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple contains the Tensor the
|
||||||
projected key and value vector with shape ((batch_size, num_heads, size_per_head, seq_length),
|
projected key and value vector with shape ((batch_size, num_heads, size_per_head, seq_length),
|
||||||
and (batch_size, num_heads, seq_length, size_per_head)).
|
and (batch_size, num_heads, seq_length, size_per_head)).
|
||||||
|
@ -1717,9 +1773,11 @@ class TransformerDecoder(Cell):
|
||||||
an instance of `TransformerOpParallelConfig` with default args.
|
an instance of `TransformerOpParallelConfig` with default args.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size]
|
- **hidden_stats** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size] or
|
||||||
|
[batch_size * seq_length, hidden_size]
|
||||||
- **attention_mask** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length]
|
- **attention_mask** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length]
|
||||||
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size]
|
- **encoder_output** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or
|
||||||
|
[batch_size * seq_length, hidden_size]
|
||||||
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
|
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
|
||||||
src_seq_length] where tgt_seq_length is the length of the decoder. the output of the encoder with shape
|
src_seq_length] where tgt_seq_length is the length of the decoder. the output of the encoder with shape
|
||||||
[batch_size, seq_length, hidden_size],
|
[batch_size, seq_length, hidden_size],
|
||||||
|
@ -1731,7 +1789,8 @@ class TransformerDecoder(Cell):
|
||||||
Outputs:
|
Outputs:
|
||||||
Tuple, a tuple contains(`output`, `layer_present`)
|
Tuple, a tuple contains(`output`, `layer_present`)
|
||||||
|
|
||||||
- **output** (Tensor) - The output logit of this layer. The shape is [batch, tgt_seq_length, hidden_size]
|
- **output** (Tensor) - The output logit of this layer. The shape is [batch, tgt_seq_length, hidden_size] or
|
||||||
|
[batch * tgt_seq_length, hidden_size]
|
||||||
- **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor of the projected
|
- **layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor of the projected
|
||||||
key and value vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length),
|
key and value vector in self attention with shape ((batch_size, num_heads, size_per_head, tgt_seq_length),
|
||||||
(batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector
|
(batch_size, num_heads, tgt_seq_length, size_per_head), and of the projected key and value vector
|
||||||
|
@ -1912,9 +1971,11 @@ class Transformer(Cell):
|
||||||
an instance of `TransformerOpParallelConfig` with default args.
|
an instance of `TransformerOpParallelConfig` with default args.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **encoder_inputs** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size].
|
- **encoder_inputs** (Tensor) - the input tensor with shape [batch_size, seq_length, hidden_size] or
|
||||||
|
[batch_size * seq_length, hidden_size].
|
||||||
- **encoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length].
|
- **encoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length].
|
||||||
- **decoder_inputs** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size],
|
- **decoder_inputs** (Tensor) - the output of the encoder with shape [batch_size, seq_length, hidden_size] or
|
||||||
|
[batch_size * seq_length, hidden_size],
|
||||||
this should be none if the decoder layer is 0.
|
this should be none if the decoder layer is 0.
|
||||||
- **decoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length]
|
- **decoder_masks** (Tensor) - the attention mask for decoder with shape [batch_size, seq_length, seq_length]
|
||||||
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
|
- **memory_mask** (Tensor) - the memory mask of the cross attention with shape [batch, tgt_seq_length,
|
||||||
|
@ -1930,8 +1991,9 @@ class Transformer(Cell):
|
||||||
Tuple, a tuple contains(`output`, `encoder_layer_present`, `encoder_layer_present`)
|
Tuple, a tuple contains(`output`, `encoder_layer_present`, `encoder_layer_present`)
|
||||||
|
|
||||||
- **output** (Tensor) - If there is only encoder, the output logit of the encoder layer. The shape is
|
- **output** (Tensor) - If there is only encoder, the output logit of the encoder layer. The shape is
|
||||||
[batch, src_seq_length, hidden_size], if there are encoder and decoders, the output is from the
|
[batch, src_seq_length, hidden_size] or [batch * src_seq_length, hidden_size], if there are encoder and
|
||||||
decoder layer. The shape is [batch, tgt_seq_length, hidden_size].
|
decoders, the output is from the decoder layer. The shape is [batch, tgt_seq_length, hidden_size] or
|
||||||
|
[batch * tgt_seq_length, hidden_size].
|
||||||
- **encoder_layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor the
|
- **encoder_layer_present** (Tuple) - A tuple with size of num_layers, where each tuple is the tensor the
|
||||||
projected key and value vector in self attention with shape ((batch_size, num_heads, size_per_head,
|
projected key and value vector in self attention with shape ((batch_size, num_heads, size_per_head,
|
||||||
src_seq_length), (batch_size, num_heads, src_seq_length, size_per_head)).
|
src_seq_length), (batch_size, num_heads, src_seq_length, size_per_head)).
|
||||||
|
|
|
@ -20,7 +20,7 @@ from mindspore.context import set_auto_parallel_context, ParallelMode
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.parallel.nn import Transformer, TransformerOpParallelConfig, MoEConfig
|
from mindspore.parallel.nn import Transformer, TransformerOpParallelConfig, MoEConfig
|
||||||
from mindspore.nn.optim import AdamWeightDecay
|
from mindspore.nn.optim import AdamWeightDecay
|
||||||
from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell
|
from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell, _VirtualDatasetCell
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model
|
||||||
from tests.dataset_mock import MindData
|
from tests.dataset_mock import MindData
|
||||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||||
|
@ -77,7 +77,7 @@ def test_transformer_model():
|
||||||
ffn_hidden_size=64,
|
ffn_hidden_size=64,
|
||||||
moe_config=moe_config,
|
moe_config=moe_config,
|
||||||
parallel_config=config)
|
parallel_config=config)
|
||||||
|
net = _VirtualDatasetCell(net)
|
||||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
|
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
|
||||||
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
|
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
|
||||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
|
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
|
||||||
|
@ -92,3 +92,35 @@ def test_transformer_model():
|
||||||
model = Model(net_with_grad)
|
model = Model(net_with_grad)
|
||||||
|
|
||||||
model.train(1, dataset, dataset_sink_mode=False)
|
model.train(1, dataset, dataset_sink_mode=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformer_model_2d():
|
||||||
|
set_auto_parallel_context(device_num=16, global_rank=0,
|
||||||
|
full_batch=True, enable_alltoall=True,
|
||||||
|
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||||
|
net = Transformer(encoder_layers=1,
|
||||||
|
decoder_layers=1,
|
||||||
|
batch_size=2,
|
||||||
|
src_seq_length=20,
|
||||||
|
tgt_seq_length=10,
|
||||||
|
hidden_size=64,
|
||||||
|
num_heads=8,
|
||||||
|
ffn_hidden_size=64,
|
||||||
|
moe_config=moe_config,
|
||||||
|
parallel_config=config)
|
||||||
|
net = _VirtualDatasetCell(net)
|
||||||
|
|
||||||
|
encoder_input_value = Tensor(np.ones((40, 64)), mstype.float32)
|
||||||
|
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
|
||||||
|
decoder_input_value = Tensor(np.ones((20, 64)), mstype.float32)
|
||||||
|
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
|
||||||
|
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
|
||||||
|
net = NetWithLossFiveInputs(net)
|
||||||
|
params = net.trainable_params()
|
||||||
|
optimizer = AdamWeightDecay(params)
|
||||||
|
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||||
|
memory_mask)
|
||||||
|
net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
|
||||||
|
model = Model(net_with_grad)
|
||||||
|
|
||||||
|
model.train(1, dataset, dataset_sink_mode=False)
|
||||||
|
|
|
@ -56,6 +56,28 @@ class Dataset(MindData):
|
||||||
self.index = 0
|
self.index = 0
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerNet(nn.Cell):
|
||||||
|
def __init__(self, en_layer, de_layer, parallel_config):
|
||||||
|
super(TransformerNet, self).__init__()
|
||||||
|
self.embedding = VocabEmbedding(vocab_size=240, embedding_size=20,
|
||||||
|
parallel_config=config.embedding_dp_mp_config)
|
||||||
|
self.network = Transformer(encoder_layers=en_layer,
|
||||||
|
decoder_layers=de_layer,
|
||||||
|
batch_size=2,
|
||||||
|
src_seq_length=20,
|
||||||
|
tgt_seq_length=10,
|
||||||
|
hidden_size=64,
|
||||||
|
num_heads=8,
|
||||||
|
ffn_hidden_size=64,
|
||||||
|
parallel_config=parallel_config)
|
||||||
|
self.head = Linear(in_channels=64, out_channels=200)
|
||||||
|
self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config)
|
||||||
|
|
||||||
|
def construct(self, x1, x2, x3, x4, x5, y, mask):
|
||||||
|
predict, _, _ = self.network(x1, x2, x3, x4, x5)
|
||||||
|
predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
|
||||||
|
return self.loss(predict, y, mask)
|
||||||
|
|
||||||
config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
|
config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
|
||||||
pipeline_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, pipeline_stage=4,
|
pipeline_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, pipeline_stage=4,
|
||||||
micro_batch_num=4, vocab_emb_dp=False)
|
micro_batch_num=4, vocab_emb_dp=False)
|
||||||
|
@ -84,28 +106,6 @@ def run_total_transformer_model_head(e_layer,
|
||||||
full_batch=True,
|
full_batch=True,
|
||||||
global_rank=0, parallel_mode=mode)
|
global_rank=0, parallel_mode=mode)
|
||||||
|
|
||||||
class Net(nn.Cell):
|
|
||||||
def __init__(self, en_layer, de_layer, parallel_config):
|
|
||||||
super(Net, self).__init__()
|
|
||||||
self.embedding = VocabEmbedding(vocab_size=240, embedding_size=20,
|
|
||||||
parallel_config=config.embedding_dp_mp_config)
|
|
||||||
self.network = Transformer(encoder_layers=en_layer,
|
|
||||||
decoder_layers=de_layer,
|
|
||||||
batch_size=2,
|
|
||||||
src_seq_length=20,
|
|
||||||
tgt_seq_length=10,
|
|
||||||
hidden_size=64,
|
|
||||||
num_heads=8,
|
|
||||||
ffn_hidden_size=64,
|
|
||||||
parallel_config=parallel_config)
|
|
||||||
self.head = Linear(in_channels=64, out_channels=200)
|
|
||||||
self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config)
|
|
||||||
|
|
||||||
def construct(self, x1, x2, x3, x4, x5, y, mask):
|
|
||||||
predict, _, _ = self.network(x1, x2, x3, x4, x5)
|
|
||||||
predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
|
|
||||||
return self.loss(predict, y, mask)
|
|
||||||
|
|
||||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
|
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
|
||||||
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
|
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
|
||||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
|
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
|
||||||
|
@ -116,7 +116,8 @@ def run_total_transformer_model_head(e_layer,
|
||||||
seq = 10
|
seq = 10
|
||||||
label = Tensor(np.ones((2 * seq,)), mstype.int32)
|
label = Tensor(np.ones((2 * seq,)), mstype.int32)
|
||||||
input_mask = Tensor(np.ones((2 * seq,)), mstype.float32)
|
input_mask = Tensor(np.ones((2 * seq,)), mstype.float32)
|
||||||
net = Net(en_layer=e_layer, de_layer=d_layer, parallel_config=arg_parallel_config)
|
net = TransformerNet(en_layer=e_layer, de_layer=d_layer, parallel_config=arg_parallel_config)
|
||||||
|
net = _VirtualDatasetCell(net)
|
||||||
params = net.trainable_params()
|
params = net.trainable_params()
|
||||||
optimizer = AdamWeightDecay(params)
|
optimizer = AdamWeightDecay(params)
|
||||||
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||||
|
@ -147,6 +148,38 @@ def test_transformer_model():
|
||||||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
|
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
|
||||||
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
|
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
|
||||||
net = NetWithLossFiveInputs(net)
|
net = NetWithLossFiveInputs(net)
|
||||||
|
net = _VirtualDatasetCell(net)
|
||||||
|
params = net.trainable_params()
|
||||||
|
optimizer = AdamWeightDecay(params)
|
||||||
|
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||||
|
memory_mask)
|
||||||
|
net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
|
||||||
|
model = Model(net_with_grad)
|
||||||
|
|
||||||
|
model.train(1, dataset, dataset_sink_mode=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transformer_model_2d_inputs():
|
||||||
|
set_auto_parallel_context(device_num=8, global_rank=0,
|
||||||
|
full_batch=True,
|
||||||
|
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||||
|
net = Transformer(encoder_layers=1,
|
||||||
|
decoder_layers=2,
|
||||||
|
batch_size=2,
|
||||||
|
src_seq_length=20,
|
||||||
|
tgt_seq_length=10,
|
||||||
|
hidden_size=64,
|
||||||
|
num_heads=8,
|
||||||
|
ffn_hidden_size=64,
|
||||||
|
parallel_config=config)
|
||||||
|
|
||||||
|
encoder_input_value = Tensor(np.ones((40, 64)), mstype.float32)
|
||||||
|
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
|
||||||
|
decoder_input_value = Tensor(np.ones((20, 64)), mstype.float32)
|
||||||
|
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
|
||||||
|
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
|
||||||
|
net = NetWithLossFiveInputs(net)
|
||||||
|
net = _VirtualDatasetCell(net)
|
||||||
params = net.trainable_params()
|
params = net.trainable_params()
|
||||||
optimizer = AdamWeightDecay(params)
|
optimizer = AdamWeightDecay(params)
|
||||||
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||||
|
@ -177,6 +210,7 @@ def test_transformer_model_int64_inputs():
|
||||||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
|
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
|
||||||
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
|
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
|
||||||
net = NetWithLossFiveInputs(net)
|
net = NetWithLossFiveInputs(net)
|
||||||
|
net = _VirtualDatasetCell(net)
|
||||||
params = net.trainable_params()
|
params = net.trainable_params()
|
||||||
optimizer = AdamWeightDecay(params)
|
optimizer = AdamWeightDecay(params)
|
||||||
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||||
|
@ -330,6 +364,8 @@ def test_encoder():
|
||||||
|
|
||||||
net = NetWithLoss(net)
|
net = NetWithLoss(net)
|
||||||
|
|
||||||
|
net = _VirtualDatasetCell(net)
|
||||||
|
|
||||||
dataset = Dataset(encoder_input_value, encoder_input_mask)
|
dataset = Dataset(encoder_input_value, encoder_input_mask)
|
||||||
|
|
||||||
model = Model(net)
|
model = Model(net)
|
||||||
|
@ -367,6 +403,8 @@ def test_decoder():
|
||||||
|
|
||||||
net = NetWithLoss(net)
|
net = NetWithLoss(net)
|
||||||
|
|
||||||
|
net = _VirtualDatasetCell(net)
|
||||||
|
|
||||||
dataset = Dataset(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
|
dataset = Dataset(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
|
||||||
|
|
||||||
model = Model(net)
|
model = Model(net)
|
||||||
|
@ -388,6 +426,7 @@ def test_vocabembedding_dp_true():
|
||||||
|
|
||||||
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
|
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
|
||||||
net = NetWithLoss(net)
|
net = NetWithLoss(net)
|
||||||
|
net = _VirtualDatasetCell(net)
|
||||||
encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
|
encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
|
||||||
dataset = Dataset(encoder_input_value)
|
dataset = Dataset(encoder_input_value)
|
||||||
|
|
||||||
|
@ -410,6 +449,7 @@ def test_vocabembedding_dp_false():
|
||||||
|
|
||||||
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
|
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
|
||||||
net = NetWithLoss(net)
|
net = NetWithLoss(net)
|
||||||
|
net = _VirtualDatasetCell(net)
|
||||||
encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
|
encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
|
||||||
dataset = Dataset(encoder_input_value)
|
dataset = Dataset(encoder_input_value)
|
||||||
|
|
||||||
|
@ -484,6 +524,7 @@ def test_sparse_attention_parallel_dp():
|
||||||
num_heads=8,
|
num_heads=8,
|
||||||
block_size=64,
|
block_size=64,
|
||||||
parallel_config=sparse_attention_config)
|
parallel_config=sparse_attention_config)
|
||||||
|
net = _VirtualDatasetCell(net)
|
||||||
q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||||
k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||||
v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
|
||||||
|
@ -509,6 +550,7 @@ def test_parallel_cross_entroy_loss_semi_auto_parallel():
|
||||||
|
|
||||||
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
|
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
|
||||||
net = NetWithLoss(net, config.dp_mp_config)
|
net = NetWithLoss(net, config.dp_mp_config)
|
||||||
|
net = _VirtualDatasetCell(net)
|
||||||
embed_ids = Tensor(np.ones((2, 64)), mstype.int32)
|
embed_ids = Tensor(np.ones((2, 64)), mstype.int32)
|
||||||
labels = Tensor(np.ones((2 * 64,)), mstype.int32)
|
labels = Tensor(np.ones((2 * 64,)), mstype.int32)
|
||||||
input_mask = Tensor(np.ones((2 * 64,)), mstype.float32)
|
input_mask = Tensor(np.ones((2 * 64,)), mstype.float32)
|
||||||
|
|
Loading…
Reference in New Issue