!21809 Improved Transformer Struture and Add Args Check

Merge pull request !21809 from huangxinjing/transformer_improved
This commit is contained in:
i-robot 2021-08-17 12:05:25 +00:00 committed by Gitee
commit 0d839fa7c6
8 changed files with 2008 additions and 605 deletions

View File

@ -17,5 +17,10 @@ Parallel Networks.
This is an experimental interface that is subject to change and/or deletion.
"""
from .transformer import *
from .loss import *
from .config import *
__all__ = []
__all__.extend(transformer.__all__)
__all__.extend(loss.__all__)
__all__.extend(config.__all__)

View File

@ -0,0 +1,162 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Parallel Config for the Parallel Training
This is an experimental interface that is subject to change and/or deletion.
"""
from mindspore._checkparam import Validator
from mindspore import context
import mindspore.communication.management as D
from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_parallel_mode
__all__ = [
"OpParallelConfig"
]
class _Config:
r""" A basic class of the configure"""
def __str__(self):
info = "[ParallelConfig]" + '\n'
for k, v in self.__dict__.items():
var_info = "{}:{}\n".format(k, v)
info += var_info
return info
class OpParallelConfig(_Config):
r"""
OpParallelConfig for the setting the data parallel and model parallel.
Args:
data_parallel (int): The data parallel way. Default: 1
model_parallel (int): The model parallel way. Default: 1
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> from mindspore.nn.parallel import OpParallelConfig
>>> config=OpParallelConfig(data_parallel=1, model_parallel=1)
"""
def __init__(self, data_parallel=1, model_parallel=1):
Validator.check_positive_int(data_parallel, "data_parallel")
Validator.check_positive_int(model_parallel, "model_parallel")
self._data_parallel = data_parallel
self._model_parallel = model_parallel
@property
def data_parallel(self):
return self._data_parallel
@data_parallel.setter
def data_parallel(self, value):
Validator.check_positive_int(value, "data_parallel")
self._data_parallel = value
@property
def model_parallel(self):
return self._model_parallel
@model_parallel.setter
def model_parallel(self, value):
Validator.check_positive_int(value, "model_parallel")
self._model_parallel = value
class _PipeLineConfig(_Config):
r"""
PPConfig for the setting the data parallel, model parallel
Args:
pipeline_stage (int): The number of the pipeline stages. Default: 1
micro_batch_num (int): The model parallel way. Default: 1
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> config=_PipeLineConfig(pipeline_stage=1, micro_batch_num=1)
"""
def __init__(self, pipeline_stage=1, micro_batch_num=1):
Validator.check_positive_int(pipeline_stage, "pipeline_stage")
Validator.check_positive_int(micro_batch_num, "micro_batch_num")
self._pipeline_stage = pipeline_stage
self._micro_batch_num = micro_batch_num
@property
def pipeline_stage(self):
return self._pipeline_stage
@pipeline_stage.setter
def pipeline_stage(self, value):
Validator.check_positive_int(value, "pipeline_stage")
self._pipeline_stage = value
context.set_auto_parallel_context(pipeline_stages=value)
@property
def micro_batch_num(self):
return self._micro_batch_num
@micro_batch_num.setter
def micro_batch_num(self, value):
Validator.check_positive_int(value, "micro_batch_num")
self._micro_batch_num = value
# In case the user doesn't pass a config as args.
default_dpmp_config = OpParallelConfig()
def _check_config(config):
"""
Check if micro_batch_num >= pipeline_stage
"""
# the config pipeline_stage is same with context.pipeline_stage
pipeline_stage = context.get_auto_parallel_context("pipeline_stages")
if hasattr(config, 'pipeline_stage') and pipeline_stage != config.pipeline_stage:
raise ValueError(
f"The pipeline stage {pipeline_stage} in auto_parallel_context is not equal to the pipeline_stage "
f"{config.pipeline_stage}"
f" in the config.")
# make sure the following is in auto parallel mode
is_auto_parallel = _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
if not is_auto_parallel:
return
device_num = D.get_group_size()
optimizer_shard = context.get_auto_parallel_context("enable_parallel_optimizer")
# dp * pp * pipeline_stage <= device_num
if config.data_parallel * config.model_parallel * pipeline_stage > device_num:
raise ValueError("The product of the data parallel {config.data_parallel},"
"model parallel {config.model_parallel}"
"pipeline stages {pipeline_stage}"
"should be less than device_num {device_num}")
# the config optimizer_shard is same with context.optimizer_shard
if hasattr(config, "optimizer_shard") and optimizer_shard != config.optimizer_shard:
raise ValueError(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the"
f"optimizer_shard {config.optimizer_shard} in the config")
# pipeline_stage <= micro_batch_num
if hasattr(config, 'pipeline_stage') and hasattr(config, 'micro_batch_num')\
and config.pipeline_stage < config.micro_batch_num:
raise ValueError(
f"The pipeline stage {config.pipeline_stage} should be greater than the micro_batch_num"
f"{config.micro_batch_num}.")

