forked from mindspore-Ecosystem/mindspore
!21809 Improved Transformer Struture and Add Args Check
Merge pull request !21809 from huangxinjing/transformer_improved
This commit is contained in:
commit
0d839fa7c6
|
@ -17,5 +17,10 @@ Parallel Networks.
|
|||
This is an experimental interface that is subject to change and/or deletion.
|
||||
"""
|
||||
from .transformer import *
|
||||
from .loss import *
|
||||
from .config import *
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(transformer.__all__)
|
||||
__all__.extend(loss.__all__)
|
||||
__all__.extend(config.__all__)
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Parallel Config for the Parallel Training
|
||||
This is an experimental interface that is subject to change and/or deletion.
|
||||
"""
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore import context
|
||||
import mindspore.communication.management as D
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.parallel._utils import _get_parallel_mode
|
||||
|
||||
__all__ = [
|
||||
"OpParallelConfig"
|
||||
]
|
||||
|
||||
|
||||
class _Config:
|
||||
r""" A basic class of the configure"""
|
||||
|
||||
def __str__(self):
|
||||
info = "[ParallelConfig]" + '\n'
|
||||
for k, v in self.__dict__.items():
|
||||
var_info = "{}:{}\n".format(k, v)
|
||||
info += var_info
|
||||
return info
|
||||
|
||||
|
||||
class OpParallelConfig(_Config):
|
||||
r"""
|
||||
OpParallelConfig for the setting the data parallel and model parallel.
|
||||
|
||||
Args:
|
||||
data_parallel (int): The data parallel way. Default: 1
|
||||
model_parallel (int): The model parallel way. Default: 1
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.nn.parallel import OpParallelConfig
|
||||
>>> config=OpParallelConfig(data_parallel=1, model_parallel=1)
|
||||
"""
|
||||
|
||||
def __init__(self, data_parallel=1, model_parallel=1):
|
||||
Validator.check_positive_int(data_parallel, "data_parallel")
|
||||
Validator.check_positive_int(model_parallel, "model_parallel")
|
||||
self._data_parallel = data_parallel
|
||||
self._model_parallel = model_parallel
|
||||
|
||||
@property
|
||||
def data_parallel(self):
|
||||
return self._data_parallel
|
||||
|
||||
@data_parallel.setter
|
||||
def data_parallel(self, value):
|
||||
Validator.check_positive_int(value, "data_parallel")
|
||||
self._data_parallel = value
|
||||
|
||||
@property
|
||||
def model_parallel(self):
|
||||
return self._model_parallel
|
||||
|
||||
@model_parallel.setter
|
||||
def model_parallel(self, value):
|
||||
Validator.check_positive_int(value, "model_parallel")
|
||||
self._model_parallel = value
|
||||
|
||||
|
||||
class _PipeLineConfig(_Config):
|
||||
r"""
|
||||
PPConfig for the setting the data parallel, model parallel
|
||||
|
||||
Args:
|
||||
pipeline_stage (int): The number of the pipeline stages. Default: 1
|
||||
micro_batch_num (int): The model parallel way. Default: 1
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> config=_PipeLineConfig(pipeline_stage=1, micro_batch_num=1)
|
||||
"""
|
||||
|
||||
def __init__(self, pipeline_stage=1, micro_batch_num=1):
|
||||
Validator.check_positive_int(pipeline_stage, "pipeline_stage")
|
||||
Validator.check_positive_int(micro_batch_num, "micro_batch_num")
|
||||
self._pipeline_stage = pipeline_stage
|
||||
self._micro_batch_num = micro_batch_num
|
||||
|
||||
@property
|
||||
def pipeline_stage(self):
|
||||
return self._pipeline_stage
|
||||
|
||||
@pipeline_stage.setter
|
||||
def pipeline_stage(self, value):
|
||||
Validator.check_positive_int(value, "pipeline_stage")
|
||||
self._pipeline_stage = value
|
||||
context.set_auto_parallel_context(pipeline_stages=value)
|
||||
|
||||
@property
|
||||
def micro_batch_num(self):
|
||||
return self._micro_batch_num
|
||||
|
||||
@micro_batch_num.setter
|
||||
def micro_batch_num(self, value):
|
||||
Validator.check_positive_int(value, "micro_batch_num")
|
||||
self._micro_batch_num = value
|
||||
|
||||
|
||||
# In case the user doesn't pass a config as args.
|
||||
default_dpmp_config = OpParallelConfig()
|
||||
|
||||
|
||||
def _check_config(config):
|
||||
"""
|
||||
Check if micro_batch_num >= pipeline_stage
|
||||
"""
|
||||
# the config pipeline_stage is same with context.pipeline_stage
|
||||
pipeline_stage = context.get_auto_parallel_context("pipeline_stages")
|
||||
if hasattr(config, 'pipeline_stage') and pipeline_stage != config.pipeline_stage:
|
||||
raise ValueError(
|
||||
f"The pipeline stage {pipeline_stage} in auto_parallel_context is not equal to the pipeline_stage "
|
||||
f"{config.pipeline_stage}"
|
||||
f" in the config.")
|
||||
|
||||
# make sure the following is in auto parallel mode
|
||||
is_auto_parallel = _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
||||
if not is_auto_parallel:
|
||||
return
|
||||
|
||||
device_num = D.get_group_size()
|
||||
optimizer_shard = context.get_auto_parallel_context("enable_parallel_optimizer")
|
||||
|
||||
# dp * pp * pipeline_stage <= device_num
|
||||
if config.data_parallel * config.model_parallel * pipeline_stage > device_num:
|
||||
raise ValueError("The product of the data parallel {config.data_parallel},"
|
||||
"model parallel {config.model_parallel}"
|
||||
"pipeline stages {pipeline_stage}"
|
||||
"should be less than device_num {device_num}")
|
||||
|
||||
# the config optimizer_shard is same with context.optimizer_shard
|
||||
if hasattr(config, "optimizer_shard") and optimizer_shard != config.optimizer_shard:
|
||||
raise ValueError(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the"
|
||||
f"optimizer_shard {config.optimizer_shard} in the config")
|
||||
|
||||
# pipeline_stage <= micro_batch_num
|
||||
if hasattr(config, 'pipeline_stage') and hasattr(config, 'micro_batch_num')\
|
||||
and config.pipeline_stage < config.micro_batch_num:
|
||||
raise ValueError(
|
||||
f"The pipeline stage {config.pipeline_stage} should be greater than the micro_batch_num"
|
||||
f"{config.micro_batch_num}.")
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Parallel Loss for the Parallel Training
|
||||
This is an experimental interface that is subject to change and/or deletion.
|
||||
"""
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.nn.loss.loss import _check_is_tensor
|
||||
from mindspore.nn.parallel.transformer.transformer import _check_input_dtype, _check_input_shape
|
||||
from .config import default_dpmp_config, OpParallelConfig
|
||||
|
||||
__all__ = ["CrossEntropyLoss"]
|
||||
|
||||
|
||||
class CrossEntropyLoss(Cell):
|
||||
"""
|
||||
Calculate the cross entropy loss.
|
||||
Args:
|
||||
parallel_config(OpParallelConfig): the configure of the parallel. Default:'default_dpmp_config'
|
||||
Inputs:
|
||||
- **logits** (Tensor) - Tensor of shape (N, C). Data type must be float16 or float32. the output logits of
|
||||
the backbone.
|
||||
- **labels** (Tensor) - Tensor of shape (N, ). The ground truth label of the sample.
|
||||
- **input_mask** (Tensor): Tensor of shape (N, ). input_mask indicates whether there is padded inputs and for
|
||||
padded inputs it will not be counted into loss.
|
||||
Returns:
|
||||
loss: Tensor, the corresponding cross entropy loss
|
||||
|
||||
Exapmes:
|
||||
>>> loss = nn.parallel.CrossEntropyLoss()
|
||||
>>> logits = Tensor(np.array([[3, 5, 6, 9, 12, 33, 42, 12, 32, 72]]), mindspore.float32)
|
||||
>>> labels_np = np.array([1]).astype(np.int32)
|
||||
>>> input_mask = Tensor(np.ones(1).astype(np.float32))
|
||||
>>> labels = Tensor(labels_np)
|
||||
>>> output = loss(logits, labels, input_mask)
|
||||
>>> print(output.shape)
|
||||
(1,)
|
||||
"""
|
||||
|
||||
def __init__(self, parallel_config=default_dpmp_config):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
if not isinstance(parallel_config, OpParallelConfig):
|
||||
raise TypeError("Input args parallel_config must be the type OpParallelConfig.")
|
||||
dp = parallel_config.data_parallel
|
||||
mp = parallel_config.model_parallel
|
||||
self.sum = P.ReduceSum().shard(((dp, mp),))
|
||||
self.onehot = P.OneHot().shard(((dp, mp), (), ()))
|
||||
# on/off value for onehot, for smooth labeling, modify the off_value
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.max = P.ArgMaxWithValue(axis=-1, keep_dims=True).shard(
|
||||
((dp, mp),))
|
||||
self.eps_const = Tensor(1e-24, mstype.float32)
|
||||
self.sub = P.Sub().shard(((dp, mp), (dp, 1)))
|
||||
self.exp = P.Exp().shard(((dp, mp),))
|
||||
self.div = P.RealDiv().shard(((dp, mp), (dp, 1)))
|
||||
self.log = P.Log().shard(((dp, mp),))
|
||||
self.add = P.TensorAdd().shard(((dp, mp), ()))
|
||||
self.mul = P.Mul().shard(
|
||||
((dp, mp), (dp, mp)))
|
||||
self.neg = P.Neg().shard(((dp, mp),))
|
||||
self.sum2 = P.ReduceSum().shard(((1,),))
|
||||
|
||||
self.mul2 = P.Mul().shard(((1,), (1,)))
|
||||
self.add2 = P.TensorAdd()
|
||||
self.div2 = P.RealDiv()
|
||||
|
||||
def construct(self, logits, label, input_mask):
|
||||
r"""
|
||||
Compute loss using logits, label and input mask
|
||||
"""
|
||||
self._check_input(logits, label, input_mask)
|
||||
|
||||
# [bs*seq_length, vocab_size]
|
||||
logits = F.cast(logits, mstype.float32)
|
||||
# LogSoftmax for logits over last dimension
|
||||
_, logit_max = self.max(logits)
|
||||
logit_sub = self.sub(logits, logit_max)
|
||||
logit_exp = self.exp(logit_sub)
|
||||
exp_sum = self.sum(logit_exp, -1)
|
||||
exp_sum = P.Reshape()(exp_sum, (F.shape(exp_sum)[0], 1))
|
||||
softmax_result = self.div(logit_exp, exp_sum)
|
||||
log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
|
||||
|
||||
# Flatten label to [bs*seq_length]
|
||||
label = P.Reshape()(label, (-1,))
|
||||
# Get onehot label [bs*seq_length, vocab_size]
|
||||
one_hot_label = self.onehot(label, F.shape(logits)[-1], self.on_value,
|
||||
self.off_value)
|
||||
# Cross-Entropy loss
|
||||
loss = self.mul(log_softmax_result, one_hot_label)
|
||||
loss_unsum = self.neg(loss)
|
||||
loss_reduce = self.sum(loss_unsum, -1)
|
||||
# input_mask indicates whether there is padded inputs and for padded inputs it will not be counted into loss
|
||||
input_mask = P.Reshape()(input_mask, (-1,))
|
||||
numerator = self.sum2(self.mul2(loss_reduce, input_mask))
|
||||
|
||||
denominator = self.add2(
|
||||
self.sum2(input_mask),
|
||||
P.Cast()(F.tuple_to_array((1e-5,)), mstype.float32))
|
||||
loss = self.div2(numerator, denominator)
|
||||
return loss
|
||||
|
||||
def _check_input(self, logits, label, input_mask):
|
||||
r"""Check the input tensor shape and type"""
|
||||
_check_is_tensor('logits', logits, self.cls_name)
|
||||
_check_is_tensor('label', label, self.cls_name)
|
||||
_check_is_tensor('input_mask', input_mask, self.cls_name)
|
||||
_check_input_dtype(F.dtype(logits), "logits", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(label), "label", [mstype.int32], self.cls_name)
|
||||
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32], self.cls_name)
|
||||
_check_input_shape(F.shape(logits), "logits", self.cls_name, 2)
|
||||
_check_input_shape(F.shape(label), "label", self.cls_name, 1)
|
||||
_check_input_shape(F.shape(input_mask), "input_mask", self.cls_name, 1)
|
||||
return True
|
|
@ -17,6 +17,7 @@ Transformer Networks
|
|||
This is an experimental interface that is subject to change and/or deletion.
|
||||
"""
|
||||
from .transformer import *
|
||||
from .layers import *
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(transformer.__all__)
|
||||
|
|
|
@ -0,0 +1,222 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
The basic layer of the Transformer Networks. This is an experimental interface that is subject to
|
||||
change and/or deletion.
|
||||
"""
|
||||
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer, Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore._extends import cell_attr_register
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer import Dense
|
||||
|
||||
|
||||
class _LayerNorm(Cell):
|
||||
r"""
|
||||
A self-defined layer norm operation using reduce sum and reduce mean
|
||||
|
||||
Args:
|
||||
normalized_shape (tuple): The shape of the input tensor
|
||||
eps (float): The epsilon value of the denominator. Default 1e-5.
|
||||
param_init_type: The param init type.
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(batch, seq\_length, hidden\_size)`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(batch, seq_length, hidden_size)`.
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-5, param_init_type=mstype.float32):
|
||||
super(_LayerNorm, self).__init__()
|
||||
if param_init_type not in [mstype.float32, mstype.float16]:
|
||||
raise TypeError(f"param type should in [float32, float16], but found type {type(param_init_type)}")
|
||||
self.gamma = Parameter(initializer('ones', normalized_shape, param_init_type), name="gamma",
|
||||
parallel_optimizer=False)
|
||||
self.beta = Parameter(initializer('zeros', normalized_shape, param_init_type), name="beta",
|
||||
parallel_optimizer=False)
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.square = P.Square()
|
||||
self.sqrt = P.Sqrt()
|
||||
self.sub1 = P.Sub()
|
||||
self.sub2 = P.Sub()
|
||||
self.add = P.TensorAdd()
|
||||
self.eps = eps
|
||||
self.mul = P.Mul()
|
||||
self.add2 = P.TensorAdd()
|
||||
self.real_div = P.RealDiv()
|
||||
|
||||
def construct(self, x):
|
||||
r"""
|
||||
x : batch x seq_length x hidden_size
|
||||
"""
|
||||
mean = self.mean(x, -1)
|
||||
diff = self.sub1(x, mean)
|
||||
variance = self.mean(self.square(diff), -1)
|
||||
variance_eps = self.sqrt(self.add(variance, self.eps))
|
||||
output = self.real_div(diff, variance_eps)
|
||||
output = self.add2(self.mul(output, self.gamma), self.beta)
|
||||
return output
|
||||
|
||||
def shard(self, strategy):
|
||||
r"""
|
||||
Set the shard for the layer norm. the strategy size should be equal to the inputs.
|
||||
|
||||
Note:
|
||||
It is valid only in semi auto parallel or auto parallel mode.
|
||||
In other parallel modes, strategies set here will be ignored.
|
||||
|
||||
Args:
|
||||
strategy (tuple): The strategy for the dropout. Should be the same shape as the inputs.
|
||||
Examples:
|
||||
>>> net = nn.parallel.transformer.LayerNorm(normalized_shape=(1024, 10))
|
||||
>>> net.shard(((10, 2, 1),))
|
||||
"""
|
||||
self.mean.shard(strategy)
|
||||
self.square.shard(strategy)
|
||||
self.sqrt.shard(strategy)
|
||||
self.sub1.shard((strategy[0], strategy[0]))
|
||||
self.sub2.shard((strategy[0], strategy[0]))
|
||||
self.add.shard((strategy[0], ()))
|
||||
self.mul.shard((strategy[0], (1,)))
|
||||
self.add2.shard((strategy[0], (1,)))
|
||||
self.real_div.shard((strategy[0], strategy[0]))
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class _Linear(Dense):
|
||||
r"""
|
||||
The dense connected layer. Once the parallel mode is enabled, the input shape should be
|
||||
3-D tensor.
|
||||
|
||||
Applies dense connected layer for the input. This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{outputs} = \text{activation}(\text{X} * \text{kernel} + \text{bias}),
|
||||
|
||||
where :math:`X` is the input tensors, :math:`\text{activation}` is the activation function passed as the activation
|
||||
argument (if passed in), :math:`\text{kernel}` is a weight matrix with the same
|
||||
data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector
|
||||
with the same data type as the :math:`X` created by the layer (only if has_bias is True).
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input space.
|
||||
out_channels (int): The number of channels in the output space.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as `x`. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as `x`. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
activation (str): activate function applied to the output of the fully connected layer,
|
||||
eg. 'ReLU'.Default: None.
|
||||
compute_dtype (mstype): The computation type. Default: mstype.float16
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `in_channels` in `Args` should be equal
|
||||
to :math:`in\_channels` in `Inputs`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(*, out\_channels)`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `in_channels` or `out_channels` is not an int.
|
||||
TypeError: If `has_bias` is not a bool.
|
||||
TypeError: If `activation` is not one of str, Cell, Primitive, None.
|
||||
ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init`
|
||||
is not equal to `out_channels` or shape[1] of `weight_init` is not equal to `in_channels`.
|
||||
ValueError: If length of shape of `bias_init` is not equal to 1
|
||||
or shape[0] of `bias_init` is not equal to `out_channels`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
|
||||
@cell_attr_register(attrs=['has_bias', 'in_channels', 'out_channels', 'shard_output', 'activation'])
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
has_bias=True,
|
||||
activation=None,
|
||||
transpose_b=True,
|
||||
param_init_type=mstype.float32,
|
||||
compute_dtype=mstype.float16):
|
||||
super(_Linear, self).__init__(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
weight_init=weight_init,
|
||||
bias_init=bias_init,
|
||||
has_bias=has_bias,
|
||||
activation=activation)
|
||||
if param_init_type not in [mstype.float32, mstype.float16]:
|
||||
raise TypeError(f"param type should in [float32, float16], but found type {type(param_init_type)}")
|
||||
if activation and not isinstance(activation, str):
|
||||
raise ValueError("Activation can only be str, but found type {}".format(activation))
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
raise ValueError("Weight init shape error.")
|
||||
if transpose_b:
|
||||
weight_shape = [out_channels, in_channels]
|
||||
else:
|
||||
weight_shape = [in_channels, out_channels]
|
||||
self.weight = Parameter(initializer(weight_init, weight_shape, param_init_type), name="weight")
|
||||
self.matmul = P.MatMul(transpose_b=transpose_b)
|
||||
self.bias = None
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
|
||||
raise ValueError("Bias init shape error.")
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels], param_init_type), name="bias")
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.act_name = activation
|
||||
self.dtype = compute_dtype
|
||||
self.cast = P.Cast()
|
||||
self.has_bias = self.has_bias
|
||||
|
||||
def construct(self, x):
|
||||
out_shape = P.Shape()(x)[:-1] + (self.out_channels,)
|
||||
x = P.Reshape()(x, (-1, self.in_channels))
|
||||
weight = self.cast(self.weight, self.dtype)
|
||||
x = self.matmul(x, weight)
|
||||
x = self.bias_add(x, self.cast(self.bias, self.dtype))
|
||||
output = P.Reshape()(x, out_shape)
|
||||
if self.activation_flag:
|
||||
output = self.activation(output)
|
||||
return output
|
||||
|
||||
def shard(self, strategy_matmul, strategy_bias=None, strategy_activation=None):
|
||||
r"""
|
||||
Set the shard for the linear. the strategy size should be equal to the inputs.
|
||||
|
||||
Note:
|
||||
It is valid only in semi auto parallel or auto parallel mode.
|
||||
In other parallel modes, strategies set here will be ignored.
|
||||
|
||||
Args:
|
||||
strategy_matmul (tuple): The strategy for the matmul. Should be the same shape as the inputs.
|
||||
strategy_bias (tuple): The strategy for the bias_add. Should be the same shape as the inputs.
|
||||
strategy_activation (tuple): The strategy for the strategy_activation. Should be the same shape as
|
||||
the inputs.
|
||||
"""
|
||||
self.matmul.shard(strategy_matmul)
|
||||
if self.has_bias:
|
||||
self.bias_add.shard(strategy_bias)
|
||||
if self.activation_flag:
|
||||
getattr(self.activation, self.act_name).shard(strategy_activation)
|
||||
|
||||
return self
|
File diff suppressed because it is too large
Load Diff
|
@ -17,38 +17,40 @@ import numpy as np
|
|||
from mindspore import Tensor
|
||||
from mindspore.common import dtype
|
||||
from mindspore.nn.parallel import MultiHeadAttention, FeedForward, TransformerEncoderLayer, TransformerEncoder, \
|
||||
TransformerDecoder, TransformerDecoderLayer, Transformer
|
||||
TransformerDecoder, TransformerDecoderLayer, Transformer, CrossEntropyLoss, AttentionMask
|
||||
from mindspore.common.api import _executor
|
||||
|
||||
|
||||
def test_transformer_encoder_only():
|
||||
model = Transformer(encoder_layers=2,
|
||||
model = Transformer(batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=0,
|
||||
encoder_layers=2,
|
||||
decoder_layers=0,
|
||||
hidden_size=64,
|
||||
ffn_hidden_size=64,
|
||||
src_seq_length=16,
|
||||
tgt_seq_length=32)
|
||||
ffn_hidden_size=64)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 1, 20, 20)), dtype.float16)
|
||||
encoder_input_mask = Tensor(np.ones((2, 20, 20)), dtype.float16)
|
||||
|
||||
_executor.compile(model, encoder_input_value, encoder_input_mask)
|
||||
|
||||
|
||||
def test_encoder_and_decoder():
|
||||
model = Transformer(encoder_layers=1,
|
||||
model = Transformer(batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
encoder_layers=1,
|
||||
decoder_layers=2,
|
||||
hidden_size=64,
|
||||
ffn_hidden_size=64,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=20)
|
||||
ffn_hidden_size=64)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 1, 20, 20)), dtype.float16)
|
||||
encoder_input_mask = Tensor(np.ones((2, 20, 20)), dtype.float16)
|
||||
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((2, 1, 10, 10)), dtype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 1, 10, 20)), dtype.float16)
|
||||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
|
||||
|
||||
_executor.compile(model, encoder_input_value, encoder_input_mask,
|
||||
decoder_input_value,
|
||||
|
@ -57,14 +59,15 @@ def test_encoder_and_decoder():
|
|||
|
||||
|
||||
def test_transformer_encoder():
|
||||
model = TransformerEncoder(num_layers=2,
|
||||
model = TransformerEncoder(batch_size=2,
|
||||
seq_length=16,
|
||||
num_layers=2,
|
||||
hidden_size=8,
|
||||
ffn_hidden_size=64,
|
||||
seq_length=16,
|
||||
num_heads=2)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((2, 16, 8)), dtype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 1, 16, 16)), dtype.float16)
|
||||
encoder_input_mask = Tensor(np.ones((2, 16, 16)), dtype.float16)
|
||||
|
||||
_executor.compile(model,
|
||||
encoder_input_value,
|
||||
|
@ -72,11 +75,11 @@ def test_transformer_encoder():
|
|||
|
||||
|
||||
def test_transformer_encoder_layer():
|
||||
model = TransformerEncoderLayer(hidden_size=8, ffn_hidden_size=64, seq_length=16,
|
||||
model = TransformerEncoderLayer(batch_size=2, hidden_size=8, ffn_hidden_size=64, seq_length=16,
|
||||
num_heads=2)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((2, 16, 8)), dtype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 1, 16, 16)), dtype.float16)
|
||||
encoder_input_mask = Tensor(np.ones((2, 16, 16)), dtype.float16)
|
||||
|
||||
_executor.compile(model,
|
||||
encoder_input_value,
|
||||
|
@ -84,11 +87,13 @@ def test_transformer_encoder_layer():
|
|||
|
||||
|
||||
def test_transformer_encoder_layer_post_ture():
|
||||
model = TransformerEncoderLayer(hidden_size=8, ffn_hidden_size=64, seq_length=16,
|
||||
model = TransformerEncoderLayer(batch_size=2,
|
||||
seq_length=16,
|
||||
hidden_size=8, ffn_hidden_size=64,
|
||||
num_heads=2, post_layernorm_residual=True)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((2, 16, 8)), dtype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 1, 16, 16)), dtype.float16)
|
||||
encoder_input_mask = Tensor(np.ones((2, 16, 16)), dtype.float16)
|
||||
|
||||
_executor.compile(model,
|
||||
encoder_input_value,
|
||||
|
@ -97,16 +102,18 @@ def test_transformer_encoder_layer_post_ture():
|
|||
|
||||
def test_transformer_decoder():
|
||||
model = TransformerDecoder(num_layers=1,
|
||||
batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
hidden_size=64,
|
||||
ffn_hidden_size=64,
|
||||
num_heads=2,
|
||||
seq_length=10)
|
||||
num_heads=2)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
|
||||
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((2, 1, 10, 10)), dtype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 1, 10, 20)), dtype.float16)
|
||||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
|
||||
|
||||
_executor.compile(model, decoder_input_value, decoder_input_mask,
|
||||
encoder_input_value,
|
||||
|
@ -115,16 +122,18 @@ def test_transformer_decoder():
|
|||
|
||||
def test_transformer_decoder_layer():
|
||||
model = TransformerDecoderLayer(
|
||||
batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
hidden_size=64,
|
||||
ffn_hidden_size=64,
|
||||
num_heads=2,
|
||||
seq_length=10)
|
||||
num_heads=2)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), dtype.float32)
|
||||
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), dtype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((2, 1, 10, 10)), dtype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 1, 10, 20)), dtype.float16)
|
||||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), dtype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 10, 20)), dtype.float16)
|
||||
|
||||
_executor.compile(model, decoder_input_value, decoder_input_mask,
|
||||
encoder_input_value,
|
||||
|
@ -133,12 +142,15 @@ def test_transformer_decoder_layer():
|
|||
|
||||
def test_multihead_attention():
|
||||
model = MultiHeadAttention(hidden_size=15,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=20,
|
||||
batch_size=2,
|
||||
num_heads=3)
|
||||
from_tensor = Tensor(np.ones((2, 20, 15)), dtype.float32)
|
||||
to_tensor = Tensor(np.ones((2, 20, 15)), dtype.float16)
|
||||
attention_mask = Tensor(np.ones((2, 1, 20, 20)), dtype.float16)
|
||||
attention_mask = Tensor(np.ones((2, 20, 20)), dtype.float16)
|
||||
|
||||
_executor.compile(model, from_tensor, to_tensor, attention_mask)
|
||||
_executor.compile(model, from_tensor, to_tensor, to_tensor, attention_mask)
|
||||
|
||||
|
||||
def test_feedforward_layer():
|
||||
|
@ -149,3 +161,18 @@ def test_feedforward_layer():
|
|||
tensor = Tensor(np.ones((2, 20, 15)), dtype.float32)
|
||||
|
||||
_executor.compile(model, tensor)
|
||||
|
||||
|
||||
def test_cross_entroy():
|
||||
model = CrossEntropyLoss()
|
||||
logits = Tensor(np.array([[3, 5, 6, 9, 12, 33, 42, 12, 32, 72]]), dtype.float32)
|
||||
labels_np = np.array([1]).astype(np.int32)
|
||||
input_mask = Tensor(np.ones(1).astype(np.float32))
|
||||
labels = Tensor(labels_np)
|
||||
_executor.compile(model, logits, labels, input_mask)
|
||||
|
||||
|
||||
def test_attention_mask():
|
||||
model = AttentionMask(seq_length=19)
|
||||
inputs = Tensor(np.ones((2, 19)), dtype.float32)
|
||||
_executor.compile(model, inputs)
|
||||
|
|
|
@ -13,14 +13,21 @@
|
|||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import pytest
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.context import set_auto_parallel_context, ParallelMode
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.nn.parallel import TransformerEncoder, TransformerDecoder, Transformer, TransformerParallelConfig,\
|
||||
VocabEmbedding
|
||||
from mindspore.ops import functional as F
|
||||
import mindspore.ops as P
|
||||
from mindspore.nn.parallel import TransformerEncoder, TransformerDecoder, Transformer, TransformerOpParallelConfig, \
|
||||
VocabEmbedding, CrossEntropyLoss, OpParallelConfig, EmbeddingOpParallelConfig
|
||||
from mindspore.nn import Dense as Linear
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.nn.optim import AdamWeightDecay
|
||||
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, TrainOneStepCell
|
||||
from mindspore.nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
|
||||
from mindspore.train import Model
|
||||
from tests.dataset_mock import MindData
|
||||
from tests.ut.python.ops.test_math_ops import VirtualLoss
|
||||
|
@ -48,39 +55,159 @@ class Dataset(MindData):
|
|||
self.index = 0
|
||||
|
||||
|
||||
def test_transformer_model():
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
|
||||
pipeline_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, pipeline_stage=4,
|
||||
micro_batch_num=4, vocab_emb_dp=False)
|
||||
|
||||
def construct(self, x1, x2, x3, x4, x5):
|
||||
|
||||
class NetWithLossFiveInputs(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(NetWithLossFiveInputs, self).__init__()
|
||||
self.loss = VirtualLoss()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x1, x2, x3, x4, x5):
|
||||
predict, _, _ = self.network(x1, x2, x3, x4, x5)
|
||||
return self.loss(predict)
|
||||
|
||||
|
||||
def run_total_transformer_model_head(e_layer,
|
||||
d_layer,
|
||||
arg_parallel_config):
|
||||
dp = arg_parallel_config.data_parallel
|
||||
mp = arg_parallel_config.model_parallel
|
||||
pp = arg_parallel_config.pipeline_stage
|
||||
if dp * mp * pp != 1:
|
||||
set_auto_parallel_context(device_num=8,
|
||||
full_batch=True,
|
||||
global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, en_layer, de_layer, parallel_config):
|
||||
super(Net, self).__init__()
|
||||
self.embedding = VocabEmbedding(vocab_size=240, embedding_size=20,
|
||||
parallel_config=config.embedding_dp_mp_config)
|
||||
self.network = Transformer(encoder_layers=en_layer,
|
||||
decoder_layers=de_layer,
|
||||
batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
hidden_size=64,
|
||||
num_heads=8,
|
||||
ffn_hidden_size=64,
|
||||
parallel_config=parallel_config)
|
||||
self.head = Linear(in_channels=64, out_channels=200)
|
||||
self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config)
|
||||
|
||||
def construct(self, x1, x2, x3, x4, x5, y, mask):
|
||||
predict, _, _ = self.network(x1, x2, x3, x4, x5)
|
||||
return self.loss(predict)
|
||||
predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
|
||||
return self.loss(predict, y, mask)
|
||||
|
||||
config = TransformerParallelConfig(dp=1, mp=8)
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
|
||||
seq = 20
|
||||
if d_layer > 0:
|
||||
seq = 10
|
||||
label = Tensor(np.ones((2 * seq,)), mstype.int32)
|
||||
input_mask = Tensor(np.ones((2 * seq,)), mstype.float32)
|
||||
net = Net(en_layer=e_layer, de_layer=d_layer, parallel_config=arg_parallel_config)
|
||||
params = net.trainable_params()
|
||||
optimizer = AdamWeightDecay(params)
|
||||
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||
memory_mask, label, input_mask)
|
||||
net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
|
||||
model = Model(net_with_grad)
|
||||
|
||||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_transformer_model():
|
||||
set_auto_parallel_context(device_num=8, global_rank=0,
|
||||
full_batch=True,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
net = Transformer(encoder_layers=1,
|
||||
decoder_layers=2,
|
||||
batch_size=2,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
hidden_size=64,
|
||||
num_heads=8,
|
||||
ffn_hidden_size=64,
|
||||
parallel_config=config)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
|
||||
net = NetWithLossFiveInputs(net)
|
||||
params = net.trainable_params()
|
||||
optimizer = AdamWeightDecay(params)
|
||||
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||
memory_mask)
|
||||
net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
|
||||
model = Model(net_with_grad)
|
||||
|
||||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_transformer_model_head_parallel_only_encoder():
|
||||
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
|
||||
run_total_transformer_model_head(e_layer=2, d_layer=0, arg_parallel_config=local_config)
|
||||
|
||||
|
||||
def test_transformer_model_head_parallel():
|
||||
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
|
||||
run_total_transformer_model_head(e_layer=1, d_layer=1, arg_parallel_config=local_config)
|
||||
|
||||
|
||||
def test_transformer_model_head_parallel_decoder():
|
||||
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
|
||||
with pytest.raises(ValueError):
|
||||
run_total_transformer_model_head(e_layer=0, d_layer=1, arg_parallel_config=local_config)
|
||||
|
||||
|
||||
def test_transformer_model_head_stand_alone():
|
||||
local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=1)
|
||||
run_total_transformer_model_head(e_layer=2, d_layer=2, arg_parallel_config=local_config)
|
||||
|
||||
|
||||
def test_pipeline_single_transformer():
|
||||
set_auto_parallel_context(device_num=32,
|
||||
full_batch=True,
|
||||
pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
|
||||
net = Transformer(batch_size=4 // pipeline_config.micro_batch_num,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
encoder_layers=2,
|
||||
decoder_layers=2,
|
||||
hidden_size=64,
|
||||
num_heads=8,
|
||||
ffn_hidden_size=64,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=20,
|
||||
parallel_config=config)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 1, 20, 20)), mstype.float16)
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((2, 1, 10, 10)), mstype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 1, 10, 20)), mstype.float16)
|
||||
net = NetWithLoss(net)
|
||||
parallel_config=pipeline_config)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((4, 20, 64)), mstype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((4, 20, 20)), mstype.float16)
|
||||
decoder_input_value = Tensor(np.ones((4, 10, 64)), mstype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((4, 10, 10)), mstype.float16)
|
||||
memory_mask = Tensor(np.ones((4, 10, 20)), mstype.float16)
|
||||
net = NetWithLossFiveInputs(net)
|
||||
net = PipelineCell(net, pipeline_config.micro_batch_num)
|
||||
net = _VirtualDatasetCell(net)
|
||||
params = net.infer_param_pipeline_stage()
|
||||
optimizer = AdamWeightDecay(params)
|
||||
dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
|
||||
memory_mask)
|
||||
|
||||
model = Model(net)
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=1024, scale_factor=2, scale_window=1000)
|
||||
net_with_grad = _TrainPipelineWithLossScaleCell(net, optimizer=optimizer,
|
||||
scale_sense=update_cell)
|
||||
model = Model(net_with_grad)
|
||||
|
||||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
@ -96,17 +223,19 @@ def test_encoder():
|
|||
predict, _ = self.network(x1, x2)
|
||||
return self.loss(predict)
|
||||
|
||||
config = TransformerParallelConfig(dp=1, mp=8)
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
set_auto_parallel_context(device_num=8,
|
||||
full_batch=True,
|
||||
global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
net = TransformerEncoder(num_layers=2,
|
||||
batch_size=2,
|
||||
seq_length=16,
|
||||
hidden_size=8,
|
||||
ffn_hidden_size=64,
|
||||
seq_length=16,
|
||||
num_heads=8,
|
||||
parallel_config=config)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32)
|
||||
encoder_input_mask = Tensor(np.ones((2, 1, 16, 16)), mstype.float16)
|
||||
encoder_input_mask = Tensor(np.ones((2, 16, 16)), mstype.float16)
|
||||
|
||||
net = NetWithLoss(net)
|
||||
|
||||
|
@ -128,19 +257,22 @@ def test_decoder():
|
|||
predict, _, _ = self.network(x1, x2, x3, x4)
|
||||
return self.loss(predict)
|
||||
|
||||
config = TransformerParallelConfig(dp=1, mp=8)
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
set_auto_parallel_context(device_num=8,
|
||||
full_batch=True,
|
||||
global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
net = TransformerDecoder(num_layers=1,
|
||||
batch_size=8,
|
||||
hidden_size=16,
|
||||
ffn_hidden_size=8,
|
||||
num_heads=8,
|
||||
seq_length=10,
|
||||
src_seq_length=20,
|
||||
tgt_seq_length=10,
|
||||
parallel_config=config)
|
||||
|
||||
encoder_input_value = Tensor(np.ones((2, 20, 16)), mstype.float32)
|
||||
decoder_input_value = Tensor(np.ones((2, 10, 16)), mstype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((2, 1, 10, 10)), mstype.float16)
|
||||
memory_mask = Tensor(np.ones((2, 1, 10, 20)), mstype.float16)
|
||||
encoder_input_value = Tensor(np.ones((8, 20, 16)), mstype.float32)
|
||||
decoder_input_value = Tensor(np.ones((8, 10, 16)), mstype.float32)
|
||||
decoder_input_mask = Tensor(np.ones((8, 10, 10)), mstype.float16)
|
||||
memory_mask = Tensor(np.ones((8, 10, 20)), mstype.float16)
|
||||
|
||||
net = NetWithLoss(net)
|
||||
|
||||
|
@ -151,7 +283,6 @@ def test_decoder():
|
|||
|
||||
|
||||
def test_vocabembedding_dp_true():
|
||||
config = TransformerParallelConfig(dp=1, mp=8)
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
|
@ -164,15 +295,7 @@ def test_vocabembedding_dp_true():
|
|||
predict, _ = self.network(x1)
|
||||
return self.loss(predict)
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x1):
|
||||
return grad_all(self.network)(x1)
|
||||
|
||||
net = VocabEmbedding(vocab_size=100, embedding_size=16, parallel_config=config)
|
||||
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
|
||||
net = NetWithLoss(net)
|
||||
encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
|
||||
dataset = Dataset(encoder_input_value)
|
||||
|
@ -182,7 +305,6 @@ def test_vocabembedding_dp_true():
|
|||
|
||||
|
||||
def test_vocabembedding_dp_false():
|
||||
config = TransformerParallelConfig(dp=1, mp=8, vocab_emb_dp=False)
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
|
@ -195,18 +317,109 @@ def test_vocabembedding_dp_false():
|
|||
predict, _ = self.network(x1)
|
||||
return self.loss(predict)
|
||||
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
|
||||
def construct(self, x1):
|
||||
return grad_all(self.network)(x1)
|
||||
|
||||
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config)
|
||||
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
|
||||
net = NetWithLoss(net)
|
||||
encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
|
||||
dataset = Dataset(encoder_input_value)
|
||||
|
||||
model = Model(net)
|
||||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_parallel_cross_entroy_loss_semi_auto_parallel():
|
||||
set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
|
||||
|
||||
class NetWithLoss(nn.Cell):
|
||||
def __init__(self, network, config_setting):
|
||||
super(NetWithLoss, self).__init__()
|
||||
self.loss = CrossEntropyLoss(config_setting)
|
||||
self.network = network
|
||||
|
||||
def construct(self, x1, x2, x3):
|
||||
predict, _ = self.network(x1)
|
||||
predict = P.Reshape()(predict, (-1, 16))
|
||||
return self.loss(predict, x2, x3)
|
||||
|
||||
net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
|
||||
net = NetWithLoss(net, config.dp_mp_config)
|
||||
embed_ids = Tensor(np.ones((2, 64)), mstype.int32)
|
||||
labels = Tensor(np.ones((2 * 64,)), mstype.int32)
|
||||
input_mask = Tensor(np.ones((2 * 64,)), mstype.float32)
|
||||
dataset = Dataset(embed_ids, labels, input_mask)
|
||||
|
||||
model = Model(net)
|
||||
model.train(1, dataset, dataset_sink_mode=False)
|
||||
|
||||
|
||||
def test_transformer_parallel_config():
|
||||
parallel_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=3)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
parallel_test_config.data_parallel = False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parallel_test_config.data_parallel = 0
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
parallel_test_config.model_parallel = False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parallel_test_config.model_parallel = 0
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
parallel_test_config.pipeline_stage = False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parallel_test_config.pipeline_stage = 0
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
parallel_test_config.micro_batch_num = False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parallel_test_config.micro_batch_num = 0
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
parallel_test_config.gradient_aggregation_group = False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parallel_test_config.gradient_aggregation_group = 0
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
parallel_test_config.recompute = 1
|
||||
|
||||
parallel_test_config.recompute = False
|
||||
|
||||
assert not parallel_test_config.recompute
|
||||
|
||||
|
||||
def test_parallel_config():
|
||||
parallel_test_config = OpParallelConfig(data_parallel=1, model_parallel=3)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parallel_test_config.data_parallel = 0
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
parallel_test_config.model_parallel = False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parallel_test_config.model_parallel = 0
|
||||
|
||||
assert parallel_test_config.model_parallel == 3
|
||||
|
||||
|
||||
def test_embedding_parallel_config():
|
||||
parallel_test_config = EmbeddingOpParallelConfig(data_parallel=1, model_parallel=3, vocab_emb_dp=False)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parallel_test_config.data_parallel = 0
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
parallel_test_config.model_parallel = False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parallel_test_config.model_parallel = 0
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
parallel_test_config.vocab_emb_dp = 0
|
||||
|
||||
assert not parallel_test_config.vocab_emb_dp
|
||||
|
|
Loading…
Reference in New Issue