View File

@ -0,0 +1,131 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Parallel Loss for the Parallel Training
This is an experimental interface that is subject to change and/or deletion.
"""
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn import Cell
from mindspore.nn.loss.loss import _check_is_tensor
from mindspore.nn.parallel.transformer.transformer import _check_input_dtype, _check_input_shape
from .config import default_dpmp_config, OpParallelConfig
__all__ = ["CrossEntropyLoss"]
class CrossEntropyLoss(Cell):
"""
Calculate the cross entropy loss.
Args:
parallel_config(OpParallelConfig): the configure of the parallel. Default:'default_dpmp_config'
Inputs:
- **logits** (Tensor) - Tensor of shape (N, C). Data type must be float16 or float32. the output logits of
the backbone.
- **labels** (Tensor) - Tensor of shape (N, ). The ground truth label of the sample.
- **input_mask** (Tensor): Tensor of shape (N, ). input_mask indicates whether there is padded inputs and for
padded inputs it will not be counted into loss.
Returns:
loss: Tensor, the corresponding cross entropy loss
Exapmes:
>>> loss = nn.parallel.CrossEntropyLoss()
>>> logits = Tensor(np.array([[3, 5, 6, 9, 12, 33, 42, 12, 32, 72]]), mindspore.float32)
>>> labels_np = np.array([1]).astype(np.int32)
>>> input_mask = Tensor(np.ones(1).astype(np.float32))
>>> labels = Tensor(labels_np)
>>> output = loss(logits, labels, input_mask)
>>> print(output.shape)
(1,)
"""
def __init__(self, parallel_config=default_dpmp_config):
super(CrossEntropyLoss, self).__init__()
if not isinstance(parallel_config, OpParallelConfig):
raise TypeError("Input args parallel_config must be the type OpParallelConfig.")
dp = parallel_config.data_parallel
mp = parallel_config.model_parallel
self.sum = P.ReduceSum().shard(((dp, mp),))
self.onehot = P.OneHot().shard(((dp, mp), (), ()))
# on/off value for onehot, for smooth labeling, modify the off_value
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.max = P.ArgMaxWithValue(axis=-1, keep_dims=True).shard(
((dp, mp),))
self.eps_const = Tensor(1e-24, mstype.float32)
self.sub = P.Sub().shard(((dp, mp), (dp, 1)))
self.exp = P.Exp().shard(((dp, mp),))
self.div = P.RealDiv().shard(((dp, mp), (dp, 1)))
self.log = P.Log().shard(((dp, mp),))
self.add = P.TensorAdd().shard(((dp, mp), ()))
self.mul = P.Mul().shard(
((dp, mp), (dp, mp)))
self.neg = P.Neg().shard(((dp, mp),))
self.sum2 = P.ReduceSum().shard(((1,),))
self.mul2 = P.Mul().shard(((1,), (1,)))
self.add2 = P.TensorAdd()
self.div2 = P.RealDiv()
def construct(self, logits, label, input_mask):
r"""
Compute loss using logits, label and input mask
"""
self._check_input(logits, label, input_mask)
# [bs*seq_length, vocab_size]
logits = F.cast(logits, mstype.float32)
# LogSoftmax for logits over last dimension
_, logit_max = self.max(logits)
logit_sub = self.sub(logits, logit_max)
logit_exp = self.exp(logit_sub)
exp_sum = self.sum(logit_exp, -1)
exp_sum = P.Reshape()(exp_sum, (F.shape(exp_sum)[0], 1))
softmax_result = self.div(logit_exp, exp_sum)
log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
# Flatten label to [bs*seq_length]
label = P.Reshape()(label, (-1,))
# Get onehot label [bs*seq_length, vocab_size]
one_hot_label = self.onehot(label, F.shape(logits)[-1], self.on_value,
self.off_value)
# Cross-Entropy loss
loss = self.mul(log_softmax_result, one_hot_label)
loss_unsum = self.neg(loss)
loss_reduce = self.sum(loss_unsum, -1)
# input_mask indicates whether there is padded inputs and for padded inputs it will not be counted into loss
input_mask = P.Reshape()(input_mask, (-1,))
numerator = self.sum2(self.mul2(loss_reduce, input_mask))
denominator = self.add2(
self.sum2(input_mask),
P.Cast()(F.tuple_to_array((1e-5,)), mstype.float32))
loss = self.div2(numerator, denominator)
return loss
def _check_input(self, logits, label, input_mask):
r"""Check the input tensor shape and type"""
_check_is_tensor('logits', logits, self.cls_name)
_check_is_tensor('label', label, self.cls_name)
_check_is_tensor('input_mask', input_mask, self.cls_name)
_check_input_dtype(F.dtype(logits), "logits", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(label), "label", [mstype.int32], self.cls_name)
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32], self.cls_name)
_check_input_shape(F.shape(logits), "logits", self.cls_name, 2)
_check_input_shape(F.shape(label), "label", self.cls_name, 1)
_check_input_shape(F.shape(input_mask), "input_mask", self.cls_name, 1)
return True

View File

@ -17,6 +17,7 @@ Transformer Networks
This is an experimental interface that is subject to change and/or deletion.
"""
from .transformer import *
from .layers import *
__all__ = []
__all__.extend(transformer.__all__)

View File

@ -0,0 +1,222 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
The basic layer of the Transformer Networks. This is an experimental interface that is subject to
change and/or deletion.
"""
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer, Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P
from mindspore._extends import cell_attr_register
from mindspore.nn.cell import Cell
from mindspore.nn.layer import Dense
class _LayerNorm(Cell):
r"""
A self-defined layer norm operation using reduce sum and reduce mean
Args:
normalized_shape (tuple): The shape of the input tensor
eps (float): The epsilon value of the denominator. Default 1e-5.
param_init_type: The param init type.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(batch, seq\_length, hidden\_size)`.
Outputs:
Tensor of shape :math:`(batch, seq_length, hidden_size)`.
"""
def __init__(self, normalized_shape, eps=1e-5, param_init_type=mstype.float32):
super(_LayerNorm, self).__init__()
if param_init_type not in [mstype.float32, mstype.float16]:
raise TypeError(f"param type should in [float32, float16], but found type {type(param_init_type)}")
self.gamma = Parameter(initializer('ones', normalized_shape, param_init_type), name="gamma",
parallel_optimizer=False)
self.beta = Parameter(initializer('zeros', normalized_shape, param_init_type), name="beta",
parallel_optimizer=False)
self.mean = P.ReduceMean(keep_dims=True)
self.square = P.Square()
self.sqrt = P.Sqrt()
self.sub1 = P.Sub()
self.sub2 = P.Sub()
self.add = P.TensorAdd()
self.eps = eps
self.mul = P.Mul()
self.add2 = P.TensorAdd()
self.real_div = P.RealDiv()
def construct(self, x):
r"""
x : batch x seq_length x hidden_size
"""
mean = self.mean(x, -1)
diff = self.sub1(x, mean)
variance = self.mean(self.square(diff), -1)
variance_eps = self.sqrt(self.add(variance, self.eps))
output = self.real_div(diff, variance_eps)
output = self.add2(self.mul(output, self.gamma), self.beta)
return output
def shard(self, strategy):
r"""
Set the shard for the layer norm. the strategy size should be equal to the inputs.
Note:
It is valid only in semi auto parallel or auto parallel mode.
In other parallel modes, strategies set here will be ignored.
Args:
strategy (tuple): The strategy for the dropout. Should be the same shape as the inputs.
Examples:
>>> net = nn.parallel.transformer.LayerNorm(normalized_shape=(1024, 10))
>>> net.shard(((10, 2, 1),))
"""
self.mean.shard(strategy)
self.square.shard(strategy)
self.sqrt.shard(strategy)
self.sub1.shard((strategy[0], strategy[0]))
self.sub2.shard((strategy[0], strategy[0]))
self.add.shard((strategy[0], ()))
self.mul.shard((strategy[0], (1,)))
self.add2.shard((strategy[0], (1,)))
self.real_div.shard((strategy[0], strategy[0]))
return self
class _Linear(Dense):
r"""
The dense connected layer. Once the parallel mode is enabled, the input shape should be
3-D tensor.
Applies dense connected layer for the input. This layer implements the operation as:
.. math::
\text{outputs} = \text{activation}(\text{X} * \text{kernel} + \text{bias}),
where :math:`X` is the input tensors, :math:`\text{activation}` is the activation function passed as the activation
argument (if passed in), :math:`\text{kernel}` is a weight matrix with the same
data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector
with the same data type as the :math:`X` created by the layer (only if has_bias is True).
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as `x`. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as `x`. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): activate function applied to the output of the fully connected layer,
eg. 'ReLU'.Default: None.
compute_dtype (mstype): The computation type. Default: mstype.float16
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `in_channels` in `Args` should be equal
to :math:`in\_channels` in `Inputs`.
Outputs:
Tensor of shape :math:`(*, out\_channels)`.
Raises:
TypeError: If `in_channels` or `out_channels` is not an int.
TypeError: If `has_bias` is not a bool.
TypeError: If `activation` is not one of str, Cell, Primitive, None.
ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init`
is not equal to `out_channels` or shape[1] of `weight_init` is not equal to `in_channels`.
ValueError: If length of shape of `bias_init` is not equal to 1
or shape[0] of `bias_init` is not equal to `out_channels`.
Supported Platforms:
``Ascend`` ``GPU``
"""
@cell_attr_register(attrs=['has_bias', 'in_channels', 'out_channels', 'shard_output', 'activation'])
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
activation=None,
transpose_b=True,
param_init_type=mstype.float32,
compute_dtype=mstype.float16):
super(_Linear, self).__init__(in_channels=in_channels,
out_channels=out_channels,
weight_init=weight_init,
bias_init=bias_init,
has_bias=has_bias,
activation=activation)
if param_init_type not in [mstype.float32, mstype.float16]:
raise TypeError(f"param type should in [float32, float16], but found type {type(param_init_type)}")
if activation and not isinstance(activation, str):
raise ValueError("Activation can only be str, but found type {}".format(activation))
if isinstance(weight_init, Tensor):
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("Weight init shape error.")
if transpose_b:
weight_shape = [out_channels, in_channels]
else:
weight_shape = [in_channels, out_channels]
self.weight = Parameter(initializer(weight_init, weight_shape, param_init_type), name="weight")
self.matmul = P.MatMul(transpose_b=transpose_b)
self.bias = None
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
raise ValueError("Bias init shape error.")
self.bias = Parameter(initializer(bias_init, [out_channels], param_init_type), name="bias")
self.bias_add = P.BiasAdd()
self.act_name = activation
self.dtype = compute_dtype
self.cast = P.Cast()
self.has_bias = self.has_bias
def construct(self, x):
out_shape = P.Shape()(x)[:-1] + (self.out_channels,)
x = P.Reshape()(x, (-1, self.in_channels))
weight = self.cast(self.weight, self.dtype)
x = self.matmul(x, weight)
x = self.bias_add(x, self.cast(self.bias, self.dtype))
output = P.Reshape()(x, out_shape)
if self.activation_flag:
output = self.activation(output)
return output
def shard(self, strategy_matmul, strategy_bias=None, strategy_activation=None):
r"""
Set the shard for the linear. the strategy size should be equal to the inputs.
Note:
It is valid only in semi auto parallel or auto parallel mode.
In other parallel modes, strategies set here will be ignored.
Args:
strategy_matmul (tuple): The strategy for the matmul. Should be the same shape as the inputs.
strategy_bias (tuple): The strategy for the bias_add. Should be the same shape as the inputs.
strategy_activation (tuple): The strategy for the strategy_activation. Should be the same shape as
the inputs.
"""
self.matmul.shard(strategy_matmul)
if self.has_bias:
self.bias_add.shard(strategy_bias)
if self.activation_flag:
getattr(self.activation, self.act_name).shard(strategy_activation)
return self

File diff suppressed because it is too large Load Diff

View File

@ -17,38 +17,40 @@ import numpy as np
from mindspore import Tensor
from mindspore.common import dtype
from mindspore.nn.parallel import MultiHeadAttention, FeedForward, TransformerEncoderLayer, TransformerEncoder, \
TransformerDecoder, TransformerDecoderLayer, Transformer
TransformerDecoder, TransformerDecoderLayer, Transformer, CrossEntropyLoss, AttentionMask
from mindspore.common.api import _executor
def test_transformer_encoder_only():
model = Transformer(encoder_layers=2,
model = Transformer(batch_size=2,
src_seq_length=20,
tgt_seq_length=0,
encoder_layers=2,
decoder_layers=0,
hidden_size=64,
ffn_hidden_size=64,
src_seq_length=16,
tgt_seq_length=32)
ffn_hidden_size=64)
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
encoder_input_mask = Tensor(np.ones((2, 1, 20, 20)), dtype.float16)
encoder_input_mask = Tensor(np.ones((2, 20, 20)), dtype.float16)
_executor.compile(model, encoder_input_value, encoder_input_mask)
def test_encoder_and_decoder():
model = Transformer(encoder_layers=1,
model = Transformer(batch_size=2,
src_seq_length=20,
tgt_seq_length=10,
encoder_layers=1,
decoder_layers=2,
hidden_size=64,
ffn_hidden_size=64,
src_seq_length=20,
tgt_seq_length=20)
ffn_hidden_size=64)
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
encoder_input_mask = Tensor(np.ones((2, 1, 20, 20)), dtype.float16)
encoder_input_mask = Tensor(np.ones((2, 20, 20)), dtype.float16)
decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
decoder_input_mask = Tensor(np.ones((2, 1, 10, 10)), dtype.float16)
memory_mask = Tensor(np.ones((2, 1, 10, 20)), dtype.float16)
decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
_executor.compile(model, encoder_input_value, encoder_input_mask,
decoder_input_value,
@ -57,14 +59,15 @@ def test_encoder_and_decoder():
def test_transformer_encoder():
model = TransformerEncoder(num_layers=2,
model = TransformerEncoder(batch_size=2,
seq_length=16,
num_layers=2,
hidden_size=8,
ffn_hidden_size=64,
seq_length=16,
num_heads=2)
encoder_input_value = Tensor(np.ones((2, 16, 8)), dtype.float32)
encoder_input_mask = Tensor(np.ones((2, 1, 16, 16)), dtype.float16)
encoder_input_mask = Tensor(np.ones((2, 16, 16)), dtype.float16)
_executor.compile(model,
encoder_input_value,
@ -72,11 +75,11 @@ def test_transformer_encoder():
def test_transformer_encoder_layer():
model = TransformerEncoderLayer(hidden_size=8, ffn_hidden_size=64, seq_length=16,
model = TransformerEncoderLayer(batch_size=2, hidden_size=8, ffn_hidden_size=64, seq_length=16,
num_heads=2)
encoder_input_value = Tensor(np.ones((2, 16, 8)), dtype.float32)
encoder_input_mask = Tensor(np.ones((2, 1, 16, 16)), dtype.float16)
encoder_input_mask = Tensor(np.ones((2, 16, 16)), dtype.float16)
_executor.compile(model,
encoder_input_value,
@ -84,11 +87,13 @@ def test_transformer_encoder_layer():
def test_transformer_encoder_layer_post_ture():
model = TransformerEncoderLayer(hidden_size=8, ffn_hidden_size=64, seq_length=16,
model = TransformerEncoderLayer(batch_size=2,
seq_length=16,
hidden_size=8, ffn_hidden_size=64,
num_heads=2, post_layernorm_residual=True)
encoder_input_value = Tensor(np.ones((2, 16, 8)), dtype.float32)
encoder_input_mask = Tensor(np.ones((2, 1, 16, 16)), dtype.float16)
encoder_input_mask = Tensor(np.ones((2, 16, 16)), dtype.float16)
_executor.compile(model,
encoder_input_value,
@ -97,16 +102,18 @@ def test_transformer_encoder_layer_post_ture():
def test_transformer_decoder():
model = TransformerDecoder(num_layers=1,
batch_size=2,
src_seq_length=20,
tgt_seq_length=10,
hidden_size=64,
ffn_hidden_size=64,
num_heads=2,
seq_length=10)
num_heads=2)
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
decoder_input_mask = Tensor(np.ones((2, 1, 10, 10)), dtype.float16)
memory_mask = Tensor(np.ones((2, 1, 10, 20)), dtype.float16)
decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
_executor.compile(model, decoder_input_value, decoder_input_mask,
encoder_input_value,
@ -115,16 +122,18 @@ def test_transformer_decoder():
def test_transformer_decoder_layer():
model = TransformerDecoderLayer(
batch_size=2,
src_seq_length=20,
tgt_seq_length=10,
hidden_size=64,
ffn_hidden_size=64,
num_heads=2,
seq_length=10)
num_heads=2)
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
decoder_input_mask = Tensor(np.ones((2, 1, 10, 10)), dtype.float16)
memory_mask = Tensor(np.ones((2, 1, 10, 20)), dtype.float16)
decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
_executor.compile(model, decoder_input_value, decoder_input_mask,
encoder_input_value,
@ -133,12 +142,15 @@ def test_transformer_decoder_layer():
def test_multihead_attention():
model = MultiHeadAttention(hidden_size=15,
src_seq_length=20,
tgt_seq_length=20,
batch_size=2,
num_heads=3)
from_tensor = Tensor(np.ones((2, 20, 15)), dtype.float32)
to_tensor = Tensor(np.ones((2, 20, 15)), dtype.float16)
attention_mask = Tensor(np.ones((2, 1, 20, 20)), dtype.float16)
attention_mask = Tensor(np.ones((2, 20, 20)), dtype.float16)
_executor.compile(model, from_tensor, to_tensor, attention_mask)
_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
def test_feedforward_layer():
@ -149,3 +161,18 @@ def test_feedforward_layer():
tensor = Tensor(np.ones((2, 20, 15)), dtype.float32)
_executor.compile(model, tensor)
def test_cross_entroy():
model = CrossEntropyLoss()
logits = Tensor(np.array([[3, 5, 6, 9, 12, 33, 42, 12, 32, 72]]), dtype.float32)
labels_np = np.array([1]).astype(np.int32)
input_mask = Tensor(np.ones(1).astype(np.float32))
labels = Tensor(labels_np)
_executor.compile(model, logits, labels, input_mask)
def test_attention_mask():
model = AttentionMask(seq_length=19)
inputs = Tensor(np.ones((2, 19)), dtype.float32)
_executor.compile(model, inputs)

View File

@ -13,14 +13,21 @@
# limitations under the License.
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.context import set_auto_parallel_context, ParallelMode
from mindspore.ops import composite as C
from mindspore.nn.parallel import TransformerEncoder, TransformerDecoder, Transformer, TransformerParallelConfig,\
VocabEmbedding
from mindspore.ops import functional as F
import mindspore.ops as P
from mindspore.nn.parallel import TransformerEncoder, TransformerDecoder, Transformer, TransformerOpParallelConfig, \
VocabEmbedding, CrossEntropyLoss, OpParallelConfig, EmbeddingOpParallelConfig
from mindspore.nn import Dense as Linear
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.optim import AdamWeightDecay
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, TrainOneStepCell
from mindspore.nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
from mindspore.train import Model
from tests.dataset_mock import MindData
from tests.ut.python.ops.test_math_ops import VirtualLoss
@ -48,39 +55,159 @@ class Dataset(MindData):
self.index = 0
def test_transformer_model():
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
pipeline_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, pipeline_stage=4,
micro_batch_num=4, vocab_emb_dp=False)
def construct(self, x1, x2, x3, x4, x5):
class NetWithLossFiveInputs(nn.Cell):
def __init__(self, network):
super(NetWithLossFiveInputs, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x1, x2, x3, x4, x5):
predict, _, _ = self.network(x1, x2, x3, x4, x5)
return self.loss(predict)
def run_total_transformer_model_head(e_layer,
d_layer,
arg_parallel_config):
dp = arg_parallel_config.data_parallel
mp = arg_parallel_config.model_parallel
pp = arg_parallel_config.pipeline_stage
if dp * mp * pp != 1:
set_auto_parallel_context(device_num=8,
full_batch=True,
global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
class Net(nn.Cell):
def __init__(self, en_layer, de_layer, parallel_config):
super(Net, self).__init__()
self.embedding = VocabEmbedding(vocab_size=240, embedding_size=20,
parallel_config=config.embedding_dp_mp_config)
self.network = Transformer(encoder_layers=en_layer,
decoder_layers=de_layer,
batch_size=2,
src_seq_length=20,
tgt_seq_length=10,
hidden_size=64,
num_heads=8,
ffn_hidden_size=64,
parallel_config=parallel_config)
self.head = Linear(in_channels=64, out_channels=200)
self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config)
def construct(self, x1, x2, x3, x4, x5, y, mask):
predict, _, _ = self.network(x1, x2, x3, x4, x5)
return self.loss(predict)
predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
return self.loss(predict, y, mask)
config = TransformerParallelConfig(dp=1, mp=8)
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
seq = 20
if d_layer > 0:
seq = 10
label = Tensor(np.ones((2 * seq,)), mstype.int32)
input_mask = Tensor(np.ones((2 * seq,)), mstype.float32)
net = Net(en_layer=e_layer, de_layer=d_layer, parallel_config=arg_parallel_config)
params = net.trainable_params()
optimizer = AdamWeightDecay(params)
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
memory_mask, label, input_mask)
net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
model = Model(net_with_grad)
model.train(1, dataset, dataset_sink_mode=False)
def test_transformer_model():
set_auto_parallel_context(device_num=8, global_rank=0,
full_batch=True,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
net = Transformer(encoder_layers=1,
decoder_layers=2,
batch_size=2,
src_seq_length=20,
tgt_seq_length=10,
hidden_size=64,
num_heads=8,
ffn_hidden_size=64,
parallel_config=config)
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
net = NetWithLossFiveInputs(net)
params = net.trainable_params()
optimizer = AdamWeightDecay(params)
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
memory_mask)
net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
model = Model(net_with_grad)
model.train(1, dataset, dataset_sink_mode=False)
def test_transformer_model_head_parallel_only_encoder():
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
run_total_transformer_model_head(e_layer=2, d_layer=0, arg_parallel_config=local_config)
def test_transformer_model_head_parallel():
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
run_total_transformer_model_head(e_layer=1, d_layer=1, arg_parallel_config=local_config)
def test_transformer_model_head_parallel_decoder():
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
with pytest.raises(ValueError):
run_total_transformer_model_head(e_layer=0, d_layer=1, arg_parallel_config=local_config)
def test_transformer_model_head_stand_alone():
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=1)
run_total_transformer_model_head(e_layer=2, d_layer=2, arg_parallel_config=local_config)
def test_pipeline_single_transformer():
set_auto_parallel_context(device_num=32,
full_batch=True,
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
net = Transformer(batch_size=4 // pipeline_config.micro_batch_num,
src_seq_length=20,
tgt_seq_length=10,
encoder_layers=2,
decoder_layers=2,
hidden_size=64,
num_heads=8,
ffn_hidden_size=64,
src_seq_length=20,
tgt_seq_length=20,
parallel_config=config)
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
encoder_input_mask = Tensor(np.ones((2, 1, 20, 20)), mstype.float16)
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
decoder_input_mask = Tensor(np.ones((2, 1, 10, 10)), mstype.float16)
memory_mask = Tensor(np.ones((2, 1, 10, 20)), mstype.float16)
net = NetWithLoss(net)
parallel_config=pipeline_config)
encoder_input_value = Tensor(np.ones((4, 20, 64)), mstype.float32)
encoder_input_mask = Tensor(np.ones((4, 20, 20)), mstype.float16)
decoder_input_value = Tensor(np.ones((4, 10, 64)), mstype.float32)
decoder_input_mask = Tensor(np.ones((4, 10, 10)), mstype.float16)
memory_mask = Tensor(np.ones((4, 10, 20)), mstype.float16)
net = NetWithLossFiveInputs(net)
net = PipelineCell(net, pipeline_config.micro_batch_num)
net = _VirtualDatasetCell(net)
params = net.infer_param_pipeline_stage()
optimizer = AdamWeightDecay(params)
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
memory_mask)
model = Model(net)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=1024, scale_factor=2, scale_window=1000)
net_with_grad = _TrainPipelineWithLossScaleCell(net, optimizer=optimizer,
scale_sense=update_cell)
model = Model(net_with_grad)
model.train(1, dataset, dataset_sink_mode=False)
@ -96,17 +223,19 @@ def test_encoder():
predict, _ = self.network(x1, x2)
return self.loss(predict)
config = TransformerParallelConfig(dp=1, mp=8)
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
set_auto_parallel_context(device_num=8,
full_batch=True,
global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
net = TransformerEncoder(num_layers=2,
batch_size=2,
seq_length=16,
hidden_size=8,
ffn_hidden_size=64,
seq_length=16,
num_heads=8,
parallel_config=config)
encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32)
encoder_input_mask = Tensor(np.ones((2, 1, 16, 16)), mstype.float16)
encoder_input_mask = Tensor(np.ones((2, 16, 16)), mstype.float16)
net = NetWithLoss(net)
@ -128,19 +257,22 @@ def test_decoder():
predict, _, _ = self.network(x1, x2, x3, x4)
return self.loss(predict)
config = TransformerParallelConfig(dp=1, mp=8)
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
set_auto_parallel_context(device_num=8,
full_batch=True,
global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
net = TransformerDecoder(num_layers=1,
batch_size=8,
hidden_size=16,
ffn_hidden_size=8,
num_heads=8,
seq_length=10,
src_seq_length=20,
tgt_seq_length=10,
parallel_config=config)
encoder_input_value = Tensor(np.ones((2, 20, 16)), mstype.float32)
decoder_input_value = Tensor(np.ones((2, 10, 16)), mstype.float32)
decoder_input_mask = Tensor(np.ones((2, 1, 10, 10)), mstype.float16)
memory_mask = Tensor(np.ones((2, 1, 10, 20)), mstype.float16)
encoder_input_value = Tensor(np.ones((8, 20, 16)), mstype.float32)
decoder_input_value = Tensor(np.ones((8, 10, 16)), mstype.float32)
decoder_input_mask = Tensor(np.ones((8, 10, 10)), mstype.float16)
memory_mask = Tensor(np.ones((8, 10, 20)), mstype.float16)
net = NetWithLoss(net)
@ -151,7 +283,6 @@ def test_decoder():
def test_vocabembedding_dp_true():
config = TransformerParallelConfig(dp=1, mp=8)
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
class NetWithLoss(nn.Cell):
@ -164,15 +295,7 @@ def test_vocabembedding_dp_true():
predict, _ = self.network(x1)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x1):
return grad_all(self.network)(x1)
net = VocabEmbedding(vocab_size=100, embedding_size=16, parallel_config=config)
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
net = NetWithLoss(net)
encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
dataset = Dataset(encoder_input_value)
@ -182,7 +305,6 @@ def test_vocabembedding_dp_true():
def test_vocabembedding_dp_false():
config = TransformerParallelConfig(dp=1, mp=8, vocab_emb_dp=False)
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
class NetWithLoss(nn.Cell):
@ -195,18 +317,109 @@ def test_vocabembedding_dp_false():
predict, _ = self.network(x1)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x1):
return grad_all(self.network)(x1)
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config)
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
net = NetWithLoss(net)
encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
dataset = Dataset(encoder_input_value)
model = Model(net)
model.train(1, dataset, dataset_sink_mode=False)
def test_parallel_cross_entroy_loss_semi_auto_parallel():
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
class NetWithLoss(nn.Cell):
def __init__(self, network, config_setting):
super(NetWithLoss, self).__init__()
self.loss = CrossEntropyLoss(config_setting)
self.network = network
def construct(self, x1, x2, x3):
predict, _ = self.network(x1)
predict = P.Reshape()(predict, (-1, 16))
return self.loss(predict, x2, x3)
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
net = NetWithLoss(net, config.dp_mp_config)
embed_ids = Tensor(np.ones((2, 64)), mstype.int32)
labels = Tensor(np.ones((2 * 64,)), mstype.int32)
input_mask = Tensor(np.ones((2 * 64,)), mstype.float32)
dataset = Dataset(embed_ids, labels, input_mask)
model = Model(net)
model.train(1, dataset, dataset_sink_mode=False)
def test_transformer_parallel_config():
parallel_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=3)
with pytest.raises(TypeError):
parallel_test_config.data_parallel = False
with pytest.raises(ValueError):
parallel_test_config.data_parallel = 0
with pytest.raises(TypeError):
parallel_test_config.model_parallel = False
with pytest.raises(ValueError):
parallel_test_config.model_parallel = 0
with pytest.raises(TypeError):
parallel_test_config.pipeline_stage = False
with pytest.raises(ValueError):
parallel_test_config.pipeline_stage = 0
with pytest.raises(TypeError):
parallel_test_config.micro_batch_num = False
with pytest.raises(ValueError):
parallel_test_config.micro_batch_num = 0
with pytest.raises(TypeError):
parallel_test_config.gradient_aggregation_group = False
with pytest.raises(ValueError):
parallel_test_config.gradient_aggregation_group = 0
with pytest.raises(TypeError):
parallel_test_config.recompute = 1
parallel_test_config.recompute = False
assert not parallel_test_config.recompute
def test_parallel_config():
parallel_test_config = OpParallelConfig(data_parallel=1, model_parallel=3)
with pytest.raises(ValueError):
parallel_test_config.data_parallel = 0
with pytest.raises(TypeError):
parallel_test_config.model_parallel = False
with pytest.raises(ValueError):
parallel_test_config.model_parallel = 0
assert parallel_test_config.model_parallel == 3
def test_embedding_parallel_config():
parallel_test_config = EmbeddingOpParallelConfig(data_parallel=1, model_parallel=3, vocab_emb_dp=False)
with pytest.raises(ValueError):
parallel_test_config.data_parallel = 0
with pytest.raises(TypeError):
parallel_test_config.model_parallel = False
with pytest.raises(ValueError):
parallel_test_config.model_parallel = 0
with pytest.raises(TypeError):
parallel_test_config.vocab_emb_dp = 0
assert not parallel_test_config.vocab_emb_dp